SMILE / tasks /retrieval_utils.py
fmthoker's picture
Upload 95 files
401fa20 verified
raw
history blame
25 kB
import datetime
import logging
import time
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from models.criterions import get_sim
from utils.basic_utils import MetricLogger
from utils.distributed import get_rank, get_world_size
logger = logging.getLogger(__name__)
def extract_text_feats(texts, max_txt_l, tokenizer, model, device):
num_text = len(texts)
text_bs = 256
text_feats = []
text_atts = []
text_ids = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_txt_l,
return_tensors="pt",
).to(device)
text_feat = model.encode_text(text_input)[0]
text_feats.append(text_feat)
text_atts.append(text_input.attention_mask)
text_ids.append(text_input.input_ids)
text_feats = torch.cat(text_feats, dim=0)
text_atts = torch.cat(text_atts, dim=0)
text_ids = torch.cat(text_ids, dim=0)
return text_feats, text_atts, text_ids
def extract_vision_feats(data_loader, model, device, config):
image_feats_all = []
pooled_image_feats_all = []
metric_logger = MetricLogger(delimiter=" ")
header = "extracting image feats"
iterator = metric_logger.log_every(data_loader, 100, header)
for image, img_id in iterator:
image = image.to(device, non_blocking=True)
image_feat, pooled_image_feat = model.encode_vision(image, test=True)
if config.evaluation.eval_frame_ensemble == "concat": # default
if len(image_feat.shape) == 4:
image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
image_feat = image_feat.unsqueeze(1) # (bsz, 1, #frm*L, d)
else:
assert config.video_input.num_frames == 1, "only support single-frame"
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
if config.evaluation.eval_offload:
image_feats_all.append(image_feat.cpu())
pooled_image_feats_all.append(pooled_image_feat.cpu())
else:
image_feats_all.append(image_feat)
pooled_image_feats_all.append(pooled_image_feat)
image_feats_all = torch.cat(image_feats_all, dim=0)
pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)
return image_feats_all, pooled_image_feats_all
@torch.no_grad()
def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""):
if dist.get_rank() == 0: # Only on one rank
#with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.bfloat16):
with torch.cuda.amp.autocast(enabled=config.fp16, dtype=torch.float):
#if config.model.model_cls == "VindLU_VideoCLIP":
if config.model.model_cls == "VindLU_VideoCLIP" or config.model.model_cls == "ViCLIP":
i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation_video_clip(
model, data_loader, tokenizer, device, config
)
else:
i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation(
model, data_loader, tokenizer, device, config
)
score_pairs = [
(prefix + "/", i2t_x, t2i_x),
(prefix + "_emb/", i2t_emb, t2i_emb),
]
res = dict()
for name, i2t, t2i in score_pairs:
if i2t is not None:
txt2img_ids = data_loader.dataset.txt2img
img2txt_ids = data_loader.dataset.img2txt
res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids)
else:
res = dict()
res_list = [res]
dist.broadcast_object_list(res_list, src=0)
res = res_list[0]
return res
@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config):
model.eval()
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
dtype = torch.half if config.fp16 else torch.float
media_type = data_loader.dataset.media_type
logger.info(f"Start evaluation for media_type={media_type}")
logger.info("Computing dual encoder features...")
start_time = time.time()
# this computes all features in each GPU
texts = data_loader.dataset.text
max_txt_l = config.inputs.max_txt_l
if not isinstance(max_txt_l, int):
max_txt_l = max_txt_l[media_type]
text_feats, text_atts, text_ids = extract_text_feats(
texts, max_txt_l, tokenizer, model, device
) # (bsz, Lt, d), (bsz, Lt)
image_feats, pooled_image_feats = extract_vision_feats(
data_loader, model, device, config
) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
logger.info("Finished feature extraction")
logger.info("Computing ITC scores [dot-product]")
_pooled_image_feats = (
pooled_image_feats.to(device, non_blocking=True)
if config.evaluation.eval_offload
else pooled_image_feats
)
i2t_scores, t2i_scores = get_sim(
model.vision_proj(_pooled_image_feats), model.text_proj(text_feats[:, 0]),
agg_method=config.model.get("agg_method", "mean"),
)
logger.info("Computing ITC scores [dot-product], done!")
num_images = len(data_loader.dataset.image)
i2t_scores_x = torch.full((num_images, len(texts)), -100.0).to(
device, torch.float, non_blocking=True
)
# computes only part of the scores at each GPU, gather at the end
logger.info("Rerank dual-encoder results with cross-encoder...")
num_tasks = get_world_size()
rank = get_rank()
# only uses the part associated with the raw eval set
# compute image2text #
step = num_images // num_tasks + 1
start = rank * step
end = min(num_images, start + step)
text_encoder = model.get_text_encoder()
iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
# generate score for each clip, and aggregate all clip scores for a video
n_clip_per_video = (
image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
)
logger.info(
f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}"
)
for i, sims in enumerate(iterator):
k = min(len(sims), config.evaluation.k_test)
topk_sim, topk_idx = sims.topk(k=k, dim=0)
clip_scores = []
for clip_idx in range(n_clip_per_video):
if config.deep_fusion:
encoder_output = [
feat[start + i, clip_idx].to(device, non_blocking=True)
for feat in image_feats
]
else:
encoder_output = (
image_feats[start + i, clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else image_feats[start + i, clip_idx]
) # (#frm*Li, d)
""" original
encoder_output = encoder_output.repeat(k, 1, 1) # (k=128, #frm*Li, d)
encoder_att = torch.ones(
encoder_output.size()[:-1], dtype=torch.long
).to(device, non_blocking=True)
output = text_encoder(
encoder_embeds=text_feats[topk_idx],
attention_mask=text_atts[topk_idx],
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
mode="fusion"
)
itm_embeds = output.last_hidden_state[:, 0]
"""
# new
bs = 128
# bs = config.batch_size_test.video
itm_embeds = []
if not config.deep_fusion: # Create fake list
encoder_output = [encoder_output]
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(device, non_blocking=True)
for feat in encoder_output
]
for j in range(0, len(topk_idx), bs):
cur_bs = min(bs, len(topk_idx) - j)
encoder_output = [feat[:cur_bs] for feat in encoder_output]
encoder_att = [att[:cur_bs] for att in encoder_att]
batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0]
batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0]
if "VindLU_BLIP" in config.model.get("model_cls", ""):
output = model.vtm_embed(
text_ids=text_ids[topk_idx[j:j+bs]],
text_atts=text_atts[topk_idx[j:j+bs]],
vision_embeds=batch_encoder_output,
vision_atts=batch_encoder_att,
)
else:
output = text_encoder(
encoder_embeds=text_feats[topk_idx[j:j+bs]],
attention_mask=text_atts[topk_idx[j:j+bs]],
encoder_hidden_states=batch_encoder_output,
encoder_attention_mask=batch_encoder_att,
return_dict=True,
mode="fusion",
).last_hidden_state[:, 0]
itm_embeds.append(output)
itm_embeds = torch.cat(itm_embeds, dim=0)
""" Original
if config.deep_fusion:
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
else:
encoder_output = encoder_output.repeat(bs, 1, 1)
encoder_att = torch.ones(
encoder_output.size()[:-1], dtype=torch.long
).to(device, non_blocking=True)
if config.deep_fusion:
if len(topk_idx) % bs != 0:
left = len(topk_idx) % bs
left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output]
left_encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in left_encoder_output
]
encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
else:
if len(topk_idx) % bs != 0:
left = len(topk_idx) % bs
left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d)
left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for j in range(0, len(topk_idx), bs):
if j + bs > len(topk_idx):
output = text_encoder(
encoder_embeds=text_feats[topk_idx[j:]],
attention_mask=text_atts[topk_idx[j:]],
encoder_hidden_states=left_encoder_output,
encoder_attention_mask=left_encoder_att,
return_dict=True,
mode="fusion",
)
else:
output = text_encoder(
encoder_embeds=text_feats[topk_idx[j : j + bs]],
attention_mask=text_atts[topk_idx[j : j + bs]],
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
mode="fusion",
)
batch_itm_embeds = output.last_hidden_state[:, 0]
itm_embeds.append(batch_itm_embeds)
itm_embeds = torch.cat(itm_embeds, dim=0)
# end new
"""
score = model.itm_head(itm_embeds)[:, 1]
clip_scores.append(score)
if len(clip_scores) == 1:
score = clip_scores[0]
else:
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
clip_scores = torch.stack(clip_scores) # (#clips, k)
if config.evaluation.eval_frame_ensemble == "mean":
score = clip_scores.mean(0)
elif config.evaluation.eval_frame_ensemble == "max":
score = clip_scores.max(0)[0]
elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp
score = torch.logsumexp(clip_scores, dim=0)
else:
raise ValueError(
"config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1."
)
i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
# compute text2image #
num_text = len(data_loader.dataset.text)
t2i_scores_x = torch.full((num_text, len(data_loader.dataset.image)), -100.0).to(
device, torch.float, non_blocking=True
)
step = num_text // num_tasks + 1
start = rank * step
end = min(num_text, start + step)
iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
# generate score for each clip, and aggregate all clip scores for a video
n_clip_per_video = (
image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
)
k = config.evaluation.k_test
logger.info(f"Top-{k} matching")
for i, sims in enumerate(iterator):
k = min(len(sims), config.evaluation.k_test)
topk_sim, topk_idx = sims.topk(k=k, dim=0)
clip_scores = []
for clip_idx in range(n_clip_per_video):
"""old
encoder_output = image_feats[topk_idx, clip_idx].to(device, non_blocking=True) \
if config.evaluation.eval_offload else image_feats[topk_idx, clip_idx]
encoder_att = torch.ones(
encoder_output.size()[:-1], dtype=torch.long
).to(device, non_blocking=True)
output = text_encoder(
encoder_embeds=text_feats[start+i].repeat(k, 1, 1),
attention_mask=text_atts[start+i].repeat(k, 1),
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
mode="fusion"
)
itm_embeds = output.last_hidden_state[:, 0]
"""
# new
bs = 128
# bs = config.batch_size_test.video
itm_embeds = []
for j in range(0, len(topk_idx), bs):
fake_image_feats = [image_feats] if not config.deep_fusion else image_feats
encoder_output = [
feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True)
if config.evaluation.eval_offload
else feat[topk_idx[j : j + bs], clip_idx]
for feat in fake_image_feats
]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
cur_bs = min(bs, len(topk_idx) - j)
batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0]
batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0]
if "VindLU_BLIP" in config.model.get("model_cls", ""):
output = model.vtm_embed(
text_ids=text_ids[start + i].repeat(cur_bs, 1),
text_atts=text_atts[start + i].repeat(cur_bs, 1),
vision_embeds=batch_encoder_output,
vision_atts=batch_encoder_att,
)
else:
output = text_encoder(
encoder_embeds=text_feats[start + i].repeat(cur_bs, 1, 1),
attention_mask=text_atts[start + i].repeat(cur_bs, 1),
encoder_hidden_states=batch_encoder_output,
encoder_attention_mask=batch_encoder_att,
return_dict=True,
mode="fusion",
).last_hidden_state[:, 0]
itm_embeds.append(output)
""" old
if config.deep_fusion:
encoder_output = [
feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True)
for feat in image_feats
]
encoder_att = [
torch.ones(feat.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
for feat in encoder_output
]
else:
encoder_output = (
image_feats[topk_idx[j : j + bs], clip_idx].to(
device, non_blocking=True
)
if config.evaluation.eval_offload
else image_feats[topk_idx[j : j + bs], clip_idx]
)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
device, non_blocking=True
)
cur_bs = (
encoder_output.shape[0]
if not config.deep_fusion
else encoder_output[0].shape[0]
)
output = text_encoder(
encoder_embeds=text_feats[start + i].repeat(cur_bs, 1, 1),
attention_mask=text_atts[start + i].repeat(cur_bs, 1),
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
mode="fusion",
)
batch_itm_embeds = output.last_hidden_state[:, 0]
itm_embeds.append(batch_itm_embeds)
"""
itm_embeds = torch.cat(itm_embeds, dim=0)
# end new
score = model.itm_head(itm_embeds)[:, 1]
clip_scores.append(score)
if len(clip_scores) == 1:
score = clip_scores[0]
else:
assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
clip_scores = torch.stack(clip_scores) # (#clips, k)
if config.evaluation.eval_frame_ensemble == "mean":
score = clip_scores.mean(0)
elif config.evaluation.eval_frame_ensemble == "max":
score = clip_scores.max(0)[0]
elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp
score = torch.logsumexp(clip_scores, dim=0)
else:
raise ValueError(
"config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1."
)
t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
if config.distributed:
# gether across GPUs
dist.barrier()
dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f"Evaluation time {total_time_str}")
return (
i2t_scores_x.cpu().float().numpy(),
t2i_scores_x.cpu().float().numpy(),
i2t_scores.cpu().float().numpy(),
i2t_scores.T.cpu().float().numpy(),
)
@torch.no_grad()
def evaluation_video_clip(model, data_loader, tokenizer, device, config):
model.eval()
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
#dtype = torch.half if config.fp16 else torch.float
dtype = torch.float32
media_type = data_loader.dataset.media_type
logger.info(f"Start evaluation for media_type={media_type}")
logger.info("Computing dual encoder features...")
# this computes all features in each GPU
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_feats = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_feat = model.encode_text(text)
text_feats.append(text_feat.cpu())
text_feats = torch.cat(text_feats, dim=0)
logger.info("Finished computing text features")
if hasattr(data_loader.dataset, "num_prompts"):
np = data_loader.dataset.num_prompts
logger.info("Using {} prompts".format(np))
nt = len(data_loader.dataset.text) // np
text_feats = text_feats.view(nt, np, -1)
image_feats = []
metric_logger = MetricLogger(delimiter=" ")
header = "extracting image feats"
iterator = metric_logger.log_every(data_loader, 100, header)
for image, _ in iterator:
image = image.to(device, non_blocking=True)
image_feat = model.encode_vision(image, test=True)
image_feats.append(image_feat.cpu())
image_feats = torch.cat(image_feats, dim=0)
logger.info("Finished feature extraction")
logger.info("Computing ITC scores [dot-product]")
i2t_scores, t2i_scores = get_sim(image_feats, text_feats)
del image_feats, text_feats
logger.info("Computing ITC scores [dot-product], done!")
i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0)
i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0)
return (
i2t_scores.cpu().float().numpy(),
i2t_scores.T.cpu().float().numpy(),
i2t_scores_dsl.cpu().float().numpy(),
i2t_scores_dsl_T.cpu().float().numpy(),
)
@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
# Images->Text
ranks = np.zeros(scores_i2t.shape[0])
for index, score in enumerate(scores_i2t):
inds = np.argsort(score)[::-1]
# Score
gt_txt_ids = img2txt[index]
if isinstance(gt_txt_ids, int):
ranks[index] = np.where(inds == gt_txt_ids)[0][0]
else:
rank = 1e20
for i in gt_txt_ids:
tmp = np.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
ranks[index] = rank
# Compute metrics
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
# Text->Images
ranks = np.zeros(scores_t2i.shape[0])
for index, score in enumerate(scores_t2i):
inds = np.argsort(score)[::-1]
gt_img_ids = txt2img[index]
if isinstance(gt_img_ids, int):
ranks[index] = np.where(inds == gt_img_ids)[0][0]
else: # list, used in the case each caption has multiple GT images
# Score
rank = 1e20
for i in gt_img_ids:
tmp = np.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
ranks[index] = rank
# Compute metrics
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
tr_mean = (tr1 + tr5 + tr10) / 3
ir_mean = (ir1 + ir5 + ir10) / 3
r_mean = (tr_mean + ir_mean) / 2
eval_result = {
"txt_r1": tr1,
"txt_r5": tr5,
"txt_r10": tr10,
"txt_r_mean": tr_mean,
"img_r1": ir1,
"img_r5": ir5,
"img_r10": ir10,
"img_r_mean": ir_mean,
"r_mean": r_mean,
}
eval_result = {k: round(v, 2) for k, v in eval_result.items()}
return eval_result