lambertxiao's picture
Overwrite with converted Qwen2.5-3B model files
492f6af verified
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers.utils import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig
from safetensors.torch import load_file
import torchvision.transforms as transforms
from .build import load_sd_model, load_Florence2_model
from .vlv_utils import initiate_time_steps, normalize, process_caption
from .VLV_stage1 import SDModel, SDConfig
from .configuration_vlv import VLV_Config
import os
import sys
import argparse
def handle_module_prefix(state_dict):
"""Handle 'module.' prefix in state dict keys."""
if any(k.startswith('module.') for k in state_dict.keys()):
return {k.replace('module.', ''): v for k, v in state_dict.items()}
return state_dict
def create_model_args(args):
"""Create model arguments needed by SDModel."""
model_args = argparse.Namespace()
model_args.use_text_encoder = args.use_text_encoder
model_args.batch_size = args.batch_size
model_args.eval_batch_size = args.batch_size
model_args.distributed_strategy = 'none'
model_args.fp32 = args.fp32
model_args.learnable_token_length = args.learnable_token_length
model_args.num_inference_steps = args.num_inference_steps
model_args.image_size = args.image_size
model_args.guidance_scale = args.guidance_scale
model_args.unfreeze_florence2_all = False
model_args.unfreeze_florence2_language_model = False
model_args.unfreeze_florence2_language_model_decoder = False
return model_args
def load_model_checkpoint(model, model_path, device):
"""Load model checkpoint."""
try:
checkpoint = torch.load(model_path, map_location="cpu")
# Handle different checkpoint formats
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
state_dict = handle_module_prefix(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {missing_keys[:10]}...") # Show first 10
if unexpected_keys:
print(f"Unexpected keys: {unexpected_keys[:10]}...") # Show first 10
print(f"Successfully loaded model from {model_path}")
except Exception as e:
print(f"Error loading model: {e}")
raise e
return model
def initialize_diffusion_model(args):
"""Initialize the diffusion model."""
config = SDConfig()
diffusion_model_args = create_model_args(args)
diffusion_model = SDModel(config, diffusion_model_args)
_dtype = torch.float32 if diffusion_model_args.fp32 else torch.bfloat16
# Delete components that aren't needed for inference
if hasattr(diffusion_model, 'vae'):
del diffusion_model.vae
if hasattr(diffusion_model, 'unet'):
del diffusion_model.unet
# Clear CUDA cache
torch.cuda.empty_cache()
diffusion_model = diffusion_model.to(_dtype)
# Freeze parameters that shouldn't be trained
for param in diffusion_model.language_proj.parameters():
param.requires_grad = False
diffusion_model.query_embed.requires_grad = False
return diffusion_model
class MLP(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLP, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.GELU(),
nn.Linear(output_dim, output_dim),
)
def forward(self, x):
return self.layers(x)
@dataclass
class CLIPDecoderOutput(ModelOutput):
"""
Output class for the CLIP Decoder model.
"""
last_hidden_state: Optional[torch.FloatTensor] = None
generated_ids: Optional[torch.LongTensor] = None
generated_text: Optional[list] = None
class CLIPDecoder(nn.Module):
def __init__(
self,
language_model: str,
VLV_model: SDModel,
device: torch.device,
bf16: str,
qwen2_config: dict = None,
args: argparse.Namespace = None
):
"""
Initialize the CLIP Decoder model.
Args:
language_model: Path to the language model
VLV_model: The VLV model instance
device: The device to run the model on
bf16: Whether to use bfloat16 precision
qwen2_config: Optional qwen2 configuration dict
"""
super(CLIPDecoder, self).__init__()
self._dtype = torch.bfloat16 if bf16 == "bf16" else torch.float32
self.qwen2_tokenizer = AutoTokenizer.from_pretrained(language_model)
self.qwen2_config = AutoConfig.from_pretrained(language_model)
self.qwen2_model = AutoModelForCausalLM.from_pretrained(
language_model,
torch_dtype=self._dtype,
device_map=None,
low_cpu_mem_usage=True
)
self.VLV_model = VLV_model # fp32 in this case
self.device = device
self.mlp = MLP(input_dim=1024, output_dim=self.qwen2_model.config.hidden_size)
self.ignore_token_id = -100
def get_conditional_context(self, images, batch_size):
"""
Get conditional context from images using the diffusion model.
Args:
images: Input images
batch_size: Batch size
Returns:
Decoder hidden states from the diffusion model
"""
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
inputs = self.VLV_model.processor(text=prompt, images=images, return_tensors="pt").to(self.device).to(self._dtype)
# Ensure all components are on the correct device
self.VLV_model = self.VLV_model.to(inputs["input_ids"].device)
self.qwen2_model = self.qwen2_model.to(inputs["input_ids"].device)
self.mlp = self.mlp.to(inputs["input_ids"].device)
self.VLV_model.model.language_model.model = self.VLV_model.model.language_model.model.to(inputs["input_ids"].device)
if inputs["input_ids"] is not None:
inputs_embeds = self.VLV_model.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self.device)
if inputs["pixel_values"] is not None:
image_features = self.VLV_model.model._encode_image(inputs["pixel_values"]).to(self.device)
inputs_embeds, attention_mask = self.VLV_model.model._merge_input_ids_with_image_features(
image_features, inputs_embeds
)
if inputs_embeds is not None:
attention_mask = attention_mask.to(inputs_embeds.dtype)
encoder_outputs = self.VLV_model.model.language_model.model.encoder(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
decoder_inputs_embeds = self.VLV_model.query_embed.expand(batch_size, -1, -1)
decoder_attention_mask = torch.ones(
(batch_size, self.VLV_model.num_queries),
dtype=self._dtype,
device=self.device
)
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
decoder_input_embeds = decoder_inputs_embeds.to(self._dtype)
attention_mask = attention_mask.to(self._dtype)
decoder_outputs = self.VLV_model.model.language_model.model.decoder(
inputs_embeds=decoder_input_embeds,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
return decoder_outputs.last_hidden_state
def process_image(self, images, batch_size):
"""
Process images to get clip text embeddings.
Args:
images: Input images
batch_size: Batch size
Returns:
Processed clip text embeddings and attention mask
"""
decoder_hidden_states = self.get_conditional_context(images, batch_size)
context_embeds = self.VLV_model.language_proj(decoder_hidden_states)
clip_text_embeds = self.VLV_model.text_encoder(inputs_embeds=context_embeds).last_hidden_state
clip_text_embeds = self.mlp(clip_text_embeds)
clip_text_embeds_attention_mask = torch.ones(
(batch_size, self.VLV_model.num_queries),
dtype=torch.long,
device=self.device
)
return clip_text_embeds, clip_text_embeds_attention_mask
def prepare_generation_inputs(self, clip_text_embeds, clip_text_attention_mask=None):
"""
Prepare inputs for text generation.
Args:
clip_text_embeds: Processed clip text embeddings
clip_text_attention_mask: Attention mask for clip text embeddings
Returns:
Dictionary of generation inputs
"""
if clip_text_attention_mask is None:
clip_text_attention_mask = torch.ones(
(clip_text_embeds.shape[0], clip_text_embeds.shape[1]),
dtype=torch.long,
device=clip_text_embeds.device
)
return {
"inputs_embeds": clip_text_embeds,
"attention_mask": clip_text_attention_mask
}
def generate(self, images, max_new_tokens=300, num_beams=4, early_stopping=True):
"""
Generate text from images.
Args:
images: Input images
max_new_tokens: Maximum number of tokens to generate
num_beams: Number of beams for beam search
early_stopping: Whether to stop early in beam search
Returns:
CLIPDecoderOutput with generated ids and text
"""
batch_size = len(images)
clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size)
generation_inputs = self.prepare_generation_inputs(clip_text_embeds, clip_text_attention_mask)
generation_inputs["inputs_embeds"] = generation_inputs["inputs_embeds"].to(self._dtype)
generation_inputs["attention_mask"] = generation_inputs["attention_mask"].to(self._dtype)
generated_ids = self.qwen2_model.generate(
inputs_embeds=generation_inputs["inputs_embeds"],
attention_mask=generation_inputs["attention_mask"],
max_new_tokens=max_new_tokens,
num_beams=num_beams,
early_stopping=early_stopping
)
generated_text = self.qwen2_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
processed_generated_text = [process_caption(text) for text in generated_text]
return CLIPDecoderOutput(
generated_ids=generated_ids,
generated_text=processed_generated_text
)
def forward(self, images, captions=None):
"""
Forward pass for training.
Args:
images: Input images
captions: Target captions (optional, for training)
Returns:
CLIPDecoderOutput with loss and logits
"""
batch_size = images.shape[0]
# Process images
clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size)
# If no captions provided, return embeddings for generation
if captions is None:
return CLIPDecoderOutput(
last_hidden_state=clip_text_embeds
)
assert len(captions) == batch_size
# Process captions for training
processed_captions = [process_caption(caption) for caption in captions]
qwen_input_ids = self.qwen2_tokenizer(
text=processed_captions,
truncation=True,
return_tensors="pt",
padding="max_length",
max_length=300,
return_token_type_ids=False,
).input_ids
assert len(captions) == batch_size
qwen_attention_mask = qwen_input_ids.ne(self.qwen2_tokenizer.pad_token_id).to(torch.long).to(self.device)
# Prepare labels for training
labels = qwen_input_ids
labels[labels == self.qwen2_tokenizer.pad_token_id] = self.ignore_token_id
labels = labels.to(self.device)
# Get embeddings for captions to create the full input sequence
labels_for_embeddings = labels.clone()
labels_for_embeddings[labels_for_embeddings == self.ignore_token_id] = self.qwen2_tokenizer.pad_token_id
clip_text_embeds_qwen = self.qwen2_model.get_input_embeddings()(labels_for_embeddings)
# Concatenate the embeddings and prepare attention mask
inputs_embeds = torch.cat((clip_text_embeds, clip_text_embeds_qwen), dim=1)
clip_seq_len = clip_text_embeds.shape[1]
clip_ignore_labels = torch.full((labels.shape[0], clip_seq_len), self.ignore_token_id).to(labels)
combined_labels = torch.cat((clip_ignore_labels, labels), dim=1)
attention_mask = torch.cat((
clip_text_attention_mask,
qwen_attention_mask
), dim=1)
# Forward through language model
outputs = self.qwen2_model(
inputs_embeds=inputs_embeds,
labels=combined_labels,
attention_mask=attention_mask,
use_cache=False
)
return outputs
# HuggingFace Model Wrapper
class VLV_MODEL(PreTrainedModel):
config_class = VLV_Config
model_type = "VLV_decoder"
def __init__(self, config):
super().__init__(config)
"""Load the CLIPDecoder model."""
# Initialize the diffusion model first
device = "cuda"
de_diffusion_model = initialize_diffusion_model(config)
clip_decoder_model = CLIPDecoder(
language_model=config.qwen_model,
VLV_model=de_diffusion_model,
device=device,
bf16=config.mixed_precision,
qwen2_config=config.qwen2_config
)
# Load the trained weights
# clip_decoder_model = load_model_checkpoint(clip_decoder_model, config.clip_decoder_checkpoint, device)
# Set to evaluation mode
clip_decoder_model.eval()
# Store components directly as attributes to match checkpoint structure
self.VLV_model = clip_decoder_model.VLV_model
self.qwen2_model = clip_decoder_model.qwen2_model
self.mlp = clip_decoder_model.mlp
# Keep the full model for methods
self._clip_decoder_model = clip_decoder_model
self.max_new_tokens = config.max_length
self.num_beams = config.num_beams
self.transform = self.get_transform(config.image_size)
def get_transform(self, image_size):
"""Transformation pipeline for input images."""
return transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop((image_size, image_size)),
transforms.PILToTensor(),
])
@classmethod
def from_checkpoint(cls, checkpoint_path, config=None, **kwargs):
"""
Load model from original training checkpoint.
Args:
checkpoint_path: Path to the original model.pt checkpoint
config: Optional VLV_Config, will create default if None
**kwargs: Additional arguments for model initialization
"""
if config is None:
# Create default config
config = VLV_Config(
image_size=384,
guidance_scale=7.5,
learnable_token_length=77,
max_length=300,
num_beams=4,
**kwargs
)
# Initialize model
model = cls(config)
# Load checkpoint weights
device = "cuda" if torch.cuda.is_available() else "cpu"
load_model_checkpoint(model._clip_decoder_model, checkpoint_path, device)
return model
def forward(self, valid_images, max_length):
valid_images = [self.transform(img) for img in valid_images]
if hasattr(self._clip_decoder_model, 'module'):
outputs = self._clip_decoder_model.module.generate(
valid_images,
max_new_tokens=max_length,
num_beams=self.num_beams,
early_stopping=True
)
else:
outputs = self._clip_decoder_model.generate(
valid_images,
max_new_tokens=max_length,
num_beams=self.num_beams,
early_stopping=True
)
return outputs