|
--- |
|
license: apache-2.0 |
|
tags: |
|
- satellite imagery |
|
- text-to-image |
|
- image-to-image |
|
- rgb-to-sentinel2 |
|
- stable diffusion |
|
- cycle gan |
|
metrics: |
|
- spectral angle mapper |
|
--- |
|
|
|
# Multispectral Caption-Image Unification via Diffusion and CycleGAN |
|
|
|
[](https://creativecommons.org/licenses/by/4.0/) |
|
[](https://huggingface.co/docs) |
|
|
|
--- |
|
|
|
## 🔗 Resources |
|
|
|
- 📄 [Full Paper (Preprint)](link_to_paper_if_available) will be added |
|
- 🧠 [Training Code (GitHub)](https://github.com/kursatkomurcu/Multispectral-Caption-Image-Unification-via-Diffusion-and-CycleGAN) |
|
|
|
To see the data, the loading time can be long |
|
- [RGB Data](https://drive.google.com/drive/folders/1J12nXyrLrjzkuwku9Sgx4WfTeYRajjEz?usp=sharing) |
|
- [Multispectral Data](https://drive.google.com/drive/folders/1mIHOMczZ4nR3GWmzUOt0PiaLpqQfYtUD?usp=sharing) |
|
|
|
<p align="left"> |
|
<a href="https://www.buymeacoffee.com/kursatkomurcu"> |
|
<img |
|
src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" |
|
alt="Buy Me A Coffee" |
|
width="200" |
|
/> |
|
</a> |
|
</p> |
|
|
|
--- |
|
|
|
## 📜 Citation |
|
|
|
If you use this model, please cite: |
|
|
|
```bibtex |
|
@article{will be added |
|
} |
|
``` |
|
|
|
## 📜 Overview |
|
|
|
**Multispectral Caption-Image Unification via Diffusion and CycleGAN** proposes a full multimodal pipeline that enables the generation and unification of satellite image data across three modalities: |
|
- **Caption (Text)** |
|
- **RGB Image** |
|
- **Multispectral Sentinel-2 Image** |
|
|
|
The system integrates **fine-tuned Stable Diffusion** for text-to-RGB image generation and **CycleGAN** for RGB-to-multispectral translation. |
|
It allows **triplet data creation** even when only partial information (e.g., just caption or RGB) is available. |
|
|
|
--- |
|
|
|
## 🚀 Key Features |
|
|
|
- **Caption ➔ RGB Image ➔ Multispectral Image** generation |
|
- **RGB Image ➔ Caption** and **Multispectral Image** generation |
|
- **Multispectral Image ➔ RGB Image ➔ Caption** reconstruction |
|
- Fine-tuned **Stable Diffusion 2-1 Base** on satellite captions |
|
- Custom **CycleGAN** model trained for Sentinel-2 13-band spectral transformation |
|
- Specialized **SAM Loss** (Spectral Angle Mapper) for better multispectral consistency |
|
- Supports creating fully unified datasets from previously disconnected modalities |
|
|
|
--- |
|
|
|
## Example of Results |
|
 |
|
 |
|
|
|
|
|
## 📚 Training Details |
|
|
|
- **Stable Diffusion Fine-Tuning:** |
|
- Dataset: 675,000 SkyScript images with captions generated by **Qwen2-VL-2B-Instruct** |
|
- Training: Text-to-Image generation targeting satellite domain |
|
|
|
- **CycleGAN Training:** |
|
- Dataset: 27,000 Eurosat RGB and multispectral images |
|
- Special Loss: Mixed of Spectral Angle Mapper (SAM) and histogram loss |
|
|
|
- **Hardware:** |
|
- Google Colab Pro+ |
|
- NVIDIA A100 GPU |
|
|
|
--- |
|
|
|
## 🛰️ Applications |
|
|
|
- Synthetic satellite dataset generation |
|
- Remote sensing research (land cover classification, environmental monitoring) |
|
- Data augmentation for multispectral models |
|
- Disaster monitoring and environmental change detection |
|
|
|
--- |
|
|
|
## 🧩 Model Components |
|
|
|
| Component | Description | |
|
|:---|:---| |
|
| `stable-diffusion-finetuned-satellite` | Fine-tuned Stable Diffusion 2-1 Base model for satellite image synthesis | |
|
| `cyclegan-rgb-to-multispectral` | Custom CycleGAN for RGB to multispectral (Sentinel-2) translation | |
|
| `synthetic-triplet-dataset` | 120,000 RGB + multispectral + caption synthetic triplet dataset | |
|
|
|
--- |
|
|
|
## ⚡ Quick Example: Generate an Image from a Single Caption |
|
|
|
```python |
|
import os |
|
import torch |
|
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DPMSolverMultistepScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from safetensors.torch import load_file as safe_load |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
|
|
# Checkpoint fine-tuned UNet |
|
checkpoint_dir = "/your/path" |
|
checkpoint_path = os.path.join(checkpoint_dir, "model.safetensors") |
|
|
|
base_unet = UNet2DConditionModel.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1-base", |
|
subfolder="unet", |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
# Fine-tuned weights |
|
state_dict = safe_load(checkpoint_path) |
|
base_unet.load_state_dict(state_dict) |
|
unet = base_unet |
|
|
|
vae = AutoencoderKL.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1-base", |
|
subfolder="vae", |
|
torch_dtype=torch.float16 |
|
) |
|
text_encoder = CLIPTextModel.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1-base", |
|
subfolder="text_encoder", |
|
torch_dtype=torch.float16 |
|
) |
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1-base", |
|
subfolder="tokenizer" |
|
) |
|
scheduler = DPMSolverMultistepScheduler.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1-base", |
|
subfolder="scheduler" |
|
) |
|
|
|
safety_checker = None |
|
feature_extractor = None |
|
|
|
# Stable Diffusion pipeline |
|
pipe = StableDiffusionPipeline( |
|
unet=unet, |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor |
|
) |
|
|
|
pipe = pipe.to("cuda") |
|
|
|
prompt = "A coastal city with large harbors and residential areas" |
|
|
|
with torch.cuda.amp.autocast(): |
|
result = pipe(prompt, num_inference_steps=100, guidance_scale=7.5) |
|
|
|
image = result.images[0] |
|
|
|
output_dir = "/your/save/path" |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_path = os.path.join(output_dir, "single_prompt_generated.png") |
|
image.save(output_path) |
|
print(f"✅ The image generated and saved: {output_path}") |
|
|
|
# 8. Matplotlib ile görselleştir |
|
if os.path.exists(output_path): |
|
img = Image.open(output_path) |
|
plt.figure(figsize=(8, 8)) |
|
plt.imshow(img) |
|
plt.axis("off") |
|
plt.show() |
|
else: |
|
print(f"The file could not find: {output_path}") |
|
|
|
``` |
|
|
|
## ⚡ Quick Example: RGB-to-Multispectral Conversion with CycleGAN |
|
|
|
```python |
|
import torch |
|
import torchvision.transforms as transforms |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
from safetensors.torch import safe_open # for loading .safetensors weights |
|
|
|
# --------------------------- |
|
# Model & Input Settings |
|
# --------------------------- |
|
model_path = "cycle_gan/G_model.safetensors" # update to your model path |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Load your Generator (3→13 channels) |
|
G = Generator(input_nc=3, output_nc=10).to(device) |
|
with safe_open(model_path, framework="pt", device="cpu") as f: |
|
state_dict = {k: f.get_tensor(k) for k in f.keys()} |
|
G.load_state_dict(state_dict) |
|
G.eval() |
|
|
|
# Load an RGB test image |
|
rgb_path = "path/to/sample_rgb.jpg" |
|
input_image = Image.open(rgb_path).convert("RGB").resize((512, 512)) |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,)*3, (0.5,)*3) |
|
]) |
|
input_tensor = transform(input_image).unsqueeze(0).to(device) # (1,3,512,512) |
|
|
|
# --------------------------- |
|
# Sliding-Window Inference |
|
# --------------------------- |
|
patch_size = 64 |
|
h, w = 512, 512 |
|
output_fake = torch.zeros((13, h, w), device=device) |
|
|
|
for y in range(0, h, patch_size): |
|
for x in range(0, w, patch_size): |
|
patch = input_tensor[:, :, y:y+patch_size, x:x+patch_size] |
|
with torch.no_grad(): |
|
extra = G(patch) # (1,10,64,64) |
|
# assemble 13-channel patch |
|
combined = torch.empty(1, 13, patch_size, patch_size, device=device) |
|
combined[:, 0, :, :] = extra[:, 0, :, :] # band 1 |
|
combined[:, 1:4, :, :] = patch # bands 2–4 (RGB) |
|
combined[:, 4:, :, :] = extra[:, 1:, :, :] # bands 5–13 |
|
output_fake[:, y:y+patch_size, x:x+patch_size] = combined.squeeze(0) |
|
|
|
# to CPU & normalize from [-1,1] to [0,1] |
|
fake_np = output_fake.cpu().numpy() |
|
fake_np = (fake_np + 1) / 2.0 # shape (13,512,512) |
|
fake_np = np.transpose(fake_np, (1,2,0)) # (512,512,13) |
|
|
|
# Optional: save as GeoTIFF |
|
# import tifffile as tiff |
|
# tiff.imwrite("generated_multispectral.tif", fake_np.astype(np.float32)) |
|
|
|
# --------------------------- |
|
# Spectral Visualization |
|
# --------------------------- |
|
spectral_composites = { |
|
"Natural Color (B4,B3,B2)": [1,2,3], |
|
"Color Infrared (B8,B4,B3)": [7,3,2], |
|
"Short-Wave Infrared (B12,B8A,B4)": [12,8,3], |
|
"Agriculture (B11,B8,B2)": [10,7,1], |
|
"Geology (B12,B11,B2)": [12,10,1], |
|
"Bathymetric (B4,B3,B1)": [3,2,0] |
|
} |
|
|
|
# Compute NDVI |
|
ndvi = (fake_np[:,:,7] - fake_np[:,:,3]) / (fake_np[:,:,7] + fake_np[:,:,3] + 1e-6) |
|
|
|
fig, axs = plt.subplots(2, 4, figsize=(16,8)) |
|
axs = axs.flatten() |
|
|
|
# plot each composite |
|
for idx, (title, bands) in enumerate(spectral_composites.items()): |
|
img = fake_np[:,:,bands] if title.endswith("(B4,B3,B2)") else np.mean(fake_np[:,:,bands], axis=2) |
|
axs[idx].imshow(img, cmap=None if title.endswith("(B4,B3,B2)") else "inferno") |
|
axs[idx].set_title(title) |
|
axs[idx].axis("off") |
|
|
|
# plot NDVI |
|
axs[-1].imshow(ndvi, cmap="RdYlGn", vmin=-1, vmax=1) |
|
axs[-1].set_title("Vegetation Index (NDVI)") |
|
axs[-1].axis("off") |
|
|
|
plt.tight_layout() |
|
plt.show() |
|
``` |
|
|
|
|
|
|
|
|
|
|