Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
State dicts generated with:
|
2 |
+
|
3 |
+
```py
|
4 |
+
from diffusers import DiffusionPipeline
|
5 |
+
import torch
|
6 |
+
from peft import LoraConfig
|
7 |
+
from peft.utils import get_peft_model_state_dict
|
8 |
+
from huggingface_hub import create_repo, upload_file
|
9 |
+
import tempfile
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
ckpts = [
|
14 |
+
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
15 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
16 |
+
"black-forest-labs/FLUX.1-dev"
|
17 |
+
]
|
18 |
+
|
19 |
+
ranks = [16, 32, 128]
|
20 |
+
|
21 |
+
repo_id = create_repo(repo_id="sayakpaul/dummy-lora-state-dicts", exist_ok=True).repo_id
|
22 |
+
|
23 |
+
def get_lora_config(rank=16):
|
24 |
+
return LoraConfig(
|
25 |
+
r=rank,
|
26 |
+
lora_alpha=rank,
|
27 |
+
init_lora_weights="gaussian",
|
28 |
+
target_modules=["to_k", "to_v", "to_q", "to_out.0"],
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def load_pipeline_and_obtain_lora(ckpt, rank):
|
33 |
+
pipeline = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.bfloat16)
|
34 |
+
pipeline_cls = pipeline.__class__
|
35 |
+
|
36 |
+
lora_config = get_lora_config(rank=rank)
|
37 |
+
weight_name = f"r@{rank}-{ckpt.split('/')[-1]}.safetensors"
|
38 |
+
|
39 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
40 |
+
save_kwargs = {"weight_name": weight_name}
|
41 |
+
if hasattr(pipeline, "unet"):
|
42 |
+
pipeline.unet.add_adapter(lora_config)
|
43 |
+
save_kwargs.update({"unet_lora_layers": get_peft_model_state_dict(pipeline.unet)})
|
44 |
+
else:
|
45 |
+
pipeline.transformer.add_adapter(lora_config)
|
46 |
+
save_kwargs.update({"transformer_lora_layers": get_peft_model_state_dict(pipeline.transformer)})
|
47 |
+
|
48 |
+
pipeline_cls.save_lora_weights(save_directory=tmpdir, **save_kwargs)
|
49 |
+
upload_file(repo_id=repo_id, path_or_fileobj=os.path.join(tmpdir, weight_name), path_in_repo=weight_name)
|
50 |
+
|
51 |
+
|
52 |
+
for ckpt in ckpts:
|
53 |
+
for rank in ranks:
|
54 |
+
load_pipeline_and_obtain_lora(ckpt=ckpt, rank=rank)
|
55 |
+
```
|