|
import datetime |
|
import logging |
|
import time |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
from models.criterions import get_sim |
|
from utils.basic_utils import MetricLogger |
|
from utils.distributed import get_rank, get_world_size |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def extract_text_feats(texts, max_txt_l, tokenizer, model, device): |
|
num_text = len(texts) |
|
text_bs = 256 |
|
text_feats = [] |
|
text_atts = [] |
|
text_ids = [] |
|
|
|
for i in range(0, num_text, text_bs): |
|
text = texts[i : min(num_text, i + text_bs)] |
|
text_input = tokenizer( |
|
text, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_txt_l, |
|
return_tensors="pt", |
|
).to(device) |
|
|
|
text_feat = model.encode_text(text_input)[0] |
|
text_feats.append(text_feat) |
|
text_atts.append(text_input.attention_mask) |
|
text_ids.append(text_input.input_ids) |
|
|
|
text_feats = torch.cat(text_feats, dim=0) |
|
text_atts = torch.cat(text_atts, dim=0) |
|
text_ids = torch.cat(text_ids, dim=0) |
|
return text_feats, text_atts, text_ids |
|
|
|
|
|
def extract_vision_feats(data_loader, model, device, config): |
|
image_feats_all = [] |
|
pooled_image_feats_all = [] |
|
metric_logger = MetricLogger(delimiter=" ") |
|
header = "extracting image feats" |
|
iterator = metric_logger.log_every(data_loader, 100, header) |
|
for image, img_id in iterator: |
|
image = image.to(device, non_blocking=True) |
|
image_feat, pooled_image_feat = model.encode_vision(image, test=True) |
|
if config.evaluation.eval_frame_ensemble == "concat": |
|
if len(image_feat.shape) == 4: |
|
image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous() |
|
image_feat = image_feat.unsqueeze(1) |
|
else: |
|
assert config.video_input.num_frames == 1, "only support single-frame" |
|
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] |
|
if config.evaluation.eval_offload: |
|
image_feats_all.append(image_feat.cpu()) |
|
pooled_image_feats_all.append(pooled_image_feat.cpu()) |
|
else: |
|
image_feats_all.append(image_feat) |
|
pooled_image_feats_all.append(pooled_image_feat) |
|
|
|
image_feats_all = torch.cat(image_feats_all, dim=0) |
|
|
|
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0) |
|
return image_feats_all, pooled_image_feats_all |
|
|
|
|
|
@torch.no_grad() |
|
def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""): |
|
if dist.get_rank() == 0: |
|
|
|
with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.float): |
|
|
|
if config.model.model_cls == "VindLU_VideoCLIP" or config.model.model_cls == "ViCLIP": |
|
|
|
i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation_video_clip( |
|
model, data_loader, tokenizer, device, config |
|
) |
|
else: |
|
i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation( |
|
model, data_loader, tokenizer, device, config |
|
) |
|
score_pairs = [ |
|
(prefix + "/", i2t_x, t2i_x), |
|
(prefix + "_emb/", i2t_emb, t2i_emb), |
|
] |
|
res = dict() |
|
for name, i2t, t2i in score_pairs: |
|
if i2t is not None: |
|
txt2img_ids = data_loader.dataset.txt2img |
|
img2txt_ids = data_loader.dataset.img2txt |
|
res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids) |
|
|
|
else: |
|
res = dict() |
|
|
|
res_list = [res] |
|
dist.broadcast_object_list(res_list, src=0) |
|
|
|
res = res_list[0] |
|
|
|
return res |
|
|
|
|
|
@torch.no_grad() |
|
def evaluation(model, data_loader, tokenizer, device, config): |
|
model.eval() |
|
|
|
metric_logger = MetricLogger(delimiter=" ") |
|
header = "Evaluation:" |
|
dtype = torch.half if config.fp16 else torch.float |
|
media_type = data_loader.dataset.media_type |
|
logger.info(f"Start evaluation for media_type={media_type}") |
|
|
|
logger.info("Computing dual encoder features...") |
|
start_time = time.time() |
|
|
|
|
|
texts = data_loader.dataset.text |
|
max_txt_l = config.inputs.max_txt_l |
|
if not isinstance(max_txt_l, int): |
|
max_txt_l = max_txt_l[media_type] |
|
text_feats, text_atts, text_ids = extract_text_feats( |
|
texts, max_txt_l, tokenizer, model, device |
|
) |
|
|
|
image_feats, pooled_image_feats = extract_vision_feats( |
|
data_loader, model, device, config |
|
) |
|
logger.info("Finished feature extraction") |
|
logger.info("Computing ITC scores [dot-product]") |
|
_pooled_image_feats = ( |
|
pooled_image_feats.to(device, non_blocking=True) |
|
if config.evaluation.eval_offload |
|
else pooled_image_feats |
|
) |
|
i2t_scores, t2i_scores = get_sim( |
|
model.vision_proj(_pooled_image_feats), model.text_proj(text_feats[:, 0]), |
|
agg_method=config.model.get("agg_method", "mean"), |
|
) |
|
logger.info("Computing ITC scores [dot-product], done!") |
|
|
|
num_images = len(data_loader.dataset.image) |
|
i2t_scores_x = torch.full((num_images, len(texts)), -100.0).to( |
|
device, torch.float, non_blocking=True |
|
) |
|
|
|
|
|
logger.info("Rerank dual-encoder results with cross-encoder...") |
|
num_tasks = get_world_size() |
|
rank = get_rank() |
|
|
|
|
|
step = num_images // num_tasks + 1 |
|
start = rank * step |
|
end = min(num_images, start + step) |
|
|
|
text_encoder = model.get_text_encoder() |
|
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header) |
|
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}") |
|
|
|
|
|
n_clip_per_video = ( |
|
image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1] |
|
) |
|
|
|
logger.info( |
|
f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}" |
|
) |
|
for i, sims in enumerate(iterator): |
|
k = min(len(sims), config.evaluation.k_test) |
|
topk_sim, topk_idx = sims.topk(k=k, dim=0) |
|
|
|
clip_scores = [] |
|
for clip_idx in range(n_clip_per_video): |
|
if config.deep_fusion: |
|
encoder_output = [ |
|
feat[start + i, clip_idx].to(device, non_blocking=True) |
|
for feat in image_feats |
|
] |
|
|
|
else: |
|
encoder_output = ( |
|
image_feats[start + i, clip_idx].to(device, non_blocking=True) |
|
if config.evaluation.eval_offload |
|
else image_feats[start + i, clip_idx] |
|
) |
|
|
|
""" original |
|
encoder_output = encoder_output.repeat(k, 1, 1) # (k=128, #frm*Li, d) |
|
encoder_att = torch.ones( |
|
encoder_output.size()[:-1], dtype=torch.long |
|
).to(device, non_blocking=True) |
|
output = text_encoder( |
|
encoder_embeds=text_feats[topk_idx], |
|
attention_mask=text_atts[topk_idx], |
|
encoder_hidden_states=encoder_output, |
|
encoder_attention_mask=encoder_att, |
|
return_dict=True, |
|
mode="fusion" |
|
) |
|
|
|
itm_embeds = output.last_hidden_state[:, 0] |
|
""" |
|
|
|
|
|
bs = 128 |
|
|
|
itm_embeds = [] |
|
|
|
if not config.deep_fusion: |
|
encoder_output = [encoder_output] |
|
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] |
|
encoder_att = [ |
|
torch.ones(feat.size()[:-1], dtype=torch.long).to(device, non_blocking=True) |
|
for feat in encoder_output |
|
] |
|
|
|
for j in range(0, len(topk_idx), bs): |
|
cur_bs = min(bs, len(topk_idx) - j) |
|
encoder_output = [feat[:cur_bs] for feat in encoder_output] |
|
encoder_att = [att[:cur_bs] for att in encoder_att] |
|
|
|
batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0] |
|
batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0] |
|
|
|
if "VindLU_BLIP" in config.model.get("model_cls", ""): |
|
output = model.vtm_embed( |
|
text_ids=text_ids[topk_idx[j:j+bs]], |
|
text_atts=text_atts[topk_idx[j:j+bs]], |
|
vision_embeds=batch_encoder_output, |
|
vision_atts=batch_encoder_att, |
|
) |
|
else: |
|
output = text_encoder( |
|
encoder_embeds=text_feats[topk_idx[j:j+bs]], |
|
attention_mask=text_atts[topk_idx[j:j+bs]], |
|
encoder_hidden_states=batch_encoder_output, |
|
encoder_attention_mask=batch_encoder_att, |
|
return_dict=True, |
|
mode="fusion", |
|
).last_hidden_state[:, 0] |
|
|
|
itm_embeds.append(output) |
|
|
|
itm_embeds = torch.cat(itm_embeds, dim=0) |
|
|
|
""" Original |
|
if config.deep_fusion: |
|
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] |
|
encoder_att = [ |
|
torch.ones(feat.size()[:-1], dtype=torch.long).to( |
|
device, non_blocking=True |
|
) |
|
for feat in encoder_output |
|
] |
|
else: |
|
encoder_output = encoder_output.repeat(bs, 1, 1) |
|
encoder_att = torch.ones( |
|
encoder_output.size()[:-1], dtype=torch.long |
|
).to(device, non_blocking=True) |
|
|
|
if config.deep_fusion: |
|
if len(topk_idx) % bs != 0: |
|
left = len(topk_idx) % bs |
|
left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output] |
|
left_encoder_att = [ |
|
torch.ones(feat.size()[:-1], dtype=torch.long).to( |
|
device, non_blocking=True |
|
) |
|
for feat in left_encoder_output |
|
] |
|
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output] |
|
encoder_att = [ |
|
torch.ones(feat.size()[:-1], dtype=torch.long).to( |
|
device, non_blocking=True |
|
) |
|
for feat in encoder_output |
|
] |
|
else: |
|
if len(topk_idx) % bs != 0: |
|
left = len(topk_idx) % bs |
|
left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d) |
|
left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to( |
|
device, non_blocking=True |
|
) |
|
encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d) |
|
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( |
|
device, non_blocking=True |
|
) |
|
|
|
for j in range(0, len(topk_idx), bs): |
|
if j + bs > len(topk_idx): |
|
output = text_encoder( |
|
encoder_embeds=text_feats[topk_idx[j:]], |
|
attention_mask=text_atts[topk_idx[j:]], |
|
encoder_hidden_states=left_encoder_output, |
|
encoder_attention_mask=left_encoder_att, |
|
return_dict=True, |
|
mode="fusion", |
|
) |
|
else: |
|
output = text_encoder( |
|
encoder_embeds=text_feats[topk_idx[j : j + bs]], |
|
attention_mask=text_atts[topk_idx[j : j + bs]], |
|
encoder_hidden_states=encoder_output, |
|
encoder_attention_mask=encoder_att, |
|
return_dict=True, |
|
mode="fusion", |
|
) |
|
batch_itm_embeds = output.last_hidden_state[:, 0] |
|
itm_embeds.append(batch_itm_embeds) |
|
itm_embeds = torch.cat(itm_embeds, dim=0) |
|
# end new |
|
""" |
|
|
|
score = model.itm_head(itm_embeds)[:, 1] |
|
clip_scores.append(score) |
|
|
|
if len(clip_scores) == 1: |
|
score = clip_scores[0] |
|
else: |
|
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] |
|
clip_scores = torch.stack(clip_scores) |
|
if config.evaluation.eval_frame_ensemble == "mean": |
|
score = clip_scores.mean(0) |
|
elif config.evaluation.eval_frame_ensemble == "max": |
|
score = clip_scores.max(0)[0] |
|
elif config.evaluation.eval_frame_ensemble == "lse": |
|
score = torch.logsumexp(clip_scores, dim=0) |
|
else: |
|
raise ValueError( |
|
"config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." |
|
) |
|
|
|
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype) |
|
|
|
|
|
num_text = len(data_loader.dataset.text) |
|
t2i_scores_x = torch.full((num_text, len(data_loader.dataset.image)), -100.0).to( |
|
device, torch.float, non_blocking=True |
|
) |
|
|
|
step = num_text // num_tasks + 1 |
|
start = rank * step |
|
end = min(num_text, start + step) |
|
|
|
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header) |
|
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}") |
|
|
|
n_clip_per_video = ( |
|
image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1] |
|
) |
|
k = config.evaluation.k_test |
|
logger.info(f"Top-{k} matching") |
|
for i, sims in enumerate(iterator): |
|
k = min(len(sims), config.evaluation.k_test) |
|
topk_sim, topk_idx = sims.topk(k=k, dim=0) |
|
|
|
clip_scores = [] |
|
for clip_idx in range(n_clip_per_video): |
|
|
|
"""old |
|
encoder_output = image_feats[topk_idx, clip_idx].to(device, non_blocking=True) \ |
|
if config.evaluation.eval_offload else image_feats[topk_idx, clip_idx] |
|
encoder_att = torch.ones( |
|
encoder_output.size()[:-1], dtype=torch.long |
|
).to(device, non_blocking=True) |
|
output = text_encoder( |
|
encoder_embeds=text_feats[start+i].repeat(k, 1, 1), |
|
attention_mask=text_atts[start+i].repeat(k, 1), |
|
encoder_hidden_states=encoder_output, |
|
encoder_attention_mask=encoder_att, |
|
return_dict=True, |
|
mode="fusion" |
|
) |
|
|
|
itm_embeds = output.last_hidden_state[:, 0] |
|
""" |
|
|
|
|
|
bs = 128 |
|
|
|
itm_embeds = [] |
|
for j in range(0, len(topk_idx), bs): |
|
|
|
fake_image_feats = [image_feats] if not config.deep_fusion else image_feats |
|
|
|
encoder_output = [ |
|
feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True) |
|
if config.evaluation.eval_offload |
|
else feat[topk_idx[j : j + bs], clip_idx] |
|
for feat in fake_image_feats |
|
] |
|
encoder_att = [ |
|
torch.ones(feat.size()[:-1], dtype=torch.long).to( |
|
device, non_blocking=True |
|
) |
|
for feat in encoder_output |
|
] |
|
cur_bs = min(bs, len(topk_idx) - j) |
|
|
|
batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0] |
|
batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0] |
|
|
|
if "VindLU_BLIP" in config.model.get("model_cls", ""): |
|
output = model.vtm_embed( |
|
text_ids=text_ids[start + i].repeat(cur_bs, 1), |
|
text_atts=text_atts[start + i].repeat(cur_bs, 1), |
|
vision_embeds=batch_encoder_output, |
|
vision_atts=batch_encoder_att, |
|
) |
|
else: |
|
output = text_encoder( |
|
encoder_embeds=text_feats[start + i].repeat(cur_bs, 1, 1), |
|
attention_mask=text_atts[start + i].repeat(cur_bs, 1), |
|
encoder_hidden_states=batch_encoder_output, |
|
encoder_attention_mask=batch_encoder_att, |
|
return_dict=True, |
|
mode="fusion", |
|
).last_hidden_state[:, 0] |
|
|
|
itm_embeds.append(output) |
|
|
|
""" old |
|
if config.deep_fusion: |
|
encoder_output = [ |
|
feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True) |
|
for feat in image_feats |
|
] |
|
encoder_att = [ |
|
torch.ones(feat.size()[:-1], dtype=torch.long).to( |
|
device, non_blocking=True |
|
) |
|
for feat in encoder_output |
|
] |
|
else: |
|
encoder_output = ( |
|
image_feats[topk_idx[j : j + bs], clip_idx].to( |
|
device, non_blocking=True |
|
) |
|
if config.evaluation.eval_offload |
|
else image_feats[topk_idx[j : j + bs], clip_idx] |
|
) |
|
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( |
|
device, non_blocking=True |
|
) |
|
|
|
cur_bs = ( |
|
encoder_output.shape[0] |
|
if not config.deep_fusion |
|
else encoder_output[0].shape[0] |
|
) |
|
output = text_encoder( |
|
encoder_embeds=text_feats[start + i].repeat(cur_bs, 1, 1), |
|
attention_mask=text_atts[start + i].repeat(cur_bs, 1), |
|
encoder_hidden_states=encoder_output, |
|
encoder_attention_mask=encoder_att, |
|
return_dict=True, |
|
mode="fusion", |
|
) |
|
|
|
batch_itm_embeds = output.last_hidden_state[:, 0] |
|
itm_embeds.append(batch_itm_embeds) |
|
""" |
|
|
|
itm_embeds = torch.cat(itm_embeds, dim=0) |
|
|
|
|
|
score = model.itm_head(itm_embeds)[:, 1] |
|
clip_scores.append(score) |
|
|
|
if len(clip_scores) == 1: |
|
score = clip_scores[0] |
|
else: |
|
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"] |
|
clip_scores = torch.stack(clip_scores) |
|
if config.evaluation.eval_frame_ensemble == "mean": |
|
score = clip_scores.mean(0) |
|
elif config.evaluation.eval_frame_ensemble == "max": |
|
score = clip_scores.max(0)[0] |
|
elif config.evaluation.eval_frame_ensemble == "lse": |
|
score = torch.logsumexp(clip_scores, dim=0) |
|
else: |
|
raise ValueError( |
|
"config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1." |
|
) |
|
|
|
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype) |
|
|
|
if config.distributed: |
|
|
|
dist.barrier() |
|
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) |
|
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
logger.info(f"Evaluation time {total_time_str}") |
|
|
|
return ( |
|
i2t_scores_x.cpu().float().numpy(), |
|
t2i_scores_x.cpu().float().numpy(), |
|
i2t_scores.cpu().float().numpy(), |
|
i2t_scores.T.cpu().float().numpy(), |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def evaluation_video_clip(model, data_loader, tokenizer, device, config): |
|
model.eval() |
|
|
|
metric_logger = MetricLogger(delimiter=" ") |
|
header = "Evaluation:" |
|
|
|
dtype = torch.float32 |
|
media_type = data_loader.dataset.media_type |
|
logger.info(f"Start evaluation for media_type={media_type}") |
|
|
|
logger.info("Computing dual encoder features...") |
|
|
|
|
|
texts = data_loader.dataset.text |
|
num_text = len(texts) |
|
text_bs = 256 |
|
text_feats = [] |
|
for i in range(0, num_text, text_bs): |
|
text = texts[i : min(num_text, i + text_bs)] |
|
text_feat = model.encode_text(text) |
|
text_feats.append(text_feat.cpu()) |
|
text_feats = torch.cat(text_feats, dim=0) |
|
logger.info("Finished computing text features") |
|
|
|
if hasattr(data_loader.dataset, "num_prompts"): |
|
np = data_loader.dataset.num_prompts |
|
logger.info("Using {} prompts".format(np)) |
|
nt = len(data_loader.dataset.text) // np |
|
text_feats = text_feats.view(nt, np, -1) |
|
|
|
image_feats = [] |
|
metric_logger = MetricLogger(delimiter=" ") |
|
header = "extracting image feats" |
|
iterator = metric_logger.log_every(data_loader, 100, header) |
|
for image, _ in iterator: |
|
image = image.to(device, non_blocking=True) |
|
image_feat = model.encode_vision(image, test=True) |
|
image_feats.append(image_feat.cpu()) |
|
image_feats = torch.cat(image_feats, dim=0) |
|
logger.info("Finished feature extraction") |
|
logger.info("Computing ITC scores [dot-product]") |
|
i2t_scores, t2i_scores = get_sim(image_feats, text_feats) |
|
del image_feats, text_feats |
|
logger.info("Computing ITC scores [dot-product], done!") |
|
|
|
i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0) |
|
i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0) |
|
|
|
return ( |
|
i2t_scores.cpu().float().numpy(), |
|
i2t_scores.T.cpu().float().numpy(), |
|
i2t_scores_dsl.cpu().float().numpy(), |
|
i2t_scores_dsl_T.cpu().float().numpy(), |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): |
|
|
|
ranks = np.zeros(scores_i2t.shape[0]) |
|
for index, score in enumerate(scores_i2t): |
|
inds = np.argsort(score)[::-1] |
|
|
|
gt_txt_ids = img2txt[index] |
|
if isinstance(gt_txt_ids, int): |
|
ranks[index] = np.where(inds == gt_txt_ids)[0][0] |
|
else: |
|
rank = 1e20 |
|
for i in gt_txt_ids: |
|
tmp = np.where(inds == i)[0][0] |
|
if tmp < rank: |
|
rank = tmp |
|
ranks[index] = rank |
|
|
|
|
|
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) |
|
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) |
|
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) |
|
|
|
|
|
ranks = np.zeros(scores_t2i.shape[0]) |
|
|
|
for index, score in enumerate(scores_t2i): |
|
inds = np.argsort(score)[::-1] |
|
gt_img_ids = txt2img[index] |
|
if isinstance(gt_img_ids, int): |
|
ranks[index] = np.where(inds == gt_img_ids)[0][0] |
|
else: |
|
|
|
rank = 1e20 |
|
for i in gt_img_ids: |
|
tmp = np.where(inds == i)[0][0] |
|
if tmp < rank: |
|
rank = tmp |
|
ranks[index] = rank |
|
|
|
|
|
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) |
|
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) |
|
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) |
|
|
|
tr_mean = (tr1 + tr5 + tr10) / 3 |
|
ir_mean = (ir1 + ir5 + ir10) / 3 |
|
r_mean = (tr_mean + ir_mean) / 2 |
|
|
|
eval_result = { |
|
"txt_r1": tr1, |
|
"txt_r5": tr5, |
|
"txt_r10": tr10, |
|
"txt_r_mean": tr_mean, |
|
"img_r1": ir1, |
|
"img_r5": ir5, |
|
"img_r10": ir10, |
|
"img_r_mean": ir_mean, |
|
"r_mean": r_mean, |
|
} |
|
eval_result = {k: round(v, 2) for k, v in eval_result.items()} |
|
return eval_result |
|
|