File size: 3,258 Bytes
c1f1d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].
"""
Script to download datasets packaged with the repository.
"""
import os
import argparse
import mimicgen_envs
import mimicgen_envs.utils.file_utils as FileUtils
from mimicgen_envs import DATASET_REGISTRY
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# directory to download datasets to
parser.add_argument(
"--download_dir",
type=str,
default=None,
help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.",
)
# dataset type to download datasets for
parser.add_argument(
"--dataset_type",
type=str,
default="core",
choices=list(DATASET_REGISTRY.keys()),
help="Dataset type to download datasets for (e.g. source, core, object, robot, large_interpolation). Defaults to core.",
)
# tasks to download datasets for
parser.add_argument(
"--tasks",
type=str,
nargs='+',
default=["square_d0"],
help="Tasks to download datasets for. Defaults to square_d0 task. Pass 'all' to download all tasks\
for the provided dataset type or directly specify the list of tasks.",
)
# dry run - don't actually download datasets, but print which datasets would be downloaded
parser.add_argument(
"--dry_run",
action='store_true',
help="set this flag to do a dry run to only print which datasets would be downloaded"
)
args = parser.parse_args()
# set default base directory for downloads
default_base_dir = args.download_dir
if default_base_dir is None:
default_base_dir = "data/robomimic/datasets"
# load args
download_dataset_type = args.dataset_type
download_tasks = args.tasks
if "all" in download_tasks:
assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks)
download_tasks = list(DATASET_REGISTRY[download_dataset_type].keys())
else:
for task in download_tasks:
assert task in DATASET_REGISTRY[download_dataset_type], "got unknown task {} for dataset type {}. Choose one of {}".format(task, download_dataset_type, list(DATASET_REGISTRY[download_dataset_type].keys()))
# download requested datasets
for task in download_tasks:
download_dir = os.path.abspath(os.path.join(default_base_dir, task))
download_path = os.path.join(download_dir, "{}.hdf5".format(task))
print("\nDownloading dataset:\n dataset type: {}\n task: {}\n download path: {}"
.format(download_dataset_type, task, download_path))
url = DATASET_REGISTRY[download_dataset_type][task]["url"]
if args.dry_run:
print("\ndry run: skip download")
else:
# Make sure path exists and create if it doesn't
os.makedirs(download_dir, exist_ok=True)
print("")
FileUtils.download_url_from_gdrive(
url=url,
download_dir=download_dir,
check_overwrite=True,
)
print("") |