File size: 4,721 Bytes
0e37bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import subprocess
import time
from logging import getLogger
import numpy as np
import torch
import torchvision
_GLOBAL_SEED = 0
logger = getLogger()
class ImageNet(torchvision.datasets.ImageFolder):
def __init__(
self,
root,
image_folder="imagenet_full_size/061417/",
tar_file="imagenet_full_size-061417.tar.gz",
transform=None,
train=True,
job_id=None,
local_rank=None,
index_targets=False,
):
"""
ImageNet
Dataset wrapper
:param root: root network directory for ImageNet data
:param image_folder: path to images inside root network directory
:param tar_file: zipped image_folder inside root network directory
:param train: whether to load train data (or validation)
:param job_id: scheduler job-id used to create dir on local machine
:param index_targets: whether to index the id of each labeled image
"""
suffix = "train/" if train else "val/"
data_path = os.path.join(root, image_folder, suffix)
logger.info(f"data-path {data_path}")
super(ImageNet, self).__init__(root=data_path, transform=transform)
logger.info("Initialized ImageNet")
if index_targets:
self.targets = []
for sample in self.samples:
self.targets.append(sample[1])
self.targets = np.array(self.targets)
self.samples = np.array(self.samples)
mint = None
self.target_indices = []
for t in range(len(self.classes)):
indices = np.squeeze(np.argwhere(self.targets == t)).tolist()
self.target_indices.append(indices)
mint = len(indices) if mint is None else min(mint, len(indices))
logger.debug(f"num-labeled target {t} {len(indices)}")
logger.info(f"min. labeled indices {mint}")
class ImageNetSubset(object):
def __init__(self, dataset, subset_file):
"""
ImageNetSubset
:param dataset: ImageNet dataset object
:param subset_file: '.txt' file containing IDs of IN1K images to keep
"""
self.dataset = dataset
self.subset_file = subset_file
self.filter_dataset_(subset_file)
def filter_dataset_(self, subset_file):
"""Filter self.dataset to a subset"""
root = self.dataset.root
class_to_idx = self.dataset.class_to_idx
# -- update samples to subset of IN1k targets/samples
new_samples = []
logger.info(f"Using {subset_file}")
with open(subset_file, "r") as rfile:
for line in rfile:
class_name = line.split("_")[0]
target = class_to_idx[class_name]
img = line.split("\n")[0]
new_samples.append((os.path.join(root, class_name, img), target))
self.samples = new_samples
@property
def classes(self):
return self.dataset.classes
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
path, target = self.samples[index]
img = self.dataset.loader(path)
if self.dataset.transform is not None:
img = self.dataset.transform(img)
if self.dataset.target_transform is not None:
target = self.dataset.target_transform(target)
return img, target
def make_imagenet1k(
transform,
batch_size,
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
training=True,
drop_last=True,
persistent_workers=False,
subset_file=None,
):
dataset = ImageNet(
root=root_path,
image_folder=image_folder,
transform=transform,
train=training,
index_targets=False,
)
if subset_file is not None:
dataset = ImageNetSubset(dataset, subset_file)
logger.info("ImageNet dataset created")
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset=dataset, num_replicas=world_size, rank=rank)
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=collator,
sampler=dist_sampler,
batch_size=batch_size,
drop_last=drop_last,
pin_memory=pin_mem,
num_workers=num_workers,
persistent_workers=persistent_workers,
)
logger.info("ImageNet unsupervised data loader created")
return dataset, data_loader, dist_sampler
|