VisualSplit

VisualSplit is a ViT-based model that explicitly factorises an image into classical visual descriptors—such as edges, color segmentation, and grayscale histogram—and learns to reconstruct the image conditioned on those descriptors. This design yields interpretable representations where geometry (edges), albedo/appearance (segmented colors), and global tone (histogram) can be reasoned about or varied independently.

Training data: ImageNet-1K.


Model Description

  • Inputs (at inference):
    • An RGB image (for convenience) which is converted to descriptors using the provided FeatureExtractor (edges, color segmentation, grayscale histogram).
  • Outputs:
    • A reconstructed RGB image tensor (same spatial size as the model’s training resolution; default 224×224 unless you trained otherwise).

Getting Started (Inference)

Below are two ways to run inference with the uploaded model.safetensors.

1) Minimal PyTorch + safetensors (load state dict)

import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# 1) Import your model & config from the VisualSplit repo
from visualsplit.models.CrossViT import CrossViTForPreTraining, CrossViTConfig
from visualsplit.utils import FeatureExtractor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) Build a config matching your training (edit if you changed widths/depths)
config = CrossViTConfig(
    image_size=224,           # change if your training size differs
    patch_size=16,
    # ... any other config fields your repo exposes
)

model = CrossViTForPreTraining(config).to(device)
model.eval()

# 3) Download and load state dict from this model repo
#    Replace REPO_ID with your Hugging Face model id, e.g. "HenryQUQ/visualsplit")
ckpt_path = hf_hub_download(repo_id="REPO_ID", filename="model.safetensors")
state_dict = load_file(ckpt_path)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

# 4) Prepare an input image and extract descriptors
from PIL import Image
from torchvision import transforms

image = Image.open("input.jpg").convert("RGB")
transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor(),
])
pixel_values = transform(image).unsqueeze(0).to(device)   # (1, 3, H, W)

# FeatureExtractor provided by the repo should return the required tensors
extractor = FeatureExtractor().to(device)
with torch.no_grad():
    edge, gray_hist, segmented_rgb, _ = extractor(pixel_values)

# 5) Run inference (reconstruction)
with torch.no_grad():
    outputs = model(
        source_edge=edge,
        source_gray_level_histogram=gray_hist,
        source_segmented_rgb=segmented_rgb,
    )
# Your repo’s forward returns may differ; adjust the key accordingly:
reconstructed = outputs["logits_reshape"]  # (1, 3, H, W)

# 6) Convert to PIL for visualisation
to_pil = transforms.ToPILImage()
recon_img = to_pil(reconstructed.squeeze(0).cpu().clamp(0, 1))
recon_img.save("reconstructed.png")
print("Saved to reconstructed.png")

2) Reproducing the notebook flow (notebook/validation.ipynb)

The repository provides a validation notebook that:

  1. Loads the trained model,
  2. Uses FeatureExtractor to compute edges, color-segmented RGB, and grayscale histograms,
  3. Runs the model to obtain a reconstructed image,
  4. Saves/visualises the result.

Installation & Requirements

# clone the VisualSplit code
git clone https://github.com/HenryQUQ/VisualSplit.git
cd VisualSplit
# pip install -e .

Training Data

  • Dataset: ImageNet-1K.
  • This repository only hosts the trained checkpoint for inference. Follow the GitHub repo for the full training pipeline and data preparation scripts.


Model Sources


Citation

If you use this model or ideas, please cite:

@inproceedings{Qu2025VisualSplit,
  title     = {Exploring Image Representation with Decoupled Classical Visual Descriptors},
  author    = {Qu, Chenyuan and Chen, Hao and Jiao, Jianbo},
  booktitle = {British Machine Vision Conference (BMVC)},
  year      = {2025}
}

Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
1.33B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support