🚀 Stable Diffusion with Transformers (Advanced Training)
This project demonstrates how to train a Stable Diffusion-like model using an image dataset with advanced Transformer-based denoising.
The implementation leverages PyTorch + Hugging Face Diffusers + Transformers.
📌 Overview
Stable Diffusion is a Latent Diffusion Model (LDM) that generates images by:
- Encoding images into a latent space using a VAE (Variational Autoencoder).
- Adding Gaussian noise to the latents across multiple time steps.
- Training a denoising Transformer/UNet to remove noise step by step.
- Using a text encoder (CLIP) for prompt conditioning.
- Decoding the cleaned latents back to an image.
🔬 Architecture
graph TD;
A[Input Image] -->|VAE Encoder| B[Latent Space];
B -->|Add Noise| C[Noisy Latents];
C -->|Transformer / UNet Denoiser| D[Clean Latents];
D -->|VAE Decoder| E[Output Image];
F[Text Prompt] -->|CLIP Encoder| C;
VAE → Compresses image → latent space
Transformer/UNet → Learns to denoise latent
Text Encoder → Aligns text + image
Noise Scheduler → Controls forward & reverse diffusion
📂 Dataset
Images should be resized (256x256) and normalized to [-1, 1].
Optional: Provide text captions for conditioning.
Example:
data/
├── class1/
│ ├── img1.png
│ └── img2.jpg
├── class2/
│ ├── img3.png
│ └── img4.jpg
⚙️ Training Algorithm
The training process for Stable Diffusion with Transformers follows these steps:
- Encode Images → Pass input images through a VAE Encoder to obtain latent representations.
- Sample Noise & Timestep → Randomly sample Gaussian noise and a timestep
t
. - Add Noise → Corrupt the latent vectors with noise according to timestep
t
. - Text Conditioning → Encode text prompts using CLIP (or another Transformer text encoder).
- Noise Prediction → Feed the noisy latents + text embeddings into the Transformer/UNet to predict the added noise.
- Compute Loss → Calculate the Mean Squared Error (MSE) between predicted noise and true noise.
- Backpropagation → Update model weights using gradient descent.
flowchart TD
A[Image] -->|VAE Encoder| B[Latent Space]
B -->|Add Noise + t| C[Noisy Latents]
D[Text Prompt] -->|CLIP Encoder| C
C -->|Transformer / UNet| E[Predicted Noise]
E -->|MSE Loss| F[Training Update]
🧑💻 Example Training Code
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torch, torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Dataset
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = datasets.ImageFolder("path_to_images", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# Components
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
scheduler = DDPMScheduler(num_train_timesteps=1000)
device = "cuda" if torch.cuda.is_available() else "cpu"
vae, unet, text_encoder = vae.to(device), unet.to(device), text_encoder.to(device)
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)
# Training Loop
for epoch in range(10):
for images, _ in dataloader:
images = images.to(device)
latents = vae.encode(images).latent_dist.sample() * 0.18215
noise = torch.randn_like(latents)
timesteps = torch.randint(0, scheduler.num_train_timesteps, (latents.shape[0],), device=device)
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
text_inputs = tokenizer(["a photo"], padding="max_length", return_tensors="pt").to(device)
text_embeds = text_encoder(text_inputs.input_ids).last_hidden_state
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeds).sample
loss = nn.MSELoss()(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch} | Loss: {loss.item()}")
💾 Saving & Inference
Save trained UNet
torch.save(unet.state_dict(), "unet_trained.pth")
# Inference pipeline
# 1. Sample random latent
# 2. Iteratively denoise with scheduler + trained UNet
# 3. Decode with VAE → image
📖 References
Stable Diffusion Paper
Hugging Face Diffusers
Diffusion Transformer (DiT)
✅ Future Work
Replace UNet with pure Transformer (DiT).
Use larger text encoders (T5/DeBERTa).
Train with custom captioned datasets.
- Downloads last month
- -
Model tree for ankitkushwaha90/Image_transformer_algorithm
Base model
stabilityai/stable-diffusion-xl-base-1.0