Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- README.md +95 -0
- added_tokens.json +33 -0
- assets/teaser.jpg +3 -0
- config.json +53 -0
- model.safetensors.index.json +698 -0
- models/config.py +45 -0
- models/gen_pipeline.py +398 -0
- models/heads.py +283 -0
- models/llama_model.py +568 -0
- models/nextstep_model.py +553 -0
- pytorch-model-00001-of-00004.safetensors +3 -0
- pytorch-model-00002-of-00004.safetensors +3 -0
- pytorch-model-00003-of-00004.safetensors +3 -0
- pytorch-model-00004-of-00004.safetensors +3 -0
- requirements.txt +14 -0
- special_tokens_map.json +26 -0
- tokenizer.json +3 -0
- tokenizer_config.json +276 -0
- utils/aspect_ratio.py +107 -0
- utils/compile_utils.py +122 -0
- utils/image_utils.py +314 -0
- utils/misc.py +51 -0
- utils/model_utils.py +128 -0
- vae/checkpoint.pt +3 -0
- vae/config.json +14 -0
- vae/nextstep_ae.py +494 -0
- vocab.json +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/teaser.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
## NextStep-1: Toward Autoregressive Image Generation with Continuous Tokens at Scale
|
5 |
+
|
6 |
+
[Homepage](https://stepfun.ai/research/en/nextstep-1)
|
7 |
+
| [GitHub](https://github.com/stepfun-ai/NextStep-1)
|
8 |
+
| [Paper](https://arxiv.org/abs/2508.10711)
|
9 |
+
|
10 |
+
We introduce **NextStep-1**, a 14B autoregressive model paired with a 157M flow matching head, training on discrete text tokens and continuous image tokens with next-token prediction objectives.
|
11 |
+
**NextStep-1** achieves state-of-the-art performance for autoregressive models in text-to-image generation tasks, exhibiting strong capabilities in high-fidelity image synthesis.
|
12 |
+
|
13 |
+
<div align='center'>
|
14 |
+
<img src="assets/teaser.jpg" class="interpolation-image" alt="arch." width="100%" />
|
15 |
+
</div>
|
16 |
+
|
17 |
+
## Environment Setup
|
18 |
+
|
19 |
+
To avoid potential errors when loading and running your models, we recommend using the following settings:
|
20 |
+
|
21 |
+
```shell
|
22 |
+
conda create -n nextstep python=3.11 -y
|
23 |
+
conda activate nextstep
|
24 |
+
|
25 |
+
pip install uv # optional
|
26 |
+
|
27 |
+
# please check and download requirements.txt in this repo
|
28 |
+
uv pip install -r requirements.txt
|
29 |
+
|
30 |
+
# diffusers==0.34.0
|
31 |
+
# einops==0.8.1
|
32 |
+
# gradio==5.42.0
|
33 |
+
# loguru==0.7.3
|
34 |
+
# numpy==1.26.4
|
35 |
+
# omegaconf==2.3.0
|
36 |
+
# Pillow==11.0.0
|
37 |
+
# Requests==2.32.4
|
38 |
+
# safetensors==0.5.3
|
39 |
+
# tabulate==0.9.0
|
40 |
+
# torch==2.5.1
|
41 |
+
# torchvision==0.20.1
|
42 |
+
# tqdm==4.67.1
|
43 |
+
# transformers==4.55.0
|
44 |
+
```
|
45 |
+
|
46 |
+
## Usage
|
47 |
+
|
48 |
+
```python
|
49 |
+
import torch
|
50 |
+
from transformers import AutoTokenizer, AutoModel
|
51 |
+
from models.gen_pipeline import NextStepPipeline
|
52 |
+
|
53 |
+
HF_HUB = "stepfun-ai/NextStep-1-Large-Pretrain"
|
54 |
+
|
55 |
+
# load model and tokenizer
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=True, trust_remote_code=True)
|
57 |
+
model = AutoModel.from_pretrained(HF_HUB, local_files_only=True, trust_remote_code=True)
|
58 |
+
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device="cuda", dtype=torch.bfloat16)
|
59 |
+
|
60 |
+
# set prompts
|
61 |
+
positive_prompt = "masterpiece, film grained, best quality."
|
62 |
+
negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
|
63 |
+
example_prompt = "A realistic photograph of a wall with \"NextStep-1.1 is coming\" prominently displayed"
|
64 |
+
|
65 |
+
# generate image from text
|
66 |
+
IMG_SIZE = 512
|
67 |
+
image = pipeline.generate_image(
|
68 |
+
example_prompt,
|
69 |
+
hw=(IMG_SIZE, IMG_SIZE),
|
70 |
+
num_images_per_caption=1,
|
71 |
+
positive_prompt=positive_prompt,
|
72 |
+
negative_prompt=negative_prompt,
|
73 |
+
cfg=7.5,
|
74 |
+
cfg_img=1.0,
|
75 |
+
cfg_schedule="constant",
|
76 |
+
use_norm=False,
|
77 |
+
num_sampling_steps=28,
|
78 |
+
timesteps_shift=1.0,
|
79 |
+
seed=3407,
|
80 |
+
)[0]
|
81 |
+
image.save("./assets/output.jpg")
|
82 |
+
```
|
83 |
+
|
84 |
+
## Citation
|
85 |
+
|
86 |
+
If you find NextStep useful for your research and applications, please consider starring this repository and citing:
|
87 |
+
|
88 |
+
```bibtex
|
89 |
+
@misc{nextstep_1,
|
90 |
+
title={NextStep-1: Toward Autoregressive Image Generation with Continuous Tokens at Scale},
|
91 |
+
author={NextStep Team},
|
92 |
+
year={2025},
|
93 |
+
url={https://github.com/stepfun-ai/NextStep-1},
|
94 |
+
}
|
95 |
+
```
|
added_tokens.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</tool_call>": 151658,
|
3 |
+
"<tool_call>": 151657,
|
4 |
+
"<|begin_of_image|>": 151667,
|
5 |
+
"<|begin_of_prompt_refinement|>": 151670,
|
6 |
+
"<|begin_of_thinking|>": 151672,
|
7 |
+
"<|box_end|>": 151649,
|
8 |
+
"<|box_start|>": 151648,
|
9 |
+
"<|end_of_image|>": 151668,
|
10 |
+
"<|end_of_prompt_refinement|>": 151671,
|
11 |
+
"<|end_of_thinking|>": 151673,
|
12 |
+
"<|endoftext|>": 151643,
|
13 |
+
"<|file_sep|>": 151664,
|
14 |
+
"<|fim_middle|>": 151660,
|
15 |
+
"<|fim_pad|>": 151662,
|
16 |
+
"<|fim_prefix|>": 151659,
|
17 |
+
"<|fim_suffix|>": 151661,
|
18 |
+
"<|im_end|>": 151645,
|
19 |
+
"<|im_start|>": 151644,
|
20 |
+
"<|image_area|>": 151666,
|
21 |
+
"<|image_pad|>": 151655,
|
22 |
+
"<|image_placeholder|>": 151669,
|
23 |
+
"<|object_ref_end|>": 151647,
|
24 |
+
"<|object_ref_start|>": 151646,
|
25 |
+
"<|quad_end|>": 151651,
|
26 |
+
"<|quad_start|>": 151650,
|
27 |
+
"<|repo_name|>": 151663,
|
28 |
+
"<|video_pad|>": 151656,
|
29 |
+
"<|vision_end|>": 151653,
|
30 |
+
"<|vision_pad|>": 151654,
|
31 |
+
"<|vision_start|>": 151652,
|
32 |
+
"[PAD]": 151665
|
33 |
+
}
|
assets/teaser.jpg
ADDED
![]() |
Git LFS Details
|
config.json
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_attn_implementation_autoset": true,
|
3 |
+
"architectures": [
|
4 |
+
"LlamaForCausalLM"
|
5 |
+
],
|
6 |
+
"auto_map":{
|
7 |
+
"AutoConfig": "models/config.NextStepConfig",
|
8 |
+
"AutoModel": "models/nextstep_model.NextStep"
|
9 |
+
},
|
10 |
+
"attention_bias": true,
|
11 |
+
"attention_dropout": 0.0,
|
12 |
+
"base_image_grid_size": 32,
|
13 |
+
"boi": 151667,
|
14 |
+
"bos_token_id": 151643,
|
15 |
+
"eoi": 151668,
|
16 |
+
"eos_token_id": 151643,
|
17 |
+
"fm_head_batch_mul": 4,
|
18 |
+
"fm_head_dim": 1536,
|
19 |
+
"fm_head_layers": 12,
|
20 |
+
"head_dim": 128,
|
21 |
+
"hidden_act": "silu",
|
22 |
+
"hidden_size": 5120,
|
23 |
+
"im_loss_weight": 1.0,
|
24 |
+
"image_placeholder_id": 151669,
|
25 |
+
"initializer_range": 0.02,
|
26 |
+
"intermediate_size": 13824,
|
27 |
+
"latent_channels": 16,
|
28 |
+
"latent_patch_size": 2,
|
29 |
+
"latent_size": 64,
|
30 |
+
"lm_loss_weight": 0.01,
|
31 |
+
"max_position_embeddings": 131072,
|
32 |
+
"max_window_layers": 48,
|
33 |
+
"mlp_bias": false,
|
34 |
+
"model_type": "nextstep",
|
35 |
+
"num_attention_heads": 40,
|
36 |
+
"num_hidden_layers": 48,
|
37 |
+
"num_key_value_heads": 8,
|
38 |
+
"o_attention_bias": false,
|
39 |
+
"pad_token_id_added": 151665,
|
40 |
+
"pretraining_tp": 1,
|
41 |
+
"rms_norm_eps": 1e-05,
|
42 |
+
"rope_scaling": null,
|
43 |
+
"rope_theta": 1000000.0,
|
44 |
+
"sliding_window": 131072,
|
45 |
+
"tie_word_embeddings": false,
|
46 |
+
"torch_dtype": "bfloat16",
|
47 |
+
"transformers_version": "4.55.0",
|
48 |
+
"use_cache": true,
|
49 |
+
"use_gen_pos_embed": false,
|
50 |
+
"use_sliding_window": false,
|
51 |
+
"vae_name_or_path": "./vae",
|
52 |
+
"vocab_size": 152064
|
53 |
+
}
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 29907628160
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"embed_tokens.weight": "pytorch-model-00004-of-00004.safetensors",
|
7 |
+
"image_head.net.cond_embed.bias": "pytorch-model-00003-of-00004.safetensors",
|
8 |
+
"image_head.net.cond_embed.weight": "pytorch-model-00003-of-00004.safetensors",
|
9 |
+
"image_head.net.final_layer.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
10 |
+
"image_head.net.final_layer.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
11 |
+
"image_head.net.final_layer.linear.bias": "pytorch-model-00003-of-00004.safetensors",
|
12 |
+
"image_head.net.final_layer.linear.weight": "pytorch-model-00003-of-00004.safetensors",
|
13 |
+
"image_head.net.input_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
14 |
+
"image_head.net.input_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
15 |
+
"image_head.net.res_blocks.0.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
16 |
+
"image_head.net.res_blocks.0.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
17 |
+
"image_head.net.res_blocks.0.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
18 |
+
"image_head.net.res_blocks.0.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
19 |
+
"image_head.net.res_blocks.0.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
20 |
+
"image_head.net.res_blocks.0.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
21 |
+
"image_head.net.res_blocks.0.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
22 |
+
"image_head.net.res_blocks.0.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
23 |
+
"image_head.net.res_blocks.1.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
24 |
+
"image_head.net.res_blocks.1.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
25 |
+
"image_head.net.res_blocks.1.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
26 |
+
"image_head.net.res_blocks.1.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
27 |
+
"image_head.net.res_blocks.1.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
28 |
+
"image_head.net.res_blocks.1.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
29 |
+
"image_head.net.res_blocks.1.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
30 |
+
"image_head.net.res_blocks.1.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
31 |
+
"image_head.net.res_blocks.10.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
32 |
+
"image_head.net.res_blocks.10.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
33 |
+
"image_head.net.res_blocks.10.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
34 |
+
"image_head.net.res_blocks.10.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
35 |
+
"image_head.net.res_blocks.10.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
36 |
+
"image_head.net.res_blocks.10.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
37 |
+
"image_head.net.res_blocks.10.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
38 |
+
"image_head.net.res_blocks.10.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
39 |
+
"image_head.net.res_blocks.11.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
40 |
+
"image_head.net.res_blocks.11.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
41 |
+
"image_head.net.res_blocks.11.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
42 |
+
"image_head.net.res_blocks.11.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
43 |
+
"image_head.net.res_blocks.11.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
44 |
+
"image_head.net.res_blocks.11.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
45 |
+
"image_head.net.res_blocks.11.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
46 |
+
"image_head.net.res_blocks.11.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
47 |
+
"image_head.net.res_blocks.2.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
48 |
+
"image_head.net.res_blocks.2.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
49 |
+
"image_head.net.res_blocks.2.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
50 |
+
"image_head.net.res_blocks.2.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
51 |
+
"image_head.net.res_blocks.2.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
52 |
+
"image_head.net.res_blocks.2.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
53 |
+
"image_head.net.res_blocks.2.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
54 |
+
"image_head.net.res_blocks.2.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
55 |
+
"image_head.net.res_blocks.3.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
56 |
+
"image_head.net.res_blocks.3.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
57 |
+
"image_head.net.res_blocks.3.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
58 |
+
"image_head.net.res_blocks.3.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
59 |
+
"image_head.net.res_blocks.3.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
60 |
+
"image_head.net.res_blocks.3.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
61 |
+
"image_head.net.res_blocks.3.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
62 |
+
"image_head.net.res_blocks.3.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
63 |
+
"image_head.net.res_blocks.4.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
64 |
+
"image_head.net.res_blocks.4.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
65 |
+
"image_head.net.res_blocks.4.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
66 |
+
"image_head.net.res_blocks.4.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
67 |
+
"image_head.net.res_blocks.4.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
68 |
+
"image_head.net.res_blocks.4.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
69 |
+
"image_head.net.res_blocks.4.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
70 |
+
"image_head.net.res_blocks.4.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
71 |
+
"image_head.net.res_blocks.5.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
72 |
+
"image_head.net.res_blocks.5.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
73 |
+
"image_head.net.res_blocks.5.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
74 |
+
"image_head.net.res_blocks.5.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
75 |
+
"image_head.net.res_blocks.5.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
76 |
+
"image_head.net.res_blocks.5.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
77 |
+
"image_head.net.res_blocks.5.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
78 |
+
"image_head.net.res_blocks.5.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
79 |
+
"image_head.net.res_blocks.6.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
80 |
+
"image_head.net.res_blocks.6.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
81 |
+
"image_head.net.res_blocks.6.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
82 |
+
"image_head.net.res_blocks.6.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
83 |
+
"image_head.net.res_blocks.6.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
84 |
+
"image_head.net.res_blocks.6.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
85 |
+
"image_head.net.res_blocks.6.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
86 |
+
"image_head.net.res_blocks.6.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
87 |
+
"image_head.net.res_blocks.7.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
88 |
+
"image_head.net.res_blocks.7.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
89 |
+
"image_head.net.res_blocks.7.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
90 |
+
"image_head.net.res_blocks.7.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
91 |
+
"image_head.net.res_blocks.7.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
92 |
+
"image_head.net.res_blocks.7.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
93 |
+
"image_head.net.res_blocks.7.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
94 |
+
"image_head.net.res_blocks.7.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
95 |
+
"image_head.net.res_blocks.8.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
96 |
+
"image_head.net.res_blocks.8.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
97 |
+
"image_head.net.res_blocks.8.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
98 |
+
"image_head.net.res_blocks.8.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
99 |
+
"image_head.net.res_blocks.8.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
100 |
+
"image_head.net.res_blocks.8.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
101 |
+
"image_head.net.res_blocks.8.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
102 |
+
"image_head.net.res_blocks.8.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
103 |
+
"image_head.net.res_blocks.9.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
|
104 |
+
"image_head.net.res_blocks.9.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
|
105 |
+
"image_head.net.res_blocks.9.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
|
106 |
+
"image_head.net.res_blocks.9.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
|
107 |
+
"image_head.net.res_blocks.9.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
108 |
+
"image_head.net.res_blocks.9.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
109 |
+
"image_head.net.res_blocks.9.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
110 |
+
"image_head.net.res_blocks.9.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
111 |
+
"image_head.net.time_embed.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
|
112 |
+
"image_head.net.time_embed.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
|
113 |
+
"image_head.net.time_embed.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
|
114 |
+
"image_head.net.time_embed.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
|
115 |
+
"image_in_projector.bias": "pytorch-model-00003-of-00004.safetensors",
|
116 |
+
"image_in_projector.weight": "pytorch-model-00003-of-00004.safetensors",
|
117 |
+
"image_out_projector.bias": "pytorch-model-00003-of-00004.safetensors",
|
118 |
+
"image_out_projector.weight": "pytorch-model-00003-of-00004.safetensors",
|
119 |
+
"layers.0.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
120 |
+
"layers.0.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
121 |
+
"layers.0.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
122 |
+
"layers.0.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
123 |
+
"layers.0.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
124 |
+
"layers.0.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
125 |
+
"layers.0.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
126 |
+
"layers.0.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
127 |
+
"layers.0.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
128 |
+
"layers.0.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
129 |
+
"layers.0.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
130 |
+
"layers.0.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
131 |
+
"layers.1.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
132 |
+
"layers.1.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
133 |
+
"layers.1.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
134 |
+
"layers.1.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
135 |
+
"layers.1.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
136 |
+
"layers.1.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
137 |
+
"layers.1.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
138 |
+
"layers.1.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
139 |
+
"layers.1.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
140 |
+
"layers.1.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
141 |
+
"layers.1.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
142 |
+
"layers.1.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
143 |
+
"layers.10.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
144 |
+
"layers.10.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
145 |
+
"layers.10.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
146 |
+
"layers.10.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
147 |
+
"layers.10.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
148 |
+
"layers.10.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
149 |
+
"layers.10.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
150 |
+
"layers.10.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
151 |
+
"layers.10.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
152 |
+
"layers.10.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
153 |
+
"layers.10.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
154 |
+
"layers.10.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
155 |
+
"layers.11.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
156 |
+
"layers.11.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
157 |
+
"layers.11.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
158 |
+
"layers.11.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
159 |
+
"layers.11.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
160 |
+
"layers.11.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
161 |
+
"layers.11.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
162 |
+
"layers.11.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
163 |
+
"layers.11.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
164 |
+
"layers.11.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
165 |
+
"layers.11.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
166 |
+
"layers.11.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
167 |
+
"layers.12.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
168 |
+
"layers.12.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
169 |
+
"layers.12.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
170 |
+
"layers.12.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
171 |
+
"layers.12.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
172 |
+
"layers.12.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
173 |
+
"layers.12.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
174 |
+
"layers.12.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
175 |
+
"layers.12.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
176 |
+
"layers.12.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
177 |
+
"layers.12.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
178 |
+
"layers.12.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
179 |
+
"layers.13.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
180 |
+
"layers.13.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
181 |
+
"layers.13.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
182 |
+
"layers.13.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
183 |
+
"layers.13.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
184 |
+
"layers.13.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
185 |
+
"layers.13.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
186 |
+
"layers.13.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
187 |
+
"layers.13.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
188 |
+
"layers.13.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
189 |
+
"layers.13.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
190 |
+
"layers.13.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
191 |
+
"layers.14.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
192 |
+
"layers.14.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
193 |
+
"layers.14.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
194 |
+
"layers.14.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
195 |
+
"layers.14.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
196 |
+
"layers.14.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
197 |
+
"layers.14.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
198 |
+
"layers.14.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
199 |
+
"layers.14.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
200 |
+
"layers.14.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
201 |
+
"layers.14.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
202 |
+
"layers.14.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
203 |
+
"layers.15.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
204 |
+
"layers.15.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
205 |
+
"layers.15.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
206 |
+
"layers.15.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
207 |
+
"layers.15.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
208 |
+
"layers.15.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
209 |
+
"layers.15.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
210 |
+
"layers.15.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
211 |
+
"layers.15.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
212 |
+
"layers.15.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
213 |
+
"layers.15.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
214 |
+
"layers.15.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
215 |
+
"layers.16.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
216 |
+
"layers.16.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
217 |
+
"layers.16.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
218 |
+
"layers.16.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
219 |
+
"layers.16.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
220 |
+
"layers.16.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
221 |
+
"layers.16.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
222 |
+
"layers.16.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
223 |
+
"layers.16.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
224 |
+
"layers.16.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
225 |
+
"layers.16.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
226 |
+
"layers.16.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
227 |
+
"layers.17.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
228 |
+
"layers.17.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
229 |
+
"layers.17.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
230 |
+
"layers.17.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
231 |
+
"layers.17.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
232 |
+
"layers.17.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
233 |
+
"layers.17.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
234 |
+
"layers.17.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
235 |
+
"layers.17.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
236 |
+
"layers.17.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
237 |
+
"layers.17.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
238 |
+
"layers.17.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
239 |
+
"layers.18.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
240 |
+
"layers.18.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
241 |
+
"layers.18.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
242 |
+
"layers.18.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
243 |
+
"layers.18.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
244 |
+
"layers.18.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
245 |
+
"layers.18.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
246 |
+
"layers.18.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
247 |
+
"layers.18.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
248 |
+
"layers.18.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
249 |
+
"layers.18.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
250 |
+
"layers.18.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
251 |
+
"layers.19.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
252 |
+
"layers.19.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
253 |
+
"layers.19.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
254 |
+
"layers.19.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
255 |
+
"layers.19.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
256 |
+
"layers.19.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
257 |
+
"layers.19.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
258 |
+
"layers.19.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
259 |
+
"layers.19.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
260 |
+
"layers.19.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
261 |
+
"layers.19.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
262 |
+
"layers.19.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
263 |
+
"layers.2.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
264 |
+
"layers.2.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
265 |
+
"layers.2.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
266 |
+
"layers.2.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
267 |
+
"layers.2.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
268 |
+
"layers.2.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
269 |
+
"layers.2.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
270 |
+
"layers.2.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
271 |
+
"layers.2.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
272 |
+
"layers.2.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
273 |
+
"layers.2.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
274 |
+
"layers.2.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
275 |
+
"layers.20.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
276 |
+
"layers.20.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
277 |
+
"layers.20.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
278 |
+
"layers.20.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
279 |
+
"layers.20.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
280 |
+
"layers.20.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
281 |
+
"layers.20.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
282 |
+
"layers.20.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
283 |
+
"layers.20.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
284 |
+
"layers.20.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
285 |
+
"layers.20.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
286 |
+
"layers.20.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
287 |
+
"layers.21.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
288 |
+
"layers.21.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
289 |
+
"layers.21.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
290 |
+
"layers.21.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
291 |
+
"layers.21.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
292 |
+
"layers.21.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
293 |
+
"layers.21.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
294 |
+
"layers.21.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
295 |
+
"layers.21.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
296 |
+
"layers.21.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
297 |
+
"layers.21.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
298 |
+
"layers.21.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
299 |
+
"layers.22.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
300 |
+
"layers.22.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
301 |
+
"layers.22.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
302 |
+
"layers.22.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
303 |
+
"layers.22.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
304 |
+
"layers.22.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
305 |
+
"layers.22.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
306 |
+
"layers.22.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
307 |
+
"layers.22.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
308 |
+
"layers.22.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
309 |
+
"layers.22.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
310 |
+
"layers.22.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
311 |
+
"layers.23.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
312 |
+
"layers.23.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
313 |
+
"layers.23.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
314 |
+
"layers.23.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
315 |
+
"layers.23.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
316 |
+
"layers.23.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
317 |
+
"layers.23.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
318 |
+
"layers.23.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
319 |
+
"layers.23.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
320 |
+
"layers.23.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
321 |
+
"layers.23.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
322 |
+
"layers.23.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
323 |
+
"layers.24.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
324 |
+
"layers.24.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
325 |
+
"layers.24.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
326 |
+
"layers.24.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
327 |
+
"layers.24.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
328 |
+
"layers.24.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
329 |
+
"layers.24.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
330 |
+
"layers.24.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
331 |
+
"layers.24.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
332 |
+
"layers.24.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
333 |
+
"layers.24.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
|
334 |
+
"layers.24.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
335 |
+
"layers.25.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
|
336 |
+
"layers.25.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
337 |
+
"layers.25.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
338 |
+
"layers.25.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
339 |
+
"layers.25.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
340 |
+
"layers.25.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
341 |
+
"layers.25.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
342 |
+
"layers.25.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
|
343 |
+
"layers.25.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
344 |
+
"layers.25.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
345 |
+
"layers.25.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
346 |
+
"layers.25.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
347 |
+
"layers.26.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
348 |
+
"layers.26.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
349 |
+
"layers.26.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
350 |
+
"layers.26.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
351 |
+
"layers.26.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
352 |
+
"layers.26.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
353 |
+
"layers.26.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
354 |
+
"layers.26.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
355 |
+
"layers.26.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
356 |
+
"layers.26.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
357 |
+
"layers.26.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
358 |
+
"layers.26.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
359 |
+
"layers.27.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
360 |
+
"layers.27.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
361 |
+
"layers.27.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
362 |
+
"layers.27.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
363 |
+
"layers.27.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
364 |
+
"layers.27.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
365 |
+
"layers.27.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
366 |
+
"layers.27.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
367 |
+
"layers.27.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
368 |
+
"layers.27.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
369 |
+
"layers.27.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
370 |
+
"layers.27.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
371 |
+
"layers.28.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
372 |
+
"layers.28.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
373 |
+
"layers.28.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
374 |
+
"layers.28.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
375 |
+
"layers.28.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
376 |
+
"layers.28.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
377 |
+
"layers.28.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
378 |
+
"layers.28.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
379 |
+
"layers.28.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
380 |
+
"layers.28.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
381 |
+
"layers.28.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
382 |
+
"layers.28.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
383 |
+
"layers.29.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
384 |
+
"layers.29.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
385 |
+
"layers.29.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
386 |
+
"layers.29.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
387 |
+
"layers.29.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
388 |
+
"layers.29.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
389 |
+
"layers.29.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
390 |
+
"layers.29.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
391 |
+
"layers.29.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
392 |
+
"layers.29.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
393 |
+
"layers.29.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
394 |
+
"layers.29.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
395 |
+
"layers.3.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
396 |
+
"layers.3.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
397 |
+
"layers.3.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
398 |
+
"layers.3.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
399 |
+
"layers.3.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
400 |
+
"layers.3.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
401 |
+
"layers.3.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
402 |
+
"layers.3.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
403 |
+
"layers.3.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
404 |
+
"layers.3.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
405 |
+
"layers.3.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
406 |
+
"layers.3.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
407 |
+
"layers.30.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
408 |
+
"layers.30.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
409 |
+
"layers.30.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
410 |
+
"layers.30.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
411 |
+
"layers.30.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
412 |
+
"layers.30.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
413 |
+
"layers.30.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
414 |
+
"layers.30.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
415 |
+
"layers.30.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
416 |
+
"layers.30.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
417 |
+
"layers.30.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
418 |
+
"layers.30.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
419 |
+
"layers.31.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
420 |
+
"layers.31.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
421 |
+
"layers.31.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
422 |
+
"layers.31.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
423 |
+
"layers.31.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
424 |
+
"layers.31.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
425 |
+
"layers.31.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
426 |
+
"layers.31.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
427 |
+
"layers.31.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
428 |
+
"layers.31.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
429 |
+
"layers.31.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
430 |
+
"layers.31.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
431 |
+
"layers.32.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
432 |
+
"layers.32.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
433 |
+
"layers.32.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
434 |
+
"layers.32.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
435 |
+
"layers.32.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
436 |
+
"layers.32.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
437 |
+
"layers.32.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
438 |
+
"layers.32.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
439 |
+
"layers.32.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
440 |
+
"layers.32.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
441 |
+
"layers.32.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
442 |
+
"layers.32.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
443 |
+
"layers.33.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
444 |
+
"layers.33.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
445 |
+
"layers.33.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
446 |
+
"layers.33.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
447 |
+
"layers.33.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
448 |
+
"layers.33.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
449 |
+
"layers.33.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
450 |
+
"layers.33.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
451 |
+
"layers.33.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
452 |
+
"layers.33.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
453 |
+
"layers.33.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
454 |
+
"layers.33.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
455 |
+
"layers.34.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
456 |
+
"layers.34.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
457 |
+
"layers.34.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
458 |
+
"layers.34.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
459 |
+
"layers.34.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
460 |
+
"layers.34.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
461 |
+
"layers.34.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
462 |
+
"layers.34.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
463 |
+
"layers.34.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
464 |
+
"layers.34.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
465 |
+
"layers.34.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
466 |
+
"layers.34.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
467 |
+
"layers.35.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
468 |
+
"layers.35.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
469 |
+
"layers.35.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
470 |
+
"layers.35.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
471 |
+
"layers.35.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
472 |
+
"layers.35.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
473 |
+
"layers.35.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
474 |
+
"layers.35.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
475 |
+
"layers.35.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
476 |
+
"layers.35.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
477 |
+
"layers.35.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
478 |
+
"layers.35.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
479 |
+
"layers.36.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
480 |
+
"layers.36.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
481 |
+
"layers.36.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
482 |
+
"layers.36.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
483 |
+
"layers.36.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
484 |
+
"layers.36.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
485 |
+
"layers.36.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
486 |
+
"layers.36.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
487 |
+
"layers.36.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
488 |
+
"layers.36.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
489 |
+
"layers.36.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
490 |
+
"layers.36.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
491 |
+
"layers.37.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
492 |
+
"layers.37.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
493 |
+
"layers.37.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
494 |
+
"layers.37.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
495 |
+
"layers.37.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
496 |
+
"layers.37.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
497 |
+
"layers.37.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
498 |
+
"layers.37.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
499 |
+
"layers.37.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
500 |
+
"layers.37.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
501 |
+
"layers.37.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
502 |
+
"layers.37.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
503 |
+
"layers.38.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
504 |
+
"layers.38.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
505 |
+
"layers.38.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
506 |
+
"layers.38.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
507 |
+
"layers.38.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
508 |
+
"layers.38.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
509 |
+
"layers.38.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
510 |
+
"layers.38.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
511 |
+
"layers.38.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
512 |
+
"layers.38.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
513 |
+
"layers.38.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
514 |
+
"layers.38.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
515 |
+
"layers.39.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
516 |
+
"layers.39.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
517 |
+
"layers.39.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
518 |
+
"layers.39.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
519 |
+
"layers.39.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
520 |
+
"layers.39.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
521 |
+
"layers.39.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
522 |
+
"layers.39.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
523 |
+
"layers.39.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
524 |
+
"layers.39.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
525 |
+
"layers.39.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
526 |
+
"layers.39.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
527 |
+
"layers.4.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
528 |
+
"layers.4.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
529 |
+
"layers.4.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
530 |
+
"layers.4.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
531 |
+
"layers.4.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
532 |
+
"layers.4.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
533 |
+
"layers.4.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
534 |
+
"layers.4.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
535 |
+
"layers.4.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
536 |
+
"layers.4.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
537 |
+
"layers.4.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
538 |
+
"layers.4.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
539 |
+
"layers.40.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
540 |
+
"layers.40.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
541 |
+
"layers.40.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
542 |
+
"layers.40.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
543 |
+
"layers.40.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
544 |
+
"layers.40.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
545 |
+
"layers.40.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
546 |
+
"layers.40.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
547 |
+
"layers.40.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
548 |
+
"layers.40.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
549 |
+
"layers.40.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
|
550 |
+
"layers.40.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
551 |
+
"layers.41.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
|
552 |
+
"layers.41.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
553 |
+
"layers.41.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
554 |
+
"layers.41.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
555 |
+
"layers.41.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
556 |
+
"layers.41.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
557 |
+
"layers.41.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
558 |
+
"layers.41.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
|
559 |
+
"layers.41.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
560 |
+
"layers.41.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
561 |
+
"layers.41.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
562 |
+
"layers.41.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
563 |
+
"layers.42.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
564 |
+
"layers.42.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
565 |
+
"layers.42.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
566 |
+
"layers.42.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
567 |
+
"layers.42.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
568 |
+
"layers.42.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
569 |
+
"layers.42.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
570 |
+
"layers.42.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
571 |
+
"layers.42.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
572 |
+
"layers.42.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
573 |
+
"layers.42.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
574 |
+
"layers.42.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
575 |
+
"layers.43.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
576 |
+
"layers.43.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
577 |
+
"layers.43.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
578 |
+
"layers.43.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
579 |
+
"layers.43.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
580 |
+
"layers.43.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
581 |
+
"layers.43.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
582 |
+
"layers.43.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
583 |
+
"layers.43.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
584 |
+
"layers.43.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
585 |
+
"layers.43.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
586 |
+
"layers.43.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
587 |
+
"layers.44.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
588 |
+
"layers.44.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
589 |
+
"layers.44.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
590 |
+
"layers.44.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
591 |
+
"layers.44.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
592 |
+
"layers.44.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
593 |
+
"layers.44.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
594 |
+
"layers.44.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
595 |
+
"layers.44.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
596 |
+
"layers.44.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
597 |
+
"layers.44.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
598 |
+
"layers.44.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
599 |
+
"layers.45.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
600 |
+
"layers.45.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
601 |
+
"layers.45.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
602 |
+
"layers.45.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
603 |
+
"layers.45.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
604 |
+
"layers.45.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
605 |
+
"layers.45.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
606 |
+
"layers.45.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
607 |
+
"layers.45.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
608 |
+
"layers.45.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
609 |
+
"layers.45.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
610 |
+
"layers.45.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
611 |
+
"layers.46.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
612 |
+
"layers.46.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
613 |
+
"layers.46.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
614 |
+
"layers.46.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
615 |
+
"layers.46.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
616 |
+
"layers.46.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
617 |
+
"layers.46.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
618 |
+
"layers.46.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
619 |
+
"layers.46.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
620 |
+
"layers.46.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
621 |
+
"layers.46.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
622 |
+
"layers.46.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
623 |
+
"layers.47.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
624 |
+
"layers.47.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
625 |
+
"layers.47.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
626 |
+
"layers.47.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
627 |
+
"layers.47.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
628 |
+
"layers.47.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
629 |
+
"layers.47.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
630 |
+
"layers.47.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
631 |
+
"layers.47.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
632 |
+
"layers.47.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
633 |
+
"layers.47.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
634 |
+
"layers.47.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
635 |
+
"layers.5.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
636 |
+
"layers.5.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
637 |
+
"layers.5.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
638 |
+
"layers.5.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
639 |
+
"layers.5.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
640 |
+
"layers.5.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
641 |
+
"layers.5.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
642 |
+
"layers.5.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
643 |
+
"layers.5.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
644 |
+
"layers.5.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
645 |
+
"layers.5.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
646 |
+
"layers.5.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
647 |
+
"layers.6.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
648 |
+
"layers.6.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
649 |
+
"layers.6.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
650 |
+
"layers.6.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
651 |
+
"layers.6.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
652 |
+
"layers.6.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
653 |
+
"layers.6.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
654 |
+
"layers.6.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
655 |
+
"layers.6.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
656 |
+
"layers.6.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
657 |
+
"layers.6.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
658 |
+
"layers.6.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
659 |
+
"layers.7.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
660 |
+
"layers.7.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
661 |
+
"layers.7.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
662 |
+
"layers.7.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
663 |
+
"layers.7.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
664 |
+
"layers.7.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
665 |
+
"layers.7.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
666 |
+
"layers.7.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
667 |
+
"layers.7.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
668 |
+
"layers.7.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
669 |
+
"layers.7.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
670 |
+
"layers.7.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
671 |
+
"layers.8.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
672 |
+
"layers.8.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
673 |
+
"layers.8.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
674 |
+
"layers.8.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
675 |
+
"layers.8.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
676 |
+
"layers.8.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
677 |
+
"layers.8.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
678 |
+
"layers.8.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
679 |
+
"layers.8.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
680 |
+
"layers.8.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
681 |
+
"layers.8.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
682 |
+
"layers.8.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
683 |
+
"layers.9.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
684 |
+
"layers.9.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
685 |
+
"layers.9.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
686 |
+
"layers.9.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
687 |
+
"layers.9.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
|
688 |
+
"layers.9.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
689 |
+
"layers.9.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
690 |
+
"layers.9.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
691 |
+
"layers.9.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
692 |
+
"layers.9.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
693 |
+
"layers.9.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
|
694 |
+
"layers.9.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
|
695 |
+
"lm_head.weight": "pytorch-model-00003-of-00004.safetensors",
|
696 |
+
"norm.weight": "pytorch-model-00003-of-00004.safetensors"
|
697 |
+
}
|
698 |
+
}
|
models/config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
2 |
+
|
3 |
+
class NextStepConfig(LlamaConfig):
|
4 |
+
|
5 |
+
model_type = "nextstep"
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
vae_name_or_path: str | None = None,
|
10 |
+
latent_size: int = 32,
|
11 |
+
latent_patch_size: int = 2,
|
12 |
+
latent_channels: int = 16,
|
13 |
+
boi: int | None = None,
|
14 |
+
eoi: int | None = None,
|
15 |
+
image_placeholder_id: int | None = None,
|
16 |
+
pad_token_id_added: int | None = None,
|
17 |
+
lm_loss_weight: float = 0.01,
|
18 |
+
im_loss_weight: float = 1.0,
|
19 |
+
fm_head_dim: int = 1536,
|
20 |
+
fm_head_layers: int = 12,
|
21 |
+
fm_head_batch_mul: int = 4,
|
22 |
+
o_attention_bias: bool | None = None,
|
23 |
+
**kwargs,
|
24 |
+
):
|
25 |
+
super().__init__(**kwargs)
|
26 |
+
|
27 |
+
self.vae_name_or_path = vae_name_or_path
|
28 |
+
|
29 |
+
self.latent_size = latent_size
|
30 |
+
self.latent_patch_size = latent_patch_size
|
31 |
+
self.latent_channels = latent_channels
|
32 |
+
|
33 |
+
self.boi = boi
|
34 |
+
self.eoi = eoi
|
35 |
+
self.image_placeholder_id = image_placeholder_id
|
36 |
+
self.pad_token_id_added = pad_token_id_added
|
37 |
+
|
38 |
+
self.lm_loss_weight = lm_loss_weight
|
39 |
+
self.im_loss_weight = im_loss_weight
|
40 |
+
|
41 |
+
self.fm_head_dim = fm_head_dim
|
42 |
+
self.fm_head_layers = fm_head_layers
|
43 |
+
self.fm_head_batch_mul = fm_head_batch_mul
|
44 |
+
|
45 |
+
self.o_attention_bias = self.attention_bias if o_attention_bias is None else o_attention_bias
|
models/gen_pipeline.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import copy
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
|
12 |
+
from transformers import AutoTokenizer
|
13 |
+
from transformers.cache_utils import Cache, StaticCache
|
14 |
+
|
15 |
+
from models.nextstep_model import NextStep
|
16 |
+
from vae.nextstep_ae import AutoencoderKL
|
17 |
+
from utils.image_utils import to_pil
|
18 |
+
from utils.model_utils import layer_norm
|
19 |
+
from utils.compile_utils import compile_manager
|
20 |
+
from utils.misc import set_seed
|
21 |
+
|
22 |
+
DEFAULT_IMAGE_AREA_TOKEN = "<|image_area|>"
|
23 |
+
|
24 |
+
def hw2str(h: int, w: int) -> str:
|
25 |
+
return f"{h}*{w}"
|
26 |
+
|
27 |
+
|
28 |
+
class NextStepPipeline:
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
model_name_or_path: str | None = None,
|
32 |
+
vae_name_or_path: str | None = None,
|
33 |
+
tokenizer: AutoTokenizer | None = None,
|
34 |
+
model: nn.Module | None = None,
|
35 |
+
vae: AutoencoderKL | None = None,
|
36 |
+
):
|
37 |
+
if model is not None:
|
38 |
+
self.tokenizer = copy.deepcopy(tokenizer)
|
39 |
+
self.tokenizer.padding_side = "left"
|
40 |
+
self.model = model
|
41 |
+
elif model_name_or_path is not None:
|
42 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
43 |
+
model_name_or_path,
|
44 |
+
local_files_only=True,
|
45 |
+
padding_side="left",
|
46 |
+
use_fast=True,
|
47 |
+
)
|
48 |
+
self.model: NextStep = NextStep.from_pretrained(model_name_or_path, local_files_only=True)
|
49 |
+
else:
|
50 |
+
raise ValueError("model or model_name_or_path is required")
|
51 |
+
|
52 |
+
self.tokenizer.add_eos_token = False
|
53 |
+
if vae_name_or_path is None:
|
54 |
+
vae_name_or_path = getattr(self.model.config, "vae_name_or_path", None)
|
55 |
+
if vae is not None:
|
56 |
+
self.vae = vae
|
57 |
+
elif vae_name_or_path is not None:
|
58 |
+
self.vae = AutoencoderKL.from_pretrained(vae_name_or_path)
|
59 |
+
else:
|
60 |
+
raise ValueError("vae or vae_name_or_path is required")
|
61 |
+
|
62 |
+
self.model.eval()
|
63 |
+
self.vae.eval()
|
64 |
+
|
65 |
+
vae_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
66 |
+
self.down_factor = vae_factor * self.model.config.latent_patch_size
|
67 |
+
self.shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
|
68 |
+
self.scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
|
69 |
+
|
70 |
+
self.boi = self.model.config.boi
|
71 |
+
self.eoi = self.model.config.eoi
|
72 |
+
|
73 |
+
self.image_placeholder_id = self.model.config.image_placeholder_id
|
74 |
+
self.pil2tensor = transforms.Compose(
|
75 |
+
[
|
76 |
+
transforms.ToTensor(),
|
77 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
78 |
+
]
|
79 |
+
)
|
80 |
+
self.__device = self.model.device
|
81 |
+
self.__dtype = self.model.dtype
|
82 |
+
self.to(self.device, self.dtype)
|
83 |
+
|
84 |
+
@property
|
85 |
+
def device(self):
|
86 |
+
return self.__device
|
87 |
+
|
88 |
+
@property
|
89 |
+
def device_type(self):
|
90 |
+
if isinstance(self.__device, str):
|
91 |
+
return self.__device
|
92 |
+
return self.__device.type
|
93 |
+
|
94 |
+
@property
|
95 |
+
def dtype(self):
|
96 |
+
return self.__dtype
|
97 |
+
|
98 |
+
def to(self, device: str | None = None, dtype: torch.dtype | None = None):
|
99 |
+
if device is not None:
|
100 |
+
self.__device = device
|
101 |
+
if dtype is not None:
|
102 |
+
self.__dtype = dtype
|
103 |
+
self.model.to(self.__device, dtype=self.__dtype)
|
104 |
+
self.vae.to(self.__device, dtype=self.__dtype)
|
105 |
+
return self
|
106 |
+
|
107 |
+
def _image_str(self, hw: tuple[int, int] = (256, 256)):
|
108 |
+
latent_hw = (hw[0] // self.down_factor, hw[1] // self.down_factor)
|
109 |
+
image_ids = [self.boi] + [self.image_placeholder_id] * (latent_hw[0] * latent_hw[1]) + [self.eoi]
|
110 |
+
image_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(*latent_hw) + self.tokenizer.decode(image_ids)
|
111 |
+
return image_str
|
112 |
+
|
113 |
+
def _check_input(
|
114 |
+
self, captions: str | list[str], images: Image.Image | list[Image.Image] | None
|
115 |
+
) -> tuple[list[str], list[Image.Image] | None]:
|
116 |
+
if not isinstance(captions, list):
|
117 |
+
captions = [captions]
|
118 |
+
if images is not None:
|
119 |
+
if not isinstance(images, list):
|
120 |
+
images = [images]
|
121 |
+
# Validate image count matches <image> tokens in captions
|
122 |
+
image_token_count = 0
|
123 |
+
for caption in captions:
|
124 |
+
num_image_token = len(re.findall(r"<image>", caption))
|
125 |
+
assert num_image_token == 1, f"Caption `{caption}` has {num_image_token} image tokens, but only 1 is allowed."
|
126 |
+
image_token_count += num_image_token
|
127 |
+
if image_token_count != len(images):
|
128 |
+
raise ValueError(
|
129 |
+
f"Number of images ({len(images)}) does not match number of image tokens ({image_token_count}).\n"
|
130 |
+
f"Captions: {captions}"
|
131 |
+
)
|
132 |
+
hws = [(image.size[1], image.size[0]) for image in images]
|
133 |
+
# Replace <image> tokens sequentially with corresponding image_str based on hw
|
134 |
+
processed_captions = []
|
135 |
+
image_idx = 0
|
136 |
+
for caption in captions:
|
137 |
+
# Process each caption
|
138 |
+
processed_caption = caption
|
139 |
+
num_image_tokens = processed_caption.count("<image>")
|
140 |
+
# Replace each <image> token in order
|
141 |
+
for _ in range(num_image_tokens):
|
142 |
+
processed_caption = processed_caption.replace("<image>", self._image_str(hws[image_idx]), 1)
|
143 |
+
image_idx += 1
|
144 |
+
processed_captions.append(processed_caption)
|
145 |
+
captions = processed_captions
|
146 |
+
return captions, images
|
147 |
+
|
148 |
+
def _build_captions(
|
149 |
+
self,
|
150 |
+
captions: str | list[str],
|
151 |
+
images: list[Image.Image] | None = None,
|
152 |
+
num_images_per_caption: int = 1,
|
153 |
+
positive_prompt: str | None = None,
|
154 |
+
negative_prompt: str | None = None,
|
155 |
+
cfg: float = 1.0,
|
156 |
+
cfg_img: float = 1.0,
|
157 |
+
):
|
158 |
+
# 1. repeat captions and images
|
159 |
+
if not isinstance(captions, list):
|
160 |
+
captions = [captions]
|
161 |
+
|
162 |
+
captions = [caption for caption in captions for _ in range(num_images_per_caption)]
|
163 |
+
if images is not None:
|
164 |
+
images = [image for image in images for _ in range(num_images_per_caption)]
|
165 |
+
|
166 |
+
# 2. add positive prompt
|
167 |
+
if positive_prompt is not None and positive_prompt != "":
|
168 |
+
captions = [f"{caption} {positive_prompt}" for caption in captions]
|
169 |
+
|
170 |
+
# 3. add negative prompt
|
171 |
+
if negative_prompt is None:
|
172 |
+
negative_prompt = ""
|
173 |
+
|
174 |
+
num_samples = len(captions)
|
175 |
+
if cfg != 1.0 and cfg_img != 1.0: # use both image and text CFG
|
176 |
+
w, h = images[0].size
|
177 |
+
captions = (
|
178 |
+
captions + [self._image_str((h, w)) + negative_prompt] * num_samples
|
179 |
+
)
|
180 |
+
images = images + images
|
181 |
+
captions = captions + [negative_prompt] * num_samples
|
182 |
+
elif cfg != 1.0 and cfg_img == 1.0: # use text CFG
|
183 |
+
captions = captions + [negative_prompt] * num_samples
|
184 |
+
elif cfg == 1.0 and cfg_img == 1.0:
|
185 |
+
pass
|
186 |
+
|
187 |
+
return captions, images
|
188 |
+
|
189 |
+
def _add_prefix_ids(self, hw: tuple[int, int], input_ids: torch.Tensor, attention_mask: torch.Tensor):
|
190 |
+
prefix_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(hw[0] // self.down_factor, hw[1] // self.down_factor)
|
191 |
+
prefix_output = self.tokenizer(
|
192 |
+
prefix_str,
|
193 |
+
truncation=False,
|
194 |
+
add_special_tokens=True,
|
195 |
+
return_tensors="pt"
|
196 |
+
)
|
197 |
+
prefix_input_ids = prefix_output.input_ids.to(input_ids.device, dtype=input_ids.dtype)
|
198 |
+
prefix_attention_mask = prefix_output.attention_mask.to(attention_mask.device, dtype=attention_mask.dtype)
|
199 |
+
# remove bos token
|
200 |
+
if self.tokenizer.bos_token is not None:
|
201 |
+
prefix_input_ids = prefix_input_ids[:, 1:]
|
202 |
+
prefix_attention_mask = prefix_attention_mask[:, 1:]
|
203 |
+
# add boi token
|
204 |
+
prefix_input_ids = torch.cat(
|
205 |
+
[
|
206 |
+
prefix_input_ids,
|
207 |
+
prefix_input_ids.new_tensor([self.model.config.boi]).unsqueeze(0),
|
208 |
+
],
|
209 |
+
dim=1,
|
210 |
+
)
|
211 |
+
prefix_attention_mask = torch.cat(
|
212 |
+
[
|
213 |
+
prefix_attention_mask,
|
214 |
+
prefix_attention_mask.new_ones((prefix_attention_mask.shape[0], 1)),
|
215 |
+
],
|
216 |
+
dim=1,
|
217 |
+
)
|
218 |
+
bsz = input_ids.shape[0]
|
219 |
+
input_ids = torch.cat([input_ids, prefix_input_ids.expand(bsz, -1)], dim=1)
|
220 |
+
attention_mask = torch.cat([attention_mask, prefix_attention_mask.expand(bsz, -1)], dim=1)
|
221 |
+
|
222 |
+
return input_ids, attention_mask
|
223 |
+
|
224 |
+
@torch.no_grad()
|
225 |
+
def decoding(
|
226 |
+
self,
|
227 |
+
c: torch.Tensor,
|
228 |
+
attention_mask: torch.Tensor,
|
229 |
+
past_key_values: Cache,
|
230 |
+
max_new_len: int,
|
231 |
+
num_images_per_caption: int,
|
232 |
+
use_norm: bool = False,
|
233 |
+
cfg: float = 1.0,
|
234 |
+
cfg_img: float = 1.0,
|
235 |
+
cfg_schedule: Literal["linear", "constant"] = "constant",
|
236 |
+
timesteps_shift: float = 1.0,
|
237 |
+
num_sampling_steps: int = 20,
|
238 |
+
progress: bool = True,
|
239 |
+
hw: tuple[int, int] = (256, 256),
|
240 |
+
step: int = 0,
|
241 |
+
):
|
242 |
+
indices = list(range(max_new_len))
|
243 |
+
indices = tqdm(indices, unit="tokens") if progress else indices
|
244 |
+
tokens = None
|
245 |
+
for step in indices:
|
246 |
+
# cfg schedule follow Muse
|
247 |
+
if cfg_schedule == "linear":
|
248 |
+
tokens_len = 0 if tokens is None else tokens.shape[1]
|
249 |
+
cfg_iter = max(cfg / 2, 1 + (cfg - 1) * tokens_len / max_new_len)
|
250 |
+
cfg_img_iter = max(cfg_img / 2, 1 + (cfg_img - 1) * tokens_len / max_new_len)
|
251 |
+
elif cfg_schedule == "constant":
|
252 |
+
cfg_iter = cfg
|
253 |
+
cfg_img_iter = cfg_img
|
254 |
+
else:
|
255 |
+
raise NotImplementedError
|
256 |
+
|
257 |
+
c = self.model.image_out_projector(c)
|
258 |
+
token_sampled = self.model.image_head.sample(
|
259 |
+
c=c.squeeze(1),
|
260 |
+
cfg=cfg_iter,
|
261 |
+
cfg_img=cfg_img_iter,
|
262 |
+
timesteps_shift=timesteps_shift,
|
263 |
+
num_sampling_steps=num_sampling_steps,
|
264 |
+
noise_repeat=num_images_per_caption,
|
265 |
+
)
|
266 |
+
|
267 |
+
if use_norm:
|
268 |
+
token_sampled = layer_norm(token_sampled, normalized_shape=token_sampled.size()[1:])
|
269 |
+
if tokens is not None:
|
270 |
+
tokens = torch.cat([tokens, token_sampled.unsqueeze(1)], dim=1)
|
271 |
+
else:
|
272 |
+
tokens = token_sampled.unsqueeze(1)
|
273 |
+
|
274 |
+
cur_inputs_embeds = self.model.image_in_projector(tokens[:, -1:])
|
275 |
+
if cfg != 1.0 and cfg_img == 1.0:
|
276 |
+
cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds], dim=0)
|
277 |
+
elif cfg != 1.0 and cfg_img != 1.0:
|
278 |
+
cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds, cur_inputs_embeds], dim=0)
|
279 |
+
|
280 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
281 |
+
outputs = self.model.forward_model(
|
282 |
+
inputs_embeds=cur_inputs_embeds,
|
283 |
+
attention_mask=attention_mask,
|
284 |
+
past_key_values=past_key_values,
|
285 |
+
use_cache=True,
|
286 |
+
)
|
287 |
+
past_key_values = outputs.past_key_values
|
288 |
+
c = outputs.last_hidden_state[:, -1:]
|
289 |
+
if self.model.config.use_gen_pos_embed:
|
290 |
+
c = c + self.model.gen_pos_embed_with_ar(hw[0], hw[1])[:, step + 1 : step + 2, :]
|
291 |
+
|
292 |
+
return tokens
|
293 |
+
|
294 |
+
@torch.no_grad()
|
295 |
+
def generate_image(
|
296 |
+
self,
|
297 |
+
captions: str | list[str],
|
298 |
+
images: list[Image.Image] | None = None,
|
299 |
+
num_images_per_caption: int = 1,
|
300 |
+
positive_prompt: str | None = None,
|
301 |
+
negative_prompt: str | None = None,
|
302 |
+
hw: tuple[int, int] = (256, 256),
|
303 |
+
use_norm: bool = False,
|
304 |
+
cfg: float = 1.0,
|
305 |
+
cfg_img: float = 1.0,
|
306 |
+
cfg_schedule: Literal["linear", "constant"] = "constant",
|
307 |
+
num_sampling_steps: int = 20,
|
308 |
+
timesteps_shift: float = 1.0,
|
309 |
+
seed: int = 42,
|
310 |
+
progress: bool = True,
|
311 |
+
) -> list[Image.Image]:
|
312 |
+
# 0. set seed
|
313 |
+
if seed is not None:
|
314 |
+
set_seed(seed)
|
315 |
+
|
316 |
+
# 1. check input
|
317 |
+
captions, images = self._check_input(captions, images)
|
318 |
+
|
319 |
+
# 2. build captions
|
320 |
+
captions, images = self._build_captions(
|
321 |
+
captions, images, num_images_per_caption, positive_prompt, negative_prompt, cfg, cfg_img
|
322 |
+
)
|
323 |
+
|
324 |
+
# 3. encode images
|
325 |
+
# `images` must be processed by `process_images` before calling this function
|
326 |
+
latents = None
|
327 |
+
if images is not None:
|
328 |
+
pixel_values = [self.pil2tensor(image) for image in images]
|
329 |
+
pixel_values = torch.stack(pixel_values).to(self.device)
|
330 |
+
with compile_manager.compile_disabled():
|
331 |
+
posterior = self.vae.encode(pixel_values.to(self.vae.dtype)).latent_dist
|
332 |
+
latents = (posterior.sample() - self.shift_factor) * self.scaling_factor
|
333 |
+
captions = [self.tokenizer.bos_token + caption if self.tokenizer.bos_token is not None else caption for caption in captions]
|
334 |
+
|
335 |
+
# 4. tokenize caption & add prefix ids
|
336 |
+
output = self.tokenizer(
|
337 |
+
captions,
|
338 |
+
padding="longest",
|
339 |
+
truncation=False,
|
340 |
+
add_special_tokens=True,
|
341 |
+
return_tensors="pt",
|
342 |
+
padding_side="left"
|
343 |
+
)
|
344 |
+
input_ids = output.input_ids.to(self.device)
|
345 |
+
attention_mask = output.attention_mask.to(self.device)
|
346 |
+
input_ids, attention_mask = self._add_prefix_ids(hw, input_ids, attention_mask)
|
347 |
+
|
348 |
+
# 5. LLM prefill
|
349 |
+
max_new_len = (hw[0] // self.down_factor) * (hw[1] // self.down_factor)
|
350 |
+
max_cache_len = input_ids.shape[1] + max_new_len
|
351 |
+
past_key_values = StaticCache(
|
352 |
+
config=self.model.config,
|
353 |
+
max_batch_size=input_ids.shape[0],
|
354 |
+
max_cache_len=max_cache_len,
|
355 |
+
device=self.device,
|
356 |
+
dtype=self.dtype,
|
357 |
+
)
|
358 |
+
inputs_embeds = self.model.prepare_inputs_embeds(input_ids, latents)
|
359 |
+
with compile_manager.compile_disabled():
|
360 |
+
outputs = self.model.forward_model(
|
361 |
+
inputs_embeds=inputs_embeds,
|
362 |
+
attention_mask=attention_mask,
|
363 |
+
past_key_values=past_key_values,
|
364 |
+
use_cache=True,
|
365 |
+
)
|
366 |
+
past_key_values = outputs.past_key_values
|
367 |
+
c = outputs.last_hidden_state[:, -1:]
|
368 |
+
if self.model.config.use_gen_pos_embed:
|
369 |
+
c = c + self.model.gen_pos_embed_with_ar(hw[0], hw[1])[:, 0:1, :]
|
370 |
+
|
371 |
+
# 6. decoding
|
372 |
+
tokens = self.decoding(
|
373 |
+
c=c,
|
374 |
+
attention_mask=attention_mask,
|
375 |
+
past_key_values=past_key_values,
|
376 |
+
max_new_len=max_new_len,
|
377 |
+
num_images_per_caption=num_images_per_caption,
|
378 |
+
use_norm=use_norm,
|
379 |
+
cfg=cfg,
|
380 |
+
cfg_img=cfg_img,
|
381 |
+
cfg_schedule=cfg_schedule,
|
382 |
+
timesteps_shift=timesteps_shift,
|
383 |
+
num_sampling_steps=num_sampling_steps,
|
384 |
+
progress=progress,
|
385 |
+
hw=hw,
|
386 |
+
)
|
387 |
+
|
388 |
+
# 7. unpatchify
|
389 |
+
latents = self.model.unpatchify(tokens)
|
390 |
+
latents = (latents / self.scaling_factor) + self.shift_factor
|
391 |
+
|
392 |
+
# 8. decode latents
|
393 |
+
with compile_manager.compile_disabled():
|
394 |
+
sampled_images = self.vae.decode(latents.to(self.vae.dtype)).sample
|
395 |
+
sampled_images = sampled_images.detach().cpu().to(torch.float32)
|
396 |
+
pil_images = [to_pil(img) for img in sampled_images]
|
397 |
+
|
398 |
+
return pil_images
|
models/heads.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.utils.checkpoint import checkpoint
|
6 |
+
|
7 |
+
from transformers.activations import ACT2FN
|
8 |
+
|
9 |
+
from models.config import LlamaConfig
|
10 |
+
from utils.misc import LargeInt
|
11 |
+
from utils.model_utils import expand_t, randn_tensor
|
12 |
+
from utils.compile_utils import smart_compile
|
13 |
+
|
14 |
+
|
15 |
+
class LlamaMLP(nn.Module):
|
16 |
+
def __init__(self, config: LlamaConfig):
|
17 |
+
super().__init__()
|
18 |
+
self.config = config
|
19 |
+
self.hidden_size = config.hidden_size
|
20 |
+
self.intermediate_size = config.intermediate_size
|
21 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
22 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
23 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
24 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
28 |
+
return down_proj
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def modulate(x, shift, scale=None):
|
34 |
+
if shift is None:
|
35 |
+
return x * (1 + scale)
|
36 |
+
return x * (1 + scale) + shift
|
37 |
+
|
38 |
+
|
39 |
+
class ResBlock(nn.Module):
|
40 |
+
def __init__(self, channels, mlp_ratio=1.0):
|
41 |
+
super().__init__()
|
42 |
+
self.channels = channels
|
43 |
+
self.intermediate_size = int(channels * mlp_ratio)
|
44 |
+
|
45 |
+
self.in_ln = nn.LayerNorm(self.channels, eps=1e-6)
|
46 |
+
self.mlp = nn.Sequential(
|
47 |
+
nn.Linear(self.channels, self.intermediate_size),
|
48 |
+
nn.SiLU(),
|
49 |
+
nn.Linear(self.intermediate_size, self.channels),
|
50 |
+
)
|
51 |
+
|
52 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
|
53 |
+
|
54 |
+
def forward(self, x, y):
|
55 |
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
|
56 |
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
57 |
+
h = self.mlp(h)
|
58 |
+
return x + gate_mlp * h
|
59 |
+
|
60 |
+
|
61 |
+
class FinalLayer(nn.Module):
|
62 |
+
def __init__(self, model_channels, out_channels):
|
63 |
+
super().__init__()
|
64 |
+
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
65 |
+
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
66 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True))
|
67 |
+
|
68 |
+
def forward(self, x, c):
|
69 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
70 |
+
x = modulate(self.norm_final(x), shift, scale)
|
71 |
+
x = self.linear(x)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class TimestepEmbedder(nn.Module):
|
76 |
+
"""
|
77 |
+
Embeds scalar timesteps into vector representations.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
81 |
+
super().__init__()
|
82 |
+
self.mlp = nn.Sequential(
|
83 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
84 |
+
nn.SiLU(),
|
85 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
86 |
+
)
|
87 |
+
self.frequency_embedding_size = frequency_embedding_size
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0):
|
91 |
+
"""
|
92 |
+
Create sinusoidal timestep embeddings.
|
93 |
+
:param t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
94 |
+
:param dim: the dimension of the output.
|
95 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
96 |
+
:return: an (N, D) Tensor of positional embeddings.
|
97 |
+
"""
|
98 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
99 |
+
half = dim // 2
|
100 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
101 |
+
device=t.device
|
102 |
+
)
|
103 |
+
args = t[:, None].float() * freqs[None]
|
104 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
105 |
+
if dim % 2:
|
106 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
107 |
+
return embedding
|
108 |
+
|
109 |
+
def forward(self, t):
|
110 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
111 |
+
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
112 |
+
return t_emb
|
113 |
+
|
114 |
+
|
115 |
+
class SimpleMLPAdaLN(nn.Module):
|
116 |
+
def __init__(self, input_dim, cond_dim, dim=1536, layers=12, mlp_ratio=1.0):
|
117 |
+
super().__init__()
|
118 |
+
self.input_dim = input_dim
|
119 |
+
self.cond_dim = cond_dim
|
120 |
+
self.dim = dim
|
121 |
+
self.layers = layers
|
122 |
+
self.mlp_ratio = mlp_ratio
|
123 |
+
|
124 |
+
self.time_embed = TimestepEmbedder(dim)
|
125 |
+
self.cond_embed = nn.Linear(cond_dim, dim)
|
126 |
+
self.input_proj = nn.Linear(input_dim, dim)
|
127 |
+
|
128 |
+
res_blocks = []
|
129 |
+
for _ in range(layers):
|
130 |
+
res_blocks.append(ResBlock(dim, mlp_ratio))
|
131 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
132 |
+
|
133 |
+
self.final_layer = FinalLayer(dim, input_dim)
|
134 |
+
|
135 |
+
self.grad_checkpointing = False
|
136 |
+
|
137 |
+
self.initialize_weights()
|
138 |
+
|
139 |
+
def initialize_weights(self):
|
140 |
+
def _basic_init(module):
|
141 |
+
if isinstance(module, nn.Linear):
|
142 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
143 |
+
if module.bias is not None:
|
144 |
+
nn.init.constant_(module.bias, 0)
|
145 |
+
|
146 |
+
self.apply(_basic_init)
|
147 |
+
|
148 |
+
# Initialize timestep embedding MLP
|
149 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
150 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
151 |
+
|
152 |
+
# Zero-out adaLN modulation layers
|
153 |
+
for block in self.res_blocks:
|
154 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
155 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
156 |
+
|
157 |
+
# Zero-out output layers
|
158 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
159 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
160 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
161 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
162 |
+
|
163 |
+
@smart_compile()
|
164 |
+
def forward(self, x, t, c):
|
165 |
+
"""
|
166 |
+
x.shape = (bsz, input_dim)
|
167 |
+
t.shape = (bsz,)
|
168 |
+
c.shape = (bsz, cond_dim)
|
169 |
+
"""
|
170 |
+
|
171 |
+
x = self.input_proj(x)
|
172 |
+
t = self.time_embed(t)
|
173 |
+
c = self.cond_embed(c)
|
174 |
+
|
175 |
+
y = t + c
|
176 |
+
|
177 |
+
for block in self.res_blocks:
|
178 |
+
if self.grad_checkpointing and self.training:
|
179 |
+
x = checkpoint(block, x, y, use_reentrant=True)
|
180 |
+
else:
|
181 |
+
x = block(x, y)
|
182 |
+
|
183 |
+
return self.final_layer(x, y)
|
184 |
+
|
185 |
+
|
186 |
+
class FlowMatchingHead(nn.Module):
|
187 |
+
|
188 |
+
def __init__(self, input_dim, cond_dim, dim=1536, layers=12, mlp_ratio=1.0):
|
189 |
+
super(FlowMatchingHead, self).__init__()
|
190 |
+
self.input_dim = input_dim
|
191 |
+
self.net = SimpleMLPAdaLN(input_dim=input_dim, cond_dim=cond_dim, dim=dim, layers=layers, mlp_ratio=mlp_ratio)
|
192 |
+
|
193 |
+
@property
|
194 |
+
def dtype(self):
|
195 |
+
return self.net.input_proj.weight.dtype
|
196 |
+
|
197 |
+
@property
|
198 |
+
def device(self):
|
199 |
+
return self.net.input_proj.weight.device
|
200 |
+
|
201 |
+
@property
|
202 |
+
def trainable_params(self) -> float:
|
203 |
+
n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
204 |
+
return LargeInt(n_params)
|
205 |
+
|
206 |
+
|
207 |
+
def get_score_from_velocity(self, velocity, x, t):
|
208 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
209 |
+
Args:
|
210 |
+
velocity: [bsz, ...] shaped tensor; velocity model output
|
211 |
+
x: [bsz, ...] shaped tensor; x_t data point
|
212 |
+
t: [bsz,] time tensor
|
213 |
+
"""
|
214 |
+
t = expand_t(t, x)
|
215 |
+
alpha_t, d_alpha_t = t, 1
|
216 |
+
sigma_t, d_sigma_t = 1 - t, -1
|
217 |
+
mean = x
|
218 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
219 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
220 |
+
score = (reverse_alpha_ratio * velocity - mean) / var
|
221 |
+
return score
|
222 |
+
|
223 |
+
def get_velocity_from_cfg(self, velocity, cfg, cfg_img, cfg_mult):
|
224 |
+
if cfg_mult == 2:
|
225 |
+
cond_v, uncond_v = torch.chunk(velocity, 2, dim=0)
|
226 |
+
velocity = uncond_v + cfg * (cond_v - uncond_v)
|
227 |
+
elif cfg_mult == 3:
|
228 |
+
cond_v, uncond_v1, uncond_v2 = torch.chunk(velocity, 3, dim=0)
|
229 |
+
velocity = uncond_v2 + cfg_img * (uncond_v1 - uncond_v2) + cfg * (cond_v - uncond_v1)
|
230 |
+
return velocity
|
231 |
+
|
232 |
+
@smart_compile(options={"triton.cudagraphs": True}, fullgraph=True)
|
233 |
+
@torch.no_grad()
|
234 |
+
def sample(
|
235 |
+
self,
|
236 |
+
c: torch.Tensor,
|
237 |
+
cfg: float = 1.0,
|
238 |
+
cfg_img: float = 1.0,
|
239 |
+
timesteps_shift: float = 1.0,
|
240 |
+
num_sampling_steps: int = 20,
|
241 |
+
last_step_size: float = 0.0,
|
242 |
+
noise_repeat: int = 1,
|
243 |
+
):
|
244 |
+
# """c.shape = (bsz, cond_dim)"""
|
245 |
+
cfg_mult = 1
|
246 |
+
if cfg > 1.0:
|
247 |
+
cfg_mult += 1
|
248 |
+
if cfg_img > 1.0:
|
249 |
+
cfg_mult += 1
|
250 |
+
|
251 |
+
noise = randn_tensor((c.shape[0] // cfg_mult, self.input_dim), noise_repeat, self.device)
|
252 |
+
|
253 |
+
mean_x = noise
|
254 |
+
x = noise
|
255 |
+
xs = []
|
256 |
+
|
257 |
+
t0, t1 = 0, 1
|
258 |
+
timesteps = torch.linspace(t0, t1, num_sampling_steps + 1, device=c.device)[:-1]
|
259 |
+
timesteps = timesteps / (timesteps_shift - (timesteps_shift - 1) * timesteps)
|
260 |
+
timesteps = torch.cat([timesteps, torch.ones(1, device=c.device)])
|
261 |
+
for ti, tj in zip(timesteps[:-1], timesteps[1:]):
|
262 |
+
dt = tj - ti
|
263 |
+
|
264 |
+
combined = torch.cat([x] * cfg_mult, dim=0)
|
265 |
+
velocity = self.net(combined.to(c.dtype), ti.expand(c.shape[0]).to(c), c)
|
266 |
+
velocity = velocity.to(torch.float32)
|
267 |
+
|
268 |
+
velocity = self.get_velocity_from_cfg(velocity, cfg, cfg_img, cfg_mult)
|
269 |
+
score = self.get_score_from_velocity(velocity, x, ti.expand(x.shape[0]).to(x))
|
270 |
+
drift = velocity + (1 - expand_t(ti.expand(x.shape[0]).to(x), x)) * score
|
271 |
+
|
272 |
+
w_cur = randn_tensor((c.shape[0] // cfg_mult, self.input_dim), noise_repeat, self.device)
|
273 |
+
dw = w_cur * torch.sqrt(dt)
|
274 |
+
|
275 |
+
mean_x = x + drift * dt
|
276 |
+
x = mean_x + torch.sqrt(2 * (1 - expand_t(ti.expand(x.shape[0]).to(x), x))) * dw
|
277 |
+
xs.append(x)
|
278 |
+
|
279 |
+
|
280 |
+
if len(xs) != num_sampling_steps:
|
281 |
+
raise ValueError(f"Samples ({len(xs)}) does not match the number of steps ({num_sampling_steps})")
|
282 |
+
|
283 |
+
return xs[-1].to(c.dtype)
|
models/llama_model.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
from loguru import logger
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from transformers.cache_utils import Cache, StaticCache
|
9 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
10 |
+
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
11 |
+
from transformers import ROPE_INIT_FUNCTIONS
|
12 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
13 |
+
|
14 |
+
from models.heads import LlamaMLP
|
15 |
+
from utils.model_utils import apply_rotary_pos_emb, repeat_kv
|
16 |
+
from models.config import NextStepConfig
|
17 |
+
|
18 |
+
|
19 |
+
class LlamaRMSNorm(nn.Module):
|
20 |
+
"""LlamaRMSNorm is equivalent to T5LayerNorm"""
|
21 |
+
|
22 |
+
def __init__(self, hidden_size, eps=1e-6):
|
23 |
+
super().__init__()
|
24 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
25 |
+
self.variance_epsilon = eps
|
26 |
+
|
27 |
+
def forward(self, hidden_states):
|
28 |
+
input_dtype = hidden_states.dtype
|
29 |
+
hidden_states = hidden_states.to(torch.float32)
|
30 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
31 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
32 |
+
return self.weight * hidden_states.to(input_dtype)
|
33 |
+
|
34 |
+
def extra_repr(self):
|
35 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
36 |
+
|
37 |
+
|
38 |
+
class LlamaRotaryEmbedding(nn.Module):
|
39 |
+
def __init__(self, device=None, config: Optional[LlamaConfig] = None):
|
40 |
+
super().__init__()
|
41 |
+
self.rope_type = "default"
|
42 |
+
self.config = config
|
43 |
+
|
44 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
45 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
46 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def forward(self, x, position_ids):
|
50 |
+
# Core RoPE block
|
51 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
52 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
53 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
54 |
+
device_type = x.device.type
|
55 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
56 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
57 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
58 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
59 |
+
cos = emb.cos()
|
60 |
+
sin = emb.sin()
|
61 |
+
|
62 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
63 |
+
cos = cos * self.attention_scaling
|
64 |
+
sin = sin * self.attention_scaling
|
65 |
+
|
66 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
67 |
+
|
68 |
+
|
69 |
+
class LlamaAttention(nn.Module):
|
70 |
+
def __init__(self, config: NextStepConfig, layer_idx: Optional[int]):
|
71 |
+
super().__init__()
|
72 |
+
self.config = config
|
73 |
+
self.layer_idx = layer_idx
|
74 |
+
|
75 |
+
self.attention_dropout = config.attention_dropout
|
76 |
+
self.hidden_size = config.hidden_size
|
77 |
+
self.num_heads = config.num_attention_heads
|
78 |
+
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
79 |
+
self.num_key_value_heads = config.num_key_value_heads
|
80 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
81 |
+
self.max_position_embeddings = config.max_position_embeddings
|
82 |
+
self.rope_theta = config.rope_theta
|
83 |
+
self.is_causal = True
|
84 |
+
|
85 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
86 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
87 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
88 |
+
self.o_proj = nn.Linear(
|
89 |
+
self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, "o_attention_bias", config.attention_bias)
|
90 |
+
)
|
91 |
+
self._flash_attn_uses_top_left_mask = False
|
92 |
+
|
93 |
+
def forward_sdpa(
|
94 |
+
self,
|
95 |
+
hidden_states: torch.Tensor,
|
96 |
+
attention_mask: Optional[torch.Tensor] = None,
|
97 |
+
position_ids: Optional[torch.LongTensor] = None,
|
98 |
+
past_key_value: Optional[Cache] = None,
|
99 |
+
output_attentions: bool = False,
|
100 |
+
use_cache: bool = False,
|
101 |
+
cache_position: Optional[torch.LongTensor] = None,
|
102 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
103 |
+
**kwargs,
|
104 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
105 |
+
bsz, q_len, _ = hidden_states.size()
|
106 |
+
|
107 |
+
query_states = self.q_proj(hidden_states)
|
108 |
+
key_states = self.k_proj(hidden_states)
|
109 |
+
value_states = self.v_proj(hidden_states)
|
110 |
+
|
111 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
112 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
113 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
114 |
+
|
115 |
+
if position_embeddings is None:
|
116 |
+
logger.warning_once(
|
117 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
118 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
119 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
120 |
+
"removed and `position_embeddings` will be mandatory."
|
121 |
+
)
|
122 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
123 |
+
else:
|
124 |
+
cos, sin = position_embeddings
|
125 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
126 |
+
|
127 |
+
if past_key_value is not None:
|
128 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
129 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
130 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
131 |
+
|
132 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
133 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
134 |
+
|
135 |
+
causal_mask = attention_mask
|
136 |
+
if attention_mask is not None:
|
137 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
138 |
+
|
139 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
140 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
141 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
142 |
+
query_states = query_states.contiguous()
|
143 |
+
key_states = key_states.contiguous()
|
144 |
+
value_states = value_states.contiguous()
|
145 |
+
|
146 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
147 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
148 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
149 |
+
|
150 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
151 |
+
query_states,
|
152 |
+
key_states,
|
153 |
+
value_states,
|
154 |
+
attn_mask=causal_mask,
|
155 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
156 |
+
is_causal=is_causal,
|
157 |
+
)
|
158 |
+
|
159 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
160 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
161 |
+
|
162 |
+
attn_output = self.o_proj(attn_output)
|
163 |
+
|
164 |
+
return attn_output, None, past_key_value
|
165 |
+
|
166 |
+
def forward_flash(
|
167 |
+
self,
|
168 |
+
hidden_states: torch.Tensor,
|
169 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
170 |
+
position_ids: Optional[torch.LongTensor] = None,
|
171 |
+
past_key_value: Optional[Cache] = None,
|
172 |
+
output_attentions: bool = False,
|
173 |
+
use_cache: bool = False,
|
174 |
+
cache_position: Optional[torch.LongTensor] = None,
|
175 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
176 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
177 |
+
if isinstance(past_key_value, StaticCache):
|
178 |
+
raise ValueError(
|
179 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
180 |
+
"make sure to use `sdpa` in the mean time, and open an issue at GitHub - huggingface/transformers: 🤗 Transformers: the model-definition framework for state-of-the-a"
|
181 |
+
)
|
182 |
+
|
183 |
+
output_attentions = False
|
184 |
+
|
185 |
+
bsz, q_len, _ = hidden_states.size()
|
186 |
+
|
187 |
+
query_states = self.q_proj(hidden_states)
|
188 |
+
key_states = self.k_proj(hidden_states)
|
189 |
+
value_states = self.v_proj(hidden_states)
|
190 |
+
|
191 |
+
# Flash attention requires the input to have the shape
|
192 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
193 |
+
# therefore we just need to keep the original shape
|
194 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
195 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
196 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
197 |
+
|
198 |
+
if position_embeddings is None:
|
199 |
+
logger.warning_once(
|
200 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
201 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
202 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
203 |
+
"removed and `position_embeddings` will be mandatory."
|
204 |
+
)
|
205 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
206 |
+
else:
|
207 |
+
cos, sin = position_embeddings
|
208 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
209 |
+
|
210 |
+
if past_key_value is not None:
|
211 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
212 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
213 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
214 |
+
|
215 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
216 |
+
# to be able to avoid many of these transpose/reshape/view.
|
217 |
+
query_states = query_states.transpose(1, 2)
|
218 |
+
key_states = key_states.transpose(1, 2)
|
219 |
+
value_states = value_states.transpose(1, 2)
|
220 |
+
|
221 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
222 |
+
|
223 |
+
input_dtype = query_states.dtype
|
224 |
+
if input_dtype == torch.float32:
|
225 |
+
if torch.is_autocast_enabled():
|
226 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
227 |
+
# Handle the case where the model is quantized
|
228 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
229 |
+
target_dtype = self.config._pre_quantization_dtype
|
230 |
+
else:
|
231 |
+
target_dtype = self.q_proj.weight.dtype
|
232 |
+
|
233 |
+
logger.warning_once(
|
234 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
235 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
236 |
+
f" {target_dtype}."
|
237 |
+
)
|
238 |
+
|
239 |
+
query_states = query_states.to(target_dtype)
|
240 |
+
key_states = key_states.to(target_dtype)
|
241 |
+
value_states = value_states.to(target_dtype)
|
242 |
+
|
243 |
+
attn_output = _flash_attention_forward(
|
244 |
+
query_states,
|
245 |
+
key_states,
|
246 |
+
value_states,
|
247 |
+
attention_mask,
|
248 |
+
q_len,
|
249 |
+
position_ids=position_ids,
|
250 |
+
dropout=dropout_rate,
|
251 |
+
sliding_window=getattr(self, "sliding_window", None),
|
252 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
253 |
+
is_causal=self.is_causal,
|
254 |
+
)
|
255 |
+
|
256 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
257 |
+
attn_output = self.o_proj(attn_output)
|
258 |
+
|
259 |
+
if not output_attentions:
|
260 |
+
attn_weights = None
|
261 |
+
|
262 |
+
return attn_output, attn_weights, past_key_value
|
263 |
+
|
264 |
+
def forward(
|
265 |
+
self,
|
266 |
+
hidden_states: torch.Tensor,
|
267 |
+
attention_mask: Optional[torch.Tensor] = None,
|
268 |
+
position_ids: Optional[torch.LongTensor] = None,
|
269 |
+
past_key_value: Optional[Cache] = None,
|
270 |
+
output_attentions: bool = False,
|
271 |
+
use_cache: bool = False,
|
272 |
+
cache_position: Optional[torch.LongTensor] = None,
|
273 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
274 |
+
**kwargs,
|
275 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
276 |
+
bsz, q_len, _ = hidden_states.size()
|
277 |
+
|
278 |
+
query_states = self.q_proj(hidden_states)
|
279 |
+
key_states = self.k_proj(hidden_states)
|
280 |
+
value_states = self.v_proj(hidden_states)
|
281 |
+
|
282 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
283 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
284 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
285 |
+
|
286 |
+
if position_embeddings is None:
|
287 |
+
logger.warning_once(
|
288 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
289 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
290 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
291 |
+
"removed and `position_embeddings` will be mandatory."
|
292 |
+
)
|
293 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
294 |
+
else:
|
295 |
+
cos, sin = position_embeddings
|
296 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
297 |
+
|
298 |
+
if past_key_value is not None:
|
299 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
300 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
301 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
302 |
+
|
303 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
304 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
305 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
306 |
+
|
307 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
308 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
309 |
+
attn_weights = attn_weights + causal_mask
|
310 |
+
|
311 |
+
# upcast attention to fp32
|
312 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
313 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
314 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
315 |
+
|
316 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
317 |
+
raise ValueError(
|
318 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
319 |
+
f" {attn_output.size()}"
|
320 |
+
)
|
321 |
+
|
322 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
323 |
+
|
324 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
325 |
+
|
326 |
+
attn_output = self.o_proj(attn_output)
|
327 |
+
|
328 |
+
if not output_attentions:
|
329 |
+
attn_weights = None
|
330 |
+
|
331 |
+
return attn_output, attn_weights, past_key_value
|
332 |
+
|
333 |
+
|
334 |
+
class LlamaFlashAttention2(LlamaAttention):
|
335 |
+
"""
|
336 |
+
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
337 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
338 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
339 |
+
"""
|
340 |
+
|
341 |
+
def __init__(self, *args, **kwargs):
|
342 |
+
super().__init__(*args, **kwargs)
|
343 |
+
|
344 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
345 |
+
|
346 |
+
def forward(
|
347 |
+
self,
|
348 |
+
hidden_states: torch.Tensor,
|
349 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
350 |
+
past_key_value: Optional[Cache] = None,
|
351 |
+
output_attentions: bool = False,
|
352 |
+
use_cache: bool = False,
|
353 |
+
cache_position: Optional[torch.LongTensor] = None,
|
354 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
355 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
356 |
+
if isinstance(past_key_value, StaticCache):
|
357 |
+
raise ValueError(
|
358 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
359 |
+
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
360 |
+
)
|
361 |
+
|
362 |
+
output_attentions = False
|
363 |
+
|
364 |
+
bsz, q_len, _ = hidden_states.size()
|
365 |
+
|
366 |
+
query_states = self.q_proj(hidden_states)
|
367 |
+
key_states = self.k_proj(hidden_states)
|
368 |
+
value_states = self.v_proj(hidden_states)
|
369 |
+
|
370 |
+
# Flash attention requires the input to have the shape
|
371 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
372 |
+
# therefore we just need to keep the original shape
|
373 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
374 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
375 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
376 |
+
|
377 |
+
cos, sin = position_embeddings
|
378 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
379 |
+
|
380 |
+
if past_key_value is not None:
|
381 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
382 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
383 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
384 |
+
|
385 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
386 |
+
# to be able to avoid many of these transpose/reshape/view.
|
387 |
+
query_states = query_states.transpose(1, 2)
|
388 |
+
key_states = key_states.transpose(1, 2)
|
389 |
+
value_states = value_states.transpose(1, 2)
|
390 |
+
|
391 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
392 |
+
|
393 |
+
input_dtype = query_states.dtype
|
394 |
+
if input_dtype == torch.float32:
|
395 |
+
if torch.is_autocast_enabled():
|
396 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
397 |
+
# Handle the case where the model is quantized
|
398 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
399 |
+
target_dtype = self.config._pre_quantization_dtype
|
400 |
+
else:
|
401 |
+
target_dtype = self.q_proj.weight.dtype
|
402 |
+
|
403 |
+
logger.warning_once(
|
404 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
405 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
406 |
+
f" {target_dtype}."
|
407 |
+
)
|
408 |
+
|
409 |
+
query_states = query_states.to(target_dtype)
|
410 |
+
key_states = key_states.to(target_dtype)
|
411 |
+
value_states = value_states.to(target_dtype)
|
412 |
+
|
413 |
+
attn_output = _flash_attention_forward(
|
414 |
+
query_states,
|
415 |
+
key_states,
|
416 |
+
value_states,
|
417 |
+
attention_mask,
|
418 |
+
q_len,
|
419 |
+
position_ids=None,
|
420 |
+
dropout=dropout_rate,
|
421 |
+
sliding_window=getattr(self, "sliding_window", None),
|
422 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
423 |
+
is_causal=self.is_causal,
|
424 |
+
)
|
425 |
+
|
426 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
427 |
+
attn_output = self.o_proj(attn_output)
|
428 |
+
|
429 |
+
if not output_attentions:
|
430 |
+
attn_weights = None
|
431 |
+
|
432 |
+
return attn_output, attn_weights, past_key_value
|
433 |
+
|
434 |
+
|
435 |
+
class LlamaSdpaAttention(LlamaAttention):
|
436 |
+
"""
|
437 |
+
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
438 |
+
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
439 |
+
SDPA API.
|
440 |
+
"""
|
441 |
+
|
442 |
+
# Adapted from LlamaAttention.forward
|
443 |
+
def forward(
|
444 |
+
self,
|
445 |
+
hidden_states: torch.Tensor,
|
446 |
+
attention_mask: Optional[torch.Tensor] = None,
|
447 |
+
past_key_value: Optional[Cache] = None,
|
448 |
+
output_attentions: bool = False,
|
449 |
+
use_cache: bool = False,
|
450 |
+
cache_position: Optional[torch.LongTensor] = None,
|
451 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
452 |
+
**kwargs,
|
453 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
454 |
+
|
455 |
+
bsz, q_len, _ = hidden_states.size()
|
456 |
+
|
457 |
+
query_states = self.q_proj(hidden_states)
|
458 |
+
key_states = self.k_proj(hidden_states)
|
459 |
+
value_states = self.v_proj(hidden_states)
|
460 |
+
|
461 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
462 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
463 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
464 |
+
|
465 |
+
cos, sin = position_embeddings
|
466 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
467 |
+
|
468 |
+
if past_key_value is not None:
|
469 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
470 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
471 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
472 |
+
|
473 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
474 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
475 |
+
|
476 |
+
causal_mask = attention_mask
|
477 |
+
if attention_mask is not None:
|
478 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
479 |
+
|
480 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
481 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
482 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
483 |
+
query_states = query_states.contiguous()
|
484 |
+
key_states = key_states.contiguous()
|
485 |
+
value_states = value_states.contiguous()
|
486 |
+
|
487 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
488 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
489 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
490 |
+
|
491 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
492 |
+
query_states,
|
493 |
+
key_states,
|
494 |
+
value_states,
|
495 |
+
attn_mask=causal_mask,
|
496 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
497 |
+
is_causal=is_causal,
|
498 |
+
)
|
499 |
+
|
500 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
501 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
502 |
+
|
503 |
+
attn_output = self.o_proj(attn_output)
|
504 |
+
|
505 |
+
return attn_output, None, past_key_value
|
506 |
+
|
507 |
+
|
508 |
+
LLAMA_ATTENTION_CLASSES = {
|
509 |
+
"eager": LlamaAttention,
|
510 |
+
"flash_attention_2": LlamaFlashAttention2,
|
511 |
+
"sdpa": LlamaSdpaAttention,
|
512 |
+
}
|
513 |
+
|
514 |
+
|
515 |
+
class LlamaDecoderLayer(nn.Module):
|
516 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
517 |
+
super().__init__()
|
518 |
+
self.hidden_size = config.hidden_size
|
519 |
+
|
520 |
+
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
521 |
+
|
522 |
+
self.mlp = LlamaMLP(config)
|
523 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
524 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
525 |
+
|
526 |
+
def forward(
|
527 |
+
self,
|
528 |
+
hidden_states: torch.Tensor,
|
529 |
+
attention_mask: Optional[torch.Tensor] = None,
|
530 |
+
past_key_value: Optional[Cache] = None,
|
531 |
+
output_attentions: Optional[bool] = False,
|
532 |
+
use_cache: Optional[bool] = False,
|
533 |
+
cache_position: Optional[torch.LongTensor] = None,
|
534 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
535 |
+
**kwargs,
|
536 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
537 |
+
residual = hidden_states
|
538 |
+
|
539 |
+
hidden_states = self.input_layernorm(hidden_states)
|
540 |
+
|
541 |
+
# Self Attention
|
542 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
543 |
+
hidden_states=hidden_states,
|
544 |
+
attention_mask=attention_mask,
|
545 |
+
past_key_value=past_key_value,
|
546 |
+
output_attentions=output_attentions,
|
547 |
+
use_cache=use_cache,
|
548 |
+
cache_position=cache_position,
|
549 |
+
position_embeddings=position_embeddings,
|
550 |
+
**kwargs,
|
551 |
+
)
|
552 |
+
hidden_states = residual + hidden_states
|
553 |
+
|
554 |
+
# Fully Connected
|
555 |
+
residual = hidden_states
|
556 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
557 |
+
hidden_states = self.mlp(hidden_states)
|
558 |
+
hidden_states = residual + hidden_states
|
559 |
+
|
560 |
+
outputs = (hidden_states,)
|
561 |
+
|
562 |
+
if output_attentions:
|
563 |
+
outputs += (self_attn_weights,)
|
564 |
+
|
565 |
+
if use_cache:
|
566 |
+
outputs += (present_key_value,)
|
567 |
+
|
568 |
+
return outputs
|
models/nextstep_model.py
ADDED
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import inspect
|
4 |
+
from loguru import logger
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import CrossEntropyLoss
|
10 |
+
|
11 |
+
from safetensors.torch import safe_open
|
12 |
+
from transformers.modeling_utils import PreTrainedModel
|
13 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
14 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
15 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
16 |
+
|
17 |
+
from models.config import NextStepConfig
|
18 |
+
from models.llama_model import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding
|
19 |
+
from models.heads import FlowMatchingHead
|
20 |
+
from utils.misc import LargeInt
|
21 |
+
from utils.compile_utils import smart_compile
|
22 |
+
from utils.model_utils import get_2d_sincos_pos_embed
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class NextStepOutputWithPast(CausalLMOutputWithPast):
|
27 |
+
lm_loss: torch.FloatTensor | None = None
|
28 |
+
im_loss: torch.FloatTensor | None = None
|
29 |
+
|
30 |
+
|
31 |
+
class NextStepPreTrainedModel(PreTrainedModel):
|
32 |
+
config_class = NextStepConfig
|
33 |
+
supports_gradient_checkpointing = True
|
34 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
35 |
+
_skip_keys_device_placement = ["past_key_values"]
|
36 |
+
_supports_flash_attn_2 = True
|
37 |
+
_supports_sdpa = True
|
38 |
+
_supports_cache_class = True
|
39 |
+
_supports_quantized_cache = True
|
40 |
+
_supports_static_cache = True
|
41 |
+
|
42 |
+
def _init_weights(self, module):
|
43 |
+
std = self.config.initializer_range
|
44 |
+
if isinstance(module, nn.Linear):
|
45 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
46 |
+
if module.bias is not None:
|
47 |
+
module.bias.data.zero_()
|
48 |
+
elif isinstance(module, nn.Embedding):
|
49 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
50 |
+
if module.padding_idx is not None:
|
51 |
+
module.weight.data[module.padding_idx].zero_()
|
52 |
+
|
53 |
+
@property
|
54 |
+
def trainable_params(self) -> float:
|
55 |
+
n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
56 |
+
return LargeInt(n_params)
|
57 |
+
|
58 |
+
|
59 |
+
class NextStep(NextStepPreTrainedModel):
|
60 |
+
|
61 |
+
def __init__(self, config: NextStepConfig):
|
62 |
+
super().__init__(config)
|
63 |
+
self.padding_idx = config.pad_token_id
|
64 |
+
self.vocab_size = config.vocab_size
|
65 |
+
|
66 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
67 |
+
|
68 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
69 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
70 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
71 |
+
|
72 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
73 |
+
|
74 |
+
self.gradient_checkpointing = False
|
75 |
+
|
76 |
+
# Initialize weights and apply final processing
|
77 |
+
self.post_init()
|
78 |
+
|
79 |
+
token_dim = self.config.latent_channels * self.config.latent_patch_size**2
|
80 |
+
|
81 |
+
self.image_in_projector = nn.Linear(token_dim, config.hidden_size)
|
82 |
+
self.image_in_projector.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
83 |
+
self.image_in_projector.bias.data.zero_()
|
84 |
+
|
85 |
+
self.image_out_projector = nn.Linear(config.hidden_size, config.hidden_size)
|
86 |
+
self.image_out_projector.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
87 |
+
self.image_out_projector.bias.data.zero_()
|
88 |
+
|
89 |
+
self.image_head = FlowMatchingHead(
|
90 |
+
input_dim=token_dim,
|
91 |
+
cond_dim=config.hidden_size,
|
92 |
+
dim=config.fm_head_dim,
|
93 |
+
layers=config.fm_head_layers,
|
94 |
+
)
|
95 |
+
|
96 |
+
if config.use_gen_pos_embed:
|
97 |
+
self.init_gen_pos_embed()
|
98 |
+
|
99 |
+
def init_gen_pos_embed(self):
|
100 |
+
self.register_buffer(
|
101 |
+
"gen_pos_embed",
|
102 |
+
torch.from_numpy(
|
103 |
+
get_2d_sincos_pos_embed(
|
104 |
+
self.config.hidden_size, self.config.base_image_grid_size
|
105 |
+
)
|
106 |
+
).float().unsqueeze(0),
|
107 |
+
)
|
108 |
+
|
109 |
+
def gen_pos_embed_with_ar(self, h, w):
|
110 |
+
bsz, hw, dim = self.gen_pos_embed.shape
|
111 |
+
gen_pos_embed = self.gen_pos_embed.reshape(bsz, int(hw**0.5), int(hw**0.5), dim)
|
112 |
+
gen_pos_embed = gen_pos_embed[:, :h, :w, :]
|
113 |
+
gen_pos_embed = gen_pos_embed.reshape(bsz, -1, dim)
|
114 |
+
return gen_pos_embed
|
115 |
+
|
116 |
+
@property
|
117 |
+
def image_size(self):
|
118 |
+
return self.config.image_size
|
119 |
+
|
120 |
+
@property
|
121 |
+
def image_patch_size(self):
|
122 |
+
return self.config.patch_size
|
123 |
+
|
124 |
+
@property
|
125 |
+
def image_grid_size(self):
|
126 |
+
return round(self.image_size / self.image_patch_size)
|
127 |
+
|
128 |
+
def get_input_embeddings(self):
|
129 |
+
return self.embed_tokens
|
130 |
+
|
131 |
+
def set_input_embeddings(self, value):
|
132 |
+
self.embed_tokens = value
|
133 |
+
|
134 |
+
def get_output_embeddings(self):
|
135 |
+
return self.lm_head
|
136 |
+
|
137 |
+
def set_output_embeddings(self, new_embeddings):
|
138 |
+
self.lm_head = new_embeddings
|
139 |
+
|
140 |
+
def load_lm_head(self, lm_head_dir: str | None = None):
|
141 |
+
index_json_file = os.path.join(lm_head_dir, "model.safetensors.index.json")
|
142 |
+
head_weight_name = "lm_head.weight" if not self.config.tie_word_embeddings else "model.embed_tokens.weight"
|
143 |
+
if os.path.exists(index_json_file):
|
144 |
+
with open(index_json_file, "r") as f:
|
145 |
+
index = json.load(f)
|
146 |
+
model_name = index["weight_map"][head_weight_name]
|
147 |
+
else:
|
148 |
+
model_name = "model.safetensors"
|
149 |
+
with safe_open(os.path.join(lm_head_dir, model_name), framework="pt") as f:
|
150 |
+
loaded_weight = f.get_tensor(head_weight_name)
|
151 |
+
loaded_weight = loaded_weight.to(dtype=self.lm_head.weight.dtype, device=self.lm_head.weight.device)
|
152 |
+
self.lm_head.weight.data.copy_(loaded_weight)
|
153 |
+
|
154 |
+
def patchify(self, img: torch.Tensor):
|
155 |
+
"""
|
156 |
+
img: (bsz, C, H, W)
|
157 |
+
x: (bsz, H * W / patch_size**2, patch_size**2 * C)
|
158 |
+
"""
|
159 |
+
bsz, c, h, w = img.shape
|
160 |
+
p = self.config.latent_patch_size
|
161 |
+
h_, w_ = h // p, w // p
|
162 |
+
|
163 |
+
img = img.reshape(bsz, c, h_, p, w_, p)
|
164 |
+
img = torch.einsum("nchpwq->nhwcpq", img)
|
165 |
+
x = img.reshape(bsz, h_ * w_, c * p**2)
|
166 |
+
return x
|
167 |
+
|
168 |
+
def unpatchify(self, x: torch.Tensor, h: int = None, w: int = None):
|
169 |
+
"""
|
170 |
+
x: (bsz, H * W / patch_size**2, patch_size**2 * C)
|
171 |
+
img: (bsz, C, H, W)
|
172 |
+
"""
|
173 |
+
bsz = x.shape[0]
|
174 |
+
p = self.config.latent_patch_size
|
175 |
+
c = self.config.latent_channels
|
176 |
+
if h is None and w is None:
|
177 |
+
h_ = w_ = int(x.shape[1] ** 0.5)
|
178 |
+
else:
|
179 |
+
h_, w_ = h, w
|
180 |
+
assert h_ * w_ == x.shape[1], f"Invalid sequence length {x.shape[1]}."
|
181 |
+
|
182 |
+
x = x.reshape(bsz, h_, w_, c, p, p)
|
183 |
+
x = torch.einsum("nhwcpq->nchpwq", x)
|
184 |
+
img = x.reshape(bsz, c, h_ * p, w_ * p)
|
185 |
+
return img
|
186 |
+
|
187 |
+
def prepare_inputs_embeds(self, input_ids: torch.LongTensor | None = None, latents: torch.FloatTensor | None = None):
|
188 |
+
if latents is None:
|
189 |
+
if not self.training:
|
190 |
+
return self.embed_tokens(input_ids)
|
191 |
+
else: # dummy forward for image pass, for the consistent shape of gradient.
|
192 |
+
raise NotImplementedError("Dummy forward for image pass is not implemented.")
|
193 |
+
else:
|
194 |
+
bs, seq_length = input_ids.shape
|
195 |
+
inputs_embeds = torch.zeros(
|
196 |
+
(bs, seq_length, self.config.hidden_size),
|
197 |
+
device=self.embed_tokens.weight.device,
|
198 |
+
dtype=self.embed_tokens.weight.dtype,
|
199 |
+
)
|
200 |
+
im_indices = input_ids == self.config.image_placeholder_id
|
201 |
+
lm_indices = ~im_indices
|
202 |
+
|
203 |
+
if isinstance(latents, list):
|
204 |
+
tokens = torch.cat([self.patchify(latent) for latent in latents], dim=1)
|
205 |
+
else:
|
206 |
+
tokens = self.patchify(latents)
|
207 |
+
# tokens = tokens.reshape(1, -1, tokens.shape[-1])
|
208 |
+
|
209 |
+
image_embeds = self.image_in_projector(tokens)
|
210 |
+
image_embeds = image_embeds.view(-1, self.config.hidden_size)
|
211 |
+
|
212 |
+
token_embeds = self.embed_tokens(input_ids[lm_indices])
|
213 |
+
|
214 |
+
inputs_embeds[im_indices] = image_embeds.to(inputs_embeds.dtype)
|
215 |
+
inputs_embeds[lm_indices] = token_embeds
|
216 |
+
|
217 |
+
return inputs_embeds
|
218 |
+
|
219 |
+
def _update_causal_mask(
|
220 |
+
self,
|
221 |
+
attention_mask: torch.Tensor,
|
222 |
+
input_tensor: torch.Tensor,
|
223 |
+
cache_position: torch.Tensor,
|
224 |
+
past_key_values: Cache,
|
225 |
+
output_attentions: bool,
|
226 |
+
):
|
227 |
+
if self.config._attn_implementation == "flash_attention_2":
|
228 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
229 |
+
return attention_mask
|
230 |
+
return None
|
231 |
+
|
232 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
233 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
234 |
+
# to infer the attention mask.
|
235 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
236 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
237 |
+
|
238 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
239 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
240 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
241 |
+
attention_mask,
|
242 |
+
inputs_embeds=input_tensor,
|
243 |
+
past_key_values_length=past_seen_tokens,
|
244 |
+
is_training=self.training,
|
245 |
+
):
|
246 |
+
return None
|
247 |
+
|
248 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
249 |
+
sequence_length = input_tensor.shape[1]
|
250 |
+
if using_static_cache:
|
251 |
+
target_length = past_key_values.get_max_cache_shape()
|
252 |
+
else:
|
253 |
+
target_length = (
|
254 |
+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1
|
255 |
+
)
|
256 |
+
|
257 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
258 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
259 |
+
attention_mask,
|
260 |
+
sequence_length=sequence_length,
|
261 |
+
target_length=target_length,
|
262 |
+
dtype=dtype,
|
263 |
+
device=device,
|
264 |
+
cache_position=cache_position,
|
265 |
+
batch_size=input_tensor.shape[0],
|
266 |
+
)
|
267 |
+
|
268 |
+
if (
|
269 |
+
self.config._attn_implementation == "sdpa"
|
270 |
+
and attention_mask is not None
|
271 |
+
and attention_mask.device.type == "cuda"
|
272 |
+
and not output_attentions
|
273 |
+
):
|
274 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
275 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
276 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
277 |
+
min_dtype = torch.finfo(dtype).min
|
278 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
279 |
+
|
280 |
+
return causal_mask
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
284 |
+
attention_mask: torch.Tensor,
|
285 |
+
sequence_length: int,
|
286 |
+
target_length: int,
|
287 |
+
dtype: torch.dtype,
|
288 |
+
device: torch.device,
|
289 |
+
cache_position: torch.Tensor,
|
290 |
+
batch_size: int,
|
291 |
+
**kwargs,
|
292 |
+
):
|
293 |
+
"""
|
294 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
295 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
attention_mask (`torch.Tensor`):
|
299 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
300 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
301 |
+
sequence_length (`int`):
|
302 |
+
The sequence length being processed.
|
303 |
+
target_length (`int`):
|
304 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
305 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
306 |
+
dtype (`torch.dtype`):
|
307 |
+
The dtype to use for the 4D attention mask.
|
308 |
+
device (`torch.device`):
|
309 |
+
The device to plcae the 4D attention mask on.
|
310 |
+
cache_position (`torch.Tensor`):
|
311 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
312 |
+
batch_size (`torch.Tensor`):
|
313 |
+
Batch size.
|
314 |
+
"""
|
315 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
316 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
317 |
+
causal_mask = attention_mask
|
318 |
+
else:
|
319 |
+
min_dtype = torch.finfo(dtype).min
|
320 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
321 |
+
if sequence_length != 1:
|
322 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
323 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
324 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
325 |
+
if attention_mask is not None:
|
326 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
327 |
+
mask_length = attention_mask.shape[-1]
|
328 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
329 |
+
padding_mask = padding_mask == 0
|
330 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
|
331 |
+
|
332 |
+
return causal_mask
|
333 |
+
|
334 |
+
@smart_compile()
|
335 |
+
def forward_model(
|
336 |
+
self,
|
337 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
338 |
+
attention_mask: torch.Tensor | None = None,
|
339 |
+
past_key_values: Cache | list[torch.FloatTensor] | None = None,
|
340 |
+
use_cache: bool | None = None,
|
341 |
+
output_attentions: bool | None = None,
|
342 |
+
output_hidden_states: bool | None = None,
|
343 |
+
cache_position: torch.LongTensor | None = None,
|
344 |
+
) -> tuple | BaseModelOutputWithPast:
|
345 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
346 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
347 |
+
|
348 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
349 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
350 |
+
use_cache = False
|
351 |
+
|
352 |
+
if use_cache and past_key_values is None:
|
353 |
+
past_key_values = DynamicCache()
|
354 |
+
|
355 |
+
if cache_position is None:
|
356 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
357 |
+
cache_position = torch.arange(
|
358 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
359 |
+
)
|
360 |
+
position_ids = cache_position.unsqueeze(0)
|
361 |
+
|
362 |
+
causal_mask = self._update_causal_mask(
|
363 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
364 |
+
)
|
365 |
+
hidden_states = inputs_embeds
|
366 |
+
|
367 |
+
# create position embeddings to be shared across the decoder layers
|
368 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
369 |
+
|
370 |
+
# decoder layers
|
371 |
+
all_hidden_states = () if output_hidden_states else None
|
372 |
+
all_self_attns = () if output_attentions else None
|
373 |
+
|
374 |
+
for decoder_layer in self.layers:
|
375 |
+
if output_hidden_states:
|
376 |
+
all_hidden_states += (hidden_states,)
|
377 |
+
|
378 |
+
if self.gradient_checkpointing and self.training:
|
379 |
+
layer_outputs = self._gradient_checkpointing_func(
|
380 |
+
decoder_layer.__call__,
|
381 |
+
hidden_states,
|
382 |
+
causal_mask,
|
383 |
+
past_key_values,
|
384 |
+
output_attentions,
|
385 |
+
use_cache,
|
386 |
+
cache_position,
|
387 |
+
position_embeddings,
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
layer_outputs = decoder_layer(
|
391 |
+
hidden_states,
|
392 |
+
attention_mask=causal_mask,
|
393 |
+
past_key_value=past_key_values,
|
394 |
+
output_attentions=output_attentions,
|
395 |
+
use_cache=use_cache,
|
396 |
+
cache_position=cache_position,
|
397 |
+
position_embeddings=position_embeddings,
|
398 |
+
)
|
399 |
+
|
400 |
+
hidden_states = layer_outputs[0]
|
401 |
+
|
402 |
+
if output_attentions:
|
403 |
+
all_self_attns += (layer_outputs[1],)
|
404 |
+
|
405 |
+
hidden_states = self.norm(hidden_states)
|
406 |
+
|
407 |
+
# add hidden states from the last decoder layer
|
408 |
+
if output_hidden_states:
|
409 |
+
all_hidden_states += (hidden_states,)
|
410 |
+
|
411 |
+
return BaseModelOutputWithPast(
|
412 |
+
last_hidden_state=hidden_states,
|
413 |
+
past_key_values=past_key_values if use_cache else None,
|
414 |
+
hidden_states=all_hidden_states,
|
415 |
+
attentions=all_self_attns,
|
416 |
+
)
|
417 |
+
|
418 |
+
|
419 |
+
|
420 |
+
def prepare_inputs_for_generation(
|
421 |
+
self,
|
422 |
+
input_ids: torch.LongTensor,
|
423 |
+
past_key_values: Cache | None = None,
|
424 |
+
attention_mask: torch.LongTensor | None = None,
|
425 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
426 |
+
cache_position: torch.LongTensor | None = None,
|
427 |
+
**kwargs,
|
428 |
+
):
|
429 |
+
"""
|
430 |
+
Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
|
431 |
+
slicing inputs given the existing cache.
|
432 |
+
|
433 |
+
See the forward pass in the model documentation for expected arguments (different models might have different
|
434 |
+
requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
|
435 |
+
"""
|
436 |
+
|
437 |
+
# 1. Handle BC:
|
438 |
+
model_inputs = {}
|
439 |
+
# - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
|
440 |
+
if self._supports_cache_class:
|
441 |
+
model_inputs["cache_position"] = cache_position
|
442 |
+
# - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
|
443 |
+
# function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
|
444 |
+
# (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
|
445 |
+
elif cache_position is None:
|
446 |
+
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
447 |
+
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
448 |
+
|
449 |
+
# 2. Generic cache-dependent input preparation
|
450 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
451 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
452 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
453 |
+
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
|
454 |
+
if past_key_values is not None:
|
455 |
+
model_inputs["past_key_values"] = past_key_values
|
456 |
+
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
|
457 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
458 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
459 |
+
input_ids = input_ids[:, cache_position]
|
460 |
+
|
461 |
+
# 3. Prepare base model inputs
|
462 |
+
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
463 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
464 |
+
if not self.config.is_encoder_decoder:
|
465 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
466 |
+
model_inputs[input_ids_key] = None
|
467 |
+
model_inputs["inputs_embeds"] = inputs_embeds
|
468 |
+
else:
|
469 |
+
# `clone` calls in this function ensure a consistent stride. See #32227
|
470 |
+
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
471 |
+
model_inputs["inputs_embeds"] = None
|
472 |
+
else:
|
473 |
+
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
474 |
+
|
475 |
+
# 4. Create missing `position_ids` on the fly
|
476 |
+
if (
|
477 |
+
attention_mask is not None
|
478 |
+
and kwargs.get("position_ids") is None
|
479 |
+
and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
|
480 |
+
):
|
481 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
482 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
483 |
+
kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
|
484 |
+
|
485 |
+
# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
486 |
+
for model_input_name in ["position_ids", "token_type_ids"]:
|
487 |
+
model_input = kwargs.get(model_input_name)
|
488 |
+
if model_input is not None:
|
489 |
+
if past_key_values:
|
490 |
+
model_input = model_input[:, -input_ids.shape[1] :]
|
491 |
+
model_input = model_input.clone(memory_format=torch.contiguous_format)
|
492 |
+
model_inputs[model_input_name] = model_input
|
493 |
+
|
494 |
+
# 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
|
495 |
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
496 |
+
if model_inputs["inputs_embeds"] is not None:
|
497 |
+
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
498 |
+
device = model_inputs["inputs_embeds"].device
|
499 |
+
else:
|
500 |
+
batch_size, sequence_length = model_inputs[input_ids_key].shape
|
501 |
+
device = model_inputs[input_ids_key].device
|
502 |
+
|
503 |
+
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
504 |
+
# the 4D causal mask exists, it should be present in the base model (XXXModel class).
|
505 |
+
base_model = getattr(self, self.base_model_prefix, None)
|
506 |
+
if base_model is None:
|
507 |
+
causal_mask_creation_function = getattr(self, "_prepare_4d_causal_attention_mask_with_cache_position", None)
|
508 |
+
else:
|
509 |
+
causal_mask_creation_function = getattr(
|
510 |
+
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
511 |
+
)
|
512 |
+
if causal_mask_creation_function is None:
|
513 |
+
logger.warning_once(
|
514 |
+
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
515 |
+
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
516 |
+
"writing code, see Llama for an example implementation. If you're a user, please report this "
|
517 |
+
"issue on GitHub."
|
518 |
+
)
|
519 |
+
else:
|
520 |
+
attention_mask = causal_mask_creation_function(
|
521 |
+
attention_mask,
|
522 |
+
sequence_length=sequence_length,
|
523 |
+
target_length=past_key_values.get_max_cache_shape(),
|
524 |
+
dtype=self.dtype,
|
525 |
+
device=device,
|
526 |
+
cache_position=cache_position,
|
527 |
+
batch_size=batch_size,
|
528 |
+
config=self.config,
|
529 |
+
past_key_values=past_key_values,
|
530 |
+
)
|
531 |
+
if attention_mask is not None:
|
532 |
+
model_inputs["attention_mask"] = attention_mask
|
533 |
+
|
534 |
+
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
535 |
+
for key, value in kwargs.items():
|
536 |
+
if key not in model_inputs:
|
537 |
+
model_inputs[key] = value
|
538 |
+
|
539 |
+
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
540 |
+
model_inputs.pop("labels", None)
|
541 |
+
return model_inputs
|
542 |
+
|
543 |
+
@torch.no_grad()
|
544 |
+
def generate(self, inputs: torch.LongTensor = None, **kwargs):
|
545 |
+
input_ids = kwargs.pop("input_ids")
|
546 |
+
latents = kwargs.pop("latents", None)
|
547 |
+
inputs_embeds = self.prepare_inputs_embeds(input_ids, latents)
|
548 |
+
return super().generate(inputs=inputs, input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
|
549 |
+
|
550 |
+
def gradient_checkpointing_enable(self, **kwargs):
|
551 |
+
super().gradient_checkpointing_enable(**kwargs)
|
552 |
+
|
553 |
+
self.image_head.net.grad_checkpointing = True
|
pytorch-model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c53c995d71566ae0cedc0f2e405db0bd71543d34e253e1f180c0bc58985404a3
|
3 |
+
size 9962132680
|
pytorch-model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1c3b95aa05095eb482494102594d25647a6d4e5963ecde87de7b1a7d6bb0e2e
|
3 |
+
size 9909693448
|
pytorch-model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f13b038eea5b0f797b927202d562801d35a8e998fd149e0e6c3035c988eaba7c
|
3 |
+
size 8478742432
|
pytorch-model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f990b56fcaa144772b37626cec841d519a62662d74fba20e9102a55cb2e7b638
|
3 |
+
size 1557135464
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.34.0
|
2 |
+
einops==0.8.1
|
3 |
+
gradio==5.42.0
|
4 |
+
loguru==0.7.3
|
5 |
+
numpy==1.26.4
|
6 |
+
omegaconf==2.3.0
|
7 |
+
Pillow==11.0.0
|
8 |
+
Requests==2.32.4
|
9 |
+
safetensors==0.5.3
|
10 |
+
tabulate==0.9.0
|
11 |
+
torch==2.5.1
|
12 |
+
torchvision==0.20.1
|
13 |
+
tqdm==4.67.1
|
14 |
+
transformers==4.55.0
|
special_tokens_map.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|image_area|>",
|
4 |
+
"<|begin_of_image|>",
|
5 |
+
"<|end_of_image|>",
|
6 |
+
"<|image_placeholder|>",
|
7 |
+
"<|begin_of_prompt_refinement|>",
|
8 |
+
"<|end_of_prompt_refinement|>",
|
9 |
+
"<|begin_of_thinking|>",
|
10 |
+
"<|end_of_thinking|>"
|
11 |
+
],
|
12 |
+
"eos_token": {
|
13 |
+
"content": "<|endoftext|>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false
|
18 |
+
},
|
19 |
+
"pad_token": {
|
20 |
+
"content": "[PAD]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false
|
25 |
+
}
|
26 |
+
}
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a0cb05912073d41a5b70def44dfe1f4331bfae17a7f5b80873ef966f19dd2c8
|
3 |
+
size 11423661
|
tokenizer_config.json
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_prefix_space": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"151643": {
|
6 |
+
"content": "<|endoftext|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"151644": {
|
14 |
+
"content": "<|im_start|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"151645": {
|
22 |
+
"content": "<|im_end|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"151646": {
|
30 |
+
"content": "<|object_ref_start|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"151647": {
|
38 |
+
"content": "<|object_ref_end|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"151648": {
|
46 |
+
"content": "<|box_start|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"151649": {
|
54 |
+
"content": "<|box_end|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"151650": {
|
62 |
+
"content": "<|quad_start|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"151651": {
|
70 |
+
"content": "<|quad_end|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"151652": {
|
78 |
+
"content": "<|vision_start|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"151653": {
|
86 |
+
"content": "<|vision_end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"151654": {
|
94 |
+
"content": "<|vision_pad|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"151655": {
|
102 |
+
"content": "<|image_pad|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"151656": {
|
110 |
+
"content": "<|video_pad|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"151657": {
|
118 |
+
"content": "<tool_call>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": false
|
124 |
+
},
|
125 |
+
"151658": {
|
126 |
+
"content": "</tool_call>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": false,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false,
|
131 |
+
"special": false
|
132 |
+
},
|
133 |
+
"151659": {
|
134 |
+
"content": "<|fim_prefix|>",
|
135 |
+
"lstrip": false,
|
136 |
+
"normalized": false,
|
137 |
+
"rstrip": false,
|
138 |
+
"single_word": false,
|
139 |
+
"special": false
|
140 |
+
},
|
141 |
+
"151660": {
|
142 |
+
"content": "<|fim_middle|>",
|
143 |
+
"lstrip": false,
|
144 |
+
"normalized": false,
|
145 |
+
"rstrip": false,
|
146 |
+
"single_word": false,
|
147 |
+
"special": false
|
148 |
+
},
|
149 |
+
"151661": {
|
150 |
+
"content": "<|fim_suffix|>",
|
151 |
+
"lstrip": false,
|
152 |
+
"normalized": false,
|
153 |
+
"rstrip": false,
|
154 |
+
"single_word": false,
|
155 |
+
"special": false
|
156 |
+
},
|
157 |
+
"151662": {
|
158 |
+
"content": "<|fim_pad|>",
|
159 |
+
"lstrip": false,
|
160 |
+
"normalized": false,
|
161 |
+
"rstrip": false,
|
162 |
+
"single_word": false,
|
163 |
+
"special": false
|
164 |
+
},
|
165 |
+
"151663": {
|
166 |
+
"content": "<|repo_name|>",
|
167 |
+
"lstrip": false,
|
168 |
+
"normalized": false,
|
169 |
+
"rstrip": false,
|
170 |
+
"single_word": false,
|
171 |
+
"special": false
|
172 |
+
},
|
173 |
+
"151664": {
|
174 |
+
"content": "<|file_sep|>",
|
175 |
+
"lstrip": false,
|
176 |
+
"normalized": false,
|
177 |
+
"rstrip": false,
|
178 |
+
"single_word": false,
|
179 |
+
"special": false
|
180 |
+
},
|
181 |
+
"151665": {
|
182 |
+
"content": "[PAD]",
|
183 |
+
"lstrip": false,
|
184 |
+
"normalized": false,
|
185 |
+
"rstrip": false,
|
186 |
+
"single_word": false,
|
187 |
+
"special": true
|
188 |
+
},
|
189 |
+
"151666": {
|
190 |
+
"content": "<|image_area|>",
|
191 |
+
"lstrip": false,
|
192 |
+
"normalized": false,
|
193 |
+
"rstrip": false,
|
194 |
+
"single_word": false,
|
195 |
+
"special": true
|
196 |
+
},
|
197 |
+
"151667": {
|
198 |
+
"content": "<|begin_of_image|>",
|
199 |
+
"lstrip": false,
|
200 |
+
"normalized": false,
|
201 |
+
"rstrip": false,
|
202 |
+
"single_word": false,
|
203 |
+
"special": true
|
204 |
+
},
|
205 |
+
"151668": {
|
206 |
+
"content": "<|end_of_image|>",
|
207 |
+
"lstrip": false,
|
208 |
+
"normalized": false,
|
209 |
+
"rstrip": false,
|
210 |
+
"single_word": false,
|
211 |
+
"special": true
|
212 |
+
},
|
213 |
+
"151669": {
|
214 |
+
"content": "<|image_placeholder|>",
|
215 |
+
"lstrip": false,
|
216 |
+
"normalized": false,
|
217 |
+
"rstrip": false,
|
218 |
+
"single_word": false,
|
219 |
+
"special": true
|
220 |
+
},
|
221 |
+
"151670": {
|
222 |
+
"content": "<|begin_of_prompt_refinement|>",
|
223 |
+
"lstrip": false,
|
224 |
+
"normalized": false,
|
225 |
+
"rstrip": false,
|
226 |
+
"single_word": false,
|
227 |
+
"special": true
|
228 |
+
},
|
229 |
+
"151671": {
|
230 |
+
"content": "<|end_of_prompt_refinement|>",
|
231 |
+
"lstrip": false,
|
232 |
+
"normalized": false,
|
233 |
+
"rstrip": false,
|
234 |
+
"single_word": false,
|
235 |
+
"special": true
|
236 |
+
},
|
237 |
+
"151672": {
|
238 |
+
"content": "<|begin_of_thinking|>",
|
239 |
+
"lstrip": false,
|
240 |
+
"normalized": false,
|
241 |
+
"rstrip": false,
|
242 |
+
"single_word": false,
|
243 |
+
"special": true
|
244 |
+
},
|
245 |
+
"151673": {
|
246 |
+
"content": "<|end_of_thinking|>",
|
247 |
+
"lstrip": false,
|
248 |
+
"normalized": false,
|
249 |
+
"rstrip": false,
|
250 |
+
"single_word": false,
|
251 |
+
"special": true
|
252 |
+
}
|
253 |
+
},
|
254 |
+
"additional_special_tokens": [
|
255 |
+
"<|image_area|>",
|
256 |
+
"<|begin_of_image|>",
|
257 |
+
"<|end_of_image|>",
|
258 |
+
"<|image_placeholder|>",
|
259 |
+
"<|begin_of_prompt_refinement|>",
|
260 |
+
"<|end_of_prompt_refinement|>",
|
261 |
+
"<|begin_of_thinking|>",
|
262 |
+
"<|end_of_thinking|>"
|
263 |
+
],
|
264 |
+
"bos_token": null,
|
265 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
266 |
+
"clean_up_tokenization_spaces": false,
|
267 |
+
"eos_token": "<|endoftext|>",
|
268 |
+
"errors": "replace",
|
269 |
+
"extra_special_tokens": {},
|
270 |
+
"model_max_length": 8192,
|
271 |
+
"pad_token": "[PAD]",
|
272 |
+
"padding_side": "right",
|
273 |
+
"split_special_tokens": false,
|
274 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
275 |
+
"unk_token": null
|
276 |
+
}
|
utils/aspect_ratio.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import PIL.Image
|
3 |
+
|
4 |
+
ANY_ASPECT_RATIO = (0, 0)
|
5 |
+
|
6 |
+
HW_ASPECT_RATIOS = [
|
7 |
+
(8, 32), # 256
|
8 |
+
(9, 28), # 252
|
9 |
+
(10, 25), # 250
|
10 |
+
(11, 23), # 253
|
11 |
+
(12, 21), # 252
|
12 |
+
(13, 19), # 247
|
13 |
+
(14, 18), # 252
|
14 |
+
(15, 17), # 255
|
15 |
+
(16, 16), # 256
|
16 |
+
(17, 15), # 255
|
17 |
+
(18, 14), # 252
|
18 |
+
(19, 13), # 247
|
19 |
+
(21, 12), # 252
|
20 |
+
(23, 11), # 253
|
21 |
+
(25, 10), # 250
|
22 |
+
(28, 9), # 252
|
23 |
+
(32, 8), # 256
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
def get_ar_base(ars: list[tuple[int, int]] = HW_ASPECT_RATIOS):
|
28 |
+
sqrt_products = [round(np.sqrt(h * w)) for h, w in ars]
|
29 |
+
return round(np.mean(sqrt_products))
|
30 |
+
|
31 |
+
|
32 |
+
def ar2str(h: int, w: int) -> str:
|
33 |
+
return f"{h}*{w}"
|
34 |
+
|
35 |
+
|
36 |
+
def str2ar(s: str) -> tuple[int, int]:
|
37 |
+
return tuple(map(int, s.split("*")))
|
38 |
+
|
39 |
+
def center_crop_arr_with_buckets(pil_image, ars: list[tuple[int, int]] = HW_ASPECT_RATIOS, crop=True, buckets: list[int] = [256, 512, 768, 1024]):
|
40 |
+
"""
|
41 |
+
Center crop the image to match the closest aspect ratio from the provided list.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
pil_image: PIL Image to be cropped
|
45 |
+
image_size: Target size for the smaller dimension
|
46 |
+
ars: List of aspect ratios as (height, width) tuples
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
PIL Image cropped to the closest aspect ratio
|
50 |
+
"""
|
51 |
+
# ar_base = get_ar_base(ars)
|
52 |
+
# Get current image dimensions
|
53 |
+
width, height = pil_image.size
|
54 |
+
|
55 |
+
buckets = sorted(buckets, reverse=True)
|
56 |
+
image_size = buckets[-1]
|
57 |
+
|
58 |
+
for bucket in buckets:
|
59 |
+
if width * height >= bucket * bucket:
|
60 |
+
image_size = bucket
|
61 |
+
break
|
62 |
+
|
63 |
+
return center_crop_arr_with_ar(pil_image, image_size, ars, crop)
|
64 |
+
|
65 |
+
def center_crop_arr_with_ar(pil_image, image_size: int, ars: list[tuple[int, int]] = HW_ASPECT_RATIOS, crop=True):
|
66 |
+
"""
|
67 |
+
Center crop the image to match the closest aspect ratio from the provided list.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
pil_image: PIL Image to be cropped
|
71 |
+
image_sizes: Target size for the smaller dimension
|
72 |
+
ars: List of aspect ratios as (height, width) tuples
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
PIL Image cropped to the closest aspect ratio
|
76 |
+
"""
|
77 |
+
|
78 |
+
ar_base = get_ar_base(ars)
|
79 |
+
assert image_size % ar_base == 0, f"image_size must be divisible by {ar_base}"
|
80 |
+
|
81 |
+
# Get current image dimensions
|
82 |
+
width, height = pil_image.size
|
83 |
+
|
84 |
+
current_ar = height / width
|
85 |
+
|
86 |
+
# Find the closest aspect ratio
|
87 |
+
closest_ar_idx = np.argmin([abs(current_ar - (h / w)) for h, w in ars])
|
88 |
+
target_h, target_w = ars[closest_ar_idx]
|
89 |
+
|
90 |
+
if crop:
|
91 |
+
target_h, target_w = round(image_size / ar_base * target_h), round(image_size / ar_base * target_w)
|
92 |
+
|
93 |
+
# First, resize the image while maintaining aspect ratio to ensure the smaller dimension is at least the target size
|
94 |
+
scale = max(target_h / height, target_w / width)
|
95 |
+
new_height = round(height * scale)
|
96 |
+
new_width = round(width * scale)
|
97 |
+
pil_image = pil_image.resize((new_width, new_height), resample=PIL.Image.LANCZOS)
|
98 |
+
|
99 |
+
arr = np.array(pil_image)
|
100 |
+
# Then perform center crop to the target dimensions
|
101 |
+
crop_y = (new_height - target_h) // 2
|
102 |
+
crop_x = (new_width - target_w) // 2
|
103 |
+
|
104 |
+
return PIL.Image.fromarray(arr[crop_y : crop_y + target_h, crop_x : crop_x + target_w])
|
105 |
+
else:
|
106 |
+
scale = image_size // ar_base
|
107 |
+
return pil_image.resize((round(target_w * scale), round(target_h * scale)), resample=PIL.Image.LANCZOS)
|
utils/compile_utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
from typing import Callable, Dict, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from loguru import logger
|
9 |
+
|
10 |
+
"""
|
11 |
+
Usage:
|
12 |
+
|
13 |
+
1. Control through environment variable (at startup):
|
14 |
+
export TORCH_COMPILE_ENABLE=true
|
15 |
+
python your_script.py
|
16 |
+
|
17 |
+
2. Control through environment variable (disable):
|
18 |
+
export TORCH_COMPILE_ENABLE=false # or not set
|
19 |
+
python your_script.py
|
20 |
+
|
21 |
+
3. Dynamically control in code:
|
22 |
+
compile_manager.set_compile_enabled(True) # enable
|
23 |
+
compile_manager.set_compile_enabled(False) # disable
|
24 |
+
|
25 |
+
4. Select version at runtime:
|
26 |
+
# use the version configured
|
27 |
+
result = my_function(args)
|
28 |
+
|
29 |
+
# force use the original version
|
30 |
+
result = my_function.original(args)
|
31 |
+
|
32 |
+
# force use the compiled version
|
33 |
+
result = my_function.compiled(args)
|
34 |
+
"""
|
35 |
+
|
36 |
+
# Global configuration: control whether to enable compile through environment variables
|
37 |
+
# Default set this env to true
|
38 |
+
ENABLE_TORCH_COMPILE = os.getenv("ENABLE_TORCH_COMPILE", "false").lower() == "true"
|
39 |
+
|
40 |
+
|
41 |
+
class CompileManager:
|
42 |
+
"""Global controller for torch.compile"""
|
43 |
+
|
44 |
+
def __init__(self):
|
45 |
+
self.compile_enabled = ENABLE_TORCH_COMPILE
|
46 |
+
self.compiled_functions: Dict[str, Callable] = {}
|
47 |
+
self.original_functions: Dict[str, Callable] = {}
|
48 |
+
|
49 |
+
def set_compile_enabled(self, enabled: bool):
|
50 |
+
"""Dynamic setting of whether to enable compile"""
|
51 |
+
self.compile_enabled = enabled
|
52 |
+
|
53 |
+
def get_compile_status(self):
|
54 |
+
"""Get the current compile status"""
|
55 |
+
return self.compile_enabled
|
56 |
+
|
57 |
+
@contextlib.contextmanager
|
58 |
+
def compile_disabled(self):
|
59 |
+
"""Temporarily disable compile within the context"""
|
60 |
+
original_status = self.compile_enabled
|
61 |
+
try:
|
62 |
+
self.compile_enabled = False
|
63 |
+
yield
|
64 |
+
finally:
|
65 |
+
self.compile_enabled = original_status
|
66 |
+
|
67 |
+
|
68 |
+
# global instance
|
69 |
+
compile_manager = CompileManager()
|
70 |
+
|
71 |
+
|
72 |
+
def smart_compile(func: Optional[Callable] = None, **compile_kwargs):
|
73 |
+
"""
|
74 |
+
Smart compile decorator
|
75 |
+
|
76 |
+
Args:
|
77 |
+
func: The function to decorate
|
78 |
+
**compile_kwargs: Other compile parameters, see https://pytorch.org/docs/stable/generated/torch.compile.html
|
79 |
+
"""
|
80 |
+
|
81 |
+
def decorator(fn: Callable) -> Callable:
|
82 |
+
# save the original function
|
83 |
+
original_func = fn
|
84 |
+
# Use qualified name to handle functions with same name in different classes
|
85 |
+
# Include module name to handle functions with same name in different files
|
86 |
+
func_name = f"{fn.__module__}.{fn.__qualname__}"
|
87 |
+
compile_manager.original_functions[func_name] = original_func
|
88 |
+
|
89 |
+
# if compile is disabled, return the original function
|
90 |
+
if not compile_manager.compile_enabled:
|
91 |
+
# add attributes to the original function for later access
|
92 |
+
original_func.original = original_func
|
93 |
+
original_func.compiled = original_func # point to itself
|
94 |
+
return original_func
|
95 |
+
|
96 |
+
# create the compiled function
|
97 |
+
try:
|
98 |
+
compiled_func = torch.compile(original_func, **compile_kwargs)
|
99 |
+
compile_manager.compiled_functions[func_name] = compiled_func
|
100 |
+
except Exception as e:
|
101 |
+
logger.warning(f"[WARNING] Failed to compile function {func_name}: {e}")
|
102 |
+
# if compile fails, revert to the original function
|
103 |
+
compiled_func = original_func
|
104 |
+
|
105 |
+
@functools.wraps(original_func)
|
106 |
+
def wrapper(*args, **kwargs):
|
107 |
+
# check whether to use the compiled version at runtime
|
108 |
+
if compile_manager.compile_enabled:
|
109 |
+
return compiled_func(*args, **kwargs)
|
110 |
+
else:
|
111 |
+
return original_func(*args, **kwargs)
|
112 |
+
|
113 |
+
# add attributes to the wrapper for later access
|
114 |
+
wrapper.original = original_func
|
115 |
+
wrapper.compiled = compiled_func
|
116 |
+
|
117 |
+
return wrapper
|
118 |
+
|
119 |
+
# support direct use of @smart_compile or @smart_compile(...)
|
120 |
+
if func is not None:
|
121 |
+
return decorator(func)
|
122 |
+
return decorator
|
utils/image_utils.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
from typing import Literal, TypeAlias
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL.Image
|
7 |
+
import PIL.ImageOps
|
8 |
+
import requests
|
9 |
+
import torch
|
10 |
+
|
11 |
+
"""
|
12 |
+
- pil: `PIL.Image.Image`, size (w, h), seamless conversion between `uint8`
|
13 |
+
- np: `np.ndarray`, shape (h, w, c), default `np.uint8`
|
14 |
+
- pt: `torch.Tensor`, shape (c, h, w), default `torch.uint8`
|
15 |
+
"""
|
16 |
+
ImageType: TypeAlias = PIL.Image.Image | np.ndarray | torch.Tensor
|
17 |
+
ImageTypeStr: TypeAlias = Literal["pil", "np", "pt"]
|
18 |
+
ImageFormat: TypeAlias = Literal["JPEG", "PNG"]
|
19 |
+
DataFormat: TypeAlias = Literal["255", "01", "11"]
|
20 |
+
|
21 |
+
|
22 |
+
IMG_SUPPORT_MODE = ["L", "LA", "RGB", "RGBA", "CMYK", "P", "1"]
|
23 |
+
IMAGE_EXT_LOWER = ["png", "jpeg", "jpg", "webp"]
|
24 |
+
IMAGE_EXT = IMAGE_EXT_LOWER + [_ext.upper() for _ext in IMAGE_EXT_LOWER]
|
25 |
+
|
26 |
+
|
27 |
+
def check_image_type(image: ImageType):
|
28 |
+
if not (isinstance(image, PIL.Image.Image) or isinstance(image, np.ndarray) or isinstance(image, torch.Tensor)):
|
29 |
+
raise TypeError(f"`image` should be PIL Image, ndarray or Tensor. Got `{type(image)}`.")
|
30 |
+
|
31 |
+
|
32 |
+
def to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
33 |
+
# Automatically adjust the orientation of the image to match the direction it was taken.
|
34 |
+
image = PIL.ImageOps.exif_transpose(image)
|
35 |
+
|
36 |
+
if image.mode not in IMG_SUPPORT_MODE:
|
37 |
+
raise ValueError(f"Only support mode in `{IMG_SUPPORT_MODE}`, got `{image.mode}`")
|
38 |
+
|
39 |
+
if image.mode == "LA":
|
40 |
+
image = image.convert("RGBA")
|
41 |
+
|
42 |
+
# add white background for RGBA images, and convert to RGB
|
43 |
+
if image.mode == "RGBA":
|
44 |
+
background = PIL.Image.new("RGBA", image.size, "white")
|
45 |
+
image = PIL.Image.alpha_composite(background, image).convert("RGB")
|
46 |
+
|
47 |
+
# then convert to RGB
|
48 |
+
image = image.convert("RGB")
|
49 |
+
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
def load_image(
|
54 |
+
image: str | os.PathLike | PIL.Image.Image | bytes,
|
55 |
+
*,
|
56 |
+
output_type: ImageTypeStr = "pil",
|
57 |
+
) -> ImageType:
|
58 |
+
"""
|
59 |
+
Loads `image` to a PIL Image, NumPy array or PyTorch tensor.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
image (str | PIL.Image.Image): The path to image or PIL Image.
|
63 |
+
mode (ImageMode, optional): The mode to convert to. Defaults to None (no conversion).
|
64 |
+
The current version supports all possible conversions between "L", "RGB", "RGBA".
|
65 |
+
output_type (ImageTypeStr, optional): The type of the output image. Defaults to "pil".
|
66 |
+
The current version supports "pil", "np", "pt".
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
ImageType: The loaded image in the given type.
|
70 |
+
"""
|
71 |
+
timeout = 10
|
72 |
+
# Load the `image` into a PIL Image.
|
73 |
+
if isinstance(image, str) or isinstance(image, os.PathLike):
|
74 |
+
if image.startswith("http://") or image.startswith("https://"):
|
75 |
+
try:
|
76 |
+
image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw)
|
77 |
+
except requests.exceptions.Timeout:
|
78 |
+
raise ValueError(f"HTTP request timed out after {timeout} seconds")
|
79 |
+
elif os.path.isfile(image):
|
80 |
+
image = PIL.Image.open(image)
|
81 |
+
else:
|
82 |
+
raise ValueError(
|
83 |
+
f"Incorrect path or url, URLs must start with `http://`, `https://` or `s3+[profile]://`, and `{image}` is not a valid path."
|
84 |
+
)
|
85 |
+
elif isinstance(image, PIL.Image.Image):
|
86 |
+
image = image
|
87 |
+
elif isinstance(image, bytes):
|
88 |
+
image = PIL.Image.open(io.BytesIO(image))
|
89 |
+
else:
|
90 |
+
raise ValueError(f"`image` must be a path or PIL Image, got `{type(image)}`")
|
91 |
+
|
92 |
+
image = to_rgb(image)
|
93 |
+
|
94 |
+
if output_type == "pil":
|
95 |
+
image = image
|
96 |
+
elif output_type == "np":
|
97 |
+
image = to_np(image)
|
98 |
+
elif output_type == "pt":
|
99 |
+
image = to_pt(image)
|
100 |
+
else:
|
101 |
+
raise ValueError(f"`output_type` must be one of `{ImageTypeStr}`, got `{output_type}`")
|
102 |
+
|
103 |
+
return image
|
104 |
+
|
105 |
+
|
106 |
+
def to_pil(image: ImageType, image_mode: DataFormat | None = None) -> PIL.Image.Image:
|
107 |
+
"""
|
108 |
+
Convert a NumPy array or a PyTorch tensor to a PIL image.
|
109 |
+
"""
|
110 |
+
check_image_type(image)
|
111 |
+
|
112 |
+
if isinstance(image, PIL.Image.Image):
|
113 |
+
return image
|
114 |
+
|
115 |
+
elif isinstance(image, np.ndarray):
|
116 |
+
image = normalize_np(image, image_mode)
|
117 |
+
|
118 |
+
elif isinstance(image, torch.Tensor):
|
119 |
+
image = normalize_pt(image, image_mode)
|
120 |
+
|
121 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
122 |
+
assert image.dtype == np.uint8, f"Supposed to convert `torch.uint8` to `np.uint8`, but got `{image.dtype}`"
|
123 |
+
|
124 |
+
mode_map = {1: "L", 3: "RGB"}
|
125 |
+
mode = mode_map[image.shape[-1]]
|
126 |
+
|
127 |
+
if image.shape[-1] == 1:
|
128 |
+
image = image[:, :, 0]
|
129 |
+
|
130 |
+
return PIL.Image.fromarray(image, mode=mode)
|
131 |
+
|
132 |
+
|
133 |
+
def to_np(image: ImageType, image_mode: DataFormat | None = None) -> np.ndarray:
|
134 |
+
"""
|
135 |
+
Convert a PIL image or a PyTorch tensor to a NumPy array.
|
136 |
+
"""
|
137 |
+
check_image_type(image)
|
138 |
+
|
139 |
+
if isinstance(image, PIL.Image.Image):
|
140 |
+
image = np.array(image, np.uint8, copy=True)
|
141 |
+
|
142 |
+
if isinstance(image, np.ndarray):
|
143 |
+
image = normalize_np(image, image_mode)
|
144 |
+
|
145 |
+
elif isinstance(image, torch.Tensor):
|
146 |
+
image = normalize_pt(image, image_mode)
|
147 |
+
|
148 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
149 |
+
assert image.dtype == np.uint8, f"Supposed to convert `torch.uint8` to `np.uint8`, but got `{image.dtype}`"
|
150 |
+
|
151 |
+
return image
|
152 |
+
|
153 |
+
|
154 |
+
def to_pt(image: ImageType, image_mode: DataFormat | None = None) -> torch.Tensor:
|
155 |
+
"""
|
156 |
+
Convert a PIL image or a NumPy array to a PyTorch tensor.
|
157 |
+
"""
|
158 |
+
check_image_type(image)
|
159 |
+
|
160 |
+
if isinstance(image, torch.Tensor):
|
161 |
+
image = normalize_pt(image, image_mode)
|
162 |
+
return image
|
163 |
+
|
164 |
+
# convert PIL Image to NumPy array
|
165 |
+
if isinstance(image, PIL.Image.Image):
|
166 |
+
image = np.array(image, np.uint8, copy=True)
|
167 |
+
|
168 |
+
image = normalize_np(image, image_mode)
|
169 |
+
|
170 |
+
image = torch.from_numpy(image.transpose((2, 0, 1))).contiguous()
|
171 |
+
assert image.dtype == torch.uint8, f"Supposed to convert `np.uint8` to `torch.uint8`, but got `{image.dtype}`"
|
172 |
+
return image
|
173 |
+
|
174 |
+
|
175 |
+
def normalize_np(image: np.ndarray, image_mode: DataFormat | None = None) -> np.ndarray:
|
176 |
+
"""
|
177 |
+
Normalize a NumPy array to the standard format of shape (h, w, c) and uint8.
|
178 |
+
"""
|
179 |
+
if image.ndim not in {2, 3}:
|
180 |
+
raise ValueError(f"`image` should be 2 or 3 dimensions. Got {image.ndim} dimensions.")
|
181 |
+
|
182 |
+
elif image.ndim == 2:
|
183 |
+
# if 2D image, add channel dimension (HWC)
|
184 |
+
image = np.expand_dims(image, 2)
|
185 |
+
|
186 |
+
if image.shape[-1] not in {1, 3}:
|
187 |
+
raise ValueError(f"`image` should have 1 (`L`) or 3 (`RGB`) channels. Got {image.shape[-1]} channels.")
|
188 |
+
|
189 |
+
image = to_dataformat(image, image_mode=image_mode, mode="255")
|
190 |
+
|
191 |
+
return image
|
192 |
+
|
193 |
+
|
194 |
+
def normalize_pt(image: torch.Tensor, image_mode: DataFormat | None = None) -> torch.Tensor:
|
195 |
+
"""
|
196 |
+
Normalize a PyTorch tensor to the standard format of shape (c, h, w) and uint8.
|
197 |
+
"""
|
198 |
+
if image.ndimension() not in {2, 3}:
|
199 |
+
raise ValueError(f"`image` should be 2 or 3 dimensions. Got {image.ndimension()} dimensions.")
|
200 |
+
|
201 |
+
elif image.ndimension() == 2:
|
202 |
+
# if 2D image, add channel dimension (CHW)
|
203 |
+
image = image.unsqueeze(0)
|
204 |
+
|
205 |
+
# check number of channels
|
206 |
+
if image.shape[-3] not in {1, 3}:
|
207 |
+
raise ValueError(f"`image` should have 1 (`L`) or 3 (`RGB`) channels. Got {image.shape[-3]} channels.")
|
208 |
+
|
209 |
+
image = to_dataformat(image, image_mode=image_mode, mode="255")
|
210 |
+
|
211 |
+
return image
|
212 |
+
|
213 |
+
|
214 |
+
def to_dataformat(
|
215 |
+
image: ImageType,
|
216 |
+
*,
|
217 |
+
image_mode: DataFormat | None = None,
|
218 |
+
mode: DataFormat = "255",
|
219 |
+
) -> np.ndarray | torch.Tensor:
|
220 |
+
check_image_type(image)
|
221 |
+
|
222 |
+
# convert PIL Image to NumPy array
|
223 |
+
if isinstance(image, PIL.Image.Image):
|
224 |
+
image = np.array(image, np.uint8, copy=True)
|
225 |
+
image_mode = "255"
|
226 |
+
|
227 |
+
# guess image mode
|
228 |
+
if image.dtype == np.uint8 or image.dtype == torch.uint8:
|
229 |
+
guess_image_mode = "255"
|
230 |
+
elif image.dtype == np.float32 or image.dtype == np.float16 or image.dtype == torch.float32 or image.dtype == torch.float16:
|
231 |
+
if image.min() < 0.0:
|
232 |
+
guess_image_mode = "11"
|
233 |
+
else:
|
234 |
+
guess_image_mode = "01"
|
235 |
+
else:
|
236 |
+
raise ValueError(f"Unsupported dtype `{image.dtype}`")
|
237 |
+
|
238 |
+
if image_mode is None:
|
239 |
+
image_mode = guess_image_mode
|
240 |
+
else:
|
241 |
+
if guess_image_mode != image_mode:
|
242 |
+
print(f"Guess image mode is `{guess_image_mode}`, but image mode is `{image_mode}`")
|
243 |
+
|
244 |
+
if isinstance(image, np.ndarray):
|
245 |
+
if image_mode == "255" and mode != "255":
|
246 |
+
np.clip((image.astype(np.float32) / 255), 0, 1, out=image)
|
247 |
+
if mode == "11":
|
248 |
+
np.clip((image * 2 - 1), -1, 1, out=image)
|
249 |
+
|
250 |
+
elif image_mode == "01" and mode != "01":
|
251 |
+
if mode == "255":
|
252 |
+
np.clip(image, 0, 1, out=image)
|
253 |
+
image = (image * 255).round().astype(np.uint8)
|
254 |
+
elif mode == "11":
|
255 |
+
np.clip((image * 2 - 1), -1, 1, out=image)
|
256 |
+
|
257 |
+
elif image_mode == "11" and mode != "11":
|
258 |
+
np.clip((image / 2 + 0.5), 0, 1, out=image)
|
259 |
+
if mode == "255":
|
260 |
+
image = (image * 255).round().astype(np.uint8)
|
261 |
+
|
262 |
+
elif isinstance(image, torch.Tensor):
|
263 |
+
if image_mode == "255" and mode != "255":
|
264 |
+
image = image.to(dtype=torch.float32).div(255).clamp(0, 1)
|
265 |
+
if mode == "11":
|
266 |
+
image = (image * 2 - 1).clamp(-1, 1)
|
267 |
+
|
268 |
+
elif image_mode == "01" and mode != "01":
|
269 |
+
if mode == "255":
|
270 |
+
image = image.clamp(0, 1)
|
271 |
+
image = (image * 255).round().to(dtype=torch.uint8)
|
272 |
+
elif mode == "11":
|
273 |
+
image = (image * 2 - 1).clamp(-1, 1)
|
274 |
+
|
275 |
+
elif image_mode == "11" and mode != "11":
|
276 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
277 |
+
if mode == "255":
|
278 |
+
image = image.mul(255).round().to(dtype=torch.uint8)
|
279 |
+
|
280 |
+
return image
|
281 |
+
|
282 |
+
|
283 |
+
def resize_image(pil_image, image_size):
|
284 |
+
while min(*pil_image.size) >= 2 * image_size:
|
285 |
+
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=PIL.Image.BOX)
|
286 |
+
|
287 |
+
scale = image_size / min(*pil_image.size)
|
288 |
+
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=PIL.Image.BICUBIC)
|
289 |
+
return pil_image
|
290 |
+
|
291 |
+
|
292 |
+
def center_crop_arr(pil_image, image_size, crop=True):
|
293 |
+
"""
|
294 |
+
Center cropping implementation from ADM.
|
295 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
296 |
+
"""
|
297 |
+
if crop:
|
298 |
+
pil_image = resize_image(pil_image, image_size)
|
299 |
+
arr = np.array(pil_image)
|
300 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
301 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
302 |
+
return PIL.Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
303 |
+
else:
|
304 |
+
# 将图像填充为正方形
|
305 |
+
width, height = pil_image.size
|
306 |
+
if width != height:
|
307 |
+
# 创建一个正方形画布,尺寸为较大的边长
|
308 |
+
max_dim = max(width, height)
|
309 |
+
padded_img = PIL.Image.new(pil_image.mode, (max_dim, max_dim), (0, 0, 0))
|
310 |
+
# 将原图居中粘贴到正方形画布上
|
311 |
+
padded_img.paste(pil_image, ((max_dim - width) // 2, (max_dim - height) // 2))
|
312 |
+
pil_image = padded_img
|
313 |
+
pil_image = resize_image(pil_image, image_size)
|
314 |
+
return pil_image
|
utils/misc.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def set_seed(seed: int, rank: int = 0):
|
9 |
+
random.seed(seed + rank)
|
10 |
+
np.random.seed(seed + rank)
|
11 |
+
torch.manual_seed(seed + rank)
|
12 |
+
torch.cuda.manual_seed_all(seed + rank)
|
13 |
+
torch.backends.cudnn.deterministic = True
|
14 |
+
os.environ["PYTHONHASHSEED"] = str(seed + rank)
|
15 |
+
|
16 |
+
class LargeInt(int):
|
17 |
+
def __new__(cls, value):
|
18 |
+
if isinstance(value, str):
|
19 |
+
units = {"K": 1e3, "M": 1e6, "B": 1e9, "T": 1e12}
|
20 |
+
last_char = value[-1].upper()
|
21 |
+
if last_char in units:
|
22 |
+
num = float(value[:-1]) * units[last_char]
|
23 |
+
return super(LargeInt, cls).__new__(cls, int(num))
|
24 |
+
else:
|
25 |
+
return super(LargeInt, cls).__new__(cls, int(value))
|
26 |
+
else:
|
27 |
+
return super(LargeInt, cls).__new__(cls, value)
|
28 |
+
|
29 |
+
def __str__(self):
|
30 |
+
value = int(self)
|
31 |
+
if abs(value) < 1000:
|
32 |
+
return f"{value}"
|
33 |
+
for unit in ["", "K", "M", "B", "T"]:
|
34 |
+
if abs(value) < 1000:
|
35 |
+
return f"{value:.1f}{unit}"
|
36 |
+
value /= 1000
|
37 |
+
return f"{value:.1f}P" # P stands for Peta, or 10^15
|
38 |
+
|
39 |
+
def __repr__(self):
|
40 |
+
return f'"{self.__str__()}"' # Ensure repr also returns the string with quotes
|
41 |
+
|
42 |
+
def __json__(self):
|
43 |
+
return f'"{self.__str__()}"'
|
44 |
+
|
45 |
+
def __add__(self, other):
|
46 |
+
if isinstance(other, int):
|
47 |
+
return LargeInt(super().__add__(other))
|
48 |
+
return NotImplemented
|
49 |
+
|
50 |
+
def __radd__(self, other):
|
51 |
+
return self.__add__(other) # This ensures commutativity
|
utils/model_utils.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0):
|
6 |
+
"""
|
7 |
+
grid_size: int of the grid height and width
|
8 |
+
return:
|
9 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
10 |
+
"""
|
11 |
+
grid_h = np.arange(grid_size, dtype=np.float32) / pe_interpolation
|
12 |
+
grid_w = np.arange(grid_size, dtype=np.float32) / pe_interpolation
|
13 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
14 |
+
grid = np.stack(grid, axis=0)
|
15 |
+
|
16 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
17 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
18 |
+
if cls_token and extra_tokens > 0:
|
19 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
20 |
+
return pos_embed
|
21 |
+
|
22 |
+
|
23 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
24 |
+
assert embed_dim % 2 == 0
|
25 |
+
|
26 |
+
# use half of dimensions to encode grid_h
|
27 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
28 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
29 |
+
|
30 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
31 |
+
return emb
|
32 |
+
|
33 |
+
|
34 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
35 |
+
"""
|
36 |
+
embed_dim: output dimension for each position
|
37 |
+
pos: a list of positions to be encoded: size (M,)
|
38 |
+
out: (M, D)
|
39 |
+
"""
|
40 |
+
assert embed_dim % 2 == 0
|
41 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
42 |
+
omega /= embed_dim / 2.0
|
43 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
44 |
+
|
45 |
+
pos = pos.reshape(-1) # (M,)
|
46 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
47 |
+
|
48 |
+
emb_sin = np.sin(out) # (M, D/2)
|
49 |
+
emb_cos = np.cos(out) # (M, D/2)
|
50 |
+
|
51 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
52 |
+
return emb
|
53 |
+
|
54 |
+
|
55 |
+
def expand_t(t, x):
|
56 |
+
"""Function to reshape time t to broadcastable dimension of x
|
57 |
+
Args:
|
58 |
+
t: [bsz,], time vector
|
59 |
+
x: [bsz,...], data point
|
60 |
+
"""
|
61 |
+
dims = [1] * (len(x.size()) - 1)
|
62 |
+
t = t.view(t.size(0), *dims)
|
63 |
+
return t
|
64 |
+
|
65 |
+
|
66 |
+
def randn_tensor(shape, noise_repeat, device, dtype=torch.float32):
|
67 |
+
bsz = shape[0]
|
68 |
+
if bsz % noise_repeat != 0:
|
69 |
+
raise ValueError(f"Batch size ({bsz}) must be divisible by noise repeat ({noise_repeat})")
|
70 |
+
_shape = (noise_repeat,) + shape[1:]
|
71 |
+
_tensor = torch.randn(_shape, device=device, dtype=dtype).repeat(bsz // noise_repeat, 1)
|
72 |
+
return _tensor
|
73 |
+
|
74 |
+
|
75 |
+
def rotate_half(x):
|
76 |
+
"""Rotates half the hidden dims of the input."""
|
77 |
+
x1 = x[..., : x.shape[-1] // 2]
|
78 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
79 |
+
return torch.cat((-x2, x1), dim=-1)
|
80 |
+
|
81 |
+
|
82 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
83 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
84 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
85 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
86 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
87 |
+
return q_embed, k_embed
|
88 |
+
|
89 |
+
|
90 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
91 |
+
"""
|
92 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
93 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
94 |
+
"""
|
95 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
96 |
+
if n_rep == 1:
|
97 |
+
return hidden_states
|
98 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
99 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
100 |
+
|
101 |
+
|
102 |
+
def identity(input: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
103 |
+
return input
|
104 |
+
|
105 |
+
|
106 |
+
def rms_norm(
|
107 |
+
input: torch.Tensor,
|
108 |
+
normalized_shape: torch.Size,
|
109 |
+
eps: float = 1e-6,
|
110 |
+
) -> torch.Tensor:
|
111 |
+
dtype = input.dtype
|
112 |
+
input = input.to(torch.float32)
|
113 |
+
variance = input.flatten(-len(normalized_shape)).pow(2).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
|
114 |
+
input = input * torch.rsqrt(variance + eps)
|
115 |
+
return input.to(dtype)
|
116 |
+
|
117 |
+
|
118 |
+
def layer_norm(
|
119 |
+
input: torch.Tensor,
|
120 |
+
normalized_shape: torch.Size,
|
121 |
+
eps: float = 1e-6,
|
122 |
+
) -> torch.Tensor:
|
123 |
+
dtype = input.dtype
|
124 |
+
input = input.to(torch.float32)
|
125 |
+
mean = input.flatten(-len(normalized_shape)).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
|
126 |
+
variance = (input - mean).flatten(-len(normalized_shape)).pow(2).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
|
127 |
+
input = (input - mean) * torch.rsqrt(variance + eps)
|
128 |
+
return input.to(dtype)
|
vae/checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99293255229a29297e2851858db3794497d1b0b09b20c308c1062636ea4bcdd9
|
3 |
+
size 335365010
|
vae/config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resolution": 256,
|
3 |
+
"in_channels": 3,
|
4 |
+
"ch": 128,
|
5 |
+
"out_ch": 3,
|
6 |
+
"ch_mult": [1, 2, 4, 4],
|
7 |
+
"num_res_blocks": 2,
|
8 |
+
"z_channels": 16,
|
9 |
+
"shift_factor": 0,
|
10 |
+
"scaling_factor": 1,
|
11 |
+
"deterministic": true,
|
12 |
+
"encoder_norm": true,
|
13 |
+
"psz": 1
|
14 |
+
}
|
vae/nextstep_ae.py
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import inspect
|
4 |
+
from dataclasses import dataclass, field, asdict
|
5 |
+
from loguru import logger
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from tabulate import tabulate
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import Tensor
|
14 |
+
from torch.utils.checkpoint import checkpoint
|
15 |
+
|
16 |
+
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
17 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
18 |
+
|
19 |
+
from utils.misc import LargeInt
|
20 |
+
from utils.model_utils import randn_tensor
|
21 |
+
from utils.compile_utils import smart_compile
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class AutoEncoderParams:
|
26 |
+
resolution: int = 256
|
27 |
+
in_channels: int = 3
|
28 |
+
ch: int = 128
|
29 |
+
out_ch: int = 3
|
30 |
+
ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
|
31 |
+
num_res_blocks: int = 2
|
32 |
+
z_channels: int = 16
|
33 |
+
scaling_factor: float = 0.3611
|
34 |
+
shift_factor: float = 0.1159
|
35 |
+
deterministic: bool = False
|
36 |
+
encoder_norm: bool = False
|
37 |
+
psz: int | None = None
|
38 |
+
|
39 |
+
|
40 |
+
def swish(x: Tensor) -> Tensor:
|
41 |
+
return x * torch.sigmoid(x)
|
42 |
+
|
43 |
+
|
44 |
+
class AttnBlock(nn.Module):
|
45 |
+
def __init__(self, in_channels: int):
|
46 |
+
super().__init__()
|
47 |
+
self.in_channels = in_channels
|
48 |
+
|
49 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
50 |
+
|
51 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
52 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
53 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
54 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
55 |
+
|
56 |
+
def attention(self, h_: Tensor) -> Tensor:
|
57 |
+
h_ = self.norm(h_)
|
58 |
+
q = self.q(h_)
|
59 |
+
k = self.k(h_)
|
60 |
+
v = self.v(h_)
|
61 |
+
|
62 |
+
b, c, h, w = q.shape
|
63 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
64 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
65 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
66 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
67 |
+
|
68 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
69 |
+
|
70 |
+
def forward(self, x: Tensor) -> Tensor:
|
71 |
+
return x + self.proj_out(self.attention(x))
|
72 |
+
|
73 |
+
|
74 |
+
class ResnetBlock(nn.Module):
|
75 |
+
def __init__(self, in_channels: int, out_channels: int):
|
76 |
+
super().__init__()
|
77 |
+
self.in_channels = in_channels
|
78 |
+
out_channels = in_channels if out_channels is None else out_channels
|
79 |
+
self.out_channels = out_channels
|
80 |
+
|
81 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
82 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
83 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
84 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
85 |
+
if self.in_channels != self.out_channels:
|
86 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
h = x
|
90 |
+
h = self.norm1(h)
|
91 |
+
h = swish(h)
|
92 |
+
h = self.conv1(h)
|
93 |
+
|
94 |
+
h = self.norm2(h)
|
95 |
+
h = swish(h)
|
96 |
+
h = self.conv2(h)
|
97 |
+
|
98 |
+
if self.in_channels != self.out_channels:
|
99 |
+
x = self.nin_shortcut(x)
|
100 |
+
|
101 |
+
return x + h
|
102 |
+
|
103 |
+
|
104 |
+
class Downsample(nn.Module):
|
105 |
+
def __init__(self, in_channels: int):
|
106 |
+
super().__init__()
|
107 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
108 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
109 |
+
|
110 |
+
def forward(self, x: Tensor):
|
111 |
+
pad = (0, 1, 0, 1)
|
112 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
113 |
+
x = self.conv(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class Upsample(nn.Module):
|
118 |
+
def __init__(self, in_channels: int):
|
119 |
+
super().__init__()
|
120 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
121 |
+
|
122 |
+
def forward(self, x: Tensor):
|
123 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
124 |
+
x = self.conv(x)
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
class Encoder(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
resolution: int,
|
132 |
+
in_channels: int,
|
133 |
+
ch: int,
|
134 |
+
ch_mult: list[int],
|
135 |
+
num_res_blocks: int,
|
136 |
+
z_channels: int,
|
137 |
+
):
|
138 |
+
super().__init__()
|
139 |
+
self.ch = ch
|
140 |
+
self.num_resolutions = len(ch_mult)
|
141 |
+
self.num_res_blocks = num_res_blocks
|
142 |
+
self.resolution = resolution
|
143 |
+
self.in_channels = in_channels
|
144 |
+
# downsampling
|
145 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
146 |
+
|
147 |
+
curr_res = resolution
|
148 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
149 |
+
self.in_ch_mult = in_ch_mult
|
150 |
+
self.down = nn.ModuleList()
|
151 |
+
block_in = self.ch
|
152 |
+
for i_level in range(self.num_resolutions):
|
153 |
+
block = nn.ModuleList()
|
154 |
+
attn = nn.ModuleList()
|
155 |
+
block_in = ch * in_ch_mult[i_level]
|
156 |
+
block_out = ch * ch_mult[i_level]
|
157 |
+
for _ in range(self.num_res_blocks):
|
158 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
159 |
+
block_in = block_out
|
160 |
+
down = nn.Module()
|
161 |
+
down.block = block
|
162 |
+
down.attn = attn
|
163 |
+
if i_level != self.num_resolutions - 1:
|
164 |
+
down.downsample = Downsample(block_in)
|
165 |
+
curr_res = curr_res // 2
|
166 |
+
self.down.append(down)
|
167 |
+
|
168 |
+
# middle
|
169 |
+
self.mid = nn.Module()
|
170 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
171 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
172 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
173 |
+
|
174 |
+
# end
|
175 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
176 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
177 |
+
|
178 |
+
self.grad_checkpointing = False
|
179 |
+
|
180 |
+
@smart_compile()
|
181 |
+
def forward(self, x: Tensor) -> Tensor:
|
182 |
+
# downsampling
|
183 |
+
hs = [self.conv_in(x)]
|
184 |
+
for i_level in range(self.num_resolutions):
|
185 |
+
for i_block in range(self.num_res_blocks):
|
186 |
+
block_fn = self.down[i_level].block[i_block]
|
187 |
+
if self.grad_checkpointing:
|
188 |
+
h = checkpoint(block_fn, hs[-1])
|
189 |
+
else:
|
190 |
+
h = block_fn(hs[-1])
|
191 |
+
if len(self.down[i_level].attn) > 0:
|
192 |
+
attn_fn = self.down[i_level].attn[i_block]
|
193 |
+
if self.grad_checkpointing:
|
194 |
+
h = checkpoint(attn_fn, h)
|
195 |
+
else:
|
196 |
+
h = attn_fn(h)
|
197 |
+
hs.append(h)
|
198 |
+
if i_level != self.num_resolutions - 1:
|
199 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
200 |
+
|
201 |
+
# middle
|
202 |
+
h = hs[-1]
|
203 |
+
h = self.mid.block_1(h)
|
204 |
+
h = self.mid.attn_1(h)
|
205 |
+
h = self.mid.block_2(h)
|
206 |
+
# end
|
207 |
+
h = self.norm_out(h)
|
208 |
+
h = swish(h)
|
209 |
+
h = self.conv_out(h)
|
210 |
+
return h
|
211 |
+
|
212 |
+
|
213 |
+
class Decoder(nn.Module):
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
ch: int,
|
217 |
+
out_ch: int,
|
218 |
+
ch_mult: list[int],
|
219 |
+
num_res_blocks: int,
|
220 |
+
in_channels: int,
|
221 |
+
resolution: int,
|
222 |
+
z_channels: int,
|
223 |
+
):
|
224 |
+
super().__init__()
|
225 |
+
self.ch = ch
|
226 |
+
self.num_resolutions = len(ch_mult)
|
227 |
+
self.num_res_blocks = num_res_blocks
|
228 |
+
self.resolution = resolution
|
229 |
+
self.in_channels = in_channels
|
230 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
231 |
+
|
232 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
233 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
234 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
235 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
236 |
+
|
237 |
+
# z to block_in
|
238 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
239 |
+
|
240 |
+
# middle
|
241 |
+
self.mid = nn.Module()
|
242 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
243 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
244 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
245 |
+
|
246 |
+
# upsampling
|
247 |
+
self.up = nn.ModuleList()
|
248 |
+
for i_level in reversed(range(self.num_resolutions)):
|
249 |
+
block = nn.ModuleList()
|
250 |
+
attn = nn.ModuleList()
|
251 |
+
block_out = ch * ch_mult[i_level]
|
252 |
+
for _ in range(self.num_res_blocks + 1):
|
253 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
254 |
+
block_in = block_out
|
255 |
+
up = nn.Module()
|
256 |
+
up.block = block
|
257 |
+
up.attn = attn
|
258 |
+
if i_level != 0:
|
259 |
+
up.upsample = Upsample(block_in)
|
260 |
+
curr_res = curr_res * 2
|
261 |
+
self.up.insert(0, up) # prepend to get consistent order
|
262 |
+
|
263 |
+
# end
|
264 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
265 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
266 |
+
|
267 |
+
self.grad_checkpointing = False
|
268 |
+
|
269 |
+
@smart_compile()
|
270 |
+
def forward(self, z: Tensor) -> Tensor:
|
271 |
+
# get dtype for proper tracing
|
272 |
+
upscale_dtype = next(self.up.parameters()).dtype
|
273 |
+
|
274 |
+
# z to block_in
|
275 |
+
h = self.conv_in(z)
|
276 |
+
|
277 |
+
# middle
|
278 |
+
h = self.mid.block_1(h)
|
279 |
+
h = self.mid.attn_1(h)
|
280 |
+
h = self.mid.block_2(h)
|
281 |
+
|
282 |
+
# cast to proper dtype
|
283 |
+
h = h.to(upscale_dtype)
|
284 |
+
# upsampling
|
285 |
+
for i_level in reversed(range(self.num_resolutions)):
|
286 |
+
for i_block in range(self.num_res_blocks + 1):
|
287 |
+
block_fn = self.up[i_level].block[i_block]
|
288 |
+
if self.grad_checkpointing:
|
289 |
+
h = checkpoint(block_fn, h)
|
290 |
+
else:
|
291 |
+
h = block_fn(h)
|
292 |
+
if len(self.up[i_level].attn) > 0:
|
293 |
+
attn_fn = self.up[i_level].attn[i_block]
|
294 |
+
if self.grad_checkpointing:
|
295 |
+
h = checkpoint(attn_fn, h)
|
296 |
+
else:
|
297 |
+
h = attn_fn(h)
|
298 |
+
if i_level != 0:
|
299 |
+
h = self.up[i_level].upsample(h)
|
300 |
+
|
301 |
+
# end
|
302 |
+
h = self.norm_out(h)
|
303 |
+
h = swish(h)
|
304 |
+
h = self.conv_out(h)
|
305 |
+
return h
|
306 |
+
|
307 |
+
|
308 |
+
def layer_norm_2d(input: torch.Tensor, normalized_shape: torch.Size, eps: float = 1e-6) -> torch.Tensor:
|
309 |
+
# input.shape = (bsz, c, h, w)
|
310 |
+
_input = input.permute(0, 2, 3, 1)
|
311 |
+
_input = F.layer_norm(_input, normalized_shape, None, None, eps)
|
312 |
+
_input = _input.permute(0, 3, 1, 2)
|
313 |
+
return _input
|
314 |
+
|
315 |
+
|
316 |
+
class AutoencoderKL(nn.Module):
|
317 |
+
def __init__(self, params: AutoEncoderParams):
|
318 |
+
super().__init__()
|
319 |
+
self.config = params
|
320 |
+
self.config = OmegaConf.create(asdict(self.config))
|
321 |
+
self.config.latent_channels = params.z_channels
|
322 |
+
self.config.block_out_channels = params.ch_mult
|
323 |
+
|
324 |
+
self.params = params
|
325 |
+
self.encoder = Encoder(
|
326 |
+
resolution=params.resolution,
|
327 |
+
in_channels=params.in_channels,
|
328 |
+
ch=params.ch,
|
329 |
+
ch_mult=params.ch_mult,
|
330 |
+
num_res_blocks=params.num_res_blocks,
|
331 |
+
z_channels=params.z_channels,
|
332 |
+
)
|
333 |
+
self.decoder = Decoder(
|
334 |
+
resolution=params.resolution,
|
335 |
+
in_channels=params.in_channels,
|
336 |
+
ch=params.ch,
|
337 |
+
out_ch=params.out_ch,
|
338 |
+
ch_mult=params.ch_mult,
|
339 |
+
num_res_blocks=params.num_res_blocks,
|
340 |
+
z_channels=params.z_channels,
|
341 |
+
)
|
342 |
+
|
343 |
+
self.encoder_norm = params.encoder_norm
|
344 |
+
self.psz = params.psz
|
345 |
+
|
346 |
+
self.apply(self._init_weights)
|
347 |
+
|
348 |
+
def _init_weights(self, module):
|
349 |
+
std = 0.02
|
350 |
+
if isinstance(module, (nn.Conv2d, nn.Linear)):
|
351 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
352 |
+
if module.bias is not None:
|
353 |
+
module.bias.data.zero_()
|
354 |
+
elif isinstance(module, nn.GroupNorm):
|
355 |
+
if module.weight is not None:
|
356 |
+
module.weight.data.fill_(1.0)
|
357 |
+
if module.bias is not None:
|
358 |
+
module.bias.data.zero_()
|
359 |
+
|
360 |
+
def gradient_checkpointing_enable(self):
|
361 |
+
self.encoder.grad_checkpointing = True
|
362 |
+
self.decoder.grad_checkpointing = True
|
363 |
+
|
364 |
+
@property
|
365 |
+
def dtype(self):
|
366 |
+
return self.encoder.conv_in.weight.dtype
|
367 |
+
|
368 |
+
@property
|
369 |
+
def device(self):
|
370 |
+
return self.encoder.conv_in.weight.device
|
371 |
+
|
372 |
+
@property
|
373 |
+
def trainable_params(self) -> float:
|
374 |
+
n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
375 |
+
return LargeInt(n_params)
|
376 |
+
|
377 |
+
@property
|
378 |
+
def params_info(self) -> str:
|
379 |
+
encoder_params = str(LargeInt(sum(p.numel() for p in self.encoder.parameters())))
|
380 |
+
decoder_params = str(LargeInt(sum(p.numel() for p in self.decoder.parameters())))
|
381 |
+
table = [["encoder", encoder_params], ["decoder", decoder_params]]
|
382 |
+
return tabulate(table, headers=["Module", "Params"], tablefmt="grid")
|
383 |
+
|
384 |
+
def get_last_layer(self):
|
385 |
+
return self.decoder.conv_out.weight
|
386 |
+
|
387 |
+
def patchify(self, img: torch.Tensor):
|
388 |
+
"""
|
389 |
+
img: (bsz, C, H, W)
|
390 |
+
x: (bsz, patch_size**2 * C, H / patch_size, W / patch_size)
|
391 |
+
"""
|
392 |
+
bsz, c, h, w = img.shape
|
393 |
+
p = self.psz
|
394 |
+
h_, w_ = h // p, w // p
|
395 |
+
|
396 |
+
img = img.reshape(bsz, c, h_, p, w_, p)
|
397 |
+
img = torch.einsum("nchpwq->ncpqhw", img)
|
398 |
+
x = img.reshape(bsz, c * p**2, h_, w_)
|
399 |
+
return x
|
400 |
+
|
401 |
+
def unpatchify(self, x: torch.Tensor):
|
402 |
+
"""
|
403 |
+
x: (bsz, patch_size**2 * C, H / patch_size, W / patch_size)
|
404 |
+
img: (bsz, C, H, W)
|
405 |
+
"""
|
406 |
+
bsz = x.shape[0]
|
407 |
+
p = self.psz
|
408 |
+
c = self.config.latent_channels
|
409 |
+
h_, w_ = x.shape[2], x.shape[3]
|
410 |
+
|
411 |
+
x = x.reshape(bsz, c, p, p, h_, w_)
|
412 |
+
x = torch.einsum("ncpqhw->nchpwq", x)
|
413 |
+
img = x.reshape(bsz, c, h_ * p, w_ * p)
|
414 |
+
return img
|
415 |
+
|
416 |
+
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
417 |
+
moments = self.encoder(x)
|
418 |
+
|
419 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
420 |
+
if self.psz is not None:
|
421 |
+
mean = self.patchify(mean)
|
422 |
+
|
423 |
+
if self.encoder_norm:
|
424 |
+
mean = layer_norm_2d(mean, mean.size()[-1:])
|
425 |
+
|
426 |
+
if self.psz is not None:
|
427 |
+
mean = self.unpatchify(mean)
|
428 |
+
|
429 |
+
moments = torch.cat([mean, logvar], dim=1).contiguous()
|
430 |
+
|
431 |
+
posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
|
432 |
+
|
433 |
+
if not return_dict:
|
434 |
+
return (posterior,)
|
435 |
+
|
436 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
437 |
+
|
438 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True):
|
439 |
+
dec = self.decoder(z)
|
440 |
+
|
441 |
+
if not return_dict:
|
442 |
+
return (dec,)
|
443 |
+
|
444 |
+
return DecoderOutput(sample=dec)
|
445 |
+
|
446 |
+
def forward(self, input, sample_posterior=True, noise_strength=0.0):
|
447 |
+
posterior = self.encode(input).latent_dist
|
448 |
+
z = posterior.sample() if sample_posterior else posterior.mode()
|
449 |
+
if noise_strength > 0.0:
|
450 |
+
p = torch.distributions.Uniform(0, noise_strength)
|
451 |
+
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
|
452 |
+
z.shape, device=z.device, dtype=z.dtype
|
453 |
+
)
|
454 |
+
dec = self.decode(z).sample
|
455 |
+
return dec, posterior
|
456 |
+
|
457 |
+
@classmethod
|
458 |
+
def from_pretrained(cls, model_path, **kwargs):
|
459 |
+
config_path = os.path.join(model_path, "config.json")
|
460 |
+
ckpt_path = os.path.join(model_path, "checkpoint.pt")
|
461 |
+
|
462 |
+
if not os.path.isdir(model_path) or not os.path.isfile(config_path) or not os.path.isfile(ckpt_path):
|
463 |
+
raise ValueError(
|
464 |
+
f"Invalid model path: {model_path}. The path should contain both config.json and checkpoint.pt files."
|
465 |
+
)
|
466 |
+
|
467 |
+
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
468 |
+
|
469 |
+
with open(config_path, "r") as f:
|
470 |
+
config: dict = json.load(f)
|
471 |
+
config.update(kwargs)
|
472 |
+
kwargs = config
|
473 |
+
|
474 |
+
# Filter out kwargs that are not in AutoEncoderParams
|
475 |
+
# This ensures we only pass parameters that the model can accept
|
476 |
+
valid_kwargs = {}
|
477 |
+
param_signature = inspect.signature(AutoEncoderParams.__init__).parameters
|
478 |
+
for key, value in kwargs.items():
|
479 |
+
if key in param_signature:
|
480 |
+
valid_kwargs[key] = value
|
481 |
+
else:
|
482 |
+
logger.info(f"Ignoring parameter '{key}' as it's not defined in AutoEncoderParams")
|
483 |
+
|
484 |
+
params = AutoEncoderParams(**valid_kwargs)
|
485 |
+
model = cls(params)
|
486 |
+
try:
|
487 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
488 |
+
logger.info(f"Loaded state_dict from {ckpt_path}")
|
489 |
+
logger.info(f"Missing keys:\n{msg.missing_keys}")
|
490 |
+
logger.info(f"Unexpected keys:\n{msg.unexpected_keys}")
|
491 |
+
except Exception as e:
|
492 |
+
logger.error(e)
|
493 |
+
logger.warning(f"Failed to load state_dict from {ckpt_path}, using random initialization")
|
494 |
+
return model
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|