import os import torch import torch.nn as nn from typing import Optional from dataclasses import dataclass from transformers.utils import ModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from .build import load_sd_model, load_Florence2_model from .vlv_utils import initiate_time_steps, normalize class SDConfig(PretrainedConfig): """Configuration class for SDModel.""" model_type = "sd" def __init__(self, **kwargs): super().__init__(**kwargs) class MLP(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.layers = nn.Sequential( nn.Linear(input_dim, output_dim), nn.GELU(), nn.Linear(output_dim, output_dim), ) def forward(self, x): return self.layers(x) @dataclass class SDOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None class SDModel(PreTrainedModel): config_class = SDConfig def __init__( self, config=None, training_args = None, ): if config is None: config = SDConfig() super().__init__(config) self.training_args = training_args if self.training_args.fp32: self._dtype = torch.float32 else: self._dtype = torch.bfloat16 self._device = torch.device(self.training_args.device if hasattr(self.training_args, 'device') else "cuda" if torch.cuda.is_available() else "cpu") self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler = load_sd_model(training_args) torch.cuda.empty_cache() self.unet.eval() self.text_encoder.eval() self.model, self.processor = load_Florence2_model(training_args) self.unet = self.unet.to(self._dtype).to(device=self._device) self.text_encoder = self.text_encoder.to(self._dtype).to_empty(device=self._device) self.model = self.model.to(self._dtype).to_empty(device=self._device) self.vae = self.vae.to(torch.float32).to_empty(device=self._device) self.batch_size = self.training_args.batch_size hidden_dim = 1024 self.language_proj = nn.Sequential( nn.Linear(1024, hidden_dim, dtype=self._dtype), nn.GELU(), nn.Linear(hidden_dim, 1024, dtype=self._dtype) ).to_empty(device=self._device) for param in self.language_proj.parameters(): param.requires_grad = True self.num_queries = self.training_args.learnable_token_length self.query_embed = nn.Parameter(torch.randn(1, self.num_queries, 1024, dtype=self._dtype)) self.query_embed.requires_grad = True self.unet.enable_gradient_checkpointing() def _unet_pred_noise(self, x_start, t, noise, context): t = t.to(dtype=torch.long) dtype = self.unet.dtype x_start = x_start.to(dtype) noise = noise.to(dtype) context = context.to(dtype) nt = t.shape[0] noised_latent = self.scheduler.add_noise(x_start, noise, t) pred_noise = self.unet( noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1) ).sample return pred_noise def generate_images(self, images): batch_size = self.training_args.eval_batch_size prompt = [""] * batch_size inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype) if inputs["input_ids"] is not None: inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype) if inputs["pixel_values"] is not None: image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype) inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds) if inputs_embeds is not None: attention_mask = attention_mask.to(inputs_embeds.dtype) encoder_outputs = self.model.language_model.model.encoder( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1) decoder_attention_mask = torch.ones( (batch_size, self.num_queries), dtype=self._dtype, device=self._device ) encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype) decoder_input_embeds = decoder_input_embeds.to(self._dtype) attention_mask = attention_mask.to(self._dtype) decoder_outputs = self.model.language_model.model.decoder( inputs_embeds=decoder_input_embeds, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) last_decoder_hidden_state = decoder_outputs.last_hidden_state conditional_context = self.language_proj(last_decoder_hidden_state) un_token = self.tokenizer("", padding="max_length", truncation=True,max_length=77, return_tensors="pt").input_ids.to(self._device) un_context_embeddings = self.text_encoder(un_token).last_hidden_state un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1) if self.training_args.use_text_encoder: context_embeddings = self.text_encoder( inputs_embeds=conditional_context.to(self._dtype) ).last_hidden_state latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8) latents = torch.randn(latent_shape, device=self._device, dtype=self._dtype) scheduler = self.scheduler scheduler.set_timesteps(self.training_args.num_inference_steps) with torch.no_grad(): for t in scheduler.timesteps: latent_model_input = torch.cat([latents, latents], dim=0) latent_model_input = scheduler.scale_model_input(latent_model_input, t) combined_embeddings = torch.cat([un_context_embeddings, context_embeddings], dim=0).to(self._dtype) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=combined_embeddings )[0] noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) noise_pred = noise_pred_uncond + self.training_args.guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = scheduler.step(noise_pred, t, latents)[0] scaled_latents = latents / 0.18215 with torch.no_grad(): decoded_latents = self.vae.decode(scaled_latents.to(torch.float32))[0] return decoded_latents def get_conditional_context(self, images, batch_size=None): if batch_size is None: batch_size = self.batch_size prompt = [""] * batch_size inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype) if inputs["input_ids"] is not None: inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype) if inputs["pixel_values"] is not None: image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype) inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds) if inputs_embeds is not None: attention_mask = attention_mask.to(inputs_embeds.dtype) encoder_outputs = self.model.language_model.model.encoder( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1) decoder_attention_mask = torch.ones( (batch_size, self.num_queries), dtype=self._dtype, device=self._device ) encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype) decoder_input_embeds = decoder_input_embeds.to(self._dtype) attention_mask = attention_mask.to(self._dtype) decoder_outputs = self.model.language_model.model.decoder( inputs_embeds=decoder_input_embeds, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) last_decoder_hidden_state = decoder_outputs.last_hidden_state return last_decoder_hidden_state def forward( self, image=None, filename=None, **kwargs, ) -> SDOutput: images_for_language_model = image normalize_images = normalize(image, rescale=True) x0=self.vae.encode(normalize_images.to(torch.float32)).latent_dist.sample() latent = x0 * 0.18215 total_timestep = self.scheduler.num_train_timesteps timesteps = initiate_time_steps(0, total_timestep, self.batch_size, self.training_args).long() timesteps = timesteps.to(self._device) c, h, w = latent.shape[1:] if not self.training_args.use_same_noise_among_timesteps: noise = torch.randn((self.batch_size, c, h, w), device=self._device, dtype=self._dtype) else: noise = torch.randn((1, c, h, w), device=self._device, dtype=self._dtype) noise = noise.repeat(self.batch_size, 1, 1, 1) conditional_context = self.get_conditional_context(images_for_language_model) conditional_context = self.language_proj(conditional_context) if self.training_args.use_text_encoder: text_encoder_output = self.text_encoder(input_ids=None, inputs_embeds=conditional_context.to(self._dtype)) pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype) else: pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype) if self.training_args.loss == "l1": loss = torch.nn.functional.l1_loss(pred_noise, noise) else: loss = torch.nn.functional.mse_loss(pred_noise, noise) return SDOutput(loss=loss)