File size: 3,354 Bytes
492f6af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTokenizer, AutoProcessor
from .modeling_clip import CustomCLIPTextModel
from .modeling_florence2 import Florence2ForConditionalGeneration
from .configuration_florence2 import Florence2Config 


def load_sd_model(training_args):
    """Load Stable Diffusion model"""

    repo_id = "stabilityai/stable-diffusion-2-1-base"

    text_encoder = CustomCLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder")
    tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer")
    vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae",revision=None)
    scheduler = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
    unet = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet",revision=None)

    for m in [vae, text_encoder, unet]:
        for param in m.parameters():
            param.requires_grad = False

    return (vae, tokenizer, text_encoder, unet, scheduler)


def load_Florence2_model(training_args):
    config = Florence2Config.from_pretrained("microsoft/Florence-2-large")
    config.vision_config.model_type = "davit"
    config._attn_implementation = "eager"
    
    # Load the model with pre-trained weights
    model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large", config=config)
    processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

    # freeze the model
    if training_args.unfreeze_florence2_all:
        for param in model.parameters():
            param.requires_grad = True
    elif training_args.unfreeze_florence2_language_model:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.language_model.parameters():
            param.requires_grad = True
        for param in model.language_model.lm_head.parameters():
            param.requires_grad = False

        model.language_model.lm_head.weight = torch.nn.Parameter(
        model.language_model.lm_head.weight.detach().clone())

        for p in model.language_model.lm_head.parameters():
            p.requires_grad = False


    elif training_args.unfreeze_florence2_language_model_decoder:
        # Create a separate embedding layer for decoder
        original_embeddings = model.language_model.model.shared
        new_decoder_embeddings = torch.nn.Embedding(
            num_embeddings=original_embeddings.num_embeddings,
            embedding_dim=original_embeddings.embedding_dim,
            padding_idx=original_embeddings.padding_idx
        )
        # Copy the weights
        new_decoder_embeddings.weight.data = original_embeddings.weight.data.clone()
        
        # Replace the decoder embeddings
        model.language_model.model.encoder.embed_tokens = original_embeddings
        model.language_model.model.decoder.embed_tokens = new_decoder_embeddings
        for param in model.parameters():
            param.requires_grad = False
        for param in model.language_model.model.decoder.parameters():
            param.requires_grad = True
        model.language_model.model.decoder.embed_tokens.weight.requires_grad = False
    else:
        for param in model.parameters():
            param.requires_grad = False

    return model, processor