|
--- |
|
license: apache-2.0 |
|
tags: |
|
- vision |
|
--- |
|
|
|
|
|
# VisualSplit |
|
|
|
**VisualSplit** is a ViT-based model that explicitly factorises an image into **classical visual descriptors**—such as **edges**, **color segmentation**, and **grayscale histogram**—and learns to reconstruct the image conditioned on those descriptors. This design yields **interpretable representations** where geometry (edges), albedo/appearance (segmented colors), and global tone (histogram) can be reasoned about or varied independently. |
|
|
|
> **Training data**: ImageNet-1K. |
|
--- |
|
|
|
## Model Description |
|
|
|
- **Inputs** (at inference): |
|
- An RGB image (for convenience) which is converted to descriptors using the provided `FeatureExtractor` (edges, color segmentation, grayscale histogram). |
|
- **Outputs**: |
|
- A reconstructed RGB image tensor (same spatial size as the model’s training resolution; default `224×224` unless you trained otherwise). |
|
|
|
--- |
|
|
|
## Getting Started (Inference) |
|
|
|
Below are two ways to run inference with the uploaded `model.safetensors`. |
|
|
|
### 1) Minimal PyTorch + safetensors (load state dict) |
|
|
|
```python |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
|
|
# 1) Import your model & config from the VisualSplit repo |
|
from visualsplit.models.CrossViT import CrossViTForPreTraining, CrossViTConfig |
|
from visualsplit.utils import FeatureExtractor |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# 2) Build a config matching your training (edit if you changed widths/depths) |
|
config = CrossViTConfig( |
|
image_size=224, # change if your training size differs |
|
patch_size=16, |
|
# ... any other config fields your repo exposes |
|
) |
|
|
|
model = CrossViTForPreTraining(config).to(device) |
|
model.eval() |
|
|
|
# 3) Download and load state dict from this model repo |
|
# Replace REPO_ID with your Hugging Face model id, e.g. "HenryQUQ/visualsplit") |
|
ckpt_path = hf_hub_download(repo_id="REPO_ID", filename="model.safetensors") |
|
state_dict = load_file(ckpt_path) |
|
missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
print("Missing keys:", missing) |
|
print("Unexpected keys:", unexpected) |
|
|
|
# 4) Prepare an input image and extract descriptors |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
image = Image.open("input.jpg").convert("RGB") |
|
transform = transforms.Compose([ |
|
transforms.Resize((config.image_size, config.image_size)), |
|
transforms.ToTensor(), |
|
]) |
|
pixel_values = transform(image).unsqueeze(0).to(device) # (1, 3, H, W) |
|
|
|
# FeatureExtractor provided by the repo should return the required tensors |
|
extractor = FeatureExtractor().to(device) |
|
with torch.no_grad(): |
|
edge, gray_hist, segmented_rgb, _ = extractor(pixel_values) |
|
|
|
# 5) Run inference (reconstruction) |
|
with torch.no_grad(): |
|
outputs = model( |
|
source_edge=edge, |
|
source_gray_level_histogram=gray_hist, |
|
source_segmented_rgb=segmented_rgb, |
|
) |
|
# Your repo’s forward returns may differ; adjust the key accordingly: |
|
reconstructed = outputs["logits_reshape"] # (1, 3, H, W) |
|
|
|
# 6) Convert to PIL for visualisation |
|
to_pil = transforms.ToPILImage() |
|
recon_img = to_pil(reconstructed.squeeze(0).cpu().clamp(0, 1)) |
|
recon_img.save("reconstructed.png") |
|
print("Saved to reconstructed.png") |
|
``` |
|
|
|
### 2) Reproducing the notebook flow (`notebook/validation.ipynb`) |
|
|
|
The repository provides a validation notebook that: |
|
1. Loads the trained model, |
|
2. Uses `FeatureExtractor` to compute **edges**, **color-segmented RGB**, and **grayscale histograms**, |
|
3. Runs the model to obtain a reconstructed image, |
|
4. Saves/visualises the result. |
|
|
|
--- |
|
|
|
## Installation & Requirements |
|
|
|
```bash |
|
# clone the VisualSplit code |
|
git clone https://github.com/HenryQUQ/VisualSplit.git |
|
cd VisualSplit |
|
# pip install -e . |
|
``` |
|
|
|
--- |
|
|
|
## Training Data |
|
|
|
- **Dataset**: **ImageNet-1K**. |
|
- |
|
> This repository only hosts the **trained checkpoint for inference**. Follow the GitHub repo for the full training pipeline and data preparation scripts. |
|
|
|
--- |
|
|
|
## Model Sources |
|
|
|
- **Code**: https://github.com/HenryQUQ/VisualSplit |
|
- **Weights (this page)**: this Hugging Face model repo |
|
|
|
--- |
|
|
|
## Citation |
|
|
|
If you use this model or ideas, please cite: |
|
|
|
```bibtex |
|
@inproceedings{Qu2025VisualSplit, |
|
title = {Exploring Image Representation with Decoupled Classical Visual Descriptors}, |
|
author = {Qu, Chenyuan and Chen, Hao and Jiao, Jianbo}, |
|
booktitle = {British Machine Vision Conference (BMVC)}, |
|
year = {2025} |
|
} |
|
``` |
|
|
|
--- |