Keras
legal
kevin110211's picture
Upload 51 files
5d58b52
# from abc import ABC, ABCMeta, abstractclassmethod
import torch
import numpy as np
from abc import ABC, abstractmethod, ABCMeta
class Metric(metaclass=ABCMeta):
"""
- reset() in the begining of every epoch.
- update_per_batch() after every batch.
- update_per_epoch() after every epoch.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def reset(self):
pass
@abstractmethod
def update_per_batch(self, output):
pass
@abstractmethod
def update_per_epoch(self):
pass
class Top_K_Metric(Metric):
"""
Stores accuracy (score), loss and timing info
"""
def __init__(self, topnum=[1,3,10]):
super().__init__()
# assert len(topnum) == 3
self.topnum = topnum
self.k_num = len(self.topnum)
self.reset()
def reset(self):
self.total_loss = 0
self.correct_list = [0] * self.k_num
self.acc_list = [0] * self.k_num
self.acc_all = 0
self.num_examples = 0
self.num_epoch = 0
self.mrr = 0
self.mr = 0
self.mrr_all = 0
self.mr_all = 0
def update_per_batch(self, loss, ans, pred):
self.total_loss += loss
self.num_epoch += 1
self.top_k_list = self.batch_accuracy(pred, ans)
self.num_examples += self.top_k_list[0].shape[0]
for i in range(self.k_num):
self.correct_list[i] += self.top_k_list[i].sum().item()
# mrr
mrr_tmp, mr_tmp = self.batch_mr_mrr(pred, ans)
self.mrr_all += mrr_tmp.sum().item()
self.mr_all += mr_tmp.sum().item()
def update_per_epoch(self):
for i in range(self.k_num):
self.acc_list[i] = 100 * (self.correct_list[i] / self.num_examples)
self.mr = self.mr_all / self.num_examples
self.mrr = self.mrr_all / self.num_examples
self.total_loss = self.total_loss / self.num_epoch
self.acc_all = sum(self.acc_list)
def batch_accuracy(self, predicted, true):
""" Compute the accuracies for a batch of predictions and answers """
if len(true.shape) == 3:
true = true[0]
_, ok = predicted.topk(max(self.topnum), dim=1)
agreeing_all = torch.zeros([predicted.shape[0], 1], dtype=torch.float).cuda()
top_k_list = [0]*self.topnum
for i in range(max(self.topnum)):
tmp = ok[:, i].reshape(-1, 1)
agreeing_all += true.gather(dim=1, index=tmp)
for k in range(self.k_num):
if i == self.topnum[k] - 1:
top_k_list[k] = (agreeing_all * 0.3).clamp(max=1)
break
return top_k_list
def batch_mr_mrr(self, predicted, true):
if len(true.shape) == 3:
true = true[0]
# 计算
top_rank = predicted.shape[1]
batch_size = predicted.shape[0]
_, predict_ans_rank = predicted.topk(top_rank, dim=1) # 答案排名的坐标 batchsize * 500
_, real_ans = true.topk(1, dim=1) # 真正的答案:batchsize * 1
# 扩充维度
real_ans = real_ans.expand(batch_size, top_rank)
ans_different = torch.abs(predict_ans_rank - real_ans)
# 此时为0的位置就是预测正确的位置
_, real_ans_list = ans_different.topk(top_rank, dim=1) #此时最后一位的数值就是正确答案在预测答案里面的位置,为 0
real_ans_list = real_ans_list + 1.0
mr = real_ans_list[:,-1].reshape(-1,1).to(torch.float64)
mrr = 1.0 / mr
# pdb.set_trace()
return mrr,mr
if __name__ == '__main__':
pass