import torch
import numpy as np
from PIL import Image
from einops import repeat
from datasets import load_dataset, concatenate_datasets
from IPython.display import display, HTML
from torchvision.transforms import ToPILImage, PILToTensor, Compose
from torchvision.transforms import Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation
from vit_pytorch.mae import MAE
from vit_pytorch.simple_vit_with_register_tokens import SimpleViT
from einops.layers.torch import Rearrange
class Args: pass
device = "cpu"
checkpoint = torch.load("v0.0.1.pt",map_location="cpu")
args = checkpoint['args']
args.crops_per_sample = 1

encoder = SimpleViT(
    image_size = args.img_dim[1],
    channels = args.img_dim[0],
    patch_size = args.patch_size,
    num_classes = args.num_classes,
    dim = args.embed_dim,
    depth = args.depth,
    heads = args.heads,
    mlp_dim = args.mlp_dim,
    dim_head = args.embed_dim//args.heads,
).to(device)

model = MAE(
    encoder=encoder,
    decoder_dim=args.embed_dim,
    masking_ratio=args.masking_ratio,
    decoder_depth=args.decoder_depth,
    decoder_heads=args.heads,
    decoder_dim_head=args.embed_dim//args.heads,
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
<All keys matched successfully>
dataset = load_dataset("danjacobellis/cell_synthetic_labels")
transforms = Compose([
    RandomCrop(896),
    RandomRotation(22.5),
    CenterCrop(672),
    Resize(224, interpolation=Image.Resampling.LANCZOS),
    RandomVerticalFlip(0.5),
    RandomHorizontalFlip(0.5),
    PILToTensor(),
])

def collate_fn(batch):
    batch_size = len(batch)*args.crops_per_sample
    inputs = torch.zeros(
        (batch_size, args.img_dim[0], args.img_dim[1], args.img_dim[2]),
        dtype=torch.uint8
    )
    for i_sample, sample in enumerate(batch):
        img = sample['image']
        for i_crop in range(args.crops_per_sample):
            ind = i_sample*args.crops_per_sample + i_crop
            inputs[ind,:,:,:] = transforms(img)
    
    return inputs
data_loader_valid = torch.utils.data.DataLoader(
    dataset['validation'],
    batch_size=8,
    shuffle=False,
    num_workers=args.num_workers,
    drop_last=False,
    pin_memory=True,
    collate_fn=collate_fn
)
with torch.no_grad():
    x = next(iter(data_loader_valid))
    x = x.to(torch.float)
    x = x / 255
    x = x.to(device)

    patches = model.to_patch(x)
    batch, num_patches, *_ = patches.shape

    tokens = model.patch_to_emb(patches)
    tokens += model.encoder.pos_embedding.to(device, dtype=tokens.dtype) 

    num_masked = int(model.masking_ratio * num_patches)
    rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
    masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

    batch_range = torch.arange(batch, device = device)[:, None]
    tokens = tokens[batch_range, unmasked_indices]

    masked_patches = patches[batch_range, masked_indices]
    encoded_tokens = model.encoder.transformer(tokens)
    decoder_tokens = model.enc_to_dec(encoded_tokens)
    unmasked_decoder_tokens = decoder_tokens + model.decoder_pos_emb(unmasked_indices)

    mask_tokens = repeat(model.mask_token, 'd -> b n d', b = batch, n = num_masked)
    mask_tokens = mask_tokens + model.decoder_pos_emb(masked_indices)
    
    decoder_tokens = torch.zeros(batch, num_patches, model.decoder_dim, device=device)
    decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
    decoder_tokens[batch_range, masked_indices] = mask_tokens
    decoded_tokens = model.decoder(decoder_tokens)

    mask_tokens = decoded_tokens[batch_range, masked_indices]
    pred_pixel_values = model.to_pixels(mask_tokens)

    recon_loss = torch.nn.functional.mse_loss(pred_pixel_values, masked_patches)
def reconstruct_image(self, patches, model_input, masked_indices=None, pred_pixel_values=None, patch_size=8):
    patches = patches.cpu()
    masked_indices_in = masked_indices is not None
    predicted_pixels_in = pred_pixel_values is not None
    if masked_indices_in:
        masked_indices = masked_indices.cpu()
    if predicted_pixels_in:
        pred_pixel_values = pred_pixel_values.cpu()
    patch_width = patch_height = patch_size
    reconstructed_image = patches.clone()
    if masked_indices_in or predicted_pixels_in:
        for i in range(reconstructed_image.shape[0]):
            if masked_indices_in and predicted_pixels_in:
                reconstructed_image[i, masked_indices[i].cpu()] = pred_pixel_values[i, :].cpu().float()
            elif masked_indices_in:
                reconstructed_image[i, masked_indices[i].cpu()] = 0
    invert_patch = Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', w=int(model_input.shape[3] / patch_width),
                             h=int(model_input.shape[2] / patch_height), c=model_input.shape[1],
                             p1=patch_height, p2=patch_width)
    reconstructed_image = invert_patch(reconstructed_image)
    reconstructed_image = reconstructed_image.numpy().transpose(0, 2, 3, 1)
    return reconstructed_image.transpose(0, 3, 1, 2)
with torch.no_grad():
    reconstructed_images1 = reconstruct_image(
        model,
        patches,
        x,
        masked_indices=masked_indices,
        pred_pixel_values=pred_pixel_values,
        patch_size=16
    )
    reconstructed_images2 = reconstruct_image(
        model,
        patches,
        x,
        masked_indices=masked_indices,
        patch_size=16
    )
for i_img, img in enumerate(x):
    rec1 = reconstructed_images1[i_img]
    rec2 = reconstructed_images2[i_img]
    display(ToPILImage()(img[0]))
    display(ToPILImage()(rec2[0]))
    display(ToPILImage()(rec1[0]))

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

!jupyter nbconvert --to markdown README.ipynb
[NbConvertApp] Converting notebook README.ipynb to markdown
[NbConvertApp] Support files will be in README_files/
[NbConvertApp] Writing 7517 bytes to README.md
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Datasets used to train danjacobellis/cell-microscopy-mae