|
|
|
|
|
|
|
|
|
""" |
|
Script to download datasets packaged with the repository. |
|
""" |
|
import os |
|
import argparse |
|
|
|
import mimicgen_envs |
|
import mimicgen_envs.utils.file_utils as FileUtils |
|
from mimicgen_envs import DATASET_REGISTRY |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument( |
|
"--download_dir", |
|
type=str, |
|
default=None, |
|
help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--dataset_type", |
|
type=str, |
|
default="core", |
|
choices=list(DATASET_REGISTRY.keys()), |
|
help="Dataset type to download datasets for (e.g. source, core, object, robot, large_interpolation). Defaults to core.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--tasks", |
|
type=str, |
|
nargs='+', |
|
default=["square_d0"], |
|
help="Tasks to download datasets for. Defaults to square_d0 task. Pass 'all' to download all tasks\ |
|
for the provided dataset type or directly specify the list of tasks.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--dry_run", |
|
action='store_true', |
|
help="set this flag to do a dry run to only print which datasets would be downloaded" |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
default_base_dir = args.download_dir |
|
if default_base_dir is None: |
|
default_base_dir = "data/robomimic/datasets" |
|
|
|
|
|
download_dataset_type = args.dataset_type |
|
download_tasks = args.tasks |
|
if "all" in download_tasks: |
|
assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks) |
|
download_tasks = list(DATASET_REGISTRY[download_dataset_type].keys()) |
|
else: |
|
for task in download_tasks: |
|
assert task in DATASET_REGISTRY[download_dataset_type], "got unknown task {} for dataset type {}. Choose one of {}".format(task, download_dataset_type, list(DATASET_REGISTRY[download_dataset_type].keys())) |
|
|
|
|
|
for task in download_tasks: |
|
download_dir = os.path.abspath(os.path.join(default_base_dir, task)) |
|
download_path = os.path.join(download_dir, "{}.hdf5".format(task)) |
|
print("\nDownloading dataset:\n dataset type: {}\n task: {}\n download path: {}" |
|
.format(download_dataset_type, task, download_path)) |
|
url = DATASET_REGISTRY[download_dataset_type][task]["url"] |
|
if args.dry_run: |
|
print("\ndry run: skip download") |
|
else: |
|
|
|
os.makedirs(download_dir, exist_ok=True) |
|
print("") |
|
FileUtils.download_url_from_gdrive( |
|
url=url, |
|
download_dir=download_dir, |
|
check_overwrite=True, |
|
) |
|
print("") |