import copy | |
import logging | |
import os | |
import os.path as osp | |
from os.path import join | |
import io | |
from copy import deepcopy | |
import torch | |
from torch.utils.data import ConcatDataset, DataLoader | |
from models_viclip.backbones.clip.clip_vision import interpolate_pos_embed_vit | |
from models_viclip.backbones.bert.tokenization_bert import BertTokenizer | |
#from utils_viclip.optimizer import create_optimizer | |
#from utils_viclip.scheduler import create_scheduler | |
#from utils_viclip.distributed import get_world_size | |
#import deepspeed | |
# | |
#logger = logging.getLogger(__name__) | |
def get_media_types(datasources): | |
"""get the media types for for all the dataloaders. | |
Args: | |
datasources (List): List of dataloaders or datasets. | |
Returns: List. The media_types. | |
""" | |
if isinstance(datasources[0], DataLoader): | |
datasets = [dataloader.dataset for dataloader in datasources] | |
else: | |
datasets = datasources | |
media_types = [ | |
dataset.datasets[0].media_type | |
if isinstance(dataset, ConcatDataset) | |
else dataset.media_type | |
for dataset in datasets | |
] | |
return media_types | |
def setup_model( | |
config, model_cls, has_decoder=False, pretrain=False, find_unused_parameters=False, num_steps_per_epoch=-1, | |
num_classes=101): | |
#logger.info("Creating model") | |
config = copy.deepcopy(config) | |
# tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained) | |
#tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", local_files_only=True) | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
#model = model_cls(config=config, tokenizer=tokenizer, is_pretrain=pretrain,num_classes=num_classes) | |
model = model_cls(config=config, tokenizer=tokenizer, is_pretrain=pretrain,num_classes=num_classes) | |
#model = model.to(torch.device(config.device)) | |
#model_without_ddp = model | |
# if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
# optimizer_params = create_optimizer(config.optimizer, model, return_group=True) | |
# scheduler = None | |
# scaler = None | |
# else: | |
# if config.distributed: | |
# model = torch.nn.parallel.DistributedDataParallel( | |
# model, | |
# device_ids=[config.gpu], | |
# find_unused_parameters=find_unused_parameters, # `False` for image-only task | |
# ) | |
# optimizer = create_optimizer(config.optimizer, model) | |
# scaler = torch.cuda.amp.GradScaler(enabled=config.fp16) # This is never used actually | |
# scheduler = create_scheduler(config.scheduler, optimizer) | |
# start_epoch = 0 | |
# global_step = 0 | |
# auto resume the latest checkpoint | |
#if config.get("auto_resume", False): | |
#logger.info("Auto resuming") | |
#model_latest = join(config.output_dir, "ckpt_latest.pth") | |
#model_best = join(config.output_dir, "ckpt_best.pth") | |
#large_num = -1 | |
#for p in os.listdir(config.output_dir): | |
# if 'ckpt' in p: | |
# num = p.split('_')[1].split('.')[0] | |
# if str.isnumeric(num): | |
# if int(num) > large_num: | |
# large_num = int(num) | |
#if large_num != -1: | |
# model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth") | |
#if osp.exists(model_latest): | |
# config.pretrained_path = model_latest | |
# config.resume = True | |
#elif osp.exists(model_best): | |
# config.pretrained_path = model_best | |
# config.resume = True | |
#else: | |
#logger.info(f"Not found checkpoint in {config.output_dir}") | |
#if config.pretrained_path.strip() and (osp.isfile(config.pretrained_path) or "s3://" in config.pretrained_path): | |
#logger.info(f"Loading checkpoint from {config.pretrained_path}") | |
#checkpoint = torch.load(config.pretrained_path, map_location="cpu") | |
#try: | |
# state_dict = checkpoint["model"] | |
#except: # This is a deepspeed stage 1 model | |
# state_dict = checkpoint["module"] | |
#if config.resume: | |
# optimizer.load_state_dict(checkpoint["optimizer"]) | |
# scheduler.load_state_dict(checkpoint["scheduler"]) | |
# scaler.load_state_dict(checkpoint["scaler"]) | |
# start_epoch = checkpoint["epoch"] + 1 | |
# global_step = checkpoint["global_step"] | |
#elif config.evaluate or (not pretrain): # downstream init from pretrained ckpt | |
# is_blip_model = "VindLU_BLIP" in config.model.get("model_cls", "") | |
# # interpolate positional embeddings. | |
# if "vit" in config.model.vision_encoder.name: | |
# state_dict = interpolate_pos_embed_vit(state_dict, model_without_ddp) | |
# else: | |
# raise ValueError( | |
# f" vision encoder: {config.model.vision_encoder.name} not implelented" | |
# ) | |
# if not config.evaluate or config.get("zero_shot", False): # finetuning from a pretarined weights. | |
# for key in list(state_dict.keys()): | |
# if "bert" in key and not is_blip_model: | |
# encoder_key = key.replace("bert.", "") | |
# state_dict[encoder_key] = state_dict[key] | |
# if not has_decoder: | |
# del state_dict[key] | |
# # init text decoder as multimodal encoder (last 6 layers of model.text_encoder) | |
# # only for generation tasks like VQA | |
# if has_decoder and "text_encoder" in key and not is_blip_model: | |
# if "layer" in key: | |
# encoder_keys = key.split(".") | |
# layer_num = int(encoder_keys[4]) | |
# if layer_num < config.model.text_encoder.fusion_layer: | |
# del state_dict[key] | |
# continue | |
# else: | |
# decoder_layer_num = layer_num - config.model.text_encoder.fusion_layer | |
# encoder_keys[4] = str(decoder_layer_num) | |
# encoder_key = ".".join(encoder_keys) | |
# else: | |
# encoder_key = key | |
# decoder_key = encoder_key.replace("text_encoder", "text_decoder") | |
# state_dict[decoder_key] = state_dict[key] | |
# del state_dict[key] | |
#if hasattr(config, "wiseft") and config.wiseft.enable: | |
# #logger.info(f"Wiseft with coefficient {config.wiseft.coef}") | |
# missing_keys_in_pretrained = [k for k in model_without_ddp.state_dict().keys() if k not in state_dict] | |
# missing_keys_in_model = [k for k in state_dict.keys() if k not in model_without_ddp.state_dict()] | |
# mismatch_keys = [k for k in state_dict.keys() if k in model_without_ddp.state_dict() \ | |
# and state_dict[k].shape != model_without_ddp.state_dict()[k].shape] | |
# common_keys = [k for k in state_dict.keys() if k in model_without_ddp.state_dict()] | |
#logger.info(f"Missing keys in pretrained: {missing_keys_in_pretrained}") | |
#logger.info(f"Missing keys in model: {missing_keys_in_model}") | |
#logger.info(f"Mismatch keys: {mismatch_keys}") | |
#logger.info(f"Keys to exclude: {config.wiseft.keys_to_exclude}") | |
#for k in common_keys: | |
# if k in config.wiseft.keys_to_exclude: | |
# continue | |
# state_dict[k] = config.wiseft.coef * state_dict[k] + (1 - config.wiseft.coef) * model_without_ddp.state_dict()[k].cpu() | |
#msg = model_without_ddp.load_state_dict(state_dict, strict=False) | |
#print(msg) | |
#logger.info(msg) | |
#logger.info(f"Loaded checkpoint from {config.pretrained_path}") | |
#else: | |
#logger.warning("No pretrained checkpoint provided, training from scratch") | |
#if scaler is None: | |
# model = model_without_ddp | |
# model, optimizer, _, _ = deepspeed.initialize( | |
# args=config, model=model, model_parameters=optimizer_params, dist_init_required=not config.distributed, | |
# lr_scheduler=lambda opt: create_scheduler(config.scheduler, opt) | |
# ) | |
# | |
#if config.resume and osp.isdir(config.pretrained_path): | |
# output_dir, tag = os.path.split(config.pretrained_path) | |
# model.load_checkpoint(output_dir, tag=tag) | |
# global_step = model.global_steps | |
# assert num_steps_per_epoch > 0, "Please provide num_steps_per_epoch" | |
# start_epoch = global_step // num_steps_per_epoch | |
#return ( | |
# model, | |
# model_without_ddp, | |
# optimizer, | |
# scheduler, | |
# scaler, | |
# tokenizer, | |
# start_epoch, | |
# global_step, | |
#) | |
return model | |