Lillianwei commited on
Commit
c1f1d32
·
1 Parent(s): 836bc67
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. README.md +35 -0
  3. equidiff/.gitignore +156 -0
  4. equidiff/LICENSE +21 -0
  5. equidiff/README.md +115 -0
  6. equidiff/combinehdf5.py +59 -0
  7. equidiff/conda_environment.yaml +61 -0
  8. equidiff/equi_diffpo/codecs/imagecodecs_numcodecs.py +1386 -0
  9. equidiff/equi_diffpo/common/checkpoint_util.py +59 -0
  10. equidiff/equi_diffpo/common/cv2_util.py +150 -0
  11. equidiff/equi_diffpo/common/env_util.py +23 -0
  12. equidiff/equi_diffpo/common/json_logger.py +117 -0
  13. equidiff/equi_diffpo/common/nested_dict_util.py +32 -0
  14. equidiff/equi_diffpo/common/normalize_util.py +311 -0
  15. equidiff/equi_diffpo/common/pose_trajectory_interpolator.py +208 -0
  16. equidiff/equi_diffpo/common/precise_sleep.py +25 -0
  17. equidiff/equi_diffpo/common/pymunk_override.py +248 -0
  18. equidiff/equi_diffpo/common/pymunk_util.py +52 -0
  19. equidiff/equi_diffpo/common/pytorch_util.py +82 -0
  20. equidiff/equi_diffpo/common/replay_buffer.py +588 -0
  21. equidiff/equi_diffpo/common/sampler.py +153 -0
  22. equidiff/equi_diffpo/common/timestamp_accumulator.py +222 -0
  23. equidiff/equi_diffpo/config/dp3.yaml +152 -0
  24. equidiff/equi_diffpo/config/task/mimicgen_abs.yaml +60 -0
  25. equidiff/equi_diffpo/config/task/mimicgen_pc_abs.yaml +81 -0
  26. equidiff/equi_diffpo/config/task/mimicgen_rel.yaml +60 -0
  27. equidiff/equi_diffpo/config/task/mimicgen_voxel_abs.yaml +84 -0
  28. equidiff/equi_diffpo/config/task/mimicgen_voxel_rel.yaml +84 -0
  29. equidiff/equi_diffpo/config/test_equi_diffusion_unet_abs_sq2.yaml +141 -0
  30. equidiff/equi_diffpo/config/test_sq2.yaml +142 -0
  31. equidiff/equi_diffpo/config/test_th2.yaml +142 -0
  32. equidiff/equi_diffpo/config/train_act_abs.yaml +88 -0
  33. equidiff/equi_diffpo/config/train_bc_rnn.yaml +94 -0
  34. equidiff/equi_diffpo/config/train_diffusion_transformer.yaml +143 -0
  35. equidiff/equi_diffpo/config/train_diffusion_unet.yaml +140 -0
  36. equidiff/equi_diffpo/config/train_diffusion_unet_voxel_abs.yaml +137 -0
  37. equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs.yaml +137 -0
  38. equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs_sq2_0-1.yaml +137 -0
  39. equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs_sq2_1-1.yaml +137 -0
  40. equidiff/equi_diffpo/config/train_equi_diffusion_unet_rel.yaml +136 -0
  41. equidiff/equi_diffpo/config/train_equi_diffusion_unet_voxel_abs.yaml +137 -0
  42. equidiff/equi_diffpo/config/train_equi_diffusion_unet_voxel_rel.yaml +137 -0
  43. equidiff/equi_diffpo/config/train_sq2.yaml +139 -0
  44. equidiff/equi_diffpo/config/train_sq2_5000.yaml +139 -0
  45. equidiff/equi_diffpo/config/train_th2_5000.yaml +139 -0
  46. equidiff/equi_diffpo/dataset/base_dataset.py +51 -0
  47. equidiff/equi_diffpo/env_runner/base_image_runner.py +9 -0
  48. equidiff/equi_diffpo/env_runner/base_lowdim_runner.py +9 -0
  49. equidiff/equi_diffpo/gym_util/async_vector_env.py +673 -0
  50. equidiff/equi_diffpo/gym_util/multistep_wrapper.py +162 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.mp4
