|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import math |
|
from safetensors.torch import load_file |
|
from typing import List, Optional, Tuple, Union |
|
import torch.utils.checkpoint |
|
from huggingface_hub import snapshot_download |
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
from transformers import Phi3Config, Phi3Model |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Phi3Transformer(Phi3Model): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] |
|
We only modified the attention mask |
|
Args: |
|
config: Phi3Config |
|
""" |
|
def prefetch_layer(self, layer_idx: int, device: torch.device): |
|
"Starts prefetching the next layer cache" |
|
with torch.cuda.stream(self.prefetch_stream): |
|
|
|
for name, param in self.layers[layer_idx].named_parameters(): |
|
param.data = param.data.to(device, non_blocking=True) |
|
|
|
def evict_previous_layer(self, layer_idx: int): |
|
"Moves the previous layer cache to the CPU" |
|
prev_layer_idx = layer_idx - 1 |
|
for name, param in self.layers[prev_layer_idx].named_parameters(): |
|
param.data = param.data.to("cpu", non_blocking=True) |
|
|
|
def get_offlaod_layer(self, layer_idx: int, device: torch.device): |
|
|
|
if not hasattr(self, "prefetch_stream"): |
|
self.prefetch_stream = torch.cuda.Stream() |
|
|
|
|
|
torch.cuda.current_stream().synchronize() |
|
self.evict_previous_layer(layer_idx) |
|
|
|
|
|
torch.cuda.synchronize(self.prefetch_stream) |
|
|
|
|
|
self.prefetch_layer((layer_idx + 1) % len(self.layers), device) |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
offload_model: Optional[bool] = False, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
return_legacy_cache = False |
|
if use_cache and not isinstance(past_key_values, Cache): |
|
return_legacy_cache = True |
|
if past_key_values is None: |
|
past_key_values = DynamicCache() |
|
else: |
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
logger.warning_once( |
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " |
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " |
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.dim() == 3: |
|
dtype = inputs_embeds.dtype |
|
min_dtype = torch.finfo(dtype).min |
|
attention_mask = (1 - attention_mask) * min_dtype |
|
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype) |
|
else: |
|
raise Exception("attention_mask parameter was unavailable or invalid") |
|
|
|
|
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = None |
|
|
|
layer_idx = -1 |
|
for decoder_layer in self.layers: |
|
layer_idx += 1 |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
use_cache, |
|
cache_position, |
|
) |
|
else: |
|
if offload_model and not self.training: |
|
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device) |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
print('************') |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
if return_legacy_cache: |
|
next_cache = next_cache.to_legacy_cache() |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
|
|
def modulate(x, shift, scale): |
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
""" |
|
Embeds scalar timesteps into vector representations. |
|
""" |
|
def __init__(self, hidden_size, frequency_embedding_size=256): |
|
super().__init__() |
|
self.mlp = nn.Sequential( |
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, hidden_size, bias=True), |
|
) |
|
self.frequency_embedding_size = frequency_embedding_size |
|
|
|
@staticmethod |
|
def timestep_embedding(t, dim, max_period=10000): |
|
""" |
|
Create sinusoidal timestep embeddings. |
|
:param t: a 1-D Tensor of N indices, one per batch element. |
|
These may be fractional. |
|
:param dim: the dimension of the output. |
|
:param max_period: controls the minimum frequency of the embeddings. |
|
:return: an (N, D) Tensor of positional embeddings. |
|
""" |
|
|
|
half = dim // 2 |
|
freqs = torch.exp( |
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half |
|
).to(device=t.device) |
|
args = t[:, None].float() * freqs[None] |
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
if dim % 2: |
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
|
return embedding |
|
|
|
def forward(self, t, dtype=torch.float32): |
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) |
|
t_emb = self.mlp(t_freq) |
|
return t_emb |
|
|
|
|
|
class FinalLayer(nn.Module): |
|
""" |
|
The final layer of DiT. |
|
""" |
|
def __init__(self, hidden_size, patch_size, out_channels): |
|
super().__init__() |
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) |
|
self.adaLN_modulation = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True) |
|
) |
|
|
|
def forward(self, x, c): |
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) |
|
x = modulate(self.norm_final(x), shift, scale) |
|
x = self.linear(x) |
|
return x |
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1): |
|
""" |
|
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or |
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
if isinstance(grid_size, int): |
|
grid_size = (grid_size, grid_size) |
|
|
|
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale |
|
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token and extra_tokens > 0: |
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float64) |
|
omega /= embed_dim / 2. |
|
omega = 1. / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum('m,d->md', pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
class PatchEmbedMR(nn.Module): |
|
""" 2D Image to Patch Embedding |
|
""" |
|
def __init__( |
|
self, |
|
patch_size: int = 2, |
|
in_chans: int = 4, |
|
embed_dim: int = 768, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) |
|
|
|
def forward(self, x): |
|
x = self.proj(x) |
|
x = x.flatten(2).transpose(1, 2) |
|
return x |
|
|
|
|
|
class OmniGenOriginalModel(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
def __init__( |
|
self, |
|
transformer_config: Phi3Config, |
|
patch_size=2, |
|
in_channels=4, |
|
pe_interpolation: float = 1.0, |
|
pos_embed_max_size: int = 192, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels |
|
self.patch_size = patch_size |
|
self.pos_embed_max_size = pos_embed_max_size |
|
|
|
hidden_size = transformer_config.hidden_size |
|
|
|
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True) |
|
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True) |
|
|
|
self.time_token = TimestepEmbedder(hidden_size) |
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
|
|
self.pe_interpolation = pe_interpolation |
|
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64) |
|
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) |
|
|
|
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) |
|
|
|
self.initialize_weights() |
|
|
|
self.llm = Phi3Transformer(config=transformer_config) |
|
self.llm.config.use_cache = False |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_name): |
|
if not os.path.exists(model_name): |
|
cache_folder = os.getenv('HF_HUB_CACHE') |
|
model_name = snapshot_download(repo_id=model_name, |
|
cache_dir=cache_folder, |
|
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) |
|
config = Phi3Config.from_pretrained(model_name) |
|
model = cls(config) |
|
if os.path.exists(os.path.join(model_name, 'model.safetensors')): |
|
print("Loading safetensors") |
|
ckpt = load_file(os.path.join(model_name, 'model.safetensors')) |
|
else: |
|
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu') |
|
model.load_state_dict(ckpt) |
|
return model |
|
|
|
def initialize_weights(self): |
|
assert not hasattr(self, "llama") |
|
|
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
self.apply(_basic_init) |
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
nn.init.constant_(self.x_embedder.proj.bias, 0) |
|
|
|
w = self.input_x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
nn.init.constant_(self.x_embedder.proj.bias, 0) |
|
|
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02) |
|
|
|
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
def unpatchify(self, x, h, w): |
|
""" |
|
x: (N, T, patch_size**2 * C) |
|
imgs: (N, H, W, C) |
|
""" |
|
c = self.out_channels |
|
|
|
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c)) |
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], c, h, w)) |
|
return imgs |
|
|
|
|
|
def cropped_pos_embed(self, height, width): |
|
"""Crops positional embeddings for SD3 compatibility.""" |
|
if self.pos_embed_max_size is None: |
|
raise ValueError("`pos_embed_max_size` must be set for cropping.") |
|
|
|
height = height // self.patch_size |
|
width = width // self.patch_size |
|
if height > self.pos_embed_max_size: |
|
raise ValueError( |
|
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." |
|
) |
|
if width > self.pos_embed_max_size: |
|
raise ValueError( |
|
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." |
|
) |
|
|
|
top = (self.pos_embed_max_size - height) // 2 |
|
left = (self.pos_embed_max_size - width) // 2 |
|
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) |
|
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] |
|
|
|
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) |
|
return spatial_pos_embed |
|
|
|
|
|
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False): |
|
if isinstance(latents, list): |
|
return_list = False |
|
if padding_latent is None: |
|
padding_latent = [None] * len(latents) |
|
return_list = True |
|
patched_latents, num_tokens, shapes = [], [], [] |
|
for latent, padding in zip(latents, padding_latent): |
|
height, width = latent.shape[-2:] |
|
if is_input_images: |
|
latent = self.input_x_embedder(latent) |
|
else: |
|
latent = self.x_embedder(latent) |
|
pos_embed = self.cropped_pos_embed(height, width) |
|
latent = latent + pos_embed |
|
if padding is not None: |
|
latent = torch.cat([latent, padding], dim=-2) |
|
patched_latents.append(latent) |
|
|
|
num_tokens.append(pos_embed.size(1)) |
|
shapes.append([height, width]) |
|
if not return_list: |
|
latents = torch.cat(patched_latents, dim=0) |
|
else: |
|
latents = patched_latents |
|
else: |
|
height, width = latents.shape[-2:] |
|
if is_input_images: |
|
latents = self.input_x_embedder(latents) |
|
else: |
|
latents = self.x_embedder(latents) |
|
pos_embed = self.cropped_pos_embed(height, width) |
|
latents = latents + pos_embed |
|
num_tokens = latents.size(1) |
|
shapes = [height, width] |
|
return latents, num_tokens, shapes |
|
|
|
|
|
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False): |
|
""" |
|
|
|
""" |
|
input_is_list = isinstance(x, list) |
|
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent) |
|
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1) |
|
|
|
if input_img_latents is not None: |
|
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True) |
|
if input_ids is not None: |
|
condition_embeds = self.llm.embed_tokens(input_ids).clone() |
|
input_img_inx = 0 |
|
for b_inx in input_image_sizes.keys(): |
|
for start_inx, end_inx in input_image_sizes[b_inx]: |
|
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx] |
|
input_img_inx += 1 |
|
if input_img_latents is not None: |
|
assert input_img_inx == len(input_latents) |
|
|
|
input_emb = torch.cat([condition_embeds, time_token, x], dim=1) |
|
else: |
|
input_emb = torch.cat([time_token, x], dim=1) |
|
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model) |
|
output, past_key_values = output.last_hidden_state, output.past_key_values |
|
if input_is_list: |
|
image_embedding = output[:, -max(num_tokens):] |
|
time_emb = self.t_embedder(timestep, dtype=x.dtype) |
|
x = self.final_layer(image_embedding, time_emb) |
|
latents = [] |
|
for i in range(x.size(0)): |
|
latent = x[i:i+1, :num_tokens[i]] |
|
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1]) |
|
latents.append(latent) |
|
else: |
|
image_embedding = output[:, -num_tokens:] |
|
time_emb = self.t_embedder(timestep, dtype=x.dtype) |
|
x = self.final_layer(image_embedding, time_emb) |
|
latents = self.unpatchify(x, shapes[0], shapes[1]) |
|
|
|
if return_past_key_values: |
|
return latents, past_key_values |
|
return latents |
|
|
|
@torch.no_grad() |
|
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model): |
|
self.llm.config.use_cache = use_kv_cache |
|
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model) |
|
if use_img_cfg: |
|
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0) |
|
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond) |
|
model_out = [cond, cond, cond] |
|
else: |
|
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0) |
|
cond = uncond + cfg_scale * (cond - uncond) |
|
model_out = [cond, cond] |
|
|
|
return torch.cat(model_out, dim=0), past_key_values |
|
|
|
|
|
@torch.no_grad() |
|
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model): |
|
self.llm.config.use_cache = use_kv_cache |
|
if past_key_values is None: |
|
past_key_values = [None] * len(attention_mask) |
|
|
|
x = torch.split(x, len(x) // len(attention_mask), dim=0) |
|
timestep = timestep.to(x[0].dtype) |
|
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0) |
|
|
|
model_out, pask_key_values = [], [] |
|
for i in range(len(input_ids)): |
|
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model) |
|
model_out.append(temp_out) |
|
pask_key_values.append(temp_pask_key_values) |
|
|
|
if len(model_out) == 3: |
|
cond, uncond, img_cond = model_out |
|
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond) |
|
model_out = [cond, cond, cond] |
|
elif len(model_out) == 2: |
|
cond, uncond = model_out |
|
cond = uncond + cfg_scale * (cond - uncond) |
|
model_out = [cond, cond] |
|
else: |
|
return model_out[0] |
|
|
|
return torch.cat(model_out, dim=0), pask_key_values |
|
|
|
|
|
|
|
class OmniGenTransformer(OmniGenOriginalModel): |
|
def __init__(self): |
|
config = { |
|
"_name_or_path": "Phi-3-vision-128k-instruct", |
|
"architectures": [ |
|
"Phi3ForCausalLM" |
|
], |
|
"attention_dropout": 0.0, |
|
"bos_token_id": 1, |
|
"eos_token_id": 2, |
|
"hidden_act": "silu", |
|
"hidden_size": 3072, |
|
"initializer_range": 0.02, |
|
"intermediate_size": 8192, |
|
"max_position_embeddings": 131072, |
|
"model_type": "phi3", |
|
"num_attention_heads": 32, |
|
"num_hidden_layers": 32, |
|
"num_key_value_heads": 32, |
|
"original_max_position_embeddings": 4096, |
|
"rms_norm_eps": 1e-05, |
|
"rope_scaling": { |
|
"long_factor": [ |
|
1.0299999713897705, |
|
1.0499999523162842, |
|
1.0499999523162842, |
|
1.0799999237060547, |
|
1.2299998998641968, |
|
1.2299998998641968, |
|
1.2999999523162842, |
|
1.4499999284744263, |
|
1.5999999046325684, |
|
1.6499998569488525, |
|
1.8999998569488525, |
|
2.859999895095825, |
|
3.68999981880188, |
|
5.419999599456787, |
|
5.489999771118164, |
|
5.489999771118164, |
|
9.09000015258789, |
|
11.579999923706055, |
|
15.65999984741211, |
|
15.769999504089355, |
|
15.789999961853027, |
|
18.360000610351562, |
|
21.989999771118164, |
|
23.079999923706055, |
|
30.009998321533203, |
|
32.35000228881836, |
|
32.590003967285156, |
|
35.56000518798828, |
|
39.95000457763672, |
|
53.840003967285156, |
|
56.20000457763672, |
|
57.95000457763672, |
|
59.29000473022461, |
|
59.77000427246094, |
|
59.920005798339844, |
|
61.190006256103516, |
|
61.96000671386719, |
|
62.50000762939453, |
|
63.3700065612793, |
|
63.48000717163086, |
|
63.48000717163086, |
|
63.66000747680664, |
|
63.850006103515625, |
|
64.08000946044922, |
|
64.760009765625, |
|
64.80001068115234, |
|
64.81001281738281, |
|
64.81001281738281 |
|
], |
|
"short_factor": [ |
|
1.05, |
|
1.05, |
|
1.05, |
|
1.1, |
|
1.1, |
|
1.1, |
|
1.2500000000000002, |
|
1.2500000000000002, |
|
1.4000000000000004, |
|
1.4500000000000004, |
|
1.5500000000000005, |
|
1.8500000000000008, |
|
1.9000000000000008, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.000000000000001, |
|
2.1000000000000005, |
|
2.1000000000000005, |
|
2.2, |
|
2.3499999999999996, |
|
2.3499999999999996, |
|
2.3499999999999996, |
|
2.3499999999999996, |
|
2.3999999999999995, |
|
2.3999999999999995, |
|
2.6499999999999986, |
|
2.6999999999999984, |
|
2.8999999999999977, |
|
2.9499999999999975, |
|
3.049999999999997, |
|
3.049999999999997, |
|
3.049999999999997 |
|
], |
|
"type": "su" |
|
}, |
|
"rope_theta": 10000.0, |
|
"sliding_window": 131072, |
|
"tie_word_embeddings": False, |
|
"torch_dtype": "bfloat16", |
|
"transformers_version": "4.38.1", |
|
"use_cache": True, |
|
"vocab_size": 32064, |
|
"_attn_implementation": "sdpa" |
|
} |
|
config = Phi3Config(**config) |
|
super().__init__(config) |
|
|
|
|
|
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False): |
|
input_is_list = isinstance(x, list) |
|
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent) |
|
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1) |
|
|
|
if input_img_latents is not None: |
|
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True) |
|
if input_ids is not None: |
|
condition_embeds = self.llm.embed_tokens(input_ids).clone() |
|
input_img_inx = 0 |
|
for b_inx in input_image_sizes.keys(): |
|
for start_inx, end_inx in input_image_sizes[b_inx]: |
|
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx] |
|
input_img_inx += 1 |
|
if input_img_latents is not None: |
|
assert input_img_inx == len(input_latents) |
|
|
|
input_emb = torch.cat([condition_embeds, time_token, x], dim=1) |
|
else: |
|
input_emb = torch.cat([time_token, x], dim=1) |
|
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model) |
|
output, past_key_values = output.last_hidden_state, output.past_key_values |
|
if input_is_list: |
|
image_embedding = output[:, -max(num_tokens):] |
|
time_emb = self.t_embedder(timestep, dtype=x.dtype) |
|
x = self.final_layer(image_embedding, time_emb) |
|
latents = [] |
|
for i in range(x.size(0)): |
|
latent = x[i:i+1, :num_tokens[i]] |
|
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1]) |
|
latents.append(latent) |
|
else: |
|
image_embedding = output[:, -num_tokens:] |
|
time_emb = self.t_embedder(timestep, dtype=x.dtype) |
|
x = self.final_layer(image_embedding, time_emb) |
|
latents = self.unpatchify(x, shapes[0], shapes[1]) |
|
|
|
if return_past_key_values: |
|
return latents, past_key_values |
|
return latents |
|
|
|
|
|
@torch.no_grad() |
|
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model): |
|
self.llm.config.use_cache = use_kv_cache |
|
if past_key_values is None: |
|
past_key_values = [None] * len(attention_mask) |
|
|
|
x = torch.split(x, len(x) // len(attention_mask), dim=0) |
|
timestep = timestep.to(x[0].dtype) |
|
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0) |
|
|
|
model_out, pask_key_values = [], [] |
|
for i in range(len(input_ids)): |
|
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model) |
|
model_out.append(temp_out) |
|
pask_key_values.append(temp_pask_key_values) |
|
|
|
if len(model_out) == 3: |
|
cond, uncond, img_cond = model_out |
|
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond) |
|
model_out = [cond, cond, cond] |
|
elif len(model_out) == 2: |
|
cond, uncond = model_out |
|
cond = uncond + cfg_scale * (cond - uncond) |
|
model_out = [cond, cond] |
|
else: |
|
return model_out[0] |
|
|
|
return torch.cat(model_out, dim=0), pask_key_values |
|
|
|
|
|
@staticmethod |
|
def state_dict_converter(): |
|
return OmniGenTransformerStateDictConverter() |
|
|
|
|
|
|
|
class OmniGenTransformerStateDictConverter: |
|
def __init__(self): |
|
pass |
|
|
|
def from_diffusers(self, state_dict): |
|
return state_dict |
|
|
|
def from_civitai(self, state_dict): |
|
return state_dict |
|
|