soccer-qa-4b / src /datasets /data_manager.py
VarunKodathala's picture
Upload folder using huggingface_hub
0e37bb2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from logging import getLogger
_GLOBAL_SEED = 0
logger = getLogger()
def init_data(
batch_size,
transform=None,
shared_transform=None,
data="ImageNet",
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
training=True,
drop_last=True,
subset_file=None,
clip_len=None,
dataset_fpcs=None,
frame_sample_rate=None,
duration=None,
fps=None,
num_clips=1,
random_clip_sampling=True,
allow_clip_overlap=False,
filter_short_videos=False,
filter_long_videos=int(1e9),
datasets_weights=None,
persistent_workers=False,
deterministic=True,
log_dir=None,
):
if data.lower() == "imagenet":
from src.datasets.imagenet1k import make_imagenet1k
dataset, data_loader, dist_sampler = make_imagenet1k(
transform=transform,
batch_size=batch_size,
collator=collator,
pin_mem=pin_mem,
training=training,
num_workers=num_workers,
world_size=world_size,
rank=rank,
root_path=root_path,
image_folder=image_folder,
persistent_workers=persistent_workers,
drop_last=drop_last,
subset_file=subset_file,
)
elif data.lower() == "videodataset":
from src.datasets.video_dataset import make_videodataset
dataset, data_loader, dist_sampler = make_videodataset(
data_paths=root_path,
batch_size=batch_size,
frames_per_clip=clip_len,
dataset_fpcs=dataset_fpcs,
frame_step=frame_sample_rate,
duration=duration,
fps=fps,
num_clips=num_clips,
random_clip_sampling=random_clip_sampling,
allow_clip_overlap=allow_clip_overlap,
filter_short_videos=filter_short_videos,
filter_long_videos=filter_long_videos,
shared_transform=shared_transform,
transform=transform,
datasets_weights=datasets_weights,
collator=collator,
num_workers=num_workers,
pin_mem=pin_mem,
persistent_workers=persistent_workers,
world_size=world_size,
rank=rank,
deterministic=deterministic,
log_dir=log_dir,
)
return (data_loader, dist_sampler)