VisualSplit / README.md
quchenyuan's picture
Update README.md
975380a verified
---
license: apache-2.0
tags:
- vision
---
# VisualSplit
**VisualSplit** is a ViT-based model that explicitly factorises an image into **classical visual descriptors**—such as **edges**, **color segmentation**, and **grayscale histogram**—and learns to reconstruct the image conditioned on those descriptors. This design yields **interpretable representations** where geometry (edges), albedo/appearance (segmented colors), and global tone (histogram) can be reasoned about or varied independently.
> **Training data**: ImageNet-1K.
---
## Model Description
- **Inputs** (at inference):
- An RGB image (for convenience) which is converted to descriptors using the provided `FeatureExtractor` (edges, color segmentation, grayscale histogram).
- **Outputs**:
- A reconstructed RGB image tensor (same spatial size as the model’s training resolution; default `224×224` unless you trained otherwise).
---
## Getting Started (Inference)
Below are two ways to run inference with the uploaded `model.safetensors`.
### 1) Minimal PyTorch + safetensors (load state dict)
```python
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# 1) Import your model & config from the VisualSplit repo
from visualsplit.models.CrossViT import CrossViTForPreTraining, CrossViTConfig
from visualsplit.utils import FeatureExtractor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 2) Build a config matching your training (edit if you changed widths/depths)
config = CrossViTConfig(
image_size=224, # change if your training size differs
patch_size=16,
# ... any other config fields your repo exposes
)
model = CrossViTForPreTraining(config).to(device)
model.eval()
# 3) Download and load state dict from this model repo
# Replace REPO_ID with your Hugging Face model id, e.g. "HenryQUQ/visualsplit")
ckpt_path = hf_hub_download(repo_id="REPO_ID", filename="model.safetensors")
state_dict = load_file(ckpt_path)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)
# 4) Prepare an input image and extract descriptors
from PIL import Image
from torchvision import transforms
image = Image.open("input.jpg").convert("RGB")
transform = transforms.Compose([
transforms.Resize((config.image_size, config.image_size)),
transforms.ToTensor(),
])
pixel_values = transform(image).unsqueeze(0).to(device) # (1, 3, H, W)
# FeatureExtractor provided by the repo should return the required tensors
extractor = FeatureExtractor().to(device)
with torch.no_grad():
edge, gray_hist, segmented_rgb, _ = extractor(pixel_values)
# 5) Run inference (reconstruction)
with torch.no_grad():
outputs = model(
source_edge=edge,
source_gray_level_histogram=gray_hist,
source_segmented_rgb=segmented_rgb,
)
# Your repo’s forward returns may differ; adjust the key accordingly:
reconstructed = outputs["logits_reshape"] # (1, 3, H, W)
# 6) Convert to PIL for visualisation
to_pil = transforms.ToPILImage()
recon_img = to_pil(reconstructed.squeeze(0).cpu().clamp(0, 1))
recon_img.save("reconstructed.png")
print("Saved to reconstructed.png")
```
### 2) Reproducing the notebook flow (`notebook/validation.ipynb`)
The repository provides a validation notebook that:
1. Loads the trained model,
2. Uses `FeatureExtractor` to compute **edges**, **color-segmented RGB**, and **grayscale histograms**,
3. Runs the model to obtain a reconstructed image,
4. Saves/visualises the result.
---
## Installation & Requirements
```bash
# clone the VisualSplit code
git clone https://github.com/HenryQUQ/VisualSplit.git
cd VisualSplit
# pip install -e .
```
---
## Training Data
- **Dataset**: **ImageNet-1K**.
-
> This repository only hosts the **trained checkpoint for inference**. Follow the GitHub repo for the full training pipeline and data preparation scripts.
---
## Model Sources
- **Code**: https://github.com/HenryQUQ/VisualSplit
- **Weights (this page)**: this Hugging Face model repo
---
## Citation
If you use this model or ideas, please cite:
```bibtex
@inproceedings{Qu2025VisualSplit,
title = {Exploring Image Representation with Decoupled Classical Visual Descriptors},
author = {Qu, Chenyuan and Chen, Hao and Jiao, Jianbo},
booktitle = {British Machine Vision Conference (BMVC)},
year = {2025}
}
```
---