File size: 3,354 Bytes
492f6af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTokenizer, AutoProcessor
from .modeling_clip import CustomCLIPTextModel
from .modeling_florence2 import Florence2ForConditionalGeneration
from .configuration_florence2 import Florence2Config
def load_sd_model(training_args):
"""Load Stable Diffusion model"""
repo_id = "stabilityai/stable-diffusion-2-1-base"
text_encoder = CustomCLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder")
tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae",revision=None)
scheduler = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
unet = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet",revision=None)
for m in [vae, text_encoder, unet]:
for param in m.parameters():
param.requires_grad = False
return (vae, tokenizer, text_encoder, unet, scheduler)
def load_Florence2_model(training_args):
config = Florence2Config.from_pretrained("microsoft/Florence-2-large")
config.vision_config.model_type = "davit"
config._attn_implementation = "eager"
# Load the model with pre-trained weights
model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large", config=config)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
# freeze the model
if training_args.unfreeze_florence2_all:
for param in model.parameters():
param.requires_grad = True
elif training_args.unfreeze_florence2_language_model:
for param in model.parameters():
param.requires_grad = False
for param in model.language_model.parameters():
param.requires_grad = True
for param in model.language_model.lm_head.parameters():
param.requires_grad = False
model.language_model.lm_head.weight = torch.nn.Parameter(
model.language_model.lm_head.weight.detach().clone())
for p in model.language_model.lm_head.parameters():
p.requires_grad = False
elif training_args.unfreeze_florence2_language_model_decoder:
# Create a separate embedding layer for decoder
original_embeddings = model.language_model.model.shared
new_decoder_embeddings = torch.nn.Embedding(
num_embeddings=original_embeddings.num_embeddings,
embedding_dim=original_embeddings.embedding_dim,
padding_idx=original_embeddings.padding_idx
)
# Copy the weights
new_decoder_embeddings.weight.data = original_embeddings.weight.data.clone()
# Replace the decoder embeddings
model.language_model.model.encoder.embed_tokens = original_embeddings
model.language_model.model.decoder.embed_tokens = new_decoder_embeddings
for param in model.parameters():
param.requires_grad = False
for param in model.language_model.model.decoder.parameters():
param.requires_grad = True
model.language_model.model.decoder.embed_tokens.weight.requires_grad = False
else:
for param in model.parameters():
param.requires_grad = False
return model, processor
|