2
+ mimicgen_environments*
3
+ robomimic*
4
+ equidiff/data/*
README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Equidiff
2
+
3
+ - folder_name: the name of the folder
4
+ - file_name: the name of your file
5
+
6
+ ## Prepare data
7
+ Use mimicgen to generate data.
8
+
9
+ Use `EmbodiedBM/equidiff/combinehdf5.py` to combine data from multiple .hdf5 files if needed.
10
+
11
+ Put hdf5 data at `EmbodiedBM/equidiff/data/robomimic/datasets` with the format [folder_name]/[file_name].hdf5
12
+
13
+ ## Convert data
14
+ ```bash
15
+ python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/square_d2_test/demo.hdf5 -o data/robomimic/datasets/square_d2_test/demo_abs.hdf5 -n 12
16
+ ```
17
+
18
+ ## Train
19
+
20
+ Use another CUDA device if 7 is currently in use.
21
+
22
+ ```bash
23
+ CUDA_VISIBLE_DEVICES=5 MUJOCO_GL=osmesa PYOPENGL_PLATFORM=osmesa HYDRA_FULL_ERROR=1 python train.py --config-name=train_sq2_5000 folder_name=square_d2_5000 file_name=demo n_demo=5000
24
+ ```
25
+
26
+ If you use another task than square_d2, you should change the task_name config by adding task_name=[task_name]
27
+
28
+ ## Test
29
+ Change the `ckpt_path` to the trained policy's weight's path in `EmbodiedBM/equidiff/equi_diffpo/config/test_sq2.yaml`
30
+
31
+ If you use another task than square_d2, you should change the dataset config in test_sq2.yaml and download the corresponding dataset from [Huggingface](https://huggingface.co/datasets/amandlek/mimicgen_datasets/tree/main/core).
32
+
33
+ ```bash
34
+ python test.py
35
+ ```
equidiff/.gitignore ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bin
2
+ logs
3
+ wandb
4
+
5
+
6
+ data_local
7
+ .vscode
8
+ _wandb
9
+
10
+ **/.DS_Store
11
+
12
+ fuse.cfg
13
+
14
+ *.ai
15
+
16
+ # Generation results
17
+ results/
18
+
19
+ ray/auth.json
20
+
21
+ # Byte-compiled / optimized / DLL files
22
+ __pycache__/
23
+ *.py[cod]
24
+ *$py.class
25
+
26
+ # C extensions
27
+ *.so
28
+
29
+ # Distribution / packaging
30
+ .Python
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ lib/
38
+ lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ wheels/
43
+ pip-wheel-metadata/
44
+ share/python-wheels/
45
+ *.egg-info/
46
+ .installed.cfg
47
+ *.egg
48
+ MANIFEST
49
+
50
+ # PyInstaller
51
+ # Usually these files are written by a python script from a template
52
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
53
+ *.manifest
54
+ *.spec
55
+
56
+ # Installer logs
57
+ pip-log.txt
58
+ pip-delete-this-directory.txt
59
+
60
+ # Unit test / coverage reports
61
+ htmlcov/
62
+ .tox/
63
+ .nox/
64
+ .coverage
65
+ .coverage.*
66
+ .cache
67
+ nosetests.xml
68
+ coverage.xml
69
+ *.cover
70
+ *.py,cover
71
+ .hypothesis/
72
+ .pytest_cache/
73
+
74
+ # Translations
75
+ *.mo
76
+ *.pot
77
+
78
+ # Django stuff:
79
+ *.log
80
+ local_settings.py
81
+ db.sqlite3
82
+ db.sqlite3-journal
83
+
84
+ # Flask stuff:
85
+ instance/
86
+ .webassets-cache
87
+
88
+ # Scrapy stuff:
89
+ .scrapy
90
+
91
+ # Sphinx documentation
92
+ docs/_build/
93
+
94
+ # PyBuilder
95
+ target/
96
+
97
+ # Jupyter Notebook
98
+ .ipynb_checkpoints
99
+
100
+ # IPython
101
+ profile_default/
102
+ ipython_config.py
103
+
104
+ # pyenv
105
+ .python-version
106
+
107
+ # pipenv
108
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
110
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
111
+ # install all needed dependencies.
112
+ #Pipfile.lock
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
141
+ equi_diffpo/scripts/equidiff_data_conversion.py
142
+ equi_diffpo/model/equi/vec_conditional_unet1d_1.py
143
+ update_max_score.py
144
+ test4.png
145
+ test3.png
146
+ test2.png
147
+ test1.png
148
+ test.png
149
+ sampled_xyz.png
150
+ pc.pt
151
+ metric3.py
152
+ metric2.py
153
+ metric1.py
154
+ grouped_xyz.png
155
+ all_xyz.png
156
+ 1.png
equidiff/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Dian Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
equidiff/README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Equivariant Diffusion Policy
2
+ [Project Website](https://equidiff.github.io) | [Paper](https://arxiv.org/pdf/2407.01812) | [Video](https://youtu.be/xIFSx_NVROU?si=MaxsHmih6AnQKAVy)
3
+ <a href="https://pointw.github.io/">Dian Wang</a><sup>1</sup>, <a href="https://www.linkedin.com/in/stephen-hart-3711666/">Stephen Hart</a><sup>2</sup>, <a href="https://www.linkedin.com/in/surovik/">David Surovik</a><sup>2</sup>, <a href="https://kelestemur.com">Tarik Kelestemur</a><sup>2</sup>, <a href="https://haojhuang.github.io/">Haojie Huang</a><sup>1</sup>, <a href="https://www.linkedin.com/in/haibo-zhao-b68742250/">Haibo Zhao</a><sup>1</sup>, <a href="https://www.linkedin.com/in/mark-yeatman-58a49763/">Mark Yeatman</a><sup>2</sup>, <a href="https://www.robo.guru/">Jiuguang Wang</a><sup>2</sup>, <a href="https://www.robinwalters.com/">Robin Walters</a><sup>1</sup>, <a href="https://helpinghandslab.netlify.app/people/">Robert Platt</a><sup>12</sup>
4
+ <sup>1</sup>Northeastern Univeristy, <sup>2</sup>Boston Dynamics AI Institute
5
+ Conference on Robot Learning 2024 (Oral)
6
+ ![](img/equi.gif) |
7
+ ## Installation
8
+ 1. Install the following apt packages for mujoco:
9
+ ```bash
10
+ sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
11
+ ```
12
+ 1. Install gfortran (dependancy for escnn)
13
+ ```bash
14
+ sudo apt install -y gfortran
15
+ ```
16
+
17
+ 1. Install [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) (strongly recommended) or Anaconda
18
+ 1. Clone this repo
19
+ ```bash
20
+ git clone https://github.com/pointW/equidiff.git
21
+ cd equidiff
22
+ ```
23
+ 1. Install environment:
24
+ Use Mambaforge (strongly recommended):
25
+ ```bash
26
+ mamba env create -f conda_environment.yaml
27
+ conda activate equidiff
28
+ ```
29
+ or use Anaconda (not recommended):
30
+ ```bash
31
+ conda env create -f conda_environment.yaml
32
+ conda activate equidiff
33
+ ```
34
+ 1. Install mimicgen:
35
+ ```bash
36
+ cd ..
37
+ git clone https://github.com/NVlabs/mimicgen_environments.git
38
+ cd mimicgen_environments
39
+ # This project was developed with Mimicgen v0.1.0. The latest version should work fine, but it is not tested
40
+ git checkout 081f7dbbe5fff17b28c67ce8ec87c371f32526a9
41
+ pip install -e .
42
+ cd ../equidiff
43
+ ```
44
+ 1. Make sure mujoco version is 2.3.2 (required by mimicgen)
45
+ ```bash
46
+ pip list | grep mujoco
47
+ ```
48
+
49
+ ## Dataset
50
+ ### Download Dataset
51
+ ```bash
52
+ # Download all datasets
53
+ python equi_diffpo/scripts/download_datasets.py --tasks stack_d1 stack_three_d1 square_d2 threading_d2 coffee_d2 three_piece_assembly_d2 hammer_cleanup_d1 mug_cleanup_d1 kitchen_d1 nut_assembly_d0 pick_place_d0 coffee_preparation_d1
54
+ # Alternatively, download one (or several) datasets of interest, e.g.,
55
+ python equi_diffpo/scripts/download_datasets.py --tasks stack_d1
56
+ ```
57
+ ### Generating Voxel and Point Cloud Observation
58
+
59
+ ```bash
60
+ # Template
61
+ python equi_diffpo/scripts/dataset_states_to_obs.py --input data/robomimic/datasets/[dataset]/[dataset].hdf5 --output data/robomimic/datasets/[dataset]/[dataset]_voxel.hdf5 --num_workers=[n_worker]
62
+ # Replace [dataset] and [n_worker] with your choices.
63
+ # E.g., use 24 workers to generate point cloud and voxel observation for stack_d1
64
+ python equi_diffpo/scripts/dataset_states_to_obs.py --input data/robomimic/datasets/stack_d1/stack_d1.hdf5 --output data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 --num_workers=24
65
+ ```
66
+
67
+ ### Convert Action Space in Dataset
68
+ The downloaded dataset has a relative action space. To train with absolute action space, the dataset needs to be converted accordingly
69
+ ```bash
70
+ # Template
71
+ python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/[dataset]/[dataset].hdf5 -o data/robomimic/datasets/[dataset]/[dataset]_abs.hdf5 -n [n_worker]
72
+ # Replace [dataset] and [n_worker] with your choices.
73
+ # E.g., convert stack_d1 (non-voxel) with 12 workers
74
+ python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 -o data/robomimic/datasets/stack_d1/stack_d1_abs.hdf5 -n 12
75
+ # E.g., convert stack_d1_voxel (voxel) with 12 workers
76
+ python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 -o data/robomimic/datasets/stack_d1/stack_d1_voxel_abs.hdf5 -n 12
77
+ ```
78
+
79
+ ## Training with image observation
80
+ To train Equivariant Diffusion Policy (with absolute pose control) in Stack D1 task:
81
+ ```bash
82
+ # Make sure you have the non-voxel converted dataset with absolute action space from the previous step
83
+ python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 n_demo=100
84
+ ```
85
+ To train with relative pose control instead:
86
+ ```bash
87
+ python train.py --config-name=train_equi_diffusion_unet_rel task_name=stack_d1 n_demo=100
88
+ ```
89
+ To train in other tasks, replace `stack_d1` with `stack_three_d1`, `square_d2`, `threading_d2`, `coffee_d2`, `three_piece_assembly_d2`, `hammer_cleanup_d1`, `mug_cleanup_d1`, `kitchen_d1`, `nut_assembly_d0`, `pick_place_d0`, `coffee_preparation_d1`. Notice that the corresponding dataset should be downloaded already. If training absolute pose control, the data conversion is also needed.
90
+
91
+ To run environments on CPU (to save GPU memory), use `osmesa` instead of `egl` through `MUJOCO_GL=osmesa PYOPENGL_PLATTFORM=osmesa`, e.g.,
92
+ ```bash
93
+ MUJOCO_GL=osmesa PYOPENGL_PLATTFORM=osmesa python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1
94
+ ```
95
+
96
+ Equivariant Diffusion Policy requires around 22G GPU memory to run with batch size of 128 (default). To reduce the GPU usage, consider training with smaller batch size and/or reducing the hidden dimension
97
+ ```bash
98
+ # to train with batch size of 64 and hidden dimension of 64
99
+ MUJOCO_GL=osmesa PYOPENGL_PLATTFORM=osmesa python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 policy.enc_n_hidden=64 dataloader.batch_size=64
100
+ ```
101
+
102
+ ## Training with voxel observation
103
+ To train Equivariant Diffusion Policy (with absolute pose control) in Stack D1 task:
104
+ ```bash
105
+ # Make sure you have the voxel converted dataset with absolute action space from the previous step
106
+ python train.py --config-name=train_equi_diffusion_unet_voxel_abs task_name=stack_d1 n_demo=100
107
+ ```
108
+
109
+ ## License
110
+ This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.
111
+
112
+ ## Acknowledgement
113
+ * Our repo is built upon the origional [Diffusion Policy](https://github.com/real-stanford/diffusion_policy)
114
+ * Our ACT baseline is adaped from its [original repo](https://github.com/tonyzhaozh/act)
115
+ * Our DP3 baseline is adaped from its [original repo](https://github.com/YanjieZe/3D-Diffusion-Policy)
equidiff/combinehdf5.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import re
3
+ import random
4
+
5
+ def update_suffix(original_string, increment):
6
+ updated_string = re.sub(r'(\d+)$', lambda x: str(int(x.group(1)) + increment), original_string)
7
+ return updated_string
8
+
9
+ def merge(output_file, input_files, total_size, truncate_len=-1):
10
+ numbers = list(range(total_size))
11
+ random.shuffle(numbers)
12
+ with h5py.File(output_file, 'w') as h5out:
13
+ h5out_data = h5out.create_group('data')
14
+ i = 0
15
+ for input_file in input_files:
16
+ with h5py.File(input_file, 'r') as f:
17
+ d = f['data']
18
+ for key in d:
19
+ new_key = f"demo_{numbers[i]}"
20
+ print(new_key)
21
+ if isinstance(d[key], h5py.Group):
22
+ d[key].copy(d[key], h5out_data, name=new_key)
23
+ elif isinstance(d[key], h5py.Dataset):
24
+ h5out_data.create_dataset(key, data=d[key][:])
25
+ i+=1
26
+ if truncate_len:
27
+ if i == truncate_len:
28
+ break
29
+ print(len(h5out_data))
30
+ with h5py.File(input_files[0], 'r') as f:
31
+ d1 = f['data']
32
+ if "env_args" in d1.attrs:
33
+ h5out_data.attrs["env_args"] = d1.attrs["env_args"]
34
+
35
+ def print_hdf5_structure(file_path):
36
+ def recursively_print(group, indent=0):
37
+ for key in group:
38
+ item = group[key]
39
+ if isinstance(item, h5py.Group):
40
+ print(" " * indent + f"Group: {key}")
41
+ recursively_print(item, indent + 1)
42
+ elif isinstance(item, h5py.Dataset):
43
+ print(" " * indent + f"Dataset: {key}, Shape: {item.shape}, Type: {item.dtype}")
44
+
45
+ with h5py.File(file_path, 'r') as f:
46
+ print(f"File: {file_path}")
47
+ recursively_print(f)
48
+ dataset = f["data"]
49
+ if "env_args" in dataset.attrs:
50
+ env_args = dataset.attrs["env_args"]
51
+ print(f"env_args: {env_args}")
52
+
53
+ # input_files = ["/home/siweih/Project/EmbodiedBM/equidiff/data/robomimic/datasets/square_d2/square_d2.hdf5","/home/siweih/Project/EmbodiedBM/equidiff/mix_4000.hdf5"]
54
+
55
+ # demo_num = 5000
56
+ # output_file = f"mix_{demo_num}.hdf5"
57
+
58
+ # merge(output_file, input_files, demo_num)
59
+ print_hdf5_structure("/home/siweih/Project/EmbodiedBM/mimicgen/core_datasets/square/demo_src_square_task_D2/demo.hdf5")
equidiff/conda_environment.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: equidiff
2
+ channels:
3
+ - pytorch
4
+ - pytorch3d
5
+ - nvidia
6
+ - conda-forge
7
+ dependencies:
8
+ - python=3.9
9
+ - pip=22.2.2
10
+ - pytorch=2.1.0
11
+ - torchaudio=2.1.0
12
+ - torchvision=0.16.0
13
+ - pytorch-cuda=11.8
14
+ - pytorch3d=0.7.5
15
+ - numpy=1.23.3
16
+ - numba==0.56.4
17
+ - scipy==1.9.1
18
+ - py-opencv=4.6.0
19
+ - cffi=1.15.1
20
+ - ipykernel=6.16
21
+ - matplotlib=3.6.1
22
+ - zarr=2.12.0
23
+ - numcodecs=0.10.2
24
+ - h5py=3.7.0
25
+ - hydra-core=1.2.0
26
+ - einops=0.4.1
27
+ - tqdm=4.64.1
28
+ - dill=0.3.5.1
29
+ - scikit-video=1.1.11
30
+ - scikit-image=0.19.3
31
+ - gym=0.21.0
32
+ - pymunk=6.2.1
33
+ - threadpoolctl=3.1.0
34
+ - shapely=1.8.4
35
+ - cython=0.29.32
36
+ - imageio=2.22.0
37
+ - imageio-ffmpeg=0.4.7
38
+ - termcolor=2.0.1
39
+ - tensorboard=2.10.1
40
+ - tensorboardx=2.5.1
41
+ - psutil=5.9.2
42
+ - click=8.0.4
43
+ - boto3=1.24.96
44
+ - accelerate=0.13.2
45
+ - datasets=2.6.1
46
+ - diffusers=0.11.1
47
+ - av=10.0.0
48
+ - cmake=3.24.3
49
+ # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
50
+ - llvm-openmp=14
51
+ # trick to force reinstall imagecodecs via pip
52
+ - imagecodecs==2022.8.8
53
+ - pip:
54
+ - open3d
55
+ - wandb==0.17.0
56
+ - pygame
57
+ - imagecodecs==2022.9.26
58
+ - escnn @ https://github.com/pointW/escnn/archive/fc4714cb6dc0d2a32f9fcea35771968b89911109.tar.gz
59
+ - robosuite @ https://github.com/ARISE-Initiative/robosuite/archive/b9d8d3de5e3dfd1724f4a0e6555246c460407daa.tar.gz
60
+ - robomimic @ https://github.com/pointW/robomimic/archive/8aad5b3caaaac9289b1504438a7f5d3a76d06c07.tar.gz
61
+ - robosuite-task-zoo @ https://github.com/pointW/robosuite-task-zoo/archive/0f8a7b2fa5d192e4e8800bebfe8090b28926f3ed.tar.gz
equidiff/equi_diffpo/codecs/imagecodecs_numcodecs.py ADDED
@@ -0,0 +1,1386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # imagecodecs/numcodecs.py
3
+
4
+ # Copyright (c) 2021-2022, Christoph Gohlke
5
+ # All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+ #
10
+ # 1. Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ # this list of conditions and the following disclaimer in the documentation
15
+ # and/or other materials provided with the distribution.
16
+ #
17
+ # 3. Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from
19
+ # this software without specific prior written permission.
20
+ #
21
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31
+ # POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ """Additional numcodecs implemented using imagecodecs."""
34
+
35
+ __version__ = '2022.9.26'
36
+
37
+ __all__ = ('register_codecs',)
38
+
39
+ import numpy
40
+ from numcodecs.abc import Codec
41
+ from numcodecs.registry import register_codec, get_codec
42
+
43
+ import imagecodecs
44
+
45
+
46
+ def protective_squeeze(x: numpy.ndarray):
47
+ """
48
+ Squeeze dim only if it's not the last dim.
49
+ Image dim expected to be *, H, W, C
50
+ """
51
+ img_shape = x.shape[-3:]
52
+ if len(x.shape) > 3:
53
+ n_imgs = numpy.prod(x.shape[:-3])
54
+ if n_imgs > 1:
55
+ img_shape = (-1,) + img_shape
56
+ return x.reshape(img_shape)
57
+
58
+ def get_default_image_compressor(**kwargs):
59
+ if imagecodecs.JPEGXL:
60
+ # has JPEGXL
61
+ this_kwargs = {
62
+ 'effort': 3,
63
+ 'distance': 0.3,
64
+ # bug in libjxl, invalid codestream for non-lossless
65
+ # when decoding speed > 1
66
+ 'decodingspeed': 1
67
+ }
68
+ this_kwargs.update(kwargs)
69
+ return JpegXl(**this_kwargs)
70
+ else:
71
+ this_kwargs = {
72
+ 'level': 50
73
+ }
74
+ this_kwargs.update(kwargs)
75
+ return Jpeg2k(**this_kwargs)
76
+
77
+ class Aec(Codec):
78
+ """AEC codec for numcodecs."""
79
+
80
+ codec_id = 'imagecodecs_aec'
81
+
82
+ def __init__(
83
+ self, bitspersample=None, flags=None, blocksize=None, rsi=None
84
+ ):
85
+ self.bitspersample = bitspersample
86
+ self.flags = flags
87
+ self.blocksize = blocksize
88
+ self.rsi = rsi
89
+
90
+ def encode(self, buf):
91
+ return imagecodecs.aec_encode(
92
+ buf,
93
+ bitspersample=self.bitspersample,
94
+ flags=self.flags,
95
+ blocksize=self.blocksize,
96
+ rsi=self.rsi,
97
+ )
98
+
99
+ def decode(self, buf, out=None):
100
+ return imagecodecs.aec_decode(
101
+ buf,
102
+ bitspersample=self.bitspersample,
103
+ flags=self.flags,
104
+ blocksize=self.blocksize,
105
+ rsi=self.rsi,
106
+ out=_flat(out),
107
+ )
108
+
109
+
110
+ class Apng(Codec):
111
+ """APNG codec for numcodecs."""
112
+
113
+ codec_id = 'imagecodecs_apng'
114
+
115
+ def __init__(self, level=None, photometric=None, delay=None):
116
+ self.level = level
117
+ self.photometric = photometric
118
+ self.delay = delay
119
+
120
+ def encode(self, buf):
121
+ buf = protective_squeeze(numpy.asarray(buf))
122
+ return imagecodecs.apng_encode(
123
+ buf,
124
+ level=self.level,
125
+ photometric=self.photometric,
126
+ delay=self.delay,
127
+ )
128
+
129
+ def decode(self, buf, out=None):
130
+ return imagecodecs.apng_decode(buf, out=out)
131
+
132
+
133
+ class Avif(Codec):
134
+ """AVIF codec for numcodecs."""
135
+
136
+ codec_id = 'imagecodecs_avif'
137
+
138
+ def __init__(
139
+ self,
140
+ level=None,
141
+ speed=None,
142
+ tilelog2=None,
143
+ bitspersample=None,
144
+ pixelformat=None,
145
+ numthreads=None,
146
+ index=None,
147
+ ):
148
+ self.level = level
149
+ self.speed = speed
150
+ self.tilelog2 = tilelog2
151
+ self.bitspersample = bitspersample
152
+ self.pixelformat = pixelformat
153
+ self.numthreads = numthreads
154
+ self.index = index
155
+
156
+ def encode(self, buf):
157
+ buf = protective_squeeze(numpy.asarray(buf))
158
+ return imagecodecs.avif_encode(
159
+ buf,
160
+ level=self.level,
161
+ speed=self.speed,
162
+ tilelog2=self.tilelog2,
163
+ bitspersample=self.bitspersample,
164
+ pixelformat=self.pixelformat,
165
+ numthreads=self.numthreads,
166
+ )
167
+
168
+ def decode(self, buf, out=None):
169
+ return imagecodecs.avif_decode(
170
+ buf, index=self.index, numthreads=self.numthreads, out=out
171
+ )
172
+
173
+
174
+ class Bitorder(Codec):
175
+ """Bitorder codec for numcodecs."""
176
+
177
+ codec_id = 'imagecodecs_bitorder'
178
+
179
+ def encode(self, buf):
180
+ return imagecodecs.bitorder_encode(buf)
181
+
182
+ def decode(self, buf, out=None):
183
+ return imagecodecs.bitorder_decode(buf, out=_flat(out))
184
+
185
+
186
+ class Bitshuffle(Codec):
187
+ """Bitshuffle codec for numcodecs."""
188
+
189
+ codec_id = 'imagecodecs_bitshuffle'
190
+
191
+ def __init__(self, itemsize=1, blocksize=0):
192
+ self.itemsize = itemsize
193
+ self.blocksize = blocksize
194
+
195
+ def encode(self, buf):
196
+ return imagecodecs.bitshuffle_encode(
197
+ buf, itemsize=self.itemsize, blocksize=self.blocksize
198
+ ).tobytes()
199
+
200
+ def decode(self, buf, out=None):
201
+ return imagecodecs.bitshuffle_decode(
202
+ buf,
203
+ itemsize=self.itemsize,
204
+ blocksize=self.blocksize,
205
+ out=_flat(out),
206
+ )
207
+
208
+
209
+ class Blosc(Codec):
210
+ """Blosc codec for numcodecs."""
211
+
212
+ codec_id = 'imagecodecs_blosc'
213
+
214
+ def __init__(
215
+ self,
216
+ level=None,
217
+ compressor=None,
218
+ typesize=None,
219
+ blocksize=None,
220
+ shuffle=None,
221
+ numthreads=None,
222
+ ):
223
+ self.level = level
224
+ self.compressor = compressor
225
+ self.typesize = typesize
226
+ self.blocksize = blocksize
227
+ self.shuffle = shuffle
228
+ self.numthreads = numthreads
229
+
230
+ def encode(self, buf):
231
+ buf = protective_squeeze(numpy.asarray(buf))
232
+ return imagecodecs.blosc_encode(
233
+ buf,
234
+ level=self.level,
235
+ compressor=self.compressor,
236
+ typesize=self.typesize,
237
+ blocksize=self.blocksize,
238
+ shuffle=self.shuffle,
239
+ numthreads=self.numthreads,
240
+ )
241
+
242
+ def decode(self, buf, out=None):
243
+ return imagecodecs.blosc_decode(
244
+ buf, numthreads=self.numthreads, out=_flat(out)
245
+ )
246
+
247
+
248
+ class Blosc2(Codec):
249
+ """Blosc2 codec for numcodecs."""
250
+
251
+ codec_id = 'imagecodecs_blosc2'
252
+
253
+ def __init__(
254
+ self,
255
+ level=None,
256
+ compressor=None,
257
+ typesize=None,
258
+ blocksize=None,
259
+ shuffle=None,
260
+ numthreads=None,
261
+ ):
262
+ self.level = level
263
+ self.compressor = compressor
264
+ self.typesize = typesize
265
+ self.blocksize = blocksize
266
+ self.shuffle = shuffle
267
+ self.numthreads = numthreads
268
+
269
+ def encode(self, buf):
270
+ buf = protective_squeeze(numpy.asarray(buf))
271
+ return imagecodecs.blosc2_encode(
272
+ buf,
273
+ level=self.level,
274
+ compressor=self.compressor,
275
+ typesize=self.typesize,
276
+ blocksize=self.blocksize,
277
+ shuffle=self.shuffle,
278
+ numthreads=self.numthreads,
279
+ )
280
+
281
+ def decode(self, buf, out=None):
282
+ return imagecodecs.blosc2_decode(
283
+ buf, numthreads=self.numthreads, out=_flat(out)
284
+ )
285
+
286
+
287
+ class Brotli(Codec):
288
+ """Brotli codec for numcodecs."""
289
+
290
+ codec_id = 'imagecodecs_brotli'
291
+
292
+ def __init__(self, level=None, mode=None, lgwin=None):
293
+ self.level = level
294
+ self.mode = mode
295
+ self.lgwin = lgwin
296
+
297
+ def encode(self, buf):
298
+ return imagecodecs.brotli_encode(
299
+ buf, level=self.level, mode=self.mode, lgwin=self.lgwin
300
+ )
301
+
302
+ def decode(self, buf, out=None):
303
+ return imagecodecs.brotli_decode(buf, out=_flat(out))
304
+
305
+
306
+ class ByteShuffle(Codec):
307
+ """ByteShuffle codec for numcodecs."""
308
+
309
+ codec_id = 'imagecodecs_byteshuffle'
310
+
311
+ def __init__(
312
+ self, shape, dtype, axis=-1, dist=1, delta=False, reorder=False
313
+ ):
314
+ self.shape = tuple(shape)
315
+ self.dtype = numpy.dtype(dtype).str
316
+ self.axis = axis
317
+ self.dist = dist
318
+ self.delta = bool(delta)
319
+ self.reorder = bool(reorder)
320
+
321
+ def encode(self, buf):
322
+ buf = protective_squeeze(numpy.asarray(buf))
323
+ assert buf.shape == self.shape
324
+ assert buf.dtype == self.dtype
325
+ return imagecodecs.byteshuffle_encode(
326
+ buf,
327
+ axis=self.axis,
328
+ dist=self.dist,
329
+ delta=self.delta,
330
+ reorder=self.reorder,
331
+ ).tobytes()
332
+
333
+ def decode(self, buf, out=None):
334
+ if not isinstance(buf, numpy.ndarray):
335
+ buf = numpy.frombuffer(buf, dtype=self.dtype).reshape(*self.shape)
336
+ return imagecodecs.byteshuffle_decode(
337
+ buf,
338
+ axis=self.axis,
339
+ dist=self.dist,
340
+ delta=self.delta,
341
+ reorder=self.reorder,
342
+ out=out,
343
+ )
344
+
345
+
346
+ class Bz2(Codec):
347
+ """Bz2 codec for numcodecs."""
348
+
349
+ codec_id = 'imagecodecs_bz2'
350
+
351
+ def __init__(self, level=None):
352
+ self.level = level
353
+
354
+ def encode(self, buf):
355
+ return imagecodecs.bz2_encode(buf, level=self.level)
356
+
357
+ def decode(self, buf, out=None):
358
+ return imagecodecs.bz2_decode(buf, out=_flat(out))
359
+
360
+
361
+ class Cms(Codec):
362
+ """CMS codec for numcodecs."""
363
+
364
+ codec_id = 'imagecodecs_cms'
365
+
366
+ def __init__(self, *args, **kwargs):
367
+ pass
368
+
369
+ def encode(self, buf, out=None):
370
+ # return imagecodecs.cms_transform(buf)
371
+ raise NotImplementedError
372
+
373
+ def decode(self, buf, out=None):
374
+ # return imagecodecs.cms_transform(buf)
375
+ raise NotImplementedError
376
+
377
+
378
+ class Deflate(Codec):
379
+ """Deflate codec for numcodecs."""
380
+
381
+ codec_id = 'imagecodecs_deflate'
382
+
383
+ def __init__(self, level=None, raw=False):
384
+ self.level = level
385
+ self.raw = bool(raw)
386
+
387
+ def encode(self, buf):
388
+ return imagecodecs.deflate_encode(buf, level=self.level, raw=self.raw)
389
+
390
+ def decode(self, buf, out=None):
391
+ return imagecodecs.deflate_decode(buf, out=_flat(out), raw=self.raw)
392
+
393
+
394
+ class Delta(Codec):
395
+ """Delta codec for numcodecs."""
396
+
397
+ codec_id = 'imagecodecs_delta'
398
+
399
+ def __init__(self, shape=None, dtype=None, axis=-1, dist=1):
400
+ self.shape = None if shape is None else tuple(shape)
401
+ self.dtype = None if dtype is None else numpy.dtype(dtype).str
402
+ self.axis = axis
403
+ self.dist = dist
404
+
405
+ def encode(self, buf):
406
+ if self.shape is not None or self.dtype is not None:
407
+ buf = protective_squeeze(numpy.asarray(buf))
408
+ assert buf.shape == self.shape
409
+ assert buf.dtype == self.dtype
410
+ return imagecodecs.delta_encode(
411
+ buf, axis=self.axis, dist=self.dist
412
+ ).tobytes()
413
+
414
+ def decode(self, buf, out=None):
415
+ if self.shape is not None or self.dtype is not None:
416
+ buf = numpy.frombuffer(buf, dtype=self.dtype).reshape(*self.shape)
417
+ return imagecodecs.delta_decode(
418
+ buf, axis=self.axis, dist=self.dist, out=out
419
+ )
420
+
421
+
422
+ class Float24(Codec):
423
+ """Float24 codec for numcodecs."""
424
+
425
+ codec_id = 'imagecodecs_float24'
426
+
427
+ def __init__(self, byteorder=None, rounding=None):
428
+ self.byteorder = byteorder
429
+ self.rounding = rounding
430
+
431
+ def encode(self, buf):
432
+ buf = protective_squeeze(numpy.asarray(buf))
433
+ return imagecodecs.float24_encode(
434
+ buf, byteorder=self.byteorder, rounding=self.rounding
435
+ )
436
+
437
+ def decode(self, buf, out=None):
438
+ return imagecodecs.float24_decode(
439
+ buf, byteorder=self.byteorder, out=out
440
+ )
441
+
442
+
443
+ class FloatPred(Codec):
444
+ """Floating Point Predictor codec for numcodecs."""
445
+
446
+ codec_id = 'imagecodecs_floatpred'
447
+
448
+ def __init__(self, shape, dtype, axis=-1, dist=1):
449
+ self.shape = tuple(shape)
450
+ self.dtype = numpy.dtype(dtype).str
451
+ self.axis = axis
452
+ self.dist = dist
453
+
454
+ def encode(self, buf):
455
+ buf = protective_squeeze(numpy.asarray(buf))
456
+ assert buf.shape == self.shape
457
+ assert buf.dtype == self.dtype
458
+ return imagecodecs.floatpred_encode(
459
+ buf, axis=self.axis, dist=self.dist
460
+ ).tobytes()
461
+
462
+ def decode(self, buf, out=None):
463
+ if not isinstance(buf, numpy.ndarray):
464
+ buf = numpy.frombuffer(buf, dtype=self.dtype).reshape(*self.shape)
465
+ return imagecodecs.floatpred_decode(
466
+ buf, axis=self.axis, dist=self.dist, out=out
467
+ )
468
+
469
+
470
+ class Gif(Codec):
471
+ """GIF codec for numcodecs."""
472
+
473
+ codec_id = 'imagecodecs_gif'
474
+
475
+ def encode(self, buf):
476
+ buf = protective_squeeze(numpy.asarray(buf))
477
+ return imagecodecs.gif_encode(buf)
478
+
479
+ def decode(self, buf, out=None):
480
+ return imagecodecs.gif_decode(buf, asrgb=False, out=out)
481
+
482
+
483
+ class Heif(Codec):
484
+ """HEIF codec for numcodecs."""
485
+
486
+ codec_id = 'imagecodecs_heif'
487
+
488
+ def __init__(
489
+ self,
490
+ level=None,
491
+ bitspersample=None,
492
+ photometric=None,
493
+ compression=None,
494
+ numthreads=None,
495
+ index=None,
496
+ ):
497
+ self.level = level
498
+ self.bitspersample = bitspersample
499
+ self.photometric = photometric
500
+ self.compression = compression
501
+ self.numthreads = numthreads
502
+ self.index = index
503
+
504
+ def encode(self, buf):
505
+ buf = protective_squeeze(numpy.asarray(buf))
506
+ return imagecodecs.heif_encode(
507
+ buf,
508
+ level=self.level,
509
+ bitspersample=self.bitspersample,
510
+ photometric=self.photometric,
511
+ compression=self.compression,
512
+ numthreads=self.numthreads,
513
+ )
514
+
515
+ def decode(self, buf, out=None):
516
+ return imagecodecs.heif_decode(
517
+ buf,
518
+ index=self.index,
519
+ photometric=self.photometric,
520
+ numthreads=self.numthreads,
521
+ out=out,
522
+ )
523
+
524
+
525
+ class Jetraw(Codec):
526
+ """Jetraw codec for numcodecs."""
527
+
528
+ codec_id = 'imagecodecs_jetraw'
529
+
530
+ def __init__(
531
+ self,
532
+ shape,
533
+ identifier,
534
+ parameters=None,
535
+ verbosity=None,
536
+ errorbound=None,
537
+ ):
538
+ self.shape = shape
539
+ self.identifier = identifier
540
+ self.errorbound = errorbound
541
+ imagecodecs.jetraw_init(parameters, verbosity)
542
+
543
+ def encode(self, buf):
544
+ return imagecodecs.jetraw_encode(
545
+ buf, identifier=self.identifier, errorbound=self.errorbound
546
+ )
547
+
548
+ def decode(self, buf, out=None):
549
+ if out is None:
550
+ out = numpy.empty(self.shape, numpy.uint16)
551
+ return imagecodecs.jetraw_decode(buf, out=out)
552
+
553
+
554
+ class Jpeg(Codec):
555
+ """JPEG codec for numcodecs."""
556
+
557
+ codec_id = 'imagecodecs_jpeg'
558
+
559
+ def __init__(
560
+ self,
561
+ bitspersample=None,
562
+ tables=None,
563
+ header=None,
564
+ colorspace_data=None,
565
+ colorspace_jpeg=None,
566
+ level=None,
567
+ subsampling=None,
568
+ optimize=None,
569
+ smoothing=None,
570
+ ):
571
+ self.tables = tables
572
+ self.header = header
573
+ self.bitspersample = bitspersample
574
+ self.colorspace_data = colorspace_data
575
+ self.colorspace_jpeg = colorspace_jpeg
576
+ self.level = level
577
+ self.subsampling = subsampling
578
+ self.optimize = optimize
579
+ self.smoothing = smoothing
580
+
581
+ def encode(self, buf):
582
+ buf = protective_squeeze(numpy.asarray(buf))
583
+ return imagecodecs.jpeg_encode(
584
+ buf,
585
+ level=self.level,
586
+ colorspace=self.colorspace_data,
587
+ outcolorspace=self.colorspace_jpeg,
588
+ subsampling=self.subsampling,
589
+ optimize=self.optimize,
590
+ smoothing=self.smoothing,
591
+ )
592
+
593
+ def decode(self, buf, out=None):
594
+ out_shape = None
595
+ if out is not None:
596
+ out_shape = out.shape
597
+ out = protective_squeeze(out)
598
+ img = imagecodecs.jpeg_decode(
599
+ buf,
600
+ bitspersample=self.bitspersample,
601
+ tables=self.tables,
602
+ header=self.header,
603
+ colorspace=self.colorspace_jpeg,
604
+ outcolorspace=self.colorspace_data,
605
+ out=out,
606
+ )
607
+ if out_shape is not None:
608
+ img = img.reshape(out_shape)
609
+ return img
610
+
611
+ def get_config(self):
612
+ """Return dictionary holding configuration parameters."""
613
+ config = dict(id=self.codec_id)
614
+ for key in self.__dict__:
615
+ if not key.startswith('_'):
616
+ value = getattr(self, key)
617
+ if value is not None and key in ('header', 'tables'):
618
+ import base64
619
+
620
+ value = base64.b64encode(value).decode()
621
+ config[key] = value
622
+ return config
623
+
624
+ @classmethod
625
+ def from_config(cls, config):
626
+ """Instantiate codec from configuration object."""
627
+ for key in ('header', 'tables'):
628
+ value = config.get(key, None)
629
+ if value is not None and isinstance(value, str):
630
+ import base64
631
+
632
+ config[key] = base64.b64decode(value.encode())
633
+ return cls(**config)
634
+
635
+
636
+ class Jpeg2k(Codec):
637
+ """JPEG 2000 codec for numcodecs."""
638
+
639
+ codec_id = 'imagecodecs_jpeg2k'
640
+
641
+ def __init__(
642
+ self,
643
+ level=None,
644
+ codecformat=None,
645
+ colorspace=None,
646
+ tile=None,
647
+ reversible=None,
648
+ bitspersample=None,
649
+ resolutions=None,
650
+ numthreads=None,
651
+ verbose=0,
652
+ ):
653
+ self.level = level
654
+ self.codecformat = codecformat
655
+ self.colorspace = colorspace
656
+ self.tile = None if tile is None else tuple(tile)
657
+ self.reversible = reversible
658
+ self.bitspersample = bitspersample
659
+ self.resolutions = resolutions
660
+ self.numthreads = numthreads
661
+ self.verbose = verbose
662
+
663
+ def encode(self, buf):
664
+ buf = protective_squeeze(numpy.asarray(buf))
665
+ return imagecodecs.jpeg2k_encode(
666
+ buf,
667
+ level=self.level,
668
+ codecformat=self.codecformat,
669
+ colorspace=self.colorspace,
670
+ tile=self.tile,
671
+ reversible=self.reversible,
672
+ bitspersample=self.bitspersample,
673
+ resolutions=self.resolutions,
674
+ numthreads=self.numthreads,
675
+ verbose=self.verbose,
676
+ )
677
+
678
+ def decode(self, buf, out=None):
679
+ return imagecodecs.jpeg2k_decode(
680
+ buf, verbose=self.verbose, numthreads=self.numthreads, out=out
681
+ )
682
+
683
+
684
+ class JpegLs(Codec):
685
+ """JPEG LS codec for numcodecs."""
686
+
687
+ codec_id = 'imagecodecs_jpegls'
688
+
689
+ def __init__(self, level=None):
690
+ self.level = level
691
+
692
+ def encode(self, buf):
693
+ buf = protective_squeeze(numpy.asarray(buf))
694
+ return imagecodecs.jpegls_encode(buf, level=self.level)
695
+
696
+ def decode(self, buf, out=None):
697
+ return imagecodecs.jpegls_decode(buf, out=out)
698
+
699
+
700
+ class JpegXl(Codec):
701
+ """JPEG XL codec for numcodecs."""
702
+
703
+ codec_id = 'imagecodecs_jpegxl'
704
+
705
+ def __init__(
706
+ self,
707
+ # encode
708
+ level=None,
709
+ effort=None,
710
+ distance=None,
711
+ lossless=None,
712
+ decodingspeed=None,
713
+ photometric=None,
714
+ planar=None,
715
+ usecontainer=None,
716
+ # decode
717
+ index=None,
718
+ keeporientation=None,
719
+ # both
720
+ numthreads=None,
721
+ ):
722
+ """
723
+ Return JPEG XL image from numpy array.
724
+ Float must be in nominal range 0..1.
725
+
726
+ Currently L, LA, RGB, RGBA images are supported in contig mode.
727
+ Extra channels are only supported for grayscale images in planar mode.
728
+
729
+ Parameters
730
+ ----------
731
+ level : Default to None, i.e. not overwriting lossess and decodingspeed options.
732
+ When < 0: Use lossless compression
733
+ When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
734
+ Minimum is 0 (slowest to decode, best quality/density), and maximum
735
+ is 4 (fastest to decode, at the cost of some quality/density).
736
+ effort : Default to 3.
737
+ Sets encoder effort/speed level without affecting decoding speed.
738
+ Valid values are, from faster to slower speed: 1:lightning 2:thunder
739
+ 3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
740
+ Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
741
+ control the encoder effort in ascending order.
742
+ This also affects memory usage: using lower effort will typically reduce memory
743
+ consumption during encoding.
744
+ lightning and thunder are fast modes useful for lossless mode (modular).
745
+ falcon disables all of the following tools.
746
+ cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
747
+ hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
748
+ wombat enables error diffusion quantization and full DCT size selection heuristics.
749
+ squirrel (default) enables dots, patches, and spline detection, and full context clustering.
750
+ kitten optimizes the adaptive quantization for a psychovisual metric.
751
+ tortoise enables a more thorough adaptive quantization search.
752
+ distance : Default to 1.0
753
+ Sets the distance level for lossy compression: target max butteraugli distance,
754
+ lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
755
+ (however, use JxlEncoderSetFrameLossless instead to use true lossless,
756
+ as setting distance to 0 alone is not the only requirement).
757
+ 1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
758
+ lossess : Default to False.
759
+ Use lossess encoding.
760
+ decodingspeed : Default to 0.
761
+ Duplicate to level. [0,4]
762
+ photometric : Return JxlColorSpace value.
763
+ Default logic is quite complicated but works most of the time.
764
+ Accepted value:
765
+ int: [-1,3]
766
+ str: ['RGB',
767
+ 'WHITEISZERO', 'MINISWHITE',
768
+ 'BLACKISZERO', 'MINISBLACK', 'GRAY',
769
+ 'XYB', 'KNOWN']
770
+ planar : Enable multi-channel mode.
771
+ Default to false.
772
+ usecontainer :
773
+ Forces the encoder to use the box-based container format (BMFF)
774
+ even when not necessary.
775
+ When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
776
+ JxlEncoderSetCodestreamLevel with level 10, the encoder will
777
+ automatically also use the container format, it is not necessary
778
+ to use JxlEncoderUseContainer for those use cases.
779
+ By default this setting is disabled.
780
+ index : Selectively decode frames for animation.
781
+ Default to 0, decode all frames.
782
+ When set to > 0, decode that frame index only.
783
+ keeporientation :
784
+ Enables or disables preserving of as-in-bitstream pixeldata orientation.
785
+ Some images are encoded with an Orientation tag indicating that the
786
+ decoder must perform a rotation and/or mirroring to the encoded image data.
787
+
788
+ If skip_reorientation is JXL_FALSE (the default): the decoder will apply
789
+ the transformation from the orientation setting, hence rendering the image
790
+ according to its specified intent. When producing a JxlBasicInfo, the decoder
791
+ will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
792
+ returned pixel data) and also align xsize and ysize so that they correspond
793
+ to the width and the height of the returned pixel data.
794
+
795
+ If skip_reorientation is JXL_TRUE: the decoder will skip applying the
796
+ transformation from the orientation setting, returning the image in
797
+ the as-in-bitstream pixeldata orientation. This may be faster to decode
798
+ since the decoder doesnt have to apply the transformation, but can
799
+ cause wrong display of the image if the orientation tag is not correctly
800
+ taken into account by the user.
801
+
802
+ By default, this option is disabled, and the returned pixel data is
803
+ re-oriented according to the images Orientation setting.
804
+ threads : Default to 1.
805
+ If <= 0, use all cores.
806
+ If > 32, clipped to 32.
807
+ """
808
+
809
+ self.level = level
810
+ self.effort = effort
811
+ self.distance = distance
812
+ self.lossless = bool(lossless)
813
+ self.decodingspeed = decodingspeed
814
+ self.photometric = photometric
815
+ self.planar = planar
816
+ self.usecontainer = usecontainer
817
+ self.index = index
818
+ self.keeporientation = keeporientation
819
+ self.numthreads = numthreads
820
+
821
+ def encode(self, buf):
822
+ # TODO: only squeeze all but last dim
823
+ buf = protective_squeeze(numpy.asarray(buf))
824
+ return imagecodecs.jpegxl_encode(
825
+ buf,
826
+ level=self.level,
827
+ effort=self.effort,
828
+ distance=self.distance,
829
+ lossless=self.lossless,
830
+ decodingspeed=self.decodingspeed,
831
+ photometric=self.photometric,
832
+ planar=self.planar,
833
+ usecontainer=self.usecontainer,
834
+ numthreads=self.numthreads,
835
+ )
836
+
837
+ def decode(self, buf, out=None):
838
+ return imagecodecs.jpegxl_decode(
839
+ buf,
840
+ index=self.index,
841
+ keeporientation=self.keeporientation,
842
+ numthreads=self.numthreads,
843
+ out=out,
844
+ )
845
+
846
+
847
+ class JpegXr(Codec):
848
+ """JPEG XR codec for numcodecs."""
849
+
850
+ codec_id = 'imagecodecs_jpegxr'
851
+
852
+ def __init__(
853
+ self,
854
+ level=None,
855
+ photometric=None,
856
+ hasalpha=None,
857
+ resolution=None,
858
+ fp2int=None,
859
+ ):
860
+ self.level = level
861
+ self.photometric = photometric
862
+ self.hasalpha = hasalpha
863
+ self.resolution = resolution
864
+ self.fp2int = fp2int
865
+
866
+ def encode(self, buf):
867
+ buf = protective_squeeze(numpy.asarray(buf))
868
+ return imagecodecs.jpegxr_encode(
869
+ buf,
870
+ level=self.level,
871
+ photometric=self.photometric,
872
+ hasalpha=self.hasalpha,
873
+ resolution=self.resolution,
874
+ )
875
+
876
+ def decode(self, buf, out=None):
877
+ return imagecodecs.jpegxr_decode(buf, fp2int=self.fp2int, out=out)
878
+
879
+
880
+ class Lerc(Codec):
881
+ """LERC codec for numcodecs."""
882
+
883
+ codec_id = 'imagecodecs_lerc'
884
+
885
+ def __init__(self, level=None, version=None, planar=None):
886
+ self.level = level
887
+ self.version = version
888
+ self.planar = bool(planar)
889
+ # TODO: support mask?
890
+ # self.mask = None
891
+
892
+ def encode(self, buf):
893
+ buf = protective_squeeze(numpy.asarray(buf))
894
+ return imagecodecs.lerc_encode(
895
+ buf,
896
+ level=self.level,
897
+ version=self.version,
898
+ planar=self.planar,
899
+ )
900
+
901
+ def decode(self, buf, out=None):
902
+ return imagecodecs.lerc_decode(buf, out=out)
903
+
904
+
905
+ class Ljpeg(Codec):
906
+ """LJPEG codec for numcodecs."""
907
+
908
+ codec_id = 'imagecodecs_ljpeg'
909
+
910
+ def __init__(self, bitspersample=None):
911
+ self.bitspersample = bitspersample
912
+
913
+ def encode(self, buf):
914
+ buf = protective_squeeze(numpy.asarray(buf))
915
+ return imagecodecs.ljpeg_encode(buf, bitspersample=self.bitspersample)
916
+
917
+ def decode(self, buf, out=None):
918
+ return imagecodecs.ljpeg_decode(buf, out=out)
919
+
920
+
921
+ class Lz4(Codec):
922
+ """LZ4 codec for numcodecs."""
923
+
924
+ codec_id = 'imagecodecs_lz4'
925
+
926
+ def __init__(self, level=None, hc=False, header=True):
927
+ self.level = level
928
+ self.hc = hc
929
+ self.header = bool(header)
930
+
931
+ def encode(self, buf):
932
+ return imagecodecs.lz4_encode(
933
+ buf, level=self.level, hc=self.hc, header=self.header
934
+ )
935
+
936
+ def decode(self, buf, out=None):
937
+ return imagecodecs.lz4_decode(buf, header=self.header, out=_flat(out))
938
+
939
+
940
+ class Lz4f(Codec):
941
+ """LZ4F codec for numcodecs."""
942
+
943
+ codec_id = 'imagecodecs_lz4f'
944
+
945
+ def __init__(
946
+ self,
947
+ level=None,
948
+ blocksizeid=False,
949
+ contentchecksum=None,
950
+ blockchecksum=None,
951
+ ):
952
+ self.level = level
953
+ self.blocksizeid = blocksizeid
954
+ self.contentchecksum = contentchecksum
955
+ self.blockchecksum = blockchecksum
956
+
957
+ def encode(self, buf):
958
+ return imagecodecs.lz4f_encode(
959
+ buf,
960
+ level=self.level,
961
+ blocksizeid=self.blocksizeid,
962
+ contentchecksum=self.contentchecksum,
963
+ blockchecksum=self.blockchecksum,
964
+ )
965
+
966
+ def decode(self, buf, out=None):
967
+ return imagecodecs.lz4f_decode(buf, out=_flat(out))
968
+
969
+
970
+ class Lzf(Codec):
971
+ """LZF codec for numcodecs."""
972
+
973
+ codec_id = 'imagecodecs_lzf'
974
+
975
+ def __init__(self, header=True):
976
+ self.header = bool(header)
977
+
978
+ def encode(self, buf):
979
+ return imagecodecs.lzf_encode(buf, header=self.header)
980
+
981
+ def decode(self, buf, out=None):
982
+ return imagecodecs.lzf_decode(buf, header=self.header, out=_flat(out))
983
+
984
+
985
+ class Lzma(Codec):
986
+ """LZMA codec for numcodecs."""
987
+
988
+ codec_id = 'imagecodecs_lzma'
989
+
990
+ def __init__(self, level=None):
991
+ self.level = level
992
+
993
+ def encode(self, buf):
994
+ return imagecodecs.lzma_encode(buf, level=self.level)
995
+
996
+ def decode(self, buf, out=None):
997
+ return imagecodecs.lzma_decode(buf, out=_flat(out))
998
+
999
+
1000
+ class Lzw(Codec):
1001
+ """LZW codec for numcodecs."""
1002
+
1003
+ codec_id = 'imagecodecs_lzw'
1004
+
1005
+ def encode(self, buf):
1006
+ return imagecodecs.lzw_encode(buf)
1007
+
1008
+ def decode(self, buf, out=None):
1009
+ return imagecodecs.lzw_decode(buf, out=_flat(out))
1010
+
1011
+
1012
+ class PackBits(Codec):
1013
+ """PackBits codec for numcodecs."""
1014
+
1015
+ codec_id = 'imagecodecs_packbits'
1016
+
1017
+ def __init__(self, axis=None):
1018
+ self.axis = axis
1019
+
1020
+ def encode(self, buf):
1021
+ if not isinstance(buf, (bytes, bytearray)):
1022
+ buf = protective_squeeze(numpy.asarray(buf))
1023
+ return imagecodecs.packbits_encode(buf, axis=self.axis)
1024
+
1025
+ def decode(self, buf, out=None):
1026
+ return imagecodecs.packbits_decode(buf, out=_flat(out))
1027
+
1028
+
1029
+ class Pglz(Codec):
1030
+ """PGLZ codec for numcodecs."""
1031
+
1032
+ codec_id = 'imagecodecs_pglz'
1033
+
1034
+ def __init__(self, header=True, strategy=None):
1035
+ self.header = bool(header)
1036
+ self.strategy = strategy
1037
+
1038
+ def encode(self, buf):
1039
+ return imagecodecs.pglz_encode(
1040
+ buf, strategy=self.strategy, header=self.header
1041
+ )
1042
+
1043
+ def decode(self, buf, out=None):
1044
+ return imagecodecs.pglz_decode(buf, header=self.header, out=_flat(out))
1045
+
1046
+
1047
+ class Png(Codec):
1048
+ """PNG codec for numcodecs."""
1049
+
1050
+ codec_id = 'imagecodecs_png'
1051
+
1052
+ def __init__(self, level=None):
1053
+ self.level = level
1054
+
1055
+ def encode(self, buf):
1056
+ buf = protective_squeeze(numpy.asarray(buf))
1057
+ return imagecodecs.png_encode(buf, level=self.level)
1058
+
1059
+ def decode(self, buf, out=None):
1060
+ return imagecodecs.png_decode(buf, out=out)
1061
+
1062
+
1063
+ class Qoi(Codec):
1064
+ """QOI codec for numcodecs."""
1065
+
1066
+ codec_id = 'imagecodecs_qoi'
1067
+
1068
+ def __init__(self):
1069
+ pass
1070
+
1071
+ def encode(self, buf):
1072
+ buf = protective_squeeze(numpy.asarray(buf))
1073
+ return imagecodecs.qoi_encode(buf)
1074
+
1075
+ def decode(self, buf, out=None):
1076
+ return imagecodecs.qoi_decode(buf, out=out)
1077
+
1078
+
1079
+ class Rgbe(Codec):
1080
+ """RGBE codec for numcodecs."""
1081
+
1082
+ codec_id = 'imagecodecs_rgbe'
1083
+
1084
+ def __init__(self, header=False, shape=None, rle=None):
1085
+ if not header and shape is None:
1086
+ raise ValueError('must specify data shape if no header')
1087
+ if shape and shape[-1] != 3:
1088
+ raise ValueError('invalid shape')
1089
+ self.shape = shape
1090
+ self.header = bool(header)
1091
+ self.rle = None if rle is None else bool(rle)
1092
+
1093
+ def encode(self, buf):
1094
+ buf = protective_squeeze(numpy.asarray(buf))
1095
+ return imagecodecs.rgbe_encode(buf, header=self.header, rle=self.rle)
1096
+
1097
+ def decode(self, buf, out=None):
1098
+ if out is None and not self.header:
1099
+ out = numpy.empty(self.shape, numpy.float32)
1100
+ return imagecodecs.rgbe_decode(
1101
+ buf, header=self.header, rle=self.rle, out=out
1102
+ )
1103
+
1104
+
1105
+ class Rcomp(Codec):
1106
+ """Rcomp codec for numcodecs."""
1107
+
1108
+ codec_id = 'imagecodecs_rcomp'
1109
+
1110
+ def __init__(self, shape, dtype, nblock=None):
1111
+ self.shape = tuple(shape)
1112
+ self.dtype = numpy.dtype(dtype).str
1113
+ self.nblock = nblock
1114
+
1115
+ def encode(self, buf):
1116
+ return imagecodecs.rcomp_encode(buf, nblock=self.nblock)
1117
+
1118
+ def decode(self, buf, out=None):
1119
+ return imagecodecs.rcomp_decode(
1120
+ buf,
1121
+ shape=self.shape,
1122
+ dtype=self.dtype,
1123
+ nblock=self.nblock,
1124
+ out=out,
1125
+ )
1126
+
1127
+
1128
+ class Snappy(Codec):
1129
+ """Snappy codec for numcodecs."""
1130
+
1131
+ codec_id = 'imagecodecs_snappy'
1132
+
1133
+ def encode(self, buf):
1134
+ return imagecodecs.snappy_encode(buf)
1135
+
1136
+ def decode(self, buf, out=None):
1137
+ return imagecodecs.snappy_decode(buf, out=_flat(out))
1138
+
1139
+
1140
+ class Spng(Codec):
1141
+ """SPNG codec for numcodecs."""
1142
+
1143
+ codec_id = 'imagecodecs_spng'
1144
+
1145
+ def __init__(self, level=None):
1146
+ self.level = level
1147
+
1148
+ def encode(self, buf):
1149
+ buf = protective_squeeze(numpy.asarray(buf))
1150
+ return imagecodecs.spng_encode(buf, level=self.level)
1151
+
1152
+ def decode(self, buf, out=None):
1153
+ return imagecodecs.spng_decode(buf, out=out)
1154
+
1155
+
1156
+ class Tiff(Codec):
1157
+ """TIFF codec for numcodecs."""
1158
+
1159
+ codec_id = 'imagecodecs_tiff'
1160
+
1161
+ def __init__(self, index=None, asrgb=None, verbose=0):
1162
+ self.index = index
1163
+ self.asrgb = bool(asrgb)
1164
+ self.verbose = verbose
1165
+
1166
+ def encode(self, buf):
1167
+ # TODO: not implemented
1168
+ buf = protective_squeeze(numpy.asarray(buf))
1169
+ return imagecodecs.tiff_encode(buf)
1170
+
1171
+ def decode(self, buf, out=None):
1172
+ return imagecodecs.tiff_decode(
1173
+ buf,
1174
+ index=self.index,
1175
+ asrgb=self.asrgb,
1176
+ verbose=self.verbose,
1177
+ out=out,
1178
+ )
1179
+
1180
+
1181
+ class Webp(Codec):
1182
+ """WebP codec for numcodecs."""
1183
+
1184
+ codec_id = 'imagecodecs_webp'
1185
+
1186
+ def __init__(self, level=None, lossless=None, method=None, hasalpha=None):
1187
+ self.level = level
1188
+ self.hasalpha = bool(hasalpha)
1189
+ self.method = method
1190
+ self.lossless = lossless
1191
+
1192
+ def encode(self, buf):
1193
+ buf = protective_squeeze(numpy.asarray(buf))
1194
+ return imagecodecs.webp_encode(
1195
+ buf, level=self.level, lossless=self.lossless, method=self.method
1196
+ )
1197
+
1198
+ def decode(self, buf, out=None):
1199
+ return imagecodecs.webp_decode(buf, hasalpha=self.hasalpha, out=out)
1200
+
1201
+
1202
+ class Xor(Codec):
1203
+ """XOR codec for numcodecs."""
1204
+
1205
+ codec_id = 'imagecodecs_xor'
1206
+
1207
+ def __init__(self, shape=None, dtype=None, axis=-1):
1208
+ self.shape = None if shape is None else tuple(shape)
1209
+ self.dtype = None if dtype is None else numpy.dtype(dtype).str
1210
+ self.axis = axis
1211
+
1212
+ def encode(self, buf):
1213
+ if self.shape is not None or self.dtype is not None:
1214
+ buf = protective_squeeze(numpy.asarray(buf))
1215
+ assert buf.shape == self.shape
1216
+ assert buf.dtype == self.dtype
1217
+ return imagecodecs.xor_encode(buf, axis=self.axis).tobytes()
1218
+
1219
+ def decode(self, buf, out=None):
1220
+ if self.shape is not None or self.dtype is not None:
1221
+ buf = numpy.frombuffer(buf, dtype=self.dtype).reshape(*self.shape)
1222
+ return imagecodecs.xor_decode(buf, axis=self.axis, out=_flat(out))
1223
+
1224
+
1225
+ class Zfp(Codec):
1226
+ """ZFP codec for numcodecs."""
1227
+
1228
+ codec_id = 'imagecodecs_zfp'
1229
+
1230
+ def __init__(
1231
+ self,
1232
+ shape=None,
1233
+ dtype=None,
1234
+ strides=None,
1235
+ level=None,
1236
+ mode=None,
1237
+ execution=None,
1238
+ numthreads=None,
1239
+ chunksize=None,
1240
+ header=True,
1241
+ ):
1242
+ if header:
1243
+ self.shape = None
1244
+ self.dtype = None
1245
+ self.strides = None
1246
+ elif shape is None or dtype is None:
1247
+ raise ValueError('invalid shape or dtype')
1248
+ else:
1249
+ self.shape = tuple(shape)
1250
+ self.dtype = numpy.dtype(dtype).str
1251
+ self.strides = None if strides is None else tuple(strides)
1252
+ self.level = level
1253
+ self.mode = mode
1254
+ self.execution = execution
1255
+ self.numthreads = numthreads
1256
+ self.chunksize = chunksize
1257
+ self.header = bool(header)
1258
+
1259
+ def encode(self, buf):
1260
+ buf = protective_squeeze(numpy.asarray(buf))
1261
+ if not self.header:
1262
+ assert buf.shape == self.shape
1263
+ assert buf.dtype == self.dtype
1264
+ return imagecodecs.zfp_encode(
1265
+ buf,
1266
+ level=self.level,
1267
+ mode=self.mode,
1268
+ execution=self.execution,
1269
+ header=self.header,
1270
+ numthreads=self.numthreads,
1271
+ chunksize=self.chunksize,
1272
+ )
1273
+
1274
+ def decode(self, buf, out=None):
1275
+ if self.header:
1276
+ return imagecodecs.zfp_decode(buf, out=out)
1277
+ return imagecodecs.zfp_decode(
1278
+ buf,
1279
+ shape=self.shape,
1280
+ dtype=numpy.dtype(self.dtype),
1281
+ strides=self.strides,
1282
+ numthreads=self.numthreads,
1283
+ out=out,
1284
+ )
1285
+
1286
+
1287
+ class Zlib(Codec):
1288
+ """Zlib codec for numcodecs."""
1289
+
1290
+ codec_id = 'imagecodecs_zlib'
1291
+
1292
+ def __init__(self, level=None):
1293
+ self.level = level
1294
+
1295
+ def encode(self, buf):
1296
+ return imagecodecs.zlib_encode(buf, level=self.level)
1297
+
1298
+ def decode(self, buf, out=None):
1299
+ return imagecodecs.zlib_decode(buf, out=_flat(out))
1300
+
1301
+
1302
+ class Zlibng(Codec):
1303
+ """Zlibng codec for numcodecs."""
1304
+
1305
+ codec_id = 'imagecodecs_zlibng'
1306
+
1307
+ def __init__(self, level=None):
1308
+ self.level = level
1309
+
1310
+ def encode(self, buf):
1311
+ return imagecodecs.zlibng_encode(buf, level=self.level)
1312
+
1313
+ def decode(self, buf, out=None):
1314
+ return imagecodecs.zlibng_decode(buf, out=_flat(out))
1315
+
1316
+
1317
+ class Zopfli(Codec):
1318
+ """Zopfli codec for numcodecs."""
1319
+
1320
+ codec_id = 'imagecodecs_zopfli'
1321
+
1322
+ def encode(self, buf):
1323
+ return imagecodecs.zopfli_encode(buf)
1324
+
1325
+ def decode(self, buf, out=None):
1326
+ return imagecodecs.zopfli_decode(buf, out=_flat(out))
1327
+
1328
+
1329
+ class Zstd(Codec):
1330
+ """ZStandard codec for numcodecs."""
1331
+
1332
+ codec_id = 'imagecodecs_zstd'
1333
+
1334
+ def __init__(self, level=None):
1335
+ self.level = level
1336
+
1337
+ def encode(self, buf):
1338
+ return imagecodecs.zstd_encode(buf, level=self.level)
1339
+
1340
+ def decode(self, buf, out=None):
1341
+ return imagecodecs.zstd_decode(buf, out=_flat(out))
1342
+
1343
+
1344
+ def _flat(out):
1345
+ """Return numpy array as contiguous view of bytes if possible."""
1346
+ if out is None:
1347
+ return None
1348
+ view = memoryview(out)
1349
+ if view.readonly or not view.contiguous:
1350
+ return None
1351
+ return view.cast('B')
1352
+
1353
+
1354
+ def register_codecs(codecs=None, force=False, verbose=True):
1355
+ """Register codecs in this module with numcodecs."""
1356
+ for name, cls in globals().items():
1357
+ if not hasattr(cls, 'codec_id') or name == 'Codec':
1358
+ continue
1359
+ if codecs is not None and cls.codec_id not in codecs:
1360
+ continue
1361
+ try:
1362
+ try:
1363
+ get_codec({'id': cls.codec_id})
1364
+ except TypeError:
1365
+ # registered, but failed
1366
+ pass
1367
+ except ValueError:
1368
+ # not registered yet
1369
+ pass
1370
+ else:
1371
+ if not force:
1372
+ if verbose:
1373
+ log_warning(
1374
+ f'numcodec {cls.codec_id!r} already registered'
1375
+ )
1376
+ continue
1377
+ if verbose:
1378
+ log_warning(f'replacing registered numcodec {cls.codec_id!r}')
1379
+ register_codec(cls)
1380
+
1381
+
1382
+ def log_warning(msg, *args, **kwargs):
1383
+ """Log message with level WARNING."""
1384
+ import logging
1385
+
1386
+ logging.getLogger(__name__).warning(msg, *args, **kwargs)
equidiff/equi_diffpo/common/checkpoint_util.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict
2
+ import os
3
+
4
+ class TopKCheckpointManager:
5
+ def __init__(self,
6
+ save_dir,
7
+ monitor_key: str,
8
+ mode='min',
9
+ k=1,
10
+ format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt'
11
+ ):
12
+ assert mode in ['max', 'min']
13
+ assert k >= 0
14
+
15
+ self.save_dir = save_dir
16
+ self.monitor_key = monitor_key
17
+ self.mode = mode
18
+ self.k = k
19
+ self.format_str = format_str
20
+ self.path_value_map = dict()
21
+
22
+ def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]:
23
+ if self.k == 0:
24
+ return None
25
+
26
+ value = data[self.monitor_key]
27
+ ckpt_path = os.path.join(
28
+ self.save_dir, self.format_str.format(**data))
29
+
30
+ if len(self.path_value_map) < self.k:
31
+ # under-capacity
32
+ self.path_value_map[ckpt_path] = value
33
+ return ckpt_path
34
+
35
+ # at capacity
36
+ sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1])
37
+ min_path, min_value = sorted_map[0]
38
+ max_path, max_value = sorted_map[-1]
39
+
40
+ delete_path = None
41
+ if self.mode == 'max':
42
+ if value > min_value:
43
+ delete_path = min_path
44
+ else:
45
+ if value < max_value:
46
+ delete_path = max_path
47
+
48
+ if delete_path is None:
49
+ return None
50
+ else:
51
+ del self.path_value_map[delete_path]
52
+ self.path_value_map[ckpt_path] = value
53
+
54
+ if not os.path.exists(self.save_dir):
55
+ os.mkdir(self.save_dir)
56
+
57
+ if os.path.exists(delete_path):
58
+ os.remove(delete_path)
59
+ return ckpt_path
equidiff/equi_diffpo/common/cv2_util.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import math
3
+ import cv2
4
+ import numpy as np
5
+
6
+ def draw_reticle(img, u, v, label_color):
7
+ """
8
+ Draws a reticle (cross-hair) on the image at the given position on top of
9
+ the original image.
10
+ @param img (In/Out) uint8 3 channel image
11
+ @param u X coordinate (width)
12
+ @param v Y coordinate (height)
13
+ @param label_color tuple of 3 ints for RGB color used for drawing.
14
+ """
15
+ # Cast to int.
16
+ u = int(u)
17
+ v = int(v)
18
+
19
+ white = (255, 255, 255)
20
+ cv2.circle(img, (u, v), 10, label_color, 1)
21
+ cv2.circle(img, (u, v), 11, white, 1)
22
+ cv2.circle(img, (u, v), 12, label_color, 1)
23
+ cv2.line(img, (u, v + 1), (u, v + 3), white, 1)
24
+ cv2.line(img, (u + 1, v), (u + 3, v), white, 1)
25
+ cv2.line(img, (u, v - 1), (u, v - 3), white, 1)
26
+ cv2.line(img, (u - 1, v), (u - 3, v), white, 1)
27
+
28
+
29
+ def draw_text(
30
+ img,
31
+ *,
32
+ text,
33
+ uv_top_left,
34
+ color=(255, 255, 255),
35
+ fontScale=0.5,
36
+ thickness=1,
37
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
38
+ outline_color=(0, 0, 0),
39
+ line_spacing=1.5,
40
+ ):
41
+ """
42
+ Draws multiline with an outline.
43
+ """
44
+ assert isinstance(text, str)
45
+
46
+ uv_top_left = np.array(uv_top_left, dtype=float)
47
+ assert uv_top_left.shape == (2,)
48
+
49
+ for line in text.splitlines():
50
+ (w, h), _ = cv2.getTextSize(
51
+ text=line,
52
+ fontFace=fontFace,
53
+ fontScale=fontScale,
54
+ thickness=thickness,
55
+ )
56
+ uv_bottom_left_i = uv_top_left + [0, h]
57
+ org = tuple(uv_bottom_left_i.astype(int))
58
+
59
+ if outline_color is not None:
60
+ cv2.putText(
61
+ img,
62
+ text=line,
63
+ org=org,
64
+ fontFace=fontFace,
65
+ fontScale=fontScale,
66
+ color=outline_color,
67
+ thickness=thickness * 3,
68
+ lineType=cv2.LINE_AA,
69
+ )
70
+ cv2.putText(
71
+ img,
72
+ text=line,
73
+ org=org,
74
+ fontFace=fontFace,
75
+ fontScale=fontScale,
76
+ color=color,
77
+ thickness=thickness,
78
+ lineType=cv2.LINE_AA,
79
+ )
80
+
81
+ uv_top_left += [0, h * line_spacing]
82
+
83
+
84
+ def get_image_transform(
85
+ input_res: Tuple[int,int]=(1280,720),
86
+ output_res: Tuple[int,int]=(640,480),
87
+ bgr_to_rgb: bool=False):
88
+
89
+ iw, ih = input_res
90
+ ow, oh = output_res
91
+ rw, rh = None, None
92
+ interp_method = cv2.INTER_AREA
93
+
94
+ if (iw/ih) >= (ow/oh):
95
+ # input is wider
96
+ rh = oh
97
+ rw = math.ceil(rh / ih * iw)
98
+ if oh > ih:
99
+ interp_method = cv2.INTER_LINEAR
100
+ else:
101
+ rw = ow
102
+ rh = math.ceil(rw / iw * ih)
103
+ if ow > iw:
104
+ interp_method = cv2.INTER_LINEAR
105
+
106
+ w_slice_start = (rw - ow) // 2
107
+ w_slice = slice(w_slice_start, w_slice_start + ow)
108
+ h_slice_start = (rh - oh) // 2
109
+ h_slice = slice(h_slice_start, h_slice_start + oh)
110
+ c_slice = slice(None)
111
+ if bgr_to_rgb:
112
+ c_slice = slice(None, None, -1)
113
+
114
+ def transform(img: np.ndarray):
115
+ assert img.shape == ((ih,iw,3))
116
+ # resize
117
+ img = cv2.resize(img, (rw, rh), interpolation=interp_method)
118
+ # crop
119
+ img = img[h_slice, w_slice, c_slice]
120
+ return img
121
+ return transform
122
+
123
+ def optimal_row_cols(
124
+ n_cameras,
125
+ in_wh_ratio,
126
+ max_resolution=(1920, 1080)
127
+ ):
128
+ out_w, out_h = max_resolution
129
+ out_wh_ratio = out_w / out_h
130
+
131
+ n_rows = np.arange(n_cameras,dtype=np.int64) + 1
132
+ n_cols = np.ceil(n_cameras / n_rows).astype(np.int64)
133
+ cat_wh_ratio = in_wh_ratio * (n_cols / n_rows)
134
+ ratio_diff = np.abs(out_wh_ratio - cat_wh_ratio)
135
+ best_idx = np.argmin(ratio_diff)
136
+ best_n_row = n_rows[best_idx]
137
+ best_n_col = n_cols[best_idx]
138
+ best_cat_wh_ratio = cat_wh_ratio[best_idx]
139
+
140
+ rw, rh = None, None
141
+ if best_cat_wh_ratio >= out_wh_ratio:
142
+ # cat is wider
143
+ rw = math.floor(out_w / best_n_col)
144
+ rh = math.floor(rw / in_wh_ratio)
145
+ else:
146
+ rh = math.floor(out_h / best_n_row)
147
+ rw = math.floor(rh * in_wh_ratio)
148
+
149
+ # crop_resolution = (rw, rh)
150
+ return rw, rh, best_n_col, best_n_row
equidiff/equi_diffpo/common/env_util.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def render_env_video(env, states, actions=None):
6
+ observations = states
7
+ imgs = list()
8
+ for i in range(len(observations)):
9
+ state = observations[i]
10
+ env.set_state(state)
11
+ if i == 0:
12
+ env.set_state(state)
13
+ img = env.render()
14
+ # draw action
15
+ if actions is not None:
16
+ action = actions[i]
17
+ coord = (action / 512 * 96).astype(np.int32)
18
+ cv2.drawMarker(img, coord,
19
+ color=(255,0,0), markerType=cv2.MARKER_CROSS,
20
+ markerSize=8, thickness=1)
21
+ imgs.append(img)
22
+ imgs = np.array(imgs)
23
+ return imgs
equidiff/equi_diffpo/common/json_logger.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Callable, Any, Sequence
2
+ import os
3
+ import copy
4
+ import json
5
+ import numbers
6
+ import pandas as pd
7
+
8
+
9
+ def read_json_log(path: str,
10
+ required_keys: Sequence[str]=tuple(),
11
+ **kwargs) -> pd.DataFrame:
12
+ """
13
+ Read json-per-line file, with potentially incomplete lines.
14
+ kwargs passed to pd.read_json
15
+ """
16
+ lines = list()
17
+ with open(path, 'r') as f:
18
+ while True:
19
+ # one json per line
20
+ line = f.readline()
21
+ if len(line) == 0:
22
+ # EOF
23
+ break
24
+ elif not line.endswith('\n'):
25
+ # incomplete line
26
+ break
27
+ is_relevant = False
28
+ for k in required_keys:
29
+ if k in line:
30
+ is_relevant = True
31
+ break
32
+ if is_relevant:
33
+ lines.append(line)
34
+ if len(lines) < 1:
35
+ return pd.DataFrame()
36
+ json_buf = f'[{",".join([line for line in (line.strip() for line in lines) if line])}]'
37
+ df = pd.read_json(json_buf, **kwargs)
38
+ return df
39
+
40
+ class JsonLogger:
41
+ def __init__(self, path: str,
42
+ filter_fn: Optional[Callable[[str,Any],bool]]=None):
43
+ if filter_fn is None:
44
+ filter_fn = lambda k,v: isinstance(v, numbers.Number)
45
+
46
+ # default to append mode
47
+ self.path = path
48
+ self.filter_fn = filter_fn
49
+ self.file = None
50
+ self.last_log = None
51
+
52
+ def start(self):
53
+ # use line buffering
54
+ try:
55
+ self.file = file = open(self.path, 'r+', buffering=1)
56
+ except FileNotFoundError:
57
+ self.file = file = open(self.path, 'w+', buffering=1)
58
+
59
+ # Move the pointer (similar to a cursor in a text editor) to the end of the file
60
+ pos = file.seek(0, os.SEEK_END)
61
+
62
+ # Read each character in the file one at a time from the last
63
+ # character going backwards, searching for a newline character
64
+ # If we find a new line, exit the search
65
+ while pos > 0 and file.read(1) != "\n":
66
+ pos -= 1
67
+ file.seek(pos, os.SEEK_SET)
68
+ # now the file pointer is at one past the last '\n'
69
+ # and pos is at the last '\n'.
70
+ last_line_end = file.tell()
71
+
72
+ # find the start of second last line
73
+ pos = max(0, pos-1)
74
+ file.seek(pos, os.SEEK_SET)
75
+ while pos > 0 and file.read(1) != "\n":
76
+ pos -= 1
77
+ file.seek(pos, os.SEEK_SET)
78
+ # now the file pointer is at one past the second last '\n'
79
+ last_line_start = file.tell()
80
+
81
+ if last_line_start < last_line_end:
82
+ # has last line of json
83
+ last_line = file.readline()
84
+ self.last_log = json.loads(last_line)
85
+
86
+ # remove the last incomplete line
87
+ file.seek(last_line_end)
88
+ file.truncate()
89
+
90
+ def stop(self):
91
+ self.file.close()
92
+ self.file = None
93
+
94
+ def __enter__(self):
95
+ self.start()
96
+ return self
97
+
98
+ def __exit__(self, exc_type, exc_val, exc_tb):
99
+ self.stop()
100
+
101
+ def log(self, data: dict):
102
+ filtered_data = dict(
103
+ filter(lambda x: self.filter_fn(*x), data.items()))
104
+ # save current as last log
105
+ self.last_log = filtered_data
106
+ for k, v in filtered_data.items():
107
+ if isinstance(v, numbers.Integral):
108
+ filtered_data[k] = int(v)
109
+ elif isinstance(v, numbers.Number):
110
+ filtered_data[k] = float(v)
111
+ buf = json.dumps(filtered_data)
112
+ # ensure one line per json
113
+ buf = buf.replace('\n','') + '\n'
114
+ self.file.write(buf)
115
+
116
+ def get_last_log(self):
117
+ return copy.deepcopy(self.last_log)
equidiff/equi_diffpo/common/nested_dict_util.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ def nested_dict_map(f, x):
4
+ """
5
+ Map f over all leaf of nested dict x
6
+ """
7
+
8
+ if not isinstance(x, dict):
9
+ return f(x)
10
+ y = dict()
11
+ for key, value in x.items():
12
+ y[key] = nested_dict_map(f, value)
13
+ return y
14
+
15
+ def nested_dict_reduce(f, x):
16
+ """
17
+ Map f over all values of nested dict x, and reduce to a single value
18
+ """
19
+ if not isinstance(x, dict):
20
+ return x
21
+
22
+ reduced_values = list()
23
+ for value in x.values():
24
+ reduced_values.append(nested_dict_reduce(f, value))
25
+ y = functools.reduce(f, reduced_values)
26
+ return y
27
+
28
+
29
+ def nested_dict_check(f, x):
30
+ bool_dict = nested_dict_map(f, x)
31
+ result = nested_dict_reduce(lambda x, y: x and y, bool_dict)
32
+ return result
equidiff/equi_diffpo/common/normalize_util.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from equi_diffpo.model.common.normalizer import SingleFieldLinearNormalizer
2
+ from equi_diffpo.common.pytorch_util import dict_apply, dict_apply_reduce, dict_apply_split
3
+ import numpy as np
4
+
5
+
6
+ def get_range_normalizer_from_stat(stat, output_max=1, output_min=-1, range_eps=1e-7):
7
+ # -1, 1 normalization
8
+ input_max = stat['max']
9
+ input_min = stat['min']
10
+ input_range = input_max - input_min
11
+ ignore_dim = input_range < range_eps
12
+ input_range[ignore_dim] = output_max - output_min
13
+ scale = (output_max - output_min) / input_range
14
+ offset = output_min - scale * input_min
15
+ offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
16
+
17
+ return SingleFieldLinearNormalizer.create_manual(
18
+ scale=scale,
19
+ offset=offset,
20
+ input_stats_dict=stat
21
+ )
22
+
23
+ def get_range_symmetric_normalizer_from_stat(stat, output_max=1, output_min=-1, range_eps=1e-7):
24
+ # -1, 1 normalization
25
+ input_max = stat['max']
26
+ input_min = stat['min']
27
+ abs_max = np.max([np.abs(stat['max'][:2]), np.abs(stat['min'][:2])])
28
+ input_max[:2] = abs_max
29
+ input_min[:2] = -abs_max
30
+ input_range = input_max - input_min
31
+ ignore_dim = input_range < range_eps
32
+ input_range[ignore_dim] = output_max - output_min
33
+ scale = (output_max - output_min) / input_range
34
+ offset = output_min - scale * input_min
35
+ offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
36
+
37
+ return SingleFieldLinearNormalizer.create_manual(
38
+ scale=scale,
39
+ offset=offset,
40
+ input_stats_dict=stat
41
+ )
42
+
43
+ def get_voxel_identity_normalizer():
44
+ scale = np.array([1], dtype=np.float32)
45
+ offset = np.array([0], dtype=np.float32)
46
+ stat = {
47
+ 'min': np.array([0], dtype=np.float32),
48
+ 'max': np.array([1], dtype=np.float32),
49
+ 'mean': np.array([0.5], dtype=np.float32),
50
+ 'std': np.array([np.sqrt(1/12)], dtype=np.float32)
51
+ }
52
+ return SingleFieldLinearNormalizer.create_manual(
53
+ scale=scale,
54
+ offset=offset,
55
+ input_stats_dict=stat
56
+ )
57
+
58
+ def get_image_range_normalizer():
59
+ scale = np.array([2], dtype=np.float32)
60
+ offset = np.array([-1], dtype=np.float32)
61
+ stat = {
62
+ 'min': np.array([0], dtype=np.float32),
63
+ 'max': np.array([1], dtype=np.float32),
64
+ 'mean': np.array([0.5], dtype=np.float32),
65
+ 'std': np.array([np.sqrt(1/12)], dtype=np.float32)
66
+ }
67
+ return SingleFieldLinearNormalizer.create_manual(
68
+ scale=scale,
69
+ offset=offset,
70
+ input_stats_dict=stat
71
+ )
72
+
73
+ def get_identity_normalizer_from_stat(stat):
74
+ scale = np.ones_like(stat['min'])
75
+ offset = np.zeros_like(stat['min'])
76
+ return SingleFieldLinearNormalizer.create_manual(
77
+ scale=scale,
78
+ offset=offset,
79
+ input_stats_dict=stat
80
+ )
81
+
82
+ def robomimic_abs_action_normalizer_from_stat(stat, rotation_transformer):
83
+ result = dict_apply_split(
84
+ stat, lambda x: {
85
+ 'pos': x[...,:3],
86
+ 'rot': x[...,3:6],
87
+ 'gripper': x[...,6:]
88
+ })
89
+
90
+ def get_pos_param_info(stat, output_max=1, output_min=-1, range_eps=1e-7):
91
+ # -1, 1 normalization
92
+ input_max = stat['max']
93
+ input_min = stat['min']
94
+ input_range = input_max - input_min
95
+ ignore_dim = input_range < range_eps
96
+ input_range[ignore_dim] = output_max - output_min
97
+ scale = (output_max - output_min) / input_range
98
+ offset = output_min - scale * input_min
99
+ offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
100
+
101
+ return {'scale': scale, 'offset': offset}, stat
102
+
103
+ def get_rot_param_info(stat):
104
+ example = rotation_transformer.forward(stat['mean'])
105
+ scale = np.ones_like(example)
106
+ offset = np.zeros_like(example)
107
+ info = {
108
+ 'max': np.ones_like(example),
109
+ 'min': np.full_like(example, -1),
110
+ 'mean': np.zeros_like(example),
111
+ 'std': np.ones_like(example)
112
+ }
113
+ return {'scale': scale, 'offset': offset}, info
114
+
115
+ def get_gripper_param_info(stat):
116
+ example = stat['max']
117
+ scale = np.ones_like(example)
118
+ offset = np.zeros_like(example)
119
+ info = {
120
+ 'max': np.ones_like(example),
121
+ 'min': np.full_like(example, -1),
122
+ 'mean': np.zeros_like(example),
123
+ 'std': np.ones_like(example)
124
+ }
125
+ return {'scale': scale, 'offset': offset}, info
126
+
127
+ pos_param, pos_info = get_pos_param_info(result['pos'])
128
+ rot_param, rot_info = get_rot_param_info(result['rot'])
129
+ gripper_param, gripper_info = get_gripper_param_info(result['gripper'])
130
+
131
+ param = dict_apply_reduce(
132
+ [pos_param, rot_param, gripper_param],
133
+ lambda x: np.concatenate(x,axis=-1))
134
+ info = dict_apply_reduce(
135
+ [pos_info, rot_info, gripper_info],
136
+ lambda x: np.concatenate(x,axis=-1))
137
+
138
+ return SingleFieldLinearNormalizer.create_manual(
139
+ scale=param['scale'],
140
+ offset=param['offset'],
141
+ input_stats_dict=info
142
+ )
143
+
144
+
145
+ def robomimic_abs_action_only_normalizer_from_stat(stat):
146
+ result = dict_apply_split(
147
+ stat, lambda x: {
148
+ 'pos': x[...,:3],
149
+ 'other': x[...,3:]
150
+ })
151
+
152
+ def get_pos_param_info(stat, output_max=1, output_min=-1, range_eps=1e-7):
153
+ # -1, 1 normalization
154
+ input_max = stat['max']
155
+ input_min = stat['min']
156
+ input_range = input_max - input_min
157
+ ignore_dim = input_range < range_eps
158
+ input_range[ignore_dim] = output_max - output_min
159
+ scale = (output_max - output_min) / input_range
160
+ offset = output_min - scale * input_min
161
+ offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
162
+
163
+ return {'scale': scale, 'offset': offset}, stat
164
+
165
+
166
+ def get_other_param_info(stat):
167
+ example = stat['max']
168
+ scale = np.ones_like(example)
169
+ offset = np.zeros_like(example)
170
+ info = {
171
+ 'max': np.ones_like(example),
172
+ 'min': np.full_like(example, -1),
173
+ 'mean': np.zeros_like(example),
174
+ 'std': np.ones_like(example)
175
+ }
176
+ return {'scale': scale, 'offset': offset}, info
177
+
178
+ pos_param, pos_info = get_pos_param_info(result['pos'])
179
+ other_param, other_info = get_other_param_info(result['other'])
180
+
181
+ param = dict_apply_reduce(
182
+ [pos_param, other_param],
183
+ lambda x: np.concatenate(x,axis=-1))
184
+ info = dict_apply_reduce(
185
+ [pos_info, other_info],
186
+ lambda x: np.concatenate(x,axis=-1))
187
+
188
+ return SingleFieldLinearNormalizer.create_manual(
189
+ scale=param['scale'],
190
+ offset=param['offset'],
191
+ input_stats_dict=info
192
+ )
193
+
194
+
195
+ def robomimic_abs_action_only_symmetric_normalizer_from_stat(stat):
196
+ result = dict_apply_split(
197
+ stat, lambda x: {
198
+ 'pos': x[...,:3],
199
+ 'other': x[...,3:]
200
+ })
201
+
202
+ def get_pos_param_info(stat, output_max=1, output_min=-1, range_eps=1e-7):
203
+ # -1, 1 normalization
204
+ input_max = stat['max']
205
+ input_min = stat['min']
206
+ abs_max = np.max([np.abs(stat['max'][:2]), np.abs(stat['min'][:2])])
207
+ input_max[:2] = abs_max
208
+ input_min[:2] = -abs_max
209
+ input_range = input_max - input_min
210
+ ignore_dim = input_range < range_eps
211
+ input_range[ignore_dim] = output_max - output_min
212
+ scale = (output_max - output_min) / input_range
213
+ offset = output_min - scale * input_min
214
+ offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
215
+
216
+ return {'scale': scale, 'offset': offset}, stat
217
+
218
+
219
+ def get_other_param_info(stat):
220
+ example = stat['max']
221
+ scale = np.ones_like(example)
222
+ offset = np.zeros_like(example)
223
+ info = {
224
+ 'max': np.ones_like(example),
225
+ 'min': np.full_like(example, -1),
226
+ 'mean': np.zeros_like(example),
227
+ 'std': np.ones_like(example)
228
+ }
229
+ return {'scale': scale, 'offset': offset}, info
230
+
231
+ pos_param, pos_info = get_pos_param_info(result['pos'])
232
+ other_param, other_info = get_other_param_info(result['other'])
233
+
234
+ param = dict_apply_reduce(
235
+ [pos_param, other_param],
236
+ lambda x: np.concatenate(x,axis=-1))
237
+ info = dict_apply_reduce(
238
+ [pos_info, other_info],
239
+ lambda x: np.concatenate(x,axis=-1))
240
+
241
+ return SingleFieldLinearNormalizer.create_manual(
242
+ scale=param['scale'],
243
+ offset=param['offset'],
244
+ input_stats_dict=info
245
+ )
246
+
247
+
248
+ def robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat):
249
+ Da = stat['max'].shape[-1]
250
+ Dah = Da // 2
251
+ result = dict_apply_split(
252
+ stat, lambda x: {
253
+ 'pos0': x[...,:3],
254
+ 'other0': x[...,3:Dah],
255
+ 'pos1': x[...,Dah:Dah+3],
256
+ 'other1': x[...,Dah+3:]
257
+ })
258
+
259
+ def get_pos_param_info(stat, output_max=1, output_min=-1, range_eps=1e-7):
260
+ # -1, 1 normalization
261
+ input_max = stat['max']
262
+ input_min = stat['min']
263
+ input_range = input_max - input_min
264
+ ignore_dim = input_range < range_eps
265
+ input_range[ignore_dim] = output_max - output_min
266
+ scale = (output_max - output_min) / input_range
267
+ offset = output_min - scale * input_min
268
+ offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
269
+
270
+ return {'scale': scale, 'offset': offset}, stat
271
+
272
+
273
+ def get_other_param_info(stat):
274
+ example = stat['max']
275
+ scale = np.ones_like(example)
276
+ offset = np.zeros_like(example)
277
+ info = {
278
+ 'max': np.ones_like(example),
279
+ 'min': np.full_like(example, -1),
280
+ 'mean': np.zeros_like(example),
281
+ 'std': np.ones_like(example)
282
+ }
283
+ return {'scale': scale, 'offset': offset}, info
284
+
285
+ pos0_param, pos0_info = get_pos_param_info(result['pos0'])
286
+ pos1_param, pos1_info = get_pos_param_info(result['pos1'])
287
+ other0_param, other0_info = get_other_param_info(result['other0'])
288
+ other1_param, other1_info = get_other_param_info(result['other1'])
289
+
290
+ param = dict_apply_reduce(
291
+ [pos0_param, other0_param, pos1_param, other1_param],
292
+ lambda x: np.concatenate(x,axis=-1))
293
+ info = dict_apply_reduce(
294
+ [pos0_info, other0_info, pos1_info, other1_info],
295
+ lambda x: np.concatenate(x,axis=-1))
296
+
297
+ return SingleFieldLinearNormalizer.create_manual(
298
+ scale=param['scale'],
299
+ offset=param['offset'],
300
+ input_stats_dict=info
301
+ )
302
+
303
+
304
+ def array_to_stats(arr: np.ndarray):
305
+ stat = {
306
+ 'min': np.min(arr, axis=0),
307
+ 'max': np.max(arr, axis=0),
308
+ 'mean': np.mean(arr, axis=0),
309
+ 'std': np.std(arr, axis=0)
310
+ }
311
+ return stat
equidiff/equi_diffpo/common/pose_trajectory_interpolator.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import numbers
3
+ import numpy as np
4
+ import scipy.interpolate as si
5
+ import scipy.spatial.transform as st
6
+
7
+ def rotation_distance(a: st.Rotation, b: st.Rotation) -> float:
8
+ return (b * a.inv()).magnitude()
9
+
10
+ def pose_distance(start_pose, end_pose):
11
+ start_pose = np.array(start_pose)
12
+ end_pose = np.array(end_pose)
13
+ start_pos = start_pose[:3]
14
+ end_pos = end_pose[:3]
15
+ start_rot = st.Rotation.from_rotvec(start_pose[3:])
16
+ end_rot = st.Rotation.from_rotvec(end_pose[3:])
17
+ pos_dist = np.linalg.norm(end_pos - start_pos)
18
+ rot_dist = rotation_distance(start_rot, end_rot)
19
+ return pos_dist, rot_dist
20
+
21
+ class PoseTrajectoryInterpolator:
22
+ def __init__(self, times: np.ndarray, poses: np.ndarray):
23
+ assert len(times) >= 1
24
+ assert len(poses) == len(times)
25
+ if not isinstance(times, np.ndarray):
26
+ times = np.array(times)
27
+ if not isinstance(poses, np.ndarray):
28
+ poses = np.array(poses)
29
+
30
+ if len(times) == 1:
31
+ # special treatment for single step interpolation
32
+ self.single_step = True
33
+ self._times = times
34
+ self._poses = poses
35
+ else:
36
+ self.single_step = False
37
+ assert np.all(times[1:] >= times[:-1])
38
+
39
+ pos = poses[:,:3]
40
+ rot = st.Rotation.from_rotvec(poses[:,3:])
41
+
42
+ self.pos_interp = si.interp1d(times, pos,
43
+ axis=0, assume_sorted=True)
44
+ self.rot_interp = st.Slerp(times, rot)
45
+
46
+ @property
47
+ def times(self) -> np.ndarray:
48
+ if self.single_step:
49
+ return self._times
50
+ else:
51
+ return self.pos_interp.x
52
+
53
+ @property
54
+ def poses(self) -> np.ndarray:
55
+ if self.single_step:
56
+ return self._poses
57
+ else:
58
+ n = len(self.times)
59
+ poses = np.zeros((n, 6))
60
+ poses[:,:3] = self.pos_interp.y
61
+ poses[:,3:] = self.rot_interp(self.times).as_rotvec()
62
+ return poses
63
+
64
+ def trim(self,
65
+ start_t: float, end_t: float
66
+ ) -> "PoseTrajectoryInterpolator":
67
+ assert start_t <= end_t
68
+ times = self.times
69
+ should_keep = (start_t < times) & (times < end_t)
70
+ keep_times = times[should_keep]
71
+ all_times = np.concatenate([[start_t], keep_times, [end_t]])
72
+ # remove duplicates, Slerp requires strictly increasing x
73
+ all_times = np.unique(all_times)
74
+ # interpolate
75
+ all_poses = self(all_times)
76
+ return PoseTrajectoryInterpolator(times=all_times, poses=all_poses)
77
+
78
+ def drive_to_waypoint(self,
79
+ pose, time, curr_time,
80
+ max_pos_speed=np.inf,
81
+ max_rot_speed=np.inf
82
+ ) -> "PoseTrajectoryInterpolator":
83
+ assert(max_pos_speed > 0)
84
+ assert(max_rot_speed > 0)
85
+ time = max(time, curr_time)
86
+
87
+ curr_pose = self(curr_time)
88
+ pos_dist, rot_dist = pose_distance(curr_pose, pose)
89
+ pos_min_duration = pos_dist / max_pos_speed
90
+ rot_min_duration = rot_dist / max_rot_speed
91
+ duration = time - curr_time
92
+ duration = max(duration, max(pos_min_duration, rot_min_duration))
93
+ assert duration >= 0
94
+ last_waypoint_time = curr_time + duration
95
+
96
+ # insert new pose
97
+ trimmed_interp = self.trim(curr_time, curr_time)
98
+ times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0)
99
+ poses = np.append(trimmed_interp.poses, [pose], axis=0)
100
+
101
+ # create new interpolator
102
+ final_interp = PoseTrajectoryInterpolator(times, poses)
103
+ return final_interp
104
+
105
+ def schedule_waypoint(self,
106
+ pose, time,
107
+ max_pos_speed=np.inf,
108
+ max_rot_speed=np.inf,
109
+ curr_time=None,
110
+ last_waypoint_time=None
111
+ ) -> "PoseTrajectoryInterpolator":
112
+ assert(max_pos_speed > 0)
113
+ assert(max_rot_speed > 0)
114
+ if last_waypoint_time is not None:
115
+ assert curr_time is not None
116
+
117
+ # trim current interpolator to between curr_time and last_waypoint_time
118
+ start_time = self.times[0]
119
+ end_time = self.times[-1]
120
+ assert start_time <= end_time
121
+
122
+ if curr_time is not None:
123
+ if time <= curr_time:
124
+ # if insert time is earlier than current time
125
+ # no effect should be done to the interpolator
126
+ return self
127
+ # now, curr_time < time
128
+ start_time = max(curr_time, start_time)
129
+
130
+ if last_waypoint_time is not None:
131
+ # if last_waypoint_time is earlier than start_time
132
+ # use start_time
133
+ if time <= last_waypoint_time:
134
+ end_time = curr_time
135
+ else:
136
+ end_time = max(last_waypoint_time, curr_time)
137
+ else:
138
+ end_time = curr_time
139
+
140
+ end_time = min(end_time, time)
141
+ start_time = min(start_time, end_time)
142
+ # end time should be the latest of all times except time
143
+ # after this we can assume order (proven by zhenjia, due to the 2 min operations)
144
+
145
+ # Constraints:
146
+ # start_time <= end_time <= time (proven by zhenjia)
147
+ # curr_time <= start_time (proven by zhenjia)
148
+ # curr_time <= time (proven by zhenjia)
149
+
150
+ # time can't change
151
+ # last_waypoint_time can't change
152
+ # curr_time can't change
153
+ assert start_time <= end_time
154
+ assert end_time <= time
155
+ if last_waypoint_time is not None:
156
+ if time <= last_waypoint_time:
157
+ assert end_time == curr_time
158
+ else:
159
+ assert end_time == max(last_waypoint_time, curr_time)
160
+
161
+ if curr_time is not None:
162
+ assert curr_time <= start_time
163
+ assert curr_time <= time
164
+
165
+ trimmed_interp = self.trim(start_time, end_time)
166
+ # after this, all waypoints in trimmed_interp is within start_time and end_time
167
+ # and is earlier than time
168
+
169
+ # determine speed
170
+ duration = time - end_time
171
+ end_pose = trimmed_interp(end_time)
172
+ pos_dist, rot_dist = pose_distance(pose, end_pose)
173
+ pos_min_duration = pos_dist / max_pos_speed
174
+ rot_min_duration = rot_dist / max_rot_speed
175
+ duration = max(duration, max(pos_min_duration, rot_min_duration))
176
+ assert duration >= 0
177
+ last_waypoint_time = end_time + duration
178
+
179
+ # insert new pose
180
+ times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0)
181
+ poses = np.append(trimmed_interp.poses, [pose], axis=0)
182
+
183
+ # create new interpolator
184
+ final_interp = PoseTrajectoryInterpolator(times, poses)
185
+ return final_interp
186
+
187
+
188
+ def __call__(self, t: Union[numbers.Number, np.ndarray]) -> np.ndarray:
189
+ is_single = False
190
+ if isinstance(t, numbers.Number):
191
+ is_single = True
192
+ t = np.array([t])
193
+
194
+ pose = np.zeros((len(t), 6))
195
+ if self.single_step:
196
+ pose[:] = self._poses[0]
197
+ else:
198
+ start_time = self.times[0]
199
+ end_time = self.times[-1]
200
+ t = np.clip(t, start_time, end_time)
201
+
202
+ pose = np.zeros((len(t), 6))
203
+ pose[:,:3] = self.pos_interp(t)
204
+ pose[:,3:] = self.rot_interp(t).as_rotvec()
205
+
206
+ if is_single:
207
+ pose = pose[0]
208
+ return pose
equidiff/equi_diffpo/common/precise_sleep.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ def precise_sleep(dt: float, slack_time: float=0.001, time_func=time.monotonic):
4
+ """
5
+ Use hybrid of time.sleep and spinning to minimize jitter.
6
+ Sleep dt - slack_time seconds first, then spin for the rest.
7
+ """
8
+ t_start = time_func()
9
+ if dt > slack_time:
10
+ time.sleep(dt - slack_time)
11
+ t_end = t_start + dt
12
+ while time_func() < t_end:
13
+ pass
14
+ return
15
+
16
+ def precise_wait(t_end: float, slack_time: float=0.001, time_func=time.monotonic):
17
+ t_start = time_func()
18
+ t_wait = t_end - t_start
19
+ if t_wait > 0:
20
+ t_sleep = t_wait - slack_time
21
+ if t_sleep > 0:
22
+ time.sleep(t_sleep)
23
+ while time_func() < t_end:
24
+ pass
25
+ return
equidiff/equi_diffpo/common/pymunk_override.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------
2
+ # pymunk
3
+ # Copyright (c) 2007-2016 Victor Blomqvist
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in
13
+ # all copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ # ----------------------------------------------------------------------------
23
+
24
+ """This submodule contains helper functions to help with quick prototyping
25
+ using pymunk together with pygame.
26
+
27
+ Intended to help with debugging and prototyping, not for actual production use
28
+ in a full application. The methods contained in this module is opinionated
29
+ about your coordinate system and not in any way optimized.
30
+ """
31
+
32
+ __docformat__ = "reStructuredText"
33
+
34
+ __all__ = [
35
+ "DrawOptions",
36
+ "get_mouse_pos",
37
+ "to_pygame",
38
+ "from_pygame",
39
+ "lighten",
40
+ "positive_y_is_up",
41
+ ]
42
+
43
+ from typing import List, Sequence, Tuple
44
+
45
+ import pygame
46
+
47
+ import numpy as np
48
+
49
+ import pymunk
50
+ from pymunk.space_debug_draw_options import SpaceDebugColor
51
+ from pymunk.vec2d import Vec2d
52
+
53
+ positive_y_is_up: bool = False
54
+ """Make increasing values of y point upwards.
55
+
56
+ When True::
57
+
58
+ y
59
+ ^
60
+ | . (3, 3)
61
+ |
62
+ | . (2, 2)
63
+ |
64
+ +------ > x
65
+
66
+ When False::
67
+
68
+ +------ > x
69
+ |
70
+ | . (2, 2)
71
+ |
72
+ | . (3, 3)
73
+ v
74
+ y
75
+
76
+ """
77
+
78
+
79
+ class DrawOptions(pymunk.SpaceDebugDrawOptions):
80
+ def __init__(self, surface: pygame.Surface) -> None:
81
+ """Draw a pymunk.Space on a pygame.Surface object.
82
+
83
+ Typical usage::
84
+
85
+ >>> import pymunk
86
+ >>> surface = pygame.Surface((10,10))
87
+ >>> space = pymunk.Space()
88
+ >>> options = pymunk.pygame_util.DrawOptions(surface)
89
+ >>> space.debug_draw(options)
90
+
91
+ You can control the color of a shape by setting shape.color to the color
92
+ you want it drawn in::
93
+
94
+ >>> c = pymunk.Circle(None, 10)
95
+ >>> c.color = pygame.Color("pink")
96
+
97
+ See pygame_util.demo.py for a full example
98
+
99
+ Since pygame uses a coordinate system where y points down (in contrast
100
+ to many other cases), you either have to make the physics simulation
101
+ with Pymunk also behave in that way, or flip everything when you draw.
102
+
103
+ The easiest is probably to just make the simulation behave the same
104
+ way as Pygame does. In that way all coordinates used are in the same
105
+ orientation and easy to reason about::
106
+
107
+ >>> space = pymunk.Space()
108
+ >>> space.gravity = (0, -1000)
109
+ >>> body = pymunk.Body()
110
+ >>> body.position = (0, 0) # will be positioned in the top left corner
111
+ >>> space.debug_draw(options)
112
+
113
+ To flip the drawing its possible to set the module property
114
+ :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
115
+ the simulation upside down before drawing::
116
+
117
+ >>> positive_y_is_up = True
118
+ >>> body = pymunk.Body()
119
+ >>> body.position = (0, 0)
120
+ >>> # Body will be position in bottom left corner
121
+
122
+ :Parameters:
123
+ surface : pygame.Surface
124
+ Surface that the objects will be drawn on
125
+ """
126
+ self.surface = surface
127
+ super(DrawOptions, self).__init__()
128
+
129
+ def draw_circle(
130
+ self,
131
+ pos: Vec2d,
132
+ angle: float,
133
+ radius: float,
134
+ outline_color: SpaceDebugColor,
135
+ fill_color: SpaceDebugColor,
136
+ ) -> None:
137
+ p = to_pygame(pos, self.surface)
138
+
139
+ pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
140
+ pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius-4), 0)
141
+
142
+ circle_edge = pos + Vec2d(radius, 0).rotated(angle)
143
+ p2 = to_pygame(circle_edge, self.surface)
144
+ line_r = 2 if radius > 20 else 1
145
+ # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
146
+
147
+ def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
148
+ p1 = to_pygame(a, self.surface)
149
+ p2 = to_pygame(b, self.surface)
150
+
151
+ pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
152
+
153
+ def draw_fat_segment(
154
+ self,
155
+ a: Tuple[float, float],
156
+ b: Tuple[float, float],
157
+ radius: float,
158
+ outline_color: SpaceDebugColor,
159
+ fill_color: SpaceDebugColor,
160
+ ) -> None:
161
+ p1 = to_pygame(a, self.surface)
162
+ p2 = to_pygame(b, self.surface)
163
+
164
+ r = round(max(1, radius * 2))
165
+ pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
166
+ if r > 2:
167
+ orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
168
+ if orthog[0] == 0 and orthog[1] == 0:
169
+ return
170
+ scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
171
+ orthog[0] = round(orthog[0] * scale)
172
+ orthog[1] = round(orthog[1] * scale)
173
+ points = [
174
+ (p1[0] - orthog[0], p1[1] - orthog[1]),
175
+ (p1[0] + orthog[0], p1[1] + orthog[1]),
176
+ (p2[0] + orthog[0], p2[1] + orthog[1]),
177
+ (p2[0] - orthog[0], p2[1] - orthog[1]),
178
+ ]
179
+ pygame.draw.polygon(self.surface, fill_color.as_int(), points)
180
+ pygame.draw.circle(
181
+ self.surface,
182
+ fill_color.as_int(),
183
+ (round(p1[0]), round(p1[1])),
184
+ round(radius),
185
+ )
186
+ pygame.draw.circle(
187
+ self.surface,
188
+ fill_color.as_int(),
189
+ (round(p2[0]), round(p2[1])),
190
+ round(radius),
191
+ )
192
+
193
+ def draw_polygon(
194
+ self,
195
+ verts: Sequence[Tuple[float, float]],
196
+ radius: float,
197
+ outline_color: SpaceDebugColor,
198
+ fill_color: SpaceDebugColor,
199
+ ) -> None:
200
+ ps = [to_pygame(v, self.surface) for v in verts]
201
+ ps += [ps[0]]
202
+
203
+ radius = 2
204
+ pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
205
+
206
+ if radius > 0:
207
+ for i in range(len(verts)):
208
+ a = verts[i]
209
+ b = verts[(i + 1) % len(verts)]
210
+ self.draw_fat_segment(a, b, radius, fill_color, fill_color)
211
+
212
+ def draw_dot(
213
+ self, size: float, pos: Tuple[float, float], color: SpaceDebugColor
214
+ ) -> None:
215
+ p = to_pygame(pos, self.surface)
216
+ pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
217
+
218
+
219
+ def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
220
+ """Get position of the mouse pointer in pymunk coordinates."""
221
+ p = pygame.mouse.get_pos()
222
+ return from_pygame(p, surface)
223
+
224
+
225
+ def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
226
+ """Convenience method to convert pymunk coordinates to pygame surface
227
+ local coordinates.
228
+
229
+ Note that in case positive_y_is_up is False, this function won't actually do
230
+ anything except converting the point to integers.
231
+ """
232
+ if positive_y_is_up:
233
+ return round(p[0]), surface.get_height() - round(p[1])
234
+ else:
235
+ return round(p[0]), round(p[1])
236
+
237
+
238
+ def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
239
+ """Convenience method to convert pygame surface local coordinates to
240
+ pymunk coordinates
241
+ """
242
+ return to_pygame(p, surface)
243
+
244
+
245
+ def light_color(color: SpaceDebugColor):
246
+ color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
247
+ color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
248
+ return color
equidiff/equi_diffpo/common/pymunk_util.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pygame
2
+ import pymunk
3
+ import pymunk.pygame_util
4
+ import numpy as np
5
+
6
+ COLLTYPE_DEFAULT = 0
7
+ COLLTYPE_MOUSE = 1
8
+ COLLTYPE_BALL = 2
9
+
10
+ def get_body_type(static=False):
11
+ body_type = pymunk.Body.DYNAMIC
12
+ if static:
13
+ body_type = pymunk.Body.STATIC
14
+ return body_type
15
+
16
+
17
+ def create_rectangle(space,
18
+ pos_x,pos_y,width,height,
19
+ density=3,static=False):
20
+ body = pymunk.Body(body_type=get_body_type(static))
21
+ body.position = (pos_x,pos_y)
22
+ shape = pymunk.Poly.create_box(body,(width,height))
23
+ shape.density = density
24
+ space.add(body,shape)
25
+ return body, shape
26
+
27
+
28
+ def create_rectangle_bb(space,
29
+ left, bottom, right, top,
30
+ **kwargs):
31
+ pos_x = (left + right) / 2
32
+ pos_y = (top + bottom) / 2
33
+ height = top - bottom
34
+ width = right - left
35
+ return create_rectangle(space, pos_x, pos_y, width, height, **kwargs)
36
+
37
+ def create_circle(space, pos_x, pos_y, radius, density=3, static=False):
38
+ body = pymunk.Body(body_type=get_body_type(static))
39
+ body.position = (pos_x, pos_y)
40
+ shape = pymunk.Circle(body, radius=radius)
41
+ shape.density = density
42
+ shape.collision_type = COLLTYPE_BALL
43
+ space.add(body, shape)
44
+ return body, shape
45
+
46
+ def get_body_state(body):
47
+ state = np.zeros(6, dtype=np.float32)
48
+ state[:2] = body.position
49
+ state[2] = body.angle
50
+ state[3:5] = body.velocity
51
+ state[5] = body.angular_velocity
52
+ return state
equidiff/equi_diffpo/common/pytorch_util.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Callable, List
2
+ import collections
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def dict_apply(
7
+ x: Dict[str, torch.Tensor],
8
+ func: Callable[[torch.Tensor], torch.Tensor]
9
+ ) -> Dict[str, torch.Tensor]:
10
+ result = dict()
11
+ for key, value in x.items():
12
+ if isinstance(value, dict):
13
+ result[key] = dict_apply(value, func)
14
+ else:
15
+ result[key] = func(value)
16
+ return result
17
+
18
+ def pad_remaining_dims(x, target):
19
+ assert x.shape == target.shape[:len(x.shape)]
20
+ return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape)))
21
+
22
+ def dict_apply_split(
23
+ x: Dict[str, torch.Tensor],
24
+ split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
25
+ ) -> Dict[str, torch.Tensor]:
26
+ results = collections.defaultdict(dict)
27
+ for key, value in x.items():
28
+ result = split_func(value)
29
+ for k, v in result.items():
30
+ results[k][key] = v
31
+ return results
32
+
33
+ def dict_apply_reduce(
34
+ x: List[Dict[str, torch.Tensor]],
35
+ reduce_func: Callable[[List[torch.Tensor]], torch.Tensor]
36
+ ) -> Dict[str, torch.Tensor]:
37
+ result = dict()
38
+ for key in x[0].keys():
39
+ result[key] = reduce_func([x_[key] for x_ in x])
40
+ return result
41
+
42
+
43
+ def replace_submodules(
44
+ root_module: nn.Module,
45
+ predicate: Callable[[nn.Module], bool],
46
+ func: Callable[[nn.Module], nn.Module]) -> nn.Module:
47
+ """
48
+ predicate: Return true if the module is to be replaced.
49
+ func: Return new module to use.
50
+ """
51
+ if predicate(root_module):
52
+ return func(root_module)
53
+
54
+ bn_list = [k.split('.') for k, m
55
+ in root_module.named_modules(remove_duplicate=True)
56
+ if predicate(m)]
57
+ for *parent, k in bn_list:
58
+ parent_module = root_module
59
+ if len(parent) > 0:
60
+ parent_module = root_module.get_submodule('.'.join(parent))
61
+ if isinstance(parent_module, nn.Sequential):
62
+ src_module = parent_module[int(k)]
63
+ else:
64
+ src_module = getattr(parent_module, k)
65
+ tgt_module = func(src_module)
66
+ if isinstance(parent_module, nn.Sequential):
67
+ parent_module[int(k)] = tgt_module
68
+ else:
69
+ setattr(parent_module, k, tgt_module)
70
+ # verify that all BN are replaced
71
+ bn_list = [k.split('.') for k, m
72
+ in root_module.named_modules(remove_duplicate=True)
73
+ if predicate(m)]
74
+ assert len(bn_list) == 0
75
+ return root_module
76
+
77
+ def optimizer_to(optimizer, device):
78
+ for state in optimizer.state.values():
79
+ for k, v in state.items():
80
+ if isinstance(v, torch.Tensor):
81
+ state[k] = v.to(device=device)
82
+ return optimizer
equidiff/equi_diffpo/common/replay_buffer.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Dict, Optional
2
+ import os
3
+ import math
4
+ import numbers
5
+ import zarr
6
+ import numcodecs
7
+ import numpy as np
8
+ from functools import cached_property
9
+
10
+ def check_chunks_compatible(chunks: tuple, shape: tuple):
11
+ assert len(shape) == len(chunks)
12
+ for c in chunks:
13
+ assert isinstance(c, numbers.Integral)
14
+ assert c > 0
15
+
16
+ def rechunk_recompress_array(group, name,
17
+ chunks=None, chunk_length=None,
18
+ compressor=None, tmp_key='_temp'):
19
+ old_arr = group[name]
20
+ if chunks is None:
21
+ if chunk_length is not None:
22
+ chunks = (chunk_length,) + old_arr.chunks[1:]
23
+ else:
24
+ chunks = old_arr.chunks
25
+ check_chunks_compatible(chunks, old_arr.shape)
26
+
27
+ if compressor is None:
28
+ compressor = old_arr.compressor
29
+
30
+ if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
31
+ # no change
32
+ return old_arr
33
+
34
+ # rechunk recompress
35
+ group.move(name, tmp_key)
36
+ old_arr = group[tmp_key]
37
+ n_copied, n_skipped, n_bytes_copied = zarr.copy(
38
+ source=old_arr,
39
+ dest=group,
40
+ name=name,
41
+ chunks=chunks,
42
+ compressor=compressor,
43
+ )
44
+ del group[tmp_key]
45
+ arr = group[name]
46
+ return arr
47
+
48
+ def get_optimal_chunks(shape, dtype,
49
+ target_chunk_bytes=2e6,
50
+ max_chunk_length=None):
51
+ """
52
+ Common shapes
53
+ T,D
54
+ T,N,D
55
+ T,H,W,C
56
+ T,N,H,W,C
57
+ """
58
+ itemsize = np.dtype(dtype).itemsize
59
+ # reversed
60
+ rshape = list(shape[::-1])
61
+ if max_chunk_length is not None:
62
+ rshape[-1] = int(max_chunk_length)
63
+ split_idx = len(shape)-1
64
+ for i in range(len(shape)-1):
65
+ this_chunk_bytes = itemsize * np.prod(rshape[:i])
66
+ next_chunk_bytes = itemsize * np.prod(rshape[:i+1])
67
+ if this_chunk_bytes <= target_chunk_bytes \
68
+ and next_chunk_bytes > target_chunk_bytes:
69
+ split_idx = i
70
+
71
+ rchunks = rshape[:split_idx]
72
+ item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
73
+ this_max_chunk_length = rshape[split_idx]
74
+ next_chunk_length = min(this_max_chunk_length, math.ceil(
75
+ target_chunk_bytes / item_chunk_bytes))
76
+ rchunks.append(next_chunk_length)
77
+ len_diff = len(shape) - len(rchunks)
78
+ rchunks.extend([1] * len_diff)
79
+ chunks = tuple(rchunks[::-1])
80
+ # print(np.prod(chunks) * itemsize / target_chunk_bytes)
81
+ return chunks
82
+
83
+
84
+ class ReplayBuffer:
85
+ """
86
+ Zarr-based temporal datastructure.
87
+ Assumes first dimension to be time. Only chunk in time dimension.
88
+ """
89
+ def __init__(self,
90
+ root: Union[zarr.Group,
91
+ Dict[str,dict]]):
92
+ """
93
+ Dummy constructor. Use copy_from* and create_from* class methods instead.
94
+ """
95
+ assert('data' in root)
96
+ assert('meta' in root)
97
+ assert('episode_ends' in root['meta'])
98
+ for key, value in root['data'].items():
99
+ assert(value.shape[0] == root['meta']['episode_ends'][-1])
100
+ self.root = root
101
+
102
+ # ============= create constructors ===============
103
+ @classmethod
104
+ def create_empty_zarr(cls, storage=None, root=None):
105
+ if root is None:
106
+ if storage is None:
107
+ storage = zarr.MemoryStore()
108
+ root = zarr.group(store=storage)
109
+ data = root.require_group('data', overwrite=False)
110
+ meta = root.require_group('meta', overwrite=False)
111
+ if 'episode_ends' not in meta:
112
+ episode_ends = meta.zeros('episode_ends', shape=(0,), dtype=np.int64,
113
+ compressor=None, overwrite=False)
114
+ return cls(root=root)
115
+
116
+ @classmethod
117
+ def create_empty_numpy(cls):
118
+ root = {
119
+ 'data': dict(),
120
+ 'meta': {
121
+ 'episode_ends': np.zeros((0,), dtype=np.int64)
122
+ }
123
+ }
124
+ return cls(root=root)
125
+
126
+ @classmethod
127
+ def create_from_group(cls, group, **kwargs):
128
+ if 'data' not in group:
129
+ # create from stratch
130
+ buffer = cls.create_empty_zarr(root=group, **kwargs)
131
+ else:
132
+ # already exist
133
+ buffer = cls(root=group, **kwargs)
134
+ return buffer
135
+
136
+ @classmethod
137
+ def create_from_path(cls, zarr_path, mode='r', **kwargs):
138
+ """
139
+ Open a on-disk zarr directly (for dataset larger than memory).
140
+ Slower.
141
+ """
142
+ group = zarr.open(os.path.expanduser(zarr_path), mode)
143
+ return cls.create_from_group(group, **kwargs)
144
+
145
+ # ============= copy constructors ===============
146
+ @classmethod
147
+ def copy_from_store(cls, src_store, store=None, keys=None,
148
+ chunks: Dict[str,tuple]=dict(),
149
+ compressors: Union[dict, str, numcodecs.abc.Codec]=dict(),
150
+ if_exists='replace',
151
+ **kwargs):
152
+ """
153
+ Load to memory.
154
+ """
155
+ src_root = zarr.group(src_store)
156
+ root = None
157
+ if store is None:
158
+ # numpy backend
159
+ meta = dict()
160
+ for key, value in src_root['meta'].items():
161
+ if len(value.shape) == 0:
162
+ meta[key] = np.array(value)
163
+ else:
164
+ meta[key] = value[:]
165
+
166
+ if keys is None:
167
+ keys = src_root['data'].keys()
168
+ data = dict()
169
+ for key in keys:
170
+ arr = src_root['data'][key]
171
+ data[key] = arr[:]
172
+
173
+ root = {
174
+ 'meta': meta,
175
+ 'data': data
176
+ }
177
+ else:
178
+ root = zarr.group(store=store)
179
+ # copy without recompression
180
+ n_copied, n_skipped, n_bytes_copied = zarr.copy_store(source=src_store, dest=store,
181
+ source_path='/meta', dest_path='/meta', if_exists=if_exists)
182
+ data_group = root.create_group('data', overwrite=True)
183
+ if keys is None:
184
+ keys = src_root['data'].keys()
185
+ for key in keys:
186
+ value = src_root['data'][key]
187
+ cks = cls._resolve_array_chunks(
188
+ chunks=chunks, key=key, array=value)
189
+ cpr = cls._resolve_array_compressor(
190
+ compressors=compressors, key=key, array=value)
191
+ if cks == value.chunks and cpr == value.compressor:
192
+ # copy without recompression
193
+ this_path = '/data/' + key
194
+ n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
195
+ source=src_store, dest=store,
196
+ source_path=this_path, dest_path=this_path,
197
+ if_exists=if_exists
198
+ )
199
+ else:
200
+ # copy with recompression
201
+ n_copied, n_skipped, n_bytes_copied = zarr.copy(
202
+ source=value, dest=data_group, name=key,
203
+ chunks=cks, compressor=cpr, if_exists=if_exists
204
+ )
205
+ buffer = cls(root=root)
206
+ return buffer
207
+
208
+ @classmethod
209
+ def copy_from_path(cls, zarr_path, backend=None, store=None, keys=None,
210
+ chunks: Dict[str,tuple]=dict(),
211
+ compressors: Union[dict, str, numcodecs.abc.Codec]=dict(),
212
+ if_exists='replace',
213
+ **kwargs):
214
+ """
215
+ Copy a on-disk zarr to in-memory compressed.
216
+ Recommended
217
+ """
218
+ if backend == 'numpy':
219
+ print('backend argument is deprecated!')
220
+ store = None
221
+ group = zarr.open(os.path.expanduser(zarr_path), 'r')
222
+ return cls.copy_from_store(src_store=group.store, store=store,
223
+ keys=keys, chunks=chunks, compressors=compressors,
224
+ if_exists=if_exists, **kwargs)
225
+
226
+ # ============= save methods ===============
227
+ def save_to_store(self, store,
228
+ chunks: Optional[Dict[str,tuple]]=dict(),
229
+ compressors: Union[str, numcodecs.abc.Codec, dict]=dict(),
230
+ if_exists='replace',
231
+ **kwargs):
232
+
233
+ root = zarr.group(store)
234
+ if self.backend == 'zarr':
235
+ # recompression free copy
236
+ n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
237
+ source=self.root.store, dest=store,
238
+ source_path='/meta', dest_path='/meta', if_exists=if_exists)
239
+ else:
240
+ meta_group = root.create_group('meta', overwrite=True)
241
+ # save meta, no chunking
242
+ for key, value in self.root['meta'].items():
243
+ _ = meta_group.array(
244
+ name=key,
245
+ data=value,
246
+ shape=value.shape,
247
+ chunks=value.shape)
248
+
249
+ # save data, chunk
250
+ data_group = root.create_group('data', overwrite=True)
251
+ for key, value in self.root['data'].items():
252
+ cks = self._resolve_array_chunks(
253
+ chunks=chunks, key=key, array=value)
254
+ cpr = self._resolve_array_compressor(
255
+ compressors=compressors, key=key, array=value)
256
+ if isinstance(value, zarr.Array):
257
+ if cks == value.chunks and cpr == value.compressor:
258
+ # copy without recompression
259
+ this_path = '/data/' + key
260
+ n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
261
+ source=self.root.store, dest=store,
262
+ source_path=this_path, dest_path=this_path, if_exists=if_exists)
263
+ else:
264
+ # copy with recompression
265
+ n_copied, n_skipped, n_bytes_copied = zarr.copy(
266
+ source=value, dest=data_group, name=key,
267
+ chunks=cks, compressor=cpr, if_exists=if_exists
268
+ )
269
+ else:
270
+ # numpy
271
+ _ = data_group.array(
272
+ name=key,
273
+ data=value,
274
+ chunks=cks,
275
+ compressor=cpr
276
+ )
277
+ return store
278
+
279
+ def save_to_path(self, zarr_path,
280
+ chunks: Optional[Dict[str,tuple]]=dict(),
281
+ compressors: Union[str, numcodecs.abc.Codec, dict]=dict(),
282
+ if_exists='replace',
283
+ **kwargs):
284
+ store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
285
+ return self.save_to_store(store, chunks=chunks,
286
+ compressors=compressors, if_exists=if_exists, **kwargs)
287
+
288
+ @staticmethod
289
+ def resolve_compressor(compressor='default'):
290
+ if compressor == 'default':
291
+ compressor = numcodecs.Blosc(cname='lz4', clevel=5,
292
+ shuffle=numcodecs.Blosc.NOSHUFFLE)
293
+ elif compressor == 'disk':
294
+ compressor = numcodecs.Blosc('zstd', clevel=5,
295
+ shuffle=numcodecs.Blosc.BITSHUFFLE)
296
+ return compressor
297
+
298
+ @classmethod
299
+ def _resolve_array_compressor(cls,
300
+ compressors: Union[dict, str, numcodecs.abc.Codec], key, array):
301
+ # allows compressor to be explicitly set to None
302
+ cpr = 'nil'
303
+ if isinstance(compressors, dict):
304
+ if key in compressors:
305
+ cpr = cls.resolve_compressor(compressors[key])
306
+ elif isinstance(array, zarr.Array):
307
+ cpr = array.compressor
308
+ else:
309
+ cpr = cls.resolve_compressor(compressors)
310
+ # backup default
311
+ if cpr == 'nil':
312
+ cpr = cls.resolve_compressor('default')
313
+ return cpr
314
+
315
+ @classmethod
316
+ def _resolve_array_chunks(cls,
317
+ chunks: Union[dict, tuple], key, array):
318
+ cks = None
319
+ if isinstance(chunks, dict):
320
+ if key in chunks:
321
+ cks = chunks[key]
322
+ elif isinstance(array, zarr.Array):
323
+ cks = array.chunks
324
+ elif isinstance(chunks, tuple):
325
+ cks = chunks
326
+ else:
327
+ raise TypeError(f"Unsupported chunks type {type(chunks)}")
328
+ # backup default
329
+ if cks is None:
330
+ cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
331
+ # check
332
+ check_chunks_compatible(chunks=cks, shape=array.shape)
333
+ return cks
334
+
335
+ # ============= properties =================
336
+ @cached_property
337
+ def data(self):
338
+ return self.root['data']
339
+
340
+ @cached_property
341
+ def meta(self):
342
+ return self.root['meta']
343
+
344
+ def update_meta(self, data):
345
+ # sanitize data
346
+ np_data = dict()
347
+ for key, value in data.items():
348
+ if isinstance(value, np.ndarray):
349
+ np_data[key] = value
350
+ else:
351
+ arr = np.array(value)
352
+ if arr.dtype == object:
353
+ raise TypeError(f"Invalid value type {type(value)}")
354
+ np_data[key] = arr
355
+
356
+ meta_group = self.meta
357
+ if self.backend == 'zarr':
358
+ for key, value in np_data.items():
359
+ _ = meta_group.array(
360
+ name=key,
361
+ data=value,
362
+ shape=value.shape,
363
+ chunks=value.shape,
364
+ overwrite=True)
365
+ else:
366
+ meta_group.update(np_data)
367
+
368
+ return meta_group
369
+
370
+ @property
371
+ def episode_ends(self):
372
+ return self.meta['episode_ends']
373
+
374
+ def get_episode_idxs(self):
375
+ import numba
376
+ numba.jit(nopython=True)
377
+ def _get_episode_idxs(episode_ends):
378
+ result = np.zeros((episode_ends[-1],), dtype=np.int64)
379
+ for i in range(len(episode_ends)):
380
+ start = 0
381
+ if i > 0:
382
+ start = episode_ends[i-1]
383
+ end = episode_ends[i]
384
+ for idx in range(start, end):
385
+ result[idx] = i
386
+ return result
387
+ return _get_episode_idxs(self.episode_ends)
388
+
389
+
390
+ @property
391
+ def backend(self):
392
+ backend = 'numpy'
393
+ if isinstance(self.root, zarr.Group):
394
+ backend = 'zarr'
395
+ return backend
396
+
397
+ # =========== dict-like API ==============
398
+ def __repr__(self) -> str:
399
+ if self.backend == 'zarr':
400
+ return str(self.root.tree())
401
+ else:
402
+ return super().__repr__()
403
+
404
+ def keys(self):
405
+ return self.data.keys()
406
+
407
+ def values(self):
408
+ return self.data.values()
409
+
410
+ def items(self):
411
+ return self.data.items()
412
+
413
+ def __getitem__(self, key):
414
+ return self.data[key]
415
+
416
+ def __contains__(self, key):
417
+ return key in self.data
418
+
419
+ # =========== our API ==============
420
+ @property
421
+ def n_steps(self):
422
+ if len(self.episode_ends) == 0:
423
+ return 0
424
+ return self.episode_ends[-1]
425
+
426
+ @property
427
+ def n_episodes(self):
428
+ return len(self.episode_ends)
429
+
430
+ @property
431
+ def chunk_size(self):
432
+ if self.backend == 'zarr':
433
+ return next(iter(self.data.arrays()))[-1].chunks[0]
434
+ return None
435
+
436
+ @property
437
+ def episode_lengths(self):
438
+ ends = self.episode_ends[:]
439
+ ends = np.insert(ends, 0, 0)
440
+ lengths = np.diff(ends)
441
+ return lengths
442
+
443
+ def add_episode(self,
444
+ data: Dict[str, np.ndarray],
445
+ chunks: Optional[Dict[str,tuple]]=dict(),
446
+ compressors: Union[str, numcodecs.abc.Codec, dict]=dict()):
447
+ assert(len(data) > 0)
448
+ is_zarr = (self.backend == 'zarr')
449
+
450
+ curr_len = self.n_steps
451
+ episode_length = None
452
+ for key, value in data.items():
453
+ assert(len(value.shape) >= 1)
454
+ if episode_length is None:
455
+ episode_length = len(value)
456
+ else:
457
+ assert(episode_length == len(value))
458
+ new_len = curr_len + episode_length
459
+
460
+ for key, value in data.items():
461
+ new_shape = (new_len,) + value.shape[1:]
462
+ # create array
463
+ if key not in self.data:
464
+ if is_zarr:
465
+ cks = self._resolve_array_chunks(
466
+ chunks=chunks, key=key, array=value)
467
+ cpr = self._resolve_array_compressor(
468
+ compressors=compressors, key=key, array=value)
469
+ arr = self.data.zeros(name=key,
470
+ shape=new_shape,
471
+ chunks=cks,
472
+ dtype=value.dtype,
473
+ compressor=cpr)
474
+ else:
475
+ # copy data to prevent modify
476
+ arr = np.zeros(shape=new_shape, dtype=value.dtype)
477
+ self.data[key] = arr
478
+ else:
479
+ arr = self.data[key]
480
+ assert(value.shape[1:] == arr.shape[1:])
481
+ # same method for both zarr and numpy
482
+ if is_zarr:
483
+ arr.resize(new_shape)
484
+ else:
485
+ arr.resize(new_shape, refcheck=False)
486
+ # copy data
487
+ arr[-value.shape[0]:] = value
488
+
489
+ # append to episode ends
490
+ episode_ends = self.episode_ends
491
+ if is_zarr:
492
+ episode_ends.resize(episode_ends.shape[0] + 1)
493
+ else:
494
+ episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
495
+ episode_ends[-1] = new_len
496
+
497
+ # rechunk
498
+ if is_zarr:
499
+ if episode_ends.chunks[0] < episode_ends.shape[0]:
500
+ rechunk_recompress_array(self.meta, 'episode_ends',
501
+ chunk_length=int(episode_ends.shape[0] * 1.5))
502
+
503
+ def drop_episode(self):
504
+ is_zarr = (self.backend == 'zarr')
505
+ episode_ends = self.episode_ends[:].copy()
506
+ assert(len(episode_ends) > 0)
507
+ start_idx = 0
508
+ if len(episode_ends) > 1:
509
+ start_idx = episode_ends[-2]
510
+ for key, value in self.data.items():
511
+ new_shape = (start_idx,) + value.shape[1:]
512
+ if is_zarr:
513
+ value.resize(new_shape)
514
+ else:
515
+ value.resize(new_shape, refcheck=False)
516
+ if is_zarr:
517
+ self.episode_ends.resize(len(episode_ends)-1)
518
+ else:
519
+ self.episode_ends.resize(len(episode_ends)-1, refcheck=False)
520
+
521
+ def pop_episode(self):
522
+ assert(self.n_episodes > 0)
523
+ episode = self.get_episode(self.n_episodes-1, copy=True)
524
+ self.drop_episode()
525
+ return episode
526
+
527
+ def extend(self, data):
528
+ self.add_episode(data)
529
+
530
+ def get_episode(self, idx, copy=False):
531
+ idx = list(range(len(self.episode_ends)))[idx]
532
+ start_idx = 0
533
+ if idx > 0:
534
+ start_idx = self.episode_ends[idx-1]
535
+ end_idx = self.episode_ends[idx]
536
+ result = self.get_steps_slice(start_idx, end_idx, copy=copy)
537
+ return result
538
+
539
+ def get_episode_slice(self, idx):
540
+ start_idx = 0
541
+ if idx > 0:
542
+ start_idx = self.episode_ends[idx-1]
543
+ end_idx = self.episode_ends[idx]
544
+ return slice(start_idx, end_idx)
545
+
546
+ def get_steps_slice(self, start, stop, step=None, copy=False):
547
+ _slice = slice(start, stop, step)
548
+
549
+ result = dict()
550
+ for key, value in self.data.items():
551
+ x = value[_slice]
552
+ if copy and isinstance(value, np.ndarray):
553
+ x = x.copy()
554
+ result[key] = x
555
+ return result
556
+
557
+ # =========== chunking =============
558
+ def get_chunks(self) -> dict:
559
+ assert self.backend == 'zarr'
560
+ chunks = dict()
561
+ for key, value in self.data.items():
562
+ chunks[key] = value.chunks
563
+ return chunks
564
+
565
+ def set_chunks(self, chunks: dict):
566
+ assert self.backend == 'zarr'
567
+ for key, value in chunks.items():
568
+ if key in self.data:
569
+ arr = self.data[key]
570
+ if value != arr.chunks:
571
+ check_chunks_compatible(chunks=value, shape=arr.shape)
572
+ rechunk_recompress_array(self.data, key, chunks=value)
573
+
574
+ def get_compressors(self) -> dict:
575
+ assert self.backend == 'zarr'
576
+ compressors = dict()
577
+ for key, value in self.data.items():
578
+ compressors[key] = value.compressor
579
+ return compressors
580
+
581
+ def set_compressors(self, compressors: dict):
582
+ assert self.backend == 'zarr'
583
+ for key, value in compressors.items():
584
+ if key in self.data:
585
+ arr = self.data[key]
586
+ compressor = self.resolve_compressor(value)
587
+ if compressor != arr.compressor:
588
+ rechunk_recompress_array(self.data, key, compressor=compressor)
equidiff/equi_diffpo/common/sampler.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import numpy as np
3
+ import numba
4
+ from equi_diffpo.common.replay_buffer import ReplayBuffer
5
+
6
+
7
+ @numba.jit(nopython=True)
8
+ def create_indices(
9
+ episode_ends:np.ndarray, sequence_length:int,
10
+ episode_mask: np.ndarray,
11
+ pad_before: int=0, pad_after: int=0,
12
+ debug:bool=True) -> np.ndarray:
13
+ episode_mask.shape == episode_ends.shape
14
+ pad_before = min(max(pad_before, 0), sequence_length-1)
15
+ pad_after = min(max(pad_after, 0), sequence_length-1)
16
+
17
+ indices = list()
18
+ for i in range(len(episode_ends)):
19
+ if not episode_mask[i]:
20
+ # skip episode
21
+ continue
22
+ start_idx = 0
23
+ if i > 0:
24
+ start_idx = episode_ends[i-1]
25
+ end_idx = episode_ends[i]
26
+ episode_length = end_idx - start_idx
27
+
28
+ min_start = -pad_before
29
+ max_start = episode_length - sequence_length + pad_after
30
+
31
+ # range stops one idx before end
32
+ for idx in range(min_start, max_start+1):
33
+ buffer_start_idx = max(idx, 0) + start_idx
34
+ buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
35
+ start_offset = buffer_start_idx - (idx+start_idx)
36
+ end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
37
+ sample_start_idx = 0 + start_offset
38
+ sample_end_idx = sequence_length - end_offset
39
+ if debug:
40
+ assert(start_offset >= 0)
41
+ assert(end_offset >= 0)
42
+ assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx)
43
+ indices.append([
44
+ buffer_start_idx, buffer_end_idx,
45
+ sample_start_idx, sample_end_idx])
46
+ indices = np.array(indices)
47
+ return indices
48
+
49
+
50
+ def get_val_mask(n_episodes, val_ratio, seed=0):
51
+ val_mask = np.zeros(n_episodes, dtype=bool)
52
+ if val_ratio <= 0:
53
+ return val_mask
54
+
55
+ # have at least 1 episode for validation, and at least 1 episode for train
56
+ n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes-1)
57
+ rng = np.random.default_rng(seed=seed)
58
+ val_idxs = rng.choice(n_episodes, size=n_val, replace=False)
59
+ val_mask[val_idxs] = True
60
+ return val_mask
61
+
62
+
63
+ def downsample_mask(mask, max_n, seed=0):
64
+ # subsample training data
65
+ train_mask = mask
66
+ if (max_n is not None) and (np.sum(train_mask) > max_n):
67
+ n_train = int(max_n)
68
+ curr_train_idxs = np.nonzero(train_mask)[0]
69
+ rng = np.random.default_rng(seed=seed)
70
+ train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False)
71
+ train_idxs = curr_train_idxs[train_idxs_idx]
72
+ train_mask = np.zeros_like(train_mask)
73
+ train_mask[train_idxs] = True
74
+ assert np.sum(train_mask) == n_train
75
+ return train_mask
76
+
77
+ class SequenceSampler:
78
+ def __init__(self,
79
+ replay_buffer: ReplayBuffer,
80
+ sequence_length:int,
81
+ pad_before:int=0,
82
+ pad_after:int=0,
83
+ keys=None,
84
+ key_first_k=dict(),
85
+ episode_mask: Optional[np.ndarray]=None,
86
+ ):
87
+ """
88
+ key_first_k: dict str: int
89
+ Only take first k data from these keys (to improve perf)
90
+ """
91
+
92
+ super().__init__()
93
+ assert(sequence_length >= 1)
94
+ if keys is None:
95
+ keys = list(replay_buffer.keys())
96
+
97
+ episode_ends = replay_buffer.episode_ends[:]
98
+ if episode_mask is None:
99
+ episode_mask = np.ones(episode_ends.shape, dtype=bool)
100
+
101
+ if np.any(episode_mask):
102
+ indices = create_indices(episode_ends,
103
+ sequence_length=sequence_length,
104
+ pad_before=pad_before,
105
+ pad_after=pad_after,
106
+ episode_mask=episode_mask
107
+ )
108
+ else:
109
+ indices = np.zeros((0,4), dtype=np.int64)
110
+
111
+ # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx)
112
+ self.indices = indices
113
+ self.keys = list(keys) # prevent OmegaConf list performance problem
114
+ self.sequence_length = sequence_length
115
+ self.replay_buffer = replay_buffer
116
+ self.key_first_k = key_first_k
117
+
118
+ def __len__(self):
119
+ return len(self.indices)
120
+
121
+ def sample_sequence(self, idx):
122
+ buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \
123
+ = self.indices[idx]
124
+ result = dict()
125
+ for key in self.keys:
126
+ input_arr = self.replay_buffer[key]
127
+ # performance optimization, avoid small allocation if possible
128
+ if key not in self.key_first_k:
129
+ sample = input_arr[buffer_start_idx:buffer_end_idx]
130
+ else:
131
+ # performance optimization, only load used obs steps
132
+ n_data = buffer_end_idx - buffer_start_idx
133
+ k_data = min(self.key_first_k[key], n_data)
134
+ # fill value with Nan to catch bugs
135
+ # the non-loaded region should never be used
136
+ sample = np.full((n_data,) + input_arr.shape[1:],
137
+ fill_value=np.nan, dtype=input_arr.dtype)
138
+ try:
139
+ sample[:k_data] = input_arr[buffer_start_idx:buffer_start_idx+k_data]
140
+ except Exception as e:
141
+ import pdb; pdb.set_trace()
142
+ data = sample
143
+ if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length):
144
+ data = np.zeros(
145
+ shape=(self.sequence_length,) + input_arr.shape[1:],
146
+ dtype=input_arr.dtype)
147
+ if sample_start_idx > 0:
148
+ data[:sample_start_idx] = sample[0]
149
+ if sample_end_idx < self.sequence_length:
150
+ data[sample_end_idx:] = sample[-1]
151
+ data[sample_start_idx:sample_end_idx] = sample
152
+ result[key] = data
153
+ return result
equidiff/equi_diffpo/common/timestamp_accumulator.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional, Dict
2
+ import math
3
+ import numpy as np
4
+
5
+
6
+ def get_accumulate_timestamp_idxs(
7
+ timestamps: List[float],
8
+ start_time: float,
9
+ dt: float,
10
+ eps:float=1e-5,
11
+ next_global_idx: Optional[int]=0,
12
+ allow_negative=False
13
+ ) -> Tuple[List[int], List[int], int]:
14
+ """
15
+ For each dt window, choose the first timestamp in the window.
16
+ Assumes timestamps sorted. One timestamp might be chosen multiple times due to dropped frames.
17
+ next_global_idx should start at 0 normally, and then use the returned next_global_idx.
18
+ However, when overwiting previous values are desired, set last_global_idx to None.
19
+
20
+ Returns:
21
+ local_idxs: which index in the given timestamps array to chose from
22
+ global_idxs: the global index of each chosen timestamp
23
+ next_global_idx: used for next call.
24
+ """
25
+ local_idxs = list()
26
+ global_idxs = list()
27
+ for local_idx, ts in enumerate(timestamps):
28
+ # add eps * dt to timestamps so that when ts == start_time + k * dt
29
+ # is always recorded as kth element (avoiding floating point errors)
30
+ global_idx = math.floor((ts - start_time) / dt + eps)
31
+ if (not allow_negative) and (global_idx < 0):
32
+ continue
33
+ if next_global_idx is None:
34
+ next_global_idx = global_idx
35
+
36
+ n_repeats = max(0, global_idx - next_global_idx + 1)
37
+ for i in range(n_repeats):
38
+ local_idxs.append(local_idx)
39
+ global_idxs.append(next_global_idx + i)
40
+ next_global_idx += n_repeats
41
+ return local_idxs, global_idxs, next_global_idx
42
+
43
+
44
+ def align_timestamps(
45
+ timestamps: List[float],
46
+ target_global_idxs: List[int],
47
+ start_time: float,
48
+ dt: float,
49
+ eps:float=1e-5):
50
+ if isinstance(target_global_idxs, np.ndarray):
51
+ target_global_idxs = target_global_idxs.tolist()
52
+ assert len(target_global_idxs) > 0
53
+
54
+ local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs(
55
+ timestamps=timestamps,
56
+ start_time=start_time,
57
+ dt=dt,
58
+ eps=eps,
59
+ next_global_idx=target_global_idxs[0],
60
+ allow_negative=True
61
+ )
62
+ if len(global_idxs) > len(target_global_idxs):
63
+ # if more steps available, truncate
64
+ global_idxs = global_idxs[:len(target_global_idxs)]
65
+ local_idxs = local_idxs[:len(target_global_idxs)]
66
+
67
+ if len(global_idxs) == 0:
68
+ import pdb; pdb.set_trace()
69
+
70
+ for i in range(len(target_global_idxs) - len(global_idxs)):
71
+ # if missing, repeat
72
+ local_idxs.append(len(timestamps)-1)
73
+ global_idxs.append(global_idxs[-1] + 1)
74
+ assert global_idxs == target_global_idxs
75
+ assert len(local_idxs) == len(global_idxs)
76
+ return local_idxs
77
+
78
+
79
+ class TimestampObsAccumulator:
80
+ def __init__(self,
81
+ start_time: float,
82
+ dt: float,
83
+ eps: float=1e-5):
84
+ self.start_time = start_time
85
+ self.dt = dt
86
+ self.eps = eps
87
+ self.obs_buffer = dict()
88
+ self.timestamp_buffer = None
89
+ self.next_global_idx = 0
90
+
91
+ def __len__(self):
92
+ return self.next_global_idx
93
+
94
+ @property
95
+ def data(self):
96
+ if self.timestamp_buffer is None:
97
+ return dict()
98
+ result = dict()
99
+ for key, value in self.obs_buffer.items():
100
+ result[key] = value[:len(self)]
101
+ return result
102
+
103
+ @property
104
+ def actual_timestamps(self):
105
+ if self.timestamp_buffer is None:
106
+ return np.array([])
107
+ return self.timestamp_buffer[:len(self)]
108
+
109
+ @property
110
+ def timestamps(self):
111
+ if self.timestamp_buffer is None:
112
+ return np.array([])
113
+ return self.start_time + np.arange(len(self)) * self.dt
114
+
115
+ def put(self, data: Dict[str, np.ndarray], timestamps: np.ndarray):
116
+ """
117
+ data:
118
+ key: T,*
119
+ """
120
+
121
+ local_idxs, global_idxs, self.next_global_idx = get_accumulate_timestamp_idxs(
122
+ timestamps=timestamps,
123
+ start_time=self.start_time,
124
+ dt=self.dt,
125
+ eps=self.eps,
126
+ next_global_idx=self.next_global_idx
127
+ )
128
+
129
+ if len(global_idxs) > 0:
130
+ if self.timestamp_buffer is None:
131
+ # first allocation
132
+ self.obs_buffer = dict()
133
+ for key, value in data.items():
134
+ self.obs_buffer[key] = np.zeros_like(value)
135
+ self.timestamp_buffer = np.zeros(
136
+ (len(timestamps),), dtype=np.float64)
137
+
138
+ this_max_size = global_idxs[-1] + 1
139
+ if this_max_size > len(self.timestamp_buffer):
140
+ # reallocate
141
+ new_size = max(this_max_size, len(self.timestamp_buffer) * 2)
142
+ for key in list(self.obs_buffer.keys()):
143
+ new_shape = (new_size,) + self.obs_buffer[key].shape[1:]
144
+ self.obs_buffer[key] = np.resize(self.obs_buffer[key], new_shape)
145
+ self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size))
146
+
147
+ # write data
148
+ for key, value in self.obs_buffer.items():
149
+ value[global_idxs] = data[key][local_idxs]
150
+ self.timestamp_buffer[global_idxs] = timestamps[local_idxs]
151
+
152
+
153
+ class TimestampActionAccumulator:
154
+ def __init__(self,
155
+ start_time: float,
156
+ dt: float,
157
+ eps: float=1e-5):
158
+ """
159
+ Different from Obs accumulator, the action accumulator
160
+ allows overwriting previous values.
161
+ """
162
+ self.start_time = start_time
163
+ self.dt = dt
164
+ self.eps = eps
165
+ self.action_buffer = None
166
+ self.timestamp_buffer = None
167
+ self.size = 0
168
+
169
+ def __len__(self):
170
+ return self.size
171
+
172
+ @property
173
+ def actions(self):
174
+ if self.action_buffer is None:
175
+ return np.array([])
176
+ return self.action_buffer[:len(self)]
177
+
178
+ @property
179
+ def actual_timestamps(self):
180
+ if self.timestamp_buffer is None:
181
+ return np.array([])
182
+ return self.timestamp_buffer[:len(self)]
183
+
184
+ @property
185
+ def timestamps(self):
186
+ if self.timestamp_buffer is None:
187
+ return np.array([])
188
+ return self.start_time + np.arange(len(self)) * self.dt
189
+
190
+ def put(self, actions: np.ndarray, timestamps: np.ndarray):
191
+ """
192
+ Note: timestamps is the time when the action will be issued,
193
+ not when the action will be completed (target_timestamp)
194
+ """
195
+
196
+ local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs(
197
+ timestamps=timestamps,
198
+ start_time=self.start_time,
199
+ dt=self.dt,
200
+ eps=self.eps,
201
+ # allows overwriting previous actions
202
+ next_global_idx=None
203
+ )
204
+
205
+ if len(global_idxs) > 0:
206
+ if self.timestamp_buffer is None:
207
+ # first allocation
208
+ self.action_buffer = np.zeros_like(actions)
209
+ self.timestamp_buffer = np.zeros((len(actions),), dtype=np.float64)
210
+
211
+ this_max_size = global_idxs[-1] + 1
212
+ if this_max_size > len(self.timestamp_buffer):
213
+ # reallocate
214
+ new_size = max(this_max_size, len(self.timestamp_buffer) * 2)
215
+ new_shape = (new_size,) + self.action_buffer.shape[1:]
216
+ self.action_buffer = np.resize(self.action_buffer, new_shape)
217
+ self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size,))
218
+
219
+ # potentially rewrite old data (as expected)
220
+ self.action_buffer[global_idxs] = actions[local_idxs]
221
+ self.timestamp_buffer[global_idxs] = timestamps[local_idxs]
222
+ self.size = max(self.size, this_max_size)
equidiff/equi_diffpo/config/dp3.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_pc_abs
4
+
5
+ name: train_dp3
6
+ _target_: equi_diffpo.workspace.train_dp3_workspace.TrainDP3Workspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "debug"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: 16
14
+ n_obs_steps: 2
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ keypoint_visible_rate: 1.0
19
+ obs_as_global_cond: True
20
+ dataset_target: equi_diffpo.dataset.robomimic_replay_point_cloud_dataset.RobomimicReplayPointCloudDataset
21
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
22
+
23
+ policy:
24
+ _target_: equi_diffpo.policy.dp3.DP3
25
+ use_point_crop: true
26
+ condition_type: film
27
+ use_down_condition: true
28
+ use_mid_condition: true
29
+ use_up_condition: true
30
+
31
+ diffusion_step_embed_dim: 128
32
+ down_dims:
33
+ - 512
34
+ - 1024
35
+ - 2048
36
+ crop_shape:
37
+ - 80
38
+ - 80
39
+ encoder_output_dim: 64
40
+ horizon: ${horizon}
41
+ kernel_size: 5
42
+ n_action_steps: ${n_action_steps}
43
+ n_groups: 8
44
+ n_obs_steps: ${n_obs_steps}
45
+
46
+ noise_scheduler:
47
+ _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler
48
+ num_train_timesteps: 100
49
+ beta_start: 0.0001
50
+ beta_end: 0.02
51
+ beta_schedule: squaredcos_cap_v2
52
+ clip_sample: True
53
+ set_alpha_to_one: True
54
+ steps_offset: 0
55
+ prediction_type: sample
56
+
57
+
58
+ num_inference_steps: 10
59
+ obs_as_global_cond: true
60
+ shape_meta: ${shape_meta}
61
+
62
+ use_pc_color: true
63
+ pointnet_type: "pointnet"
64
+
65
+
66
+ pointcloud_encoder_cfg:
67
+ in_channels: 3
68
+ out_channels: ${policy.encoder_output_dim}
69
+ use_layernorm: true
70
+ final_norm: layernorm # layernorm, none
71
+ normal_channel: false
72
+
73
+
74
+ ema:
75
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
76
+ update_after_step: 0
77
+ inv_gamma: 1.0
78
+ power: 0.75
79
+ min_value: 0.0
80
+ max_value: 0.9999
81
+
82
+ dataloader:
83
+ batch_size: 128
84
+ num_workers: 8
85
+ shuffle: True
86
+ pin_memory: True
87
+ persistent_workers: True
88
+
89
+ val_dataloader:
90
+ batch_size: 128
91
+ num_workers: 8
92
+ shuffle: False
93
+ pin_memory: True
94
+ persistent_workers: True
95
+
96
+ optimizer:
97
+ _target_: torch.optim.AdamW
98
+ lr: 1.0e-4
99
+ betas: [0.95, 0.999]
100
+ eps: 1.0e-8
101
+ weight_decay: 1.0e-6
102
+
103
+ training:
104
+ device: "cuda:0"
105
+ seed: 42
106
+ debug: False
107
+ resume: True
108
+ lr_scheduler: cosine
109
+ lr_warmup_steps: 500
110
+ num_epochs: ${eval:'50000 / ${n_demo}'}
111
+ gradient_accumulate_every: 1
112
+ use_ema: True
113
+ rollout_every: ${eval:'1000 / ${n_demo}'}
114
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
115
+ val_every: 1
116
+ sample_every: 5
117
+ max_train_steps: null
118
+ max_val_steps: null
119
+ tqdm_interval_sec: 1.0
120
+
121
+ logging:
122
+ project: dp3_${task_name}
123
+ resume: true
124
+ mode: online
125
+ name: dp3_${n_demo}
126
+ tags: ["${name}", "${task_name}", "${exp_name}"]
127
+ id: null
128
+ group: null
129
+
130
+
131
+ checkpoint:
132
+ save_ckpt: False # if True, save checkpoint every checkpoint_every
133
+ topk:
134
+ monitor_key: test_mean_score
135
+ mode: max
136
+ k: 1
137
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
138
+ save_last_ckpt: True # this only saves when save_ckpt is True
139
+ save_last_snapshot: False
140
+
141
+ multi_run:
142
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
143
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
144
+
145
+ hydra:
146
+ job:
147
+ override_dirname: ${name}
148
+ run:
149
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
150
+ sweep:
151
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
152
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/task/mimicgen_abs.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mimicgen_abs
2
+
3
+ shape_meta: &shape_meta
4
+ # acceptable types: rgb, low_dim
5
+ obs:
6
+ agentview_image:
7
+ shape: [3, 84, 84]
8
+ type: rgb
9
+ robot0_eye_in_hand_image:
10
+ shape: [3, 84, 84]
11
+ type: rgb
12
+ robot0_eef_pos:
13
+ shape: [3]
14
+ # type default: low_dim
15
+ robot0_eef_quat:
16
+ shape: [4]
17
+ robot0_gripper_qpos:
18
+ shape: [2]
19
+ action:
20
+ shape: [10]
21
+
22
+ abs_action: &abs_action True
23
+
24
+ env_runner:
25
+ _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
26
+ dataset_path: ${dataset_path}
27
+ shape_meta: *shape_meta
28
+ n_train: 6
29
+ n_train_vis: 2
30
+ train_start_idx: 0
31
+ n_test: 50
32
+ n_test_vis: 4
33
+ test_start_seed: 100000
34
+ max_steps: ${get_max_steps:${task_name}}
35
+ n_obs_steps: ${n_obs_steps}
36
+ n_action_steps: ${n_action_steps}
37
+ render_obs_key: 'agentview_image'
38
+ fps: 10
39
+ crf: 22
40
+ past_action: ${past_action_visible}
41
+ abs_action: *abs_action
42
+ tqdm_interval_sec: 1.0
43
+ n_envs: 28
44
+
45
+ dataset:
46
+ # _target_: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
47
+ _target_: ${dataset}
48
+ n_demo: ${n_demo}
49
+ shape_meta: *shape_meta
50
+ dataset_path: ${dataset_path}
51
+ horizon: ${horizon}
52
+ pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
53
+ pad_after: ${eval:'${n_action_steps}-1'}
54
+ n_obs_steps: ${dataset_obs_steps}
55
+ abs_action: *abs_action
56
+ rotation_rep: 'rotation_6d'
57
+ use_legacy_normalizer: False
58
+ use_cache: True
59
+ seed: 42
60
+ val_ratio: 0.02
equidiff/equi_diffpo/config/task/mimicgen_pc_abs.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mimicgen_pc_abs
2
+
3
+ shape_meta: &shape_meta
4
+ # acceptable types: rgb, low_dim
5
+ obs:
6
+ robot0_eye_in_hand_image:
7
+ shape: [3, 84, 84]
8
+ type: rgb
9
+ point_cloud:
10
+ shape: [1024, 6]
11
+ type: point_cloud
12
+ robot0_eef_pos:
13
+ shape: [3]
14
+ # type default: low_dim
15
+ robot0_eef_quat:
16
+ shape: [4]
17
+ robot0_gripper_qpos:
18
+ shape: [2]
19
+ action:
20
+ shape: [10]
21
+
22
+ env_runner_shape_meta: &env_runner_shape_meta
23
+ # acceptable types: rgb, low_dim
24
+ obs:
25
+ robot0_eye_in_hand_image:
26
+ shape: [3, 84, 84]
27
+ type: rgb
28
+ agentview_image:
29
+ shape: [3, 84, 84]
30
+ type: rgb
31
+ point_cloud:
32
+ shape: [1024, 6]
33
+ type: point_cloud
34
+ robot0_eef_pos:
35
+ shape: [3]
36
+ # type default: low_dim
37
+ robot0_eef_quat:
38
+ shape: [4]
39
+ robot0_gripper_qpos:
40
+ shape: [2]
41
+ action:
42
+ shape: [10]
43
+
44
+ abs_action: &abs_action True
45
+
46
+ env_runner:
47
+ _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
48
+ dataset_path: ${dataset_path}
49
+ shape_meta: *env_runner_shape_meta
50
+ n_train: 6
51
+ n_train_vis: 2
52
+ train_start_idx: 0
53
+ n_test: 50
54
+ n_test_vis: 4
55
+ test_start_seed: 100000
56
+ max_steps: ${get_max_steps:${task_name}}
57
+ n_obs_steps: ${n_obs_steps}
58
+ n_action_steps: ${n_action_steps}
59
+ render_obs_key: 'agentview_image'
60
+ fps: 10
61
+ crf: 22
62
+ past_action: False
63
+ abs_action: *abs_action
64
+ tqdm_interval_sec: 1.0
65
+ n_envs: 28
66
+
67
+ dataset:
68
+ _target_: ${dataset_target}
69
+ n_demo: ${n_demo}
70
+ shape_meta: *shape_meta
71
+ dataset_path: ${dataset_path}
72
+ horizon: ${horizon}
73
+ pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
74
+ pad_after: ${eval:'${n_action_steps}-1'}
75
+ n_obs_steps: ${dataset_obs_steps}
76
+ abs_action: *abs_action
77
+ rotation_rep: 'rotation_6d'
78
+ use_legacy_normalizer: False
79
+ use_cache: False
80
+ seed: 42
81
+ val_ratio: 0.02
equidiff/equi_diffpo/config/task/mimicgen_rel.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mimicgen_rel
2
+
3
+ shape_meta: &shape_meta
4
+ # acceptable types: rgb, low_dim
5
+ obs:
6
+ agentview_image:
7
+ shape: [3, 84, 84]
8
+ type: rgb
9
+ robot0_eye_in_hand_image:
10
+ shape: [3, 84, 84]
11
+ type: rgb
12
+ robot0_eef_pos:
13
+ shape: [3]
14
+ # type default: low_dim
15
+ robot0_eef_quat:
16
+ shape: [4]
17
+ robot0_gripper_qpos:
18
+ shape: [2]
19
+ action:
20
+ shape: [7]
21
+
22
+ abs_action: &abs_action False
23
+
24
+ env_runner:
25
+ _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
26
+ dataset_path: ${dataset_path}
27
+ shape_meta: *shape_meta
28
+ n_train: 6
29
+ n_train_vis: 2
30
+ train_start_idx: 0
31
+ n_test: 50
32
+ n_test_vis: 4
33
+ test_start_seed: 100000
34
+ max_steps: ${get_max_steps:${task_name}}
35
+ n_obs_steps: ${n_obs_steps}
36
+ n_action_steps: ${n_action_steps}
37
+ render_obs_key: 'agentview_image'
38
+ fps: 10
39
+ crf: 22
40
+ past_action: ${past_action_visible}
41
+ abs_action: *abs_action
42
+ tqdm_interval_sec: 1.0
43
+ n_envs: 28
44
+
45
+ dataset:
46
+ # _target_: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
47
+ _target_: ${dataset}
48
+ n_demo: ${n_demo}
49
+ shape_meta: *shape_meta
50
+ dataset_path: ${dataset_path}
51
+ horizon: ${horizon}
52
+ pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
53
+ pad_after: ${eval:'${n_action_steps}-1'}
54
+ n_obs_steps: ${dataset_obs_steps}
55
+ abs_action: *abs_action
56
+ rotation_rep: 'rotation_6d'
57
+ use_legacy_normalizer: False
58
+ use_cache: True
59
+ seed: 42
60
+ val_ratio: 0.02
equidiff/equi_diffpo/config/task/mimicgen_voxel_abs.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mimicgen_abs
2
+
3
+ shape_meta: &shape_meta
4
+ # acceptable types: rgb, low_dim
5
+ obs:
6
+ robot0_eye_in_hand_image:
7
+ shape: [3, 84, 84]
8
+ type: rgb
9
+ voxels:
10
+ shape: [4, 64, 64, 64]
11
+ type: voxel
12
+ robot0_eef_pos:
13
+ shape: [3]
14
+ # type default: low_dim
15
+ robot0_eef_quat:
16
+ shape: [4]
17
+ robot0_gripper_qpos:
18
+ shape: [2]
19
+ action:
20
+ shape: [10]
21
+
22
+ env_runner_shape_meta: &env_runner_shape_meta
23
+ # acceptable types: rgb, low_dim
24
+ obs:
25
+ robot0_eye_in_hand_image:
26
+ shape: [3, 84, 84]
27
+ type: rgb
28
+ agentview_image:
29
+ shape: [3, 84, 84]
30
+ type: rgb
31
+ voxels:
32
+ shape: [4, 64, 64, 64]
33
+ type: voxel
34
+ robot0_eef_pos:
35
+ shape: [3]
36
+ # type default: low_dim
37
+ robot0_eef_quat:
38
+ shape: [4]
39
+ robot0_gripper_qpos:
40
+ shape: [2]
41
+ action:
42
+ shape: [10]
43
+
44
+ # dataset_path: &dataset_path data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
45
+ abs_action: &abs_action True
46
+
47
+ env_runner:
48
+ _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
49
+ dataset_path: ${dataset_path}
50
+ shape_meta: *env_runner_shape_meta
51
+ n_train: 6
52
+ n_train_vis: 2
53
+ train_start_idx: 0
54
+ n_test: 50
55
+ n_test_vis: 4
56
+ test_start_seed: 100000
57
+ max_steps: ${get_max_steps:${task_name}}
58
+ n_obs_steps: ${n_obs_steps}
59
+ n_action_steps: ${n_action_steps}
60
+ render_obs_key: 'agentview_image'
61
+ fps: 10
62
+ crf: 22
63
+ past_action: ${past_action_visible}
64
+ abs_action: *abs_action
65
+ tqdm_interval_sec: 1.0
66
+ n_envs: 28
67
+
68
+ dataset:
69
+ _target_: equi_diffpo.dataset.robomimic_replay_voxel_sym_dataset.RobomimicReplayVoxelSymDataset
70
+ n_demo: ${n_demo}
71
+ shape_meta: *shape_meta
72
+ dataset_path: ${dataset_path}
73
+ horizon: ${horizon}
74
+ pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
75
+ pad_after: ${eval:'${n_action_steps}-1'}
76
+ n_obs_steps: ${dataset_obs_steps}
77
+ abs_action: *abs_action
78
+ rotation_rep: 'rotation_6d'
79
+ use_legacy_normalizer: False
80
+ use_cache: True
81
+ seed: 42
82
+ val_ratio: 0.02
83
+ ws_x_center: ${get_ws_x_center:${task_name}}
84
+ ws_y_center: ${get_ws_y_center:${task_name}}
equidiff/equi_diffpo/config/task/mimicgen_voxel_rel.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mimicgen_rel
2
+
3
+ shape_meta: &shape_meta
4
+ # acceptable types: rgb, low_dim
5
+ obs:
6
+ robot0_eye_in_hand_image:
7
+ shape: [3, 84, 84]
8
+ type: rgb
9
+ voxels:
10
+ shape: [4, 64, 64, 64]
11
+ type: voxel
12
+ robot0_eef_pos:
13
+ shape: [3]
14
+ # type default: low_dim
15
+ robot0_eef_quat:
16
+ shape: [4]
17
+ robot0_gripper_qpos:
18
+ shape: [2]
19
+ action:
20
+ shape: [7]
21
+
22
+ env_runner_shape_meta: &env_runner_shape_meta
23
+ # acceptable types: rgb, low_dim
24
+ obs:
25
+ robot0_eye_in_hand_image:
26
+ shape: [3, 84, 84]
27
+ type: rgb
28
+ agentview_image:
29
+ shape: [3, 84, 84]
30
+ type: rgb
31
+ voxels:
32
+ shape: [4, 64, 64, 64]
33
+ type: voxel
34
+ robot0_eef_pos:
35
+ shape: [3]
36
+ # type default: low_dim
37
+ robot0_eef_quat:
38
+ shape: [4]
39
+ robot0_gripper_qpos:
40
+ shape: [2]
41
+ action:
42
+ shape: [7]
43
+
44
+ # dataset_path: &dataset_path data/robomimic/datasets/${task_name}/${task_name}_voxel.hdf5
45
+ abs_action: &abs_action False
46
+
47
+ env_runner:
48
+ _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
49
+ dataset_path: ${dataset_path}
50
+ shape_meta: *env_runner_shape_meta
51
+ n_train: 6
52
+ n_train_vis: 2
53
+ train_start_idx: 0
54
+ n_test: 50
55
+ n_test_vis: 4
56
+ test_start_seed: 100000
57
+ max_steps: ${get_max_steps:${task_name}}
58
+ n_obs_steps: ${n_obs_steps}
59
+ n_action_steps: ${n_action_steps}
60
+ render_obs_key: 'agentview_image'
61
+ fps: 10
62
+ crf: 22
63
+ past_action: ${past_action_visible}
64
+ abs_action: *abs_action
65
+ tqdm_interval_sec: 1.0
66
+ n_envs: 28
67
+
68
+ dataset:
69
+ _target_: equi_diffpo.dataset.robomimic_replay_voxel_sym_dataset.RobomimicReplayVoxelSymDataset
70
+ n_demo: ${n_demo}
71
+ shape_meta: *shape_meta
72
+ dataset_path: ${dataset_path}
73
+ horizon: ${horizon}
74
+ pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
75
+ pad_after: ${eval:'${n_action_steps}-1'}
76
+ n_obs_steps: ${dataset_obs_steps}
77
+ abs_action: *abs_action
78
+ rotation_rep: 'rotation_6d'
79
+ use_legacy_normalizer: False
80
+ use_cache: True
81
+ seed: 42
82
+ val_ratio: 0.02
83
+ ws_x_center: ${get_ws_x_center:${task_name}}
84
+ ws_y_center: ${get_ws_y_center:${task_name}}
equidiff/equi_diffpo/config/test_equi_diffusion_unet_abs_sq2.yaml ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: equi_diff
6
+ _target_: equi_diffpo.workspace.test_equi_workspace.TestEquiWorkspace
7
+ ckpt_path: data/outputs/2025.01.10/06.04.25_equi_diff_square_d2_high/checkpoints/epoch=0046-test_mean_score=0.760.ckpt
8
+ diversity: high
9
+
10
+ shape_meta: ${task.shape_meta}
11
+ exp_name: "default"
12
+
13
+ task_name: square_d2
14
+ log_txt_path: data/test_result.txt
15
+ n_demo: 1000
16
+ horizon: 16
17
+ n_obs_steps: 2
18
+ n_action_steps: 8
19
+ n_latency_steps: 0
20
+ dataset_obs_steps: ${n_obs_steps}
21
+ past_action_visible: False
22
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
23
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
24
+
25
+ policy:
26
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
27
+
28
+ shape_meta: ${shape_meta}
29
+
30
+ noise_scheduler:
31
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
32
+ num_train_timesteps: 100
33
+ beta_start: 0.0001
34
+ beta_end: 0.02
35
+ beta_schedule: squaredcos_cap_v2
36
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
37
+ clip_sample: True # required when predict_epsilon=False
38
+ prediction_type: epsilon # or sample
39
+
40
+ horizon: ${horizon}
41
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
42
+ n_obs_steps: ${n_obs_steps}
43
+ num_inference_steps: 100
44
+ crop_shape: [76, 76]
45
+ # crop_shape: null
46
+ diffusion_step_embed_dim: 128
47
+ enc_n_hidden: 128
48
+ down_dims: [512, 1024, 2048]
49
+ kernel_size: 5
50
+ n_groups: 8
51
+ cond_predict_scale: True
52
+ rot_aug: False
53
+
54
+ # scheduler.step params
55
+ # predict_epsilon: True
56
+
57
+ ema:
58
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
59
+ update_after_step: 0
60
+ inv_gamma: 1.0
61
+ power: 0.75
62
+ min_value: 0.0
63
+ max_value: 0.9999
64
+
65
+ dataloader:
66
+ batch_size: 128
67
+ num_workers: 4
68
+ shuffle: True
69
+ pin_memory: True
70
+ persistent_workers: True
71
+ drop_last: true
72
+
73
+ val_dataloader:
74
+ batch_size: 128
75
+ num_workers: 8
76
+ shuffle: False
77
+ pin_memory: True
78
+ persistent_workers: True
79
+
80
+ optimizer:
81
+ betas: [0.95, 0.999]
82
+ eps: 1.0e-08
83
+ learning_rate: 0.0001
84
+ weight_decay: 1.0e-06
85
+
86
+ training:
87
+ ckpt_path: ${ckpt_path}
88
+ device: "cuda:0"
89
+ seed: 0
90
+ debug: False
91
+ resume: True
92
+ # optimization
93
+ lr_scheduler: cosine
94
+ lr_warmup_steps: 500
95
+ num_epochs: ${eval:'50000 / ${n_demo}'}
96
+ gradient_accumulate_every: 1
97
+ # EMA destroys performance when used with BatchNorm
98
+ # replace BatchNorm with GroupNorm.
99
+ use_ema: True
100
+ # training loop control
101
+ # in epochs
102
+ rollout_every: ${eval:'1000 / ${n_demo}'}
103
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
104
+ val_every: 1
105
+ sample_every: 5
106
+ # steps per epoch
107
+ max_train_steps: null
108
+ max_val_steps: null
109
+ # misc
110
+ tqdm_interval_sec: 1.0
111
+
112
+ logging:
113
+ project: test_diffusion_policy_${task_name}
114
+ resume: True
115
+ mode: online
116
+ name: equidiff_${n_demo}_${diversity}_${policy.n_action_steps}
117
+ tags: ["${name}", "${task_name}", "${exp_name}"]
118
+ id: null
119
+ group: null
120
+
121
+ checkpoint:
122
+ topk:
123
+ monitor_key: test_mean_score
124
+ mode: max
125
+ k: 5
126
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
127
+ save_last_ckpt: True
128
+ save_last_snapshot: False
129
+
130
+ # multi_run:
131
+ # run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
132
+ # wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
133
+
134
+ hydra:
135
+ job:
136
+ override_dirname: ${name}
137
+ run:
138
+ dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
139
+ sweep:
140
+ dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
141
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/test_sq2.yaml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ run_name: square_d2_test
6
+ name: equi_diff
7
+ _target_: equi_diffpo.workspace.test_equi_workspace.TestEquiWorkspace
8
+ ckpt_path: /home/siweih/Project/EmbodiedBM/equidiff/data/outputs/2025.02.23/00.15.04_equi_diff_square_d2/checkpoints/epoch=0019-test_mean_score=0.840.ckpt
9
+ diversity: high
10
+
11
+ shape_meta: ${task.shape_meta}
12
+ exp_name: "default"
13
+
14
+ task_name: square_d2
15
+ log_txt_path: data/sq2_test_result.txt
16
+ n_demo: 1000
17
+ horizon: 16
18
+ n_obs_steps: 2
19
+ n_action_steps: 8
20
+ n_latency_steps: 0
21
+ dataset_obs_steps: ${n_obs_steps}
22
+ past_action_visible: False
23
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
24
+ dataset_path: data/robomimic/datasets/square_d2/square_d2_abs.hdf5
25
+
26
+ policy:
27
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
28
+
29
+ shape_meta: ${shape_meta}
30
+
31
+ noise_scheduler:
32
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
33
+ num_train_timesteps: 100
34
+ beta_start: 0.0001
35
+ beta_end: 0.02
36
+ beta_schedule: squaredcos_cap_v2
37
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
38
+ clip_sample: True # required when predict_epsilon=False
39
+ prediction_type: epsilon # or sample
40
+
41
+ horizon: ${horizon}
42
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
43
+ n_obs_steps: ${n_obs_steps}
44
+ num_inference_steps: 100
45
+ crop_shape: [76, 76]
46
+ # crop_shape: null
47
+ diffusion_step_embed_dim: 128
48
+ enc_n_hidden: 128
49
+ down_dims: [512, 1024, 2048]
50
+ kernel_size: 5
51
+ n_groups: 8
52
+ cond_predict_scale: True
53
+ rot_aug: False
54
+
55
+ # scheduler.step params
56
+ # predict_epsilon: True
57
+
58
+ ema:
59
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
60
+ update_after_step: 0
61
+ inv_gamma: 1.0
62
+ power: 0.75
63
+ min_value: 0.0
64
+ max_value: 0.9999
65
+
66
+ dataloader:
67
+ batch_size: 128
68
+ num_workers: 4
69
+ shuffle: True
70
+ pin_memory: True
71
+ persistent_workers: True
72
+ drop_last: true
73
+
74
+ val_dataloader:
75
+ batch_size: 128
76
+ num_workers: 8
77
+ shuffle: False
78
+ pin_memory: True
79
+ persistent_workers: True
80
+
81
+ optimizer:
82
+ betas: [0.95, 0.999]
83
+ eps: 1.0e-08
84
+ learning_rate: 0.0001
85
+ weight_decay: 1.0e-06
86
+
87
+ training:
88
+ ckpt_path: ${ckpt_path}
89
+ device: "cuda:0"
90
+ seed: 0
91
+ debug: False
92
+ resume: True
93
+ # optimization
94
+ lr_scheduler: cosine
95
+ lr_warmup_steps: 500
96
+ num_epochs: ${eval:'50000 / ${n_demo}'}
97
+ gradient_accumulate_every: 1
98
+ # EMA destroys performance when used with BatchNorm
99
+ # replace BatchNorm with GroupNorm.
100
+ use_ema: True
101
+ # training loop control
102
+ # in epochs
103
+ rollout_every: ${eval:'1000 / ${n_demo}'}
104
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
105
+ val_every: 1
106
+ sample_every: 5
107
+ # steps per epoch
108
+ max_train_steps: null
109
+ max_val_steps: null
110
+ # misc
111
+ tqdm_interval_sec: 1.0
112
+
113
+ logging:
114
+ project: test_diffusion_policy_${task_name}
115
+ resume: True
116
+ mode: online
117
+ name: equidiff_${n_demo}_${diversity}_${policy.n_action_steps}
118
+ tags: ["${name}", "${task_name}", "${exp_name}"]
119
+ id: null
120
+ group: null
121
+
122
+ checkpoint:
123
+ topk:
124
+ monitor_key: test_mean_score
125
+ mode: max
126
+ k: 5
127
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
128
+ save_last_ckpt: True
129
+ save_last_snapshot: False
130
+
131
+ # multi_run:
132
+ # run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
133
+ # wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
134
+
135
+ hydra:
136
+ job:
137
+ override_dirname: ${name}
138
+ run:
139
+ dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
140
+ sweep:
141
+ dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
142
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/test_th2.yaml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ run_name: threading_d2_test
6
+ name: equi_diff
7
+ _target_: equi_diffpo.workspace.test_equi_workspace.TestEquiWorkspace
8
+ ckpt_path: null
9
+ diversity: high
10
+
11
+ shape_meta: ${task.shape_meta}
12
+ exp_name: "default"
13
+
14
+ task_name: threading_d2
15
+ log_txt_path: data/th2_test_result.txt
16
+ n_demo: 100
17
+ horizon: 16
18
+ n_obs_steps: 2
19
+ n_action_steps: 8
20
+ n_latency_steps: 0
21
+ dataset_obs_steps: ${n_obs_steps}
22
+ past_action_visible: False
23
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
24
+ dataset_path: data/robomimic/datasets/threading_d2_test/demo_abs.hdf5
25
+
26
+ policy:
27
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
28
+
29
+ shape_meta: ${shape_meta}
30
+
31
+ noise_scheduler:
32
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
33
+ num_train_timesteps: 100
34
+ beta_start: 0.0001
35
+ beta_end: 0.02
36
+ beta_schedule: squaredcos_cap_v2
37
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
38
+ clip_sample: True # required when predict_epsilon=False
39
+ prediction_type: epsilon # or sample
40
+
41
+ horizon: ${horizon}
42
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
43
+ n_obs_steps: ${n_obs_steps}
44
+ num_inference_steps: 100
45
+ crop_shape: [76, 76]
46
+ # crop_shape: null
47
+ diffusion_step_embed_dim: 128
48
+ enc_n_hidden: 128
49
+ down_dims: [512, 1024, 2048]
50
+ kernel_size: 5
51
+ n_groups: 8
52
+ cond_predict_scale: True
53
+ rot_aug: False
54
+
55
+ # scheduler.step params
56
+ # predict_epsilon: True
57
+
58
+ ema:
59
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
60
+ update_after_step: 0
61
+ inv_gamma: 1.0
62
+ power: 0.75
63
+ min_value: 0.0
64
+ max_value: 0.9999
65
+
66
+ dataloader:
67
+ batch_size: 128
68
+ num_workers: 4
69
+ shuffle: True
70
+ pin_memory: True
71
+ persistent_workers: True
72
+ drop_last: true
73
+
74
+ val_dataloader:
75
+ batch_size: 128
76
+ num_workers: 8
77
+ shuffle: False
78
+ pin_memory: True
79
+ persistent_workers: True
80
+
81
+ optimizer:
82
+ betas: [0.95, 0.999]
83
+ eps: 1.0e-08
84
+ learning_rate: 0.0001
85
+ weight_decay: 1.0e-06
86
+
87
+ training:
88
+ ckpt_path: ${ckpt_path}
89
+ device: "cuda:0"
90
+ seed: 0
91
+ debug: False
92
+ resume: True
93
+ # optimization
94
+ lr_scheduler: cosine
95
+ lr_warmup_steps: 500
96
+ num_epochs: ${eval:'50000 / ${n_demo}'}
97
+ gradient_accumulate_every: 1
98
+ # EMA destroys performance when used with BatchNorm
99
+ # replace BatchNorm with GroupNorm.
100
+ use_ema: True
101
+ # training loop control
102
+ # in epochs
103
+ rollout_every: ${eval:'1000 / ${n_demo}'}
104
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
105
+ val_every: 1
106
+ sample_every: 5
107
+ # steps per epoch
108
+ max_train_steps: null
109
+ max_val_steps: null
110
+ # misc
111
+ tqdm_interval_sec: 1.0
112
+
113
+ logging:
114
+ project: test_diffusion_policy_${task_name}
115
+ resume: True
116
+ mode: online
117
+ name: equidiff_${n_demo}_${diversity}_${policy.n_action_steps}
118
+ tags: ["${name}", "${task_name}", "${exp_name}"]
119
+ id: null
120
+ group: null
121
+
122
+ checkpoint:
123
+ topk:
124
+ monitor_key: test_mean_score
125
+ mode: max
126
+ k: 5
127
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
128
+ save_last_ckpt: True
129
+ save_last_snapshot: False
130
+
131
+ # multi_run:
132
+ # run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
133
+ # wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
134
+
135
+ hydra:
136
+ job:
137
+ override_dirname: ${name}
138
+ run:
139
+ dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
140
+ sweep:
141
+ dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
142
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_act_abs.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: act
6
+ _target_: equi_diffpo.workspace.train_act_workspace.TrainActWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: 10
14
+ n_obs_steps: 1
15
+ n_action_steps: 10
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.act_policy.ACTPolicyWrapper
24
+
25
+ shape_meta: ${shape_meta}
26
+
27
+ max_timesteps: ${task.env_runner.max_steps}
28
+ temporal_agg: false
29
+ n_envs: ${task.env_runner.n_envs}
30
+ horizon: ${horizon}
31
+
32
+ dataloader:
33
+ batch_size: 64
34
+ num_workers: 4
35
+ shuffle: True
36
+ pin_memory: True
37
+ persistent_workers: True
38
+
39
+ val_dataloader:
40
+ batch_size: 64
41
+ num_workers: 4
42
+ shuffle: False
43
+ pin_memory: True
44
+ persistent_workers: True
45
+
46
+ training:
47
+ device: "cuda:0"
48
+ seed: 0
49
+ debug: False
50
+ resume: True
51
+ num_epochs: ${eval:'50000 / ${n_demo}'}
52
+ rollout_every: ${eval:'1000 / ${n_demo}'}
53
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
54
+ val_every: 1
55
+ max_train_steps: null
56
+ max_val_steps: null
57
+ tqdm_interval_sec: 1.0
58
+
59
+ logging:
60
+ project: diffusion_policy_${task_name}
61
+ resume: True
62
+ mode: online
63
+ name: act_demo${n_demo}
64
+ tags: ["${name}", "${task_name}", "${exp_name}"]
65
+ id: null
66
+ group: null
67
+
68
+ checkpoint:
69
+ topk:
70
+ monitor_key: test_mean_score
71
+ mode: max
72
+ k: 5
73
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
74
+ save_last_ckpt: True
75
+ save_last_snapshot: False
76
+
77
+ multi_run:
78
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
79
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
80
+
81
+ hydra:
82
+ job:
83
+ override_dirname: ${name}
84
+ run:
85
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
86
+ sweep:
87
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
88
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_bc_rnn.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_rel
4
+
5
+ name: bc_rnn
6
+ _target_: equi_diffpo.workspace.train_robomimic_image_workspace.TrainRobomimicImageWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: &horizon 10
14
+ n_obs_steps: 1
15
+ n_action_steps: 1
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: *horizon
18
+ past_action_visible: False
19
+ dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.robomimic_image_policy.RobomimicImagePolicy
24
+ shape_meta: ${shape_meta}
25
+ algo_name: bc_rnn
26
+ obs_type: image
27
+ # oc.select resolver: key, default
28
+ task_name: ${oc.select:task.task_name,lift}
29
+ dataset_type: ${oc.select:task.dataset_type,ph}
30
+ crop_shape: [76,76]
31
+
32
+ dataloader:
33
+ batch_size: 64
34
+ num_workers: 4
35
+ shuffle: True
36
+ pin_memory: True
37
+ persistent_workers: True
38
+
39
+ val_dataloader:
40
+ batch_size: 64
41
+ num_workers: 4
42
+ shuffle: False
43
+ pin_memory: True
44
+ persistent_workers: True
45
+
46
+ training:
47
+ device: "cuda:0"
48
+ seed: 0
49
+ debug: False
50
+ resume: True
51
+ # optimization
52
+ num_epochs: ${eval:'50000 / ${n_demo}'}
53
+ # training loop control
54
+ # in epochs
55
+ rollout_every: ${eval:'1000 / ${n_demo}'}
56
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
57
+ val_every: 1
58
+ sample_every: 5
59
+ # steps per epoch
60
+ max_train_steps: null
61
+ max_val_steps: null
62
+ # misc
63
+ tqdm_interval_sec: 1.0
64
+
65
+ logging:
66
+ project: diffusion_policy_${task_name}
67
+ resume: True
68
+ mode: online
69
+ name: bc_rnn_demo${n_demo}
70
+ tags: ["${name}", "${task_name}", "${exp_name}"]
71
+ id: null
72
+ group: null
73
+
74
+ checkpoint:
75
+ topk:
76
+ monitor_key: test_mean_score
77
+ mode: max
78
+ k: 5
79
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
80
+ save_last_ckpt: True
81
+ save_last_snapshot: False
82
+
83
+ multi_run:
84
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
85
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
86
+
87
+ hydra:
88
+ job:
89
+ override_dirname: ${name}
90
+ run:
91
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
92
+ sweep:
93
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
94
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_diffusion_transformer.yaml ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: diff_t
6
+ _target_: equi_diffpo.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: 10
14
+ n_obs_steps: 2
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ obs_as_cond: True
20
+ dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
21
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
22
+
23
+ policy:
24
+ _target_: equi_diffpo.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
25
+
26
+ shape_meta: ${shape_meta}
27
+
28
+ noise_scheduler:
29
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
30
+ num_train_timesteps: 100
31
+ beta_start: 0.0001
32
+ beta_end: 0.02
33
+ beta_schedule: squaredcos_cap_v2
34
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
35
+ clip_sample: True # required when predict_epsilon=False
36
+ prediction_type: epsilon # or sample
37
+
38
+ horizon: ${horizon}
39
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
40
+ n_obs_steps: ${n_obs_steps}
41
+ num_inference_steps: 100
42
+
43
+ crop_shape: [76, 76]
44
+ obs_encoder_group_norm: True
45
+ eval_fixed_crop: True
46
+
47
+ n_layer: 8
48
+ n_cond_layers: 0 # >0: use transformer encoder for cond, otherwise use MLP
49
+ n_head: 4
50
+ n_emb: 256
51
+ p_drop_emb: 0.0
52
+ p_drop_attn: 0.3
53
+ causal_attn: True
54
+ time_as_cond: True # if false, use BERT like encoder only arch, time as input
55
+ obs_as_cond: ${obs_as_cond}
56
+
57
+ # scheduler.step params
58
+ # predict_epsilon: True
59
+
60
+ ema:
61
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
62
+ update_after_step: 0
63
+ inv_gamma: 1.0
64
+ power: 0.75
65
+ min_value: 0.0
66
+ max_value: 0.9999
67
+
68
+ dataloader:
69
+ batch_size: 64
70
+ num_workers: 4
71
+ shuffle: True
72
+ pin_memory: True
73
+ persistent_workers: True
74
+
75
+ val_dataloader:
76
+ batch_size: 64
77
+ num_workers: 4
78
+ shuffle: False
79
+ pin_memory: True
80
+ persistent_workers: True
81
+
82
+ optimizer:
83
+ transformer_weight_decay: 1.0e-3
84
+ obs_encoder_weight_decay: 1.0e-6
85
+ learning_rate: 1.0e-4
86
+ betas: [0.9, 0.95]
87
+
88
+ training:
89
+ device: "cuda:0"
90
+ seed: 0
91
+ debug: False
92
+ resume: True
93
+ # optimization
94
+ lr_scheduler: cosine
95
+ # Transformer needs LR warmup
96
+ lr_warmup_steps: 1000
97
+ num_epochs: ${eval:'50000 / ${n_demo}'}
98
+ gradient_accumulate_every: 1
99
+ # EMA destroys performance when used with BatchNorm
100
+ # replace BatchNorm with GroupNorm.
101
+ use_ema: True
102
+ # training loop control
103
+ # in epochs
104
+ rollout_every: ${eval:'1000 / ${n_demo}'}
105
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
106
+ val_every: 1
107
+ sample_every: 5
108
+ # steps per epoch
109
+ max_train_steps: null
110
+ max_val_steps: null
111
+ # misc
112
+ tqdm_interval_sec: 1.0
113
+
114
+ logging:
115
+ project: diffusion_policy_${task_name}
116
+ resume: True
117
+ mode: online
118
+ name: diff_t_demo${n_demo}
119
+ tags: ["${name}", "${task_name}", "${exp_name}"]
120
+ id: null
121
+ group: null
122
+
123
+ checkpoint:
124
+ topk:
125
+ monitor_key: test_mean_score
126
+ mode: max
127
+ k: 5
128
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
129
+ save_last_ckpt: True
130
+ save_last_snapshot: False
131
+
132
+ multi_run:
133
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
134
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
135
+
136
+ hydra:
137
+ job:
138
+ override_dirname: ${name}
139
+ run:
140
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
141
+ sweep:
142
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
143
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_diffusion_unet.yaml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: diff_c
6
+ _target_: equi_diffpo.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: 16
14
+ n_obs_steps: 2
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ obs_as_global_cond: True
20
+ dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
21
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
22
+
23
+ policy:
24
+ _target_: equi_diffpo.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
25
+
26
+ shape_meta: ${shape_meta}
27
+
28
+ noise_scheduler:
29
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
30
+ num_train_timesteps: 100
31
+ beta_start: 0.0001
32
+ beta_end: 0.02
33
+ beta_schedule: squaredcos_cap_v2
34
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
35
+ clip_sample: True # required when predict_epsilon=False
36
+ prediction_type: epsilon # or sample
37
+
38
+ horizon: ${horizon}
39
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
40
+ n_obs_steps: ${n_obs_steps}
41
+ num_inference_steps: 100
42
+ obs_as_global_cond: ${obs_as_global_cond}
43
+ crop_shape: [76, 76]
44
+ # crop_shape: null
45
+ diffusion_step_embed_dim: 128
46
+ down_dims: [512, 1024, 2048]
47
+ kernel_size: 5
48
+ n_groups: 8
49
+ cond_predict_scale: True
50
+ obs_encoder_group_norm: True
51
+ eval_fixed_crop: True
52
+ rot_aug: False
53
+
54
+ # scheduler.step params
55
+ # predict_epsilon: True
56
+
57
+ ema:
58
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
59
+ update_after_step: 0
60
+ inv_gamma: 1.0
61
+ power: 0.75
62
+ min_value: 0.0
63
+ max_value: 0.9999
64
+
65
+ dataloader:
66
+ batch_size: 64
67
+ num_workers: 4
68
+ shuffle: True
69
+ pin_memory: True
70
+ persistent_workers: True
71
+
72
+ val_dataloader:
73
+ batch_size: 64
74
+ num_workers: 4
75
+ shuffle: False
76
+ pin_memory: True
77
+ persistent_workers: True
78
+
79
+ optimizer:
80
+ _target_: torch.optim.AdamW
81
+ lr: 1.0e-4
82
+ betas: [0.95, 0.999]
83
+ eps: 1.0e-8
84
+ weight_decay: 1.0e-6
85
+
86
+ training:
87
+ device: "cuda:0"
88
+ seed: 0
89
+ debug: False
90
+ resume: True
91
+ # optimization
92
+ lr_scheduler: cosine
93
+ lr_warmup_steps: 500
94
+ num_epochs: ${eval:'50000 / ${n_demo}'}
95
+ gradient_accumulate_every: 1
96
+ # EMA destroys performance when used with BatchNorm
97
+ # replace BatchNorm with GroupNorm.
98
+ use_ema: True
99
+ # training loop control
100
+ # in epochs
101
+ rollout_every: ${eval:'1000 / ${n_demo}'}
102
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
103
+ val_every: 1
104
+ sample_every: 5
105
+ # steps per epoch
106
+ max_train_steps: null
107
+ max_val_steps: null
108
+ # misc
109
+ tqdm_interval_sec: 1.0
110
+
111
+ logging:
112
+ project: diffusion_policy_${task_name}
113
+ resume: True
114
+ mode: online
115
+ name: diff_c_demo${n_demo}
116
+ tags: ["${name}", "${task_name}", "${exp_name}"]
117
+ id: null
118
+ group: null
119
+
120
+ checkpoint:
121
+ topk:
122
+ monitor_key: test_mean_score
123
+ mode: max
124
+ k: 5
125
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
126
+ save_last_ckpt: True
127
+ save_last_snapshot: False
128
+
129
+ multi_run:
130
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
131
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
132
+
133
+ hydra:
134
+ job:
135
+ override_dirname: ${name}
136
+ run:
137
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
138
+ sweep:
139
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
140
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_diffusion_unet_voxel_abs.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_voxel_abs
4
+
5
+ name: diff_voxel
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: 16
14
+ n_obs_steps: 1
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.diffusion_unet_voxel_policy.DiffusionUNetPolicyVoxel
24
+
25
+ shape_meta: ${shape_meta}
26
+
27
+ noise_scheduler:
28
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29
+ num_train_timesteps: 100
30
+ beta_start: 0.0001
31
+ beta_end: 0.02
32
+ beta_schedule: squaredcos_cap_v2
33
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34
+ clip_sample: True # required when predict_epsilon=False
35
+ prediction_type: epsilon # or sample
36
+
37
+ horizon: ${horizon}
38
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39
+ n_obs_steps: ${n_obs_steps}
40
+ num_inference_steps: 100
41
+ crop_shape: [58, 58, 58]
42
+ # crop_shape: null
43
+ diffusion_step_embed_dim: 128
44
+ enc_n_hidden: 256
45
+ down_dims: [256, 512, 1024]
46
+ kernel_size: 5
47
+ n_groups: 8
48
+ cond_predict_scale: True
49
+ rot_aug: False
50
+
51
+ # scheduler.step params
52
+ # predict_epsilon: True
53
+
54
+ ema:
55
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56
+ update_after_step: 0
57
+ inv_gamma: 1.0
58
+ power: 0.75
59
+ min_value: 0.0
60
+ max_value: 0.9999
61
+
62
+ dataloader:
63
+ batch_size: 64
64
+ num_workers: 16
65
+ shuffle: True
66
+ pin_memory: True
67
+ persistent_workers: True
68
+ drop_last: true
69
+
70
+ val_dataloader:
71
+ batch_size: 64
72
+ num_workers: 16
73
+ shuffle: False
74
+ pin_memory: True
75
+ persistent_workers: True
76
+
77
+ optimizer:
78
+ betas: [0.95, 0.999]
79
+ eps: 1.0e-08
80
+ learning_rate: 0.0001
81
+ weight_decay: 1.0e-06
82
+
83
+ training:
84
+ device: "cuda:0"
85
+ seed: 0
86
+ debug: False
87
+ resume: True
88
+ # optimization
89
+ lr_scheduler: cosine
90
+ lr_warmup_steps: 500
91
+ num_epochs: ${eval:'50000 / ${n_demo}'}
92
+ gradient_accumulate_every: 1
93
+ # EMA destroys performance when used with BatchNorm
94
+ # replace BatchNorm with GroupNorm.
95
+ use_ema: True
96
+ # training loop control
97
+ # in epochs
98
+ rollout_every: ${eval:'1000 / ${n_demo}'}
99
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
100
+ val_every: 1
101
+ sample_every: 5
102
+ # steps per epoch
103
+ max_train_steps: null
104
+ max_val_steps: null
105
+ # misc
106
+ tqdm_interval_sec: 1.0
107
+
108
+ logging:
109
+ project: equi_diff_${task_name}_voxel
110
+ resume: True
111
+ mode: online
112
+ name: diff_voxel_${n_demo}
113
+ tags: ["${name}", "${task_name}", "${exp_name}"]
114
+ id: null
115
+ group: null
116
+
117
+ checkpoint:
118
+ topk:
119
+ monitor_key: test_mean_score
120
+ mode: max
121
+ k: 5
122
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123
+ save_last_ckpt: True
124
+ save_last_snapshot: False
125
+
126
+ multi_run:
127
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129
+
130
+ hydra:
131
+ job:
132
+ override_dirname: ${name}
133
+ run:
134
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135
+ sweep:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: equi_diff
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 1000
13
+ horizon: 16
14
+ n_obs_steps: 2
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
24
+
25
+ shape_meta: ${shape_meta}
26
+
27
+ noise_scheduler:
28
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29
+ num_train_timesteps: 100
30
+ beta_start: 0.0001
31
+ beta_end: 0.02
32
+ beta_schedule: squaredcos_cap_v2
33
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34
+ clip_sample: True # required when predict_epsilon=False
35
+ prediction_type: epsilon # or sample
36
+
37
+ horizon: ${horizon}
38
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39
+ n_obs_steps: ${n_obs_steps}
40
+ num_inference_steps: 100
41
+ crop_shape: [76, 76]
42
+ # crop_shape: null
43
+ diffusion_step_embed_dim: 128
44
+ enc_n_hidden: 128
45
+ down_dims: [512, 1024, 2048]
46
+ kernel_size: 5
47
+ n_groups: 8
48
+ cond_predict_scale: True
49
+ rot_aug: False
50
+
51
+ # scheduler.step params
52
+ # predict_epsilon: True
53
+
54
+ ema:
55
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56
+ update_after_step: 0
57
+ inv_gamma: 1.0
58
+ power: 0.75
59
+ min_value: 0.0
60
+ max_value: 0.9999
61
+
62
+ dataloader:
63
+ batch_size: 128
64
+ num_workers: 4
65
+ shuffle: True
66
+ pin_memory: True
67
+ persistent_workers: True
68
+ drop_last: true
69
+
70
+ val_dataloader:
71
+ batch_size: 128
72
+ num_workers: 8
73
+ shuffle: False
74
+ pin_memory: True
75
+ persistent_workers: True
76
+
77
+ optimizer:
78
+ betas: [0.95, 0.999]
79
+ eps: 1.0e-08
80
+ learning_rate: 0.0001
81
+ weight_decay: 1.0e-06
82
+
83
+ training:
84
+ device: "cuda:0"
85
+ seed: 0
86
+ debug: False
87
+ resume: True
88
+ # optimization
89
+ lr_scheduler: cosine
90
+ lr_warmup_steps: 500
91
+ num_epochs: ${eval:'50000 / ${n_demo}'}
92
+ gradient_accumulate_every: 1
93
+ # EMA destroys performance when used with BatchNorm
94
+ # replace BatchNorm with GroupNorm.
95
+ use_ema: True
96
+ # training loop control
97
+ # in epochs
98
+ rollout_every: ${eval:'1000 / ${n_demo}'}
99
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
100
+ val_every: 1
101
+ sample_every: 5
102
+ # steps per epoch
103
+ max_train_steps: null
104
+ max_val_steps: null
105
+ # misc
106
+ tqdm_interval_sec: 1.0
107
+
108
+ logging:
109
+ project: diffusion_policy_${task_name}
110
+ resume: True
111
+ mode: online
112
+ name: equidiff_demo${n_demo}
113
+ tags: ["${name}", "${task_name}", "${exp_name}"]
114
+ id: null
115
+ group: null
116
+
117
+ checkpoint:
118
+ topk:
119
+ monitor_key: test_mean_score
120
+ mode: max
121
+ k: 5
122
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123
+ save_last_ckpt: True
124
+ save_last_snapshot: False
125
+
126
+ multi_run:
127
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129
+
130
+ hydra:
131
+ job:
132
+ override_dirname: ${name}
133
+ run:
134
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135
+ sweep:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs_sq2_0-1.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: equi_diff
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: square_d2
12
+ n_demo: 1000
13
+ horizon: 16
14
+ n_obs_steps: 2
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
24
+
25
+ shape_meta: ${shape_meta}
26
+
27
+ noise_scheduler:
28
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29
+ num_train_timesteps: 100
30
+ beta_start: 0.0001
31
+ beta_end: 0.02
32
+ beta_schedule: squaredcos_cap_v2
33
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34
+ clip_sample: True # required when predict_epsilon=False
35
+ prediction_type: epsilon # or sample
36
+
37
+ horizon: ${horizon}
38
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39
+ n_obs_steps: ${n_obs_steps}
40
+ num_inference_steps: 100
41
+ crop_shape: [76, 76]
42
+ # crop_shape: null
43
+ diffusion_step_embed_dim: 128
44
+ enc_n_hidden: 128
45
+ down_dims: [512, 1024, 2048]
46
+ kernel_size: 5
47
+ n_groups: 8
48
+ cond_predict_scale: True
49
+ rot_aug: False
50
+
51
+ # scheduler.step params
52
+ # predict_epsilon: True
53
+
54
+ ema:
55
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56
+ update_after_step: 0
57
+ inv_gamma: 1.0
58
+ power: 0.75
59
+ min_value: 0.0
60
+ max_value: 0.9999
61
+
62
+ dataloader:
63
+ batch_size: 128
64
+ num_workers: 4
65
+ shuffle: True
66
+ pin_memory: True
67
+ persistent_workers: True
68
+ drop_last: true
69
+
70
+ val_dataloader:
71
+ batch_size: 128
72
+ num_workers: 8
73
+ shuffle: False
74
+ pin_memory: True
75
+ persistent_workers: True
76
+
77
+ optimizer:
78
+ betas: [0.95, 0.999]
79
+ eps: 1.0e-08
80
+ learning_rate: 0.0001
81
+ weight_decay: 1.0e-06
82
+
83
+ training:
84
+ device: "cuda:0"
85
+ seed: 0
86
+ debug: False
87
+ resume: True
88
+ # optimization
89
+ lr_scheduler: cosine
90
+ lr_warmup_steps: 500
91
+ num_epochs: ${eval:'50000 / ${n_demo}'}
92
+ gradient_accumulate_every: 1
93
+ # EMA destroys performance when used with BatchNorm
94
+ # replace BatchNorm with GroupNorm.
95
+ use_ema: True
96
+ # training loop control
97
+ # in epochs
98
+ rollout_every: ${eval:'1000 / ${n_demo}'}
99
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
100
+ val_every: 1
101
+ sample_every: 5
102
+ # steps per epoch
103
+ max_train_steps: null
104
+ max_val_steps: null
105
+ # misc
106
+ tqdm_interval_sec: 1.0
107
+
108
+ logging:
109
+ project: diffusion_policy_${task_name}
110
+ resume: True
111
+ mode: online
112
+ name: equidiff_demo${n_demo}
113
+ tags: ["${name}", "${task_name}", "${exp_name}"]
114
+ id: null
115
+ group: null
116
+
117
+ checkpoint:
118
+ topk:
119
+ monitor_key: test_mean_score
120
+ mode: max
121
+ k: 5
122
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123
+ save_last_ckpt: True
124
+ save_last_snapshot: False
125
+
126
+ multi_run:
127
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129
+
130
+ hydra:
131
+ job:
132
+ override_dirname: ${name}
133
+ run:
134
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135
+ sweep:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs_sq2_1-1.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: equi_diff
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: square_d2
12
+ n_demo: 1000
13
+ horizon: 16
14
+ n_obs_steps: 2
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}_1-1/${task_name}_1-1_abs.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
24
+
25
+ shape_meta: ${shape_meta}
26
+
27
+ noise_scheduler:
28
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29
+ num_train_timesteps: 100
30
+ beta_start: 0.0001
31
+ beta_end: 0.02
32
+ beta_schedule: squaredcos_cap_v2
33
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34
+ clip_sample: True # required when predict_epsilon=False
35
+ prediction_type: epsilon # or sample
36
+
37
+ horizon: ${horizon}
38
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39
+ n_obs_steps: ${n_obs_steps}
40
+ num_inference_steps: 100
41
+ crop_shape: [76, 76]
42
+ # crop_shape: null
43
+ diffusion_step_embed_dim: 128
44
+ enc_n_hidden: 128
45
+ down_dims: [512, 1024, 2048]
46
+ kernel_size: 5
47
+ n_groups: 8
48
+ cond_predict_scale: True
49
+ rot_aug: False
50
+
51
+ # scheduler.step params
52
+ # predict_epsilon: True
53
+
54
+ ema:
55
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56
+ update_after_step: 0
57
+ inv_gamma: 1.0
58
+ power: 0.75
59
+ min_value: 0.0
60
+ max_value: 0.9999
61
+
62
+ dataloader:
63
+ batch_size: 128
64
+ num_workers: 4
65
+ shuffle: True
66
+ pin_memory: True
67
+ persistent_workers: True
68
+ drop_last: true
69
+
70
+ val_dataloader:
71
+ batch_size: 128
72
+ num_workers: 8
73
+ shuffle: False
74
+ pin_memory: True
75
+ persistent_workers: True
76
+
77
+ optimizer:
78
+ betas: [0.95, 0.999]
79
+ eps: 1.0e-08
80
+ learning_rate: 0.0001
81
+ weight_decay: 1.0e-06
82
+
83
+ training:
84
+ device: "cuda:0"
85
+ seed: 0
86
+ debug: False
87
+ resume: True
88
+ # optimization
89
+ lr_scheduler: cosine
90
+ lr_warmup_steps: 500
91
+ num_epochs: ${eval:'50000 / ${n_demo}'}
92
+ gradient_accumulate_every: 1
93
+ # EMA destroys performance when used with BatchNorm
94
+ # replace BatchNorm with GroupNorm.
95
+ use_ema: True
96
+ # training loop control
97
+ # in epochs
98
+ rollout_every: ${eval:'1000 / ${n_demo}'}
99
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
100
+ val_every: 1
101
+ sample_every: 5
102
+ # steps per epoch
103
+ max_train_steps: null
104
+ max_val_steps: null
105
+ # misc
106
+ tqdm_interval_sec: 1.0
107
+
108
+ logging:
109
+ project: diffusion_policy_${task_name}
110
+ resume: True
111
+ mode: online
112
+ name: equidiff_demo${n_demo}
113
+ tags: ["${name}", "${task_name}", "${exp_name}"]
114
+ id: null
115
+ group: null
116
+
117
+ checkpoint:
118
+ topk:
119
+ monitor_key: test_mean_score
120
+ mode: max
121
+ k: 5
122
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123
+ save_last_ckpt: True
124
+ save_last_snapshot: False
125
+
126
+ multi_run:
127
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129
+
130
+ hydra:
131
+ job:
132
+ override_dirname: ${name}
133
+ run:
134
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135
+ sweep:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_equi_diffusion_unet_rel.yaml ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_rel
4
+
5
+ name: equi_diff
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: 16
14
+ n_obs_steps: 2
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_rel_policy.DiffusionEquiUNetCNNEncRelPolicy
24
+
25
+ shape_meta: ${shape_meta}
26
+
27
+ noise_scheduler:
28
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29
+ num_train_timesteps: 100
30
+ beta_start: 0.0001
31
+ beta_end: 0.02
32
+ beta_schedule: squaredcos_cap_v2
33
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34
+ clip_sample: True # required when predict_epsilon=False
35
+ prediction_type: epsilon # or sample
36
+
37
+ horizon: ${horizon}
38
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39
+ n_obs_steps: ${n_obs_steps}
40
+ num_inference_steps: 100
41
+ crop_shape: [76, 76]
42
+ # crop_shape: null
43
+ diffusion_step_embed_dim: 128
44
+ enc_n_hidden: 128
45
+ down_dims: [512, 1024, 2048]
46
+ kernel_size: 5
47
+ n_groups: 8
48
+ cond_predict_scale: True
49
+
50
+ # scheduler.step params
51
+ # predict_epsilon: True
52
+
53
+ ema:
54
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
55
+ update_after_step: 0
56
+ inv_gamma: 1.0
57
+ power: 0.75
58
+ min_value: 0.0
59
+ max_value: 0.9999
60
+
61
+ dataloader:
62
+ batch_size: 128
63
+ num_workers: 4
64
+ shuffle: True
65
+ pin_memory: True
66
+ persistent_workers: True
67
+ drop_last: true
68
+
69
+ val_dataloader:
70
+ batch_size: 128
71
+ num_workers: 8
72
+ shuffle: False
73
+ pin_memory: True
74
+ persistent_workers: True
75
+
76
+ optimizer:
77
+ betas: [0.95, 0.999]
78
+ eps: 1.0e-08
79
+ learning_rate: 0.0001
80
+ weight_decay: 1.0e-06
81
+
82
+ training:
83
+ device: "cuda:0"
84
+ seed: 0
85
+ debug: False
86
+ resume: True
87
+ # optimization
88
+ lr_scheduler: cosine
89
+ lr_warmup_steps: 500
90
+ num_epochs: ${eval:'50000 / ${n_demo}'}
91
+ gradient_accumulate_every: 1
92
+ # EMA destroys performance when used with BatchNorm
93
+ # replace BatchNorm with GroupNorm.
94
+ use_ema: True
95
+ # training loop control
96
+ # in epochs
97
+ rollout_every: ${eval:'1000 / ${n_demo}'}
98
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
99
+ val_every: 1
100
+ sample_every: 5
101
+ # steps per epoch
102
+ max_train_steps: null
103
+ max_val_steps: null
104
+ # misc
105
+ tqdm_interval_sec: 1.0
106
+
107
+ logging:
108
+ project: diffusion_policy_${task_name}_vel
109
+ resume: True
110
+ mode: online
111
+ name: equidiff_demo${n_demo}
112
+ tags: ["${name}", "${task_name}", "${exp_name}"]
113
+ id: null
114
+ group: null
115
+
116
+ checkpoint:
117
+ topk:
118
+ monitor_key: test_mean_score
119
+ mode: max
120
+ k: 5
121
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
122
+ save_last_ckpt: True
123
+ save_last_snapshot: False
124
+
125
+ multi_run:
126
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
127
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
128
+
129
+ hydra:
130
+ job:
131
+ override_dirname: ${name}
132
+ run:
133
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
134
+ sweep:
135
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
136
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_equi_diffusion_unet_voxel_abs.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_voxel_abs
4
+
5
+ name: equi_diff_voxel
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: 16
14
+ n_obs_steps: 1
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.diffusion_equi_unet_voxel_policy.DiffusionEquiUNetPolicyVoxel
24
+
25
+ shape_meta: ${shape_meta}
26
+
27
+ noise_scheduler:
28
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29
+ num_train_timesteps: 100
30
+ beta_start: 0.0001
31
+ beta_end: 0.02
32
+ beta_schedule: squaredcos_cap_v2
33
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34
+ clip_sample: True # required when predict_epsilon=False
35
+ prediction_type: epsilon # or sample
36
+
37
+ horizon: ${horizon}
38
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39
+ n_obs_steps: ${n_obs_steps}
40
+ num_inference_steps: 100
41
+ crop_shape: [58, 58, 58]
42
+ # crop_shape: null
43
+ diffusion_step_embed_dim: 128
44
+ enc_n_hidden: 128
45
+ down_dims: [256, 512, 1024]
46
+ kernel_size: 5
47
+ n_groups: 8
48
+ cond_predict_scale: True
49
+ rot_aug: True
50
+
51
+ # scheduler.step params
52
+ # predict_epsilon: True
53
+
54
+ ema:
55
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56
+ update_after_step: 0
57
+ inv_gamma: 1.0
58
+ power: 0.75
59
+ min_value: 0.0
60
+ max_value: 0.9999
61
+
62
+ dataloader:
63
+ batch_size: 64
64
+ num_workers: 16
65
+ shuffle: True
66
+ pin_memory: True
67
+ persistent_workers: True
68
+ drop_last: true
69
+
70
+ val_dataloader:
71
+ batch_size: 64
72
+ num_workers: 16
73
+ shuffle: False
74
+ pin_memory: True
75
+ persistent_workers: True
76
+
77
+ optimizer:
78
+ betas: [0.95, 0.999]
79
+ eps: 1.0e-08
80
+ learning_rate: 0.0001
81
+ weight_decay: 1.0e-06
82
+
83
+ training:
84
+ device: "cuda:0"
85
+ seed: 0
86
+ debug: False
87
+ resume: True
88
+ # optimization
89
+ lr_scheduler: cosine
90
+ lr_warmup_steps: 500
91
+ num_epochs: ${eval:'50000 / ${n_demo}'}
92
+ gradient_accumulate_every: 1
93
+ # EMA destroys performance when used with BatchNorm
94
+ # replace BatchNorm with GroupNorm.
95
+ use_ema: True
96
+ # training loop control
97
+ # in epochs
98
+ rollout_every: ${eval:'1000 / ${n_demo}'}
99
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
100
+ val_every: 1
101
+ sample_every: 5
102
+ # steps per epoch
103
+ max_train_steps: null
104
+ max_val_steps: null
105
+ # misc
106
+ tqdm_interval_sec: 1.0
107
+
108
+ logging:
109
+ project: equi_diff_${task_name}_voxel
110
+ resume: True
111
+ mode: online
112
+ name: equi_diff_voxel_${n_demo}
113
+ tags: ["${name}", "${task_name}", "${exp_name}"]
114
+ id: null
115
+ group: null
116
+
117
+ checkpoint:
118
+ topk:
119
+ monitor_key: test_mean_score
120
+ mode: max
121
+ k: 5
122
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123
+ save_last_ckpt: True
124
+ save_last_snapshot: False
125
+
126
+ multi_run:
127
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129
+
130
+ hydra:
131
+ job:
132
+ override_dirname: ${name}
133
+ run:
134
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135
+ sweep:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_equi_diffusion_unet_voxel_rel.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_voxel_rel
4
+
5
+ name: equi_diff_voxel
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: stack_d1
12
+ n_demo: 200
13
+ horizon: 16
14
+ n_obs_steps: 1
15
+ n_action_steps: 8
16
+ n_latency_steps: 0
17
+ dataset_obs_steps: ${n_obs_steps}
18
+ past_action_visible: False
19
+ # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20
+ dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel.hdf5
21
+
22
+ policy:
23
+ _target_: equi_diffpo.policy.diffusion_equi_unet_voxel_rel_policy.DiffusionEquiUNetRelPolicyVoxel
24
+
25
+ shape_meta: ${shape_meta}
26
+
27
+ noise_scheduler:
28
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29
+ num_train_timesteps: 100
30
+ beta_start: 0.0001
31
+ beta_end: 0.02
32
+ beta_schedule: squaredcos_cap_v2
33
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34
+ clip_sample: True # required when predict_epsilon=False
35
+ prediction_type: epsilon # or sample
36
+
37
+ horizon: ${horizon}
38
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39
+ n_obs_steps: ${n_obs_steps}
40
+ num_inference_steps: 100
41
+ crop_shape: [58, 58, 58]
42
+ # crop_shape: null
43
+ diffusion_step_embed_dim: 128
44
+ enc_n_hidden: 128
45
+ down_dims: [256, 512, 1024]
46
+ kernel_size: 5
47
+ n_groups: 8
48
+ cond_predict_scale: True
49
+ rot_aug: True
50
+
51
+ # scheduler.step params
52
+ # predict_epsilon: True
53
+
54
+ ema:
55
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56
+ update_after_step: 0
57
+ inv_gamma: 1.0
58
+ power: 0.75
59
+ min_value: 0.0
60
+ max_value: 0.9999
61
+
62
+ dataloader:
63
+ batch_size: 64
64
+ num_workers: 16
65
+ shuffle: True
66
+ pin_memory: True
67
+ persistent_workers: True
68
+ drop_last: true
69
+
70
+ val_dataloader:
71
+ batch_size: 64
72
+ num_workers: 16
73
+ shuffle: False
74
+ pin_memory: True
75
+ persistent_workers: True
76
+
77
+ optimizer:
78
+ betas: [0.95, 0.999]
79
+ eps: 1.0e-08
80
+ learning_rate: 0.0001
81
+ weight_decay: 1.0e-06
82
+
83
+ training:
84
+ device: "cuda:0"
85
+ seed: 0
86
+ debug: False
87
+ resume: True
88
+ # optimization
89
+ lr_scheduler: cosine
90
+ lr_warmup_steps: 500
91
+ num_epochs: ${eval:'50000 / ${n_demo}'}
92
+ gradient_accumulate_every: 1
93
+ # EMA destroys performance when used with BatchNorm
94
+ # replace BatchNorm with GroupNorm.
95
+ use_ema: True
96
+ # training loop control
97
+ # in epochs
98
+ rollout_every: ${eval:'1000 / ${n_demo}'}
99
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
100
+ val_every: 1
101
+ sample_every: 5
102
+ # steps per epoch
103
+ max_train_steps: null
104
+ max_val_steps: null
105
+ # misc
106
+ tqdm_interval_sec: 1.0
107
+
108
+ logging:
109
+ project: equi_diff_${task_name}_voxel_rel
110
+ resume: True
111
+ mode: online
112
+ name: equi_diff_voxel_${n_demo}
113
+ tags: ["${name}", "${task_name}", "${exp_name}"]
114
+ id: null
115
+ group: null
116
+
117
+ checkpoint:
118
+ topk:
119
+ monitor_key: test_mean_score
120
+ mode: max
121
+ k: 5
122
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123
+ save_last_ckpt: True
124
+ save_last_snapshot: False
125
+
126
+ multi_run:
127
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129
+
130
+ hydra:
131
+ job:
132
+ override_dirname: ${name}
133
+ run:
134
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135
+ sweep:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_sq2.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: equi_diff
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: square_d2
12
+ folder_name: square_d2
13
+ file_name: square_d2
14
+ n_demo: 1000
15
+ horizon: 16
16
+ n_obs_steps: 2
17
+ n_action_steps: 8
18
+ n_latency_steps: 0
19
+ dataset_obs_steps: ${n_obs_steps}
20
+ past_action_visible: False
21
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
22
+ dataset_path: data/robomimic/datasets/${folder_name}/${file_name}_abs.hdf5
23
+
24
+ policy:
25
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
26
+
27
+ shape_meta: ${shape_meta}
28
+
29
+ noise_scheduler:
30
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
31
+ num_train_timesteps: 100
32
+ beta_start: 0.0001
33
+ beta_end: 0.02
34
+ beta_schedule: squaredcos_cap_v2
35
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
36
+ clip_sample: True # required when predict_epsilon=False
37
+ prediction_type: epsilon # or sample
38
+
39
+ horizon: ${horizon}
40
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
41
+ n_obs_steps: ${n_obs_steps}
42
+ num_inference_steps: 100
43
+ crop_shape: [76, 76]
44
+ # crop_shape: null
45
+ diffusion_step_embed_dim: 128
46
+ enc_n_hidden: 128
47
+ down_dims: [512, 1024, 2048]
48
+ kernel_size: 5
49
+ n_groups: 8
50
+ cond_predict_scale: True
51
+ rot_aug: False
52
+
53
+ # scheduler.step params
54
+ # predict_epsilon: True
55
+
56
+ ema:
57
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
58
+ update_after_step: 0
59
+ inv_gamma: 1.0
60
+ power: 0.75
61
+ min_value: 0.0
62
+ max_value: 0.9999
63
+
64
+ dataloader:
65
+ batch_size: 128
66
+ num_workers: 4
67
+ shuffle: True
68
+ pin_memory: True
69
+ persistent_workers: True
70
+ drop_last: true
71
+
72
+ val_dataloader:
73
+ batch_size: 128
74
+ num_workers: 8
75
+ shuffle: False
76
+ pin_memory: True
77
+ persistent_workers: True
78
+
79
+ optimizer:
80
+ betas: [0.95, 0.999]
81
+ eps: 1.0e-08
82
+ learning_rate: 0.0001
83
+ weight_decay: 1.0e-06
84
+
85
+ training:
86
+ device: "cuda:0"
87
+ seed: 0
88
+ debug: False
89
+ resume: True
90
+ # optimization
91
+ lr_scheduler: cosine
92
+ lr_warmup_steps: 500
93
+ num_epochs: ${eval:'50000 / ${n_demo}'}
94
+ gradient_accumulate_every: 1
95
+ # EMA destroys performance when used with BatchNorm
96
+ # replace BatchNorm with GroupNorm.
97
+ use_ema: True
98
+ # training loop control
99
+ # in epochs
100
+ rollout_every: ${eval:'1000 / ${n_demo}'}
101
+ checkpoint_every: ${eval:'1000 / ${n_demo}'}
102
+ val_every: 1
103
+ sample_every: 5
104
+ # steps per epoch
105
+ max_train_steps: null
106
+ max_val_steps: null
107
+ # misc
108
+ tqdm_interval_sec: 1.0
109
+
110
+ logging:
111
+ project: diffusion_policy_${task_name}
112
+ resume: True
113
+ mode: online
114
+ name: equidiff_demo${n_demo}
115
+ tags: ["${name}", "${task_name}", "${exp_name}"]
116
+ id: null
117
+ group: null
118
+
119
+ checkpoint:
120
+ topk:
121
+ monitor_key: test_mean_score
122
+ mode: max
123
+ k: 5
124
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
125
+ save_last_ckpt: True
126
+ save_last_snapshot: False
127
+
128
+ multi_run:
129
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
130
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
131
+
132
+ hydra:
133
+ job:
134
+ override_dirname: ${name}
135
+ run:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ sweep:
138
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
139
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_sq2_5000.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: equi_diff
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: square_d2
12
+ folder_name: square_d2
13
+ file_name: square_d2
14
+ n_demo: 5000
15
+ horizon: 16
16
+ n_obs_steps: 2
17
+ n_action_steps: 8
18
+ n_latency_steps: 0
19
+ dataset_obs_steps: ${n_obs_steps}
20
+ past_action_visible: False
21
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
22
+ dataset_path: data/robomimic/datasets/${folder_name}/${file_name}_abs.hdf5
23
+
24
+ policy:
25
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
26
+
27
+ shape_meta: ${shape_meta}
28
+
29
+ noise_scheduler:
30
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
31
+ num_train_timesteps: 100
32
+ beta_start: 0.0001
33
+ beta_end: 0.02
34
+ beta_schedule: squaredcos_cap_v2
35
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
36
+ clip_sample: True # required when predict_epsilon=False
37
+ prediction_type: epsilon # or sample
38
+
39
+ horizon: ${horizon}
40
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
41
+ n_obs_steps: ${n_obs_steps}
42
+ num_inference_steps: 100
43
+ crop_shape: [76, 76]
44
+ # crop_shape: null
45
+ diffusion_step_embed_dim: 128
46
+ enc_n_hidden: 128
47
+ down_dims: [512, 1024, 2048]
48
+ kernel_size: 5
49
+ n_groups: 8
50
+ cond_predict_scale: True
51
+ rot_aug: False
52
+
53
+ # scheduler.step params
54
+ # predict_epsilon: True
55
+
56
+ ema:
57
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
58
+ update_after_step: 0
59
+ inv_gamma: 1.0
60
+ power: 0.75
61
+ min_value: 0.0
62
+ max_value: 0.9999
63
+
64
+ dataloader:
65
+ batch_size: 128
66
+ num_workers: 4
67
+ shuffle: True
68
+ pin_memory: True
69
+ persistent_workers: True
70
+ drop_last: true
71
+
72
+ val_dataloader:
73
+ batch_size: 128
74
+ num_workers: 8
75
+ shuffle: False
76
+ pin_memory: True
77
+ persistent_workers: True
78
+
79
+ optimizer:
80
+ betas: [0.95, 0.999]
81
+ eps: 1.0e-08
82
+ learning_rate: 0.0001
83
+ weight_decay: 1.0e-06
84
+
85
+ training:
86
+ device: "cuda:0"
87
+ seed: 0
88
+ debug: False
89
+ resume: True
90
+ # optimization
91
+ lr_scheduler: cosine
92
+ lr_warmup_steps: 2500
93
+ num_epochs: ${eval:'100000 / ${n_demo}'}
94
+ gradient_accumulate_every: 1
95
+ # EMA destroys performance when used with BatchNorm
96
+ # replace BatchNorm with GroupNorm.
97
+ use_ema: True
98
+ # training loop control
99
+ # in epochs
100
+ rollout_every: ${eval:'5000 / ${n_demo}'}
101
+ checkpoint_every: ${eval:'5000 / ${n_demo}'}
102
+ val_every: 1
103
+ sample_every: 5
104
+ # steps per epoch
105
+ max_train_steps: null
106
+ max_val_steps: null
107
+ # misc
108
+ tqdm_interval_sec: 1.0
109
+
110
+ logging:
111
+ project: diffusion_policy_${task_name}
112
+ resume: True
113
+ mode: online
114
+ name: equidiff_demo${n_demo}
115
+ tags: ["${name}", "${task_name}", "${exp_name}"]
116
+ id: null
117
+ group: null
118
+
119
+ checkpoint:
120
+ topk:
121
+ monitor_key: test_mean_score
122
+ mode: max
123
+ k: 5
124
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
125
+ save_last_ckpt: True
126
+ save_last_snapshot: False
127
+
128
+ multi_run:
129
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
130
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
131
+
132
+ hydra:
133
+ job:
134
+ override_dirname: ${name}
135
+ run:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ sweep:
138
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
139
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/config/train_th2_5000.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: mimicgen_abs
4
+
5
+ name: equi_diff
6
+ _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7
+
8
+ shape_meta: ${task.shape_meta}
9
+ exp_name: "default"
10
+
11
+ task_name: threading_d2
12
+ folder_name: threading_d2
13
+ file_name: threading_d2
14
+ n_demo: 5000
15
+ horizon: 16
16
+ n_obs_steps: 2
17
+ n_action_steps: 8
18
+ n_latency_steps: 0
19
+ dataset_obs_steps: ${n_obs_steps}
20
+ past_action_visible: False
21
+ dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
22
+ dataset_path: data/robomimic/datasets/${folder_name}/${file_name}_abs.hdf5
23
+
24
+ policy:
25
+ _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
26
+
27
+ shape_meta: ${shape_meta}
28
+
29
+ noise_scheduler:
30
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
31
+ num_train_timesteps: 100
32
+ beta_start: 0.0001
33
+ beta_end: 0.02
34
+ beta_schedule: squaredcos_cap_v2
35
+ variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
36
+ clip_sample: True # required when predict_epsilon=False
37
+ prediction_type: epsilon # or sample
38
+
39
+ horizon: ${horizon}
40
+ n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
41
+ n_obs_steps: ${n_obs_steps}
42
+ num_inference_steps: 100
43
+ crop_shape: [76, 76]
44
+ # crop_shape: null
45
+ diffusion_step_embed_dim: 128
46
+ enc_n_hidden: 128
47
+ down_dims: [512, 1024, 2048]
48
+ kernel_size: 5
49
+ n_groups: 8
50
+ cond_predict_scale: True
51
+ rot_aug: False
52
+
53
+ # scheduler.step params
54
+ # predict_epsilon: True
55
+
56
+ ema:
57
+ _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
58
+ update_after_step: 0
59
+ inv_gamma: 1.0
60
+ power: 0.75
61
+ min_value: 0.0
62
+ max_value: 0.9999
63
+
64
+ dataloader:
65
+ batch_size: 128
66
+ num_workers: 4
67
+ shuffle: True
68
+ pin_memory: True
69
+ persistent_workers: True
70
+ drop_last: true
71
+
72
+ val_dataloader:
73
+ batch_size: 128
74
+ num_workers: 8
75
+ shuffle: False
76
+ pin_memory: True
77
+ persistent_workers: True
78
+
79
+ optimizer:
80
+ betas: [0.95, 0.999]
81
+ eps: 1.0e-08
82
+ learning_rate: 0.0001
83
+ weight_decay: 1.0e-06
84
+
85
+ training:
86
+ device: "cuda:0"
87
+ seed: 0
88
+ debug: False
89
+ resume: True
90
+ # optimization
91
+ lr_scheduler: cosine
92
+ lr_warmup_steps: 2500
93
+ num_epochs: ${eval:'100000 / ${n_demo}'}
94
+ gradient_accumulate_every: 1
95
+ # EMA destroys performance when used with BatchNorm
96
+ # replace BatchNorm with GroupNorm.
97
+ use_ema: True
98
+ # training loop control
99
+ # in epochs
100
+ rollout_every: ${eval:'5000 / ${n_demo}'}
101
+ checkpoint_every: ${eval:'5000 / ${n_demo}'}
102
+ val_every: 1
103
+ sample_every: 5
104
+ # steps per epoch
105
+ max_train_steps: null
106
+ max_val_steps: null
107
+ # misc
108
+ tqdm_interval_sec: 1.0
109
+
110
+ logging:
111
+ project: diffusion_policy_${task_name}
112
+ resume: True
113
+ mode: online
114
+ name: equidiff_demo${n_demo}
115
+ tags: ["${name}", "${task_name}", "${exp_name}"]
116
+ id: null
117
+ group: null
118
+
119
+ checkpoint:
120
+ topk:
121
+ monitor_key: test_mean_score
122
+ mode: max
123
+ k: 5
124
+ format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
125
+ save_last_ckpt: True
126
+ save_last_snapshot: False
127
+
128
+ multi_run:
129
+ run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
130
+ wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
131
+
132
+ hydra:
133
+ job:
134
+ override_dirname: ${name}
135
+ run:
136
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137
+ sweep:
138
+ dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
139
+ subdir: ${hydra.job.num}
equidiff/equi_diffpo/dataset/base_dataset.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ import torch.nn
5
+ from equi_diffpo.model.common.normalizer import LinearNormalizer
6
+
7
+ class BaseLowdimDataset(torch.utils.data.Dataset):
8
+ def get_validation_dataset(self) -> 'BaseLowdimDataset':
9
+ # return an empty dataset by default
10
+ return BaseLowdimDataset()
11
+
12
+ def get_normalizer(self, **kwargs) -> LinearNormalizer:
13
+ raise NotImplementedError()
14
+
15
+ def get_all_actions(self) -> torch.Tensor:
16
+ raise NotImplementedError()
17
+
18
+ def __len__(self) -> int:
19
+ return 0
20
+
21
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
22
+ """
23
+ output:
24
+ obs: T, Do
25
+ action: T, Da
26
+ """
27
+ raise NotImplementedError()
28
+
29
+
30
+ class BaseImageDataset(torch.utils.data.Dataset):
31
+ def get_validation_dataset(self) -> 'BaseLowdimDataset':
32
+ # return an empty dataset by default
33
+ return BaseImageDataset()
34
+
35
+ def get_normalizer(self, **kwargs) -> LinearNormalizer:
36
+ raise NotImplementedError()
37
+
38
+ def get_all_actions(self) -> torch.Tensor:
39
+ raise NotImplementedError()
40
+
41
+ def __len__(self) -> int:
42
+ return 0
43
+
44
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
45
+ """
46
+ output:
47
+ obs:
48
+ key: T, *
49
+ action: T, Da
50
+ """
51
+ raise NotImplementedError()
equidiff/equi_diffpo/env_runner/base_image_runner.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from equi_diffpo.policy.base_image_policy import BaseImagePolicy
3
+
4
+ class BaseImageRunner:
5
+ def __init__(self, output_dir):
6
+ self.output_dir = output_dir
7
+
8
+ def run(self, policy: BaseImagePolicy) -> Dict:
9
+ raise NotImplementedError()
equidiff/equi_diffpo/env_runner/base_lowdim_runner.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from equi_diffpo.policy.base_lowdim_policy import BaseLowdimPolicy
3
+
4
+ class BaseLowdimRunner:
5
+ def __init__(self, output_dir):
6
+ self.output_dir = output_dir
7
+
8
+ def run(self, policy: BaseLowdimPolicy) -> Dict:
9
+ raise NotImplementedError()
equidiff/equi_diffpo/gym_util/async_vector_env.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Back ported methods: call, set_attr from v0.26
3
+ Disabled auto-reset after done
4
+ Added render method.
5
+ """
6
+
7
+ import os
8
+ import numpy as np
9
+ import multiprocessing as mp
10
+ if os.getenv("MUJOCO_GL") != "osmesa":
11
+ mp.set_start_method('spawn', force=True)
12
+ import time
13
+ import sys
14
+ from enum import Enum
15
+ from copy import deepcopy
16
+
17
+ from gym import logger
18
+ from gym.vector.vector_env import VectorEnv
19
+ from gym.error import (
20
+ AlreadyPendingCallError,
21
+ NoAsyncCallError,
22
+ ClosedEnvironmentError,
23
+ CustomSpaceError,
24
+ )
25
+ from gym.vector.utils import (
26
+ create_shared_memory,
27
+ create_empty_array,
28
+ write_to_shared_memory,
29
+ read_from_shared_memory,
30
+ concatenate,
31
+ CloudpickleWrapper,
32
+ clear_mpi_env_vars,
33
+ )
34
+
35
+ __all__ = ["AsyncVectorEnv"]
36
+
37
+
38
+ class AsyncState(Enum):
39
+ DEFAULT = "default"
40
+ WAITING_RESET = "reset"
41
+ WAITING_STEP = "step"
42
+ WAITING_CALL = "call"
43
+
44
+
45
+ class AsyncVectorEnv(VectorEnv):
46
+ """Vectorized environment that runs multiple environments in parallel. It
47
+ uses `multiprocessing` processes, and pipes for communication.
48
+ Parameters
49
+ ----------
50
+ env_fns : iterable of callable
51
+ Functions that create the environments.
52
+ observation_space : `gym.spaces.Space` instance, optional
53
+ Observation space of a single environment. If `None`, then the
54
+ observation space of the first environment is taken.
55
+ action_space : `gym.spaces.Space` instance, optional
56
+ Action space of a single environment. If `None`, then the action space
57
+ of the first environment is taken.
58
+ shared_memory : bool (default: `True`)
59
+ If `True`, then the observations from the worker processes are
60
+ communicated back through shared variables. This can improve the
61
+ efficiency if the observations are large (e.g. images).
62
+ copy : bool (default: `True`)
63
+ If `True`, then the `reset` and `step` methods return a copy of the
64
+ observations.
65
+ context : str, optional
66
+ Context for multiprocessing. If `None`, then the default context is used.
67
+ Only available in Python 3.
68
+ daemon : bool (default: `True`)
69
+ If `True`, then subprocesses have `daemon` flag turned on; that is, they
70
+ will quit if the head process quits. However, `daemon=True` prevents
71
+ subprocesses to spawn children, so for some environments you may want
72
+ to have it set to `False`
73
+ worker : function, optional
74
+ WARNING - advanced mode option! If set, then use that worker in a subprocess
75
+ instead of a default one. Can be useful to override some inner vector env
76
+ logic, for instance, how resets on done are handled. Provides high
77
+ degree of flexibility and a high chance to shoot yourself in the foot; thus,
78
+ if you are writing your own worker, it is recommended to start from the code
79
+ for `_worker` (or `_worker_shared_memory`) method below, and add changes
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ env_fns,
85
+ dummy_env_fn=None,
86
+ observation_space=None,
87
+ action_space=None,
88
+ shared_memory=True,
89
+ copy=True,
90
+ context=None,
91
+ daemon=True,
92
+ worker=None,
93
+ ):
94
+ ctx = mp.get_context(context)
95
+ self.env_fns = env_fns
96
+ self.shared_memory = shared_memory
97
+ self.copy = copy
98
+
99
+ # Added dummy_env_fn to fix OpenGL error in Mujoco
100
+ # disable any OpenGL rendering in dummy_env_fn, since it
101
+ # will conflict with OpenGL context in the forked child process
102
+ if dummy_env_fn is None:
103
+ dummy_env_fn = env_fns[0]
104
+ dummy_env = dummy_env_fn()
105
+ self.metadata = dummy_env.metadata
106
+
107
+ if (observation_space is None) or (action_space is None):
108
+ observation_space = observation_space or dummy_env.observation_space
109
+ action_space = action_space or dummy_env.action_space
110
+ dummy_env.close()
111
+ del dummy_env
112
+ super(AsyncVectorEnv, self).__init__(
113
+ num_envs=len(env_fns),
114
+ observation_space=observation_space,
115
+ action_space=action_space,
116
+ )
117
+
118
+ if self.shared_memory:
119
+ try:
120
+ _obs_buffer = create_shared_memory(
121
+ self.single_observation_space, n=self.num_envs, ctx=ctx
122
+ )
123
+ self.observations = read_from_shared_memory(
124
+ _obs_buffer, self.single_observation_space, n=self.num_envs
125
+ )
126
+ except CustomSpaceError:
127
+ raise ValueError(
128
+ "Using `shared_memory=True` in `AsyncVectorEnv` "
129
+ "is incompatible with non-standard Gym observation spaces "
130
+ "(i.e. custom spaces inheriting from `gym.Space`), and is "
131
+ "only compatible with default Gym spaces (e.g. `Box`, "
132
+ "`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
133
+ "if you use custom observation spaces."
134
+ )
135
+ else:
136
+ _obs_buffer = None
137
+ self.observations = create_empty_array(
138
+ self.single_observation_space, n=self.num_envs, fn=np.zeros
139
+ )
140
+
141
+ self.parent_pipes, self.processes = [], []
142
+ self.error_queue = ctx.Queue()
143
+ target = _worker_shared_memory if self.shared_memory else _worker
144
+ target = worker or target
145
+ with clear_mpi_env_vars():
146
+ for idx, env_fn in enumerate(self.env_fns):
147
+ parent_pipe, child_pipe = ctx.Pipe()
148
+ process = ctx.Process(
149
+ target=target,
150
+ name="Worker<{0}>-{1}".format(type(self).__name__, idx),
151
+ args=(
152
+ idx,
153
+ CloudpickleWrapper(env_fn),
154
+ child_pipe,
155
+ parent_pipe,
156
+ _obs_buffer,
157
+ self.error_queue,
158
+ ),
159
+ )
160
+
161
+ self.parent_pipes.append(parent_pipe)
162
+ self.processes.append(process)
163
+
164
+ process.daemon = daemon
165
+ process.start()
166
+ child_pipe.close()
167
+
168
+ self._state = AsyncState.DEFAULT
169
+ self._check_observation_spaces()
170
+
171
+ def seed(self, seeds=None):
172
+ self._assert_is_running()
173
+ if seeds is None:
174
+ seeds = [None for _ in range(self.num_envs)]
175
+ if isinstance(seeds, int):
176
+ seeds = [seeds + i for i in range(self.num_envs)]
177
+ assert len(seeds) == self.num_envs
178
+
179
+ if self._state != AsyncState.DEFAULT:
180
+ raise AlreadyPendingCallError(
181
+ "Calling `seed` while waiting "
182
+ "for a pending call to `{0}` to complete.".format(self._state.value),
183
+ self._state.value,
184
+ )
185
+
186
+ for pipe, seed in zip(self.parent_pipes, seeds):
187
+ pipe.send(("seed", seed))
188
+ _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
189
+ self._raise_if_errors(successes)
190
+
191
+ def reset_async(self):
192
+ self._assert_is_running()
193
+ if self._state != AsyncState.DEFAULT:
194
+ raise AlreadyPendingCallError(
195
+ "Calling `reset_async` while waiting "
196
+ "for a pending call to `{0}` to complete".format(self._state.value),
197
+ self._state.value,
198
+ )
199
+
200
+ for pipe in self.parent_pipes:
201
+ pipe.send(("reset", None))
202
+ self._state = AsyncState.WAITING_RESET
203
+
204
+ def reset_wait(self, timeout=None):
205
+ """
206
+ Parameters
207
+ ----------
208
+ timeout : int or float, optional
209
+ Number of seconds before the call to `reset_wait` times out. If
210
+ `None`, the call to `reset_wait` never times out.
211
+ Returns
212
+ -------
213
+ observations : sample from `observation_space`
214
+ A batch of observations from the vectorized environment.
215
+ """
216
+ self._assert_is_running()
217
+ if self._state != AsyncState.WAITING_RESET:
218
+ raise NoAsyncCallError(
219
+ "Calling `reset_wait` without any prior " "call to `reset_async`.",
220
+ AsyncState.WAITING_RESET.value,
221
+ )
222
+
223
+ if not self._poll(timeout):
224
+ self._state = AsyncState.DEFAULT
225
+ raise mp.TimeoutError(
226
+ "The call to `reset_wait` has timed out after "
227
+ "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
228
+ )
229
+
230
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
231
+ self._raise_if_errors(successes)
232
+ self._state = AsyncState.DEFAULT
233
+
234
+ if not self.shared_memory:
235
+ self.observations = concatenate(
236
+ results, self.observations, self.single_observation_space
237
+ )
238
+
239
+ return deepcopy(self.observations) if self.copy else self.observations
240
+
241
+ def step_async(self, actions):
242
+ """
243
+ Parameters
244
+ ----------
245
+ actions : iterable of samples from `action_space`
246
+ List of actions.
247
+ """
248
+ self._assert_is_running()
249
+ if self._state != AsyncState.DEFAULT:
250
+ raise AlreadyPendingCallError(
251
+ "Calling `step_async` while waiting "
252
+ "for a pending call to `{0}` to complete.".format(self._state.value),
253
+ self._state.value,
254
+ )
255
+
256
+ for pipe, action in zip(self.parent_pipes, actions):
257
+ pipe.send(("step", action))
258
+ self._state = AsyncState.WAITING_STEP
259
+
260
+ def step_wait(self, timeout=None):
261
+ """
262
+ Parameters
263
+ ----------
264
+ timeout : int or float, optional
265
+ Number of seconds before the call to `step_wait` times out. If
266
+ `None`, the call to `step_wait` never times out.
267
+ Returns
268
+ -------
269
+ observations : sample from `observation_space`
270
+ A batch of observations from the vectorized environment.
271
+ rewards : `np.ndarray` instance (dtype `np.float_`)
272
+ A vector of rewards from the vectorized environment.
273
+ dones : `np.ndarray` instance (dtype `np.bool_`)
274
+ A vector whose entries indicate whether the episode has ended.
275
+ infos : list of dict
276
+ A list of auxiliary diagnostic information.
277
+ """
278
+ self._assert_is_running()
279
+ if self._state != AsyncState.WAITING_STEP:
280
+ raise NoAsyncCallError(
281
+ "Calling `step_wait` without any prior call " "to `step_async`.",
282
+ AsyncState.WAITING_STEP.value,
283
+ )
284
+
285
+ if not self._poll(timeout):
286
+ self._state = AsyncState.DEFAULT
287
+ raise mp.TimeoutError(
288
+ "The call to `step_wait` has timed out after "
289
+ "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
290
+ )
291
+
292
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
293
+ self._raise_if_errors(successes)
294
+ self._state = AsyncState.DEFAULT
295
+ observations_list, rewards, dones, infos = zip(*results)
296
+
297
+ if not self.shared_memory:
298
+ self.observations = concatenate(
299
+ observations_list, self.observations, self.single_observation_space
300
+ )
301
+
302
+ return (
303
+ deepcopy(self.observations) if self.copy else self.observations,
304
+ np.array(rewards),
305
+ np.array(dones, dtype=np.bool_),
306
+ infos,
307
+ )
308
+
309
+ def close_extras(self, timeout=None, terminate=False):
310
+ """
311
+ Parameters
312
+ ----------
313
+ timeout : int or float, optional
314
+ Number of seconds before the call to `close` times out. If `None`,
315
+ the call to `close` never times out. If the call to `close` times
316
+ out, then all processes are terminated.
317
+ terminate : bool (default: `False`)
318
+ If `True`, then the `close` operation is forced and all processes
319
+ are terminated.
320
+ """
321
+ timeout = 0 if terminate else timeout
322
+ try:
323
+ if self._state != AsyncState.DEFAULT:
324
+ logger.warn(
325
+ "Calling `close` while waiting for a pending "
326
+ "call to `{0}` to complete.".format(self._state.value)
327
+ )
328
+ function = getattr(self, "{0}_wait".format(self._state.value))
329
+ function(timeout)
330
+ except mp.TimeoutError:
331
+ terminate = True
332
+
333
+ if terminate:
334
+ for process in self.processes:
335
+ if process.is_alive():
336
+ process.terminate()
337
+ else:
338
+ for pipe in self.parent_pipes:
339
+ if (pipe is not None) and (not pipe.closed):
340
+ pipe.send(("close", None))
341
+ for pipe in self.parent_pipes:
342
+ if (pipe is not None) and (not pipe.closed):
343
+ pipe.recv()
344
+
345
+ for pipe in self.parent_pipes:
346
+ if pipe is not None:
347
+ pipe.close()
348
+ for process in self.processes:
349
+ process.join()
350
+
351
+ def _poll(self, timeout=None):
352
+ self._assert_is_running()
353
+ if timeout is None:
354
+ return True
355
+ end_time = time.perf_counter() + timeout
356
+ delta = None
357
+ for pipe in self.parent_pipes:
358
+ delta = max(end_time - time.perf_counter(), 0)
359
+ if pipe is None:
360
+ return False
361
+ if pipe.closed or (not pipe.poll(delta)):
362
+ return False
363
+ return True
364
+
365
+ def _check_observation_spaces(self):
366
+ self._assert_is_running()
367
+ for pipe in self.parent_pipes:
368
+ pipe.send(("_check_observation_space", self.single_observation_space))
369
+ same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
370
+ self._raise_if_errors(successes)
371
+ if not all(same_spaces):
372
+ raise RuntimeError(
373
+ "Some environments have an observation space "
374
+ "different from `{0}`. In order to batch observations, the "
375
+ "observation spaces from all environments must be "
376
+ "equal.".format(self.single_observation_space)
377
+ )
378
+
379
+ def _assert_is_running(self):
380
+ if self.closed:
381
+ raise ClosedEnvironmentError(
382
+ "Trying to operate on `{0}`, after a "
383
+ "call to `close()`.".format(type(self).__name__)
384
+ )
385
+
386
+ def _raise_if_errors(self, successes):
387
+ if all(successes):
388
+ return
389
+
390
+ num_errors = self.num_envs - sum(successes)
391
+ assert num_errors > 0
392
+ for _ in range(num_errors):
393
+ index, exctype, value = self.error_queue.get()
394
+ logger.error(
395
+ "Received the following error from Worker-{0}: "
396
+ "{1}: {2}".format(index, exctype.__name__, value)
397
+ )
398
+ logger.error("Shutting down Worker-{0}.".format(index))
399
+ self.parent_pipes[index].close()
400
+ self.parent_pipes[index] = None
401
+
402
+ logger.error("Raising the last exception back to the main process.")
403
+ raise exctype(value)
404
+
405
+ def call_async(self, name: str, *args, **kwargs):
406
+ """Calls the method with name asynchronously and apply args and kwargs to the method.
407
+
408
+ Args:
409
+ name: Name of the method or property to call.
410
+ *args: Arguments to apply to the method call.
411
+ **kwargs: Keyword arguments to apply to the method call.
412
+
413
+ Raises:
414
+ ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
415
+ AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
416
+ """
417
+ self._assert_is_running()
418
+ if self._state != AsyncState.DEFAULT:
419
+ raise AlreadyPendingCallError(
420
+ "Calling `call_async` while waiting "
421
+ f"for a pending call to `{self._state.value}` to complete.",
422
+ self._state.value,
423
+ )
424
+
425
+ for pipe in self.parent_pipes:
426
+ pipe.send(("_call", (name, args, kwargs)))
427
+ self._state = AsyncState.WAITING_CALL
428
+
429
+ def call_wait(self, timeout = None) -> list:
430
+ """Calls all parent pipes and waits for the results.
431
+
432
+ Args:
433
+ timeout: Number of seconds before the call to `step_wait` times out.
434
+ If `None` (default), the call to `step_wait` never times out.
435
+
436
+ Returns:
437
+ List of the results of the individual calls to the method or property for each environment.
438
+
439
+ Raises:
440
+ NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
441
+ TimeoutError: The call to `call_wait` has timed out after timeout second(s).
442
+ """
443
+ self._assert_is_running()
444
+ if self._state != AsyncState.WAITING_CALL:
445
+ raise NoAsyncCallError(
446
+ "Calling `call_wait` without any prior call to `call_async`.",
447
+ AsyncState.WAITING_CALL.value,
448
+ )
449
+
450
+ if not self._poll(timeout):
451
+ self._state = AsyncState.DEFAULT
452
+ raise mp.TimeoutError(
453
+ f"The call to `call_wait` has timed out after {timeout} second(s)."
454
+ )
455
+
456
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
457
+ self._raise_if_errors(successes)
458
+ self._state = AsyncState.DEFAULT
459
+
460
+ return results
461
+
462
+ def call(self, name: str, *args, **kwargs):
463
+ """Call a method, or get a property, from each parallel environment.
464
+
465
+ Args:
466
+ name (str): Name of the method or property to call.
467
+ *args: Arguments to apply to the method call.
468
+ **kwargs: Keyword arguments to apply to the method call.
469
+
470
+ Returns:
471
+ List of the results of the individual calls to the method or property for each environment.
472
+ """
473
+ self.call_async(name, *args, **kwargs)
474
+ return self.call_wait()
475
+
476
+
477
+ def call_each(self, name: str,
478
+ args_list: list=None,
479
+ kwargs_list: list=None,
480
+ timeout = None):
481
+ n_envs = len(self.parent_pipes)
482
+ if args_list is None:
483
+ args_list = [[]] * n_envs
484
+ assert len(args_list) == n_envs
485
+
486
+ if kwargs_list is None:
487
+ kwargs_list = [dict()] * n_envs
488
+ assert len(kwargs_list) == n_envs
489
+
490
+ # send
491
+ self._assert_is_running()
492
+ if self._state != AsyncState.DEFAULT:
493
+ raise AlreadyPendingCallError(
494
+ "Calling `call_async` while waiting "
495
+ f"for a pending call to `{self._state.value}` to complete.",
496
+ self._state.value,
497
+ )
498
+
499
+ for i, pipe in enumerate(self.parent_pipes):
500
+ pipe.send(("_call", (name, args_list[i], kwargs_list[i])))
501
+ self._state = AsyncState.WAITING_CALL
502
+
503
+ # receive
504
+ self._assert_is_running()
505
+ if self._state != AsyncState.WAITING_CALL:
506
+ raise NoAsyncCallError(
507
+ "Calling `call_wait` without any prior call to `call_async`.",
508
+ AsyncState.WAITING_CALL.value,
509
+ )
510
+
511
+ if not self._poll(timeout):
512
+ self._state = AsyncState.DEFAULT
513
+ raise mp.TimeoutError(
514
+ f"The call to `call_wait` has timed out after {timeout} second(s)."
515
+ )
516
+
517
+ results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
518
+ self._raise_if_errors(successes)
519
+ self._state = AsyncState.DEFAULT
520
+
521
+ return results
522
+
523
+
524
+ def set_attr(self, name: str, values):
525
+ """Sets an attribute of the sub-environments.
526
+
527
+ Args:
528
+ name: Name of the property to be set in each individual environment.
529
+ values: Values of the property to be set to. If ``values`` is a list or
530
+ tuple, then it corresponds to the values for each individual
531
+ environment, otherwise a single value is set for all environments.
532
+
533
+ Raises:
534
+ ValueError: Values must be a list or tuple with length equal to the number of environments.
535
+ AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
536
+ """
537
+ self._assert_is_running()
538
+ if not isinstance(values, (list, tuple)):
539
+ values = [values for _ in range(self.num_envs)]
540
+ if len(values) != self.num_envs:
541
+ raise ValueError(
542
+ "Values must be a list or tuple with length equal to the "
543
+ f"number of environments. Got `{len(values)}` values for "
544
+ f"{self.num_envs} environments."
545
+ )
546
+
547
+ if self._state != AsyncState.DEFAULT:
548
+ raise AlreadyPendingCallError(
549
+ "Calling `set_attr` while waiting "
550
+ f"for a pending call to `{self._state.value}` to complete.",
551
+ self._state.value,
552
+ )
553
+
554
+ for pipe, value in zip(self.parent_pipes, values):
555
+ pipe.send(("_setattr", (name, value)))
556
+ _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
557
+ self._raise_if_errors(successes)
558
+
559
+ def render(self, *args, **kwargs):
560
+ return self.call('render', *args, **kwargs)
561
+
562
+
563
+
564
+ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
565
+ assert shared_memory is None
566
+ env = env_fn()
567
+ parent_pipe.close()
568
+ try:
569
+ while True:
570
+ command, data = pipe.recv()
571
+ if command == "reset":
572
+ observation = env.reset()
573
+ pipe.send((observation, True))
574
+ elif command == "step":
575
+ observation, reward, done, info = env.step(data)
576
+ # if done:
577
+ # observation = env.reset()
578
+ pipe.send(((observation, reward, done, info), True))
579
+ elif command == "seed":
580
+ env.seed(data)
581
+ pipe.send((None, True))
582
+ elif command == "close":
583
+ pipe.send((None, True))
584
+ break
585
+ elif command == "_call":
586
+ name, args, kwargs = data
587
+ if name in ["reset", "step", "seed", "close"]:
588
+ raise ValueError(
589
+ f"Trying to call function `{name}` with "
590
+ f"`_call`. Use `{name}` directly instead."
591
+ )
592
+ function = getattr(env, name)
593
+ if callable(function):
594
+ pipe.send((function(*args, **kwargs), True))
595
+ else:
596
+ pipe.send((function, True))
597
+ elif command == "_setattr":
598
+ name, value = data
599
+ setattr(env, name, value)
600
+ pipe.send((None, True))
601
+
602
+ elif command == "_check_observation_space":
603
+ pipe.send((data == env.observation_space, True))
604
+ else:
605
+ raise RuntimeError(
606
+ "Received unknown command `{0}`. Must "
607
+ "be one of {`reset`, `step`, `seed`, `close`, "
608
+ "`_check_observation_space`}.".format(command)
609
+ )
610
+ except (KeyboardInterrupt, Exception):
611
+ error_queue.put((index,) + sys.exc_info()[:2])
612
+ pipe.send((None, False))
613
+ finally:
614
+ env.close()
615
+
616
+
617
+ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
618
+ assert shared_memory is not None
619
+ env = env_fn()
620
+ observation_space = env.observation_space
621
+ parent_pipe.close()
622
+ try:
623
+ while True:
624
+ command, data = pipe.recv()
625
+ if command == "reset":
626
+ observation = env.reset()
627
+ write_to_shared_memory(
628
+ index, observation, shared_memory, observation_space
629
+ )
630
+ pipe.send((None, True))
631
+ elif command == "step":
632
+ observation, reward, done, info = env.step(data)
633
+ # if done:
634
+ # observation = env.reset()
635
+ write_to_shared_memory(
636
+ index, observation, shared_memory, observation_space
637
+ )
638
+ pipe.send(((None, reward, done, info), True))
639
+ elif command == "seed":
640
+ env.seed(data)
641
+ pipe.send((None, True))
642
+ elif command == "close":
643
+ pipe.send((None, True))
644
+ break
645
+ elif command == "_call":
646
+ name, args, kwargs = data
647
+ if name in ["reset", "step", "seed", "close"]:
648
+ raise ValueError(
649
+ f"Trying to call function `{name}` with "
650
+ f"`_call`. Use `{name}` directly instead."
651
+ )
652
+ function = getattr(env, name)
653
+ if callable(function):
654
+ pipe.send((function(*args, **kwargs), True))
655
+ else:
656
+ pipe.send((function, True))
657
+ elif command == "_setattr":
658
+ name, value = data
659
+ setattr(env, name, value)
660
+ pipe.send((None, True))
661
+ elif command == "_check_observation_space":
662
+ pipe.send((data == observation_space, True))
663
+ else:
664
+ raise RuntimeError(
665
+ "Received unknown command `{0}`. Must "
666
+ "be one of {`reset`, `step`, `seed`, `close`, "
667
+ "`_check_observation_space`}.".format(command)
668
+ )
669
+ except (KeyboardInterrupt, Exception):
670
+ error_queue.put((index,) + sys.exc_info()[:2])
671
+ pipe.send((None, False))
672
+ finally:
673
+ env.close()
equidiff/equi_diffpo/gym_util/multistep_wrapper.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ from gym import spaces
3
+ import numpy as np
4
+ from collections import defaultdict, deque
5
+ import dill
6
+
7
+ def stack_repeated(x, n):
8
+ return np.repeat(np.expand_dims(x,axis=0),n,axis=0)
9
+
10
+ def repeated_box(box_space, n):
11
+ return spaces.Box(
12
+ low=stack_repeated(box_space.low, n),
13
+ high=stack_repeated(box_space.high, n),
14
+ shape=(n,) + box_space.shape,
15
+ dtype=box_space.dtype
16
+ )
17
+
18
+ def repeated_space(space, n):
19
+ if isinstance(space, spaces.Box):
20
+ return repeated_box(space, n)
21
+ elif isinstance(space, spaces.Dict):
22
+ result_space = spaces.Dict()
23
+ for key, value in space.items():
24
+ result_space[key] = repeated_space(value, n)
25
+ return result_space
26
+ else:
27
+ raise RuntimeError(f'Unsupported space type {type(space)}')
28
+
29
+ def take_last_n(x, n):
30
+ x = list(x)
31
+ n = min(len(x), n)
32
+ return np.array(x[-n:])
33
+
34
+ def dict_take_last_n(x, n):
35
+ result = dict()
36
+ for key, value in x.items():
37
+ result[key] = take_last_n(value, n)
38
+ return result
39
+
40
+ def aggregate(data, method='max'):
41
+ if method == 'max':
42
+ # equivalent to any
43
+ return np.max(data)
44
+ elif method == 'min':
45
+ # equivalent to all
46
+ return np.min(data)
47
+ elif method == 'mean':
48
+ return np.mean(data)
49
+ elif method == 'sum':
50
+ return np.sum(data)
51
+ else:
52
+ raise NotImplementedError()
53
+
54
+ def stack_last_n_obs(all_obs, n_steps):
55
+ assert(len(all_obs) > 0)
56
+ all_obs = list(all_obs)
57
+ result = np.zeros((n_steps,) + all_obs[-1].shape,
58
+ dtype=all_obs[-1].dtype)
59
+ start_idx = -min(n_steps, len(all_obs))
60
+ result[start_idx:] = np.array(all_obs[start_idx:])
61
+ if n_steps > len(all_obs):
62
+ # pad
63
+ result[:start_idx] = result[start_idx]
64
+ return result
65
+
66
+
67
+ class MultiStepWrapper(gym.Wrapper):
68
+ def __init__(self,
69
+ env,
70
+ n_obs_steps,
71
+ n_action_steps,
72
+ max_episode_steps=None,
73
+ reward_agg_method='max'
74
+ ):
75
+ super().__init__(env)
76
+ self._action_space = repeated_space(env.action_space, n_action_steps)
77
+ self._observation_space = repeated_space(env.observation_space, n_obs_steps)
78
+ self.max_episode_steps = max_episode_steps
79
+ self.n_obs_steps = n_obs_steps
80
+ self.n_action_steps = n_action_steps
81
+ self.reward_agg_method = reward_agg_method
82
+ self.n_obs_steps = n_obs_steps
83
+
84
+ self.obs = deque(maxlen=n_obs_steps+1)
85
+ self.reward = list()
86
+ self.done = list()
87
+ self.info = defaultdict(lambda : deque(maxlen=n_obs_steps+1))
88
+
89
+ def reset(self):
90
+ """Resets the environment using kwargs."""
91
+ obs = super().reset()
92
+
93
+ self.obs = deque([obs], maxlen=self.n_obs_steps+1)
94
+ self.reward = list()
95
+ self.done = list()
96
+ self.info = defaultdict(lambda : deque(maxlen=self.n_obs_steps+1))
97
+
98
+ obs = self._get_obs(self.n_obs_steps)
99
+ return obs
100
+
101
+ def step(self, action):
102
+ """
103
+ actions: (n_action_steps,) + action_shape
104
+ """
105
+ for act in action:
106
+ if len(self.done) > 0 and self.done[-1]:
107
+ # termination
108
+ break
109
+ observation, reward, done, info = super().step(act)
110
+
111
+ self.obs.append(observation)
112
+ self.reward.append(reward)
113
+ if (self.max_episode_steps is not None) \
114
+ and (len(self.reward) >= self.max_episode_steps):
115
+ # truncation
116
+ done = True
117
+ self.done.append(done)
118
+ self._add_info(info)
119
+
120
+ observation = self._get_obs(self.n_obs_steps)
121
+ reward = aggregate(self.reward, self.reward_agg_method)
122
+ done = aggregate(self.done, 'max')
123
+ info = dict_take_last_n(self.info, self.n_obs_steps)
124
+ return observation, reward, done, info
125
+
126
+ def _get_obs(self, n_steps=1):
127
+ """
128
+ Output (n_steps,) + obs_shape
129
+ """
130
+ assert(len(self.obs) > 0)
131
+ if isinstance(self.observation_space, spaces.Box):
132
+ return stack_last_n_obs(self.obs, n_steps)
133
+ elif isinstance(self.observation_space, spaces.Dict):
134
+ result = dict()
135
+ for key in self.observation_space.keys():
136
+ result[key] = stack_last_n_obs(
137
+ [obs[key] for obs in self.obs],
138
+ n_steps
139
+ )
140
+ return result
141
+ else:
142
+ raise RuntimeError('Unsupported space type')
143
+
144
+ def _add_info(self, info):
145
+ for key, value in info.items():
146
+ self.info[key].append(value)
147
+
148
+ def get_rewards(self):
149
+ return self.reward
150
+
151
+ def get_attr(self, name):
152
+ return getattr(self, name)
153
+
154
+ def run_dill_function(self, dill_fn):
155
+ fn = dill.loads(dill_fn)
156
+ return fn(self)
157
+
158
+ def get_infos(self):
159
+ result = dict()
160
+ for k, v in self.info.items():
161
+ result[k] = list(v)
162
+ return result