import os from typing import Tuple from PIL import Image import numpy as np import torch import torchvision.transforms as transforms import torchvision.datasets as datasets # Import MedIMeta from medimeta import MedIMeta from data.hoi_dataset import BongardDataset try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC from data.fewshot_datasets import * import data.augmix_ops as augmentations import medmnist from medmnist import INFO, Evaluator ID_to_DIRNAME={ 'I': 'ImageNet', 'A': 'imagenet-a', 'K': 'ImageNet-Sketch', 'R': 'imagenet-r', 'V': 'imagenetv2-matched-frequency-format-val', 'flower102': 'Flower102', 'dtd': 'DTD', 'pets': 'OxfordPets', 'cars': 'StanfordCars', 'ucf101': 'UCF101', 'caltech101': 'Caltech101', 'food101': 'Food101', 'sun397': 'SUN397', 'aircraft': 'fgvc_aircraft', 'eurosat': 'eurosat', 'idrid':'IDRID', 'isic2018':'ISIC2018', 'pneumonia_guangzhou':'PneumoniaGuangzhou', 'shenzhen_cxr':'ShenzhenCXR', "montgomery_cxr":'MontgomeryCXR', 'covid':'Covid' } def build_dataset(set_id, transform, data_root, mode='test', n_shot=None, split="all", bongard_anno=False): testdir = os.path.join(os.path.join(data_root, set_id),ID_to_DIRNAME[set_id]) # testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'test') testset = datasets.ImageFolder(testdir, transform=transform) # if set_id == 'I': # # ImageNet validation set # testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'val') # testset = datasets.ImageFolder(testdir, transform=transform) # elif set_id in ['A', 'K', 'R', 'V']: # testdir = os.path.join(data_root, ID_to_DIRNAME[set_id]) # testset = datasets.ImageFolder(testdir, transform=transform) # elif set_id in fewshot_datasets: # if mode == 'train' and n_shot: # testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode, n_shot=n_shot) # else: # testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode) # elif set_id == 'bongard': # assert isinstance(transform, Tuple) # base_transform, query_transform = transform # testset = BongardDataset(data_root, split, mode, base_transform, query_transform, bongard_anno) # else: # raise NotImplementedError return testset def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None): dataset = MedIMeta(data_root, task, disease, transform=transform) return dataset def build_medmnist_dataset(data_root, set_id, transform, split='test', size=224, download=False): info = INFO[set_id] DataClass = getattr(medmnist, info['python_class']) dataset = DataClass(split=split, transform=transform, size=size, download=download, root=data_root) return dataset medmnist_datasets = [ 'tissuemnist', 'pathmnist', 'chestmnist', 'dermamnist', 'octmnist', 'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist', 'organamnist', 'organcmnist', 'organsmnist' ] # AugMix Transforms def get_preaugment(): return transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), ]) def augmix(image, preprocess, aug_list, severity=1): preaugment = get_preaugment() x_orig = preaugment(image) x_processed = preprocess(x_orig) if len(aug_list) == 0: return x_processed w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) m = np.float32(np.random.beta(1.0, 1.0)) mix = torch.zeros_like(x_processed) for i in range(3): x_aug = x_orig.copy() for _ in range(np.random.randint(1, 4)): x_aug = np.random.choice(aug_list)(x_aug, severity) mix += w[i] * preprocess(x_aug) mix = m * x_processed + (1 - m) * mix return mix class AugMixAugmenter(object): def __init__(self, base_transform, preprocess, n_views=2, augmix=False, severity=1): self.base_transform = base_transform self.preprocess = preprocess self.n_views = n_views if augmix: self.aug_list = augmentations.augmentations else: self.aug_list = [] self.severity = severity def __call__(self, x): # breakpoint() image = self.preprocess(self.base_transform(x)) views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)] return [image] + views