A newer version of this model is available: openai/gpt-oss-120b

🚀 Stable Diffusion with Transformers (Advanced Training)

This project demonstrates how to train a Stable Diffusion-like model using an image dataset with advanced Transformer-based denoising.
The implementation leverages PyTorch + Hugging Face Diffusers + Transformers.


📌 Overview

Stable Diffusion is a Latent Diffusion Model (LDM) that generates images by:

  1. Encoding images into a latent space using a VAE (Variational Autoencoder).
  2. Adding Gaussian noise to the latents across multiple time steps.
  3. Training a denoising Transformer/UNet to remove noise step by step.
  4. Using a text encoder (CLIP) for prompt conditioning.
  5. Decoding the cleaned latents back to an image.

🔬 Architecture

graph TD;
    A[Input Image] -->|VAE Encoder| B[Latent Space];
    B -->|Add Noise| C[Noisy Latents];
    C -->|Transformer / UNet Denoiser| D[Clean Latents];
    D -->|VAE Decoder| E[Output Image];
    F[Text Prompt] -->|CLIP Encoder| C;
  • VAE → Compresses image → latent space

  • Transformer/UNet → Learns to denoise latent

  • Text Encoder → Aligns text + image

  • Noise Scheduler → Controls forward & reverse diffusion

📂 Dataset

  • Images should be resized (256x256) and normalized to [-1, 1].

  • Optional: Provide text captions for conditioning.

  • Example:

data/
 ├── class1/
 │   ├── img1.png
 │   └── img2.jpg
 ├── class2/
 │   ├── img3.png
 │   └── img4.jpg

⚙️ Training Algorithm

The training process for Stable Diffusion with Transformers follows these steps:

  1. Encode Images → Pass input images through a VAE Encoder to obtain latent representations.
  2. Sample Noise & Timestep → Randomly sample Gaussian noise and a timestep t.
  3. Add Noise → Corrupt the latent vectors with noise according to timestep t.
  4. Text Conditioning → Encode text prompts using CLIP (or another Transformer text encoder).
  5. Noise Prediction → Feed the noisy latents + text embeddings into the Transformer/UNet to predict the added noise.
  6. Compute Loss → Calculate the Mean Squared Error (MSE) between predicted noise and true noise.
  7. Backpropagation → Update model weights using gradient descent.

flowchart TD
    A[Image] -->|VAE Encoder| B[Latent Space]
    B -->|Add Noise + t| C[Noisy Latents]
    D[Text Prompt] -->|CLIP Encoder| C
    C -->|Transformer / UNet| E[Predicted Noise]
    E -->|MSE Loss| F[Training Update]

🧑‍💻 Example Training Code

from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torch, torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataset = datasets.ImageFolder("path_to_images", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Components
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
scheduler = DDPMScheduler(num_train_timesteps=1000)

device = "cuda" if torch.cuda.is_available() else "cpu"
vae, unet, text_encoder = vae.to(device), unet.to(device), text_encoder.to(device)

optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)

# Training Loop
for epoch in range(10):
    for images, _ in dataloader:
        images = images.to(device)
        latents = vae.encode(images).latent_dist.sample() * 0.18215
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (latents.shape[0],), device=device)
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        text_inputs = tokenizer(["a photo"], padding="max_length", return_tensors="pt").to(device)
        text_embeds = text_encoder(text_inputs.input_ids).last_hidden_state

        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeds).sample
        loss = nn.MSELoss()(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | Loss: {loss.item()}")

💾 Saving & Inference

Save trained UNet

torch.save(unet.state_dict(), "unet_trained.pth")


# Inference pipeline
# 1. Sample random latent
# 2. Iteratively denoise with scheduler + trained UNet
# 3. Decode with VAE → image

📖 References

  • Stable Diffusion Paper

  • Hugging Face Diffusers

  • Diffusion Transformer (DiT)

✅ Future Work

Replace UNet with pure Transformer (DiT).

Use larger text encoders (T5/DeBERTa).

Train with custom captioned datasets.

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for ankitkushwaha90/Image_transformer_algorithm

Adapter
(7366)
this model

Dataset used to train ankitkushwaha90/Image_transformer_algorithm