File size: 3,258 Bytes
c1f1d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

"""
Script to download datasets packaged with the repository.
"""
import os
import argparse

import mimicgen_envs
import mimicgen_envs.utils.file_utils as FileUtils
from mimicgen_envs import DATASET_REGISTRY


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # directory to download datasets to
    parser.add_argument(
        "--download_dir",
        type=str,
        default=None,
        help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.",
    )

    # dataset type to download datasets for
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="core",
        choices=list(DATASET_REGISTRY.keys()),
        help="Dataset type to download datasets for (e.g. source, core, object, robot, large_interpolation). Defaults to core.",
    )

    # tasks to download datasets for
    parser.add_argument(
        "--tasks",
        type=str,
        nargs='+',
        default=["square_d0"],
        help="Tasks to download datasets for. Defaults to square_d0 task. Pass 'all' to download all tasks\
            for the provided dataset type or directly specify the list of tasks.",
    )

    # dry run - don't actually download datasets, but print which datasets would be downloaded
    parser.add_argument(
        "--dry_run",
        action='store_true',
        help="set this flag to do a dry run to only print which datasets would be downloaded"
    )

    args = parser.parse_args()

    # set default base directory for downloads
    default_base_dir = args.download_dir
    if default_base_dir is None:
        default_base_dir = "data/robomimic/datasets"

    # load args
    download_dataset_type = args.dataset_type
    download_tasks = args.tasks
    if "all" in download_tasks:
        assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks)
        download_tasks = list(DATASET_REGISTRY[download_dataset_type].keys())
    else:
        for task in download_tasks:
            assert task in DATASET_REGISTRY[download_dataset_type], "got unknown task {} for dataset type {}. Choose one of {}".format(task, download_dataset_type, list(DATASET_REGISTRY[download_dataset_type].keys()))

    # download requested datasets
    for task in download_tasks:
        download_dir = os.path.abspath(os.path.join(default_base_dir, task))
        download_path = os.path.join(download_dir, "{}.hdf5".format(task))
        print("\nDownloading dataset:\n    dataset type: {}\n    task: {}\n    download path: {}"
            .format(download_dataset_type, task, download_path))
        url = DATASET_REGISTRY[download_dataset_type][task]["url"]
        if args.dry_run:
            print("\ndry run: skip download")
        else:
            # Make sure path exists and create if it doesn't
            os.makedirs(download_dir, exist_ok=True)
            print("")
            FileUtils.download_url_from_gdrive(
                url=url, 
                download_dir=download_dir,
                check_overwrite=True,
            )
        print("")