import torch from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler from transformers import CLIPTokenizer, AutoProcessor from .modeling_clip import CustomCLIPTextModel from .modeling_florence2 import Florence2ForConditionalGeneration from .configuration_florence2 import Florence2Config def load_sd_model(training_args): """Load Stable Diffusion model""" repo_id = "stabilityai/stable-diffusion-2-1-base" text_encoder = CustomCLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder") tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer") vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae",revision=None) scheduler = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler") unet = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet",revision=None) for m in [vae, text_encoder, unet]: for param in m.parameters(): param.requires_grad = False return (vae, tokenizer, text_encoder, unet, scheduler) def load_Florence2_model(training_args): config = Florence2Config.from_pretrained("microsoft/Florence-2-large") config.vision_config.model_type = "davit" config._attn_implementation = "eager" # Load the model with pre-trained weights model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large", config=config) processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True) # freeze the model if training_args.unfreeze_florence2_all: for param in model.parameters(): param.requires_grad = True elif training_args.unfreeze_florence2_language_model: for param in model.parameters(): param.requires_grad = False for param in model.language_model.parameters(): param.requires_grad = True for param in model.language_model.lm_head.parameters(): param.requires_grad = False model.language_model.lm_head.weight = torch.nn.Parameter( model.language_model.lm_head.weight.detach().clone()) for p in model.language_model.lm_head.parameters(): p.requires_grad = False elif training_args.unfreeze_florence2_language_model_decoder: # Create a separate embedding layer for decoder original_embeddings = model.language_model.model.shared new_decoder_embeddings = torch.nn.Embedding( num_embeddings=original_embeddings.num_embeddings, embedding_dim=original_embeddings.embedding_dim, padding_idx=original_embeddings.padding_idx ) # Copy the weights new_decoder_embeddings.weight.data = original_embeddings.weight.data.clone() # Replace the decoder embeddings model.language_model.model.encoder.embed_tokens = original_embeddings model.language_model.model.decoder.embed_tokens = new_decoder_embeddings for param in model.parameters(): param.requires_grad = False for param in model.language_model.model.decoder.parameters(): param.requires_grad = True model.language_model.model.decoder.embed_tokens.weight.requires_grad = False else: for param in model.parameters(): param.requires_grad = False return model, processor