TextBraTS / utils /data_utils.py
Jupitern52's picture
Upload 16 files
2a5693e verified
# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import os
import numpy as np
import torch
from monai import data, transforms
from monai.data import NibabelReader
from monai.transforms import MapTransform
#Load biobert features
class LoadNumpyd(MapTransform):
def __init__(self, keys):
super().__init__(keys)
def __call__(self, data):
d = dict(data)
for key in self.keys:
d[key] = np.load(d[key])
d[key] = np.squeeze(d[key],axis=0)
return d
class Sampler(torch.utils.data.Sampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.shuffle = shuffle
self.make_even = make_even
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas])
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
if self.make_even:
if len(indices) < self.total_size:
if self.total_size - len(indices) < len(indices):
indices += indices[: (self.total_size - len(indices))]
else:
extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices))
indices += [indices[ids] for ids in extra_ids]
assert len(indices) == self.total_size
indices = indices[self.rank : self.total_size : self.num_replicas]
self.num_samples = len(indices)
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
def datafold_read(datalist, basedir, fold=0, key="training"):
with open(datalist) as f:
json_data = json.load(f)
json_data = json_data[key]
for d in json_data:
for k, v in d.items():
if isinstance(d[k], list):
d[k] = [os.path.join(basedir, iv) for iv in d[k]]
elif isinstance(d[k], str):
d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
tr = []
val = []
for d in json_data:
if "fold" in d and d["fold"] == fold:
val.append(d)
else:
tr.append(d)
return tr, val
def get_loader(args):
data_dir = args.data_dir
datalist_json = args.json_list
train_files, validation_files = datafold_read(datalist=datalist_json, basedir=data_dir, fold=args.fold)
train_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
LoadNumpyd(keys=["text_feature"]),
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
transforms.Resized(keys=["image","label"],spatial_size=[args.roi_x,args.roi_y,args.roi_z]),
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
transforms.ToTensord(keys=["image", "label", "text_feature"]),
]
)
val_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
LoadNumpyd(keys=["text_feature"]),
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
transforms.Resized(keys=["image", "label"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
transforms.ToTensord(keys=["image", "label", "text_feature"]),
]
)
test_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
LoadNumpyd(keys=["text_feature"]),
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
transforms.Resized(keys=["image", "label"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
transforms.ToTensord(keys=["image", "label", "text_feature"]),
]
)
if args.test_mode:
val_ds = data.Dataset(data=validation_files, transform=test_transform)
val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None
test_loader = data.DataLoader(
val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True
)
loader = test_loader
else:
train_ds = data.Dataset(data=train_files, transform=train_transform)
train_sampler = Sampler(train_ds) if args.distributed else None
train_loader = data.DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=(train_sampler is None),
num_workers=args.workers,
sampler=train_sampler,
pin_memory=True,
)
val_ds = data.Dataset(data=validation_files, transform=val_transform)
val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None
val_loader = data.DataLoader(
val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True
)
loader = [train_loader, val_loader]
return loader