Update README.md
Browse files
README.md
CHANGED
@@ -111,27 +111,29 @@ It allows **triplet data creation** even when only partial information (e.g., ju
|
|
111 |
## ⚡ Quick Example: Generate an Image from a Single Caption
|
112 |
|
113 |
```python
|
|
|
114 |
import torch
|
115 |
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DPMSolverMultistepScheduler
|
116 |
from transformers import CLIPTextModel, CLIPTokenizer
|
117 |
from safetensors.torch import load_file as safe_load
|
|
|
118 |
from PIL import Image
|
119 |
-
import os
|
120 |
|
121 |
-
#
|
122 |
-
|
|
|
123 |
|
124 |
-
# Initialize UNet and load weights
|
125 |
base_unet = UNet2DConditionModel.from_pretrained(
|
126 |
"stabilityai/stable-diffusion-2-1-base",
|
127 |
subfolder="unet",
|
128 |
torch_dtype=torch.float16
|
129 |
)
|
|
|
|
|
130 |
state_dict = safe_load(checkpoint_path)
|
131 |
base_unet.load_state_dict(state_dict)
|
132 |
unet = base_unet
|
133 |
|
134 |
-
# Load VAE, text encoder, tokenizer, scheduler
|
135 |
vae = AutoencoderKL.from_pretrained(
|
136 |
"stabilityai/stable-diffusion-2-1-base",
|
137 |
subfolder="vae",
|
@@ -151,31 +153,44 @@ scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
|
151 |
subfolder="scheduler"
|
152 |
)
|
153 |
|
154 |
-
|
|
|
|
|
|
|
155 |
pipe = StableDiffusionPipeline(
|
156 |
unet=unet,
|
157 |
vae=vae,
|
158 |
text_encoder=text_encoder,
|
159 |
tokenizer=tokenizer,
|
160 |
scheduler=scheduler,
|
161 |
-
safety_checker=
|
162 |
-
feature_extractor=
|
163 |
-
)
|
|
|
|
|
164 |
|
165 |
-
|
166 |
-
|
|
|
|
|
167 |
|
168 |
-
# Generate the image
|
169 |
-
result = pipe(prompt, num_inference_steps=50)
|
170 |
image = result.images[0]
|
171 |
|
172 |
-
|
173 |
-
output_dir = "./generated_images"
|
174 |
os.makedirs(output_dir, exist_ok=True)
|
175 |
-
output_path = os.path.join(output_dir, "
|
176 |
image.save(output_path)
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
```
|
181 |
|
|
|
111 |
## ⚡ Quick Example: Generate an Image from a Single Caption
|
112 |
|
113 |
```python
|
114 |
+
import os
|
115 |
import torch
|
116 |
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DPMSolverMultistepScheduler
|
117 |
from transformers import CLIPTextModel, CLIPTokenizer
|
118 |
from safetensors.torch import load_file as safe_load
|
119 |
+
import matplotlib.pyplot as plt
|
120 |
from PIL import Image
|
|
|
121 |
|
122 |
+
# Checkpoint fine-tuned UNet
|
123 |
+
checkpoint_dir = "/your/path"
|
124 |
+
checkpoint_path = os.path.join(checkpoint_dir, "model.safetensors")
|
125 |
|
|
|
126 |
base_unet = UNet2DConditionModel.from_pretrained(
|
127 |
"stabilityai/stable-diffusion-2-1-base",
|
128 |
subfolder="unet",
|
129 |
torch_dtype=torch.float16
|
130 |
)
|
131 |
+
|
132 |
+
# Fine-tuned weights
|
133 |
state_dict = safe_load(checkpoint_path)
|
134 |
base_unet.load_state_dict(state_dict)
|
135 |
unet = base_unet
|
136 |
|
|
|
137 |
vae = AutoencoderKL.from_pretrained(
|
138 |
"stabilityai/stable-diffusion-2-1-base",
|
139 |
subfolder="vae",
|
|
|
153 |
subfolder="scheduler"
|
154 |
)
|
155 |
|
156 |
+
safety_checker = None
|
157 |
+
feature_extractor = None
|
158 |
+
|
159 |
+
# Stable Diffusion pipeline
|
160 |
pipe = StableDiffusionPipeline(
|
161 |
unet=unet,
|
162 |
vae=vae,
|
163 |
text_encoder=text_encoder,
|
164 |
tokenizer=tokenizer,
|
165 |
scheduler=scheduler,
|
166 |
+
safety_checker=safety_checker,
|
167 |
+
feature_extractor=feature_extractor
|
168 |
+
)
|
169 |
+
|
170 |
+
pipe = pipe.to("cuda")
|
171 |
|
172 |
+
prompt = "A coastal city with large harbors and residential areas"
|
173 |
+
|
174 |
+
with torch.cuda.amp.autocast():
|
175 |
+
result = pipe(prompt, num_inference_steps=100, guidance_scale=7.5)
|
176 |
|
|
|
|
|
177 |
image = result.images[0]
|
178 |
|
179 |
+
output_dir = "/your/save/path"
|
|
|
180 |
os.makedirs(output_dir, exist_ok=True)
|
181 |
+
output_path = os.path.join(output_dir, "single_prompt_generated.png")
|
182 |
image.save(output_path)
|
183 |
+
print(f"✅ The image generated and saved: {output_path}")
|
184 |
+
|
185 |
+
# 8. Matplotlib ile görselleştir
|
186 |
+
if os.path.exists(output_path):
|
187 |
+
img = Image.open(output_path)
|
188 |
+
plt.figure(figsize=(8, 8))
|
189 |
+
plt.imshow(img)
|
190 |
+
plt.axis("off")
|
191 |
+
plt.show()
|
192 |
+
else:
|
193 |
+
print(f"The file could not find: {output_path}")
|
194 |
|
195 |
```
|
196 |
|