|
|
import os |
|
|
from typing import Tuple |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
import torchvision.datasets as datasets |
|
|
|
|
|
|
|
|
from medimeta import MedIMeta |
|
|
|
|
|
from data.hoi_dataset import BongardDataset |
|
|
try: |
|
|
from torchvision.transforms import InterpolationMode |
|
|
BICUBIC = InterpolationMode.BICUBIC |
|
|
except ImportError: |
|
|
BICUBIC = Image.BICUBIC |
|
|
|
|
|
from data.fewshot_datasets import * |
|
|
import data.augmix_ops as augmentations |
|
|
|
|
|
import medmnist |
|
|
from medmnist import INFO, Evaluator |
|
|
|
|
|
ID_to_DIRNAME={ |
|
|
'I': 'ImageNet', |
|
|
'A': 'imagenet-a', |
|
|
'K': 'ImageNet-Sketch', |
|
|
'R': 'imagenet-r', |
|
|
'V': 'imagenetv2-matched-frequency-format-val', |
|
|
'flower102': 'Flower102', |
|
|
'dtd': 'DTD', |
|
|
'pets': 'OxfordPets', |
|
|
'cars': 'StanfordCars', |
|
|
'ucf101': 'UCF101', |
|
|
'caltech101': 'Caltech101', |
|
|
'food101': 'Food101', |
|
|
'sun397': 'SUN397', |
|
|
'aircraft': 'fgvc_aircraft', |
|
|
'eurosat': 'eurosat', |
|
|
'idrid':'IDRID', |
|
|
'isic2018':'ISIC2018', |
|
|
'pneumonia_guangzhou':'PneumoniaGuangzhou', |
|
|
'shenzhen_cxr':'ShenzhenCXR', |
|
|
"montgomery_cxr":'MontgomeryCXR', |
|
|
'covid':'Covid' |
|
|
} |
|
|
|
|
|
def build_dataset(set_id, transform, data_root, mode='test', n_shot=None, split="all", bongard_anno=False): |
|
|
|
|
|
testdir = os.path.join(os.path.join(data_root, set_id),ID_to_DIRNAME[set_id]) |
|
|
|
|
|
testset = datasets.ImageFolder(testdir, transform=transform) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return testset |
|
|
|
|
|
def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None): |
|
|
dataset = MedIMeta(data_root, task, disease, transform=transform) |
|
|
return dataset |
|
|
|
|
|
def build_medmnist_dataset(data_root, set_id, transform, split='test', size=224, download=False): |
|
|
info = INFO[set_id] |
|
|
DataClass = getattr(medmnist, info['python_class']) |
|
|
dataset = DataClass(split=split, transform=transform, size=size, download=download, root=data_root) |
|
|
return dataset |
|
|
|
|
|
medmnist_datasets = [ |
|
|
'tissuemnist', 'pathmnist', 'chestmnist', 'dermamnist', 'octmnist', |
|
|
'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist', |
|
|
'organamnist', 'organcmnist', 'organsmnist' |
|
|
] |
|
|
|
|
|
|
|
|
def get_preaugment(): |
|
|
return transforms.Compose([ |
|
|
transforms.RandomResizedCrop(224), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
]) |
|
|
|
|
|
def augmix(image, preprocess, aug_list, severity=1): |
|
|
preaugment = get_preaugment() |
|
|
x_orig = preaugment(image) |
|
|
x_processed = preprocess(x_orig) |
|
|
if len(aug_list) == 0: |
|
|
return x_processed |
|
|
w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) |
|
|
m = np.float32(np.random.beta(1.0, 1.0)) |
|
|
|
|
|
mix = torch.zeros_like(x_processed) |
|
|
for i in range(3): |
|
|
x_aug = x_orig.copy() |
|
|
for _ in range(np.random.randint(1, 4)): |
|
|
x_aug = np.random.choice(aug_list)(x_aug, severity) |
|
|
mix += w[i] * preprocess(x_aug) |
|
|
mix = m * x_processed + (1 - m) * mix |
|
|
return mix |
|
|
|
|
|
|
|
|
class AugMixAugmenter(object): |
|
|
def __init__(self, base_transform, preprocess, n_views=2, augmix=False, |
|
|
severity=1): |
|
|
self.base_transform = base_transform |
|
|
self.preprocess = preprocess |
|
|
self.n_views = n_views |
|
|
|
|
|
if augmix: |
|
|
self.aug_list = augmentations.augmentations |
|
|
else: |
|
|
self.aug_list = [] |
|
|
self.severity = severity |
|
|
|
|
|
def __call__(self, x): |
|
|
|
|
|
image = self.preprocess(self.base_transform(x)) |
|
|
views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)] |
|
|
return [image] + views |
|
|
|
|
|
|
|
|
|
|
|
|