lambertxiao's picture
Overwrite with converted Qwen2.5-3B model files
492f6af verified
import os
import torch
import torch.nn as nn
from typing import Optional
from dataclasses import dataclass
from transformers.utils import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from .build import load_sd_model, load_Florence2_model
from .vlv_utils import initiate_time_steps, normalize
class SDConfig(PretrainedConfig):
"""Configuration class for SDModel."""
model_type = "sd"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class MLP(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.GELU(),
nn.Linear(output_dim, output_dim),
)
def forward(self, x):
return self.layers(x)
@dataclass
class SDOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
class SDModel(PreTrainedModel):
config_class = SDConfig
def __init__(
self,
config=None,
training_args = None,
):
if config is None:
config = SDConfig()
super().__init__(config)
self.training_args = training_args
if self.training_args.fp32:
self._dtype = torch.float32
else:
self._dtype = torch.bfloat16
self._device = torch.device(self.training_args.device if hasattr(self.training_args, 'device') else "cuda" if torch.cuda.is_available() else "cpu")
self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler = load_sd_model(training_args)
torch.cuda.empty_cache()
self.unet.eval()
self.text_encoder.eval()
self.model, self.processor = load_Florence2_model(training_args)
self.unet = self.unet.to(self._dtype).to(device=self._device)
self.text_encoder = self.text_encoder.to(self._dtype).to_empty(device=self._device)
self.model = self.model.to(self._dtype).to_empty(device=self._device)
self.vae = self.vae.to(torch.float32).to_empty(device=self._device)
self.batch_size = self.training_args.batch_size
hidden_dim = 1024
self.language_proj = nn.Sequential(
nn.Linear(1024, hidden_dim, dtype=self._dtype),
nn.GELU(),
nn.Linear(hidden_dim, 1024, dtype=self._dtype)
).to_empty(device=self._device)
for param in self.language_proj.parameters():
param.requires_grad = True
self.num_queries = self.training_args.learnable_token_length
self.query_embed = nn.Parameter(torch.randn(1, self.num_queries, 1024, dtype=self._dtype))
self.query_embed.requires_grad = True
self.unet.enable_gradient_checkpointing()
def _unet_pred_noise(self, x_start, t, noise, context):
t = t.to(dtype=torch.long)
dtype = self.unet.dtype
x_start = x_start.to(dtype)
noise = noise.to(dtype)
context = context.to(dtype)
nt = t.shape[0]
noised_latent = self.scheduler.add_noise(x_start, noise, t)
pred_noise = self.unet(
noised_latent,
t,
encoder_hidden_states=context.expand(nt, -1, -1)
).sample
return pred_noise
def generate_images(self, images):
batch_size = self.training_args.eval_batch_size
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype)
if inputs["input_ids"] is not None:
inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype)
if inputs["pixel_values"] is not None:
image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype)
inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds)
if inputs_embeds is not None:
attention_mask = attention_mask.to(inputs_embeds.dtype)
encoder_outputs = self.model.language_model.model.encoder(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1)
decoder_attention_mask = torch.ones(
(batch_size, self.num_queries),
dtype=self._dtype,
device=self._device
)
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
decoder_input_embeds = decoder_input_embeds.to(self._dtype)
attention_mask = attention_mask.to(self._dtype)
decoder_outputs = self.model.language_model.model.decoder(
inputs_embeds=decoder_input_embeds,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
last_decoder_hidden_state = decoder_outputs.last_hidden_state
conditional_context = self.language_proj(last_decoder_hidden_state)
un_token = self.tokenizer("", padding="max_length", truncation=True,max_length=77, return_tensors="pt").input_ids.to(self._device)
un_context_embeddings = self.text_encoder(un_token).last_hidden_state
un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1)
if self.training_args.use_text_encoder:
context_embeddings = self.text_encoder(
inputs_embeds=conditional_context.to(self._dtype)
).last_hidden_state
latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8)
latents = torch.randn(latent_shape, device=self._device, dtype=self._dtype)
scheduler = self.scheduler
scheduler.set_timesteps(self.training_args.num_inference_steps)
with torch.no_grad():
for t in scheduler.timesteps:
latent_model_input = torch.cat([latents, latents], dim=0)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
combined_embeddings = torch.cat([un_context_embeddings, context_embeddings], dim=0).to(self._dtype)
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=combined_embeddings
)[0]
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncond + self.training_args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents)[0]
scaled_latents = latents / 0.18215
with torch.no_grad():
decoded_latents = self.vae.decode(scaled_latents.to(torch.float32))[0]
return decoded_latents
def get_conditional_context(self, images, batch_size=None):
if batch_size is None:
batch_size = self.batch_size
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype)
if inputs["input_ids"] is not None:
inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype)
if inputs["pixel_values"] is not None:
image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype)
inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds)
if inputs_embeds is not None:
attention_mask = attention_mask.to(inputs_embeds.dtype)
encoder_outputs = self.model.language_model.model.encoder(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1)
decoder_attention_mask = torch.ones(
(batch_size, self.num_queries),
dtype=self._dtype,
device=self._device
)
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
decoder_input_embeds = decoder_input_embeds.to(self._dtype)
attention_mask = attention_mask.to(self._dtype)
decoder_outputs = self.model.language_model.model.decoder(
inputs_embeds=decoder_input_embeds,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
last_decoder_hidden_state = decoder_outputs.last_hidden_state
return last_decoder_hidden_state
def forward(
self,
image=None,
filename=None,
**kwargs,
) -> SDOutput:
images_for_language_model = image
normalize_images = normalize(image, rescale=True)
x0=self.vae.encode(normalize_images.to(torch.float32)).latent_dist.sample()
latent = x0 * 0.18215
total_timestep = self.scheduler.num_train_timesteps
timesteps = initiate_time_steps(0, total_timestep, self.batch_size, self.training_args).long()
timesteps = timesteps.to(self._device)
c, h, w = latent.shape[1:]
if not self.training_args.use_same_noise_among_timesteps:
noise = torch.randn((self.batch_size, c, h, w), device=self._device, dtype=self._dtype)
else:
noise = torch.randn((1, c, h, w), device=self._device, dtype=self._dtype)
noise = noise.repeat(self.batch_size, 1, 1, 1)
conditional_context = self.get_conditional_context(images_for_language_model)
conditional_context = self.language_proj(conditional_context)
if self.training_args.use_text_encoder:
text_encoder_output = self.text_encoder(input_ids=None, inputs_embeds=conditional_context.to(self._dtype))
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype)
else:
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype)
if self.training_args.loss == "l1":
loss = torch.nn.functional.l1_loss(pred_noise, noise)
else:
loss = torch.nn.functional.mse_loss(pred_noise, noise)
return SDOutput(loss=loss)