|
import os |
|
import numpy as np |
|
import math |
|
import sys |
|
from typing import Iterable, Optional |
|
import torch |
|
from mixup import Mixup |
|
from timm.utils import accuracy, ModelEma |
|
import utils_mae as utils |
|
from scipy.special import softmax |
|
import gc |
|
import pickle |
|
|
|
def train_class_batch(model, samples, target, criterion): |
|
outputs = model(samples) |
|
loss = criterion(outputs, target) |
|
return loss, outputs |
|
|
|
|
|
def get_loss_scale_for_deepspeed(model): |
|
optimizer = model.optimizer |
|
return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale |
|
|
|
|
|
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, |
|
data_loader: Iterable, optimizer: torch.optim.Optimizer, |
|
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, |
|
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, |
|
start_steps=None, lr_schedule_values=None, wd_schedule_values=None, |
|
num_training_steps_per_epoch=None, update_freq=None): |
|
model.train(True) |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
header = 'Epoch: [{}]'.format(epoch) |
|
print_freq = 10 |
|
|
|
if loss_scaler is None: |
|
model.zero_grad() |
|
model.micro_steps = 0 |
|
else: |
|
optimizer.zero_grad() |
|
|
|
for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
|
step = data_iter_step // update_freq |
|
if step >= num_training_steps_per_epoch: |
|
continue |
|
it = start_steps + step |
|
|
|
if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: |
|
for i, param_group in enumerate(optimizer.param_groups): |
|
if lr_schedule_values is not None: |
|
param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] |
|
if wd_schedule_values is not None and param_group["weight_decay"] > 0: |
|
param_group["weight_decay"] = wd_schedule_values[it] |
|
|
|
samples = samples.to(device, non_blocking=True) |
|
targets = targets.to(device, non_blocking=True) |
|
|
|
if mixup_fn is not None: |
|
samples, targets = mixup_fn(samples, targets) |
|
|
|
if loss_scaler is None: |
|
samples = samples.half() |
|
loss, output = train_class_batch( |
|
model, samples, targets, criterion) |
|
else: |
|
with torch.cuda.amp.autocast(): |
|
loss, output = train_class_batch( |
|
model, samples, targets, criterion) |
|
|
|
loss_value = loss.item() |
|
|
|
if not math.isfinite(loss_value): |
|
print("Loss is {}, stopping training".format(loss_value)) |
|
sys.exit(1) |
|
|
|
if loss_scaler is None: |
|
loss /= update_freq |
|
model.backward(loss) |
|
model.step() |
|
|
|
if (data_iter_step + 1) % update_freq == 0: |
|
|
|
|
|
if model_ema is not None: |
|
model_ema.update(model) |
|
grad_norm = None |
|
loss_scale_value = get_loss_scale_for_deepspeed(model) |
|
else: |
|
|
|
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order |
|
loss /= update_freq |
|
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, |
|
parameters=model.parameters(), create_graph=is_second_order, |
|
update_grad=(data_iter_step + 1) % update_freq == 0) |
|
if (data_iter_step + 1) % update_freq == 0: |
|
optimizer.zero_grad() |
|
if model_ema is not None: |
|
model_ema.update(model) |
|
loss_scale_value = loss_scaler.state_dict()["scale"] |
|
|
|
torch.cuda.synchronize() |
|
|
|
if mixup_fn is None: |
|
class_acc = (output.max(-1)[-1] == targets).float().mean() |
|
else: |
|
class_acc = None |
|
metric_logger.update(loss=loss_value) |
|
metric_logger.update(class_acc=class_acc) |
|
metric_logger.update(loss_scale=loss_scale_value) |
|
min_lr = 10. |
|
max_lr = 0. |
|
for group in optimizer.param_groups: |
|
min_lr = min(min_lr, group["lr"]) |
|
max_lr = max(max_lr, group["lr"]) |
|
|
|
metric_logger.update(lr=max_lr) |
|
metric_logger.update(min_lr=min_lr) |
|
weight_decay_value = None |
|
for group in optimizer.param_groups: |
|
if group["weight_decay"] > 0: |
|
weight_decay_value = group["weight_decay"] |
|
metric_logger.update(weight_decay=weight_decay_value) |
|
metric_logger.update(grad_norm=grad_norm) |
|
|
|
if log_writer is not None: |
|
log_writer.update(loss=loss_value, head="loss") |
|
log_writer.update(class_acc=class_acc, head="loss") |
|
log_writer.update(loss_scale=loss_scale_value, head="opt") |
|
log_writer.update(lr=max_lr, head="opt") |
|
log_writer.update(min_lr=min_lr, head="opt") |
|
log_writer.update(weight_decay=weight_decay_value, head="opt") |
|
log_writer.update(grad_norm=grad_norm, head="opt") |
|
|
|
log_writer.set_step() |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
@torch.no_grad() |
|
def validation_one_epoch(data_loader, model, device): |
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Val:' |
|
|
|
|
|
model.eval() |
|
|
|
for batch in metric_logger.log_every(data_loader, 10, header): |
|
videos = batch[0] |
|
target = batch[1] |
|
videos = videos.to(device, non_blocking=True) |
|
target = target.to(device, non_blocking=True) |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
output = model(videos) |
|
loss = criterion(output, target) |
|
|
|
acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
|
|
|
batch_size = videos.shape[0] |
|
metric_logger.update(loss=loss.item()) |
|
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
|
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) |
|
|
|
metric_logger.synchronize_between_processes() |
|
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' |
|
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) |
|
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def final_test(data_loader, model, device, file): |
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
|
|
|
|
model.eval() |
|
final_result = [] |
|
|
|
for batch in metric_logger.log_every(data_loader, 10, header): |
|
videos = batch[0] |
|
target = batch[1] |
|
ids = batch[2] |
|
chunk_nb = batch[3] |
|
split_nb = batch[4] |
|
videos = videos.to(device, non_blocking=True) |
|
target = target.to(device, non_blocking=True) |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
output = model(videos) |
|
loss = criterion(output, target) |
|
|
|
for i in range(output.size(0)): |
|
string = "{} {} {} {} {}\n".format(ids[i], \ |
|
str(output.data[i].cpu().numpy().tolist()), \ |
|
str(int(target[i].cpu().numpy())), \ |
|
str(int(chunk_nb[i].cpu().numpy())), \ |
|
str(int(split_nb[i].cpu().numpy()))) |
|
final_result.append(string) |
|
|
|
acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
|
|
|
batch_size = videos.shape[0] |
|
metric_logger.update(loss=loss.item()) |
|
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
|
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) |
|
|
|
if not os.path.exists(file): |
|
os.mknod(file) |
|
with open(file, 'w') as f: |
|
f.write("{}, {}\n".format(acc1, acc5)) |
|
for line in final_result: |
|
f.write(line) |
|
|
|
metric_logger.synchronize_between_processes() |
|
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' |
|
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) |
|
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
def merge(eval_path, num_tasks): |
|
dict_feats = {} |
|
dict_label = {} |
|
dict_pos = {} |
|
print("Reading individual output files") |
|
|
|
for x in range(num_tasks): |
|
file = os.path.join(eval_path, str(x) + '.txt') |
|
lines = open(file, 'r').readlines()[1:] |
|
for line in lines: |
|
line = line.strip() |
|
name = line.split('[')[0] |
|
label = line.split(']')[1].split(' ')[1] |
|
chunk_nb = line.split(']')[1].split(' ')[2] |
|
split_nb = line.split(']')[1].split(' ')[3] |
|
data = np.fromstring(line.split('[')[1].split(']')[0], dtype=float, sep=',') |
|
data = softmax(data) |
|
if not name in dict_feats: |
|
dict_feats[name] = [] |
|
dict_label[name] = 0 |
|
dict_pos[name] = [] |
|
if chunk_nb + split_nb in dict_pos[name]: |
|
continue |
|
dict_feats[name].append(data) |
|
dict_pos[name].append(chunk_nb + split_nb) |
|
dict_label[name] = label |
|
print("Computing final results") |
|
|
|
input_lst = [] |
|
print(len(dict_feats)) |
|
for i, item in enumerate(dict_feats): |
|
input_lst.append([i, item, dict_feats[item], dict_label[item]]) |
|
from multiprocessing import Pool |
|
p = Pool(64) |
|
ans = p.map(compute_video, input_lst) |
|
top1 = [x[1] for x in ans] |
|
top5 = [x[2] for x in ans] |
|
pred = [x[0] for x in ans] |
|
label = [x[3] for x in ans] |
|
final_top1 ,final_top5 = np.mean(top1), np.mean(top5) |
|
return final_top1*100 ,final_top5*100 |
|
|
|
def compute_video(lst): |
|
i, video_id, data, label = lst |
|
feat = [x for x in data] |
|
feat = np.mean(feat, axis=0) |
|
pred = np.argmax(feat) |
|
top1 = (int(pred) == int(label)) * 1.0 |
|
top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 |
|
return [pred, top1, top5, int(label)] |
|
|
|
def merge_mean_per_class(eval_path, num_tasks,nb_classes): |
|
dict_feats = {} |
|
dict_label = {} |
|
dict_pos = {} |
|
|
|
|
|
for x in range(num_tasks): |
|
file = os.path.join(eval_path, str(x) + '.txt') |
|
lines = open(file, 'r').readlines()[1:] |
|
for line in lines: |
|
line = line.strip() |
|
name = line.split('[')[0] |
|
label = line.split(']')[1].split(' ')[1] |
|
chunk_nb = line.split(']')[1].split(' ')[2] |
|
split_nb = line.split(']')[1].split(' ')[3] |
|
data = np.fromstring(line.split('[')[1].split(']')[0], dtype=float, sep=',') |
|
data = softmax(data) |
|
if not name in dict_feats: |
|
dict_feats[name] = [] |
|
dict_label[name] = 0 |
|
dict_pos[name] = [] |
|
if chunk_nb + split_nb in dict_pos[name]: |
|
continue |
|
dict_feats[name].append(data) |
|
dict_pos[name].append(chunk_nb + split_nb) |
|
dict_label[name] = label |
|
print("Computing mean per class results") |
|
|
|
input_lst = [] |
|
all_pred = [] |
|
all_label = [] |
|
|
|
classes = torch.arange(nb_classes) |
|
classwise_top1 = [0 for c in classes] |
|
classwise_top5 = [0 for c in classes] |
|
actual_nb_classes = nb_classes |
|
cnt = 0 |
|
|
|
for c in classes: |
|
input_lst = [] |
|
for i, item in enumerate(dict_feats): |
|
if int(dict_label[item]) == c: |
|
input_lst.append([i, item, dict_feats[item], dict_label[item]]) |
|
cnt += len(input_lst) |
|
|
|
|
|
|
|
if len(input_lst) == 0: |
|
actual_nb_classes -= 1 |
|
print(f"Class {c} is not present in test set, skip") |
|
continue |
|
|
|
ans = [] |
|
for i in input_lst: |
|
ans.append(compute_video(i)) |
|
top1 = [x[1] for x in ans] |
|
top5 = [x[2] for x in ans] |
|
pred = [x[0] for x in ans] |
|
label = [x[3] for x in ans] |
|
|
|
|
|
|
|
|
|
|
|
final_top1 ,final_top5 = np.mean(top1), np.mean(top5) |
|
|
|
classwise_top1[c] = final_top1*100 |
|
classwise_top5[c] = final_top5*100 |
|
|
|
del input_lst |
|
del ans |
|
del top1 |
|
del top5 |
|
del pred |
|
del label |
|
gc.collect() |
|
|
|
assert cnt == len(dict_feats) |
|
|
|
|
|
|
|
|
|
|
|
classwise_top1_path = os.path.join(eval_path, "classwise_top1.pkl") |
|
with open(classwise_top1_path, 'wb') as file: |
|
pickle.dump(classwise_top1, file) |
|
|
|
classwise_top1 = np.sum(classwise_top1) / actual_nb_classes |
|
classwise_top5 = np.sum(classwise_top5) / actual_nb_classes |
|
|
|
return classwise_top1,classwise_top5 |
|
|