|
|
|
|
|
|
|
|
|
|
|
from logging import getLogger |
|
|
|
_GLOBAL_SEED = 0 |
|
logger = getLogger() |
|
|
|
|
|
def init_data( |
|
batch_size, |
|
transform=None, |
|
shared_transform=None, |
|
data="ImageNet", |
|
collator=None, |
|
pin_mem=True, |
|
num_workers=8, |
|
world_size=1, |
|
rank=0, |
|
root_path=None, |
|
image_folder=None, |
|
training=True, |
|
drop_last=True, |
|
subset_file=None, |
|
clip_len=None, |
|
dataset_fpcs=None, |
|
frame_sample_rate=None, |
|
duration=None, |
|
fps=None, |
|
num_clips=1, |
|
random_clip_sampling=True, |
|
allow_clip_overlap=False, |
|
filter_short_videos=False, |
|
filter_long_videos=int(1e9), |
|
datasets_weights=None, |
|
persistent_workers=False, |
|
deterministic=True, |
|
log_dir=None, |
|
): |
|
if data.lower() == "imagenet": |
|
from src.datasets.imagenet1k import make_imagenet1k |
|
|
|
dataset, data_loader, dist_sampler = make_imagenet1k( |
|
transform=transform, |
|
batch_size=batch_size, |
|
collator=collator, |
|
pin_mem=pin_mem, |
|
training=training, |
|
num_workers=num_workers, |
|
world_size=world_size, |
|
rank=rank, |
|
root_path=root_path, |
|
image_folder=image_folder, |
|
persistent_workers=persistent_workers, |
|
drop_last=drop_last, |
|
subset_file=subset_file, |
|
) |
|
|
|
elif data.lower() == "videodataset": |
|
from src.datasets.video_dataset import make_videodataset |
|
|
|
dataset, data_loader, dist_sampler = make_videodataset( |
|
data_paths=root_path, |
|
batch_size=batch_size, |
|
frames_per_clip=clip_len, |
|
dataset_fpcs=dataset_fpcs, |
|
frame_step=frame_sample_rate, |
|
duration=duration, |
|
fps=fps, |
|
num_clips=num_clips, |
|
random_clip_sampling=random_clip_sampling, |
|
allow_clip_overlap=allow_clip_overlap, |
|
filter_short_videos=filter_short_videos, |
|
filter_long_videos=filter_long_videos, |
|
shared_transform=shared_transform, |
|
transform=transform, |
|
datasets_weights=datasets_weights, |
|
collator=collator, |
|
num_workers=num_workers, |
|
pin_mem=pin_mem, |
|
persistent_workers=persistent_workers, |
|
world_size=world_size, |
|
rank=rank, |
|
deterministic=deterministic, |
|
log_dir=log_dir, |
|
) |
|
|
|
return (data_loader, dist_sampler) |
|
|