from torch.distributed.distributed_c10d import is_initialized | |
from torch.utils.data import Dataset, DistributedSampler | |
def get_ddp_sampler(dataset: Dataset, epoch: int): | |
""" | |
This function will create a DistributedSampler if DDP is initialized, | |
and will just return None if DDP is not initialized. | |
""" | |
if is_initialized(): | |
sampler = DistributedSampler(dataset) | |
sampler.set_epoch(epoch) | |
else: | |
sampler = None | |
return sampler | |