from dataclasses import dataclass from typing import Optional, Tuple, Dict, Any, Union import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from transformers.utils import ModelOutput from transformers.modeling_utils import PreTrainedModel from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig from safetensors.torch import load_file import torchvision.transforms as transforms from .build import load_sd_model, load_Florence2_model from .vlv_utils import initiate_time_steps, normalize, process_caption from .VLV_stage1 import SDModel, SDConfig from .configuration_vlv import VLV_Config import os import sys import argparse def handle_module_prefix(state_dict): """Handle 'module.' prefix in state dict keys.""" if any(k.startswith('module.') for k in state_dict.keys()): return {k.replace('module.', ''): v for k, v in state_dict.items()} return state_dict def create_model_args(args): """Create model arguments needed by SDModel.""" model_args = argparse.Namespace() model_args.use_text_encoder = args.use_text_encoder model_args.batch_size = args.batch_size model_args.eval_batch_size = args.batch_size model_args.distributed_strategy = 'none' model_args.fp32 = args.fp32 model_args.learnable_token_length = args.learnable_token_length model_args.num_inference_steps = args.num_inference_steps model_args.image_size = args.image_size model_args.guidance_scale = args.guidance_scale model_args.unfreeze_florence2_all = False model_args.unfreeze_florence2_language_model = False model_args.unfreeze_florence2_language_model_decoder = False return model_args def load_model_checkpoint(model, model_path, device): """Load model checkpoint.""" try: checkpoint = torch.load(model_path, map_location="cpu") # Handle different checkpoint formats if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint state_dict = handle_module_prefix(state_dict) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if missing_keys: print(f"Missing keys: {missing_keys[:10]}...") # Show first 10 if unexpected_keys: print(f"Unexpected keys: {unexpected_keys[:10]}...") # Show first 10 print(f"Successfully loaded model from {model_path}") except Exception as e: print(f"Error loading model: {e}") raise e return model def initialize_diffusion_model(args): """Initialize the diffusion model.""" config = SDConfig() diffusion_model_args = create_model_args(args) diffusion_model = SDModel(config, diffusion_model_args) _dtype = torch.float32 if diffusion_model_args.fp32 else torch.bfloat16 # Delete components that aren't needed for inference if hasattr(diffusion_model, 'vae'): del diffusion_model.vae if hasattr(diffusion_model, 'unet'): del diffusion_model.unet # Clear CUDA cache torch.cuda.empty_cache() diffusion_model = diffusion_model.to(_dtype) # Freeze parameters that shouldn't be trained for param in diffusion_model.language_proj.parameters(): param.requires_grad = False diffusion_model.query_embed.requires_grad = False return diffusion_model class MLP(nn.Module): def __init__(self, input_dim, output_dim): super(MLP, self).__init__() self.layers = nn.Sequential( nn.Linear(input_dim, output_dim), nn.GELU(), nn.Linear(output_dim, output_dim), ) def forward(self, x): return self.layers(x) @dataclass class CLIPDecoderOutput(ModelOutput): """ Output class for the CLIP Decoder model. """ last_hidden_state: Optional[torch.FloatTensor] = None generated_ids: Optional[torch.LongTensor] = None generated_text: Optional[list] = None class CLIPDecoder(nn.Module): def __init__( self, language_model: str, VLV_model: SDModel, device: torch.device, bf16: str, qwen2_config: dict = None, args: argparse.Namespace = None ): """ Initialize the CLIP Decoder model. Args: language_model: Path to the language model VLV_model: The VLV model instance device: The device to run the model on bf16: Whether to use bfloat16 precision qwen2_config: Optional qwen2 configuration dict """ super(CLIPDecoder, self).__init__() self._dtype = torch.bfloat16 if bf16 == "bf16" else torch.float32 self.qwen2_tokenizer = AutoTokenizer.from_pretrained(language_model) self.qwen2_config = AutoConfig.from_pretrained(language_model) self.qwen2_model = AutoModelForCausalLM.from_pretrained( language_model, torch_dtype=self._dtype, device_map=None, low_cpu_mem_usage=True ) self.VLV_model = VLV_model # fp32 in this case self.device = device self.mlp = MLP(input_dim=1024, output_dim=self.qwen2_model.config.hidden_size) self.ignore_token_id = -100 def get_conditional_context(self, images, batch_size): """ Get conditional context from images using the diffusion model. Args: images: Input images batch_size: Batch size Returns: Decoder hidden states from the diffusion model """ prompt = [""] * batch_size inputs = self.VLV_model.processor(text=prompt, images=images, return_tensors="pt").to(self.device).to(self._dtype) # Ensure all components are on the correct device self.VLV_model = self.VLV_model.to(inputs["input_ids"].device) self.qwen2_model = self.qwen2_model.to(inputs["input_ids"].device) self.mlp = self.mlp.to(inputs["input_ids"].device) self.VLV_model.model.language_model.model = self.VLV_model.model.language_model.model.to(inputs["input_ids"].device) if inputs["input_ids"] is not None: inputs_embeds = self.VLV_model.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self.device) if inputs["pixel_values"] is not None: image_features = self.VLV_model.model._encode_image(inputs["pixel_values"]).to(self.device) inputs_embeds, attention_mask = self.VLV_model.model._merge_input_ids_with_image_features( image_features, inputs_embeds ) if inputs_embeds is not None: attention_mask = attention_mask.to(inputs_embeds.dtype) encoder_outputs = self.VLV_model.model.language_model.model.encoder( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) decoder_inputs_embeds = self.VLV_model.query_embed.expand(batch_size, -1, -1) decoder_attention_mask = torch.ones( (batch_size, self.VLV_model.num_queries), dtype=self._dtype, device=self.device ) encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype) decoder_input_embeds = decoder_inputs_embeds.to(self._dtype) attention_mask = attention_mask.to(self._dtype) decoder_outputs = self.VLV_model.model.language_model.model.decoder( inputs_embeds=decoder_input_embeds, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) return decoder_outputs.last_hidden_state def process_image(self, images, batch_size): """ Process images to get clip text embeddings. Args: images: Input images batch_size: Batch size Returns: Processed clip text embeddings and attention mask """ decoder_hidden_states = self.get_conditional_context(images, batch_size) context_embeds = self.VLV_model.language_proj(decoder_hidden_states) clip_text_embeds = self.VLV_model.text_encoder(inputs_embeds=context_embeds).last_hidden_state clip_text_embeds = self.mlp(clip_text_embeds) clip_text_embeds_attention_mask = torch.ones( (batch_size, self.VLV_model.num_queries), dtype=torch.long, device=self.device ) return clip_text_embeds, clip_text_embeds_attention_mask def prepare_generation_inputs(self, clip_text_embeds, clip_text_attention_mask=None): """ Prepare inputs for text generation. Args: clip_text_embeds: Processed clip text embeddings clip_text_attention_mask: Attention mask for clip text embeddings Returns: Dictionary of generation inputs """ if clip_text_attention_mask is None: clip_text_attention_mask = torch.ones( (clip_text_embeds.shape[0], clip_text_embeds.shape[1]), dtype=torch.long, device=clip_text_embeds.device ) return { "inputs_embeds": clip_text_embeds, "attention_mask": clip_text_attention_mask } def generate(self, images, max_new_tokens=300, num_beams=4, early_stopping=True): """ Generate text from images. Args: images: Input images max_new_tokens: Maximum number of tokens to generate num_beams: Number of beams for beam search early_stopping: Whether to stop early in beam search Returns: CLIPDecoderOutput with generated ids and text """ batch_size = len(images) clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size) generation_inputs = self.prepare_generation_inputs(clip_text_embeds, clip_text_attention_mask) generation_inputs["inputs_embeds"] = generation_inputs["inputs_embeds"].to(self._dtype) generation_inputs["attention_mask"] = generation_inputs["attention_mask"].to(self._dtype) generated_ids = self.qwen2_model.generate( inputs_embeds=generation_inputs["inputs_embeds"], attention_mask=generation_inputs["attention_mask"], max_new_tokens=max_new_tokens, num_beams=num_beams, early_stopping=early_stopping ) generated_text = self.qwen2_tokenizer.batch_decode(generated_ids, skip_special_tokens=True) processed_generated_text = [process_caption(text) for text in generated_text] return CLIPDecoderOutput( generated_ids=generated_ids, generated_text=processed_generated_text ) def forward(self, images, captions=None): """ Forward pass for training. Args: images: Input images captions: Target captions (optional, for training) Returns: CLIPDecoderOutput with loss and logits """ batch_size = images.shape[0] # Process images clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size) # If no captions provided, return embeddings for generation if captions is None: return CLIPDecoderOutput( last_hidden_state=clip_text_embeds ) assert len(captions) == batch_size # Process captions for training processed_captions = [process_caption(caption) for caption in captions] qwen_input_ids = self.qwen2_tokenizer( text=processed_captions, truncation=True, return_tensors="pt", padding="max_length", max_length=300, return_token_type_ids=False, ).input_ids assert len(captions) == batch_size qwen_attention_mask = qwen_input_ids.ne(self.qwen2_tokenizer.pad_token_id).to(torch.long).to(self.device) # Prepare labels for training labels = qwen_input_ids labels[labels == self.qwen2_tokenizer.pad_token_id] = self.ignore_token_id labels = labels.to(self.device) # Get embeddings for captions to create the full input sequence labels_for_embeddings = labels.clone() labels_for_embeddings[labels_for_embeddings == self.ignore_token_id] = self.qwen2_tokenizer.pad_token_id clip_text_embeds_qwen = self.qwen2_model.get_input_embeddings()(labels_for_embeddings) # Concatenate the embeddings and prepare attention mask inputs_embeds = torch.cat((clip_text_embeds, clip_text_embeds_qwen), dim=1) clip_seq_len = clip_text_embeds.shape[1] clip_ignore_labels = torch.full((labels.shape[0], clip_seq_len), self.ignore_token_id).to(labels) combined_labels = torch.cat((clip_ignore_labels, labels), dim=1) attention_mask = torch.cat(( clip_text_attention_mask, qwen_attention_mask ), dim=1) # Forward through language model outputs = self.qwen2_model( inputs_embeds=inputs_embeds, labels=combined_labels, attention_mask=attention_mask, use_cache=False ) return outputs # HuggingFace Model Wrapper class VLV_MODEL(PreTrainedModel): config_class = VLV_Config model_type = "VLV_decoder" def __init__(self, config): super().__init__(config) """Load the CLIPDecoder model.""" # Initialize the diffusion model first device = "cuda" de_diffusion_model = initialize_diffusion_model(config) clip_decoder_model = CLIPDecoder( language_model=config.qwen_model, VLV_model=de_diffusion_model, device=device, bf16=config.mixed_precision, qwen2_config=config.qwen2_config ) # Load the trained weights # clip_decoder_model = load_model_checkpoint(clip_decoder_model, config.clip_decoder_checkpoint, device) # Set to evaluation mode clip_decoder_model.eval() # Store components directly as attributes to match checkpoint structure self.VLV_model = clip_decoder_model.VLV_model self.qwen2_model = clip_decoder_model.qwen2_model self.mlp = clip_decoder_model.mlp # Keep the full model for methods self._clip_decoder_model = clip_decoder_model self.max_new_tokens = config.max_length self.num_beams = config.num_beams self.transform = self.get_transform(config.image_size) def get_transform(self, image_size): """Transformation pipeline for input images.""" return transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop((image_size, image_size)), transforms.PILToTensor(), ]) @classmethod def from_checkpoint(cls, checkpoint_path, config=None, **kwargs): """ Load model from original training checkpoint. Args: checkpoint_path: Path to the original model.pt checkpoint config: Optional VLV_Config, will create default if None **kwargs: Additional arguments for model initialization """ if config is None: # Create default config config = VLV_Config( image_size=384, guidance_scale=7.5, learnable_token_length=77, max_length=300, num_beams=4, **kwargs ) # Initialize model model = cls(config) # Load checkpoint weights device = "cuda" if torch.cuda.is_available() else "cpu" load_model_checkpoint(model._clip_decoder_model, checkpoint_path, device) return model def forward(self, valid_images, max_length): valid_images = [self.transform(img) for img in valid_images] if hasattr(self._clip_decoder_model, 'module'): outputs = self._clip_decoder_model.module.generate( valid_images, max_new_tokens=max_length, num_beams=self.num_beams, early_stopping=True ) else: outputs = self._clip_decoder_model.generate( valid_images, max_new_tokens=max_length, num_beams=self.num_beams, early_stopping=True ) return outputs