|
import torch |
|
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler |
|
from transformers import CLIPTokenizer, AutoProcessor |
|
from .modeling_clip import CustomCLIPTextModel |
|
from .modeling_florence2 import Florence2ForConditionalGeneration |
|
from .configuration_florence2 import Florence2Config |
|
|
|
|
|
def load_sd_model(training_args): |
|
"""Load Stable Diffusion model""" |
|
|
|
repo_id = "stabilityai/stable-diffusion-2-1-base" |
|
|
|
text_encoder = CustomCLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder") |
|
tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer") |
|
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae",revision=None) |
|
scheduler = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler") |
|
unet = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet",revision=None) |
|
|
|
for m in [vae, text_encoder, unet]: |
|
for param in m.parameters(): |
|
param.requires_grad = False |
|
|
|
return (vae, tokenizer, text_encoder, unet, scheduler) |
|
|
|
|
|
def load_Florence2_model(training_args): |
|
config = Florence2Config.from_pretrained("microsoft/Florence-2-large") |
|
config.vision_config.model_type = "davit" |
|
config._attn_implementation = "eager" |
|
|
|
|
|
model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large", config=config) |
|
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True) |
|
|
|
|
|
if training_args.unfreeze_florence2_all: |
|
for param in model.parameters(): |
|
param.requires_grad = True |
|
elif training_args.unfreeze_florence2_language_model: |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
for param in model.language_model.parameters(): |
|
param.requires_grad = True |
|
for param in model.language_model.lm_head.parameters(): |
|
param.requires_grad = False |
|
|
|
model.language_model.lm_head.weight = torch.nn.Parameter( |
|
model.language_model.lm_head.weight.detach().clone()) |
|
|
|
for p in model.language_model.lm_head.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
elif training_args.unfreeze_florence2_language_model_decoder: |
|
|
|
original_embeddings = model.language_model.model.shared |
|
new_decoder_embeddings = torch.nn.Embedding( |
|
num_embeddings=original_embeddings.num_embeddings, |
|
embedding_dim=original_embeddings.embedding_dim, |
|
padding_idx=original_embeddings.padding_idx |
|
) |
|
|
|
new_decoder_embeddings.weight.data = original_embeddings.weight.data.clone() |
|
|
|
|
|
model.language_model.model.encoder.embed_tokens = original_embeddings |
|
model.language_model.model.decoder.embed_tokens = new_decoder_embeddings |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
for param in model.language_model.model.decoder.parameters(): |
|
param.requires_grad = True |
|
model.language_model.model.decoder.embed_tokens.weight.requires_grad = False |
|
else: |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
return model, processor |
|
|