TCube_Merging / data /datautils.py
razaimam45's picture
Upload 108 files
a96891a verified
import os
from typing import Tuple
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# Import MedIMeta
from medimeta import MedIMeta
from data.hoi_dataset import BongardDataset
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
from data.fewshot_datasets import *
import data.augmix_ops as augmentations
import medmnist
from medmnist import INFO, Evaluator
ID_to_DIRNAME={
'I': 'ImageNet',
'A': 'imagenet-a',
'K': 'ImageNet-Sketch',
'R': 'imagenet-r',
'V': 'imagenetv2-matched-frequency-format-val',
'flower102': 'Flower102',
'dtd': 'DTD',
'pets': 'OxfordPets',
'cars': 'StanfordCars',
'ucf101': 'UCF101',
'caltech101': 'Caltech101',
'food101': 'Food101',
'sun397': 'SUN397',
'aircraft': 'fgvc_aircraft',
'eurosat': 'eurosat',
'idrid':'IDRID',
'isic2018':'ISIC2018',
'pneumonia_guangzhou':'PneumoniaGuangzhou',
'shenzhen_cxr':'ShenzhenCXR',
"montgomery_cxr":'MontgomeryCXR',
'covid':'Covid'
}
def build_dataset(set_id, transform, data_root, mode='test', n_shot=None, split="all", bongard_anno=False):
testdir = os.path.join(os.path.join(data_root, set_id),ID_to_DIRNAME[set_id])
# testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'test')
testset = datasets.ImageFolder(testdir, transform=transform)
# if set_id == 'I':
# # ImageNet validation set
# testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'val')
# testset = datasets.ImageFolder(testdir, transform=transform)
# elif set_id in ['A', 'K', 'R', 'V']:
# testdir = os.path.join(data_root, ID_to_DIRNAME[set_id])
# testset = datasets.ImageFolder(testdir, transform=transform)
# elif set_id in fewshot_datasets:
# if mode == 'train' and n_shot:
# testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode, n_shot=n_shot)
# else:
# testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode)
# elif set_id == 'bongard':
# assert isinstance(transform, Tuple)
# base_transform, query_transform = transform
# testset = BongardDataset(data_root, split, mode, base_transform, query_transform, bongard_anno)
# else:
# raise NotImplementedError
return testset
def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None):
dataset = MedIMeta(data_root, task, disease, transform=transform)
return dataset
def build_medmnist_dataset(data_root, set_id, transform, split='test', size=224, download=False):
info = INFO[set_id]
DataClass = getattr(medmnist, info['python_class'])
dataset = DataClass(split=split, transform=transform, size=size, download=download, root=data_root)
return dataset
medmnist_datasets = [
'tissuemnist', 'pathmnist', 'chestmnist', 'dermamnist', 'octmnist',
'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist',
'organamnist', 'organcmnist', 'organsmnist'
]
# AugMix Transforms
def get_preaugment():
return transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
])
def augmix(image, preprocess, aug_list, severity=1):
preaugment = get_preaugment()
x_orig = preaugment(image)
x_processed = preprocess(x_orig)
if len(aug_list) == 0:
return x_processed
w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
m = np.float32(np.random.beta(1.0, 1.0))
mix = torch.zeros_like(x_processed)
for i in range(3):
x_aug = x_orig.copy()
for _ in range(np.random.randint(1, 4)):
x_aug = np.random.choice(aug_list)(x_aug, severity)
mix += w[i] * preprocess(x_aug)
mix = m * x_processed + (1 - m) * mix
return mix
class AugMixAugmenter(object):
def __init__(self, base_transform, preprocess, n_views=2, augmix=False,
severity=1):
self.base_transform = base_transform
self.preprocess = preprocess
self.n_views = n_views
if augmix:
self.aug_list = augmentations.augmentations
else:
self.aug_list = []
self.severity = severity
def __call__(self, x):
# breakpoint()
image = self.preprocess(self.base_transform(x))
views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)]
return [image] + views