--- license: apache-2.0 tags: - satellite imagery - text-to-image - image-to-image - rgb-to-sentinel2 - stable diffusion - cycle gan metrics: - spectral angle mapper --- # Multispectral Caption-Image Unification via Diffusion and CycleGAN [](https://creativecommons.org/licenses/by/4.0/) [](https://huggingface.co/docs) --- ## π Resources - π [Full Paper (Preprint)](link_to_paper_if_available) will be added - π§ [Training Code (GitHub)](https://github.com/kursatkomurcu/Multispectral-Caption-Image-Unification-via-Diffusion-and-CycleGAN) To see the data, the loading time can be long - [RGB Data](https://drive.google.com/drive/folders/1J12nXyrLrjzkuwku9Sgx4WfTeYRajjEz?usp=sharing) - [Multispectral Data](https://drive.google.com/drive/folders/1mIHOMczZ4nR3GWmzUOt0PiaLpqQfYtUD?usp=sharing)
--- ## π Citation If you use this model, please cite: ```bibtex @article{will be added } ``` ## π Overview **Multispectral Caption-Image Unification via Diffusion and CycleGAN** proposes a full multimodal pipeline that enables the generation and unification of satellite image data across three modalities: - **Caption (Text)** - **RGB Image** - **Multispectral Sentinel-2 Image** The system integrates **fine-tuned Stable Diffusion** for text-to-RGB image generation and **CycleGAN** for RGB-to-multispectral translation. It allows **triplet data creation** even when only partial information (e.g., just caption or RGB) is available. --- ## π Key Features - **Caption β RGB Image β Multispectral Image** generation - **RGB Image β Caption** and **Multispectral Image** generation - **Multispectral Image β RGB Image β Caption** reconstruction - Fine-tuned **Stable Diffusion 2-1 Base** on satellite captions - Custom **CycleGAN** model trained for Sentinel-2 13-band spectral transformation - Specialized **SAM Loss** (Spectral Angle Mapper) for better multispectral consistency - Supports creating fully unified datasets from previously disconnected modalities --- ## Example of Results   ## π Training Details - **Stable Diffusion Fine-Tuning:** - Dataset: 675,000 SkyScript images with captions generated by **Qwen2-VL-2B-Instruct** - Training: Text-to-Image generation targeting satellite domain - **CycleGAN Training:** - Dataset: 27,000 Eurosat RGB and multispectral images - Special Loss: Mixed of Spectral Angle Mapper (SAM) and histogram loss - **Hardware:** - Google Colab Pro+ - NVIDIA A100 GPU --- ## π°οΈ Applications - Synthetic satellite dataset generation - Remote sensing research (land cover classification, environmental monitoring) - Data augmentation for multispectral models - Disaster monitoring and environmental change detection --- ## π§© Model Components | Component | Description | |:---|:---| | `stable-diffusion-finetuned-satellite` | Fine-tuned Stable Diffusion 2-1 Base model for satellite image synthesis | | `cyclegan-rgb-to-multispectral` | Custom CycleGAN for RGB to multispectral (Sentinel-2) translation | | `synthetic-triplet-dataset` | 120,000 RGB + multispectral + caption synthetic triplet dataset | --- ## β‘ Quick Example: Generate an Image from a Single Caption ```python import os import torch from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DPMSolverMultistepScheduler from transformers import CLIPTextModel, CLIPTokenizer from safetensors.torch import load_file as safe_load import matplotlib.pyplot as plt from PIL import Image # Checkpoint fine-tuned UNet checkpoint_dir = "/your/path" checkpoint_path = os.path.join(checkpoint_dir, "model.safetensors") base_unet = UNet2DConditionModel.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="unet", torch_dtype=torch.float16 ) # Fine-tuned weights state_dict = safe_load(checkpoint_path) base_unet.load_state_dict(state_dict) unet = base_unet vae = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="vae", torch_dtype=torch.float16 ) text_encoder = CLIPTextModel.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder", torch_dtype=torch.float16 ) tokenizer = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer" ) scheduler = DPMSolverMultistepScheduler.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler" ) safety_checker = None feature_extractor = None # Stable Diffusion pipeline pipe = StableDiffusionPipeline( unet=unet, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor ) pipe = pipe.to("cuda") prompt = "A coastal city with large harbors and residential areas" with torch.cuda.amp.autocast(): result = pipe(prompt, num_inference_steps=100, guidance_scale=7.5) image = result.images[0] output_dir = "/your/save/path" os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, "single_prompt_generated.png") image.save(output_path) print(f"β The image generated and saved: {output_path}") # 8. Matplotlib ile gΓΆrselleΕtir if os.path.exists(output_path): img = Image.open(output_path) plt.figure(figsize=(8, 8)) plt.imshow(img) plt.axis("off") plt.show() else: print(f"The file could not find: {output_path}") ``` ## β‘ Quick Example: RGB-to-Multispectral Conversion with CycleGAN ```python import torch import torchvision.transforms as transforms import numpy as np import matplotlib.pyplot as plt from PIL import Image from safetensors.torch import safe_open # for loading .safetensors weights # --------------------------- # Model & Input Settings # --------------------------- model_path = "cycle_gan/G_model.safetensors" # update to your model path device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load your Generator (3β13 channels) G = Generator(input_nc=3, output_nc=10).to(device) with safe_open(model_path, framework="pt", device="cpu") as f: state_dict = {k: f.get_tensor(k) for k in f.keys()} G.load_state_dict(state_dict) G.eval() # Load an RGB test image rgb_path = "path/to/sample_rgb.jpg" input_image = Image.open(rgb_path).convert("RGB").resize((512, 512)) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,)*3, (0.5,)*3) ]) input_tensor = transform(input_image).unsqueeze(0).to(device) # (1,3,512,512) # --------------------------- # Sliding-Window Inference # --------------------------- patch_size = 64 h, w = 512, 512 output_fake = torch.zeros((13, h, w), device=device) for y in range(0, h, patch_size): for x in range(0, w, patch_size): patch = input_tensor[:, :, y:y+patch_size, x:x+patch_size] with torch.no_grad(): extra = G(patch) # (1,10,64,64) # assemble 13-channel patch combined = torch.empty(1, 13, patch_size, patch_size, device=device) combined[:, 0, :, :] = extra[:, 0, :, :] # band 1 combined[:, 1:4, :, :] = patch # bands 2β4 (RGB) combined[:, 4:, :, :] = extra[:, 1:, :, :] # bands 5β13 output_fake[:, y:y+patch_size, x:x+patch_size] = combined.squeeze(0) # to CPU & normalize from [-1,1] to [0,1] fake_np = output_fake.cpu().numpy() fake_np = (fake_np + 1) / 2.0 # shape (13,512,512) fake_np = np.transpose(fake_np, (1,2,0)) # (512,512,13) # Optional: save as GeoTIFF # import tifffile as tiff # tiff.imwrite("generated_multispectral.tif", fake_np.astype(np.float32)) # --------------------------- # Spectral Visualization # --------------------------- spectral_composites = { "Natural Color (B4,B3,B2)": [1,2,3], "Color Infrared (B8,B4,B3)": [7,3,2], "Short-Wave Infrared (B12,B8A,B4)": [12,8,3], "Agriculture (B11,B8,B2)": [10,7,1], "Geology (B12,B11,B2)": [12,10,1], "Bathymetric (B4,B3,B1)": [3,2,0] } # Compute NDVI ndvi = (fake_np[:,:,7] - fake_np[:,:,3]) / (fake_np[:,:,7] + fake_np[:,:,3] + 1e-6) fig, axs = plt.subplots(2, 4, figsize=(16,8)) axs = axs.flatten() # plot each composite for idx, (title, bands) in enumerate(spectral_composites.items()): img = fake_np[:,:,bands] if title.endswith("(B4,B3,B2)") else np.mean(fake_np[:,:,bands], axis=2) axs[idx].imshow(img, cmap=None if title.endswith("(B4,B3,B2)") else "inferno") axs[idx].set_title(title) axs[idx].axis("off") # plot NDVI axs[-1].imshow(ndvi, cmap="RdYlGn", vmin=-1, vmax=1) axs[-1].set_title("Vegetation Index (NDVI)") axs[-1].axis("off") plt.tight_layout() plt.show() ```