sayakpaul HF staff commited on
Commit
6d81182
·
verified ·
1 Parent(s): 9f9c297

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ State dicts generated with:
2
+
3
+ ```py
4
+ from diffusers import DiffusionPipeline
5
+ import torch
6
+ from peft import LoraConfig
7
+ from peft.utils import get_peft_model_state_dict
8
+ from huggingface_hub import create_repo, upload_file
9
+ import tempfile
10
+ import os
11
+
12
+
13
+ ckpts = [
14
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
15
+ "stabilityai/stable-diffusion-xl-base-1.0",
16
+ "black-forest-labs/FLUX.1-dev"
17
+ ]
18
+
19
+ ranks = [16, 32, 128]
20
+
21
+ repo_id = create_repo(repo_id="sayakpaul/dummy-lora-state-dicts", exist_ok=True).repo_id
22
+
23
+ def get_lora_config(rank=16):
24
+ return LoraConfig(
25
+ r=rank,
26
+ lora_alpha=rank,
27
+ init_lora_weights="gaussian",
28
+ target_modules=["to_k", "to_v", "to_q", "to_out.0"],
29
+ )
30
+
31
+
32
+ def load_pipeline_and_obtain_lora(ckpt, rank):
33
+ pipeline = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.bfloat16)
34
+ pipeline_cls = pipeline.__class__
35
+
36
+ lora_config = get_lora_config(rank=rank)
37
+ weight_name = f"r@{rank}-{ckpt.split('/')[-1]}.safetensors"
38
+
39
+ with tempfile.TemporaryDirectory() as tmpdir:
40
+ save_kwargs = {"weight_name": weight_name}
41
+ if hasattr(pipeline, "unet"):
42
+ pipeline.unet.add_adapter(lora_config)
43
+ save_kwargs.update({"unet_lora_layers": get_peft_model_state_dict(pipeline.unet)})
44
+ else:
45
+ pipeline.transformer.add_adapter(lora_config)
46
+ save_kwargs.update({"transformer_lora_layers": get_peft_model_state_dict(pipeline.transformer)})
47
+
48
+ pipeline_cls.save_lora_weights(save_directory=tmpdir, **save_kwargs)
49
+ upload_file(repo_id=repo_id, path_or_fileobj=os.path.join(tmpdir, weight_name), path_in_repo=weight_name)
50
+
51
+
52
+ for ckpt in ckpts:
53
+ for rank in ranks:
54
+ load_pipeline_and_obtain_lora(ckpt=ckpt, rank=rank)
55
+ ```