Commit
·
c1f1d32
1
Parent(s):
836bc67
mimicgen
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- README.md +35 -0
- equidiff/.gitignore +156 -0
- equidiff/LICENSE +21 -0
- equidiff/README.md +115 -0
- equidiff/combinehdf5.py +59 -0
- equidiff/conda_environment.yaml +61 -0
- equidiff/equi_diffpo/codecs/imagecodecs_numcodecs.py +1386 -0
- equidiff/equi_diffpo/common/checkpoint_util.py +59 -0
- equidiff/equi_diffpo/common/cv2_util.py +150 -0
- equidiff/equi_diffpo/common/env_util.py +23 -0
- equidiff/equi_diffpo/common/json_logger.py +117 -0
- equidiff/equi_diffpo/common/nested_dict_util.py +32 -0
- equidiff/equi_diffpo/common/normalize_util.py +311 -0
- equidiff/equi_diffpo/common/pose_trajectory_interpolator.py +208 -0
- equidiff/equi_diffpo/common/precise_sleep.py +25 -0
- equidiff/equi_diffpo/common/pymunk_override.py +248 -0
- equidiff/equi_diffpo/common/pymunk_util.py +52 -0
- equidiff/equi_diffpo/common/pytorch_util.py +82 -0
- equidiff/equi_diffpo/common/replay_buffer.py +588 -0
- equidiff/equi_diffpo/common/sampler.py +153 -0
- equidiff/equi_diffpo/common/timestamp_accumulator.py +222 -0
- equidiff/equi_diffpo/config/dp3.yaml +152 -0
- equidiff/equi_diffpo/config/task/mimicgen_abs.yaml +60 -0
- equidiff/equi_diffpo/config/task/mimicgen_pc_abs.yaml +81 -0
- equidiff/equi_diffpo/config/task/mimicgen_rel.yaml +60 -0
- equidiff/equi_diffpo/config/task/mimicgen_voxel_abs.yaml +84 -0
- equidiff/equi_diffpo/config/task/mimicgen_voxel_rel.yaml +84 -0
- equidiff/equi_diffpo/config/test_equi_diffusion_unet_abs_sq2.yaml +141 -0
- equidiff/equi_diffpo/config/test_sq2.yaml +142 -0
- equidiff/equi_diffpo/config/test_th2.yaml +142 -0
- equidiff/equi_diffpo/config/train_act_abs.yaml +88 -0
- equidiff/equi_diffpo/config/train_bc_rnn.yaml +94 -0
- equidiff/equi_diffpo/config/train_diffusion_transformer.yaml +143 -0
- equidiff/equi_diffpo/config/train_diffusion_unet.yaml +140 -0
- equidiff/equi_diffpo/config/train_diffusion_unet_voxel_abs.yaml +137 -0
- equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs.yaml +137 -0
- equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs_sq2_0-1.yaml +137 -0
- equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs_sq2_1-1.yaml +137 -0
- equidiff/equi_diffpo/config/train_equi_diffusion_unet_rel.yaml +136 -0
- equidiff/equi_diffpo/config/train_equi_diffusion_unet_voxel_abs.yaml +137 -0
- equidiff/equi_diffpo/config/train_equi_diffusion_unet_voxel_rel.yaml +137 -0
- equidiff/equi_diffpo/config/train_sq2.yaml +139 -0
- equidiff/equi_diffpo/config/train_sq2_5000.yaml +139 -0
- equidiff/equi_diffpo/config/train_th2_5000.yaml +139 -0
- equidiff/equi_diffpo/dataset/base_dataset.py +51 -0
- equidiff/equi_diffpo/env_runner/base_image_runner.py +9 -0
- equidiff/equi_diffpo/env_runner/base_lowdim_runner.py +9 -0
- equidiff/equi_diffpo/gym_util/async_vector_env.py +673 -0
- equidiff/equi_diffpo/gym_util/multistep_wrapper.py +162 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.mp4
|
| 2 |
+
mimicgen_environments*
|
| 3 |
+
robomimic*
|
| 4 |
+
equidiff/data/*
|
README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Equidiff
|
| 2 |
+
|
| 3 |
+
- folder_name: the name of the folder
|
| 4 |
+
- file_name: the name of your file
|
| 5 |
+
|
| 6 |
+
## Prepare data
|
| 7 |
+
Use mimicgen to generate data.
|
| 8 |
+
|
| 9 |
+
Use `EmbodiedBM/equidiff/combinehdf5.py` to combine data from multiple .hdf5 files if needed.
|
| 10 |
+
|
| 11 |
+
Put hdf5 data at `EmbodiedBM/equidiff/data/robomimic/datasets` with the format [folder_name]/[file_name].hdf5
|
| 12 |
+
|
| 13 |
+
## Convert data
|
| 14 |
+
```bash
|
| 15 |
+
python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/square_d2_test/demo.hdf5 -o data/robomimic/datasets/square_d2_test/demo_abs.hdf5 -n 12
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
## Train
|
| 19 |
+
|
| 20 |
+
Use another CUDA device if 7 is currently in use.
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
CUDA_VISIBLE_DEVICES=5 MUJOCO_GL=osmesa PYOPENGL_PLATFORM=osmesa HYDRA_FULL_ERROR=1 python train.py --config-name=train_sq2_5000 folder_name=square_d2_5000 file_name=demo n_demo=5000
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
If you use another task than square_d2, you should change the task_name config by adding task_name=[task_name]
|
| 27 |
+
|
| 28 |
+
## Test
|
| 29 |
+
Change the `ckpt_path` to the trained policy's weight's path in `EmbodiedBM/equidiff/equi_diffpo/config/test_sq2.yaml`
|
| 30 |
+
|
| 31 |
+
If you use another task than square_d2, you should change the dataset config in test_sq2.yaml and download the corresponding dataset from [Huggingface](https://huggingface.co/datasets/amandlek/mimicgen_datasets/tree/main/core).
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
python test.py
|
| 35 |
+
```
|
equidiff/.gitignore
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
bin
|
| 2 |
+
logs
|
| 3 |
+
wandb
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
data_local
|
| 7 |
+
.vscode
|
| 8 |
+
_wandb
|
| 9 |
+
|
| 10 |
+
**/.DS_Store
|
| 11 |
+
|
| 12 |
+
fuse.cfg
|
| 13 |
+
|
| 14 |
+
*.ai
|
| 15 |
+
|
| 16 |
+
# Generation results
|
| 17 |
+
results/
|
| 18 |
+
|
| 19 |
+
ray/auth.json
|
| 20 |
+
|
| 21 |
+
# Byte-compiled / optimized / DLL files
|
| 22 |
+
__pycache__/
|
| 23 |
+
*.py[cod]
|
| 24 |
+
*$py.class
|
| 25 |
+
|
| 26 |
+
# C extensions
|
| 27 |
+
*.so
|
| 28 |
+
|
| 29 |
+
# Distribution / packaging
|
| 30 |
+
.Python
|
| 31 |
+
build/
|
| 32 |
+
develop-eggs/
|
| 33 |
+
dist/
|
| 34 |
+
downloads/
|
| 35 |
+
eggs/
|
| 36 |
+
.eggs/
|
| 37 |
+
lib/
|
| 38 |
+
lib64/
|
| 39 |
+
parts/
|
| 40 |
+
sdist/
|
| 41 |
+
var/
|
| 42 |
+
wheels/
|
| 43 |
+
pip-wheel-metadata/
|
| 44 |
+
share/python-wheels/
|
| 45 |
+
*.egg-info/
|
| 46 |
+
.installed.cfg
|
| 47 |
+
*.egg
|
| 48 |
+
MANIFEST
|
| 49 |
+
|
| 50 |
+
# PyInstaller
|
| 51 |
+
# Usually these files are written by a python script from a template
|
| 52 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 53 |
+
*.manifest
|
| 54 |
+
*.spec
|
| 55 |
+
|
| 56 |
+
# Installer logs
|
| 57 |
+
pip-log.txt
|
| 58 |
+
pip-delete-this-directory.txt
|
| 59 |
+
|
| 60 |
+
# Unit test / coverage reports
|
| 61 |
+
htmlcov/
|
| 62 |
+
.tox/
|
| 63 |
+
.nox/
|
| 64 |
+
.coverage
|
| 65 |
+
.coverage.*
|
| 66 |
+
.cache
|
| 67 |
+
nosetests.xml
|
| 68 |
+
coverage.xml
|
| 69 |
+
*.cover
|
| 70 |
+
*.py,cover
|
| 71 |
+
.hypothesis/
|
| 72 |
+
.pytest_cache/
|
| 73 |
+
|
| 74 |
+
# Translations
|
| 75 |
+
*.mo
|
| 76 |
+
*.pot
|
| 77 |
+
|
| 78 |
+
# Django stuff:
|
| 79 |
+
*.log
|
| 80 |
+
local_settings.py
|
| 81 |
+
db.sqlite3
|
| 82 |
+
db.sqlite3-journal
|
| 83 |
+
|
| 84 |
+
# Flask stuff:
|
| 85 |
+
instance/
|
| 86 |
+
.webassets-cache
|
| 87 |
+
|
| 88 |
+
# Scrapy stuff:
|
| 89 |
+
.scrapy
|
| 90 |
+
|
| 91 |
+
# Sphinx documentation
|
| 92 |
+
docs/_build/
|
| 93 |
+
|
| 94 |
+
# PyBuilder
|
| 95 |
+
target/
|
| 96 |
+
|
| 97 |
+
# Jupyter Notebook
|
| 98 |
+
.ipynb_checkpoints
|
| 99 |
+
|
| 100 |
+
# IPython
|
| 101 |
+
profile_default/
|
| 102 |
+
ipython_config.py
|
| 103 |
+
|
| 104 |
+
# pyenv
|
| 105 |
+
.python-version
|
| 106 |
+
|
| 107 |
+
# pipenv
|
| 108 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 109 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 110 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 111 |
+
# install all needed dependencies.
|
| 112 |
+
#Pipfile.lock
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Spyder project settings
|
| 125 |
+
.spyderproject
|
| 126 |
+
.spyproject
|
| 127 |
+
|
| 128 |
+
# Rope project settings
|
| 129 |
+
.ropeproject
|
| 130 |
+
|
| 131 |
+
# mkdocs documentation
|
| 132 |
+
/site
|
| 133 |
+
|
| 134 |
+
# mypy
|
| 135 |
+
.mypy_cache/
|
| 136 |
+
.dmypy.json
|
| 137 |
+
dmypy.json
|
| 138 |
+
|
| 139 |
+
# Pyre type checker
|
| 140 |
+
.pyre/
|
| 141 |
+
equi_diffpo/scripts/equidiff_data_conversion.py
|
| 142 |
+
equi_diffpo/model/equi/vec_conditional_unet1d_1.py
|
| 143 |
+
update_max_score.py
|
| 144 |
+
test4.png
|
| 145 |
+
test3.png
|
| 146 |
+
test2.png
|
| 147 |
+
test1.png
|
| 148 |
+
test.png
|
| 149 |
+
sampled_xyz.png
|
| 150 |
+
pc.pt
|
| 151 |
+
metric3.py
|
| 152 |
+
metric2.py
|
| 153 |
+
metric1.py
|
| 154 |
+
grouped_xyz.png
|
| 155 |
+
all_xyz.png
|
| 156 |
+
1.png
|
equidiff/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Dian Wang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
equidiff/README.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Equivariant Diffusion Policy
|
| 2 |
+
[Project Website](https://equidiff.github.io) | [Paper](https://arxiv.org/pdf/2407.01812) | [Video](https://youtu.be/xIFSx_NVROU?si=MaxsHmih6AnQKAVy)
|
| 3 |
+
<a href="https://pointw.github.io/">Dian Wang</a><sup>1</sup>, <a href="https://www.linkedin.com/in/stephen-hart-3711666/">Stephen Hart</a><sup>2</sup>, <a href="https://www.linkedin.com/in/surovik/">David Surovik</a><sup>2</sup>, <a href="https://kelestemur.com">Tarik Kelestemur</a><sup>2</sup>, <a href="https://haojhuang.github.io/">Haojie Huang</a><sup>1</sup>, <a href="https://www.linkedin.com/in/haibo-zhao-b68742250/">Haibo Zhao</a><sup>1</sup>, <a href="https://www.linkedin.com/in/mark-yeatman-58a49763/">Mark Yeatman</a><sup>2</sup>, <a href="https://www.robo.guru/">Jiuguang Wang</a><sup>2</sup>, <a href="https://www.robinwalters.com/">Robin Walters</a><sup>1</sup>, <a href="https://helpinghandslab.netlify.app/people/">Robert Platt</a><sup>12</sup>
|
| 4 |
+
<sup>1</sup>Northeastern Univeristy, <sup>2</sup>Boston Dynamics AI Institute
|
| 5 |
+
Conference on Robot Learning 2024 (Oral)
|
| 6 |
+
 |
|
| 7 |
+
## Installation
|
| 8 |
+
1. Install the following apt packages for mujoco:
|
| 9 |
+
```bash
|
| 10 |
+
sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
|
| 11 |
+
```
|
| 12 |
+
1. Install gfortran (dependancy for escnn)
|
| 13 |
+
```bash
|
| 14 |
+
sudo apt install -y gfortran
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
1. Install [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) (strongly recommended) or Anaconda
|
| 18 |
+
1. Clone this repo
|
| 19 |
+
```bash
|
| 20 |
+
git clone https://github.com/pointW/equidiff.git
|
| 21 |
+
cd equidiff
|
| 22 |
+
```
|
| 23 |
+
1. Install environment:
|
| 24 |
+
Use Mambaforge (strongly recommended):
|
| 25 |
+
```bash
|
| 26 |
+
mamba env create -f conda_environment.yaml
|
| 27 |
+
conda activate equidiff
|
| 28 |
+
```
|
| 29 |
+
or use Anaconda (not recommended):
|
| 30 |
+
```bash
|
| 31 |
+
conda env create -f conda_environment.yaml
|
| 32 |
+
conda activate equidiff
|
| 33 |
+
```
|
| 34 |
+
1. Install mimicgen:
|
| 35 |
+
```bash
|
| 36 |
+
cd ..
|
| 37 |
+
git clone https://github.com/NVlabs/mimicgen_environments.git
|
| 38 |
+
cd mimicgen_environments
|
| 39 |
+
# This project was developed with Mimicgen v0.1.0. The latest version should work fine, but it is not tested
|
| 40 |
+
git checkout 081f7dbbe5fff17b28c67ce8ec87c371f32526a9
|
| 41 |
+
pip install -e .
|
| 42 |
+
cd ../equidiff
|
| 43 |
+
```
|
| 44 |
+
1. Make sure mujoco version is 2.3.2 (required by mimicgen)
|
| 45 |
+
```bash
|
| 46 |
+
pip list | grep mujoco
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Dataset
|
| 50 |
+
### Download Dataset
|
| 51 |
+
```bash
|
| 52 |
+
# Download all datasets
|
| 53 |
+
python equi_diffpo/scripts/download_datasets.py --tasks stack_d1 stack_three_d1 square_d2 threading_d2 coffee_d2 three_piece_assembly_d2 hammer_cleanup_d1 mug_cleanup_d1 kitchen_d1 nut_assembly_d0 pick_place_d0 coffee_preparation_d1
|
| 54 |
+
# Alternatively, download one (or several) datasets of interest, e.g.,
|
| 55 |
+
python equi_diffpo/scripts/download_datasets.py --tasks stack_d1
|
| 56 |
+
```
|
| 57 |
+
### Generating Voxel and Point Cloud Observation
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Template
|
| 61 |
+
python equi_diffpo/scripts/dataset_states_to_obs.py --input data/robomimic/datasets/[dataset]/[dataset].hdf5 --output data/robomimic/datasets/[dataset]/[dataset]_voxel.hdf5 --num_workers=[n_worker]
|
| 62 |
+
# Replace [dataset] and [n_worker] with your choices.
|
| 63 |
+
# E.g., use 24 workers to generate point cloud and voxel observation for stack_d1
|
| 64 |
+
python equi_diffpo/scripts/dataset_states_to_obs.py --input data/robomimic/datasets/stack_d1/stack_d1.hdf5 --output data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 --num_workers=24
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Convert Action Space in Dataset
|
| 68 |
+
The downloaded dataset has a relative action space. To train with absolute action space, the dataset needs to be converted accordingly
|
| 69 |
+
```bash
|
| 70 |
+
# Template
|
| 71 |
+
python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/[dataset]/[dataset].hdf5 -o data/robomimic/datasets/[dataset]/[dataset]_abs.hdf5 -n [n_worker]
|
| 72 |
+
# Replace [dataset] and [n_worker] with your choices.
|
| 73 |
+
# E.g., convert stack_d1 (non-voxel) with 12 workers
|
| 74 |
+
python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 -o data/robomimic/datasets/stack_d1/stack_d1_abs.hdf5 -n 12
|
| 75 |
+
# E.g., convert stack_d1_voxel (voxel) with 12 workers
|
| 76 |
+
python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 -o data/robomimic/datasets/stack_d1/stack_d1_voxel_abs.hdf5 -n 12
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## Training with image observation
|
| 80 |
+
To train Equivariant Diffusion Policy (with absolute pose control) in Stack D1 task:
|
| 81 |
+
```bash
|
| 82 |
+
# Make sure you have the non-voxel converted dataset with absolute action space from the previous step
|
| 83 |
+
python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 n_demo=100
|
| 84 |
+
```
|
| 85 |
+
To train with relative pose control instead:
|
| 86 |
+
```bash
|
| 87 |
+
python train.py --config-name=train_equi_diffusion_unet_rel task_name=stack_d1 n_demo=100
|
| 88 |
+
```
|
| 89 |
+
To train in other tasks, replace `stack_d1` with `stack_three_d1`, `square_d2`, `threading_d2`, `coffee_d2`, `three_piece_assembly_d2`, `hammer_cleanup_d1`, `mug_cleanup_d1`, `kitchen_d1`, `nut_assembly_d0`, `pick_place_d0`, `coffee_preparation_d1`. Notice that the corresponding dataset should be downloaded already. If training absolute pose control, the data conversion is also needed.
|
| 90 |
+
|
| 91 |
+
To run environments on CPU (to save GPU memory), use `osmesa` instead of `egl` through `MUJOCO_GL=osmesa PYOPENGL_PLATTFORM=osmesa`, e.g.,
|
| 92 |
+
```bash
|
| 93 |
+
MUJOCO_GL=osmesa PYOPENGL_PLATTFORM=osmesa python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
Equivariant Diffusion Policy requires around 22G GPU memory to run with batch size of 128 (default). To reduce the GPU usage, consider training with smaller batch size and/or reducing the hidden dimension
|
| 97 |
+
```bash
|
| 98 |
+
# to train with batch size of 64 and hidden dimension of 64
|
| 99 |
+
MUJOCO_GL=osmesa PYOPENGL_PLATTFORM=osmesa python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 policy.enc_n_hidden=64 dataloader.batch_size=64
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
## Training with voxel observation
|
| 103 |
+
To train Equivariant Diffusion Policy (with absolute pose control) in Stack D1 task:
|
| 104 |
+
```bash
|
| 105 |
+
# Make sure you have the voxel converted dataset with absolute action space from the previous step
|
| 106 |
+
python train.py --config-name=train_equi_diffusion_unet_voxel_abs task_name=stack_d1 n_demo=100
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## License
|
| 110 |
+
This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.
|
| 111 |
+
|
| 112 |
+
## Acknowledgement
|
| 113 |
+
* Our repo is built upon the origional [Diffusion Policy](https://github.com/real-stanford/diffusion_policy)
|
| 114 |
+
* Our ACT baseline is adaped from its [original repo](https://github.com/tonyzhaozh/act)
|
| 115 |
+
* Our DP3 baseline is adaped from its [original repo](https://github.com/YanjieZe/3D-Diffusion-Policy)
|
equidiff/combinehdf5.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import h5py
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
def update_suffix(original_string, increment):
|
| 6 |
+
updated_string = re.sub(r'(\d+)$', lambda x: str(int(x.group(1)) + increment), original_string)
|
| 7 |
+
return updated_string
|
| 8 |
+
|
| 9 |
+
def merge(output_file, input_files, total_size, truncate_len=-1):
|
| 10 |
+
numbers = list(range(total_size))
|
| 11 |
+
random.shuffle(numbers)
|
| 12 |
+
with h5py.File(output_file, 'w') as h5out:
|
| 13 |
+
h5out_data = h5out.create_group('data')
|
| 14 |
+
i = 0
|
| 15 |
+
for input_file in input_files:
|
| 16 |
+
with h5py.File(input_file, 'r') as f:
|
| 17 |
+
d = f['data']
|
| 18 |
+
for key in d:
|
| 19 |
+
new_key = f"demo_{numbers[i]}"
|
| 20 |
+
print(new_key)
|
| 21 |
+
if isinstance(d[key], h5py.Group):
|
| 22 |
+
d[key].copy(d[key], h5out_data, name=new_key)
|
| 23 |
+
elif isinstance(d[key], h5py.Dataset):
|
| 24 |
+
h5out_data.create_dataset(key, data=d[key][:])
|
| 25 |
+
i+=1
|
| 26 |
+
if truncate_len:
|
| 27 |
+
if i == truncate_len:
|
| 28 |
+
break
|
| 29 |
+
print(len(h5out_data))
|
| 30 |
+
with h5py.File(input_files[0], 'r') as f:
|
| 31 |
+
d1 = f['data']
|
| 32 |
+
if "env_args" in d1.attrs:
|
| 33 |
+
h5out_data.attrs["env_args"] = d1.attrs["env_args"]
|
| 34 |
+
|
| 35 |
+
def print_hdf5_structure(file_path):
|
| 36 |
+
def recursively_print(group, indent=0):
|
| 37 |
+
for key in group:
|
| 38 |
+
item = group[key]
|
| 39 |
+
if isinstance(item, h5py.Group):
|
| 40 |
+
print(" " * indent + f"Group: {key}")
|
| 41 |
+
recursively_print(item, indent + 1)
|
| 42 |
+
elif isinstance(item, h5py.Dataset):
|
| 43 |
+
print(" " * indent + f"Dataset: {key}, Shape: {item.shape}, Type: {item.dtype}")
|
| 44 |
+
|
| 45 |
+
with h5py.File(file_path, 'r') as f:
|
| 46 |
+
print(f"File: {file_path}")
|
| 47 |
+
recursively_print(f)
|
| 48 |
+
dataset = f["data"]
|
| 49 |
+
if "env_args" in dataset.attrs:
|
| 50 |
+
env_args = dataset.attrs["env_args"]
|
| 51 |
+
print(f"env_args: {env_args}")
|
| 52 |
+
|
| 53 |
+
# input_files = ["/home/siweih/Project/EmbodiedBM/equidiff/data/robomimic/datasets/square_d2/square_d2.hdf5","/home/siweih/Project/EmbodiedBM/equidiff/mix_4000.hdf5"]
|
| 54 |
+
|
| 55 |
+
# demo_num = 5000
|
| 56 |
+
# output_file = f"mix_{demo_num}.hdf5"
|
| 57 |
+
|
| 58 |
+
# merge(output_file, input_files, demo_num)
|
| 59 |
+
print_hdf5_structure("/home/siweih/Project/EmbodiedBM/mimicgen/core_datasets/square/demo_src_square_task_D2/demo.hdf5")
|
equidiff/conda_environment.yaml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: equidiff
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- pytorch3d
|
| 5 |
+
- nvidia
|
| 6 |
+
- conda-forge
|
| 7 |
+
dependencies:
|
| 8 |
+
- python=3.9
|
| 9 |
+
- pip=22.2.2
|
| 10 |
+
- pytorch=2.1.0
|
| 11 |
+
- torchaudio=2.1.0
|
| 12 |
+
- torchvision=0.16.0
|
| 13 |
+
- pytorch-cuda=11.8
|
| 14 |
+
- pytorch3d=0.7.5
|
| 15 |
+
- numpy=1.23.3
|
| 16 |
+
- numba==0.56.4
|
| 17 |
+
- scipy==1.9.1
|
| 18 |
+
- py-opencv=4.6.0
|
| 19 |
+
- cffi=1.15.1
|
| 20 |
+
- ipykernel=6.16
|
| 21 |
+
- matplotlib=3.6.1
|
| 22 |
+
- zarr=2.12.0
|
| 23 |
+
- numcodecs=0.10.2
|
| 24 |
+
- h5py=3.7.0
|
| 25 |
+
- hydra-core=1.2.0
|
| 26 |
+
- einops=0.4.1
|
| 27 |
+
- tqdm=4.64.1
|
| 28 |
+
- dill=0.3.5.1
|
| 29 |
+
- scikit-video=1.1.11
|
| 30 |
+
- scikit-image=0.19.3
|
| 31 |
+
- gym=0.21.0
|
| 32 |
+
- pymunk=6.2.1
|
| 33 |
+
- threadpoolctl=3.1.0
|
| 34 |
+
- shapely=1.8.4
|
| 35 |
+
- cython=0.29.32
|
| 36 |
+
- imageio=2.22.0
|
| 37 |
+
- imageio-ffmpeg=0.4.7
|
| 38 |
+
- termcolor=2.0.1
|
| 39 |
+
- tensorboard=2.10.1
|
| 40 |
+
- tensorboardx=2.5.1
|
| 41 |
+
- psutil=5.9.2
|
| 42 |
+
- click=8.0.4
|
| 43 |
+
- boto3=1.24.96
|
| 44 |
+
- accelerate=0.13.2
|
| 45 |
+
- datasets=2.6.1
|
| 46 |
+
- diffusers=0.11.1
|
| 47 |
+
- av=10.0.0
|
| 48 |
+
- cmake=3.24.3
|
| 49 |
+
# trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
|
| 50 |
+
- llvm-openmp=14
|
| 51 |
+
# trick to force reinstall imagecodecs via pip
|
| 52 |
+
- imagecodecs==2022.8.8
|
| 53 |
+
- pip:
|
| 54 |
+
- open3d
|
| 55 |
+
- wandb==0.17.0
|
| 56 |
+
- pygame
|
| 57 |
+
- imagecodecs==2022.9.26
|
| 58 |
+
- escnn @ https://github.com/pointW/escnn/archive/fc4714cb6dc0d2a32f9fcea35771968b89911109.tar.gz
|
| 59 |
+
- robosuite @ https://github.com/ARISE-Initiative/robosuite/archive/b9d8d3de5e3dfd1724f4a0e6555246c460407daa.tar.gz
|
| 60 |
+
- robomimic @ https://github.com/pointW/robomimic/archive/8aad5b3caaaac9289b1504438a7f5d3a76d06c07.tar.gz
|
| 61 |
+
- robosuite-task-zoo @ https://github.com/pointW/robosuite-task-zoo/archive/0f8a7b2fa5d192e4e8800bebfe8090b28926f3ed.tar.gz
|
equidiff/equi_diffpo/codecs/imagecodecs_numcodecs.py
ADDED
|
@@ -0,0 +1,1386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# imagecodecs/numcodecs.py
|
| 3 |
+
|
| 4 |
+
# Copyright (c) 2021-2022, Christoph Gohlke
|
| 5 |
+
# All rights reserved.
|
| 6 |
+
#
|
| 7 |
+
# Redistribution and use in source and binary forms, with or without
|
| 8 |
+
# modification, are permitted provided that the following conditions are met:
|
| 9 |
+
#
|
| 10 |
+
# 1. Redistributions of source code must retain the above copyright notice,
|
| 11 |
+
# this list of conditions and the following disclaimer.
|
| 12 |
+
#
|
| 13 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 14 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 15 |
+
# and/or other materials provided with the distribution.
|
| 16 |
+
#
|
| 17 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 18 |
+
# contributors may be used to endorse or promote products derived from
|
| 19 |
+
# this software without specific prior written permission.
|
| 20 |
+
#
|
| 21 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 22 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 23 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
| 24 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
| 25 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
| 26 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
| 27 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
| 28 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
| 29 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
| 30 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
| 31 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
| 32 |
+
|
| 33 |
+
"""Additional numcodecs implemented using imagecodecs."""
|
| 34 |
+
|
| 35 |
+
__version__ = '2022.9.26'
|
| 36 |
+
|
| 37 |
+
__all__ = ('register_codecs',)
|
| 38 |
+
|
| 39 |
+
import numpy
|
| 40 |
+
from numcodecs.abc import Codec
|
| 41 |
+
from numcodecs.registry import register_codec, get_codec
|
| 42 |
+
|
| 43 |
+
import imagecodecs
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def protective_squeeze(x: numpy.ndarray):
|
| 47 |
+
"""
|
| 48 |
+
Squeeze dim only if it's not the last dim.
|
| 49 |
+
Image dim expected to be *, H, W, C
|
| 50 |
+
"""
|
| 51 |
+
img_shape = x.shape[-3:]
|
| 52 |
+
if len(x.shape) > 3:
|
| 53 |
+
n_imgs = numpy.prod(x.shape[:-3])
|
| 54 |
+
if n_imgs > 1:
|
| 55 |
+
img_shape = (-1,) + img_shape
|
| 56 |
+
return x.reshape(img_shape)
|
| 57 |
+
|
| 58 |
+
def get_default_image_compressor(**kwargs):
|
| 59 |
+
if imagecodecs.JPEGXL:
|
| 60 |
+
# has JPEGXL
|
| 61 |
+
this_kwargs = {
|
| 62 |
+
'effort': 3,
|
| 63 |
+
'distance': 0.3,
|
| 64 |
+
# bug in libjxl, invalid codestream for non-lossless
|
| 65 |
+
# when decoding speed > 1
|
| 66 |
+
'decodingspeed': 1
|
| 67 |
+
}
|
| 68 |
+
this_kwargs.update(kwargs)
|
| 69 |
+
return JpegXl(**this_kwargs)
|
| 70 |
+
else:
|
| 71 |
+
this_kwargs = {
|
| 72 |
+
'level': 50
|
| 73 |
+
}
|
| 74 |
+
this_kwargs.update(kwargs)
|
| 75 |
+
return Jpeg2k(**this_kwargs)
|
| 76 |
+
|
| 77 |
+
class Aec(Codec):
|
| 78 |
+
"""AEC codec for numcodecs."""
|
| 79 |
+
|
| 80 |
+
codec_id = 'imagecodecs_aec'
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self, bitspersample=None, flags=None, blocksize=None, rsi=None
|
| 84 |
+
):
|
| 85 |
+
self.bitspersample = bitspersample
|
| 86 |
+
self.flags = flags
|
| 87 |
+
self.blocksize = blocksize
|
| 88 |
+
self.rsi = rsi
|
| 89 |
+
|
| 90 |
+
def encode(self, buf):
|
| 91 |
+
return imagecodecs.aec_encode(
|
| 92 |
+
buf,
|
| 93 |
+
bitspersample=self.bitspersample,
|
| 94 |
+
flags=self.flags,
|
| 95 |
+
blocksize=self.blocksize,
|
| 96 |
+
rsi=self.rsi,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def decode(self, buf, out=None):
|
| 100 |
+
return imagecodecs.aec_decode(
|
| 101 |
+
buf,
|
| 102 |
+
bitspersample=self.bitspersample,
|
| 103 |
+
flags=self.flags,
|
| 104 |
+
blocksize=self.blocksize,
|
| 105 |
+
rsi=self.rsi,
|
| 106 |
+
out=_flat(out),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Apng(Codec):
|
| 111 |
+
"""APNG codec for numcodecs."""
|
| 112 |
+
|
| 113 |
+
codec_id = 'imagecodecs_apng'
|
| 114 |
+
|
| 115 |
+
def __init__(self, level=None, photometric=None, delay=None):
|
| 116 |
+
self.level = level
|
| 117 |
+
self.photometric = photometric
|
| 118 |
+
self.delay = delay
|
| 119 |
+
|
| 120 |
+
def encode(self, buf):
|
| 121 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 122 |
+
return imagecodecs.apng_encode(
|
| 123 |
+
buf,
|
| 124 |
+
level=self.level,
|
| 125 |
+
photometric=self.photometric,
|
| 126 |
+
delay=self.delay,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def decode(self, buf, out=None):
|
| 130 |
+
return imagecodecs.apng_decode(buf, out=out)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class Avif(Codec):
|
| 134 |
+
"""AVIF codec for numcodecs."""
|
| 135 |
+
|
| 136 |
+
codec_id = 'imagecodecs_avif'
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
level=None,
|
| 141 |
+
speed=None,
|
| 142 |
+
tilelog2=None,
|
| 143 |
+
bitspersample=None,
|
| 144 |
+
pixelformat=None,
|
| 145 |
+
numthreads=None,
|
| 146 |
+
index=None,
|
| 147 |
+
):
|
| 148 |
+
self.level = level
|
| 149 |
+
self.speed = speed
|
| 150 |
+
self.tilelog2 = tilelog2
|
| 151 |
+
self.bitspersample = bitspersample
|
| 152 |
+
self.pixelformat = pixelformat
|
| 153 |
+
self.numthreads = numthreads
|
| 154 |
+
self.index = index
|
| 155 |
+
|
| 156 |
+
def encode(self, buf):
|
| 157 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 158 |
+
return imagecodecs.avif_encode(
|
| 159 |
+
buf,
|
| 160 |
+
level=self.level,
|
| 161 |
+
speed=self.speed,
|
| 162 |
+
tilelog2=self.tilelog2,
|
| 163 |
+
bitspersample=self.bitspersample,
|
| 164 |
+
pixelformat=self.pixelformat,
|
| 165 |
+
numthreads=self.numthreads,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def decode(self, buf, out=None):
|
| 169 |
+
return imagecodecs.avif_decode(
|
| 170 |
+
buf, index=self.index, numthreads=self.numthreads, out=out
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class Bitorder(Codec):
|
| 175 |
+
"""Bitorder codec for numcodecs."""
|
| 176 |
+
|
| 177 |
+
codec_id = 'imagecodecs_bitorder'
|
| 178 |
+
|
| 179 |
+
def encode(self, buf):
|
| 180 |
+
return imagecodecs.bitorder_encode(buf)
|
| 181 |
+
|
| 182 |
+
def decode(self, buf, out=None):
|
| 183 |
+
return imagecodecs.bitorder_decode(buf, out=_flat(out))
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Bitshuffle(Codec):
|
| 187 |
+
"""Bitshuffle codec for numcodecs."""
|
| 188 |
+
|
| 189 |
+
codec_id = 'imagecodecs_bitshuffle'
|
| 190 |
+
|
| 191 |
+
def __init__(self, itemsize=1, blocksize=0):
|
| 192 |
+
self.itemsize = itemsize
|
| 193 |
+
self.blocksize = blocksize
|
| 194 |
+
|
| 195 |
+
def encode(self, buf):
|
| 196 |
+
return imagecodecs.bitshuffle_encode(
|
| 197 |
+
buf, itemsize=self.itemsize, blocksize=self.blocksize
|
| 198 |
+
).tobytes()
|
| 199 |
+
|
| 200 |
+
def decode(self, buf, out=None):
|
| 201 |
+
return imagecodecs.bitshuffle_decode(
|
| 202 |
+
buf,
|
| 203 |
+
itemsize=self.itemsize,
|
| 204 |
+
blocksize=self.blocksize,
|
| 205 |
+
out=_flat(out),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class Blosc(Codec):
|
| 210 |
+
"""Blosc codec for numcodecs."""
|
| 211 |
+
|
| 212 |
+
codec_id = 'imagecodecs_blosc'
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
level=None,
|
| 217 |
+
compressor=None,
|
| 218 |
+
typesize=None,
|
| 219 |
+
blocksize=None,
|
| 220 |
+
shuffle=None,
|
| 221 |
+
numthreads=None,
|
| 222 |
+
):
|
| 223 |
+
self.level = level
|
| 224 |
+
self.compressor = compressor
|
| 225 |
+
self.typesize = typesize
|
| 226 |
+
self.blocksize = blocksize
|
| 227 |
+
self.shuffle = shuffle
|
| 228 |
+
self.numthreads = numthreads
|
| 229 |
+
|
| 230 |
+
def encode(self, buf):
|
| 231 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 232 |
+
return imagecodecs.blosc_encode(
|
| 233 |
+
buf,
|
| 234 |
+
level=self.level,
|
| 235 |
+
compressor=self.compressor,
|
| 236 |
+
typesize=self.typesize,
|
| 237 |
+
blocksize=self.blocksize,
|
| 238 |
+
shuffle=self.shuffle,
|
| 239 |
+
numthreads=self.numthreads,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def decode(self, buf, out=None):
|
| 243 |
+
return imagecodecs.blosc_decode(
|
| 244 |
+
buf, numthreads=self.numthreads, out=_flat(out)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class Blosc2(Codec):
|
| 249 |
+
"""Blosc2 codec for numcodecs."""
|
| 250 |
+
|
| 251 |
+
codec_id = 'imagecodecs_blosc2'
|
| 252 |
+
|
| 253 |
+
def __init__(
|
| 254 |
+
self,
|
| 255 |
+
level=None,
|
| 256 |
+
compressor=None,
|
| 257 |
+
typesize=None,
|
| 258 |
+
blocksize=None,
|
| 259 |
+
shuffle=None,
|
| 260 |
+
numthreads=None,
|
| 261 |
+
):
|
| 262 |
+
self.level = level
|
| 263 |
+
self.compressor = compressor
|
| 264 |
+
self.typesize = typesize
|
| 265 |
+
self.blocksize = blocksize
|
| 266 |
+
self.shuffle = shuffle
|
| 267 |
+
self.numthreads = numthreads
|
| 268 |
+
|
| 269 |
+
def encode(self, buf):
|
| 270 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 271 |
+
return imagecodecs.blosc2_encode(
|
| 272 |
+
buf,
|
| 273 |
+
level=self.level,
|
| 274 |
+
compressor=self.compressor,
|
| 275 |
+
typesize=self.typesize,
|
| 276 |
+
blocksize=self.blocksize,
|
| 277 |
+
shuffle=self.shuffle,
|
| 278 |
+
numthreads=self.numthreads,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def decode(self, buf, out=None):
|
| 282 |
+
return imagecodecs.blosc2_decode(
|
| 283 |
+
buf, numthreads=self.numthreads, out=_flat(out)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Brotli(Codec):
|
| 288 |
+
"""Brotli codec for numcodecs."""
|
| 289 |
+
|
| 290 |
+
codec_id = 'imagecodecs_brotli'
|
| 291 |
+
|
| 292 |
+
def __init__(self, level=None, mode=None, lgwin=None):
|
| 293 |
+
self.level = level
|
| 294 |
+
self.mode = mode
|
| 295 |
+
self.lgwin = lgwin
|
| 296 |
+
|
| 297 |
+
def encode(self, buf):
|
| 298 |
+
return imagecodecs.brotli_encode(
|
| 299 |
+
buf, level=self.level, mode=self.mode, lgwin=self.lgwin
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def decode(self, buf, out=None):
|
| 303 |
+
return imagecodecs.brotli_decode(buf, out=_flat(out))
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class ByteShuffle(Codec):
|
| 307 |
+
"""ByteShuffle codec for numcodecs."""
|
| 308 |
+
|
| 309 |
+
codec_id = 'imagecodecs_byteshuffle'
|
| 310 |
+
|
| 311 |
+
def __init__(
|
| 312 |
+
self, shape, dtype, axis=-1, dist=1, delta=False, reorder=False
|
| 313 |
+
):
|
| 314 |
+
self.shape = tuple(shape)
|
| 315 |
+
self.dtype = numpy.dtype(dtype).str
|
| 316 |
+
self.axis = axis
|
| 317 |
+
self.dist = dist
|
| 318 |
+
self.delta = bool(delta)
|
| 319 |
+
self.reorder = bool(reorder)
|
| 320 |
+
|
| 321 |
+
def encode(self, buf):
|
| 322 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 323 |
+
assert buf.shape == self.shape
|
| 324 |
+
assert buf.dtype == self.dtype
|
| 325 |
+
return imagecodecs.byteshuffle_encode(
|
| 326 |
+
buf,
|
| 327 |
+
axis=self.axis,
|
| 328 |
+
dist=self.dist,
|
| 329 |
+
delta=self.delta,
|
| 330 |
+
reorder=self.reorder,
|
| 331 |
+
).tobytes()
|
| 332 |
+
|
| 333 |
+
def decode(self, buf, out=None):
|
| 334 |
+
if not isinstance(buf, numpy.ndarray):
|
| 335 |
+
buf = numpy.frombuffer(buf, dtype=self.dtype).reshape(*self.shape)
|
| 336 |
+
return imagecodecs.byteshuffle_decode(
|
| 337 |
+
buf,
|
| 338 |
+
axis=self.axis,
|
| 339 |
+
dist=self.dist,
|
| 340 |
+
delta=self.delta,
|
| 341 |
+
reorder=self.reorder,
|
| 342 |
+
out=out,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class Bz2(Codec):
|
| 347 |
+
"""Bz2 codec for numcodecs."""
|
| 348 |
+
|
| 349 |
+
codec_id = 'imagecodecs_bz2'
|
| 350 |
+
|
| 351 |
+
def __init__(self, level=None):
|
| 352 |
+
self.level = level
|
| 353 |
+
|
| 354 |
+
def encode(self, buf):
|
| 355 |
+
return imagecodecs.bz2_encode(buf, level=self.level)
|
| 356 |
+
|
| 357 |
+
def decode(self, buf, out=None):
|
| 358 |
+
return imagecodecs.bz2_decode(buf, out=_flat(out))
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class Cms(Codec):
|
| 362 |
+
"""CMS codec for numcodecs."""
|
| 363 |
+
|
| 364 |
+
codec_id = 'imagecodecs_cms'
|
| 365 |
+
|
| 366 |
+
def __init__(self, *args, **kwargs):
|
| 367 |
+
pass
|
| 368 |
+
|
| 369 |
+
def encode(self, buf, out=None):
|
| 370 |
+
# return imagecodecs.cms_transform(buf)
|
| 371 |
+
raise NotImplementedError
|
| 372 |
+
|
| 373 |
+
def decode(self, buf, out=None):
|
| 374 |
+
# return imagecodecs.cms_transform(buf)
|
| 375 |
+
raise NotImplementedError
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class Deflate(Codec):
|
| 379 |
+
"""Deflate codec for numcodecs."""
|
| 380 |
+
|
| 381 |
+
codec_id = 'imagecodecs_deflate'
|
| 382 |
+
|
| 383 |
+
def __init__(self, level=None, raw=False):
|
| 384 |
+
self.level = level
|
| 385 |
+
self.raw = bool(raw)
|
| 386 |
+
|
| 387 |
+
def encode(self, buf):
|
| 388 |
+
return imagecodecs.deflate_encode(buf, level=self.level, raw=self.raw)
|
| 389 |
+
|
| 390 |
+
def decode(self, buf, out=None):
|
| 391 |
+
return imagecodecs.deflate_decode(buf, out=_flat(out), raw=self.raw)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class Delta(Codec):
|
| 395 |
+
"""Delta codec for numcodecs."""
|
| 396 |
+
|
| 397 |
+
codec_id = 'imagecodecs_delta'
|
| 398 |
+
|
| 399 |
+
def __init__(self, shape=None, dtype=None, axis=-1, dist=1):
|
| 400 |
+
self.shape = None if shape is None else tuple(shape)
|
| 401 |
+
self.dtype = None if dtype is None else numpy.dtype(dtype).str
|
| 402 |
+
self.axis = axis
|
| 403 |
+
self.dist = dist
|
| 404 |
+
|
| 405 |
+
def encode(self, buf):
|
| 406 |
+
if self.shape is not None or self.dtype is not None:
|
| 407 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 408 |
+
assert buf.shape == self.shape
|
| 409 |
+
assert buf.dtype == self.dtype
|
| 410 |
+
return imagecodecs.delta_encode(
|
| 411 |
+
buf, axis=self.axis, dist=self.dist
|
| 412 |
+
).tobytes()
|
| 413 |
+
|
| 414 |
+
def decode(self, buf, out=None):
|
| 415 |
+
if self.shape is not None or self.dtype is not None:
|
| 416 |
+
buf = numpy.frombuffer(buf, dtype=self.dtype).reshape(*self.shape)
|
| 417 |
+
return imagecodecs.delta_decode(
|
| 418 |
+
buf, axis=self.axis, dist=self.dist, out=out
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class Float24(Codec):
|
| 423 |
+
"""Float24 codec for numcodecs."""
|
| 424 |
+
|
| 425 |
+
codec_id = 'imagecodecs_float24'
|
| 426 |
+
|
| 427 |
+
def __init__(self, byteorder=None, rounding=None):
|
| 428 |
+
self.byteorder = byteorder
|
| 429 |
+
self.rounding = rounding
|
| 430 |
+
|
| 431 |
+
def encode(self, buf):
|
| 432 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 433 |
+
return imagecodecs.float24_encode(
|
| 434 |
+
buf, byteorder=self.byteorder, rounding=self.rounding
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def decode(self, buf, out=None):
|
| 438 |
+
return imagecodecs.float24_decode(
|
| 439 |
+
buf, byteorder=self.byteorder, out=out
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class FloatPred(Codec):
|
| 444 |
+
"""Floating Point Predictor codec for numcodecs."""
|
| 445 |
+
|
| 446 |
+
codec_id = 'imagecodecs_floatpred'
|
| 447 |
+
|
| 448 |
+
def __init__(self, shape, dtype, axis=-1, dist=1):
|
| 449 |
+
self.shape = tuple(shape)
|
| 450 |
+
self.dtype = numpy.dtype(dtype).str
|
| 451 |
+
self.axis = axis
|
| 452 |
+
self.dist = dist
|
| 453 |
+
|
| 454 |
+
def encode(self, buf):
|
| 455 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 456 |
+
assert buf.shape == self.shape
|
| 457 |
+
assert buf.dtype == self.dtype
|
| 458 |
+
return imagecodecs.floatpred_encode(
|
| 459 |
+
buf, axis=self.axis, dist=self.dist
|
| 460 |
+
).tobytes()
|
| 461 |
+
|
| 462 |
+
def decode(self, buf, out=None):
|
| 463 |
+
if not isinstance(buf, numpy.ndarray):
|
| 464 |
+
buf = numpy.frombuffer(buf, dtype=self.dtype).reshape(*self.shape)
|
| 465 |
+
return imagecodecs.floatpred_decode(
|
| 466 |
+
buf, axis=self.axis, dist=self.dist, out=out
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class Gif(Codec):
|
| 471 |
+
"""GIF codec for numcodecs."""
|
| 472 |
+
|
| 473 |
+
codec_id = 'imagecodecs_gif'
|
| 474 |
+
|
| 475 |
+
def encode(self, buf):
|
| 476 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 477 |
+
return imagecodecs.gif_encode(buf)
|
| 478 |
+
|
| 479 |
+
def decode(self, buf, out=None):
|
| 480 |
+
return imagecodecs.gif_decode(buf, asrgb=False, out=out)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class Heif(Codec):
|
| 484 |
+
"""HEIF codec for numcodecs."""
|
| 485 |
+
|
| 486 |
+
codec_id = 'imagecodecs_heif'
|
| 487 |
+
|
| 488 |
+
def __init__(
|
| 489 |
+
self,
|
| 490 |
+
level=None,
|
| 491 |
+
bitspersample=None,
|
| 492 |
+
photometric=None,
|
| 493 |
+
compression=None,
|
| 494 |
+
numthreads=None,
|
| 495 |
+
index=None,
|
| 496 |
+
):
|
| 497 |
+
self.level = level
|
| 498 |
+
self.bitspersample = bitspersample
|
| 499 |
+
self.photometric = photometric
|
| 500 |
+
self.compression = compression
|
| 501 |
+
self.numthreads = numthreads
|
| 502 |
+
self.index = index
|
| 503 |
+
|
| 504 |
+
def encode(self, buf):
|
| 505 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 506 |
+
return imagecodecs.heif_encode(
|
| 507 |
+
buf,
|
| 508 |
+
level=self.level,
|
| 509 |
+
bitspersample=self.bitspersample,
|
| 510 |
+
photometric=self.photometric,
|
| 511 |
+
compression=self.compression,
|
| 512 |
+
numthreads=self.numthreads,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
def decode(self, buf, out=None):
|
| 516 |
+
return imagecodecs.heif_decode(
|
| 517 |
+
buf,
|
| 518 |
+
index=self.index,
|
| 519 |
+
photometric=self.photometric,
|
| 520 |
+
numthreads=self.numthreads,
|
| 521 |
+
out=out,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class Jetraw(Codec):
|
| 526 |
+
"""Jetraw codec for numcodecs."""
|
| 527 |
+
|
| 528 |
+
codec_id = 'imagecodecs_jetraw'
|
| 529 |
+
|
| 530 |
+
def __init__(
|
| 531 |
+
self,
|
| 532 |
+
shape,
|
| 533 |
+
identifier,
|
| 534 |
+
parameters=None,
|
| 535 |
+
verbosity=None,
|
| 536 |
+
errorbound=None,
|
| 537 |
+
):
|
| 538 |
+
self.shape = shape
|
| 539 |
+
self.identifier = identifier
|
| 540 |
+
self.errorbound = errorbound
|
| 541 |
+
imagecodecs.jetraw_init(parameters, verbosity)
|
| 542 |
+
|
| 543 |
+
def encode(self, buf):
|
| 544 |
+
return imagecodecs.jetraw_encode(
|
| 545 |
+
buf, identifier=self.identifier, errorbound=self.errorbound
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
def decode(self, buf, out=None):
|
| 549 |
+
if out is None:
|
| 550 |
+
out = numpy.empty(self.shape, numpy.uint16)
|
| 551 |
+
return imagecodecs.jetraw_decode(buf, out=out)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class Jpeg(Codec):
|
| 555 |
+
"""JPEG codec for numcodecs."""
|
| 556 |
+
|
| 557 |
+
codec_id = 'imagecodecs_jpeg'
|
| 558 |
+
|
| 559 |
+
def __init__(
|
| 560 |
+
self,
|
| 561 |
+
bitspersample=None,
|
| 562 |
+
tables=None,
|
| 563 |
+
header=None,
|
| 564 |
+
colorspace_data=None,
|
| 565 |
+
colorspace_jpeg=None,
|
| 566 |
+
level=None,
|
| 567 |
+
subsampling=None,
|
| 568 |
+
optimize=None,
|
| 569 |
+
smoothing=None,
|
| 570 |
+
):
|
| 571 |
+
self.tables = tables
|
| 572 |
+
self.header = header
|
| 573 |
+
self.bitspersample = bitspersample
|
| 574 |
+
self.colorspace_data = colorspace_data
|
| 575 |
+
self.colorspace_jpeg = colorspace_jpeg
|
| 576 |
+
self.level = level
|
| 577 |
+
self.subsampling = subsampling
|
| 578 |
+
self.optimize = optimize
|
| 579 |
+
self.smoothing = smoothing
|
| 580 |
+
|
| 581 |
+
def encode(self, buf):
|
| 582 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 583 |
+
return imagecodecs.jpeg_encode(
|
| 584 |
+
buf,
|
| 585 |
+
level=self.level,
|
| 586 |
+
colorspace=self.colorspace_data,
|
| 587 |
+
outcolorspace=self.colorspace_jpeg,
|
| 588 |
+
subsampling=self.subsampling,
|
| 589 |
+
optimize=self.optimize,
|
| 590 |
+
smoothing=self.smoothing,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
def decode(self, buf, out=None):
|
| 594 |
+
out_shape = None
|
| 595 |
+
if out is not None:
|
| 596 |
+
out_shape = out.shape
|
| 597 |
+
out = protective_squeeze(out)
|
| 598 |
+
img = imagecodecs.jpeg_decode(
|
| 599 |
+
buf,
|
| 600 |
+
bitspersample=self.bitspersample,
|
| 601 |
+
tables=self.tables,
|
| 602 |
+
header=self.header,
|
| 603 |
+
colorspace=self.colorspace_jpeg,
|
| 604 |
+
outcolorspace=self.colorspace_data,
|
| 605 |
+
out=out,
|
| 606 |
+
)
|
| 607 |
+
if out_shape is not None:
|
| 608 |
+
img = img.reshape(out_shape)
|
| 609 |
+
return img
|
| 610 |
+
|
| 611 |
+
def get_config(self):
|
| 612 |
+
"""Return dictionary holding configuration parameters."""
|
| 613 |
+
config = dict(id=self.codec_id)
|
| 614 |
+
for key in self.__dict__:
|
| 615 |
+
if not key.startswith('_'):
|
| 616 |
+
value = getattr(self, key)
|
| 617 |
+
if value is not None and key in ('header', 'tables'):
|
| 618 |
+
import base64
|
| 619 |
+
|
| 620 |
+
value = base64.b64encode(value).decode()
|
| 621 |
+
config[key] = value
|
| 622 |
+
return config
|
| 623 |
+
|
| 624 |
+
@classmethod
|
| 625 |
+
def from_config(cls, config):
|
| 626 |
+
"""Instantiate codec from configuration object."""
|
| 627 |
+
for key in ('header', 'tables'):
|
| 628 |
+
value = config.get(key, None)
|
| 629 |
+
if value is not None and isinstance(value, str):
|
| 630 |
+
import base64
|
| 631 |
+
|
| 632 |
+
config[key] = base64.b64decode(value.encode())
|
| 633 |
+
return cls(**config)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
class Jpeg2k(Codec):
|
| 637 |
+
"""JPEG 2000 codec for numcodecs."""
|
| 638 |
+
|
| 639 |
+
codec_id = 'imagecodecs_jpeg2k'
|
| 640 |
+
|
| 641 |
+
def __init__(
|
| 642 |
+
self,
|
| 643 |
+
level=None,
|
| 644 |
+
codecformat=None,
|
| 645 |
+
colorspace=None,
|
| 646 |
+
tile=None,
|
| 647 |
+
reversible=None,
|
| 648 |
+
bitspersample=None,
|
| 649 |
+
resolutions=None,
|
| 650 |
+
numthreads=None,
|
| 651 |
+
verbose=0,
|
| 652 |
+
):
|
| 653 |
+
self.level = level
|
| 654 |
+
self.codecformat = codecformat
|
| 655 |
+
self.colorspace = colorspace
|
| 656 |
+
self.tile = None if tile is None else tuple(tile)
|
| 657 |
+
self.reversible = reversible
|
| 658 |
+
self.bitspersample = bitspersample
|
| 659 |
+
self.resolutions = resolutions
|
| 660 |
+
self.numthreads = numthreads
|
| 661 |
+
self.verbose = verbose
|
| 662 |
+
|
| 663 |
+
def encode(self, buf):
|
| 664 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 665 |
+
return imagecodecs.jpeg2k_encode(
|
| 666 |
+
buf,
|
| 667 |
+
level=self.level,
|
| 668 |
+
codecformat=self.codecformat,
|
| 669 |
+
colorspace=self.colorspace,
|
| 670 |
+
tile=self.tile,
|
| 671 |
+
reversible=self.reversible,
|
| 672 |
+
bitspersample=self.bitspersample,
|
| 673 |
+
resolutions=self.resolutions,
|
| 674 |
+
numthreads=self.numthreads,
|
| 675 |
+
verbose=self.verbose,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
def decode(self, buf, out=None):
|
| 679 |
+
return imagecodecs.jpeg2k_decode(
|
| 680 |
+
buf, verbose=self.verbose, numthreads=self.numthreads, out=out
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class JpegLs(Codec):
|
| 685 |
+
"""JPEG LS codec for numcodecs."""
|
| 686 |
+
|
| 687 |
+
codec_id = 'imagecodecs_jpegls'
|
| 688 |
+
|
| 689 |
+
def __init__(self, level=None):
|
| 690 |
+
self.level = level
|
| 691 |
+
|
| 692 |
+
def encode(self, buf):
|
| 693 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 694 |
+
return imagecodecs.jpegls_encode(buf, level=self.level)
|
| 695 |
+
|
| 696 |
+
def decode(self, buf, out=None):
|
| 697 |
+
return imagecodecs.jpegls_decode(buf, out=out)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class JpegXl(Codec):
|
| 701 |
+
"""JPEG XL codec for numcodecs."""
|
| 702 |
+
|
| 703 |
+
codec_id = 'imagecodecs_jpegxl'
|
| 704 |
+
|
| 705 |
+
def __init__(
|
| 706 |
+
self,
|
| 707 |
+
# encode
|
| 708 |
+
level=None,
|
| 709 |
+
effort=None,
|
| 710 |
+
distance=None,
|
| 711 |
+
lossless=None,
|
| 712 |
+
decodingspeed=None,
|
| 713 |
+
photometric=None,
|
| 714 |
+
planar=None,
|
| 715 |
+
usecontainer=None,
|
| 716 |
+
# decode
|
| 717 |
+
index=None,
|
| 718 |
+
keeporientation=None,
|
| 719 |
+
# both
|
| 720 |
+
numthreads=None,
|
| 721 |
+
):
|
| 722 |
+
"""
|
| 723 |
+
Return JPEG XL image from numpy array.
|
| 724 |
+
Float must be in nominal range 0..1.
|
| 725 |
+
|
| 726 |
+
Currently L, LA, RGB, RGBA images are supported in contig mode.
|
| 727 |
+
Extra channels are only supported for grayscale images in planar mode.
|
| 728 |
+
|
| 729 |
+
Parameters
|
| 730 |
+
----------
|
| 731 |
+
level : Default to None, i.e. not overwriting lossess and decodingspeed options.
|
| 732 |
+
When < 0: Use lossless compression
|
| 733 |
+
When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
|
| 734 |
+
Minimum is 0 (slowest to decode, best quality/density), and maximum
|
| 735 |
+
is 4 (fastest to decode, at the cost of some quality/density).
|
| 736 |
+
effort : Default to 3.
|
| 737 |
+
Sets encoder effort/speed level without affecting decoding speed.
|
| 738 |
+
Valid values are, from faster to slower speed: 1:lightning 2:thunder
|
| 739 |
+
3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
|
| 740 |
+
Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
|
| 741 |
+
control the encoder effort in ascending order.
|
| 742 |
+
This also affects memory usage: using lower effort will typically reduce memory
|
| 743 |
+
consumption during encoding.
|
| 744 |
+
lightning and thunder are fast modes useful for lossless mode (modular).
|
| 745 |
+
falcon disables all of the following tools.
|
| 746 |
+
cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
|
| 747 |
+
hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
|
| 748 |
+
wombat enables error diffusion quantization and full DCT size selection heuristics.
|
| 749 |
+
squirrel (default) enables dots, patches, and spline detection, and full context clustering.
|
| 750 |
+
kitten optimizes the adaptive quantization for a psychovisual metric.
|
| 751 |
+
tortoise enables a more thorough adaptive quantization search.
|
| 752 |
+
distance : Default to 1.0
|
| 753 |
+
Sets the distance level for lossy compression: target max butteraugli distance,
|
| 754 |
+
lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
|
| 755 |
+
(however, use JxlEncoderSetFrameLossless instead to use true lossless,
|
| 756 |
+
as setting distance to 0 alone is not the only requirement).
|
| 757 |
+
1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
|
| 758 |
+
lossess : Default to False.
|
| 759 |
+
Use lossess encoding.
|
| 760 |
+
decodingspeed : Default to 0.
|
| 761 |
+
Duplicate to level. [0,4]
|
| 762 |
+
photometric : Return JxlColorSpace value.
|
| 763 |
+
Default logic is quite complicated but works most of the time.
|
| 764 |
+
Accepted value:
|
| 765 |
+
int: [-1,3]
|
| 766 |
+
str: ['RGB',
|
| 767 |
+
'WHITEISZERO', 'MINISWHITE',
|
| 768 |
+
'BLACKISZERO', 'MINISBLACK', 'GRAY',
|
| 769 |
+
'XYB', 'KNOWN']
|
| 770 |
+
planar : Enable multi-channel mode.
|
| 771 |
+
Default to false.
|
| 772 |
+
usecontainer :
|
| 773 |
+
Forces the encoder to use the box-based container format (BMFF)
|
| 774 |
+
even when not necessary.
|
| 775 |
+
When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
|
| 776 |
+
JxlEncoderSetCodestreamLevel with level 10, the encoder will
|
| 777 |
+
automatically also use the container format, it is not necessary
|
| 778 |
+
to use JxlEncoderUseContainer for those use cases.
|
| 779 |
+
By default this setting is disabled.
|
| 780 |
+
index : Selectively decode frames for animation.
|
| 781 |
+
Default to 0, decode all frames.
|
| 782 |
+
When set to > 0, decode that frame index only.
|
| 783 |
+
keeporientation :
|
| 784 |
+
Enables or disables preserving of as-in-bitstream pixeldata orientation.
|
| 785 |
+
Some images are encoded with an Orientation tag indicating that the
|
| 786 |
+
decoder must perform a rotation and/or mirroring to the encoded image data.
|
| 787 |
+
|
| 788 |
+
If skip_reorientation is JXL_FALSE (the default): the decoder will apply
|
| 789 |
+
the transformation from the orientation setting, hence rendering the image
|
| 790 |
+
according to its specified intent. When producing a JxlBasicInfo, the decoder
|
| 791 |
+
will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
|
| 792 |
+
returned pixel data) and also align xsize and ysize so that they correspond
|
| 793 |
+
to the width and the height of the returned pixel data.
|
| 794 |
+
|
| 795 |
+
If skip_reorientation is JXL_TRUE: the decoder will skip applying the
|
| 796 |
+
transformation from the orientation setting, returning the image in
|
| 797 |
+
the as-in-bitstream pixeldata orientation. This may be faster to decode
|
| 798 |
+
since the decoder doesnt have to apply the transformation, but can
|
| 799 |
+
cause wrong display of the image if the orientation tag is not correctly
|
| 800 |
+
taken into account by the user.
|
| 801 |
+
|
| 802 |
+
By default, this option is disabled, and the returned pixel data is
|
| 803 |
+
re-oriented according to the images Orientation setting.
|
| 804 |
+
threads : Default to 1.
|
| 805 |
+
If <= 0, use all cores.
|
| 806 |
+
If > 32, clipped to 32.
|
| 807 |
+
"""
|
| 808 |
+
|
| 809 |
+
self.level = level
|
| 810 |
+
self.effort = effort
|
| 811 |
+
self.distance = distance
|
| 812 |
+
self.lossless = bool(lossless)
|
| 813 |
+
self.decodingspeed = decodingspeed
|
| 814 |
+
self.photometric = photometric
|
| 815 |
+
self.planar = planar
|
| 816 |
+
self.usecontainer = usecontainer
|
| 817 |
+
self.index = index
|
| 818 |
+
self.keeporientation = keeporientation
|
| 819 |
+
self.numthreads = numthreads
|
| 820 |
+
|
| 821 |
+
def encode(self, buf):
|
| 822 |
+
# TODO: only squeeze all but last dim
|
| 823 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 824 |
+
return imagecodecs.jpegxl_encode(
|
| 825 |
+
buf,
|
| 826 |
+
level=self.level,
|
| 827 |
+
effort=self.effort,
|
| 828 |
+
distance=self.distance,
|
| 829 |
+
lossless=self.lossless,
|
| 830 |
+
decodingspeed=self.decodingspeed,
|
| 831 |
+
photometric=self.photometric,
|
| 832 |
+
planar=self.planar,
|
| 833 |
+
usecontainer=self.usecontainer,
|
| 834 |
+
numthreads=self.numthreads,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
def decode(self, buf, out=None):
|
| 838 |
+
return imagecodecs.jpegxl_decode(
|
| 839 |
+
buf,
|
| 840 |
+
index=self.index,
|
| 841 |
+
keeporientation=self.keeporientation,
|
| 842 |
+
numthreads=self.numthreads,
|
| 843 |
+
out=out,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
class JpegXr(Codec):
|
| 848 |
+
"""JPEG XR codec for numcodecs."""
|
| 849 |
+
|
| 850 |
+
codec_id = 'imagecodecs_jpegxr'
|
| 851 |
+
|
| 852 |
+
def __init__(
|
| 853 |
+
self,
|
| 854 |
+
level=None,
|
| 855 |
+
photometric=None,
|
| 856 |
+
hasalpha=None,
|
| 857 |
+
resolution=None,
|
| 858 |
+
fp2int=None,
|
| 859 |
+
):
|
| 860 |
+
self.level = level
|
| 861 |
+
self.photometric = photometric
|
| 862 |
+
self.hasalpha = hasalpha
|
| 863 |
+
self.resolution = resolution
|
| 864 |
+
self.fp2int = fp2int
|
| 865 |
+
|
| 866 |
+
def encode(self, buf):
|
| 867 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 868 |
+
return imagecodecs.jpegxr_encode(
|
| 869 |
+
buf,
|
| 870 |
+
level=self.level,
|
| 871 |
+
photometric=self.photometric,
|
| 872 |
+
hasalpha=self.hasalpha,
|
| 873 |
+
resolution=self.resolution,
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
def decode(self, buf, out=None):
|
| 877 |
+
return imagecodecs.jpegxr_decode(buf, fp2int=self.fp2int, out=out)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
class Lerc(Codec):
|
| 881 |
+
"""LERC codec for numcodecs."""
|
| 882 |
+
|
| 883 |
+
codec_id = 'imagecodecs_lerc'
|
| 884 |
+
|
| 885 |
+
def __init__(self, level=None, version=None, planar=None):
|
| 886 |
+
self.level = level
|
| 887 |
+
self.version = version
|
| 888 |
+
self.planar = bool(planar)
|
| 889 |
+
# TODO: support mask?
|
| 890 |
+
# self.mask = None
|
| 891 |
+
|
| 892 |
+
def encode(self, buf):
|
| 893 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 894 |
+
return imagecodecs.lerc_encode(
|
| 895 |
+
buf,
|
| 896 |
+
level=self.level,
|
| 897 |
+
version=self.version,
|
| 898 |
+
planar=self.planar,
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
def decode(self, buf, out=None):
|
| 902 |
+
return imagecodecs.lerc_decode(buf, out=out)
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class Ljpeg(Codec):
|
| 906 |
+
"""LJPEG codec for numcodecs."""
|
| 907 |
+
|
| 908 |
+
codec_id = 'imagecodecs_ljpeg'
|
| 909 |
+
|
| 910 |
+
def __init__(self, bitspersample=None):
|
| 911 |
+
self.bitspersample = bitspersample
|
| 912 |
+
|
| 913 |
+
def encode(self, buf):
|
| 914 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 915 |
+
return imagecodecs.ljpeg_encode(buf, bitspersample=self.bitspersample)
|
| 916 |
+
|
| 917 |
+
def decode(self, buf, out=None):
|
| 918 |
+
return imagecodecs.ljpeg_decode(buf, out=out)
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
class Lz4(Codec):
|
| 922 |
+
"""LZ4 codec for numcodecs."""
|
| 923 |
+
|
| 924 |
+
codec_id = 'imagecodecs_lz4'
|
| 925 |
+
|
| 926 |
+
def __init__(self, level=None, hc=False, header=True):
|
| 927 |
+
self.level = level
|
| 928 |
+
self.hc = hc
|
| 929 |
+
self.header = bool(header)
|
| 930 |
+
|
| 931 |
+
def encode(self, buf):
|
| 932 |
+
return imagecodecs.lz4_encode(
|
| 933 |
+
buf, level=self.level, hc=self.hc, header=self.header
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
def decode(self, buf, out=None):
|
| 937 |
+
return imagecodecs.lz4_decode(buf, header=self.header, out=_flat(out))
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
class Lz4f(Codec):
|
| 941 |
+
"""LZ4F codec for numcodecs."""
|
| 942 |
+
|
| 943 |
+
codec_id = 'imagecodecs_lz4f'
|
| 944 |
+
|
| 945 |
+
def __init__(
|
| 946 |
+
self,
|
| 947 |
+
level=None,
|
| 948 |
+
blocksizeid=False,
|
| 949 |
+
contentchecksum=None,
|
| 950 |
+
blockchecksum=None,
|
| 951 |
+
):
|
| 952 |
+
self.level = level
|
| 953 |
+
self.blocksizeid = blocksizeid
|
| 954 |
+
self.contentchecksum = contentchecksum
|
| 955 |
+
self.blockchecksum = blockchecksum
|
| 956 |
+
|
| 957 |
+
def encode(self, buf):
|
| 958 |
+
return imagecodecs.lz4f_encode(
|
| 959 |
+
buf,
|
| 960 |
+
level=self.level,
|
| 961 |
+
blocksizeid=self.blocksizeid,
|
| 962 |
+
contentchecksum=self.contentchecksum,
|
| 963 |
+
blockchecksum=self.blockchecksum,
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
def decode(self, buf, out=None):
|
| 967 |
+
return imagecodecs.lz4f_decode(buf, out=_flat(out))
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
class Lzf(Codec):
|
| 971 |
+
"""LZF codec for numcodecs."""
|
| 972 |
+
|
| 973 |
+
codec_id = 'imagecodecs_lzf'
|
| 974 |
+
|
| 975 |
+
def __init__(self, header=True):
|
| 976 |
+
self.header = bool(header)
|
| 977 |
+
|
| 978 |
+
def encode(self, buf):
|
| 979 |
+
return imagecodecs.lzf_encode(buf, header=self.header)
|
| 980 |
+
|
| 981 |
+
def decode(self, buf, out=None):
|
| 982 |
+
return imagecodecs.lzf_decode(buf, header=self.header, out=_flat(out))
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
class Lzma(Codec):
|
| 986 |
+
"""LZMA codec for numcodecs."""
|
| 987 |
+
|
| 988 |
+
codec_id = 'imagecodecs_lzma'
|
| 989 |
+
|
| 990 |
+
def __init__(self, level=None):
|
| 991 |
+
self.level = level
|
| 992 |
+
|
| 993 |
+
def encode(self, buf):
|
| 994 |
+
return imagecodecs.lzma_encode(buf, level=self.level)
|
| 995 |
+
|
| 996 |
+
def decode(self, buf, out=None):
|
| 997 |
+
return imagecodecs.lzma_decode(buf, out=_flat(out))
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
class Lzw(Codec):
|
| 1001 |
+
"""LZW codec for numcodecs."""
|
| 1002 |
+
|
| 1003 |
+
codec_id = 'imagecodecs_lzw'
|
| 1004 |
+
|
| 1005 |
+
def encode(self, buf):
|
| 1006 |
+
return imagecodecs.lzw_encode(buf)
|
| 1007 |
+
|
| 1008 |
+
def decode(self, buf, out=None):
|
| 1009 |
+
return imagecodecs.lzw_decode(buf, out=_flat(out))
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
class PackBits(Codec):
|
| 1013 |
+
"""PackBits codec for numcodecs."""
|
| 1014 |
+
|
| 1015 |
+
codec_id = 'imagecodecs_packbits'
|
| 1016 |
+
|
| 1017 |
+
def __init__(self, axis=None):
|
| 1018 |
+
self.axis = axis
|
| 1019 |
+
|
| 1020 |
+
def encode(self, buf):
|
| 1021 |
+
if not isinstance(buf, (bytes, bytearray)):
|
| 1022 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1023 |
+
return imagecodecs.packbits_encode(buf, axis=self.axis)
|
| 1024 |
+
|
| 1025 |
+
def decode(self, buf, out=None):
|
| 1026 |
+
return imagecodecs.packbits_decode(buf, out=_flat(out))
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
class Pglz(Codec):
|
| 1030 |
+
"""PGLZ codec for numcodecs."""
|
| 1031 |
+
|
| 1032 |
+
codec_id = 'imagecodecs_pglz'
|
| 1033 |
+
|
| 1034 |
+
def __init__(self, header=True, strategy=None):
|
| 1035 |
+
self.header = bool(header)
|
| 1036 |
+
self.strategy = strategy
|
| 1037 |
+
|
| 1038 |
+
def encode(self, buf):
|
| 1039 |
+
return imagecodecs.pglz_encode(
|
| 1040 |
+
buf, strategy=self.strategy, header=self.header
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
def decode(self, buf, out=None):
|
| 1044 |
+
return imagecodecs.pglz_decode(buf, header=self.header, out=_flat(out))
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
class Png(Codec):
|
| 1048 |
+
"""PNG codec for numcodecs."""
|
| 1049 |
+
|
| 1050 |
+
codec_id = 'imagecodecs_png'
|
| 1051 |
+
|
| 1052 |
+
def __init__(self, level=None):
|
| 1053 |
+
self.level = level
|
| 1054 |
+
|
| 1055 |
+
def encode(self, buf):
|
| 1056 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1057 |
+
return imagecodecs.png_encode(buf, level=self.level)
|
| 1058 |
+
|
| 1059 |
+
def decode(self, buf, out=None):
|
| 1060 |
+
return imagecodecs.png_decode(buf, out=out)
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
class Qoi(Codec):
|
| 1064 |
+
"""QOI codec for numcodecs."""
|
| 1065 |
+
|
| 1066 |
+
codec_id = 'imagecodecs_qoi'
|
| 1067 |
+
|
| 1068 |
+
def __init__(self):
|
| 1069 |
+
pass
|
| 1070 |
+
|
| 1071 |
+
def encode(self, buf):
|
| 1072 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1073 |
+
return imagecodecs.qoi_encode(buf)
|
| 1074 |
+
|
| 1075 |
+
def decode(self, buf, out=None):
|
| 1076 |
+
return imagecodecs.qoi_decode(buf, out=out)
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
class Rgbe(Codec):
|
| 1080 |
+
"""RGBE codec for numcodecs."""
|
| 1081 |
+
|
| 1082 |
+
codec_id = 'imagecodecs_rgbe'
|
| 1083 |
+
|
| 1084 |
+
def __init__(self, header=False, shape=None, rle=None):
|
| 1085 |
+
if not header and shape is None:
|
| 1086 |
+
raise ValueError('must specify data shape if no header')
|
| 1087 |
+
if shape and shape[-1] != 3:
|
| 1088 |
+
raise ValueError('invalid shape')
|
| 1089 |
+
self.shape = shape
|
| 1090 |
+
self.header = bool(header)
|
| 1091 |
+
self.rle = None if rle is None else bool(rle)
|
| 1092 |
+
|
| 1093 |
+
def encode(self, buf):
|
| 1094 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1095 |
+
return imagecodecs.rgbe_encode(buf, header=self.header, rle=self.rle)
|
| 1096 |
+
|
| 1097 |
+
def decode(self, buf, out=None):
|
| 1098 |
+
if out is None and not self.header:
|
| 1099 |
+
out = numpy.empty(self.shape, numpy.float32)
|
| 1100 |
+
return imagecodecs.rgbe_decode(
|
| 1101 |
+
buf, header=self.header, rle=self.rle, out=out
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
class Rcomp(Codec):
|
| 1106 |
+
"""Rcomp codec for numcodecs."""
|
| 1107 |
+
|
| 1108 |
+
codec_id = 'imagecodecs_rcomp'
|
| 1109 |
+
|
| 1110 |
+
def __init__(self, shape, dtype, nblock=None):
|
| 1111 |
+
self.shape = tuple(shape)
|
| 1112 |
+
self.dtype = numpy.dtype(dtype).str
|
| 1113 |
+
self.nblock = nblock
|
| 1114 |
+
|
| 1115 |
+
def encode(self, buf):
|
| 1116 |
+
return imagecodecs.rcomp_encode(buf, nblock=self.nblock)
|
| 1117 |
+
|
| 1118 |
+
def decode(self, buf, out=None):
|
| 1119 |
+
return imagecodecs.rcomp_decode(
|
| 1120 |
+
buf,
|
| 1121 |
+
shape=self.shape,
|
| 1122 |
+
dtype=self.dtype,
|
| 1123 |
+
nblock=self.nblock,
|
| 1124 |
+
out=out,
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
|
| 1128 |
+
class Snappy(Codec):
|
| 1129 |
+
"""Snappy codec for numcodecs."""
|
| 1130 |
+
|
| 1131 |
+
codec_id = 'imagecodecs_snappy'
|
| 1132 |
+
|
| 1133 |
+
def encode(self, buf):
|
| 1134 |
+
return imagecodecs.snappy_encode(buf)
|
| 1135 |
+
|
| 1136 |
+
def decode(self, buf, out=None):
|
| 1137 |
+
return imagecodecs.snappy_decode(buf, out=_flat(out))
|
| 1138 |
+
|
| 1139 |
+
|
| 1140 |
+
class Spng(Codec):
|
| 1141 |
+
"""SPNG codec for numcodecs."""
|
| 1142 |
+
|
| 1143 |
+
codec_id = 'imagecodecs_spng'
|
| 1144 |
+
|
| 1145 |
+
def __init__(self, level=None):
|
| 1146 |
+
self.level = level
|
| 1147 |
+
|
| 1148 |
+
def encode(self, buf):
|
| 1149 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1150 |
+
return imagecodecs.spng_encode(buf, level=self.level)
|
| 1151 |
+
|
| 1152 |
+
def decode(self, buf, out=None):
|
| 1153 |
+
return imagecodecs.spng_decode(buf, out=out)
|
| 1154 |
+
|
| 1155 |
+
|
| 1156 |
+
class Tiff(Codec):
|
| 1157 |
+
"""TIFF codec for numcodecs."""
|
| 1158 |
+
|
| 1159 |
+
codec_id = 'imagecodecs_tiff'
|
| 1160 |
+
|
| 1161 |
+
def __init__(self, index=None, asrgb=None, verbose=0):
|
| 1162 |
+
self.index = index
|
| 1163 |
+
self.asrgb = bool(asrgb)
|
| 1164 |
+
self.verbose = verbose
|
| 1165 |
+
|
| 1166 |
+
def encode(self, buf):
|
| 1167 |
+
# TODO: not implemented
|
| 1168 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1169 |
+
return imagecodecs.tiff_encode(buf)
|
| 1170 |
+
|
| 1171 |
+
def decode(self, buf, out=None):
|
| 1172 |
+
return imagecodecs.tiff_decode(
|
| 1173 |
+
buf,
|
| 1174 |
+
index=self.index,
|
| 1175 |
+
asrgb=self.asrgb,
|
| 1176 |
+
verbose=self.verbose,
|
| 1177 |
+
out=out,
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
|
| 1181 |
+
class Webp(Codec):
|
| 1182 |
+
"""WebP codec for numcodecs."""
|
| 1183 |
+
|
| 1184 |
+
codec_id = 'imagecodecs_webp'
|
| 1185 |
+
|
| 1186 |
+
def __init__(self, level=None, lossless=None, method=None, hasalpha=None):
|
| 1187 |
+
self.level = level
|
| 1188 |
+
self.hasalpha = bool(hasalpha)
|
| 1189 |
+
self.method = method
|
| 1190 |
+
self.lossless = lossless
|
| 1191 |
+
|
| 1192 |
+
def encode(self, buf):
|
| 1193 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1194 |
+
return imagecodecs.webp_encode(
|
| 1195 |
+
buf, level=self.level, lossless=self.lossless, method=self.method
|
| 1196 |
+
)
|
| 1197 |
+
|
| 1198 |
+
def decode(self, buf, out=None):
|
| 1199 |
+
return imagecodecs.webp_decode(buf, hasalpha=self.hasalpha, out=out)
|
| 1200 |
+
|
| 1201 |
+
|
| 1202 |
+
class Xor(Codec):
|
| 1203 |
+
"""XOR codec for numcodecs."""
|
| 1204 |
+
|
| 1205 |
+
codec_id = 'imagecodecs_xor'
|
| 1206 |
+
|
| 1207 |
+
def __init__(self, shape=None, dtype=None, axis=-1):
|
| 1208 |
+
self.shape = None if shape is None else tuple(shape)
|
| 1209 |
+
self.dtype = None if dtype is None else numpy.dtype(dtype).str
|
| 1210 |
+
self.axis = axis
|
| 1211 |
+
|
| 1212 |
+
def encode(self, buf):
|
| 1213 |
+
if self.shape is not None or self.dtype is not None:
|
| 1214 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1215 |
+
assert buf.shape == self.shape
|
| 1216 |
+
assert buf.dtype == self.dtype
|
| 1217 |
+
return imagecodecs.xor_encode(buf, axis=self.axis).tobytes()
|
| 1218 |
+
|
| 1219 |
+
def decode(self, buf, out=None):
|
| 1220 |
+
if self.shape is not None or self.dtype is not None:
|
| 1221 |
+
buf = numpy.frombuffer(buf, dtype=self.dtype).reshape(*self.shape)
|
| 1222 |
+
return imagecodecs.xor_decode(buf, axis=self.axis, out=_flat(out))
|
| 1223 |
+
|
| 1224 |
+
|
| 1225 |
+
class Zfp(Codec):
|
| 1226 |
+
"""ZFP codec for numcodecs."""
|
| 1227 |
+
|
| 1228 |
+
codec_id = 'imagecodecs_zfp'
|
| 1229 |
+
|
| 1230 |
+
def __init__(
|
| 1231 |
+
self,
|
| 1232 |
+
shape=None,
|
| 1233 |
+
dtype=None,
|
| 1234 |
+
strides=None,
|
| 1235 |
+
level=None,
|
| 1236 |
+
mode=None,
|
| 1237 |
+
execution=None,
|
| 1238 |
+
numthreads=None,
|
| 1239 |
+
chunksize=None,
|
| 1240 |
+
header=True,
|
| 1241 |
+
):
|
| 1242 |
+
if header:
|
| 1243 |
+
self.shape = None
|
| 1244 |
+
self.dtype = None
|
| 1245 |
+
self.strides = None
|
| 1246 |
+
elif shape is None or dtype is None:
|
| 1247 |
+
raise ValueError('invalid shape or dtype')
|
| 1248 |
+
else:
|
| 1249 |
+
self.shape = tuple(shape)
|
| 1250 |
+
self.dtype = numpy.dtype(dtype).str
|
| 1251 |
+
self.strides = None if strides is None else tuple(strides)
|
| 1252 |
+
self.level = level
|
| 1253 |
+
self.mode = mode
|
| 1254 |
+
self.execution = execution
|
| 1255 |
+
self.numthreads = numthreads
|
| 1256 |
+
self.chunksize = chunksize
|
| 1257 |
+
self.header = bool(header)
|
| 1258 |
+
|
| 1259 |
+
def encode(self, buf):
|
| 1260 |
+
buf = protective_squeeze(numpy.asarray(buf))
|
| 1261 |
+
if not self.header:
|
| 1262 |
+
assert buf.shape == self.shape
|
| 1263 |
+
assert buf.dtype == self.dtype
|
| 1264 |
+
return imagecodecs.zfp_encode(
|
| 1265 |
+
buf,
|
| 1266 |
+
level=self.level,
|
| 1267 |
+
mode=self.mode,
|
| 1268 |
+
execution=self.execution,
|
| 1269 |
+
header=self.header,
|
| 1270 |
+
numthreads=self.numthreads,
|
| 1271 |
+
chunksize=self.chunksize,
|
| 1272 |
+
)
|
| 1273 |
+
|
| 1274 |
+
def decode(self, buf, out=None):
|
| 1275 |
+
if self.header:
|
| 1276 |
+
return imagecodecs.zfp_decode(buf, out=out)
|
| 1277 |
+
return imagecodecs.zfp_decode(
|
| 1278 |
+
buf,
|
| 1279 |
+
shape=self.shape,
|
| 1280 |
+
dtype=numpy.dtype(self.dtype),
|
| 1281 |
+
strides=self.strides,
|
| 1282 |
+
numthreads=self.numthreads,
|
| 1283 |
+
out=out,
|
| 1284 |
+
)
|
| 1285 |
+
|
| 1286 |
+
|
| 1287 |
+
class Zlib(Codec):
|
| 1288 |
+
"""Zlib codec for numcodecs."""
|
| 1289 |
+
|
| 1290 |
+
codec_id = 'imagecodecs_zlib'
|
| 1291 |
+
|
| 1292 |
+
def __init__(self, level=None):
|
| 1293 |
+
self.level = level
|
| 1294 |
+
|
| 1295 |
+
def encode(self, buf):
|
| 1296 |
+
return imagecodecs.zlib_encode(buf, level=self.level)
|
| 1297 |
+
|
| 1298 |
+
def decode(self, buf, out=None):
|
| 1299 |
+
return imagecodecs.zlib_decode(buf, out=_flat(out))
|
| 1300 |
+
|
| 1301 |
+
|
| 1302 |
+
class Zlibng(Codec):
|
| 1303 |
+
"""Zlibng codec for numcodecs."""
|
| 1304 |
+
|
| 1305 |
+
codec_id = 'imagecodecs_zlibng'
|
| 1306 |
+
|
| 1307 |
+
def __init__(self, level=None):
|
| 1308 |
+
self.level = level
|
| 1309 |
+
|
| 1310 |
+
def encode(self, buf):
|
| 1311 |
+
return imagecodecs.zlibng_encode(buf, level=self.level)
|
| 1312 |
+
|
| 1313 |
+
def decode(self, buf, out=None):
|
| 1314 |
+
return imagecodecs.zlibng_decode(buf, out=_flat(out))
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
class Zopfli(Codec):
|
| 1318 |
+
"""Zopfli codec for numcodecs."""
|
| 1319 |
+
|
| 1320 |
+
codec_id = 'imagecodecs_zopfli'
|
| 1321 |
+
|
| 1322 |
+
def encode(self, buf):
|
| 1323 |
+
return imagecodecs.zopfli_encode(buf)
|
| 1324 |
+
|
| 1325 |
+
def decode(self, buf, out=None):
|
| 1326 |
+
return imagecodecs.zopfli_decode(buf, out=_flat(out))
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
+
class Zstd(Codec):
|
| 1330 |
+
"""ZStandard codec for numcodecs."""
|
| 1331 |
+
|
| 1332 |
+
codec_id = 'imagecodecs_zstd'
|
| 1333 |
+
|
| 1334 |
+
def __init__(self, level=None):
|
| 1335 |
+
self.level = level
|
| 1336 |
+
|
| 1337 |
+
def encode(self, buf):
|
| 1338 |
+
return imagecodecs.zstd_encode(buf, level=self.level)
|
| 1339 |
+
|
| 1340 |
+
def decode(self, buf, out=None):
|
| 1341 |
+
return imagecodecs.zstd_decode(buf, out=_flat(out))
|
| 1342 |
+
|
| 1343 |
+
|
| 1344 |
+
def _flat(out):
|
| 1345 |
+
"""Return numpy array as contiguous view of bytes if possible."""
|
| 1346 |
+
if out is None:
|
| 1347 |
+
return None
|
| 1348 |
+
view = memoryview(out)
|
| 1349 |
+
if view.readonly or not view.contiguous:
|
| 1350 |
+
return None
|
| 1351 |
+
return view.cast('B')
|
| 1352 |
+
|
| 1353 |
+
|
| 1354 |
+
def register_codecs(codecs=None, force=False, verbose=True):
|
| 1355 |
+
"""Register codecs in this module with numcodecs."""
|
| 1356 |
+
for name, cls in globals().items():
|
| 1357 |
+
if not hasattr(cls, 'codec_id') or name == 'Codec':
|
| 1358 |
+
continue
|
| 1359 |
+
if codecs is not None and cls.codec_id not in codecs:
|
| 1360 |
+
continue
|
| 1361 |
+
try:
|
| 1362 |
+
try:
|
| 1363 |
+
get_codec({'id': cls.codec_id})
|
| 1364 |
+
except TypeError:
|
| 1365 |
+
# registered, but failed
|
| 1366 |
+
pass
|
| 1367 |
+
except ValueError:
|
| 1368 |
+
# not registered yet
|
| 1369 |
+
pass
|
| 1370 |
+
else:
|
| 1371 |
+
if not force:
|
| 1372 |
+
if verbose:
|
| 1373 |
+
log_warning(
|
| 1374 |
+
f'numcodec {cls.codec_id!r} already registered'
|
| 1375 |
+
)
|
| 1376 |
+
continue
|
| 1377 |
+
if verbose:
|
| 1378 |
+
log_warning(f'replacing registered numcodec {cls.codec_id!r}')
|
| 1379 |
+
register_codec(cls)
|
| 1380 |
+
|
| 1381 |
+
|
| 1382 |
+
def log_warning(msg, *args, **kwargs):
|
| 1383 |
+
"""Log message with level WARNING."""
|
| 1384 |
+
import logging
|
| 1385 |
+
|
| 1386 |
+
logging.getLogger(__name__).warning(msg, *args, **kwargs)
|
equidiff/equi_diffpo/common/checkpoint_util.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
class TopKCheckpointManager:
|
| 5 |
+
def __init__(self,
|
| 6 |
+
save_dir,
|
| 7 |
+
monitor_key: str,
|
| 8 |
+
mode='min',
|
| 9 |
+
k=1,
|
| 10 |
+
format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt'
|
| 11 |
+
):
|
| 12 |
+
assert mode in ['max', 'min']
|
| 13 |
+
assert k >= 0
|
| 14 |
+
|
| 15 |
+
self.save_dir = save_dir
|
| 16 |
+
self.monitor_key = monitor_key
|
| 17 |
+
self.mode = mode
|
| 18 |
+
self.k = k
|
| 19 |
+
self.format_str = format_str
|
| 20 |
+
self.path_value_map = dict()
|
| 21 |
+
|
| 22 |
+
def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]:
|
| 23 |
+
if self.k == 0:
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
value = data[self.monitor_key]
|
| 27 |
+
ckpt_path = os.path.join(
|
| 28 |
+
self.save_dir, self.format_str.format(**data))
|
| 29 |
+
|
| 30 |
+
if len(self.path_value_map) < self.k:
|
| 31 |
+
# under-capacity
|
| 32 |
+
self.path_value_map[ckpt_path] = value
|
| 33 |
+
return ckpt_path
|
| 34 |
+
|
| 35 |
+
# at capacity
|
| 36 |
+
sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1])
|
| 37 |
+
min_path, min_value = sorted_map[0]
|
| 38 |
+
max_path, max_value = sorted_map[-1]
|
| 39 |
+
|
| 40 |
+
delete_path = None
|
| 41 |
+
if self.mode == 'max':
|
| 42 |
+
if value > min_value:
|
| 43 |
+
delete_path = min_path
|
| 44 |
+
else:
|
| 45 |
+
if value < max_value:
|
| 46 |
+
delete_path = max_path
|
| 47 |
+
|
| 48 |
+
if delete_path is None:
|
| 49 |
+
return None
|
| 50 |
+
else:
|
| 51 |
+
del self.path_value_map[delete_path]
|
| 52 |
+
self.path_value_map[ckpt_path] = value
|
| 53 |
+
|
| 54 |
+
if not os.path.exists(self.save_dir):
|
| 55 |
+
os.mkdir(self.save_dir)
|
| 56 |
+
|
| 57 |
+
if os.path.exists(delete_path):
|
| 58 |
+
os.remove(delete_path)
|
| 59 |
+
return ckpt_path
|
equidiff/equi_diffpo/common/cv2_util.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import math
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def draw_reticle(img, u, v, label_color):
|
| 7 |
+
"""
|
| 8 |
+
Draws a reticle (cross-hair) on the image at the given position on top of
|
| 9 |
+
the original image.
|
| 10 |
+
@param img (In/Out) uint8 3 channel image
|
| 11 |
+
@param u X coordinate (width)
|
| 12 |
+
@param v Y coordinate (height)
|
| 13 |
+
@param label_color tuple of 3 ints for RGB color used for drawing.
|
| 14 |
+
"""
|
| 15 |
+
# Cast to int.
|
| 16 |
+
u = int(u)
|
| 17 |
+
v = int(v)
|
| 18 |
+
|
| 19 |
+
white = (255, 255, 255)
|
| 20 |
+
cv2.circle(img, (u, v), 10, label_color, 1)
|
| 21 |
+
cv2.circle(img, (u, v), 11, white, 1)
|
| 22 |
+
cv2.circle(img, (u, v), 12, label_color, 1)
|
| 23 |
+
cv2.line(img, (u, v + 1), (u, v + 3), white, 1)
|
| 24 |
+
cv2.line(img, (u + 1, v), (u + 3, v), white, 1)
|
| 25 |
+
cv2.line(img, (u, v - 1), (u, v - 3), white, 1)
|
| 26 |
+
cv2.line(img, (u - 1, v), (u - 3, v), white, 1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def draw_text(
|
| 30 |
+
img,
|
| 31 |
+
*,
|
| 32 |
+
text,
|
| 33 |
+
uv_top_left,
|
| 34 |
+
color=(255, 255, 255),
|
| 35 |
+
fontScale=0.5,
|
| 36 |
+
thickness=1,
|
| 37 |
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
| 38 |
+
outline_color=(0, 0, 0),
|
| 39 |
+
line_spacing=1.5,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Draws multiline with an outline.
|
| 43 |
+
"""
|
| 44 |
+
assert isinstance(text, str)
|
| 45 |
+
|
| 46 |
+
uv_top_left = np.array(uv_top_left, dtype=float)
|
| 47 |
+
assert uv_top_left.shape == (2,)
|
| 48 |
+
|
| 49 |
+
for line in text.splitlines():
|
| 50 |
+
(w, h), _ = cv2.getTextSize(
|
| 51 |
+
text=line,
|
| 52 |
+
fontFace=fontFace,
|
| 53 |
+
fontScale=fontScale,
|
| 54 |
+
thickness=thickness,
|
| 55 |
+
)
|
| 56 |
+
uv_bottom_left_i = uv_top_left + [0, h]
|
| 57 |
+
org = tuple(uv_bottom_left_i.astype(int))
|
| 58 |
+
|
| 59 |
+
if outline_color is not None:
|
| 60 |
+
cv2.putText(
|
| 61 |
+
img,
|
| 62 |
+
text=line,
|
| 63 |
+
org=org,
|
| 64 |
+
fontFace=fontFace,
|
| 65 |
+
fontScale=fontScale,
|
| 66 |
+
color=outline_color,
|
| 67 |
+
thickness=thickness * 3,
|
| 68 |
+
lineType=cv2.LINE_AA,
|
| 69 |
+
)
|
| 70 |
+
cv2.putText(
|
| 71 |
+
img,
|
| 72 |
+
text=line,
|
| 73 |
+
org=org,
|
| 74 |
+
fontFace=fontFace,
|
| 75 |
+
fontScale=fontScale,
|
| 76 |
+
color=color,
|
| 77 |
+
thickness=thickness,
|
| 78 |
+
lineType=cv2.LINE_AA,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
uv_top_left += [0, h * line_spacing]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_image_transform(
|
| 85 |
+
input_res: Tuple[int,int]=(1280,720),
|
| 86 |
+
output_res: Tuple[int,int]=(640,480),
|
| 87 |
+
bgr_to_rgb: bool=False):
|
| 88 |
+
|
| 89 |
+
iw, ih = input_res
|
| 90 |
+
ow, oh = output_res
|
| 91 |
+
rw, rh = None, None
|
| 92 |
+
interp_method = cv2.INTER_AREA
|
| 93 |
+
|
| 94 |
+
if (iw/ih) >= (ow/oh):
|
| 95 |
+
# input is wider
|
| 96 |
+
rh = oh
|
| 97 |
+
rw = math.ceil(rh / ih * iw)
|
| 98 |
+
if oh > ih:
|
| 99 |
+
interp_method = cv2.INTER_LINEAR
|
| 100 |
+
else:
|
| 101 |
+
rw = ow
|
| 102 |
+
rh = math.ceil(rw / iw * ih)
|
| 103 |
+
if ow > iw:
|
| 104 |
+
interp_method = cv2.INTER_LINEAR
|
| 105 |
+
|
| 106 |
+
w_slice_start = (rw - ow) // 2
|
| 107 |
+
w_slice = slice(w_slice_start, w_slice_start + ow)
|
| 108 |
+
h_slice_start = (rh - oh) // 2
|
| 109 |
+
h_slice = slice(h_slice_start, h_slice_start + oh)
|
| 110 |
+
c_slice = slice(None)
|
| 111 |
+
if bgr_to_rgb:
|
| 112 |
+
c_slice = slice(None, None, -1)
|
| 113 |
+
|
| 114 |
+
def transform(img: np.ndarray):
|
| 115 |
+
assert img.shape == ((ih,iw,3))
|
| 116 |
+
# resize
|
| 117 |
+
img = cv2.resize(img, (rw, rh), interpolation=interp_method)
|
| 118 |
+
# crop
|
| 119 |
+
img = img[h_slice, w_slice, c_slice]
|
| 120 |
+
return img
|
| 121 |
+
return transform
|
| 122 |
+
|
| 123 |
+
def optimal_row_cols(
|
| 124 |
+
n_cameras,
|
| 125 |
+
in_wh_ratio,
|
| 126 |
+
max_resolution=(1920, 1080)
|
| 127 |
+
):
|
| 128 |
+
out_w, out_h = max_resolution
|
| 129 |
+
out_wh_ratio = out_w / out_h
|
| 130 |
+
|
| 131 |
+
n_rows = np.arange(n_cameras,dtype=np.int64) + 1
|
| 132 |
+
n_cols = np.ceil(n_cameras / n_rows).astype(np.int64)
|
| 133 |
+
cat_wh_ratio = in_wh_ratio * (n_cols / n_rows)
|
| 134 |
+
ratio_diff = np.abs(out_wh_ratio - cat_wh_ratio)
|
| 135 |
+
best_idx = np.argmin(ratio_diff)
|
| 136 |
+
best_n_row = n_rows[best_idx]
|
| 137 |
+
best_n_col = n_cols[best_idx]
|
| 138 |
+
best_cat_wh_ratio = cat_wh_ratio[best_idx]
|
| 139 |
+
|
| 140 |
+
rw, rh = None, None
|
| 141 |
+
if best_cat_wh_ratio >= out_wh_ratio:
|
| 142 |
+
# cat is wider
|
| 143 |
+
rw = math.floor(out_w / best_n_col)
|
| 144 |
+
rh = math.floor(rw / in_wh_ratio)
|
| 145 |
+
else:
|
| 146 |
+
rh = math.floor(out_h / best_n_row)
|
| 147 |
+
rw = math.floor(rh * in_wh_ratio)
|
| 148 |
+
|
| 149 |
+
# crop_resolution = (rw, rh)
|
| 150 |
+
return rw, rh, best_n_col, best_n_row
|
equidiff/equi_diffpo/common/env_util.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def render_env_video(env, states, actions=None):
|
| 6 |
+
observations = states
|
| 7 |
+
imgs = list()
|
| 8 |
+
for i in range(len(observations)):
|
| 9 |
+
state = observations[i]
|
| 10 |
+
env.set_state(state)
|
| 11 |
+
if i == 0:
|
| 12 |
+
env.set_state(state)
|
| 13 |
+
img = env.render()
|
| 14 |
+
# draw action
|
| 15 |
+
if actions is not None:
|
| 16 |
+
action = actions[i]
|
| 17 |
+
coord = (action / 512 * 96).astype(np.int32)
|
| 18 |
+
cv2.drawMarker(img, coord,
|
| 19 |
+
color=(255,0,0), markerType=cv2.MARKER_CROSS,
|
| 20 |
+
markerSize=8, thickness=1)
|
| 21 |
+
imgs.append(img)
|
| 22 |
+
imgs = np.array(imgs)
|
| 23 |
+
return imgs
|
equidiff/equi_diffpo/common/json_logger.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Callable, Any, Sequence
|
| 2 |
+
import os
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
import numbers
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def read_json_log(path: str,
|
| 10 |
+
required_keys: Sequence[str]=tuple(),
|
| 11 |
+
**kwargs) -> pd.DataFrame:
|
| 12 |
+
"""
|
| 13 |
+
Read json-per-line file, with potentially incomplete lines.
|
| 14 |
+
kwargs passed to pd.read_json
|
| 15 |
+
"""
|
| 16 |
+
lines = list()
|
| 17 |
+
with open(path, 'r') as f:
|
| 18 |
+
while True:
|
| 19 |
+
# one json per line
|
| 20 |
+
line = f.readline()
|
| 21 |
+
if len(line) == 0:
|
| 22 |
+
# EOF
|
| 23 |
+
break
|
| 24 |
+
elif not line.endswith('\n'):
|
| 25 |
+
# incomplete line
|
| 26 |
+
break
|
| 27 |
+
is_relevant = False
|
| 28 |
+
for k in required_keys:
|
| 29 |
+
if k in line:
|
| 30 |
+
is_relevant = True
|
| 31 |
+
break
|
| 32 |
+
if is_relevant:
|
| 33 |
+
lines.append(line)
|
| 34 |
+
if len(lines) < 1:
|
| 35 |
+
return pd.DataFrame()
|
| 36 |
+
json_buf = f'[{",".join([line for line in (line.strip() for line in lines) if line])}]'
|
| 37 |
+
df = pd.read_json(json_buf, **kwargs)
|
| 38 |
+
return df
|
| 39 |
+
|
| 40 |
+
class JsonLogger:
|
| 41 |
+
def __init__(self, path: str,
|
| 42 |
+
filter_fn: Optional[Callable[[str,Any],bool]]=None):
|
| 43 |
+
if filter_fn is None:
|
| 44 |
+
filter_fn = lambda k,v: isinstance(v, numbers.Number)
|
| 45 |
+
|
| 46 |
+
# default to append mode
|
| 47 |
+
self.path = path
|
| 48 |
+
self.filter_fn = filter_fn
|
| 49 |
+
self.file = None
|
| 50 |
+
self.last_log = None
|
| 51 |
+
|
| 52 |
+
def start(self):
|
| 53 |
+
# use line buffering
|
| 54 |
+
try:
|
| 55 |
+
self.file = file = open(self.path, 'r+', buffering=1)
|
| 56 |
+
except FileNotFoundError:
|
| 57 |
+
self.file = file = open(self.path, 'w+', buffering=1)
|
| 58 |
+
|
| 59 |
+
# Move the pointer (similar to a cursor in a text editor) to the end of the file
|
| 60 |
+
pos = file.seek(0, os.SEEK_END)
|
| 61 |
+
|
| 62 |
+
# Read each character in the file one at a time from the last
|
| 63 |
+
# character going backwards, searching for a newline character
|
| 64 |
+
# If we find a new line, exit the search
|
| 65 |
+
while pos > 0 and file.read(1) != "\n":
|
| 66 |
+
pos -= 1
|
| 67 |
+
file.seek(pos, os.SEEK_SET)
|
| 68 |
+
# now the file pointer is at one past the last '\n'
|
| 69 |
+
# and pos is at the last '\n'.
|
| 70 |
+
last_line_end = file.tell()
|
| 71 |
+
|
| 72 |
+
# find the start of second last line
|
| 73 |
+
pos = max(0, pos-1)
|
| 74 |
+
file.seek(pos, os.SEEK_SET)
|
| 75 |
+
while pos > 0 and file.read(1) != "\n":
|
| 76 |
+
pos -= 1
|
| 77 |
+
file.seek(pos, os.SEEK_SET)
|
| 78 |
+
# now the file pointer is at one past the second last '\n'
|
| 79 |
+
last_line_start = file.tell()
|
| 80 |
+
|
| 81 |
+
if last_line_start < last_line_end:
|
| 82 |
+
# has last line of json
|
| 83 |
+
last_line = file.readline()
|
| 84 |
+
self.last_log = json.loads(last_line)
|
| 85 |
+
|
| 86 |
+
# remove the last incomplete line
|
| 87 |
+
file.seek(last_line_end)
|
| 88 |
+
file.truncate()
|
| 89 |
+
|
| 90 |
+
def stop(self):
|
| 91 |
+
self.file.close()
|
| 92 |
+
self.file = None
|
| 93 |
+
|
| 94 |
+
def __enter__(self):
|
| 95 |
+
self.start()
|
| 96 |
+
return self
|
| 97 |
+
|
| 98 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 99 |
+
self.stop()
|
| 100 |
+
|
| 101 |
+
def log(self, data: dict):
|
| 102 |
+
filtered_data = dict(
|
| 103 |
+
filter(lambda x: self.filter_fn(*x), data.items()))
|
| 104 |
+
# save current as last log
|
| 105 |
+
self.last_log = filtered_data
|
| 106 |
+
for k, v in filtered_data.items():
|
| 107 |
+
if isinstance(v, numbers.Integral):
|
| 108 |
+
filtered_data[k] = int(v)
|
| 109 |
+
elif isinstance(v, numbers.Number):
|
| 110 |
+
filtered_data[k] = float(v)
|
| 111 |
+
buf = json.dumps(filtered_data)
|
| 112 |
+
# ensure one line per json
|
| 113 |
+
buf = buf.replace('\n','') + '\n'
|
| 114 |
+
self.file.write(buf)
|
| 115 |
+
|
| 116 |
+
def get_last_log(self):
|
| 117 |
+
return copy.deepcopy(self.last_log)
|
equidiff/equi_diffpo/common/nested_dict_util.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
def nested_dict_map(f, x):
|
| 4 |
+
"""
|
| 5 |
+
Map f over all leaf of nested dict x
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
if not isinstance(x, dict):
|
| 9 |
+
return f(x)
|
| 10 |
+
y = dict()
|
| 11 |
+
for key, value in x.items():
|
| 12 |
+
y[key] = nested_dict_map(f, value)
|
| 13 |
+
return y
|
| 14 |
+
|
| 15 |
+
def nested_dict_reduce(f, x):
|
| 16 |
+
"""
|
| 17 |
+
Map f over all values of nested dict x, and reduce to a single value
|
| 18 |
+
"""
|
| 19 |
+
if not isinstance(x, dict):
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
reduced_values = list()
|
| 23 |
+
for value in x.values():
|
| 24 |
+
reduced_values.append(nested_dict_reduce(f, value))
|
| 25 |
+
y = functools.reduce(f, reduced_values)
|
| 26 |
+
return y
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def nested_dict_check(f, x):
|
| 30 |
+
bool_dict = nested_dict_map(f, x)
|
| 31 |
+
result = nested_dict_reduce(lambda x, y: x and y, bool_dict)
|
| 32 |
+
return result
|
equidiff/equi_diffpo/common/normalize_util.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from equi_diffpo.model.common.normalizer import SingleFieldLinearNormalizer
|
| 2 |
+
from equi_diffpo.common.pytorch_util import dict_apply, dict_apply_reduce, dict_apply_split
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_range_normalizer_from_stat(stat, output_max=1, output_min=-1, range_eps=1e-7):
|
| 7 |
+
# -1, 1 normalization
|
| 8 |
+
input_max = stat['max']
|
| 9 |
+
input_min = stat['min']
|
| 10 |
+
input_range = input_max - input_min
|
| 11 |
+
ignore_dim = input_range < range_eps
|
| 12 |
+
input_range[ignore_dim] = output_max - output_min
|
| 13 |
+
scale = (output_max - output_min) / input_range
|
| 14 |
+
offset = output_min - scale * input_min
|
| 15 |
+
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
| 16 |
+
|
| 17 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 18 |
+
scale=scale,
|
| 19 |
+
offset=offset,
|
| 20 |
+
input_stats_dict=stat
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def get_range_symmetric_normalizer_from_stat(stat, output_max=1, output_min=-1, range_eps=1e-7):
|
| 24 |
+
# -1, 1 normalization
|
| 25 |
+
input_max = stat['max']
|
| 26 |
+
input_min = stat['min']
|
| 27 |
+
abs_max = np.max([np.abs(stat['max'][:2]), np.abs(stat['min'][:2])])
|
| 28 |
+
input_max[:2] = abs_max
|
| 29 |
+
input_min[:2] = -abs_max
|
| 30 |
+
input_range = input_max - input_min
|
| 31 |
+
ignore_dim = input_range < range_eps
|
| 32 |
+
input_range[ignore_dim] = output_max - output_min
|
| 33 |
+
scale = (output_max - output_min) / input_range
|
| 34 |
+
offset = output_min - scale * input_min
|
| 35 |
+
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
| 36 |
+
|
| 37 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 38 |
+
scale=scale,
|
| 39 |
+
offset=offset,
|
| 40 |
+
input_stats_dict=stat
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def get_voxel_identity_normalizer():
|
| 44 |
+
scale = np.array([1], dtype=np.float32)
|
| 45 |
+
offset = np.array([0], dtype=np.float32)
|
| 46 |
+
stat = {
|
| 47 |
+
'min': np.array([0], dtype=np.float32),
|
| 48 |
+
'max': np.array([1], dtype=np.float32),
|
| 49 |
+
'mean': np.array([0.5], dtype=np.float32),
|
| 50 |
+
'std': np.array([np.sqrt(1/12)], dtype=np.float32)
|
| 51 |
+
}
|
| 52 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 53 |
+
scale=scale,
|
| 54 |
+
offset=offset,
|
| 55 |
+
input_stats_dict=stat
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def get_image_range_normalizer():
|
| 59 |
+
scale = np.array([2], dtype=np.float32)
|
| 60 |
+
offset = np.array([-1], dtype=np.float32)
|
| 61 |
+
stat = {
|
| 62 |
+
'min': np.array([0], dtype=np.float32),
|
| 63 |
+
'max': np.array([1], dtype=np.float32),
|
| 64 |
+
'mean': np.array([0.5], dtype=np.float32),
|
| 65 |
+
'std': np.array([np.sqrt(1/12)], dtype=np.float32)
|
| 66 |
+
}
|
| 67 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 68 |
+
scale=scale,
|
| 69 |
+
offset=offset,
|
| 70 |
+
input_stats_dict=stat
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def get_identity_normalizer_from_stat(stat):
|
| 74 |
+
scale = np.ones_like(stat['min'])
|
| 75 |
+
offset = np.zeros_like(stat['min'])
|
| 76 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 77 |
+
scale=scale,
|
| 78 |
+
offset=offset,
|
| 79 |
+
input_stats_dict=stat
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def robomimic_abs_action_normalizer_from_stat(stat, rotation_transformer):
|
| 83 |
+
result = dict_apply_split(
|
| 84 |
+
stat, lambda x: {
|
| 85 |
+
'pos': x[...,:3],
|
| 86 |
+
'rot': x[...,3:6],
|
| 87 |
+
'gripper': x[...,6:]
|
| 88 |
+
})
|
| 89 |
+
|
| 90 |
+
def get_pos_param_info(stat, output_max=1, output_min=-1, range_eps=1e-7):
|
| 91 |
+
# -1, 1 normalization
|
| 92 |
+
input_max = stat['max']
|
| 93 |
+
input_min = stat['min']
|
| 94 |
+
input_range = input_max - input_min
|
| 95 |
+
ignore_dim = input_range < range_eps
|
| 96 |
+
input_range[ignore_dim] = output_max - output_min
|
| 97 |
+
scale = (output_max - output_min) / input_range
|
| 98 |
+
offset = output_min - scale * input_min
|
| 99 |
+
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
| 100 |
+
|
| 101 |
+
return {'scale': scale, 'offset': offset}, stat
|
| 102 |
+
|
| 103 |
+
def get_rot_param_info(stat):
|
| 104 |
+
example = rotation_transformer.forward(stat['mean'])
|
| 105 |
+
scale = np.ones_like(example)
|
| 106 |
+
offset = np.zeros_like(example)
|
| 107 |
+
info = {
|
| 108 |
+
'max': np.ones_like(example),
|
| 109 |
+
'min': np.full_like(example, -1),
|
| 110 |
+
'mean': np.zeros_like(example),
|
| 111 |
+
'std': np.ones_like(example)
|
| 112 |
+
}
|
| 113 |
+
return {'scale': scale, 'offset': offset}, info
|
| 114 |
+
|
| 115 |
+
def get_gripper_param_info(stat):
|
| 116 |
+
example = stat['max']
|
| 117 |
+
scale = np.ones_like(example)
|
| 118 |
+
offset = np.zeros_like(example)
|
| 119 |
+
info = {
|
| 120 |
+
'max': np.ones_like(example),
|
| 121 |
+
'min': np.full_like(example, -1),
|
| 122 |
+
'mean': np.zeros_like(example),
|
| 123 |
+
'std': np.ones_like(example)
|
| 124 |
+
}
|
| 125 |
+
return {'scale': scale, 'offset': offset}, info
|
| 126 |
+
|
| 127 |
+
pos_param, pos_info = get_pos_param_info(result['pos'])
|
| 128 |
+
rot_param, rot_info = get_rot_param_info(result['rot'])
|
| 129 |
+
gripper_param, gripper_info = get_gripper_param_info(result['gripper'])
|
| 130 |
+
|
| 131 |
+
param = dict_apply_reduce(
|
| 132 |
+
[pos_param, rot_param, gripper_param],
|
| 133 |
+
lambda x: np.concatenate(x,axis=-1))
|
| 134 |
+
info = dict_apply_reduce(
|
| 135 |
+
[pos_info, rot_info, gripper_info],
|
| 136 |
+
lambda x: np.concatenate(x,axis=-1))
|
| 137 |
+
|
| 138 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 139 |
+
scale=param['scale'],
|
| 140 |
+
offset=param['offset'],
|
| 141 |
+
input_stats_dict=info
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def robomimic_abs_action_only_normalizer_from_stat(stat):
|
| 146 |
+
result = dict_apply_split(
|
| 147 |
+
stat, lambda x: {
|
| 148 |
+
'pos': x[...,:3],
|
| 149 |
+
'other': x[...,3:]
|
| 150 |
+
})
|
| 151 |
+
|
| 152 |
+
def get_pos_param_info(stat, output_max=1, output_min=-1, range_eps=1e-7):
|
| 153 |
+
# -1, 1 normalization
|
| 154 |
+
input_max = stat['max']
|
| 155 |
+
input_min = stat['min']
|
| 156 |
+
input_range = input_max - input_min
|
| 157 |
+
ignore_dim = input_range < range_eps
|
| 158 |
+
input_range[ignore_dim] = output_max - output_min
|
| 159 |
+
scale = (output_max - output_min) / input_range
|
| 160 |
+
offset = output_min - scale * input_min
|
| 161 |
+
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
| 162 |
+
|
| 163 |
+
return {'scale': scale, 'offset': offset}, stat
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_other_param_info(stat):
|
| 167 |
+
example = stat['max']
|
| 168 |
+
scale = np.ones_like(example)
|
| 169 |
+
offset = np.zeros_like(example)
|
| 170 |
+
info = {
|
| 171 |
+
'max': np.ones_like(example),
|
| 172 |
+
'min': np.full_like(example, -1),
|
| 173 |
+
'mean': np.zeros_like(example),
|
| 174 |
+
'std': np.ones_like(example)
|
| 175 |
+
}
|
| 176 |
+
return {'scale': scale, 'offset': offset}, info
|
| 177 |
+
|
| 178 |
+
pos_param, pos_info = get_pos_param_info(result['pos'])
|
| 179 |
+
other_param, other_info = get_other_param_info(result['other'])
|
| 180 |
+
|
| 181 |
+
param = dict_apply_reduce(
|
| 182 |
+
[pos_param, other_param],
|
| 183 |
+
lambda x: np.concatenate(x,axis=-1))
|
| 184 |
+
info = dict_apply_reduce(
|
| 185 |
+
[pos_info, other_info],
|
| 186 |
+
lambda x: np.concatenate(x,axis=-1))
|
| 187 |
+
|
| 188 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 189 |
+
scale=param['scale'],
|
| 190 |
+
offset=param['offset'],
|
| 191 |
+
input_stats_dict=info
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def robomimic_abs_action_only_symmetric_normalizer_from_stat(stat):
|
| 196 |
+
result = dict_apply_split(
|
| 197 |
+
stat, lambda x: {
|
| 198 |
+
'pos': x[...,:3],
|
| 199 |
+
'other': x[...,3:]
|
| 200 |
+
})
|
| 201 |
+
|
| 202 |
+
def get_pos_param_info(stat, output_max=1, output_min=-1, range_eps=1e-7):
|
| 203 |
+
# -1, 1 normalization
|
| 204 |
+
input_max = stat['max']
|
| 205 |
+
input_min = stat['min']
|
| 206 |
+
abs_max = np.max([np.abs(stat['max'][:2]), np.abs(stat['min'][:2])])
|
| 207 |
+
input_max[:2] = abs_max
|
| 208 |
+
input_min[:2] = -abs_max
|
| 209 |
+
input_range = input_max - input_min
|
| 210 |
+
ignore_dim = input_range < range_eps
|
| 211 |
+
input_range[ignore_dim] = output_max - output_min
|
| 212 |
+
scale = (output_max - output_min) / input_range
|
| 213 |
+
offset = output_min - scale * input_min
|
| 214 |
+
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
| 215 |
+
|
| 216 |
+
return {'scale': scale, 'offset': offset}, stat
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_other_param_info(stat):
|
| 220 |
+
example = stat['max']
|
| 221 |
+
scale = np.ones_like(example)
|
| 222 |
+
offset = np.zeros_like(example)
|
| 223 |
+
info = {
|
| 224 |
+
'max': np.ones_like(example),
|
| 225 |
+
'min': np.full_like(example, -1),
|
| 226 |
+
'mean': np.zeros_like(example),
|
| 227 |
+
'std': np.ones_like(example)
|
| 228 |
+
}
|
| 229 |
+
return {'scale': scale, 'offset': offset}, info
|
| 230 |
+
|
| 231 |
+
pos_param, pos_info = get_pos_param_info(result['pos'])
|
| 232 |
+
other_param, other_info = get_other_param_info(result['other'])
|
| 233 |
+
|
| 234 |
+
param = dict_apply_reduce(
|
| 235 |
+
[pos_param, other_param],
|
| 236 |
+
lambda x: np.concatenate(x,axis=-1))
|
| 237 |
+
info = dict_apply_reduce(
|
| 238 |
+
[pos_info, other_info],
|
| 239 |
+
lambda x: np.concatenate(x,axis=-1))
|
| 240 |
+
|
| 241 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 242 |
+
scale=param['scale'],
|
| 243 |
+
offset=param['offset'],
|
| 244 |
+
input_stats_dict=info
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat):
|
| 249 |
+
Da = stat['max'].shape[-1]
|
| 250 |
+
Dah = Da // 2
|
| 251 |
+
result = dict_apply_split(
|
| 252 |
+
stat, lambda x: {
|
| 253 |
+
'pos0': x[...,:3],
|
| 254 |
+
'other0': x[...,3:Dah],
|
| 255 |
+
'pos1': x[...,Dah:Dah+3],
|
| 256 |
+
'other1': x[...,Dah+3:]
|
| 257 |
+
})
|
| 258 |
+
|
| 259 |
+
def get_pos_param_info(stat, output_max=1, output_min=-1, range_eps=1e-7):
|
| 260 |
+
# -1, 1 normalization
|
| 261 |
+
input_max = stat['max']
|
| 262 |
+
input_min = stat['min']
|
| 263 |
+
input_range = input_max - input_min
|
| 264 |
+
ignore_dim = input_range < range_eps
|
| 265 |
+
input_range[ignore_dim] = output_max - output_min
|
| 266 |
+
scale = (output_max - output_min) / input_range
|
| 267 |
+
offset = output_min - scale * input_min
|
| 268 |
+
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
| 269 |
+
|
| 270 |
+
return {'scale': scale, 'offset': offset}, stat
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def get_other_param_info(stat):
|
| 274 |
+
example = stat['max']
|
| 275 |
+
scale = np.ones_like(example)
|
| 276 |
+
offset = np.zeros_like(example)
|
| 277 |
+
info = {
|
| 278 |
+
'max': np.ones_like(example),
|
| 279 |
+
'min': np.full_like(example, -1),
|
| 280 |
+
'mean': np.zeros_like(example),
|
| 281 |
+
'std': np.ones_like(example)
|
| 282 |
+
}
|
| 283 |
+
return {'scale': scale, 'offset': offset}, info
|
| 284 |
+
|
| 285 |
+
pos0_param, pos0_info = get_pos_param_info(result['pos0'])
|
| 286 |
+
pos1_param, pos1_info = get_pos_param_info(result['pos1'])
|
| 287 |
+
other0_param, other0_info = get_other_param_info(result['other0'])
|
| 288 |
+
other1_param, other1_info = get_other_param_info(result['other1'])
|
| 289 |
+
|
| 290 |
+
param = dict_apply_reduce(
|
| 291 |
+
[pos0_param, other0_param, pos1_param, other1_param],
|
| 292 |
+
lambda x: np.concatenate(x,axis=-1))
|
| 293 |
+
info = dict_apply_reduce(
|
| 294 |
+
[pos0_info, other0_info, pos1_info, other1_info],
|
| 295 |
+
lambda x: np.concatenate(x,axis=-1))
|
| 296 |
+
|
| 297 |
+
return SingleFieldLinearNormalizer.create_manual(
|
| 298 |
+
scale=param['scale'],
|
| 299 |
+
offset=param['offset'],
|
| 300 |
+
input_stats_dict=info
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def array_to_stats(arr: np.ndarray):
|
| 305 |
+
stat = {
|
| 306 |
+
'min': np.min(arr, axis=0),
|
| 307 |
+
'max': np.max(arr, axis=0),
|
| 308 |
+
'mean': np.mean(arr, axis=0),
|
| 309 |
+
'std': np.std(arr, axis=0)
|
| 310 |
+
}
|
| 311 |
+
return stat
|
equidiff/equi_diffpo/common/pose_trajectory_interpolator.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
import numbers
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.interpolate as si
|
| 5 |
+
import scipy.spatial.transform as st
|
| 6 |
+
|
| 7 |
+
def rotation_distance(a: st.Rotation, b: st.Rotation) -> float:
|
| 8 |
+
return (b * a.inv()).magnitude()
|
| 9 |
+
|
| 10 |
+
def pose_distance(start_pose, end_pose):
|
| 11 |
+
start_pose = np.array(start_pose)
|
| 12 |
+
end_pose = np.array(end_pose)
|
| 13 |
+
start_pos = start_pose[:3]
|
| 14 |
+
end_pos = end_pose[:3]
|
| 15 |
+
start_rot = st.Rotation.from_rotvec(start_pose[3:])
|
| 16 |
+
end_rot = st.Rotation.from_rotvec(end_pose[3:])
|
| 17 |
+
pos_dist = np.linalg.norm(end_pos - start_pos)
|
| 18 |
+
rot_dist = rotation_distance(start_rot, end_rot)
|
| 19 |
+
return pos_dist, rot_dist
|
| 20 |
+
|
| 21 |
+
class PoseTrajectoryInterpolator:
|
| 22 |
+
def __init__(self, times: np.ndarray, poses: np.ndarray):
|
| 23 |
+
assert len(times) >= 1
|
| 24 |
+
assert len(poses) == len(times)
|
| 25 |
+
if not isinstance(times, np.ndarray):
|
| 26 |
+
times = np.array(times)
|
| 27 |
+
if not isinstance(poses, np.ndarray):
|
| 28 |
+
poses = np.array(poses)
|
| 29 |
+
|
| 30 |
+
if len(times) == 1:
|
| 31 |
+
# special treatment for single step interpolation
|
| 32 |
+
self.single_step = True
|
| 33 |
+
self._times = times
|
| 34 |
+
self._poses = poses
|
| 35 |
+
else:
|
| 36 |
+
self.single_step = False
|
| 37 |
+
assert np.all(times[1:] >= times[:-1])
|
| 38 |
+
|
| 39 |
+
pos = poses[:,:3]
|
| 40 |
+
rot = st.Rotation.from_rotvec(poses[:,3:])
|
| 41 |
+
|
| 42 |
+
self.pos_interp = si.interp1d(times, pos,
|
| 43 |
+
axis=0, assume_sorted=True)
|
| 44 |
+
self.rot_interp = st.Slerp(times, rot)
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def times(self) -> np.ndarray:
|
| 48 |
+
if self.single_step:
|
| 49 |
+
return self._times
|
| 50 |
+
else:
|
| 51 |
+
return self.pos_interp.x
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def poses(self) -> np.ndarray:
|
| 55 |
+
if self.single_step:
|
| 56 |
+
return self._poses
|
| 57 |
+
else:
|
| 58 |
+
n = len(self.times)
|
| 59 |
+
poses = np.zeros((n, 6))
|
| 60 |
+
poses[:,:3] = self.pos_interp.y
|
| 61 |
+
poses[:,3:] = self.rot_interp(self.times).as_rotvec()
|
| 62 |
+
return poses
|
| 63 |
+
|
| 64 |
+
def trim(self,
|
| 65 |
+
start_t: float, end_t: float
|
| 66 |
+
) -> "PoseTrajectoryInterpolator":
|
| 67 |
+
assert start_t <= end_t
|
| 68 |
+
times = self.times
|
| 69 |
+
should_keep = (start_t < times) & (times < end_t)
|
| 70 |
+
keep_times = times[should_keep]
|
| 71 |
+
all_times = np.concatenate([[start_t], keep_times, [end_t]])
|
| 72 |
+
# remove duplicates, Slerp requires strictly increasing x
|
| 73 |
+
all_times = np.unique(all_times)
|
| 74 |
+
# interpolate
|
| 75 |
+
all_poses = self(all_times)
|
| 76 |
+
return PoseTrajectoryInterpolator(times=all_times, poses=all_poses)
|
| 77 |
+
|
| 78 |
+
def drive_to_waypoint(self,
|
| 79 |
+
pose, time, curr_time,
|
| 80 |
+
max_pos_speed=np.inf,
|
| 81 |
+
max_rot_speed=np.inf
|
| 82 |
+
) -> "PoseTrajectoryInterpolator":
|
| 83 |
+
assert(max_pos_speed > 0)
|
| 84 |
+
assert(max_rot_speed > 0)
|
| 85 |
+
time = max(time, curr_time)
|
| 86 |
+
|
| 87 |
+
curr_pose = self(curr_time)
|
| 88 |
+
pos_dist, rot_dist = pose_distance(curr_pose, pose)
|
| 89 |
+
pos_min_duration = pos_dist / max_pos_speed
|
| 90 |
+
rot_min_duration = rot_dist / max_rot_speed
|
| 91 |
+
duration = time - curr_time
|
| 92 |
+
duration = max(duration, max(pos_min_duration, rot_min_duration))
|
| 93 |
+
assert duration >= 0
|
| 94 |
+
last_waypoint_time = curr_time + duration
|
| 95 |
+
|
| 96 |
+
# insert new pose
|
| 97 |
+
trimmed_interp = self.trim(curr_time, curr_time)
|
| 98 |
+
times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0)
|
| 99 |
+
poses = np.append(trimmed_interp.poses, [pose], axis=0)
|
| 100 |
+
|
| 101 |
+
# create new interpolator
|
| 102 |
+
final_interp = PoseTrajectoryInterpolator(times, poses)
|
| 103 |
+
return final_interp
|
| 104 |
+
|
| 105 |
+
def schedule_waypoint(self,
|
| 106 |
+
pose, time,
|
| 107 |
+
max_pos_speed=np.inf,
|
| 108 |
+
max_rot_speed=np.inf,
|
| 109 |
+
curr_time=None,
|
| 110 |
+
last_waypoint_time=None
|
| 111 |
+
) -> "PoseTrajectoryInterpolator":
|
| 112 |
+
assert(max_pos_speed > 0)
|
| 113 |
+
assert(max_rot_speed > 0)
|
| 114 |
+
if last_waypoint_time is not None:
|
| 115 |
+
assert curr_time is not None
|
| 116 |
+
|
| 117 |
+
# trim current interpolator to between curr_time and last_waypoint_time
|
| 118 |
+
start_time = self.times[0]
|
| 119 |
+
end_time = self.times[-1]
|
| 120 |
+
assert start_time <= end_time
|
| 121 |
+
|
| 122 |
+
if curr_time is not None:
|
| 123 |
+
if time <= curr_time:
|
| 124 |
+
# if insert time is earlier than current time
|
| 125 |
+
# no effect should be done to the interpolator
|
| 126 |
+
return self
|
| 127 |
+
# now, curr_time < time
|
| 128 |
+
start_time = max(curr_time, start_time)
|
| 129 |
+
|
| 130 |
+
if last_waypoint_time is not None:
|
| 131 |
+
# if last_waypoint_time is earlier than start_time
|
| 132 |
+
# use start_time
|
| 133 |
+
if time <= last_waypoint_time:
|
| 134 |
+
end_time = curr_time
|
| 135 |
+
else:
|
| 136 |
+
end_time = max(last_waypoint_time, curr_time)
|
| 137 |
+
else:
|
| 138 |
+
end_time = curr_time
|
| 139 |
+
|
| 140 |
+
end_time = min(end_time, time)
|
| 141 |
+
start_time = min(start_time, end_time)
|
| 142 |
+
# end time should be the latest of all times except time
|
| 143 |
+
# after this we can assume order (proven by zhenjia, due to the 2 min operations)
|
| 144 |
+
|
| 145 |
+
# Constraints:
|
| 146 |
+
# start_time <= end_time <= time (proven by zhenjia)
|
| 147 |
+
# curr_time <= start_time (proven by zhenjia)
|
| 148 |
+
# curr_time <= time (proven by zhenjia)
|
| 149 |
+
|
| 150 |
+
# time can't change
|
| 151 |
+
# last_waypoint_time can't change
|
| 152 |
+
# curr_time can't change
|
| 153 |
+
assert start_time <= end_time
|
| 154 |
+
assert end_time <= time
|
| 155 |
+
if last_waypoint_time is not None:
|
| 156 |
+
if time <= last_waypoint_time:
|
| 157 |
+
assert end_time == curr_time
|
| 158 |
+
else:
|
| 159 |
+
assert end_time == max(last_waypoint_time, curr_time)
|
| 160 |
+
|
| 161 |
+
if curr_time is not None:
|
| 162 |
+
assert curr_time <= start_time
|
| 163 |
+
assert curr_time <= time
|
| 164 |
+
|
| 165 |
+
trimmed_interp = self.trim(start_time, end_time)
|
| 166 |
+
# after this, all waypoints in trimmed_interp is within start_time and end_time
|
| 167 |
+
# and is earlier than time
|
| 168 |
+
|
| 169 |
+
# determine speed
|
| 170 |
+
duration = time - end_time
|
| 171 |
+
end_pose = trimmed_interp(end_time)
|
| 172 |
+
pos_dist, rot_dist = pose_distance(pose, end_pose)
|
| 173 |
+
pos_min_duration = pos_dist / max_pos_speed
|
| 174 |
+
rot_min_duration = rot_dist / max_rot_speed
|
| 175 |
+
duration = max(duration, max(pos_min_duration, rot_min_duration))
|
| 176 |
+
assert duration >= 0
|
| 177 |
+
last_waypoint_time = end_time + duration
|
| 178 |
+
|
| 179 |
+
# insert new pose
|
| 180 |
+
times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0)
|
| 181 |
+
poses = np.append(trimmed_interp.poses, [pose], axis=0)
|
| 182 |
+
|
| 183 |
+
# create new interpolator
|
| 184 |
+
final_interp = PoseTrajectoryInterpolator(times, poses)
|
| 185 |
+
return final_interp
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def __call__(self, t: Union[numbers.Number, np.ndarray]) -> np.ndarray:
|
| 189 |
+
is_single = False
|
| 190 |
+
if isinstance(t, numbers.Number):
|
| 191 |
+
is_single = True
|
| 192 |
+
t = np.array([t])
|
| 193 |
+
|
| 194 |
+
pose = np.zeros((len(t), 6))
|
| 195 |
+
if self.single_step:
|
| 196 |
+
pose[:] = self._poses[0]
|
| 197 |
+
else:
|
| 198 |
+
start_time = self.times[0]
|
| 199 |
+
end_time = self.times[-1]
|
| 200 |
+
t = np.clip(t, start_time, end_time)
|
| 201 |
+
|
| 202 |
+
pose = np.zeros((len(t), 6))
|
| 203 |
+
pose[:,:3] = self.pos_interp(t)
|
| 204 |
+
pose[:,3:] = self.rot_interp(t).as_rotvec()
|
| 205 |
+
|
| 206 |
+
if is_single:
|
| 207 |
+
pose = pose[0]
|
| 208 |
+
return pose
|
equidiff/equi_diffpo/common/precise_sleep.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
def precise_sleep(dt: float, slack_time: float=0.001, time_func=time.monotonic):
|
| 4 |
+
"""
|
| 5 |
+
Use hybrid of time.sleep and spinning to minimize jitter.
|
| 6 |
+
Sleep dt - slack_time seconds first, then spin for the rest.
|
| 7 |
+
"""
|
| 8 |
+
t_start = time_func()
|
| 9 |
+
if dt > slack_time:
|
| 10 |
+
time.sleep(dt - slack_time)
|
| 11 |
+
t_end = t_start + dt
|
| 12 |
+
while time_func() < t_end:
|
| 13 |
+
pass
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
def precise_wait(t_end: float, slack_time: float=0.001, time_func=time.monotonic):
|
| 17 |
+
t_start = time_func()
|
| 18 |
+
t_wait = t_end - t_start
|
| 19 |
+
if t_wait > 0:
|
| 20 |
+
t_sleep = t_wait - slack_time
|
| 21 |
+
if t_sleep > 0:
|
| 22 |
+
time.sleep(t_sleep)
|
| 23 |
+
while time_func() < t_end:
|
| 24 |
+
pass
|
| 25 |
+
return
|
equidiff/equi_diffpo/common/pymunk_override.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ----------------------------------------------------------------------------
|
| 2 |
+
# pymunk
|
| 3 |
+
# Copyright (c) 2007-2016 Victor Blomqvist
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in
|
| 13 |
+
# all copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
# ----------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
"""This submodule contains helper functions to help with quick prototyping
|
| 25 |
+
using pymunk together with pygame.
|
| 26 |
+
|
| 27 |
+
Intended to help with debugging and prototyping, not for actual production use
|
| 28 |
+
in a full application. The methods contained in this module is opinionated
|
| 29 |
+
about your coordinate system and not in any way optimized.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
__docformat__ = "reStructuredText"
|
| 33 |
+
|
| 34 |
+
__all__ = [
|
| 35 |
+
"DrawOptions",
|
| 36 |
+
"get_mouse_pos",
|
| 37 |
+
"to_pygame",
|
| 38 |
+
"from_pygame",
|
| 39 |
+
"lighten",
|
| 40 |
+
"positive_y_is_up",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
from typing import List, Sequence, Tuple
|
| 44 |
+
|
| 45 |
+
import pygame
|
| 46 |
+
|
| 47 |
+
import numpy as np
|
| 48 |
+
|
| 49 |
+
import pymunk
|
| 50 |
+
from pymunk.space_debug_draw_options import SpaceDebugColor
|
| 51 |
+
from pymunk.vec2d import Vec2d
|
| 52 |
+
|
| 53 |
+
positive_y_is_up: bool = False
|
| 54 |
+
"""Make increasing values of y point upwards.
|
| 55 |
+
|
| 56 |
+
When True::
|
| 57 |
+
|
| 58 |
+
y
|
| 59 |
+
^
|
| 60 |
+
| . (3, 3)
|
| 61 |
+
|
|
| 62 |
+
| . (2, 2)
|
| 63 |
+
|
|
| 64 |
+
+------ > x
|
| 65 |
+
|
| 66 |
+
When False::
|
| 67 |
+
|
| 68 |
+
+------ > x
|
| 69 |
+
|
|
| 70 |
+
| . (2, 2)
|
| 71 |
+
|
|
| 72 |
+
| . (3, 3)
|
| 73 |
+
v
|
| 74 |
+
y
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DrawOptions(pymunk.SpaceDebugDrawOptions):
|
| 80 |
+
def __init__(self, surface: pygame.Surface) -> None:
|
| 81 |
+
"""Draw a pymunk.Space on a pygame.Surface object.
|
| 82 |
+
|
| 83 |
+
Typical usage::
|
| 84 |
+
|
| 85 |
+
>>> import pymunk
|
| 86 |
+
>>> surface = pygame.Surface((10,10))
|
| 87 |
+
>>> space = pymunk.Space()
|
| 88 |
+
>>> options = pymunk.pygame_util.DrawOptions(surface)
|
| 89 |
+
>>> space.debug_draw(options)
|
| 90 |
+
|
| 91 |
+
You can control the color of a shape by setting shape.color to the color
|
| 92 |
+
you want it drawn in::
|
| 93 |
+
|
| 94 |
+
>>> c = pymunk.Circle(None, 10)
|
| 95 |
+
>>> c.color = pygame.Color("pink")
|
| 96 |
+
|
| 97 |
+
See pygame_util.demo.py for a full example
|
| 98 |
+
|
| 99 |
+
Since pygame uses a coordinate system where y points down (in contrast
|
| 100 |
+
to many other cases), you either have to make the physics simulation
|
| 101 |
+
with Pymunk also behave in that way, or flip everything when you draw.
|
| 102 |
+
|
| 103 |
+
The easiest is probably to just make the simulation behave the same
|
| 104 |
+
way as Pygame does. In that way all coordinates used are in the same
|
| 105 |
+
orientation and easy to reason about::
|
| 106 |
+
|
| 107 |
+
>>> space = pymunk.Space()
|
| 108 |
+
>>> space.gravity = (0, -1000)
|
| 109 |
+
>>> body = pymunk.Body()
|
| 110 |
+
>>> body.position = (0, 0) # will be positioned in the top left corner
|
| 111 |
+
>>> space.debug_draw(options)
|
| 112 |
+
|
| 113 |
+
To flip the drawing its possible to set the module property
|
| 114 |
+
:py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
|
| 115 |
+
the simulation upside down before drawing::
|
| 116 |
+
|
| 117 |
+
>>> positive_y_is_up = True
|
| 118 |
+
>>> body = pymunk.Body()
|
| 119 |
+
>>> body.position = (0, 0)
|
| 120 |
+
>>> # Body will be position in bottom left corner
|
| 121 |
+
|
| 122 |
+
:Parameters:
|
| 123 |
+
surface : pygame.Surface
|
| 124 |
+
Surface that the objects will be drawn on
|
| 125 |
+
"""
|
| 126 |
+
self.surface = surface
|
| 127 |
+
super(DrawOptions, self).__init__()
|
| 128 |
+
|
| 129 |
+
def draw_circle(
|
| 130 |
+
self,
|
| 131 |
+
pos: Vec2d,
|
| 132 |
+
angle: float,
|
| 133 |
+
radius: float,
|
| 134 |
+
outline_color: SpaceDebugColor,
|
| 135 |
+
fill_color: SpaceDebugColor,
|
| 136 |
+
) -> None:
|
| 137 |
+
p = to_pygame(pos, self.surface)
|
| 138 |
+
|
| 139 |
+
pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
|
| 140 |
+
pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius-4), 0)
|
| 141 |
+
|
| 142 |
+
circle_edge = pos + Vec2d(radius, 0).rotated(angle)
|
| 143 |
+
p2 = to_pygame(circle_edge, self.surface)
|
| 144 |
+
line_r = 2 if radius > 20 else 1
|
| 145 |
+
# pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
|
| 146 |
+
|
| 147 |
+
def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
|
| 148 |
+
p1 = to_pygame(a, self.surface)
|
| 149 |
+
p2 = to_pygame(b, self.surface)
|
| 150 |
+
|
| 151 |
+
pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
|
| 152 |
+
|
| 153 |
+
def draw_fat_segment(
|
| 154 |
+
self,
|
| 155 |
+
a: Tuple[float, float],
|
| 156 |
+
b: Tuple[float, float],
|
| 157 |
+
radius: float,
|
| 158 |
+
outline_color: SpaceDebugColor,
|
| 159 |
+
fill_color: SpaceDebugColor,
|
| 160 |
+
) -> None:
|
| 161 |
+
p1 = to_pygame(a, self.surface)
|
| 162 |
+
p2 = to_pygame(b, self.surface)
|
| 163 |
+
|
| 164 |
+
r = round(max(1, radius * 2))
|
| 165 |
+
pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
|
| 166 |
+
if r > 2:
|
| 167 |
+
orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
|
| 168 |
+
if orthog[0] == 0 and orthog[1] == 0:
|
| 169 |
+
return
|
| 170 |
+
scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
|
| 171 |
+
orthog[0] = round(orthog[0] * scale)
|
| 172 |
+
orthog[1] = round(orthog[1] * scale)
|
| 173 |
+
points = [
|
| 174 |
+
(p1[0] - orthog[0], p1[1] - orthog[1]),
|
| 175 |
+
(p1[0] + orthog[0], p1[1] + orthog[1]),
|
| 176 |
+
(p2[0] + orthog[0], p2[1] + orthog[1]),
|
| 177 |
+
(p2[0] - orthog[0], p2[1] - orthog[1]),
|
| 178 |
+
]
|
| 179 |
+
pygame.draw.polygon(self.surface, fill_color.as_int(), points)
|
| 180 |
+
pygame.draw.circle(
|
| 181 |
+
self.surface,
|
| 182 |
+
fill_color.as_int(),
|
| 183 |
+
(round(p1[0]), round(p1[1])),
|
| 184 |
+
round(radius),
|
| 185 |
+
)
|
| 186 |
+
pygame.draw.circle(
|
| 187 |
+
self.surface,
|
| 188 |
+
fill_color.as_int(),
|
| 189 |
+
(round(p2[0]), round(p2[1])),
|
| 190 |
+
round(radius),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def draw_polygon(
|
| 194 |
+
self,
|
| 195 |
+
verts: Sequence[Tuple[float, float]],
|
| 196 |
+
radius: float,
|
| 197 |
+
outline_color: SpaceDebugColor,
|
| 198 |
+
fill_color: SpaceDebugColor,
|
| 199 |
+
) -> None:
|
| 200 |
+
ps = [to_pygame(v, self.surface) for v in verts]
|
| 201 |
+
ps += [ps[0]]
|
| 202 |
+
|
| 203 |
+
radius = 2
|
| 204 |
+
pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
|
| 205 |
+
|
| 206 |
+
if radius > 0:
|
| 207 |
+
for i in range(len(verts)):
|
| 208 |
+
a = verts[i]
|
| 209 |
+
b = verts[(i + 1) % len(verts)]
|
| 210 |
+
self.draw_fat_segment(a, b, radius, fill_color, fill_color)
|
| 211 |
+
|
| 212 |
+
def draw_dot(
|
| 213 |
+
self, size: float, pos: Tuple[float, float], color: SpaceDebugColor
|
| 214 |
+
) -> None:
|
| 215 |
+
p = to_pygame(pos, self.surface)
|
| 216 |
+
pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
|
| 220 |
+
"""Get position of the mouse pointer in pymunk coordinates."""
|
| 221 |
+
p = pygame.mouse.get_pos()
|
| 222 |
+
return from_pygame(p, surface)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
| 226 |
+
"""Convenience method to convert pymunk coordinates to pygame surface
|
| 227 |
+
local coordinates.
|
| 228 |
+
|
| 229 |
+
Note that in case positive_y_is_up is False, this function won't actually do
|
| 230 |
+
anything except converting the point to integers.
|
| 231 |
+
"""
|
| 232 |
+
if positive_y_is_up:
|
| 233 |
+
return round(p[0]), surface.get_height() - round(p[1])
|
| 234 |
+
else:
|
| 235 |
+
return round(p[0]), round(p[1])
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
| 239 |
+
"""Convenience method to convert pygame surface local coordinates to
|
| 240 |
+
pymunk coordinates
|
| 241 |
+
"""
|
| 242 |
+
return to_pygame(p, surface)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def light_color(color: SpaceDebugColor):
|
| 246 |
+
color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
|
| 247 |
+
color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
|
| 248 |
+
return color
|
equidiff/equi_diffpo/common/pymunk_util.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pygame
|
| 2 |
+
import pymunk
|
| 3 |
+
import pymunk.pygame_util
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
COLLTYPE_DEFAULT = 0
|
| 7 |
+
COLLTYPE_MOUSE = 1
|
| 8 |
+
COLLTYPE_BALL = 2
|
| 9 |
+
|
| 10 |
+
def get_body_type(static=False):
|
| 11 |
+
body_type = pymunk.Body.DYNAMIC
|
| 12 |
+
if static:
|
| 13 |
+
body_type = pymunk.Body.STATIC
|
| 14 |
+
return body_type
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def create_rectangle(space,
|
| 18 |
+
pos_x,pos_y,width,height,
|
| 19 |
+
density=3,static=False):
|
| 20 |
+
body = pymunk.Body(body_type=get_body_type(static))
|
| 21 |
+
body.position = (pos_x,pos_y)
|
| 22 |
+
shape = pymunk.Poly.create_box(body,(width,height))
|
| 23 |
+
shape.density = density
|
| 24 |
+
space.add(body,shape)
|
| 25 |
+
return body, shape
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_rectangle_bb(space,
|
| 29 |
+
left, bottom, right, top,
|
| 30 |
+
**kwargs):
|
| 31 |
+
pos_x = (left + right) / 2
|
| 32 |
+
pos_y = (top + bottom) / 2
|
| 33 |
+
height = top - bottom
|
| 34 |
+
width = right - left
|
| 35 |
+
return create_rectangle(space, pos_x, pos_y, width, height, **kwargs)
|
| 36 |
+
|
| 37 |
+
def create_circle(space, pos_x, pos_y, radius, density=3, static=False):
|
| 38 |
+
body = pymunk.Body(body_type=get_body_type(static))
|
| 39 |
+
body.position = (pos_x, pos_y)
|
| 40 |
+
shape = pymunk.Circle(body, radius=radius)
|
| 41 |
+
shape.density = density
|
| 42 |
+
shape.collision_type = COLLTYPE_BALL
|
| 43 |
+
space.add(body, shape)
|
| 44 |
+
return body, shape
|
| 45 |
+
|
| 46 |
+
def get_body_state(body):
|
| 47 |
+
state = np.zeros(6, dtype=np.float32)
|
| 48 |
+
state[:2] = body.position
|
| 49 |
+
state[2] = body.angle
|
| 50 |
+
state[3:5] = body.velocity
|
| 51 |
+
state[5] = body.angular_velocity
|
| 52 |
+
return state
|
equidiff/equi_diffpo/common/pytorch_util.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Callable, List
|
| 2 |
+
import collections
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
def dict_apply(
|
| 7 |
+
x: Dict[str, torch.Tensor],
|
| 8 |
+
func: Callable[[torch.Tensor], torch.Tensor]
|
| 9 |
+
) -> Dict[str, torch.Tensor]:
|
| 10 |
+
result = dict()
|
| 11 |
+
for key, value in x.items():
|
| 12 |
+
if isinstance(value, dict):
|
| 13 |
+
result[key] = dict_apply(value, func)
|
| 14 |
+
else:
|
| 15 |
+
result[key] = func(value)
|
| 16 |
+
return result
|
| 17 |
+
|
| 18 |
+
def pad_remaining_dims(x, target):
|
| 19 |
+
assert x.shape == target.shape[:len(x.shape)]
|
| 20 |
+
return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape)))
|
| 21 |
+
|
| 22 |
+
def dict_apply_split(
|
| 23 |
+
x: Dict[str, torch.Tensor],
|
| 24 |
+
split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
|
| 25 |
+
) -> Dict[str, torch.Tensor]:
|
| 26 |
+
results = collections.defaultdict(dict)
|
| 27 |
+
for key, value in x.items():
|
| 28 |
+
result = split_func(value)
|
| 29 |
+
for k, v in result.items():
|
| 30 |
+
results[k][key] = v
|
| 31 |
+
return results
|
| 32 |
+
|
| 33 |
+
def dict_apply_reduce(
|
| 34 |
+
x: List[Dict[str, torch.Tensor]],
|
| 35 |
+
reduce_func: Callable[[List[torch.Tensor]], torch.Tensor]
|
| 36 |
+
) -> Dict[str, torch.Tensor]:
|
| 37 |
+
result = dict()
|
| 38 |
+
for key in x[0].keys():
|
| 39 |
+
result[key] = reduce_func([x_[key] for x_ in x])
|
| 40 |
+
return result
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def replace_submodules(
|
| 44 |
+
root_module: nn.Module,
|
| 45 |
+
predicate: Callable[[nn.Module], bool],
|
| 46 |
+
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
|
| 47 |
+
"""
|
| 48 |
+
predicate: Return true if the module is to be replaced.
|
| 49 |
+
func: Return new module to use.
|
| 50 |
+
"""
|
| 51 |
+
if predicate(root_module):
|
| 52 |
+
return func(root_module)
|
| 53 |
+
|
| 54 |
+
bn_list = [k.split('.') for k, m
|
| 55 |
+
in root_module.named_modules(remove_duplicate=True)
|
| 56 |
+
if predicate(m)]
|
| 57 |
+
for *parent, k in bn_list:
|
| 58 |
+
parent_module = root_module
|
| 59 |
+
if len(parent) > 0:
|
| 60 |
+
parent_module = root_module.get_submodule('.'.join(parent))
|
| 61 |
+
if isinstance(parent_module, nn.Sequential):
|
| 62 |
+
src_module = parent_module[int(k)]
|
| 63 |
+
else:
|
| 64 |
+
src_module = getattr(parent_module, k)
|
| 65 |
+
tgt_module = func(src_module)
|
| 66 |
+
if isinstance(parent_module, nn.Sequential):
|
| 67 |
+
parent_module[int(k)] = tgt_module
|
| 68 |
+
else:
|
| 69 |
+
setattr(parent_module, k, tgt_module)
|
| 70 |
+
# verify that all BN are replaced
|
| 71 |
+
bn_list = [k.split('.') for k, m
|
| 72 |
+
in root_module.named_modules(remove_duplicate=True)
|
| 73 |
+
if predicate(m)]
|
| 74 |
+
assert len(bn_list) == 0
|
| 75 |
+
return root_module
|
| 76 |
+
|
| 77 |
+
def optimizer_to(optimizer, device):
|
| 78 |
+
for state in optimizer.state.values():
|
| 79 |
+
for k, v in state.items():
|
| 80 |
+
if isinstance(v, torch.Tensor):
|
| 81 |
+
state[k] = v.to(device=device)
|
| 82 |
+
return optimizer
|
equidiff/equi_diffpo/common/replay_buffer.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Dict, Optional
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import numbers
|
| 5 |
+
import zarr
|
| 6 |
+
import numcodecs
|
| 7 |
+
import numpy as np
|
| 8 |
+
from functools import cached_property
|
| 9 |
+
|
| 10 |
+
def check_chunks_compatible(chunks: tuple, shape: tuple):
|
| 11 |
+
assert len(shape) == len(chunks)
|
| 12 |
+
for c in chunks:
|
| 13 |
+
assert isinstance(c, numbers.Integral)
|
| 14 |
+
assert c > 0
|
| 15 |
+
|
| 16 |
+
def rechunk_recompress_array(group, name,
|
| 17 |
+
chunks=None, chunk_length=None,
|
| 18 |
+
compressor=None, tmp_key='_temp'):
|
| 19 |
+
old_arr = group[name]
|
| 20 |
+
if chunks is None:
|
| 21 |
+
if chunk_length is not None:
|
| 22 |
+
chunks = (chunk_length,) + old_arr.chunks[1:]
|
| 23 |
+
else:
|
| 24 |
+
chunks = old_arr.chunks
|
| 25 |
+
check_chunks_compatible(chunks, old_arr.shape)
|
| 26 |
+
|
| 27 |
+
if compressor is None:
|
| 28 |
+
compressor = old_arr.compressor
|
| 29 |
+
|
| 30 |
+
if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
|
| 31 |
+
# no change
|
| 32 |
+
return old_arr
|
| 33 |
+
|
| 34 |
+
# rechunk recompress
|
| 35 |
+
group.move(name, tmp_key)
|
| 36 |
+
old_arr = group[tmp_key]
|
| 37 |
+
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
| 38 |
+
source=old_arr,
|
| 39 |
+
dest=group,
|
| 40 |
+
name=name,
|
| 41 |
+
chunks=chunks,
|
| 42 |
+
compressor=compressor,
|
| 43 |
+
)
|
| 44 |
+
del group[tmp_key]
|
| 45 |
+
arr = group[name]
|
| 46 |
+
return arr
|
| 47 |
+
|
| 48 |
+
def get_optimal_chunks(shape, dtype,
|
| 49 |
+
target_chunk_bytes=2e6,
|
| 50 |
+
max_chunk_length=None):
|
| 51 |
+
"""
|
| 52 |
+
Common shapes
|
| 53 |
+
T,D
|
| 54 |
+
T,N,D
|
| 55 |
+
T,H,W,C
|
| 56 |
+
T,N,H,W,C
|
| 57 |
+
"""
|
| 58 |
+
itemsize = np.dtype(dtype).itemsize
|
| 59 |
+
# reversed
|
| 60 |
+
rshape = list(shape[::-1])
|
| 61 |
+
if max_chunk_length is not None:
|
| 62 |
+
rshape[-1] = int(max_chunk_length)
|
| 63 |
+
split_idx = len(shape)-1
|
| 64 |
+
for i in range(len(shape)-1):
|
| 65 |
+
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
| 66 |
+
next_chunk_bytes = itemsize * np.prod(rshape[:i+1])
|
| 67 |
+
if this_chunk_bytes <= target_chunk_bytes \
|
| 68 |
+
and next_chunk_bytes > target_chunk_bytes:
|
| 69 |
+
split_idx = i
|
| 70 |
+
|
| 71 |
+
rchunks = rshape[:split_idx]
|
| 72 |
+
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
| 73 |
+
this_max_chunk_length = rshape[split_idx]
|
| 74 |
+
next_chunk_length = min(this_max_chunk_length, math.ceil(
|
| 75 |
+
target_chunk_bytes / item_chunk_bytes))
|
| 76 |
+
rchunks.append(next_chunk_length)
|
| 77 |
+
len_diff = len(shape) - len(rchunks)
|
| 78 |
+
rchunks.extend([1] * len_diff)
|
| 79 |
+
chunks = tuple(rchunks[::-1])
|
| 80 |
+
# print(np.prod(chunks) * itemsize / target_chunk_bytes)
|
| 81 |
+
return chunks
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ReplayBuffer:
|
| 85 |
+
"""
|
| 86 |
+
Zarr-based temporal datastructure.
|
| 87 |
+
Assumes first dimension to be time. Only chunk in time dimension.
|
| 88 |
+
"""
|
| 89 |
+
def __init__(self,
|
| 90 |
+
root: Union[zarr.Group,
|
| 91 |
+
Dict[str,dict]]):
|
| 92 |
+
"""
|
| 93 |
+
Dummy constructor. Use copy_from* and create_from* class methods instead.
|
| 94 |
+
"""
|
| 95 |
+
assert('data' in root)
|
| 96 |
+
assert('meta' in root)
|
| 97 |
+
assert('episode_ends' in root['meta'])
|
| 98 |
+
for key, value in root['data'].items():
|
| 99 |
+
assert(value.shape[0] == root['meta']['episode_ends'][-1])
|
| 100 |
+
self.root = root
|
| 101 |
+
|
| 102 |
+
# ============= create constructors ===============
|
| 103 |
+
@classmethod
|
| 104 |
+
def create_empty_zarr(cls, storage=None, root=None):
|
| 105 |
+
if root is None:
|
| 106 |
+
if storage is None:
|
| 107 |
+
storage = zarr.MemoryStore()
|
| 108 |
+
root = zarr.group(store=storage)
|
| 109 |
+
data = root.require_group('data', overwrite=False)
|
| 110 |
+
meta = root.require_group('meta', overwrite=False)
|
| 111 |
+
if 'episode_ends' not in meta:
|
| 112 |
+
episode_ends = meta.zeros('episode_ends', shape=(0,), dtype=np.int64,
|
| 113 |
+
compressor=None, overwrite=False)
|
| 114 |
+
return cls(root=root)
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def create_empty_numpy(cls):
|
| 118 |
+
root = {
|
| 119 |
+
'data': dict(),
|
| 120 |
+
'meta': {
|
| 121 |
+
'episode_ends': np.zeros((0,), dtype=np.int64)
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
return cls(root=root)
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def create_from_group(cls, group, **kwargs):
|
| 128 |
+
if 'data' not in group:
|
| 129 |
+
# create from stratch
|
| 130 |
+
buffer = cls.create_empty_zarr(root=group, **kwargs)
|
| 131 |
+
else:
|
| 132 |
+
# already exist
|
| 133 |
+
buffer = cls(root=group, **kwargs)
|
| 134 |
+
return buffer
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def create_from_path(cls, zarr_path, mode='r', **kwargs):
|
| 138 |
+
"""
|
| 139 |
+
Open a on-disk zarr directly (for dataset larger than memory).
|
| 140 |
+
Slower.
|
| 141 |
+
"""
|
| 142 |
+
group = zarr.open(os.path.expanduser(zarr_path), mode)
|
| 143 |
+
return cls.create_from_group(group, **kwargs)
|
| 144 |
+
|
| 145 |
+
# ============= copy constructors ===============
|
| 146 |
+
@classmethod
|
| 147 |
+
def copy_from_store(cls, src_store, store=None, keys=None,
|
| 148 |
+
chunks: Dict[str,tuple]=dict(),
|
| 149 |
+
compressors: Union[dict, str, numcodecs.abc.Codec]=dict(),
|
| 150 |
+
if_exists='replace',
|
| 151 |
+
**kwargs):
|
| 152 |
+
"""
|
| 153 |
+
Load to memory.
|
| 154 |
+
"""
|
| 155 |
+
src_root = zarr.group(src_store)
|
| 156 |
+
root = None
|
| 157 |
+
if store is None:
|
| 158 |
+
# numpy backend
|
| 159 |
+
meta = dict()
|
| 160 |
+
for key, value in src_root['meta'].items():
|
| 161 |
+
if len(value.shape) == 0:
|
| 162 |
+
meta[key] = np.array(value)
|
| 163 |
+
else:
|
| 164 |
+
meta[key] = value[:]
|
| 165 |
+
|
| 166 |
+
if keys is None:
|
| 167 |
+
keys = src_root['data'].keys()
|
| 168 |
+
data = dict()
|
| 169 |
+
for key in keys:
|
| 170 |
+
arr = src_root['data'][key]
|
| 171 |
+
data[key] = arr[:]
|
| 172 |
+
|
| 173 |
+
root = {
|
| 174 |
+
'meta': meta,
|
| 175 |
+
'data': data
|
| 176 |
+
}
|
| 177 |
+
else:
|
| 178 |
+
root = zarr.group(store=store)
|
| 179 |
+
# copy without recompression
|
| 180 |
+
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(source=src_store, dest=store,
|
| 181 |
+
source_path='/meta', dest_path='/meta', if_exists=if_exists)
|
| 182 |
+
data_group = root.create_group('data', overwrite=True)
|
| 183 |
+
if keys is None:
|
| 184 |
+
keys = src_root['data'].keys()
|
| 185 |
+
for key in keys:
|
| 186 |
+
value = src_root['data'][key]
|
| 187 |
+
cks = cls._resolve_array_chunks(
|
| 188 |
+
chunks=chunks, key=key, array=value)
|
| 189 |
+
cpr = cls._resolve_array_compressor(
|
| 190 |
+
compressors=compressors, key=key, array=value)
|
| 191 |
+
if cks == value.chunks and cpr == value.compressor:
|
| 192 |
+
# copy without recompression
|
| 193 |
+
this_path = '/data/' + key
|
| 194 |
+
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
| 195 |
+
source=src_store, dest=store,
|
| 196 |
+
source_path=this_path, dest_path=this_path,
|
| 197 |
+
if_exists=if_exists
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
# copy with recompression
|
| 201 |
+
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
| 202 |
+
source=value, dest=data_group, name=key,
|
| 203 |
+
chunks=cks, compressor=cpr, if_exists=if_exists
|
| 204 |
+
)
|
| 205 |
+
buffer = cls(root=root)
|
| 206 |
+
return buffer
|
| 207 |
+
|
| 208 |
+
@classmethod
|
| 209 |
+
def copy_from_path(cls, zarr_path, backend=None, store=None, keys=None,
|
| 210 |
+
chunks: Dict[str,tuple]=dict(),
|
| 211 |
+
compressors: Union[dict, str, numcodecs.abc.Codec]=dict(),
|
| 212 |
+
if_exists='replace',
|
| 213 |
+
**kwargs):
|
| 214 |
+
"""
|
| 215 |
+
Copy a on-disk zarr to in-memory compressed.
|
| 216 |
+
Recommended
|
| 217 |
+
"""
|
| 218 |
+
if backend == 'numpy':
|
| 219 |
+
print('backend argument is deprecated!')
|
| 220 |
+
store = None
|
| 221 |
+
group = zarr.open(os.path.expanduser(zarr_path), 'r')
|
| 222 |
+
return cls.copy_from_store(src_store=group.store, store=store,
|
| 223 |
+
keys=keys, chunks=chunks, compressors=compressors,
|
| 224 |
+
if_exists=if_exists, **kwargs)
|
| 225 |
+
|
| 226 |
+
# ============= save methods ===============
|
| 227 |
+
def save_to_store(self, store,
|
| 228 |
+
chunks: Optional[Dict[str,tuple]]=dict(),
|
| 229 |
+
compressors: Union[str, numcodecs.abc.Codec, dict]=dict(),
|
| 230 |
+
if_exists='replace',
|
| 231 |
+
**kwargs):
|
| 232 |
+
|
| 233 |
+
root = zarr.group(store)
|
| 234 |
+
if self.backend == 'zarr':
|
| 235 |
+
# recompression free copy
|
| 236 |
+
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
| 237 |
+
source=self.root.store, dest=store,
|
| 238 |
+
source_path='/meta', dest_path='/meta', if_exists=if_exists)
|
| 239 |
+
else:
|
| 240 |
+
meta_group = root.create_group('meta', overwrite=True)
|
| 241 |
+
# save meta, no chunking
|
| 242 |
+
for key, value in self.root['meta'].items():
|
| 243 |
+
_ = meta_group.array(
|
| 244 |
+
name=key,
|
| 245 |
+
data=value,
|
| 246 |
+
shape=value.shape,
|
| 247 |
+
chunks=value.shape)
|
| 248 |
+
|
| 249 |
+
# save data, chunk
|
| 250 |
+
data_group = root.create_group('data', overwrite=True)
|
| 251 |
+
for key, value in self.root['data'].items():
|
| 252 |
+
cks = self._resolve_array_chunks(
|
| 253 |
+
chunks=chunks, key=key, array=value)
|
| 254 |
+
cpr = self._resolve_array_compressor(
|
| 255 |
+
compressors=compressors, key=key, array=value)
|
| 256 |
+
if isinstance(value, zarr.Array):
|
| 257 |
+
if cks == value.chunks and cpr == value.compressor:
|
| 258 |
+
# copy without recompression
|
| 259 |
+
this_path = '/data/' + key
|
| 260 |
+
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
| 261 |
+
source=self.root.store, dest=store,
|
| 262 |
+
source_path=this_path, dest_path=this_path, if_exists=if_exists)
|
| 263 |
+
else:
|
| 264 |
+
# copy with recompression
|
| 265 |
+
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
| 266 |
+
source=value, dest=data_group, name=key,
|
| 267 |
+
chunks=cks, compressor=cpr, if_exists=if_exists
|
| 268 |
+
)
|
| 269 |
+
else:
|
| 270 |
+
# numpy
|
| 271 |
+
_ = data_group.array(
|
| 272 |
+
name=key,
|
| 273 |
+
data=value,
|
| 274 |
+
chunks=cks,
|
| 275 |
+
compressor=cpr
|
| 276 |
+
)
|
| 277 |
+
return store
|
| 278 |
+
|
| 279 |
+
def save_to_path(self, zarr_path,
|
| 280 |
+
chunks: Optional[Dict[str,tuple]]=dict(),
|
| 281 |
+
compressors: Union[str, numcodecs.abc.Codec, dict]=dict(),
|
| 282 |
+
if_exists='replace',
|
| 283 |
+
**kwargs):
|
| 284 |
+
store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
|
| 285 |
+
return self.save_to_store(store, chunks=chunks,
|
| 286 |
+
compressors=compressors, if_exists=if_exists, **kwargs)
|
| 287 |
+
|
| 288 |
+
@staticmethod
|
| 289 |
+
def resolve_compressor(compressor='default'):
|
| 290 |
+
if compressor == 'default':
|
| 291 |
+
compressor = numcodecs.Blosc(cname='lz4', clevel=5,
|
| 292 |
+
shuffle=numcodecs.Blosc.NOSHUFFLE)
|
| 293 |
+
elif compressor == 'disk':
|
| 294 |
+
compressor = numcodecs.Blosc('zstd', clevel=5,
|
| 295 |
+
shuffle=numcodecs.Blosc.BITSHUFFLE)
|
| 296 |
+
return compressor
|
| 297 |
+
|
| 298 |
+
@classmethod
|
| 299 |
+
def _resolve_array_compressor(cls,
|
| 300 |
+
compressors: Union[dict, str, numcodecs.abc.Codec], key, array):
|
| 301 |
+
# allows compressor to be explicitly set to None
|
| 302 |
+
cpr = 'nil'
|
| 303 |
+
if isinstance(compressors, dict):
|
| 304 |
+
if key in compressors:
|
| 305 |
+
cpr = cls.resolve_compressor(compressors[key])
|
| 306 |
+
elif isinstance(array, zarr.Array):
|
| 307 |
+
cpr = array.compressor
|
| 308 |
+
else:
|
| 309 |
+
cpr = cls.resolve_compressor(compressors)
|
| 310 |
+
# backup default
|
| 311 |
+
if cpr == 'nil':
|
| 312 |
+
cpr = cls.resolve_compressor('default')
|
| 313 |
+
return cpr
|
| 314 |
+
|
| 315 |
+
@classmethod
|
| 316 |
+
def _resolve_array_chunks(cls,
|
| 317 |
+
chunks: Union[dict, tuple], key, array):
|
| 318 |
+
cks = None
|
| 319 |
+
if isinstance(chunks, dict):
|
| 320 |
+
if key in chunks:
|
| 321 |
+
cks = chunks[key]
|
| 322 |
+
elif isinstance(array, zarr.Array):
|
| 323 |
+
cks = array.chunks
|
| 324 |
+
elif isinstance(chunks, tuple):
|
| 325 |
+
cks = chunks
|
| 326 |
+
else:
|
| 327 |
+
raise TypeError(f"Unsupported chunks type {type(chunks)}")
|
| 328 |
+
# backup default
|
| 329 |
+
if cks is None:
|
| 330 |
+
cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
|
| 331 |
+
# check
|
| 332 |
+
check_chunks_compatible(chunks=cks, shape=array.shape)
|
| 333 |
+
return cks
|
| 334 |
+
|
| 335 |
+
# ============= properties =================
|
| 336 |
+
@cached_property
|
| 337 |
+
def data(self):
|
| 338 |
+
return self.root['data']
|
| 339 |
+
|
| 340 |
+
@cached_property
|
| 341 |
+
def meta(self):
|
| 342 |
+
return self.root['meta']
|
| 343 |
+
|
| 344 |
+
def update_meta(self, data):
|
| 345 |
+
# sanitize data
|
| 346 |
+
np_data = dict()
|
| 347 |
+
for key, value in data.items():
|
| 348 |
+
if isinstance(value, np.ndarray):
|
| 349 |
+
np_data[key] = value
|
| 350 |
+
else:
|
| 351 |
+
arr = np.array(value)
|
| 352 |
+
if arr.dtype == object:
|
| 353 |
+
raise TypeError(f"Invalid value type {type(value)}")
|
| 354 |
+
np_data[key] = arr
|
| 355 |
+
|
| 356 |
+
meta_group = self.meta
|
| 357 |
+
if self.backend == 'zarr':
|
| 358 |
+
for key, value in np_data.items():
|
| 359 |
+
_ = meta_group.array(
|
| 360 |
+
name=key,
|
| 361 |
+
data=value,
|
| 362 |
+
shape=value.shape,
|
| 363 |
+
chunks=value.shape,
|
| 364 |
+
overwrite=True)
|
| 365 |
+
else:
|
| 366 |
+
meta_group.update(np_data)
|
| 367 |
+
|
| 368 |
+
return meta_group
|
| 369 |
+
|
| 370 |
+
@property
|
| 371 |
+
def episode_ends(self):
|
| 372 |
+
return self.meta['episode_ends']
|
| 373 |
+
|
| 374 |
+
def get_episode_idxs(self):
|
| 375 |
+
import numba
|
| 376 |
+
numba.jit(nopython=True)
|
| 377 |
+
def _get_episode_idxs(episode_ends):
|
| 378 |
+
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
| 379 |
+
for i in range(len(episode_ends)):
|
| 380 |
+
start = 0
|
| 381 |
+
if i > 0:
|
| 382 |
+
start = episode_ends[i-1]
|
| 383 |
+
end = episode_ends[i]
|
| 384 |
+
for idx in range(start, end):
|
| 385 |
+
result[idx] = i
|
| 386 |
+
return result
|
| 387 |
+
return _get_episode_idxs(self.episode_ends)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
@property
|
| 391 |
+
def backend(self):
|
| 392 |
+
backend = 'numpy'
|
| 393 |
+
if isinstance(self.root, zarr.Group):
|
| 394 |
+
backend = 'zarr'
|
| 395 |
+
return backend
|
| 396 |
+
|
| 397 |
+
# =========== dict-like API ==============
|
| 398 |
+
def __repr__(self) -> str:
|
| 399 |
+
if self.backend == 'zarr':
|
| 400 |
+
return str(self.root.tree())
|
| 401 |
+
else:
|
| 402 |
+
return super().__repr__()
|
| 403 |
+
|
| 404 |
+
def keys(self):
|
| 405 |
+
return self.data.keys()
|
| 406 |
+
|
| 407 |
+
def values(self):
|
| 408 |
+
return self.data.values()
|
| 409 |
+
|
| 410 |
+
def items(self):
|
| 411 |
+
return self.data.items()
|
| 412 |
+
|
| 413 |
+
def __getitem__(self, key):
|
| 414 |
+
return self.data[key]
|
| 415 |
+
|
| 416 |
+
def __contains__(self, key):
|
| 417 |
+
return key in self.data
|
| 418 |
+
|
| 419 |
+
# =========== our API ==============
|
| 420 |
+
@property
|
| 421 |
+
def n_steps(self):
|
| 422 |
+
if len(self.episode_ends) == 0:
|
| 423 |
+
return 0
|
| 424 |
+
return self.episode_ends[-1]
|
| 425 |
+
|
| 426 |
+
@property
|
| 427 |
+
def n_episodes(self):
|
| 428 |
+
return len(self.episode_ends)
|
| 429 |
+
|
| 430 |
+
@property
|
| 431 |
+
def chunk_size(self):
|
| 432 |
+
if self.backend == 'zarr':
|
| 433 |
+
return next(iter(self.data.arrays()))[-1].chunks[0]
|
| 434 |
+
return None
|
| 435 |
+
|
| 436 |
+
@property
|
| 437 |
+
def episode_lengths(self):
|
| 438 |
+
ends = self.episode_ends[:]
|
| 439 |
+
ends = np.insert(ends, 0, 0)
|
| 440 |
+
lengths = np.diff(ends)
|
| 441 |
+
return lengths
|
| 442 |
+
|
| 443 |
+
def add_episode(self,
|
| 444 |
+
data: Dict[str, np.ndarray],
|
| 445 |
+
chunks: Optional[Dict[str,tuple]]=dict(),
|
| 446 |
+
compressors: Union[str, numcodecs.abc.Codec, dict]=dict()):
|
| 447 |
+
assert(len(data) > 0)
|
| 448 |
+
is_zarr = (self.backend == 'zarr')
|
| 449 |
+
|
| 450 |
+
curr_len = self.n_steps
|
| 451 |
+
episode_length = None
|
| 452 |
+
for key, value in data.items():
|
| 453 |
+
assert(len(value.shape) >= 1)
|
| 454 |
+
if episode_length is None:
|
| 455 |
+
episode_length = len(value)
|
| 456 |
+
else:
|
| 457 |
+
assert(episode_length == len(value))
|
| 458 |
+
new_len = curr_len + episode_length
|
| 459 |
+
|
| 460 |
+
for key, value in data.items():
|
| 461 |
+
new_shape = (new_len,) + value.shape[1:]
|
| 462 |
+
# create array
|
| 463 |
+
if key not in self.data:
|
| 464 |
+
if is_zarr:
|
| 465 |
+
cks = self._resolve_array_chunks(
|
| 466 |
+
chunks=chunks, key=key, array=value)
|
| 467 |
+
cpr = self._resolve_array_compressor(
|
| 468 |
+
compressors=compressors, key=key, array=value)
|
| 469 |
+
arr = self.data.zeros(name=key,
|
| 470 |
+
shape=new_shape,
|
| 471 |
+
chunks=cks,
|
| 472 |
+
dtype=value.dtype,
|
| 473 |
+
compressor=cpr)
|
| 474 |
+
else:
|
| 475 |
+
# copy data to prevent modify
|
| 476 |
+
arr = np.zeros(shape=new_shape, dtype=value.dtype)
|
| 477 |
+
self.data[key] = arr
|
| 478 |
+
else:
|
| 479 |
+
arr = self.data[key]
|
| 480 |
+
assert(value.shape[1:] == arr.shape[1:])
|
| 481 |
+
# same method for both zarr and numpy
|
| 482 |
+
if is_zarr:
|
| 483 |
+
arr.resize(new_shape)
|
| 484 |
+
else:
|
| 485 |
+
arr.resize(new_shape, refcheck=False)
|
| 486 |
+
# copy data
|
| 487 |
+
arr[-value.shape[0]:] = value
|
| 488 |
+
|
| 489 |
+
# append to episode ends
|
| 490 |
+
episode_ends = self.episode_ends
|
| 491 |
+
if is_zarr:
|
| 492 |
+
episode_ends.resize(episode_ends.shape[0] + 1)
|
| 493 |
+
else:
|
| 494 |
+
episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
|
| 495 |
+
episode_ends[-1] = new_len
|
| 496 |
+
|
| 497 |
+
# rechunk
|
| 498 |
+
if is_zarr:
|
| 499 |
+
if episode_ends.chunks[0] < episode_ends.shape[0]:
|
| 500 |
+
rechunk_recompress_array(self.meta, 'episode_ends',
|
| 501 |
+
chunk_length=int(episode_ends.shape[0] * 1.5))
|
| 502 |
+
|
| 503 |
+
def drop_episode(self):
|
| 504 |
+
is_zarr = (self.backend == 'zarr')
|
| 505 |
+
episode_ends = self.episode_ends[:].copy()
|
| 506 |
+
assert(len(episode_ends) > 0)
|
| 507 |
+
start_idx = 0
|
| 508 |
+
if len(episode_ends) > 1:
|
| 509 |
+
start_idx = episode_ends[-2]
|
| 510 |
+
for key, value in self.data.items():
|
| 511 |
+
new_shape = (start_idx,) + value.shape[1:]
|
| 512 |
+
if is_zarr:
|
| 513 |
+
value.resize(new_shape)
|
| 514 |
+
else:
|
| 515 |
+
value.resize(new_shape, refcheck=False)
|
| 516 |
+
if is_zarr:
|
| 517 |
+
self.episode_ends.resize(len(episode_ends)-1)
|
| 518 |
+
else:
|
| 519 |
+
self.episode_ends.resize(len(episode_ends)-1, refcheck=False)
|
| 520 |
+
|
| 521 |
+
def pop_episode(self):
|
| 522 |
+
assert(self.n_episodes > 0)
|
| 523 |
+
episode = self.get_episode(self.n_episodes-1, copy=True)
|
| 524 |
+
self.drop_episode()
|
| 525 |
+
return episode
|
| 526 |
+
|
| 527 |
+
def extend(self, data):
|
| 528 |
+
self.add_episode(data)
|
| 529 |
+
|
| 530 |
+
def get_episode(self, idx, copy=False):
|
| 531 |
+
idx = list(range(len(self.episode_ends)))[idx]
|
| 532 |
+
start_idx = 0
|
| 533 |
+
if idx > 0:
|
| 534 |
+
start_idx = self.episode_ends[idx-1]
|
| 535 |
+
end_idx = self.episode_ends[idx]
|
| 536 |
+
result = self.get_steps_slice(start_idx, end_idx, copy=copy)
|
| 537 |
+
return result
|
| 538 |
+
|
| 539 |
+
def get_episode_slice(self, idx):
|
| 540 |
+
start_idx = 0
|
| 541 |
+
if idx > 0:
|
| 542 |
+
start_idx = self.episode_ends[idx-1]
|
| 543 |
+
end_idx = self.episode_ends[idx]
|
| 544 |
+
return slice(start_idx, end_idx)
|
| 545 |
+
|
| 546 |
+
def get_steps_slice(self, start, stop, step=None, copy=False):
|
| 547 |
+
_slice = slice(start, stop, step)
|
| 548 |
+
|
| 549 |
+
result = dict()
|
| 550 |
+
for key, value in self.data.items():
|
| 551 |
+
x = value[_slice]
|
| 552 |
+
if copy and isinstance(value, np.ndarray):
|
| 553 |
+
x = x.copy()
|
| 554 |
+
result[key] = x
|
| 555 |
+
return result
|
| 556 |
+
|
| 557 |
+
# =========== chunking =============
|
| 558 |
+
def get_chunks(self) -> dict:
|
| 559 |
+
assert self.backend == 'zarr'
|
| 560 |
+
chunks = dict()
|
| 561 |
+
for key, value in self.data.items():
|
| 562 |
+
chunks[key] = value.chunks
|
| 563 |
+
return chunks
|
| 564 |
+
|
| 565 |
+
def set_chunks(self, chunks: dict):
|
| 566 |
+
assert self.backend == 'zarr'
|
| 567 |
+
for key, value in chunks.items():
|
| 568 |
+
if key in self.data:
|
| 569 |
+
arr = self.data[key]
|
| 570 |
+
if value != arr.chunks:
|
| 571 |
+
check_chunks_compatible(chunks=value, shape=arr.shape)
|
| 572 |
+
rechunk_recompress_array(self.data, key, chunks=value)
|
| 573 |
+
|
| 574 |
+
def get_compressors(self) -> dict:
|
| 575 |
+
assert self.backend == 'zarr'
|
| 576 |
+
compressors = dict()
|
| 577 |
+
for key, value in self.data.items():
|
| 578 |
+
compressors[key] = value.compressor
|
| 579 |
+
return compressors
|
| 580 |
+
|
| 581 |
+
def set_compressors(self, compressors: dict):
|
| 582 |
+
assert self.backend == 'zarr'
|
| 583 |
+
for key, value in compressors.items():
|
| 584 |
+
if key in self.data:
|
| 585 |
+
arr = self.data[key]
|
| 586 |
+
compressor = self.resolve_compressor(value)
|
| 587 |
+
if compressor != arr.compressor:
|
| 588 |
+
rechunk_recompress_array(self.data, key, compressor=compressor)
|
equidiff/equi_diffpo/common/sampler.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import numpy as np
|
| 3 |
+
import numba
|
| 4 |
+
from equi_diffpo.common.replay_buffer import ReplayBuffer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@numba.jit(nopython=True)
|
| 8 |
+
def create_indices(
|
| 9 |
+
episode_ends:np.ndarray, sequence_length:int,
|
| 10 |
+
episode_mask: np.ndarray,
|
| 11 |
+
pad_before: int=0, pad_after: int=0,
|
| 12 |
+
debug:bool=True) -> np.ndarray:
|
| 13 |
+
episode_mask.shape == episode_ends.shape
|
| 14 |
+
pad_before = min(max(pad_before, 0), sequence_length-1)
|
| 15 |
+
pad_after = min(max(pad_after, 0), sequence_length-1)
|
| 16 |
+
|
| 17 |
+
indices = list()
|
| 18 |
+
for i in range(len(episode_ends)):
|
| 19 |
+
if not episode_mask[i]:
|
| 20 |
+
# skip episode
|
| 21 |
+
continue
|
| 22 |
+
start_idx = 0
|
| 23 |
+
if i > 0:
|
| 24 |
+
start_idx = episode_ends[i-1]
|
| 25 |
+
end_idx = episode_ends[i]
|
| 26 |
+
episode_length = end_idx - start_idx
|
| 27 |
+
|
| 28 |
+
min_start = -pad_before
|
| 29 |
+
max_start = episode_length - sequence_length + pad_after
|
| 30 |
+
|
| 31 |
+
# range stops one idx before end
|
| 32 |
+
for idx in range(min_start, max_start+1):
|
| 33 |
+
buffer_start_idx = max(idx, 0) + start_idx
|
| 34 |
+
buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
|
| 35 |
+
start_offset = buffer_start_idx - (idx+start_idx)
|
| 36 |
+
end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
|
| 37 |
+
sample_start_idx = 0 + start_offset
|
| 38 |
+
sample_end_idx = sequence_length - end_offset
|
| 39 |
+
if debug:
|
| 40 |
+
assert(start_offset >= 0)
|
| 41 |
+
assert(end_offset >= 0)
|
| 42 |
+
assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx)
|
| 43 |
+
indices.append([
|
| 44 |
+
buffer_start_idx, buffer_end_idx,
|
| 45 |
+
sample_start_idx, sample_end_idx])
|
| 46 |
+
indices = np.array(indices)
|
| 47 |
+
return indices
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_val_mask(n_episodes, val_ratio, seed=0):
|
| 51 |
+
val_mask = np.zeros(n_episodes, dtype=bool)
|
| 52 |
+
if val_ratio <= 0:
|
| 53 |
+
return val_mask
|
| 54 |
+
|
| 55 |
+
# have at least 1 episode for validation, and at least 1 episode for train
|
| 56 |
+
n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes-1)
|
| 57 |
+
rng = np.random.default_rng(seed=seed)
|
| 58 |
+
val_idxs = rng.choice(n_episodes, size=n_val, replace=False)
|
| 59 |
+
val_mask[val_idxs] = True
|
| 60 |
+
return val_mask
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def downsample_mask(mask, max_n, seed=0):
|
| 64 |
+
# subsample training data
|
| 65 |
+
train_mask = mask
|
| 66 |
+
if (max_n is not None) and (np.sum(train_mask) > max_n):
|
| 67 |
+
n_train = int(max_n)
|
| 68 |
+
curr_train_idxs = np.nonzero(train_mask)[0]
|
| 69 |
+
rng = np.random.default_rng(seed=seed)
|
| 70 |
+
train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False)
|
| 71 |
+
train_idxs = curr_train_idxs[train_idxs_idx]
|
| 72 |
+
train_mask = np.zeros_like(train_mask)
|
| 73 |
+
train_mask[train_idxs] = True
|
| 74 |
+
assert np.sum(train_mask) == n_train
|
| 75 |
+
return train_mask
|
| 76 |
+
|
| 77 |
+
class SequenceSampler:
|
| 78 |
+
def __init__(self,
|
| 79 |
+
replay_buffer: ReplayBuffer,
|
| 80 |
+
sequence_length:int,
|
| 81 |
+
pad_before:int=0,
|
| 82 |
+
pad_after:int=0,
|
| 83 |
+
keys=None,
|
| 84 |
+
key_first_k=dict(),
|
| 85 |
+
episode_mask: Optional[np.ndarray]=None,
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
key_first_k: dict str: int
|
| 89 |
+
Only take first k data from these keys (to improve perf)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
super().__init__()
|
| 93 |
+
assert(sequence_length >= 1)
|
| 94 |
+
if keys is None:
|
| 95 |
+
keys = list(replay_buffer.keys())
|
| 96 |
+
|
| 97 |
+
episode_ends = replay_buffer.episode_ends[:]
|
| 98 |
+
if episode_mask is None:
|
| 99 |
+
episode_mask = np.ones(episode_ends.shape, dtype=bool)
|
| 100 |
+
|
| 101 |
+
if np.any(episode_mask):
|
| 102 |
+
indices = create_indices(episode_ends,
|
| 103 |
+
sequence_length=sequence_length,
|
| 104 |
+
pad_before=pad_before,
|
| 105 |
+
pad_after=pad_after,
|
| 106 |
+
episode_mask=episode_mask
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
indices = np.zeros((0,4), dtype=np.int64)
|
| 110 |
+
|
| 111 |
+
# (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx)
|
| 112 |
+
self.indices = indices
|
| 113 |
+
self.keys = list(keys) # prevent OmegaConf list performance problem
|
| 114 |
+
self.sequence_length = sequence_length
|
| 115 |
+
self.replay_buffer = replay_buffer
|
| 116 |
+
self.key_first_k = key_first_k
|
| 117 |
+
|
| 118 |
+
def __len__(self):
|
| 119 |
+
return len(self.indices)
|
| 120 |
+
|
| 121 |
+
def sample_sequence(self, idx):
|
| 122 |
+
buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \
|
| 123 |
+
= self.indices[idx]
|
| 124 |
+
result = dict()
|
| 125 |
+
for key in self.keys:
|
| 126 |
+
input_arr = self.replay_buffer[key]
|
| 127 |
+
# performance optimization, avoid small allocation if possible
|
| 128 |
+
if key not in self.key_first_k:
|
| 129 |
+
sample = input_arr[buffer_start_idx:buffer_end_idx]
|
| 130 |
+
else:
|
| 131 |
+
# performance optimization, only load used obs steps
|
| 132 |
+
n_data = buffer_end_idx - buffer_start_idx
|
| 133 |
+
k_data = min(self.key_first_k[key], n_data)
|
| 134 |
+
# fill value with Nan to catch bugs
|
| 135 |
+
# the non-loaded region should never be used
|
| 136 |
+
sample = np.full((n_data,) + input_arr.shape[1:],
|
| 137 |
+
fill_value=np.nan, dtype=input_arr.dtype)
|
| 138 |
+
try:
|
| 139 |
+
sample[:k_data] = input_arr[buffer_start_idx:buffer_start_idx+k_data]
|
| 140 |
+
except Exception as e:
|
| 141 |
+
import pdb; pdb.set_trace()
|
| 142 |
+
data = sample
|
| 143 |
+
if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length):
|
| 144 |
+
data = np.zeros(
|
| 145 |
+
shape=(self.sequence_length,) + input_arr.shape[1:],
|
| 146 |
+
dtype=input_arr.dtype)
|
| 147 |
+
if sample_start_idx > 0:
|
| 148 |
+
data[:sample_start_idx] = sample[0]
|
| 149 |
+
if sample_end_idx < self.sequence_length:
|
| 150 |
+
data[sample_end_idx:] = sample[-1]
|
| 151 |
+
data[sample_start_idx:sample_end_idx] = sample
|
| 152 |
+
result[key] = data
|
| 153 |
+
return result
|
equidiff/equi_diffpo/common/timestamp_accumulator.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Optional, Dict
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_accumulate_timestamp_idxs(
|
| 7 |
+
timestamps: List[float],
|
| 8 |
+
start_time: float,
|
| 9 |
+
dt: float,
|
| 10 |
+
eps:float=1e-5,
|
| 11 |
+
next_global_idx: Optional[int]=0,
|
| 12 |
+
allow_negative=False
|
| 13 |
+
) -> Tuple[List[int], List[int], int]:
|
| 14 |
+
"""
|
| 15 |
+
For each dt window, choose the first timestamp in the window.
|
| 16 |
+
Assumes timestamps sorted. One timestamp might be chosen multiple times due to dropped frames.
|
| 17 |
+
next_global_idx should start at 0 normally, and then use the returned next_global_idx.
|
| 18 |
+
However, when overwiting previous values are desired, set last_global_idx to None.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
local_idxs: which index in the given timestamps array to chose from
|
| 22 |
+
global_idxs: the global index of each chosen timestamp
|
| 23 |
+
next_global_idx: used for next call.
|
| 24 |
+
"""
|
| 25 |
+
local_idxs = list()
|
| 26 |
+
global_idxs = list()
|
| 27 |
+
for local_idx, ts in enumerate(timestamps):
|
| 28 |
+
# add eps * dt to timestamps so that when ts == start_time + k * dt
|
| 29 |
+
# is always recorded as kth element (avoiding floating point errors)
|
| 30 |
+
global_idx = math.floor((ts - start_time) / dt + eps)
|
| 31 |
+
if (not allow_negative) and (global_idx < 0):
|
| 32 |
+
continue
|
| 33 |
+
if next_global_idx is None:
|
| 34 |
+
next_global_idx = global_idx
|
| 35 |
+
|
| 36 |
+
n_repeats = max(0, global_idx - next_global_idx + 1)
|
| 37 |
+
for i in range(n_repeats):
|
| 38 |
+
local_idxs.append(local_idx)
|
| 39 |
+
global_idxs.append(next_global_idx + i)
|
| 40 |
+
next_global_idx += n_repeats
|
| 41 |
+
return local_idxs, global_idxs, next_global_idx
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def align_timestamps(
|
| 45 |
+
timestamps: List[float],
|
| 46 |
+
target_global_idxs: List[int],
|
| 47 |
+
start_time: float,
|
| 48 |
+
dt: float,
|
| 49 |
+
eps:float=1e-5):
|
| 50 |
+
if isinstance(target_global_idxs, np.ndarray):
|
| 51 |
+
target_global_idxs = target_global_idxs.tolist()
|
| 52 |
+
assert len(target_global_idxs) > 0
|
| 53 |
+
|
| 54 |
+
local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs(
|
| 55 |
+
timestamps=timestamps,
|
| 56 |
+
start_time=start_time,
|
| 57 |
+
dt=dt,
|
| 58 |
+
eps=eps,
|
| 59 |
+
next_global_idx=target_global_idxs[0],
|
| 60 |
+
allow_negative=True
|
| 61 |
+
)
|
| 62 |
+
if len(global_idxs) > len(target_global_idxs):
|
| 63 |
+
# if more steps available, truncate
|
| 64 |
+
global_idxs = global_idxs[:len(target_global_idxs)]
|
| 65 |
+
local_idxs = local_idxs[:len(target_global_idxs)]
|
| 66 |
+
|
| 67 |
+
if len(global_idxs) == 0:
|
| 68 |
+
import pdb; pdb.set_trace()
|
| 69 |
+
|
| 70 |
+
for i in range(len(target_global_idxs) - len(global_idxs)):
|
| 71 |
+
# if missing, repeat
|
| 72 |
+
local_idxs.append(len(timestamps)-1)
|
| 73 |
+
global_idxs.append(global_idxs[-1] + 1)
|
| 74 |
+
assert global_idxs == target_global_idxs
|
| 75 |
+
assert len(local_idxs) == len(global_idxs)
|
| 76 |
+
return local_idxs
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TimestampObsAccumulator:
|
| 80 |
+
def __init__(self,
|
| 81 |
+
start_time: float,
|
| 82 |
+
dt: float,
|
| 83 |
+
eps: float=1e-5):
|
| 84 |
+
self.start_time = start_time
|
| 85 |
+
self.dt = dt
|
| 86 |
+
self.eps = eps
|
| 87 |
+
self.obs_buffer = dict()
|
| 88 |
+
self.timestamp_buffer = None
|
| 89 |
+
self.next_global_idx = 0
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
return self.next_global_idx
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def data(self):
|
| 96 |
+
if self.timestamp_buffer is None:
|
| 97 |
+
return dict()
|
| 98 |
+
result = dict()
|
| 99 |
+
for key, value in self.obs_buffer.items():
|
| 100 |
+
result[key] = value[:len(self)]
|
| 101 |
+
return result
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def actual_timestamps(self):
|
| 105 |
+
if self.timestamp_buffer is None:
|
| 106 |
+
return np.array([])
|
| 107 |
+
return self.timestamp_buffer[:len(self)]
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def timestamps(self):
|
| 111 |
+
if self.timestamp_buffer is None:
|
| 112 |
+
return np.array([])
|
| 113 |
+
return self.start_time + np.arange(len(self)) * self.dt
|
| 114 |
+
|
| 115 |
+
def put(self, data: Dict[str, np.ndarray], timestamps: np.ndarray):
|
| 116 |
+
"""
|
| 117 |
+
data:
|
| 118 |
+
key: T,*
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
local_idxs, global_idxs, self.next_global_idx = get_accumulate_timestamp_idxs(
|
| 122 |
+
timestamps=timestamps,
|
| 123 |
+
start_time=self.start_time,
|
| 124 |
+
dt=self.dt,
|
| 125 |
+
eps=self.eps,
|
| 126 |
+
next_global_idx=self.next_global_idx
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if len(global_idxs) > 0:
|
| 130 |
+
if self.timestamp_buffer is None:
|
| 131 |
+
# first allocation
|
| 132 |
+
self.obs_buffer = dict()
|
| 133 |
+
for key, value in data.items():
|
| 134 |
+
self.obs_buffer[key] = np.zeros_like(value)
|
| 135 |
+
self.timestamp_buffer = np.zeros(
|
| 136 |
+
(len(timestamps),), dtype=np.float64)
|
| 137 |
+
|
| 138 |
+
this_max_size = global_idxs[-1] + 1
|
| 139 |
+
if this_max_size > len(self.timestamp_buffer):
|
| 140 |
+
# reallocate
|
| 141 |
+
new_size = max(this_max_size, len(self.timestamp_buffer) * 2)
|
| 142 |
+
for key in list(self.obs_buffer.keys()):
|
| 143 |
+
new_shape = (new_size,) + self.obs_buffer[key].shape[1:]
|
| 144 |
+
self.obs_buffer[key] = np.resize(self.obs_buffer[key], new_shape)
|
| 145 |
+
self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size))
|
| 146 |
+
|
| 147 |
+
# write data
|
| 148 |
+
for key, value in self.obs_buffer.items():
|
| 149 |
+
value[global_idxs] = data[key][local_idxs]
|
| 150 |
+
self.timestamp_buffer[global_idxs] = timestamps[local_idxs]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TimestampActionAccumulator:
|
| 154 |
+
def __init__(self,
|
| 155 |
+
start_time: float,
|
| 156 |
+
dt: float,
|
| 157 |
+
eps: float=1e-5):
|
| 158 |
+
"""
|
| 159 |
+
Different from Obs accumulator, the action accumulator
|
| 160 |
+
allows overwriting previous values.
|
| 161 |
+
"""
|
| 162 |
+
self.start_time = start_time
|
| 163 |
+
self.dt = dt
|
| 164 |
+
self.eps = eps
|
| 165 |
+
self.action_buffer = None
|
| 166 |
+
self.timestamp_buffer = None
|
| 167 |
+
self.size = 0
|
| 168 |
+
|
| 169 |
+
def __len__(self):
|
| 170 |
+
return self.size
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def actions(self):
|
| 174 |
+
if self.action_buffer is None:
|
| 175 |
+
return np.array([])
|
| 176 |
+
return self.action_buffer[:len(self)]
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def actual_timestamps(self):
|
| 180 |
+
if self.timestamp_buffer is None:
|
| 181 |
+
return np.array([])
|
| 182 |
+
return self.timestamp_buffer[:len(self)]
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def timestamps(self):
|
| 186 |
+
if self.timestamp_buffer is None:
|
| 187 |
+
return np.array([])
|
| 188 |
+
return self.start_time + np.arange(len(self)) * self.dt
|
| 189 |
+
|
| 190 |
+
def put(self, actions: np.ndarray, timestamps: np.ndarray):
|
| 191 |
+
"""
|
| 192 |
+
Note: timestamps is the time when the action will be issued,
|
| 193 |
+
not when the action will be completed (target_timestamp)
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs(
|
| 197 |
+
timestamps=timestamps,
|
| 198 |
+
start_time=self.start_time,
|
| 199 |
+
dt=self.dt,
|
| 200 |
+
eps=self.eps,
|
| 201 |
+
# allows overwriting previous actions
|
| 202 |
+
next_global_idx=None
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if len(global_idxs) > 0:
|
| 206 |
+
if self.timestamp_buffer is None:
|
| 207 |
+
# first allocation
|
| 208 |
+
self.action_buffer = np.zeros_like(actions)
|
| 209 |
+
self.timestamp_buffer = np.zeros((len(actions),), dtype=np.float64)
|
| 210 |
+
|
| 211 |
+
this_max_size = global_idxs[-1] + 1
|
| 212 |
+
if this_max_size > len(self.timestamp_buffer):
|
| 213 |
+
# reallocate
|
| 214 |
+
new_size = max(this_max_size, len(self.timestamp_buffer) * 2)
|
| 215 |
+
new_shape = (new_size,) + self.action_buffer.shape[1:]
|
| 216 |
+
self.action_buffer = np.resize(self.action_buffer, new_shape)
|
| 217 |
+
self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size,))
|
| 218 |
+
|
| 219 |
+
# potentially rewrite old data (as expected)
|
| 220 |
+
self.action_buffer[global_idxs] = actions[local_idxs]
|
| 221 |
+
self.timestamp_buffer[global_idxs] = timestamps[local_idxs]
|
| 222 |
+
self.size = max(self.size, this_max_size)
|
equidiff/equi_diffpo/config/dp3.yaml
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_pc_abs
|
| 4 |
+
|
| 5 |
+
name: train_dp3
|
| 6 |
+
_target_: equi_diffpo.workspace.train_dp3_workspace.TrainDP3Workspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "debug"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 2
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
keypoint_visible_rate: 1.0
|
| 19 |
+
obs_as_global_cond: True
|
| 20 |
+
dataset_target: equi_diffpo.dataset.robomimic_replay_point_cloud_dataset.RobomimicReplayPointCloudDataset
|
| 21 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
|
| 22 |
+
|
| 23 |
+
policy:
|
| 24 |
+
_target_: equi_diffpo.policy.dp3.DP3
|
| 25 |
+
use_point_crop: true
|
| 26 |
+
condition_type: film
|
| 27 |
+
use_down_condition: true
|
| 28 |
+
use_mid_condition: true
|
| 29 |
+
use_up_condition: true
|
| 30 |
+
|
| 31 |
+
diffusion_step_embed_dim: 128
|
| 32 |
+
down_dims:
|
| 33 |
+
- 512
|
| 34 |
+
- 1024
|
| 35 |
+
- 2048
|
| 36 |
+
crop_shape:
|
| 37 |
+
- 80
|
| 38 |
+
- 80
|
| 39 |
+
encoder_output_dim: 64
|
| 40 |
+
horizon: ${horizon}
|
| 41 |
+
kernel_size: 5
|
| 42 |
+
n_action_steps: ${n_action_steps}
|
| 43 |
+
n_groups: 8
|
| 44 |
+
n_obs_steps: ${n_obs_steps}
|
| 45 |
+
|
| 46 |
+
noise_scheduler:
|
| 47 |
+
_target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler
|
| 48 |
+
num_train_timesteps: 100
|
| 49 |
+
beta_start: 0.0001
|
| 50 |
+
beta_end: 0.02
|
| 51 |
+
beta_schedule: squaredcos_cap_v2
|
| 52 |
+
clip_sample: True
|
| 53 |
+
set_alpha_to_one: True
|
| 54 |
+
steps_offset: 0
|
| 55 |
+
prediction_type: sample
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
num_inference_steps: 10
|
| 59 |
+
obs_as_global_cond: true
|
| 60 |
+
shape_meta: ${shape_meta}
|
| 61 |
+
|
| 62 |
+
use_pc_color: true
|
| 63 |
+
pointnet_type: "pointnet"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
pointcloud_encoder_cfg:
|
| 67 |
+
in_channels: 3
|
| 68 |
+
out_channels: ${policy.encoder_output_dim}
|
| 69 |
+
use_layernorm: true
|
| 70 |
+
final_norm: layernorm # layernorm, none
|
| 71 |
+
normal_channel: false
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
ema:
|
| 75 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 76 |
+
update_after_step: 0
|
| 77 |
+
inv_gamma: 1.0
|
| 78 |
+
power: 0.75
|
| 79 |
+
min_value: 0.0
|
| 80 |
+
max_value: 0.9999
|
| 81 |
+
|
| 82 |
+
dataloader:
|
| 83 |
+
batch_size: 128
|
| 84 |
+
num_workers: 8
|
| 85 |
+
shuffle: True
|
| 86 |
+
pin_memory: True
|
| 87 |
+
persistent_workers: True
|
| 88 |
+
|
| 89 |
+
val_dataloader:
|
| 90 |
+
batch_size: 128
|
| 91 |
+
num_workers: 8
|
| 92 |
+
shuffle: False
|
| 93 |
+
pin_memory: True
|
| 94 |
+
persistent_workers: True
|
| 95 |
+
|
| 96 |
+
optimizer:
|
| 97 |
+
_target_: torch.optim.AdamW
|
| 98 |
+
lr: 1.0e-4
|
| 99 |
+
betas: [0.95, 0.999]
|
| 100 |
+
eps: 1.0e-8
|
| 101 |
+
weight_decay: 1.0e-6
|
| 102 |
+
|
| 103 |
+
training:
|
| 104 |
+
device: "cuda:0"
|
| 105 |
+
seed: 42
|
| 106 |
+
debug: False
|
| 107 |
+
resume: True
|
| 108 |
+
lr_scheduler: cosine
|
| 109 |
+
lr_warmup_steps: 500
|
| 110 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 111 |
+
gradient_accumulate_every: 1
|
| 112 |
+
use_ema: True
|
| 113 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 114 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 115 |
+
val_every: 1
|
| 116 |
+
sample_every: 5
|
| 117 |
+
max_train_steps: null
|
| 118 |
+
max_val_steps: null
|
| 119 |
+
tqdm_interval_sec: 1.0
|
| 120 |
+
|
| 121 |
+
logging:
|
| 122 |
+
project: dp3_${task_name}
|
| 123 |
+
resume: true
|
| 124 |
+
mode: online
|
| 125 |
+
name: dp3_${n_demo}
|
| 126 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 127 |
+
id: null
|
| 128 |
+
group: null
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
checkpoint:
|
| 132 |
+
save_ckpt: False # if True, save checkpoint every checkpoint_every
|
| 133 |
+
topk:
|
| 134 |
+
monitor_key: test_mean_score
|
| 135 |
+
mode: max
|
| 136 |
+
k: 1
|
| 137 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 138 |
+
save_last_ckpt: True # this only saves when save_ckpt is True
|
| 139 |
+
save_last_snapshot: False
|
| 140 |
+
|
| 141 |
+
multi_run:
|
| 142 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 143 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 144 |
+
|
| 145 |
+
hydra:
|
| 146 |
+
job:
|
| 147 |
+
override_dirname: ${name}
|
| 148 |
+
run:
|
| 149 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 150 |
+
sweep:
|
| 151 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 152 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/task/mimicgen_abs.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mimicgen_abs
|
| 2 |
+
|
| 3 |
+
shape_meta: &shape_meta
|
| 4 |
+
# acceptable types: rgb, low_dim
|
| 5 |
+
obs:
|
| 6 |
+
agentview_image:
|
| 7 |
+
shape: [3, 84, 84]
|
| 8 |
+
type: rgb
|
| 9 |
+
robot0_eye_in_hand_image:
|
| 10 |
+
shape: [3, 84, 84]
|
| 11 |
+
type: rgb
|
| 12 |
+
robot0_eef_pos:
|
| 13 |
+
shape: [3]
|
| 14 |
+
# type default: low_dim
|
| 15 |
+
robot0_eef_quat:
|
| 16 |
+
shape: [4]
|
| 17 |
+
robot0_gripper_qpos:
|
| 18 |
+
shape: [2]
|
| 19 |
+
action:
|
| 20 |
+
shape: [10]
|
| 21 |
+
|
| 22 |
+
abs_action: &abs_action True
|
| 23 |
+
|
| 24 |
+
env_runner:
|
| 25 |
+
_target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
|
| 26 |
+
dataset_path: ${dataset_path}
|
| 27 |
+
shape_meta: *shape_meta
|
| 28 |
+
n_train: 6
|
| 29 |
+
n_train_vis: 2
|
| 30 |
+
train_start_idx: 0
|
| 31 |
+
n_test: 50
|
| 32 |
+
n_test_vis: 4
|
| 33 |
+
test_start_seed: 100000
|
| 34 |
+
max_steps: ${get_max_steps:${task_name}}
|
| 35 |
+
n_obs_steps: ${n_obs_steps}
|
| 36 |
+
n_action_steps: ${n_action_steps}
|
| 37 |
+
render_obs_key: 'agentview_image'
|
| 38 |
+
fps: 10
|
| 39 |
+
crf: 22
|
| 40 |
+
past_action: ${past_action_visible}
|
| 41 |
+
abs_action: *abs_action
|
| 42 |
+
tqdm_interval_sec: 1.0
|
| 43 |
+
n_envs: 28
|
| 44 |
+
|
| 45 |
+
dataset:
|
| 46 |
+
# _target_: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
|
| 47 |
+
_target_: ${dataset}
|
| 48 |
+
n_demo: ${n_demo}
|
| 49 |
+
shape_meta: *shape_meta
|
| 50 |
+
dataset_path: ${dataset_path}
|
| 51 |
+
horizon: ${horizon}
|
| 52 |
+
pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
|
| 53 |
+
pad_after: ${eval:'${n_action_steps}-1'}
|
| 54 |
+
n_obs_steps: ${dataset_obs_steps}
|
| 55 |
+
abs_action: *abs_action
|
| 56 |
+
rotation_rep: 'rotation_6d'
|
| 57 |
+
use_legacy_normalizer: False
|
| 58 |
+
use_cache: True
|
| 59 |
+
seed: 42
|
| 60 |
+
val_ratio: 0.02
|
equidiff/equi_diffpo/config/task/mimicgen_pc_abs.yaml
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mimicgen_pc_abs
|
| 2 |
+
|
| 3 |
+
shape_meta: &shape_meta
|
| 4 |
+
# acceptable types: rgb, low_dim
|
| 5 |
+
obs:
|
| 6 |
+
robot0_eye_in_hand_image:
|
| 7 |
+
shape: [3, 84, 84]
|
| 8 |
+
type: rgb
|
| 9 |
+
point_cloud:
|
| 10 |
+
shape: [1024, 6]
|
| 11 |
+
type: point_cloud
|
| 12 |
+
robot0_eef_pos:
|
| 13 |
+
shape: [3]
|
| 14 |
+
# type default: low_dim
|
| 15 |
+
robot0_eef_quat:
|
| 16 |
+
shape: [4]
|
| 17 |
+
robot0_gripper_qpos:
|
| 18 |
+
shape: [2]
|
| 19 |
+
action:
|
| 20 |
+
shape: [10]
|
| 21 |
+
|
| 22 |
+
env_runner_shape_meta: &env_runner_shape_meta
|
| 23 |
+
# acceptable types: rgb, low_dim
|
| 24 |
+
obs:
|
| 25 |
+
robot0_eye_in_hand_image:
|
| 26 |
+
shape: [3, 84, 84]
|
| 27 |
+
type: rgb
|
| 28 |
+
agentview_image:
|
| 29 |
+
shape: [3, 84, 84]
|
| 30 |
+
type: rgb
|
| 31 |
+
point_cloud:
|
| 32 |
+
shape: [1024, 6]
|
| 33 |
+
type: point_cloud
|
| 34 |
+
robot0_eef_pos:
|
| 35 |
+
shape: [3]
|
| 36 |
+
# type default: low_dim
|
| 37 |
+
robot0_eef_quat:
|
| 38 |
+
shape: [4]
|
| 39 |
+
robot0_gripper_qpos:
|
| 40 |
+
shape: [2]
|
| 41 |
+
action:
|
| 42 |
+
shape: [10]
|
| 43 |
+
|
| 44 |
+
abs_action: &abs_action True
|
| 45 |
+
|
| 46 |
+
env_runner:
|
| 47 |
+
_target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
|
| 48 |
+
dataset_path: ${dataset_path}
|
| 49 |
+
shape_meta: *env_runner_shape_meta
|
| 50 |
+
n_train: 6
|
| 51 |
+
n_train_vis: 2
|
| 52 |
+
train_start_idx: 0
|
| 53 |
+
n_test: 50
|
| 54 |
+
n_test_vis: 4
|
| 55 |
+
test_start_seed: 100000
|
| 56 |
+
max_steps: ${get_max_steps:${task_name}}
|
| 57 |
+
n_obs_steps: ${n_obs_steps}
|
| 58 |
+
n_action_steps: ${n_action_steps}
|
| 59 |
+
render_obs_key: 'agentview_image'
|
| 60 |
+
fps: 10
|
| 61 |
+
crf: 22
|
| 62 |
+
past_action: False
|
| 63 |
+
abs_action: *abs_action
|
| 64 |
+
tqdm_interval_sec: 1.0
|
| 65 |
+
n_envs: 28
|
| 66 |
+
|
| 67 |
+
dataset:
|
| 68 |
+
_target_: ${dataset_target}
|
| 69 |
+
n_demo: ${n_demo}
|
| 70 |
+
shape_meta: *shape_meta
|
| 71 |
+
dataset_path: ${dataset_path}
|
| 72 |
+
horizon: ${horizon}
|
| 73 |
+
pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
|
| 74 |
+
pad_after: ${eval:'${n_action_steps}-1'}
|
| 75 |
+
n_obs_steps: ${dataset_obs_steps}
|
| 76 |
+
abs_action: *abs_action
|
| 77 |
+
rotation_rep: 'rotation_6d'
|
| 78 |
+
use_legacy_normalizer: False
|
| 79 |
+
use_cache: False
|
| 80 |
+
seed: 42
|
| 81 |
+
val_ratio: 0.02
|
equidiff/equi_diffpo/config/task/mimicgen_rel.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mimicgen_rel
|
| 2 |
+
|
| 3 |
+
shape_meta: &shape_meta
|
| 4 |
+
# acceptable types: rgb, low_dim
|
| 5 |
+
obs:
|
| 6 |
+
agentview_image:
|
| 7 |
+
shape: [3, 84, 84]
|
| 8 |
+
type: rgb
|
| 9 |
+
robot0_eye_in_hand_image:
|
| 10 |
+
shape: [3, 84, 84]
|
| 11 |
+
type: rgb
|
| 12 |
+
robot0_eef_pos:
|
| 13 |
+
shape: [3]
|
| 14 |
+
# type default: low_dim
|
| 15 |
+
robot0_eef_quat:
|
| 16 |
+
shape: [4]
|
| 17 |
+
robot0_gripper_qpos:
|
| 18 |
+
shape: [2]
|
| 19 |
+
action:
|
| 20 |
+
shape: [7]
|
| 21 |
+
|
| 22 |
+
abs_action: &abs_action False
|
| 23 |
+
|
| 24 |
+
env_runner:
|
| 25 |
+
_target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
|
| 26 |
+
dataset_path: ${dataset_path}
|
| 27 |
+
shape_meta: *shape_meta
|
| 28 |
+
n_train: 6
|
| 29 |
+
n_train_vis: 2
|
| 30 |
+
train_start_idx: 0
|
| 31 |
+
n_test: 50
|
| 32 |
+
n_test_vis: 4
|
| 33 |
+
test_start_seed: 100000
|
| 34 |
+
max_steps: ${get_max_steps:${task_name}}
|
| 35 |
+
n_obs_steps: ${n_obs_steps}
|
| 36 |
+
n_action_steps: ${n_action_steps}
|
| 37 |
+
render_obs_key: 'agentview_image'
|
| 38 |
+
fps: 10
|
| 39 |
+
crf: 22
|
| 40 |
+
past_action: ${past_action_visible}
|
| 41 |
+
abs_action: *abs_action
|
| 42 |
+
tqdm_interval_sec: 1.0
|
| 43 |
+
n_envs: 28
|
| 44 |
+
|
| 45 |
+
dataset:
|
| 46 |
+
# _target_: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
|
| 47 |
+
_target_: ${dataset}
|
| 48 |
+
n_demo: ${n_demo}
|
| 49 |
+
shape_meta: *shape_meta
|
| 50 |
+
dataset_path: ${dataset_path}
|
| 51 |
+
horizon: ${horizon}
|
| 52 |
+
pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
|
| 53 |
+
pad_after: ${eval:'${n_action_steps}-1'}
|
| 54 |
+
n_obs_steps: ${dataset_obs_steps}
|
| 55 |
+
abs_action: *abs_action
|
| 56 |
+
rotation_rep: 'rotation_6d'
|
| 57 |
+
use_legacy_normalizer: False
|
| 58 |
+
use_cache: True
|
| 59 |
+
seed: 42
|
| 60 |
+
val_ratio: 0.02
|
equidiff/equi_diffpo/config/task/mimicgen_voxel_abs.yaml
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mimicgen_abs
|
| 2 |
+
|
| 3 |
+
shape_meta: &shape_meta
|
| 4 |
+
# acceptable types: rgb, low_dim
|
| 5 |
+
obs:
|
| 6 |
+
robot0_eye_in_hand_image:
|
| 7 |
+
shape: [3, 84, 84]
|
| 8 |
+
type: rgb
|
| 9 |
+
voxels:
|
| 10 |
+
shape: [4, 64, 64, 64]
|
| 11 |
+
type: voxel
|
| 12 |
+
robot0_eef_pos:
|
| 13 |
+
shape: [3]
|
| 14 |
+
# type default: low_dim
|
| 15 |
+
robot0_eef_quat:
|
| 16 |
+
shape: [4]
|
| 17 |
+
robot0_gripper_qpos:
|
| 18 |
+
shape: [2]
|
| 19 |
+
action:
|
| 20 |
+
shape: [10]
|
| 21 |
+
|
| 22 |
+
env_runner_shape_meta: &env_runner_shape_meta
|
| 23 |
+
# acceptable types: rgb, low_dim
|
| 24 |
+
obs:
|
| 25 |
+
robot0_eye_in_hand_image:
|
| 26 |
+
shape: [3, 84, 84]
|
| 27 |
+
type: rgb
|
| 28 |
+
agentview_image:
|
| 29 |
+
shape: [3, 84, 84]
|
| 30 |
+
type: rgb
|
| 31 |
+
voxels:
|
| 32 |
+
shape: [4, 64, 64, 64]
|
| 33 |
+
type: voxel
|
| 34 |
+
robot0_eef_pos:
|
| 35 |
+
shape: [3]
|
| 36 |
+
# type default: low_dim
|
| 37 |
+
robot0_eef_quat:
|
| 38 |
+
shape: [4]
|
| 39 |
+
robot0_gripper_qpos:
|
| 40 |
+
shape: [2]
|
| 41 |
+
action:
|
| 42 |
+
shape: [10]
|
| 43 |
+
|
| 44 |
+
# dataset_path: &dataset_path data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
|
| 45 |
+
abs_action: &abs_action True
|
| 46 |
+
|
| 47 |
+
env_runner:
|
| 48 |
+
_target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
|
| 49 |
+
dataset_path: ${dataset_path}
|
| 50 |
+
shape_meta: *env_runner_shape_meta
|
| 51 |
+
n_train: 6
|
| 52 |
+
n_train_vis: 2
|
| 53 |
+
train_start_idx: 0
|
| 54 |
+
n_test: 50
|
| 55 |
+
n_test_vis: 4
|
| 56 |
+
test_start_seed: 100000
|
| 57 |
+
max_steps: ${get_max_steps:${task_name}}
|
| 58 |
+
n_obs_steps: ${n_obs_steps}
|
| 59 |
+
n_action_steps: ${n_action_steps}
|
| 60 |
+
render_obs_key: 'agentview_image'
|
| 61 |
+
fps: 10
|
| 62 |
+
crf: 22
|
| 63 |
+
past_action: ${past_action_visible}
|
| 64 |
+
abs_action: *abs_action
|
| 65 |
+
tqdm_interval_sec: 1.0
|
| 66 |
+
n_envs: 28
|
| 67 |
+
|
| 68 |
+
dataset:
|
| 69 |
+
_target_: equi_diffpo.dataset.robomimic_replay_voxel_sym_dataset.RobomimicReplayVoxelSymDataset
|
| 70 |
+
n_demo: ${n_demo}
|
| 71 |
+
shape_meta: *shape_meta
|
| 72 |
+
dataset_path: ${dataset_path}
|
| 73 |
+
horizon: ${horizon}
|
| 74 |
+
pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
|
| 75 |
+
pad_after: ${eval:'${n_action_steps}-1'}
|
| 76 |
+
n_obs_steps: ${dataset_obs_steps}
|
| 77 |
+
abs_action: *abs_action
|
| 78 |
+
rotation_rep: 'rotation_6d'
|
| 79 |
+
use_legacy_normalizer: False
|
| 80 |
+
use_cache: True
|
| 81 |
+
seed: 42
|
| 82 |
+
val_ratio: 0.02
|
| 83 |
+
ws_x_center: ${get_ws_x_center:${task_name}}
|
| 84 |
+
ws_y_center: ${get_ws_y_center:${task_name}}
|
equidiff/equi_diffpo/config/task/mimicgen_voxel_rel.yaml
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mimicgen_rel
|
| 2 |
+
|
| 3 |
+
shape_meta: &shape_meta
|
| 4 |
+
# acceptable types: rgb, low_dim
|
| 5 |
+
obs:
|
| 6 |
+
robot0_eye_in_hand_image:
|
| 7 |
+
shape: [3, 84, 84]
|
| 8 |
+
type: rgb
|
| 9 |
+
voxels:
|
| 10 |
+
shape: [4, 64, 64, 64]
|
| 11 |
+
type: voxel
|
| 12 |
+
robot0_eef_pos:
|
| 13 |
+
shape: [3]
|
| 14 |
+
# type default: low_dim
|
| 15 |
+
robot0_eef_quat:
|
| 16 |
+
shape: [4]
|
| 17 |
+
robot0_gripper_qpos:
|
| 18 |
+
shape: [2]
|
| 19 |
+
action:
|
| 20 |
+
shape: [7]
|
| 21 |
+
|
| 22 |
+
env_runner_shape_meta: &env_runner_shape_meta
|
| 23 |
+
# acceptable types: rgb, low_dim
|
| 24 |
+
obs:
|
| 25 |
+
robot0_eye_in_hand_image:
|
| 26 |
+
shape: [3, 84, 84]
|
| 27 |
+
type: rgb
|
| 28 |
+
agentview_image:
|
| 29 |
+
shape: [3, 84, 84]
|
| 30 |
+
type: rgb
|
| 31 |
+
voxels:
|
| 32 |
+
shape: [4, 64, 64, 64]
|
| 33 |
+
type: voxel
|
| 34 |
+
robot0_eef_pos:
|
| 35 |
+
shape: [3]
|
| 36 |
+
# type default: low_dim
|
| 37 |
+
robot0_eef_quat:
|
| 38 |
+
shape: [4]
|
| 39 |
+
robot0_gripper_qpos:
|
| 40 |
+
shape: [2]
|
| 41 |
+
action:
|
| 42 |
+
shape: [7]
|
| 43 |
+
|
| 44 |
+
# dataset_path: &dataset_path data/robomimic/datasets/${task_name}/${task_name}_voxel.hdf5
|
| 45 |
+
abs_action: &abs_action False
|
| 46 |
+
|
| 47 |
+
env_runner:
|
| 48 |
+
_target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
|
| 49 |
+
dataset_path: ${dataset_path}
|
| 50 |
+
shape_meta: *env_runner_shape_meta
|
| 51 |
+
n_train: 6
|
| 52 |
+
n_train_vis: 2
|
| 53 |
+
train_start_idx: 0
|
| 54 |
+
n_test: 50
|
| 55 |
+
n_test_vis: 4
|
| 56 |
+
test_start_seed: 100000
|
| 57 |
+
max_steps: ${get_max_steps:${task_name}}
|
| 58 |
+
n_obs_steps: ${n_obs_steps}
|
| 59 |
+
n_action_steps: ${n_action_steps}
|
| 60 |
+
render_obs_key: 'agentview_image'
|
| 61 |
+
fps: 10
|
| 62 |
+
crf: 22
|
| 63 |
+
past_action: ${past_action_visible}
|
| 64 |
+
abs_action: *abs_action
|
| 65 |
+
tqdm_interval_sec: 1.0
|
| 66 |
+
n_envs: 28
|
| 67 |
+
|
| 68 |
+
dataset:
|
| 69 |
+
_target_: equi_diffpo.dataset.robomimic_replay_voxel_sym_dataset.RobomimicReplayVoxelSymDataset
|
| 70 |
+
n_demo: ${n_demo}
|
| 71 |
+
shape_meta: *shape_meta
|
| 72 |
+
dataset_path: ${dataset_path}
|
| 73 |
+
horizon: ${horizon}
|
| 74 |
+
pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
|
| 75 |
+
pad_after: ${eval:'${n_action_steps}-1'}
|
| 76 |
+
n_obs_steps: ${dataset_obs_steps}
|
| 77 |
+
abs_action: *abs_action
|
| 78 |
+
rotation_rep: 'rotation_6d'
|
| 79 |
+
use_legacy_normalizer: False
|
| 80 |
+
use_cache: True
|
| 81 |
+
seed: 42
|
| 82 |
+
val_ratio: 0.02
|
| 83 |
+
ws_x_center: ${get_ws_x_center:${task_name}}
|
| 84 |
+
ws_y_center: ${get_ws_y_center:${task_name}}
|
equidiff/equi_diffpo/config/test_equi_diffusion_unet_abs_sq2.yaml
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: equi_diff
|
| 6 |
+
_target_: equi_diffpo.workspace.test_equi_workspace.TestEquiWorkspace
|
| 7 |
+
ckpt_path: data/outputs/2025.01.10/06.04.25_equi_diff_square_d2_high/checkpoints/epoch=0046-test_mean_score=0.760.ckpt
|
| 8 |
+
diversity: high
|
| 9 |
+
|
| 10 |
+
shape_meta: ${task.shape_meta}
|
| 11 |
+
exp_name: "default"
|
| 12 |
+
|
| 13 |
+
task_name: square_d2
|
| 14 |
+
log_txt_path: data/test_result.txt
|
| 15 |
+
n_demo: 1000
|
| 16 |
+
horizon: 16
|
| 17 |
+
n_obs_steps: 2
|
| 18 |
+
n_action_steps: 8
|
| 19 |
+
n_latency_steps: 0
|
| 20 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 21 |
+
past_action_visible: False
|
| 22 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 23 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
|
| 24 |
+
|
| 25 |
+
policy:
|
| 26 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 27 |
+
|
| 28 |
+
shape_meta: ${shape_meta}
|
| 29 |
+
|
| 30 |
+
noise_scheduler:
|
| 31 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 32 |
+
num_train_timesteps: 100
|
| 33 |
+
beta_start: 0.0001
|
| 34 |
+
beta_end: 0.02
|
| 35 |
+
beta_schedule: squaredcos_cap_v2
|
| 36 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 37 |
+
clip_sample: True # required when predict_epsilon=False
|
| 38 |
+
prediction_type: epsilon # or sample
|
| 39 |
+
|
| 40 |
+
horizon: ${horizon}
|
| 41 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 42 |
+
n_obs_steps: ${n_obs_steps}
|
| 43 |
+
num_inference_steps: 100
|
| 44 |
+
crop_shape: [76, 76]
|
| 45 |
+
# crop_shape: null
|
| 46 |
+
diffusion_step_embed_dim: 128
|
| 47 |
+
enc_n_hidden: 128
|
| 48 |
+
down_dims: [512, 1024, 2048]
|
| 49 |
+
kernel_size: 5
|
| 50 |
+
n_groups: 8
|
| 51 |
+
cond_predict_scale: True
|
| 52 |
+
rot_aug: False
|
| 53 |
+
|
| 54 |
+
# scheduler.step params
|
| 55 |
+
# predict_epsilon: True
|
| 56 |
+
|
| 57 |
+
ema:
|
| 58 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 59 |
+
update_after_step: 0
|
| 60 |
+
inv_gamma: 1.0
|
| 61 |
+
power: 0.75
|
| 62 |
+
min_value: 0.0
|
| 63 |
+
max_value: 0.9999
|
| 64 |
+
|
| 65 |
+
dataloader:
|
| 66 |
+
batch_size: 128
|
| 67 |
+
num_workers: 4
|
| 68 |
+
shuffle: True
|
| 69 |
+
pin_memory: True
|
| 70 |
+
persistent_workers: True
|
| 71 |
+
drop_last: true
|
| 72 |
+
|
| 73 |
+
val_dataloader:
|
| 74 |
+
batch_size: 128
|
| 75 |
+
num_workers: 8
|
| 76 |
+
shuffle: False
|
| 77 |
+
pin_memory: True
|
| 78 |
+
persistent_workers: True
|
| 79 |
+
|
| 80 |
+
optimizer:
|
| 81 |
+
betas: [0.95, 0.999]
|
| 82 |
+
eps: 1.0e-08
|
| 83 |
+
learning_rate: 0.0001
|
| 84 |
+
weight_decay: 1.0e-06
|
| 85 |
+
|
| 86 |
+
training:
|
| 87 |
+
ckpt_path: ${ckpt_path}
|
| 88 |
+
device: "cuda:0"
|
| 89 |
+
seed: 0
|
| 90 |
+
debug: False
|
| 91 |
+
resume: True
|
| 92 |
+
# optimization
|
| 93 |
+
lr_scheduler: cosine
|
| 94 |
+
lr_warmup_steps: 500
|
| 95 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 96 |
+
gradient_accumulate_every: 1
|
| 97 |
+
# EMA destroys performance when used with BatchNorm
|
| 98 |
+
# replace BatchNorm with GroupNorm.
|
| 99 |
+
use_ema: True
|
| 100 |
+
# training loop control
|
| 101 |
+
# in epochs
|
| 102 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 103 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 104 |
+
val_every: 1
|
| 105 |
+
sample_every: 5
|
| 106 |
+
# steps per epoch
|
| 107 |
+
max_train_steps: null
|
| 108 |
+
max_val_steps: null
|
| 109 |
+
# misc
|
| 110 |
+
tqdm_interval_sec: 1.0
|
| 111 |
+
|
| 112 |
+
logging:
|
| 113 |
+
project: test_diffusion_policy_${task_name}
|
| 114 |
+
resume: True
|
| 115 |
+
mode: online
|
| 116 |
+
name: equidiff_${n_demo}_${diversity}_${policy.n_action_steps}
|
| 117 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 118 |
+
id: null
|
| 119 |
+
group: null
|
| 120 |
+
|
| 121 |
+
checkpoint:
|
| 122 |
+
topk:
|
| 123 |
+
monitor_key: test_mean_score
|
| 124 |
+
mode: max
|
| 125 |
+
k: 5
|
| 126 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 127 |
+
save_last_ckpt: True
|
| 128 |
+
save_last_snapshot: False
|
| 129 |
+
|
| 130 |
+
# multi_run:
|
| 131 |
+
# run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 132 |
+
# wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 133 |
+
|
| 134 |
+
hydra:
|
| 135 |
+
job:
|
| 136 |
+
override_dirname: ${name}
|
| 137 |
+
run:
|
| 138 |
+
dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 139 |
+
sweep:
|
| 140 |
+
dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 141 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/test_sq2.yaml
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
run_name: square_d2_test
|
| 6 |
+
name: equi_diff
|
| 7 |
+
_target_: equi_diffpo.workspace.test_equi_workspace.TestEquiWorkspace
|
| 8 |
+
ckpt_path: /home/siweih/Project/EmbodiedBM/equidiff/data/outputs/2025.02.23/00.15.04_equi_diff_square_d2/checkpoints/epoch=0019-test_mean_score=0.840.ckpt
|
| 9 |
+
diversity: high
|
| 10 |
+
|
| 11 |
+
shape_meta: ${task.shape_meta}
|
| 12 |
+
exp_name: "default"
|
| 13 |
+
|
| 14 |
+
task_name: square_d2
|
| 15 |
+
log_txt_path: data/sq2_test_result.txt
|
| 16 |
+
n_demo: 1000
|
| 17 |
+
horizon: 16
|
| 18 |
+
n_obs_steps: 2
|
| 19 |
+
n_action_steps: 8
|
| 20 |
+
n_latency_steps: 0
|
| 21 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 22 |
+
past_action_visible: False
|
| 23 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 24 |
+
dataset_path: data/robomimic/datasets/square_d2/square_d2_abs.hdf5
|
| 25 |
+
|
| 26 |
+
policy:
|
| 27 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 28 |
+
|
| 29 |
+
shape_meta: ${shape_meta}
|
| 30 |
+
|
| 31 |
+
noise_scheduler:
|
| 32 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 33 |
+
num_train_timesteps: 100
|
| 34 |
+
beta_start: 0.0001
|
| 35 |
+
beta_end: 0.02
|
| 36 |
+
beta_schedule: squaredcos_cap_v2
|
| 37 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 38 |
+
clip_sample: True # required when predict_epsilon=False
|
| 39 |
+
prediction_type: epsilon # or sample
|
| 40 |
+
|
| 41 |
+
horizon: ${horizon}
|
| 42 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 43 |
+
n_obs_steps: ${n_obs_steps}
|
| 44 |
+
num_inference_steps: 100
|
| 45 |
+
crop_shape: [76, 76]
|
| 46 |
+
# crop_shape: null
|
| 47 |
+
diffusion_step_embed_dim: 128
|
| 48 |
+
enc_n_hidden: 128
|
| 49 |
+
down_dims: [512, 1024, 2048]
|
| 50 |
+
kernel_size: 5
|
| 51 |
+
n_groups: 8
|
| 52 |
+
cond_predict_scale: True
|
| 53 |
+
rot_aug: False
|
| 54 |
+
|
| 55 |
+
# scheduler.step params
|
| 56 |
+
# predict_epsilon: True
|
| 57 |
+
|
| 58 |
+
ema:
|
| 59 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 60 |
+
update_after_step: 0
|
| 61 |
+
inv_gamma: 1.0
|
| 62 |
+
power: 0.75
|
| 63 |
+
min_value: 0.0
|
| 64 |
+
max_value: 0.9999
|
| 65 |
+
|
| 66 |
+
dataloader:
|
| 67 |
+
batch_size: 128
|
| 68 |
+
num_workers: 4
|
| 69 |
+
shuffle: True
|
| 70 |
+
pin_memory: True
|
| 71 |
+
persistent_workers: True
|
| 72 |
+
drop_last: true
|
| 73 |
+
|
| 74 |
+
val_dataloader:
|
| 75 |
+
batch_size: 128
|
| 76 |
+
num_workers: 8
|
| 77 |
+
shuffle: False
|
| 78 |
+
pin_memory: True
|
| 79 |
+
persistent_workers: True
|
| 80 |
+
|
| 81 |
+
optimizer:
|
| 82 |
+
betas: [0.95, 0.999]
|
| 83 |
+
eps: 1.0e-08
|
| 84 |
+
learning_rate: 0.0001
|
| 85 |
+
weight_decay: 1.0e-06
|
| 86 |
+
|
| 87 |
+
training:
|
| 88 |
+
ckpt_path: ${ckpt_path}
|
| 89 |
+
device: "cuda:0"
|
| 90 |
+
seed: 0
|
| 91 |
+
debug: False
|
| 92 |
+
resume: True
|
| 93 |
+
# optimization
|
| 94 |
+
lr_scheduler: cosine
|
| 95 |
+
lr_warmup_steps: 500
|
| 96 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 97 |
+
gradient_accumulate_every: 1
|
| 98 |
+
# EMA destroys performance when used with BatchNorm
|
| 99 |
+
# replace BatchNorm with GroupNorm.
|
| 100 |
+
use_ema: True
|
| 101 |
+
# training loop control
|
| 102 |
+
# in epochs
|
| 103 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 104 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 105 |
+
val_every: 1
|
| 106 |
+
sample_every: 5
|
| 107 |
+
# steps per epoch
|
| 108 |
+
max_train_steps: null
|
| 109 |
+
max_val_steps: null
|
| 110 |
+
# misc
|
| 111 |
+
tqdm_interval_sec: 1.0
|
| 112 |
+
|
| 113 |
+
logging:
|
| 114 |
+
project: test_diffusion_policy_${task_name}
|
| 115 |
+
resume: True
|
| 116 |
+
mode: online
|
| 117 |
+
name: equidiff_${n_demo}_${diversity}_${policy.n_action_steps}
|
| 118 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 119 |
+
id: null
|
| 120 |
+
group: null
|
| 121 |
+
|
| 122 |
+
checkpoint:
|
| 123 |
+
topk:
|
| 124 |
+
monitor_key: test_mean_score
|
| 125 |
+
mode: max
|
| 126 |
+
k: 5
|
| 127 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 128 |
+
save_last_ckpt: True
|
| 129 |
+
save_last_snapshot: False
|
| 130 |
+
|
| 131 |
+
# multi_run:
|
| 132 |
+
# run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 133 |
+
# wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 134 |
+
|
| 135 |
+
hydra:
|
| 136 |
+
job:
|
| 137 |
+
override_dirname: ${name}
|
| 138 |
+
run:
|
| 139 |
+
dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 140 |
+
sweep:
|
| 141 |
+
dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 142 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/test_th2.yaml
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
run_name: threading_d2_test
|
| 6 |
+
name: equi_diff
|
| 7 |
+
_target_: equi_diffpo.workspace.test_equi_workspace.TestEquiWorkspace
|
| 8 |
+
ckpt_path: null
|
| 9 |
+
diversity: high
|
| 10 |
+
|
| 11 |
+
shape_meta: ${task.shape_meta}
|
| 12 |
+
exp_name: "default"
|
| 13 |
+
|
| 14 |
+
task_name: threading_d2
|
| 15 |
+
log_txt_path: data/th2_test_result.txt
|
| 16 |
+
n_demo: 100
|
| 17 |
+
horizon: 16
|
| 18 |
+
n_obs_steps: 2
|
| 19 |
+
n_action_steps: 8
|
| 20 |
+
n_latency_steps: 0
|
| 21 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 22 |
+
past_action_visible: False
|
| 23 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 24 |
+
dataset_path: data/robomimic/datasets/threading_d2_test/demo_abs.hdf5
|
| 25 |
+
|
| 26 |
+
policy:
|
| 27 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 28 |
+
|
| 29 |
+
shape_meta: ${shape_meta}
|
| 30 |
+
|
| 31 |
+
noise_scheduler:
|
| 32 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 33 |
+
num_train_timesteps: 100
|
| 34 |
+
beta_start: 0.0001
|
| 35 |
+
beta_end: 0.02
|
| 36 |
+
beta_schedule: squaredcos_cap_v2
|
| 37 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 38 |
+
clip_sample: True # required when predict_epsilon=False
|
| 39 |
+
prediction_type: epsilon # or sample
|
| 40 |
+
|
| 41 |
+
horizon: ${horizon}
|
| 42 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 43 |
+
n_obs_steps: ${n_obs_steps}
|
| 44 |
+
num_inference_steps: 100
|
| 45 |
+
crop_shape: [76, 76]
|
| 46 |
+
# crop_shape: null
|
| 47 |
+
diffusion_step_embed_dim: 128
|
| 48 |
+
enc_n_hidden: 128
|
| 49 |
+
down_dims: [512, 1024, 2048]
|
| 50 |
+
kernel_size: 5
|
| 51 |
+
n_groups: 8
|
| 52 |
+
cond_predict_scale: True
|
| 53 |
+
rot_aug: False
|
| 54 |
+
|
| 55 |
+
# scheduler.step params
|
| 56 |
+
# predict_epsilon: True
|
| 57 |
+
|
| 58 |
+
ema:
|
| 59 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 60 |
+
update_after_step: 0
|
| 61 |
+
inv_gamma: 1.0
|
| 62 |
+
power: 0.75
|
| 63 |
+
min_value: 0.0
|
| 64 |
+
max_value: 0.9999
|
| 65 |
+
|
| 66 |
+
dataloader:
|
| 67 |
+
batch_size: 128
|
| 68 |
+
num_workers: 4
|
| 69 |
+
shuffle: True
|
| 70 |
+
pin_memory: True
|
| 71 |
+
persistent_workers: True
|
| 72 |
+
drop_last: true
|
| 73 |
+
|
| 74 |
+
val_dataloader:
|
| 75 |
+
batch_size: 128
|
| 76 |
+
num_workers: 8
|
| 77 |
+
shuffle: False
|
| 78 |
+
pin_memory: True
|
| 79 |
+
persistent_workers: True
|
| 80 |
+
|
| 81 |
+
optimizer:
|
| 82 |
+
betas: [0.95, 0.999]
|
| 83 |
+
eps: 1.0e-08
|
| 84 |
+
learning_rate: 0.0001
|
| 85 |
+
weight_decay: 1.0e-06
|
| 86 |
+
|
| 87 |
+
training:
|
| 88 |
+
ckpt_path: ${ckpt_path}
|
| 89 |
+
device: "cuda:0"
|
| 90 |
+
seed: 0
|
| 91 |
+
debug: False
|
| 92 |
+
resume: True
|
| 93 |
+
# optimization
|
| 94 |
+
lr_scheduler: cosine
|
| 95 |
+
lr_warmup_steps: 500
|
| 96 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 97 |
+
gradient_accumulate_every: 1
|
| 98 |
+
# EMA destroys performance when used with BatchNorm
|
| 99 |
+
# replace BatchNorm with GroupNorm.
|
| 100 |
+
use_ema: True
|
| 101 |
+
# training loop control
|
| 102 |
+
# in epochs
|
| 103 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 104 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 105 |
+
val_every: 1
|
| 106 |
+
sample_every: 5
|
| 107 |
+
# steps per epoch
|
| 108 |
+
max_train_steps: null
|
| 109 |
+
max_val_steps: null
|
| 110 |
+
# misc
|
| 111 |
+
tqdm_interval_sec: 1.0
|
| 112 |
+
|
| 113 |
+
logging:
|
| 114 |
+
project: test_diffusion_policy_${task_name}
|
| 115 |
+
resume: True
|
| 116 |
+
mode: online
|
| 117 |
+
name: equidiff_${n_demo}_${diversity}_${policy.n_action_steps}
|
| 118 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 119 |
+
id: null
|
| 120 |
+
group: null
|
| 121 |
+
|
| 122 |
+
checkpoint:
|
| 123 |
+
topk:
|
| 124 |
+
monitor_key: test_mean_score
|
| 125 |
+
mode: max
|
| 126 |
+
k: 5
|
| 127 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 128 |
+
save_last_ckpt: True
|
| 129 |
+
save_last_snapshot: False
|
| 130 |
+
|
| 131 |
+
# multi_run:
|
| 132 |
+
# run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 133 |
+
# wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 134 |
+
|
| 135 |
+
hydra:
|
| 136 |
+
job:
|
| 137 |
+
override_dirname: ${name}
|
| 138 |
+
run:
|
| 139 |
+
dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 140 |
+
sweep:
|
| 141 |
+
dir: data/test_outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 142 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_act_abs.yaml
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: act
|
| 6 |
+
_target_: equi_diffpo.workspace.train_act_workspace.TrainActWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: 10
|
| 14 |
+
n_obs_steps: 1
|
| 15 |
+
n_action_steps: 10
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.act_policy.ACTPolicyWrapper
|
| 24 |
+
|
| 25 |
+
shape_meta: ${shape_meta}
|
| 26 |
+
|
| 27 |
+
max_timesteps: ${task.env_runner.max_steps}
|
| 28 |
+
temporal_agg: false
|
| 29 |
+
n_envs: ${task.env_runner.n_envs}
|
| 30 |
+
horizon: ${horizon}
|
| 31 |
+
|
| 32 |
+
dataloader:
|
| 33 |
+
batch_size: 64
|
| 34 |
+
num_workers: 4
|
| 35 |
+
shuffle: True
|
| 36 |
+
pin_memory: True
|
| 37 |
+
persistent_workers: True
|
| 38 |
+
|
| 39 |
+
val_dataloader:
|
| 40 |
+
batch_size: 64
|
| 41 |
+
num_workers: 4
|
| 42 |
+
shuffle: False
|
| 43 |
+
pin_memory: True
|
| 44 |
+
persistent_workers: True
|
| 45 |
+
|
| 46 |
+
training:
|
| 47 |
+
device: "cuda:0"
|
| 48 |
+
seed: 0
|
| 49 |
+
debug: False
|
| 50 |
+
resume: True
|
| 51 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 52 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 53 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 54 |
+
val_every: 1
|
| 55 |
+
max_train_steps: null
|
| 56 |
+
max_val_steps: null
|
| 57 |
+
tqdm_interval_sec: 1.0
|
| 58 |
+
|
| 59 |
+
logging:
|
| 60 |
+
project: diffusion_policy_${task_name}
|
| 61 |
+
resume: True
|
| 62 |
+
mode: online
|
| 63 |
+
name: act_demo${n_demo}
|
| 64 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 65 |
+
id: null
|
| 66 |
+
group: null
|
| 67 |
+
|
| 68 |
+
checkpoint:
|
| 69 |
+
topk:
|
| 70 |
+
monitor_key: test_mean_score
|
| 71 |
+
mode: max
|
| 72 |
+
k: 5
|
| 73 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 74 |
+
save_last_ckpt: True
|
| 75 |
+
save_last_snapshot: False
|
| 76 |
+
|
| 77 |
+
multi_run:
|
| 78 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 79 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 80 |
+
|
| 81 |
+
hydra:
|
| 82 |
+
job:
|
| 83 |
+
override_dirname: ${name}
|
| 84 |
+
run:
|
| 85 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 86 |
+
sweep:
|
| 87 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 88 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_bc_rnn.yaml
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_rel
|
| 4 |
+
|
| 5 |
+
name: bc_rnn
|
| 6 |
+
_target_: equi_diffpo.workspace.train_robomimic_image_workspace.TrainRobomimicImageWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: &horizon 10
|
| 14 |
+
n_obs_steps: 1
|
| 15 |
+
n_action_steps: 1
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: *horizon
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.robomimic_image_policy.RobomimicImagePolicy
|
| 24 |
+
shape_meta: ${shape_meta}
|
| 25 |
+
algo_name: bc_rnn
|
| 26 |
+
obs_type: image
|
| 27 |
+
# oc.select resolver: key, default
|
| 28 |
+
task_name: ${oc.select:task.task_name,lift}
|
| 29 |
+
dataset_type: ${oc.select:task.dataset_type,ph}
|
| 30 |
+
crop_shape: [76,76]
|
| 31 |
+
|
| 32 |
+
dataloader:
|
| 33 |
+
batch_size: 64
|
| 34 |
+
num_workers: 4
|
| 35 |
+
shuffle: True
|
| 36 |
+
pin_memory: True
|
| 37 |
+
persistent_workers: True
|
| 38 |
+
|
| 39 |
+
val_dataloader:
|
| 40 |
+
batch_size: 64
|
| 41 |
+
num_workers: 4
|
| 42 |
+
shuffle: False
|
| 43 |
+
pin_memory: True
|
| 44 |
+
persistent_workers: True
|
| 45 |
+
|
| 46 |
+
training:
|
| 47 |
+
device: "cuda:0"
|
| 48 |
+
seed: 0
|
| 49 |
+
debug: False
|
| 50 |
+
resume: True
|
| 51 |
+
# optimization
|
| 52 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 53 |
+
# training loop control
|
| 54 |
+
# in epochs
|
| 55 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 56 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 57 |
+
val_every: 1
|
| 58 |
+
sample_every: 5
|
| 59 |
+
# steps per epoch
|
| 60 |
+
max_train_steps: null
|
| 61 |
+
max_val_steps: null
|
| 62 |
+
# misc
|
| 63 |
+
tqdm_interval_sec: 1.0
|
| 64 |
+
|
| 65 |
+
logging:
|
| 66 |
+
project: diffusion_policy_${task_name}
|
| 67 |
+
resume: True
|
| 68 |
+
mode: online
|
| 69 |
+
name: bc_rnn_demo${n_demo}
|
| 70 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 71 |
+
id: null
|
| 72 |
+
group: null
|
| 73 |
+
|
| 74 |
+
checkpoint:
|
| 75 |
+
topk:
|
| 76 |
+
monitor_key: test_mean_score
|
| 77 |
+
mode: max
|
| 78 |
+
k: 5
|
| 79 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 80 |
+
save_last_ckpt: True
|
| 81 |
+
save_last_snapshot: False
|
| 82 |
+
|
| 83 |
+
multi_run:
|
| 84 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 85 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 86 |
+
|
| 87 |
+
hydra:
|
| 88 |
+
job:
|
| 89 |
+
override_dirname: ${name}
|
| 90 |
+
run:
|
| 91 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 92 |
+
sweep:
|
| 93 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 94 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_diffusion_transformer.yaml
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: diff_t
|
| 6 |
+
_target_: equi_diffpo.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: 10
|
| 14 |
+
n_obs_steps: 2
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
obs_as_cond: True
|
| 20 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
|
| 21 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
|
| 22 |
+
|
| 23 |
+
policy:
|
| 24 |
+
_target_: equi_diffpo.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
|
| 25 |
+
|
| 26 |
+
shape_meta: ${shape_meta}
|
| 27 |
+
|
| 28 |
+
noise_scheduler:
|
| 29 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 30 |
+
num_train_timesteps: 100
|
| 31 |
+
beta_start: 0.0001
|
| 32 |
+
beta_end: 0.02
|
| 33 |
+
beta_schedule: squaredcos_cap_v2
|
| 34 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 35 |
+
clip_sample: True # required when predict_epsilon=False
|
| 36 |
+
prediction_type: epsilon # or sample
|
| 37 |
+
|
| 38 |
+
horizon: ${horizon}
|
| 39 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 40 |
+
n_obs_steps: ${n_obs_steps}
|
| 41 |
+
num_inference_steps: 100
|
| 42 |
+
|
| 43 |
+
crop_shape: [76, 76]
|
| 44 |
+
obs_encoder_group_norm: True
|
| 45 |
+
eval_fixed_crop: True
|
| 46 |
+
|
| 47 |
+
n_layer: 8
|
| 48 |
+
n_cond_layers: 0 # >0: use transformer encoder for cond, otherwise use MLP
|
| 49 |
+
n_head: 4
|
| 50 |
+
n_emb: 256
|
| 51 |
+
p_drop_emb: 0.0
|
| 52 |
+
p_drop_attn: 0.3
|
| 53 |
+
causal_attn: True
|
| 54 |
+
time_as_cond: True # if false, use BERT like encoder only arch, time as input
|
| 55 |
+
obs_as_cond: ${obs_as_cond}
|
| 56 |
+
|
| 57 |
+
# scheduler.step params
|
| 58 |
+
# predict_epsilon: True
|
| 59 |
+
|
| 60 |
+
ema:
|
| 61 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 62 |
+
update_after_step: 0
|
| 63 |
+
inv_gamma: 1.0
|
| 64 |
+
power: 0.75
|
| 65 |
+
min_value: 0.0
|
| 66 |
+
max_value: 0.9999
|
| 67 |
+
|
| 68 |
+
dataloader:
|
| 69 |
+
batch_size: 64
|
| 70 |
+
num_workers: 4
|
| 71 |
+
shuffle: True
|
| 72 |
+
pin_memory: True
|
| 73 |
+
persistent_workers: True
|
| 74 |
+
|
| 75 |
+
val_dataloader:
|
| 76 |
+
batch_size: 64
|
| 77 |
+
num_workers: 4
|
| 78 |
+
shuffle: False
|
| 79 |
+
pin_memory: True
|
| 80 |
+
persistent_workers: True
|
| 81 |
+
|
| 82 |
+
optimizer:
|
| 83 |
+
transformer_weight_decay: 1.0e-3
|
| 84 |
+
obs_encoder_weight_decay: 1.0e-6
|
| 85 |
+
learning_rate: 1.0e-4
|
| 86 |
+
betas: [0.9, 0.95]
|
| 87 |
+
|
| 88 |
+
training:
|
| 89 |
+
device: "cuda:0"
|
| 90 |
+
seed: 0
|
| 91 |
+
debug: False
|
| 92 |
+
resume: True
|
| 93 |
+
# optimization
|
| 94 |
+
lr_scheduler: cosine
|
| 95 |
+
# Transformer needs LR warmup
|
| 96 |
+
lr_warmup_steps: 1000
|
| 97 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 98 |
+
gradient_accumulate_every: 1
|
| 99 |
+
# EMA destroys performance when used with BatchNorm
|
| 100 |
+
# replace BatchNorm with GroupNorm.
|
| 101 |
+
use_ema: True
|
| 102 |
+
# training loop control
|
| 103 |
+
# in epochs
|
| 104 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 105 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 106 |
+
val_every: 1
|
| 107 |
+
sample_every: 5
|
| 108 |
+
# steps per epoch
|
| 109 |
+
max_train_steps: null
|
| 110 |
+
max_val_steps: null
|
| 111 |
+
# misc
|
| 112 |
+
tqdm_interval_sec: 1.0
|
| 113 |
+
|
| 114 |
+
logging:
|
| 115 |
+
project: diffusion_policy_${task_name}
|
| 116 |
+
resume: True
|
| 117 |
+
mode: online
|
| 118 |
+
name: diff_t_demo${n_demo}
|
| 119 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 120 |
+
id: null
|
| 121 |
+
group: null
|
| 122 |
+
|
| 123 |
+
checkpoint:
|
| 124 |
+
topk:
|
| 125 |
+
monitor_key: test_mean_score
|
| 126 |
+
mode: max
|
| 127 |
+
k: 5
|
| 128 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 129 |
+
save_last_ckpt: True
|
| 130 |
+
save_last_snapshot: False
|
| 131 |
+
|
| 132 |
+
multi_run:
|
| 133 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 134 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 135 |
+
|
| 136 |
+
hydra:
|
| 137 |
+
job:
|
| 138 |
+
override_dirname: ${name}
|
| 139 |
+
run:
|
| 140 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 141 |
+
sweep:
|
| 142 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 143 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_diffusion_unet.yaml
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: diff_c
|
| 6 |
+
_target_: equi_diffpo.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 2
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
obs_as_global_cond: True
|
| 20 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
|
| 21 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
|
| 22 |
+
|
| 23 |
+
policy:
|
| 24 |
+
_target_: equi_diffpo.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
|
| 25 |
+
|
| 26 |
+
shape_meta: ${shape_meta}
|
| 27 |
+
|
| 28 |
+
noise_scheduler:
|
| 29 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 30 |
+
num_train_timesteps: 100
|
| 31 |
+
beta_start: 0.0001
|
| 32 |
+
beta_end: 0.02
|
| 33 |
+
beta_schedule: squaredcos_cap_v2
|
| 34 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 35 |
+
clip_sample: True # required when predict_epsilon=False
|
| 36 |
+
prediction_type: epsilon # or sample
|
| 37 |
+
|
| 38 |
+
horizon: ${horizon}
|
| 39 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 40 |
+
n_obs_steps: ${n_obs_steps}
|
| 41 |
+
num_inference_steps: 100
|
| 42 |
+
obs_as_global_cond: ${obs_as_global_cond}
|
| 43 |
+
crop_shape: [76, 76]
|
| 44 |
+
# crop_shape: null
|
| 45 |
+
diffusion_step_embed_dim: 128
|
| 46 |
+
down_dims: [512, 1024, 2048]
|
| 47 |
+
kernel_size: 5
|
| 48 |
+
n_groups: 8
|
| 49 |
+
cond_predict_scale: True
|
| 50 |
+
obs_encoder_group_norm: True
|
| 51 |
+
eval_fixed_crop: True
|
| 52 |
+
rot_aug: False
|
| 53 |
+
|
| 54 |
+
# scheduler.step params
|
| 55 |
+
# predict_epsilon: True
|
| 56 |
+
|
| 57 |
+
ema:
|
| 58 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 59 |
+
update_after_step: 0
|
| 60 |
+
inv_gamma: 1.0
|
| 61 |
+
power: 0.75
|
| 62 |
+
min_value: 0.0
|
| 63 |
+
max_value: 0.9999
|
| 64 |
+
|
| 65 |
+
dataloader:
|
| 66 |
+
batch_size: 64
|
| 67 |
+
num_workers: 4
|
| 68 |
+
shuffle: True
|
| 69 |
+
pin_memory: True
|
| 70 |
+
persistent_workers: True
|
| 71 |
+
|
| 72 |
+
val_dataloader:
|
| 73 |
+
batch_size: 64
|
| 74 |
+
num_workers: 4
|
| 75 |
+
shuffle: False
|
| 76 |
+
pin_memory: True
|
| 77 |
+
persistent_workers: True
|
| 78 |
+
|
| 79 |
+
optimizer:
|
| 80 |
+
_target_: torch.optim.AdamW
|
| 81 |
+
lr: 1.0e-4
|
| 82 |
+
betas: [0.95, 0.999]
|
| 83 |
+
eps: 1.0e-8
|
| 84 |
+
weight_decay: 1.0e-6
|
| 85 |
+
|
| 86 |
+
training:
|
| 87 |
+
device: "cuda:0"
|
| 88 |
+
seed: 0
|
| 89 |
+
debug: False
|
| 90 |
+
resume: True
|
| 91 |
+
# optimization
|
| 92 |
+
lr_scheduler: cosine
|
| 93 |
+
lr_warmup_steps: 500
|
| 94 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 95 |
+
gradient_accumulate_every: 1
|
| 96 |
+
# EMA destroys performance when used with BatchNorm
|
| 97 |
+
# replace BatchNorm with GroupNorm.
|
| 98 |
+
use_ema: True
|
| 99 |
+
# training loop control
|
| 100 |
+
# in epochs
|
| 101 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 102 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 103 |
+
val_every: 1
|
| 104 |
+
sample_every: 5
|
| 105 |
+
# steps per epoch
|
| 106 |
+
max_train_steps: null
|
| 107 |
+
max_val_steps: null
|
| 108 |
+
# misc
|
| 109 |
+
tqdm_interval_sec: 1.0
|
| 110 |
+
|
| 111 |
+
logging:
|
| 112 |
+
project: diffusion_policy_${task_name}
|
| 113 |
+
resume: True
|
| 114 |
+
mode: online
|
| 115 |
+
name: diff_c_demo${n_demo}
|
| 116 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 117 |
+
id: null
|
| 118 |
+
group: null
|
| 119 |
+
|
| 120 |
+
checkpoint:
|
| 121 |
+
topk:
|
| 122 |
+
monitor_key: test_mean_score
|
| 123 |
+
mode: max
|
| 124 |
+
k: 5
|
| 125 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 126 |
+
save_last_ckpt: True
|
| 127 |
+
save_last_snapshot: False
|
| 128 |
+
|
| 129 |
+
multi_run:
|
| 130 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 131 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 132 |
+
|
| 133 |
+
hydra:
|
| 134 |
+
job:
|
| 135 |
+
override_dirname: ${name}
|
| 136 |
+
run:
|
| 137 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 138 |
+
sweep:
|
| 139 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 140 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_diffusion_unet_voxel_abs.yaml
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_voxel_abs
|
| 4 |
+
|
| 5 |
+
name: diff_voxel
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 1
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
# dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.diffusion_unet_voxel_policy.DiffusionUNetPolicyVoxel
|
| 24 |
+
|
| 25 |
+
shape_meta: ${shape_meta}
|
| 26 |
+
|
| 27 |
+
noise_scheduler:
|
| 28 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 29 |
+
num_train_timesteps: 100
|
| 30 |
+
beta_start: 0.0001
|
| 31 |
+
beta_end: 0.02
|
| 32 |
+
beta_schedule: squaredcos_cap_v2
|
| 33 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 34 |
+
clip_sample: True # required when predict_epsilon=False
|
| 35 |
+
prediction_type: epsilon # or sample
|
| 36 |
+
|
| 37 |
+
horizon: ${horizon}
|
| 38 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 39 |
+
n_obs_steps: ${n_obs_steps}
|
| 40 |
+
num_inference_steps: 100
|
| 41 |
+
crop_shape: [58, 58, 58]
|
| 42 |
+
# crop_shape: null
|
| 43 |
+
diffusion_step_embed_dim: 128
|
| 44 |
+
enc_n_hidden: 256
|
| 45 |
+
down_dims: [256, 512, 1024]
|
| 46 |
+
kernel_size: 5
|
| 47 |
+
n_groups: 8
|
| 48 |
+
cond_predict_scale: True
|
| 49 |
+
rot_aug: False
|
| 50 |
+
|
| 51 |
+
# scheduler.step params
|
| 52 |
+
# predict_epsilon: True
|
| 53 |
+
|
| 54 |
+
ema:
|
| 55 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 56 |
+
update_after_step: 0
|
| 57 |
+
inv_gamma: 1.0
|
| 58 |
+
power: 0.75
|
| 59 |
+
min_value: 0.0
|
| 60 |
+
max_value: 0.9999
|
| 61 |
+
|
| 62 |
+
dataloader:
|
| 63 |
+
batch_size: 64
|
| 64 |
+
num_workers: 16
|
| 65 |
+
shuffle: True
|
| 66 |
+
pin_memory: True
|
| 67 |
+
persistent_workers: True
|
| 68 |
+
drop_last: true
|
| 69 |
+
|
| 70 |
+
val_dataloader:
|
| 71 |
+
batch_size: 64
|
| 72 |
+
num_workers: 16
|
| 73 |
+
shuffle: False
|
| 74 |
+
pin_memory: True
|
| 75 |
+
persistent_workers: True
|
| 76 |
+
|
| 77 |
+
optimizer:
|
| 78 |
+
betas: [0.95, 0.999]
|
| 79 |
+
eps: 1.0e-08
|
| 80 |
+
learning_rate: 0.0001
|
| 81 |
+
weight_decay: 1.0e-06
|
| 82 |
+
|
| 83 |
+
training:
|
| 84 |
+
device: "cuda:0"
|
| 85 |
+
seed: 0
|
| 86 |
+
debug: False
|
| 87 |
+
resume: True
|
| 88 |
+
# optimization
|
| 89 |
+
lr_scheduler: cosine
|
| 90 |
+
lr_warmup_steps: 500
|
| 91 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 92 |
+
gradient_accumulate_every: 1
|
| 93 |
+
# EMA destroys performance when used with BatchNorm
|
| 94 |
+
# replace BatchNorm with GroupNorm.
|
| 95 |
+
use_ema: True
|
| 96 |
+
# training loop control
|
| 97 |
+
# in epochs
|
| 98 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 99 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 100 |
+
val_every: 1
|
| 101 |
+
sample_every: 5
|
| 102 |
+
# steps per epoch
|
| 103 |
+
max_train_steps: null
|
| 104 |
+
max_val_steps: null
|
| 105 |
+
# misc
|
| 106 |
+
tqdm_interval_sec: 1.0
|
| 107 |
+
|
| 108 |
+
logging:
|
| 109 |
+
project: equi_diff_${task_name}_voxel
|
| 110 |
+
resume: True
|
| 111 |
+
mode: online
|
| 112 |
+
name: diff_voxel_${n_demo}
|
| 113 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 114 |
+
id: null
|
| 115 |
+
group: null
|
| 116 |
+
|
| 117 |
+
checkpoint:
|
| 118 |
+
topk:
|
| 119 |
+
monitor_key: test_mean_score
|
| 120 |
+
mode: max
|
| 121 |
+
k: 5
|
| 122 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 123 |
+
save_last_ckpt: True
|
| 124 |
+
save_last_snapshot: False
|
| 125 |
+
|
| 126 |
+
multi_run:
|
| 127 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 128 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 129 |
+
|
| 130 |
+
hydra:
|
| 131 |
+
job:
|
| 132 |
+
override_dirname: ${name}
|
| 133 |
+
run:
|
| 134 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 135 |
+
sweep:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs.yaml
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: equi_diff
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 1000
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 2
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 24 |
+
|
| 25 |
+
shape_meta: ${shape_meta}
|
| 26 |
+
|
| 27 |
+
noise_scheduler:
|
| 28 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 29 |
+
num_train_timesteps: 100
|
| 30 |
+
beta_start: 0.0001
|
| 31 |
+
beta_end: 0.02
|
| 32 |
+
beta_schedule: squaredcos_cap_v2
|
| 33 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 34 |
+
clip_sample: True # required when predict_epsilon=False
|
| 35 |
+
prediction_type: epsilon # or sample
|
| 36 |
+
|
| 37 |
+
horizon: ${horizon}
|
| 38 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 39 |
+
n_obs_steps: ${n_obs_steps}
|
| 40 |
+
num_inference_steps: 100
|
| 41 |
+
crop_shape: [76, 76]
|
| 42 |
+
# crop_shape: null
|
| 43 |
+
diffusion_step_embed_dim: 128
|
| 44 |
+
enc_n_hidden: 128
|
| 45 |
+
down_dims: [512, 1024, 2048]
|
| 46 |
+
kernel_size: 5
|
| 47 |
+
n_groups: 8
|
| 48 |
+
cond_predict_scale: True
|
| 49 |
+
rot_aug: False
|
| 50 |
+
|
| 51 |
+
# scheduler.step params
|
| 52 |
+
# predict_epsilon: True
|
| 53 |
+
|
| 54 |
+
ema:
|
| 55 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 56 |
+
update_after_step: 0
|
| 57 |
+
inv_gamma: 1.0
|
| 58 |
+
power: 0.75
|
| 59 |
+
min_value: 0.0
|
| 60 |
+
max_value: 0.9999
|
| 61 |
+
|
| 62 |
+
dataloader:
|
| 63 |
+
batch_size: 128
|
| 64 |
+
num_workers: 4
|
| 65 |
+
shuffle: True
|
| 66 |
+
pin_memory: True
|
| 67 |
+
persistent_workers: True
|
| 68 |
+
drop_last: true
|
| 69 |
+
|
| 70 |
+
val_dataloader:
|
| 71 |
+
batch_size: 128
|
| 72 |
+
num_workers: 8
|
| 73 |
+
shuffle: False
|
| 74 |
+
pin_memory: True
|
| 75 |
+
persistent_workers: True
|
| 76 |
+
|
| 77 |
+
optimizer:
|
| 78 |
+
betas: [0.95, 0.999]
|
| 79 |
+
eps: 1.0e-08
|
| 80 |
+
learning_rate: 0.0001
|
| 81 |
+
weight_decay: 1.0e-06
|
| 82 |
+
|
| 83 |
+
training:
|
| 84 |
+
device: "cuda:0"
|
| 85 |
+
seed: 0
|
| 86 |
+
debug: False
|
| 87 |
+
resume: True
|
| 88 |
+
# optimization
|
| 89 |
+
lr_scheduler: cosine
|
| 90 |
+
lr_warmup_steps: 500
|
| 91 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 92 |
+
gradient_accumulate_every: 1
|
| 93 |
+
# EMA destroys performance when used with BatchNorm
|
| 94 |
+
# replace BatchNorm with GroupNorm.
|
| 95 |
+
use_ema: True
|
| 96 |
+
# training loop control
|
| 97 |
+
# in epochs
|
| 98 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 99 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 100 |
+
val_every: 1
|
| 101 |
+
sample_every: 5
|
| 102 |
+
# steps per epoch
|
| 103 |
+
max_train_steps: null
|
| 104 |
+
max_val_steps: null
|
| 105 |
+
# misc
|
| 106 |
+
tqdm_interval_sec: 1.0
|
| 107 |
+
|
| 108 |
+
logging:
|
| 109 |
+
project: diffusion_policy_${task_name}
|
| 110 |
+
resume: True
|
| 111 |
+
mode: online
|
| 112 |
+
name: equidiff_demo${n_demo}
|
| 113 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 114 |
+
id: null
|
| 115 |
+
group: null
|
| 116 |
+
|
| 117 |
+
checkpoint:
|
| 118 |
+
topk:
|
| 119 |
+
monitor_key: test_mean_score
|
| 120 |
+
mode: max
|
| 121 |
+
k: 5
|
| 122 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 123 |
+
save_last_ckpt: True
|
| 124 |
+
save_last_snapshot: False
|
| 125 |
+
|
| 126 |
+
multi_run:
|
| 127 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 128 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 129 |
+
|
| 130 |
+
hydra:
|
| 131 |
+
job:
|
| 132 |
+
override_dirname: ${name}
|
| 133 |
+
run:
|
| 134 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 135 |
+
sweep:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs_sq2_0-1.yaml
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: equi_diff
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: square_d2
|
| 12 |
+
n_demo: 1000
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 2
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 24 |
+
|
| 25 |
+
shape_meta: ${shape_meta}
|
| 26 |
+
|
| 27 |
+
noise_scheduler:
|
| 28 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 29 |
+
num_train_timesteps: 100
|
| 30 |
+
beta_start: 0.0001
|
| 31 |
+
beta_end: 0.02
|
| 32 |
+
beta_schedule: squaredcos_cap_v2
|
| 33 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 34 |
+
clip_sample: True # required when predict_epsilon=False
|
| 35 |
+
prediction_type: epsilon # or sample
|
| 36 |
+
|
| 37 |
+
horizon: ${horizon}
|
| 38 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 39 |
+
n_obs_steps: ${n_obs_steps}
|
| 40 |
+
num_inference_steps: 100
|
| 41 |
+
crop_shape: [76, 76]
|
| 42 |
+
# crop_shape: null
|
| 43 |
+
diffusion_step_embed_dim: 128
|
| 44 |
+
enc_n_hidden: 128
|
| 45 |
+
down_dims: [512, 1024, 2048]
|
| 46 |
+
kernel_size: 5
|
| 47 |
+
n_groups: 8
|
| 48 |
+
cond_predict_scale: True
|
| 49 |
+
rot_aug: False
|
| 50 |
+
|
| 51 |
+
# scheduler.step params
|
| 52 |
+
# predict_epsilon: True
|
| 53 |
+
|
| 54 |
+
ema:
|
| 55 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 56 |
+
update_after_step: 0
|
| 57 |
+
inv_gamma: 1.0
|
| 58 |
+
power: 0.75
|
| 59 |
+
min_value: 0.0
|
| 60 |
+
max_value: 0.9999
|
| 61 |
+
|
| 62 |
+
dataloader:
|
| 63 |
+
batch_size: 128
|
| 64 |
+
num_workers: 4
|
| 65 |
+
shuffle: True
|
| 66 |
+
pin_memory: True
|
| 67 |
+
persistent_workers: True
|
| 68 |
+
drop_last: true
|
| 69 |
+
|
| 70 |
+
val_dataloader:
|
| 71 |
+
batch_size: 128
|
| 72 |
+
num_workers: 8
|
| 73 |
+
shuffle: False
|
| 74 |
+
pin_memory: True
|
| 75 |
+
persistent_workers: True
|
| 76 |
+
|
| 77 |
+
optimizer:
|
| 78 |
+
betas: [0.95, 0.999]
|
| 79 |
+
eps: 1.0e-08
|
| 80 |
+
learning_rate: 0.0001
|
| 81 |
+
weight_decay: 1.0e-06
|
| 82 |
+
|
| 83 |
+
training:
|
| 84 |
+
device: "cuda:0"
|
| 85 |
+
seed: 0
|
| 86 |
+
debug: False
|
| 87 |
+
resume: True
|
| 88 |
+
# optimization
|
| 89 |
+
lr_scheduler: cosine
|
| 90 |
+
lr_warmup_steps: 500
|
| 91 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 92 |
+
gradient_accumulate_every: 1
|
| 93 |
+
# EMA destroys performance when used with BatchNorm
|
| 94 |
+
# replace BatchNorm with GroupNorm.
|
| 95 |
+
use_ema: True
|
| 96 |
+
# training loop control
|
| 97 |
+
# in epochs
|
| 98 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 99 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 100 |
+
val_every: 1
|
| 101 |
+
sample_every: 5
|
| 102 |
+
# steps per epoch
|
| 103 |
+
max_train_steps: null
|
| 104 |
+
max_val_steps: null
|
| 105 |
+
# misc
|
| 106 |
+
tqdm_interval_sec: 1.0
|
| 107 |
+
|
| 108 |
+
logging:
|
| 109 |
+
project: diffusion_policy_${task_name}
|
| 110 |
+
resume: True
|
| 111 |
+
mode: online
|
| 112 |
+
name: equidiff_demo${n_demo}
|
| 113 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 114 |
+
id: null
|
| 115 |
+
group: null
|
| 116 |
+
|
| 117 |
+
checkpoint:
|
| 118 |
+
topk:
|
| 119 |
+
monitor_key: test_mean_score
|
| 120 |
+
mode: max
|
| 121 |
+
k: 5
|
| 122 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 123 |
+
save_last_ckpt: True
|
| 124 |
+
save_last_snapshot: False
|
| 125 |
+
|
| 126 |
+
multi_run:
|
| 127 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 128 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 129 |
+
|
| 130 |
+
hydra:
|
| 131 |
+
job:
|
| 132 |
+
override_dirname: ${name}
|
| 133 |
+
run:
|
| 134 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 135 |
+
sweep:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_equi_diffusion_unet_abs_sq2_1-1.yaml
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: equi_diff
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: square_d2
|
| 12 |
+
n_demo: 1000
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 2
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}_1-1/${task_name}_1-1_abs.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 24 |
+
|
| 25 |
+
shape_meta: ${shape_meta}
|
| 26 |
+
|
| 27 |
+
noise_scheduler:
|
| 28 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 29 |
+
num_train_timesteps: 100
|
| 30 |
+
beta_start: 0.0001
|
| 31 |
+
beta_end: 0.02
|
| 32 |
+
beta_schedule: squaredcos_cap_v2
|
| 33 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 34 |
+
clip_sample: True # required when predict_epsilon=False
|
| 35 |
+
prediction_type: epsilon # or sample
|
| 36 |
+
|
| 37 |
+
horizon: ${horizon}
|
| 38 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 39 |
+
n_obs_steps: ${n_obs_steps}
|
| 40 |
+
num_inference_steps: 100
|
| 41 |
+
crop_shape: [76, 76]
|
| 42 |
+
# crop_shape: null
|
| 43 |
+
diffusion_step_embed_dim: 128
|
| 44 |
+
enc_n_hidden: 128
|
| 45 |
+
down_dims: [512, 1024, 2048]
|
| 46 |
+
kernel_size: 5
|
| 47 |
+
n_groups: 8
|
| 48 |
+
cond_predict_scale: True
|
| 49 |
+
rot_aug: False
|
| 50 |
+
|
| 51 |
+
# scheduler.step params
|
| 52 |
+
# predict_epsilon: True
|
| 53 |
+
|
| 54 |
+
ema:
|
| 55 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 56 |
+
update_after_step: 0
|
| 57 |
+
inv_gamma: 1.0
|
| 58 |
+
power: 0.75
|
| 59 |
+
min_value: 0.0
|
| 60 |
+
max_value: 0.9999
|
| 61 |
+
|
| 62 |
+
dataloader:
|
| 63 |
+
batch_size: 128
|
| 64 |
+
num_workers: 4
|
| 65 |
+
shuffle: True
|
| 66 |
+
pin_memory: True
|
| 67 |
+
persistent_workers: True
|
| 68 |
+
drop_last: true
|
| 69 |
+
|
| 70 |
+
val_dataloader:
|
| 71 |
+
batch_size: 128
|
| 72 |
+
num_workers: 8
|
| 73 |
+
shuffle: False
|
| 74 |
+
pin_memory: True
|
| 75 |
+
persistent_workers: True
|
| 76 |
+
|
| 77 |
+
optimizer:
|
| 78 |
+
betas: [0.95, 0.999]
|
| 79 |
+
eps: 1.0e-08
|
| 80 |
+
learning_rate: 0.0001
|
| 81 |
+
weight_decay: 1.0e-06
|
| 82 |
+
|
| 83 |
+
training:
|
| 84 |
+
device: "cuda:0"
|
| 85 |
+
seed: 0
|
| 86 |
+
debug: False
|
| 87 |
+
resume: True
|
| 88 |
+
# optimization
|
| 89 |
+
lr_scheduler: cosine
|
| 90 |
+
lr_warmup_steps: 500
|
| 91 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 92 |
+
gradient_accumulate_every: 1
|
| 93 |
+
# EMA destroys performance when used with BatchNorm
|
| 94 |
+
# replace BatchNorm with GroupNorm.
|
| 95 |
+
use_ema: True
|
| 96 |
+
# training loop control
|
| 97 |
+
# in epochs
|
| 98 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 99 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 100 |
+
val_every: 1
|
| 101 |
+
sample_every: 5
|
| 102 |
+
# steps per epoch
|
| 103 |
+
max_train_steps: null
|
| 104 |
+
max_val_steps: null
|
| 105 |
+
# misc
|
| 106 |
+
tqdm_interval_sec: 1.0
|
| 107 |
+
|
| 108 |
+
logging:
|
| 109 |
+
project: diffusion_policy_${task_name}
|
| 110 |
+
resume: True
|
| 111 |
+
mode: online
|
| 112 |
+
name: equidiff_demo${n_demo}
|
| 113 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 114 |
+
id: null
|
| 115 |
+
group: null
|
| 116 |
+
|
| 117 |
+
checkpoint:
|
| 118 |
+
topk:
|
| 119 |
+
monitor_key: test_mean_score
|
| 120 |
+
mode: max
|
| 121 |
+
k: 5
|
| 122 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 123 |
+
save_last_ckpt: True
|
| 124 |
+
save_last_snapshot: False
|
| 125 |
+
|
| 126 |
+
multi_run:
|
| 127 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 128 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 129 |
+
|
| 130 |
+
hydra:
|
| 131 |
+
job:
|
| 132 |
+
override_dirname: ${name}
|
| 133 |
+
run:
|
| 134 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 135 |
+
sweep:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_equi_diffusion_unet_rel.yaml
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_rel
|
| 4 |
+
|
| 5 |
+
name: equi_diff
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 2
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_rel_policy.DiffusionEquiUNetCNNEncRelPolicy
|
| 24 |
+
|
| 25 |
+
shape_meta: ${shape_meta}
|
| 26 |
+
|
| 27 |
+
noise_scheduler:
|
| 28 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 29 |
+
num_train_timesteps: 100
|
| 30 |
+
beta_start: 0.0001
|
| 31 |
+
beta_end: 0.02
|
| 32 |
+
beta_schedule: squaredcos_cap_v2
|
| 33 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 34 |
+
clip_sample: True # required when predict_epsilon=False
|
| 35 |
+
prediction_type: epsilon # or sample
|
| 36 |
+
|
| 37 |
+
horizon: ${horizon}
|
| 38 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 39 |
+
n_obs_steps: ${n_obs_steps}
|
| 40 |
+
num_inference_steps: 100
|
| 41 |
+
crop_shape: [76, 76]
|
| 42 |
+
# crop_shape: null
|
| 43 |
+
diffusion_step_embed_dim: 128
|
| 44 |
+
enc_n_hidden: 128
|
| 45 |
+
down_dims: [512, 1024, 2048]
|
| 46 |
+
kernel_size: 5
|
| 47 |
+
n_groups: 8
|
| 48 |
+
cond_predict_scale: True
|
| 49 |
+
|
| 50 |
+
# scheduler.step params
|
| 51 |
+
# predict_epsilon: True
|
| 52 |
+
|
| 53 |
+
ema:
|
| 54 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 55 |
+
update_after_step: 0
|
| 56 |
+
inv_gamma: 1.0
|
| 57 |
+
power: 0.75
|
| 58 |
+
min_value: 0.0
|
| 59 |
+
max_value: 0.9999
|
| 60 |
+
|
| 61 |
+
dataloader:
|
| 62 |
+
batch_size: 128
|
| 63 |
+
num_workers: 4
|
| 64 |
+
shuffle: True
|
| 65 |
+
pin_memory: True
|
| 66 |
+
persistent_workers: True
|
| 67 |
+
drop_last: true
|
| 68 |
+
|
| 69 |
+
val_dataloader:
|
| 70 |
+
batch_size: 128
|
| 71 |
+
num_workers: 8
|
| 72 |
+
shuffle: False
|
| 73 |
+
pin_memory: True
|
| 74 |
+
persistent_workers: True
|
| 75 |
+
|
| 76 |
+
optimizer:
|
| 77 |
+
betas: [0.95, 0.999]
|
| 78 |
+
eps: 1.0e-08
|
| 79 |
+
learning_rate: 0.0001
|
| 80 |
+
weight_decay: 1.0e-06
|
| 81 |
+
|
| 82 |
+
training:
|
| 83 |
+
device: "cuda:0"
|
| 84 |
+
seed: 0
|
| 85 |
+
debug: False
|
| 86 |
+
resume: True
|
| 87 |
+
# optimization
|
| 88 |
+
lr_scheduler: cosine
|
| 89 |
+
lr_warmup_steps: 500
|
| 90 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 91 |
+
gradient_accumulate_every: 1
|
| 92 |
+
# EMA destroys performance when used with BatchNorm
|
| 93 |
+
# replace BatchNorm with GroupNorm.
|
| 94 |
+
use_ema: True
|
| 95 |
+
# training loop control
|
| 96 |
+
# in epochs
|
| 97 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 98 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 99 |
+
val_every: 1
|
| 100 |
+
sample_every: 5
|
| 101 |
+
# steps per epoch
|
| 102 |
+
max_train_steps: null
|
| 103 |
+
max_val_steps: null
|
| 104 |
+
# misc
|
| 105 |
+
tqdm_interval_sec: 1.0
|
| 106 |
+
|
| 107 |
+
logging:
|
| 108 |
+
project: diffusion_policy_${task_name}_vel
|
| 109 |
+
resume: True
|
| 110 |
+
mode: online
|
| 111 |
+
name: equidiff_demo${n_demo}
|
| 112 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 113 |
+
id: null
|
| 114 |
+
group: null
|
| 115 |
+
|
| 116 |
+
checkpoint:
|
| 117 |
+
topk:
|
| 118 |
+
monitor_key: test_mean_score
|
| 119 |
+
mode: max
|
| 120 |
+
k: 5
|
| 121 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 122 |
+
save_last_ckpt: True
|
| 123 |
+
save_last_snapshot: False
|
| 124 |
+
|
| 125 |
+
multi_run:
|
| 126 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 127 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 128 |
+
|
| 129 |
+
hydra:
|
| 130 |
+
job:
|
| 131 |
+
override_dirname: ${name}
|
| 132 |
+
run:
|
| 133 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 134 |
+
sweep:
|
| 135 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 136 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_equi_diffusion_unet_voxel_abs.yaml
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_voxel_abs
|
| 4 |
+
|
| 5 |
+
name: equi_diff_voxel
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 1
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
# dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_voxel_policy.DiffusionEquiUNetPolicyVoxel
|
| 24 |
+
|
| 25 |
+
shape_meta: ${shape_meta}
|
| 26 |
+
|
| 27 |
+
noise_scheduler:
|
| 28 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 29 |
+
num_train_timesteps: 100
|
| 30 |
+
beta_start: 0.0001
|
| 31 |
+
beta_end: 0.02
|
| 32 |
+
beta_schedule: squaredcos_cap_v2
|
| 33 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 34 |
+
clip_sample: True # required when predict_epsilon=False
|
| 35 |
+
prediction_type: epsilon # or sample
|
| 36 |
+
|
| 37 |
+
horizon: ${horizon}
|
| 38 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 39 |
+
n_obs_steps: ${n_obs_steps}
|
| 40 |
+
num_inference_steps: 100
|
| 41 |
+
crop_shape: [58, 58, 58]
|
| 42 |
+
# crop_shape: null
|
| 43 |
+
diffusion_step_embed_dim: 128
|
| 44 |
+
enc_n_hidden: 128
|
| 45 |
+
down_dims: [256, 512, 1024]
|
| 46 |
+
kernel_size: 5
|
| 47 |
+
n_groups: 8
|
| 48 |
+
cond_predict_scale: True
|
| 49 |
+
rot_aug: True
|
| 50 |
+
|
| 51 |
+
# scheduler.step params
|
| 52 |
+
# predict_epsilon: True
|
| 53 |
+
|
| 54 |
+
ema:
|
| 55 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 56 |
+
update_after_step: 0
|
| 57 |
+
inv_gamma: 1.0
|
| 58 |
+
power: 0.75
|
| 59 |
+
min_value: 0.0
|
| 60 |
+
max_value: 0.9999
|
| 61 |
+
|
| 62 |
+
dataloader:
|
| 63 |
+
batch_size: 64
|
| 64 |
+
num_workers: 16
|
| 65 |
+
shuffle: True
|
| 66 |
+
pin_memory: True
|
| 67 |
+
persistent_workers: True
|
| 68 |
+
drop_last: true
|
| 69 |
+
|
| 70 |
+
val_dataloader:
|
| 71 |
+
batch_size: 64
|
| 72 |
+
num_workers: 16
|
| 73 |
+
shuffle: False
|
| 74 |
+
pin_memory: True
|
| 75 |
+
persistent_workers: True
|
| 76 |
+
|
| 77 |
+
optimizer:
|
| 78 |
+
betas: [0.95, 0.999]
|
| 79 |
+
eps: 1.0e-08
|
| 80 |
+
learning_rate: 0.0001
|
| 81 |
+
weight_decay: 1.0e-06
|
| 82 |
+
|
| 83 |
+
training:
|
| 84 |
+
device: "cuda:0"
|
| 85 |
+
seed: 0
|
| 86 |
+
debug: False
|
| 87 |
+
resume: True
|
| 88 |
+
# optimization
|
| 89 |
+
lr_scheduler: cosine
|
| 90 |
+
lr_warmup_steps: 500
|
| 91 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 92 |
+
gradient_accumulate_every: 1
|
| 93 |
+
# EMA destroys performance when used with BatchNorm
|
| 94 |
+
# replace BatchNorm with GroupNorm.
|
| 95 |
+
use_ema: True
|
| 96 |
+
# training loop control
|
| 97 |
+
# in epochs
|
| 98 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 99 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 100 |
+
val_every: 1
|
| 101 |
+
sample_every: 5
|
| 102 |
+
# steps per epoch
|
| 103 |
+
max_train_steps: null
|
| 104 |
+
max_val_steps: null
|
| 105 |
+
# misc
|
| 106 |
+
tqdm_interval_sec: 1.0
|
| 107 |
+
|
| 108 |
+
logging:
|
| 109 |
+
project: equi_diff_${task_name}_voxel
|
| 110 |
+
resume: True
|
| 111 |
+
mode: online
|
| 112 |
+
name: equi_diff_voxel_${n_demo}
|
| 113 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 114 |
+
id: null
|
| 115 |
+
group: null
|
| 116 |
+
|
| 117 |
+
checkpoint:
|
| 118 |
+
topk:
|
| 119 |
+
monitor_key: test_mean_score
|
| 120 |
+
mode: max
|
| 121 |
+
k: 5
|
| 122 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 123 |
+
save_last_ckpt: True
|
| 124 |
+
save_last_snapshot: False
|
| 125 |
+
|
| 126 |
+
multi_run:
|
| 127 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 128 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 129 |
+
|
| 130 |
+
hydra:
|
| 131 |
+
job:
|
| 132 |
+
override_dirname: ${name}
|
| 133 |
+
run:
|
| 134 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 135 |
+
sweep:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_equi_diffusion_unet_voxel_rel.yaml
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_voxel_rel
|
| 4 |
+
|
| 5 |
+
name: equi_diff_voxel
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: stack_d1
|
| 12 |
+
n_demo: 200
|
| 13 |
+
horizon: 16
|
| 14 |
+
n_obs_steps: 1
|
| 15 |
+
n_action_steps: 8
|
| 16 |
+
n_latency_steps: 0
|
| 17 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 18 |
+
past_action_visible: False
|
| 19 |
+
# dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 20 |
+
dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel.hdf5
|
| 21 |
+
|
| 22 |
+
policy:
|
| 23 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_voxel_rel_policy.DiffusionEquiUNetRelPolicyVoxel
|
| 24 |
+
|
| 25 |
+
shape_meta: ${shape_meta}
|
| 26 |
+
|
| 27 |
+
noise_scheduler:
|
| 28 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 29 |
+
num_train_timesteps: 100
|
| 30 |
+
beta_start: 0.0001
|
| 31 |
+
beta_end: 0.02
|
| 32 |
+
beta_schedule: squaredcos_cap_v2
|
| 33 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 34 |
+
clip_sample: True # required when predict_epsilon=False
|
| 35 |
+
prediction_type: epsilon # or sample
|
| 36 |
+
|
| 37 |
+
horizon: ${horizon}
|
| 38 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 39 |
+
n_obs_steps: ${n_obs_steps}
|
| 40 |
+
num_inference_steps: 100
|
| 41 |
+
crop_shape: [58, 58, 58]
|
| 42 |
+
# crop_shape: null
|
| 43 |
+
diffusion_step_embed_dim: 128
|
| 44 |
+
enc_n_hidden: 128
|
| 45 |
+
down_dims: [256, 512, 1024]
|
| 46 |
+
kernel_size: 5
|
| 47 |
+
n_groups: 8
|
| 48 |
+
cond_predict_scale: True
|
| 49 |
+
rot_aug: True
|
| 50 |
+
|
| 51 |
+
# scheduler.step params
|
| 52 |
+
# predict_epsilon: True
|
| 53 |
+
|
| 54 |
+
ema:
|
| 55 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 56 |
+
update_after_step: 0
|
| 57 |
+
inv_gamma: 1.0
|
| 58 |
+
power: 0.75
|
| 59 |
+
min_value: 0.0
|
| 60 |
+
max_value: 0.9999
|
| 61 |
+
|
| 62 |
+
dataloader:
|
| 63 |
+
batch_size: 64
|
| 64 |
+
num_workers: 16
|
| 65 |
+
shuffle: True
|
| 66 |
+
pin_memory: True
|
| 67 |
+
persistent_workers: True
|
| 68 |
+
drop_last: true
|
| 69 |
+
|
| 70 |
+
val_dataloader:
|
| 71 |
+
batch_size: 64
|
| 72 |
+
num_workers: 16
|
| 73 |
+
shuffle: False
|
| 74 |
+
pin_memory: True
|
| 75 |
+
persistent_workers: True
|
| 76 |
+
|
| 77 |
+
optimizer:
|
| 78 |
+
betas: [0.95, 0.999]
|
| 79 |
+
eps: 1.0e-08
|
| 80 |
+
learning_rate: 0.0001
|
| 81 |
+
weight_decay: 1.0e-06
|
| 82 |
+
|
| 83 |
+
training:
|
| 84 |
+
device: "cuda:0"
|
| 85 |
+
seed: 0
|
| 86 |
+
debug: False
|
| 87 |
+
resume: True
|
| 88 |
+
# optimization
|
| 89 |
+
lr_scheduler: cosine
|
| 90 |
+
lr_warmup_steps: 500
|
| 91 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 92 |
+
gradient_accumulate_every: 1
|
| 93 |
+
# EMA destroys performance when used with BatchNorm
|
| 94 |
+
# replace BatchNorm with GroupNorm.
|
| 95 |
+
use_ema: True
|
| 96 |
+
# training loop control
|
| 97 |
+
# in epochs
|
| 98 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 99 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 100 |
+
val_every: 1
|
| 101 |
+
sample_every: 5
|
| 102 |
+
# steps per epoch
|
| 103 |
+
max_train_steps: null
|
| 104 |
+
max_val_steps: null
|
| 105 |
+
# misc
|
| 106 |
+
tqdm_interval_sec: 1.0
|
| 107 |
+
|
| 108 |
+
logging:
|
| 109 |
+
project: equi_diff_${task_name}_voxel_rel
|
| 110 |
+
resume: True
|
| 111 |
+
mode: online
|
| 112 |
+
name: equi_diff_voxel_${n_demo}
|
| 113 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 114 |
+
id: null
|
| 115 |
+
group: null
|
| 116 |
+
|
| 117 |
+
checkpoint:
|
| 118 |
+
topk:
|
| 119 |
+
monitor_key: test_mean_score
|
| 120 |
+
mode: max
|
| 121 |
+
k: 5
|
| 122 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 123 |
+
save_last_ckpt: True
|
| 124 |
+
save_last_snapshot: False
|
| 125 |
+
|
| 126 |
+
multi_run:
|
| 127 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 128 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 129 |
+
|
| 130 |
+
hydra:
|
| 131 |
+
job:
|
| 132 |
+
override_dirname: ${name}
|
| 133 |
+
run:
|
| 134 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 135 |
+
sweep:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_sq2.yaml
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: equi_diff
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: square_d2
|
| 12 |
+
folder_name: square_d2
|
| 13 |
+
file_name: square_d2
|
| 14 |
+
n_demo: 1000
|
| 15 |
+
horizon: 16
|
| 16 |
+
n_obs_steps: 2
|
| 17 |
+
n_action_steps: 8
|
| 18 |
+
n_latency_steps: 0
|
| 19 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 20 |
+
past_action_visible: False
|
| 21 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 22 |
+
dataset_path: data/robomimic/datasets/${folder_name}/${file_name}_abs.hdf5
|
| 23 |
+
|
| 24 |
+
policy:
|
| 25 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 26 |
+
|
| 27 |
+
shape_meta: ${shape_meta}
|
| 28 |
+
|
| 29 |
+
noise_scheduler:
|
| 30 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 31 |
+
num_train_timesteps: 100
|
| 32 |
+
beta_start: 0.0001
|
| 33 |
+
beta_end: 0.02
|
| 34 |
+
beta_schedule: squaredcos_cap_v2
|
| 35 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 36 |
+
clip_sample: True # required when predict_epsilon=False
|
| 37 |
+
prediction_type: epsilon # or sample
|
| 38 |
+
|
| 39 |
+
horizon: ${horizon}
|
| 40 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 41 |
+
n_obs_steps: ${n_obs_steps}
|
| 42 |
+
num_inference_steps: 100
|
| 43 |
+
crop_shape: [76, 76]
|
| 44 |
+
# crop_shape: null
|
| 45 |
+
diffusion_step_embed_dim: 128
|
| 46 |
+
enc_n_hidden: 128
|
| 47 |
+
down_dims: [512, 1024, 2048]
|
| 48 |
+
kernel_size: 5
|
| 49 |
+
n_groups: 8
|
| 50 |
+
cond_predict_scale: True
|
| 51 |
+
rot_aug: False
|
| 52 |
+
|
| 53 |
+
# scheduler.step params
|
| 54 |
+
# predict_epsilon: True
|
| 55 |
+
|
| 56 |
+
ema:
|
| 57 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 58 |
+
update_after_step: 0
|
| 59 |
+
inv_gamma: 1.0
|
| 60 |
+
power: 0.75
|
| 61 |
+
min_value: 0.0
|
| 62 |
+
max_value: 0.9999
|
| 63 |
+
|
| 64 |
+
dataloader:
|
| 65 |
+
batch_size: 128
|
| 66 |
+
num_workers: 4
|
| 67 |
+
shuffle: True
|
| 68 |
+
pin_memory: True
|
| 69 |
+
persistent_workers: True
|
| 70 |
+
drop_last: true
|
| 71 |
+
|
| 72 |
+
val_dataloader:
|
| 73 |
+
batch_size: 128
|
| 74 |
+
num_workers: 8
|
| 75 |
+
shuffle: False
|
| 76 |
+
pin_memory: True
|
| 77 |
+
persistent_workers: True
|
| 78 |
+
|
| 79 |
+
optimizer:
|
| 80 |
+
betas: [0.95, 0.999]
|
| 81 |
+
eps: 1.0e-08
|
| 82 |
+
learning_rate: 0.0001
|
| 83 |
+
weight_decay: 1.0e-06
|
| 84 |
+
|
| 85 |
+
training:
|
| 86 |
+
device: "cuda:0"
|
| 87 |
+
seed: 0
|
| 88 |
+
debug: False
|
| 89 |
+
resume: True
|
| 90 |
+
# optimization
|
| 91 |
+
lr_scheduler: cosine
|
| 92 |
+
lr_warmup_steps: 500
|
| 93 |
+
num_epochs: ${eval:'50000 / ${n_demo}'}
|
| 94 |
+
gradient_accumulate_every: 1
|
| 95 |
+
# EMA destroys performance when used with BatchNorm
|
| 96 |
+
# replace BatchNorm with GroupNorm.
|
| 97 |
+
use_ema: True
|
| 98 |
+
# training loop control
|
| 99 |
+
# in epochs
|
| 100 |
+
rollout_every: ${eval:'1000 / ${n_demo}'}
|
| 101 |
+
checkpoint_every: ${eval:'1000 / ${n_demo}'}
|
| 102 |
+
val_every: 1
|
| 103 |
+
sample_every: 5
|
| 104 |
+
# steps per epoch
|
| 105 |
+
max_train_steps: null
|
| 106 |
+
max_val_steps: null
|
| 107 |
+
# misc
|
| 108 |
+
tqdm_interval_sec: 1.0
|
| 109 |
+
|
| 110 |
+
logging:
|
| 111 |
+
project: diffusion_policy_${task_name}
|
| 112 |
+
resume: True
|
| 113 |
+
mode: online
|
| 114 |
+
name: equidiff_demo${n_demo}
|
| 115 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 116 |
+
id: null
|
| 117 |
+
group: null
|
| 118 |
+
|
| 119 |
+
checkpoint:
|
| 120 |
+
topk:
|
| 121 |
+
monitor_key: test_mean_score
|
| 122 |
+
mode: max
|
| 123 |
+
k: 5
|
| 124 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 125 |
+
save_last_ckpt: True
|
| 126 |
+
save_last_snapshot: False
|
| 127 |
+
|
| 128 |
+
multi_run:
|
| 129 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 130 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 131 |
+
|
| 132 |
+
hydra:
|
| 133 |
+
job:
|
| 134 |
+
override_dirname: ${name}
|
| 135 |
+
run:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
sweep:
|
| 138 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 139 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_sq2_5000.yaml
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: equi_diff
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: square_d2
|
| 12 |
+
folder_name: square_d2
|
| 13 |
+
file_name: square_d2
|
| 14 |
+
n_demo: 5000
|
| 15 |
+
horizon: 16
|
| 16 |
+
n_obs_steps: 2
|
| 17 |
+
n_action_steps: 8
|
| 18 |
+
n_latency_steps: 0
|
| 19 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 20 |
+
past_action_visible: False
|
| 21 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 22 |
+
dataset_path: data/robomimic/datasets/${folder_name}/${file_name}_abs.hdf5
|
| 23 |
+
|
| 24 |
+
policy:
|
| 25 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 26 |
+
|
| 27 |
+
shape_meta: ${shape_meta}
|
| 28 |
+
|
| 29 |
+
noise_scheduler:
|
| 30 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 31 |
+
num_train_timesteps: 100
|
| 32 |
+
beta_start: 0.0001
|
| 33 |
+
beta_end: 0.02
|
| 34 |
+
beta_schedule: squaredcos_cap_v2
|
| 35 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 36 |
+
clip_sample: True # required when predict_epsilon=False
|
| 37 |
+
prediction_type: epsilon # or sample
|
| 38 |
+
|
| 39 |
+
horizon: ${horizon}
|
| 40 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 41 |
+
n_obs_steps: ${n_obs_steps}
|
| 42 |
+
num_inference_steps: 100
|
| 43 |
+
crop_shape: [76, 76]
|
| 44 |
+
# crop_shape: null
|
| 45 |
+
diffusion_step_embed_dim: 128
|
| 46 |
+
enc_n_hidden: 128
|
| 47 |
+
down_dims: [512, 1024, 2048]
|
| 48 |
+
kernel_size: 5
|
| 49 |
+
n_groups: 8
|
| 50 |
+
cond_predict_scale: True
|
| 51 |
+
rot_aug: False
|
| 52 |
+
|
| 53 |
+
# scheduler.step params
|
| 54 |
+
# predict_epsilon: True
|
| 55 |
+
|
| 56 |
+
ema:
|
| 57 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 58 |
+
update_after_step: 0
|
| 59 |
+
inv_gamma: 1.0
|
| 60 |
+
power: 0.75
|
| 61 |
+
min_value: 0.0
|
| 62 |
+
max_value: 0.9999
|
| 63 |
+
|
| 64 |
+
dataloader:
|
| 65 |
+
batch_size: 128
|
| 66 |
+
num_workers: 4
|
| 67 |
+
shuffle: True
|
| 68 |
+
pin_memory: True
|
| 69 |
+
persistent_workers: True
|
| 70 |
+
drop_last: true
|
| 71 |
+
|
| 72 |
+
val_dataloader:
|
| 73 |
+
batch_size: 128
|
| 74 |
+
num_workers: 8
|
| 75 |
+
shuffle: False
|
| 76 |
+
pin_memory: True
|
| 77 |
+
persistent_workers: True
|
| 78 |
+
|
| 79 |
+
optimizer:
|
| 80 |
+
betas: [0.95, 0.999]
|
| 81 |
+
eps: 1.0e-08
|
| 82 |
+
learning_rate: 0.0001
|
| 83 |
+
weight_decay: 1.0e-06
|
| 84 |
+
|
| 85 |
+
training:
|
| 86 |
+
device: "cuda:0"
|
| 87 |
+
seed: 0
|
| 88 |
+
debug: False
|
| 89 |
+
resume: True
|
| 90 |
+
# optimization
|
| 91 |
+
lr_scheduler: cosine
|
| 92 |
+
lr_warmup_steps: 2500
|
| 93 |
+
num_epochs: ${eval:'100000 / ${n_demo}'}
|
| 94 |
+
gradient_accumulate_every: 1
|
| 95 |
+
# EMA destroys performance when used with BatchNorm
|
| 96 |
+
# replace BatchNorm with GroupNorm.
|
| 97 |
+
use_ema: True
|
| 98 |
+
# training loop control
|
| 99 |
+
# in epochs
|
| 100 |
+
rollout_every: ${eval:'5000 / ${n_demo}'}
|
| 101 |
+
checkpoint_every: ${eval:'5000 / ${n_demo}'}
|
| 102 |
+
val_every: 1
|
| 103 |
+
sample_every: 5
|
| 104 |
+
# steps per epoch
|
| 105 |
+
max_train_steps: null
|
| 106 |
+
max_val_steps: null
|
| 107 |
+
# misc
|
| 108 |
+
tqdm_interval_sec: 1.0
|
| 109 |
+
|
| 110 |
+
logging:
|
| 111 |
+
project: diffusion_policy_${task_name}
|
| 112 |
+
resume: True
|
| 113 |
+
mode: online
|
| 114 |
+
name: equidiff_demo${n_demo}
|
| 115 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 116 |
+
id: null
|
| 117 |
+
group: null
|
| 118 |
+
|
| 119 |
+
checkpoint:
|
| 120 |
+
topk:
|
| 121 |
+
monitor_key: test_mean_score
|
| 122 |
+
mode: max
|
| 123 |
+
k: 5
|
| 124 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 125 |
+
save_last_ckpt: True
|
| 126 |
+
save_last_snapshot: False
|
| 127 |
+
|
| 128 |
+
multi_run:
|
| 129 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 130 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 131 |
+
|
| 132 |
+
hydra:
|
| 133 |
+
job:
|
| 134 |
+
override_dirname: ${name}
|
| 135 |
+
run:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
sweep:
|
| 138 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 139 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/config/train_th2_5000.yaml
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: mimicgen_abs
|
| 4 |
+
|
| 5 |
+
name: equi_diff
|
| 6 |
+
_target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
|
| 7 |
+
|
| 8 |
+
shape_meta: ${task.shape_meta}
|
| 9 |
+
exp_name: "default"
|
| 10 |
+
|
| 11 |
+
task_name: threading_d2
|
| 12 |
+
folder_name: threading_d2
|
| 13 |
+
file_name: threading_d2
|
| 14 |
+
n_demo: 5000
|
| 15 |
+
horizon: 16
|
| 16 |
+
n_obs_steps: 2
|
| 17 |
+
n_action_steps: 8
|
| 18 |
+
n_latency_steps: 0
|
| 19 |
+
dataset_obs_steps: ${n_obs_steps}
|
| 20 |
+
past_action_visible: False
|
| 21 |
+
dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
|
| 22 |
+
dataset_path: data/robomimic/datasets/${folder_name}/${file_name}_abs.hdf5
|
| 23 |
+
|
| 24 |
+
policy:
|
| 25 |
+
_target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
|
| 26 |
+
|
| 27 |
+
shape_meta: ${shape_meta}
|
| 28 |
+
|
| 29 |
+
noise_scheduler:
|
| 30 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
| 31 |
+
num_train_timesteps: 100
|
| 32 |
+
beta_start: 0.0001
|
| 33 |
+
beta_end: 0.02
|
| 34 |
+
beta_schedule: squaredcos_cap_v2
|
| 35 |
+
variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
|
| 36 |
+
clip_sample: True # required when predict_epsilon=False
|
| 37 |
+
prediction_type: epsilon # or sample
|
| 38 |
+
|
| 39 |
+
horizon: ${horizon}
|
| 40 |
+
n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
|
| 41 |
+
n_obs_steps: ${n_obs_steps}
|
| 42 |
+
num_inference_steps: 100
|
| 43 |
+
crop_shape: [76, 76]
|
| 44 |
+
# crop_shape: null
|
| 45 |
+
diffusion_step_embed_dim: 128
|
| 46 |
+
enc_n_hidden: 128
|
| 47 |
+
down_dims: [512, 1024, 2048]
|
| 48 |
+
kernel_size: 5
|
| 49 |
+
n_groups: 8
|
| 50 |
+
cond_predict_scale: True
|
| 51 |
+
rot_aug: False
|
| 52 |
+
|
| 53 |
+
# scheduler.step params
|
| 54 |
+
# predict_epsilon: True
|
| 55 |
+
|
| 56 |
+
ema:
|
| 57 |
+
_target_: equi_diffpo.model.diffusion.ema_model.EMAModel
|
| 58 |
+
update_after_step: 0
|
| 59 |
+
inv_gamma: 1.0
|
| 60 |
+
power: 0.75
|
| 61 |
+
min_value: 0.0
|
| 62 |
+
max_value: 0.9999
|
| 63 |
+
|
| 64 |
+
dataloader:
|
| 65 |
+
batch_size: 128
|
| 66 |
+
num_workers: 4
|
| 67 |
+
shuffle: True
|
| 68 |
+
pin_memory: True
|
| 69 |
+
persistent_workers: True
|
| 70 |
+
drop_last: true
|
| 71 |
+
|
| 72 |
+
val_dataloader:
|
| 73 |
+
batch_size: 128
|
| 74 |
+
num_workers: 8
|
| 75 |
+
shuffle: False
|
| 76 |
+
pin_memory: True
|
| 77 |
+
persistent_workers: True
|
| 78 |
+
|
| 79 |
+
optimizer:
|
| 80 |
+
betas: [0.95, 0.999]
|
| 81 |
+
eps: 1.0e-08
|
| 82 |
+
learning_rate: 0.0001
|
| 83 |
+
weight_decay: 1.0e-06
|
| 84 |
+
|
| 85 |
+
training:
|
| 86 |
+
device: "cuda:0"
|
| 87 |
+
seed: 0
|
| 88 |
+
debug: False
|
| 89 |
+
resume: True
|
| 90 |
+
# optimization
|
| 91 |
+
lr_scheduler: cosine
|
| 92 |
+
lr_warmup_steps: 2500
|
| 93 |
+
num_epochs: ${eval:'100000 / ${n_demo}'}
|
| 94 |
+
gradient_accumulate_every: 1
|
| 95 |
+
# EMA destroys performance when used with BatchNorm
|
| 96 |
+
# replace BatchNorm with GroupNorm.
|
| 97 |
+
use_ema: True
|
| 98 |
+
# training loop control
|
| 99 |
+
# in epochs
|
| 100 |
+
rollout_every: ${eval:'5000 / ${n_demo}'}
|
| 101 |
+
checkpoint_every: ${eval:'5000 / ${n_demo}'}
|
| 102 |
+
val_every: 1
|
| 103 |
+
sample_every: 5
|
| 104 |
+
# steps per epoch
|
| 105 |
+
max_train_steps: null
|
| 106 |
+
max_val_steps: null
|
| 107 |
+
# misc
|
| 108 |
+
tqdm_interval_sec: 1.0
|
| 109 |
+
|
| 110 |
+
logging:
|
| 111 |
+
project: diffusion_policy_${task_name}
|
| 112 |
+
resume: True
|
| 113 |
+
mode: online
|
| 114 |
+
name: equidiff_demo${n_demo}
|
| 115 |
+
tags: ["${name}", "${task_name}", "${exp_name}"]
|
| 116 |
+
id: null
|
| 117 |
+
group: null
|
| 118 |
+
|
| 119 |
+
checkpoint:
|
| 120 |
+
topk:
|
| 121 |
+
monitor_key: test_mean_score
|
| 122 |
+
mode: max
|
| 123 |
+
k: 5
|
| 124 |
+
format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
|
| 125 |
+
save_last_ckpt: True
|
| 126 |
+
save_last_snapshot: False
|
| 127 |
+
|
| 128 |
+
multi_run:
|
| 129 |
+
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 130 |
+
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
|
| 131 |
+
|
| 132 |
+
hydra:
|
| 133 |
+
job:
|
| 134 |
+
override_dirname: ${name}
|
| 135 |
+
run:
|
| 136 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 137 |
+
sweep:
|
| 138 |
+
dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
|
| 139 |
+
subdir: ${hydra.job.num}
|
equidiff/equi_diffpo/dataset/base_dataset.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn
|
| 5 |
+
from equi_diffpo.model.common.normalizer import LinearNormalizer
|
| 6 |
+
|
| 7 |
+
class BaseLowdimDataset(torch.utils.data.Dataset):
|
| 8 |
+
def get_validation_dataset(self) -> 'BaseLowdimDataset':
|
| 9 |
+
# return an empty dataset by default
|
| 10 |
+
return BaseLowdimDataset()
|
| 11 |
+
|
| 12 |
+
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
| 13 |
+
raise NotImplementedError()
|
| 14 |
+
|
| 15 |
+
def get_all_actions(self) -> torch.Tensor:
|
| 16 |
+
raise NotImplementedError()
|
| 17 |
+
|
| 18 |
+
def __len__(self) -> int:
|
| 19 |
+
return 0
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 22 |
+
"""
|
| 23 |
+
output:
|
| 24 |
+
obs: T, Do
|
| 25 |
+
action: T, Da
|
| 26 |
+
"""
|
| 27 |
+
raise NotImplementedError()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BaseImageDataset(torch.utils.data.Dataset):
|
| 31 |
+
def get_validation_dataset(self) -> 'BaseLowdimDataset':
|
| 32 |
+
# return an empty dataset by default
|
| 33 |
+
return BaseImageDataset()
|
| 34 |
+
|
| 35 |
+
def get_normalizer(self, **kwargs) -> LinearNormalizer:
|
| 36 |
+
raise NotImplementedError()
|
| 37 |
+
|
| 38 |
+
def get_all_actions(self) -> torch.Tensor:
|
| 39 |
+
raise NotImplementedError()
|
| 40 |
+
|
| 41 |
+
def __len__(self) -> int:
|
| 42 |
+
return 0
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 45 |
+
"""
|
| 46 |
+
output:
|
| 47 |
+
obs:
|
| 48 |
+
key: T, *
|
| 49 |
+
action: T, Da
|
| 50 |
+
"""
|
| 51 |
+
raise NotImplementedError()
|
equidiff/equi_diffpo/env_runner/base_image_runner.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
from equi_diffpo.policy.base_image_policy import BaseImagePolicy
|
| 3 |
+
|
| 4 |
+
class BaseImageRunner:
|
| 5 |
+
def __init__(self, output_dir):
|
| 6 |
+
self.output_dir = output_dir
|
| 7 |
+
|
| 8 |
+
def run(self, policy: BaseImagePolicy) -> Dict:
|
| 9 |
+
raise NotImplementedError()
|
equidiff/equi_diffpo/env_runner/base_lowdim_runner.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
from equi_diffpo.policy.base_lowdim_policy import BaseLowdimPolicy
|
| 3 |
+
|
| 4 |
+
class BaseLowdimRunner:
|
| 5 |
+
def __init__(self, output_dir):
|
| 6 |
+
self.output_dir = output_dir
|
| 7 |
+
|
| 8 |
+
def run(self, policy: BaseLowdimPolicy) -> Dict:
|
| 9 |
+
raise NotImplementedError()
|
equidiff/equi_diffpo/gym_util/async_vector_env.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Back ported methods: call, set_attr from v0.26
|
| 3 |
+
Disabled auto-reset after done
|
| 4 |
+
Added render method.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import multiprocessing as mp
|
| 10 |
+
if os.getenv("MUJOCO_GL") != "osmesa":
|
| 11 |
+
mp.set_start_method('spawn', force=True)
|
| 12 |
+
import time
|
| 13 |
+
import sys
|
| 14 |
+
from enum import Enum
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
|
| 17 |
+
from gym import logger
|
| 18 |
+
from gym.vector.vector_env import VectorEnv
|
| 19 |
+
from gym.error import (
|
| 20 |
+
AlreadyPendingCallError,
|
| 21 |
+
NoAsyncCallError,
|
| 22 |
+
ClosedEnvironmentError,
|
| 23 |
+
CustomSpaceError,
|
| 24 |
+
)
|
| 25 |
+
from gym.vector.utils import (
|
| 26 |
+
create_shared_memory,
|
| 27 |
+
create_empty_array,
|
| 28 |
+
write_to_shared_memory,
|
| 29 |
+
read_from_shared_memory,
|
| 30 |
+
concatenate,
|
| 31 |
+
CloudpickleWrapper,
|
| 32 |
+
clear_mpi_env_vars,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
__all__ = ["AsyncVectorEnv"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AsyncState(Enum):
|
| 39 |
+
DEFAULT = "default"
|
| 40 |
+
WAITING_RESET = "reset"
|
| 41 |
+
WAITING_STEP = "step"
|
| 42 |
+
WAITING_CALL = "call"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AsyncVectorEnv(VectorEnv):
|
| 46 |
+
"""Vectorized environment that runs multiple environments in parallel. It
|
| 47 |
+
uses `multiprocessing` processes, and pipes for communication.
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
env_fns : iterable of callable
|
| 51 |
+
Functions that create the environments.
|
| 52 |
+
observation_space : `gym.spaces.Space` instance, optional
|
| 53 |
+
Observation space of a single environment. If `None`, then the
|
| 54 |
+
observation space of the first environment is taken.
|
| 55 |
+
action_space : `gym.spaces.Space` instance, optional
|
| 56 |
+
Action space of a single environment. If `None`, then the action space
|
| 57 |
+
of the first environment is taken.
|
| 58 |
+
shared_memory : bool (default: `True`)
|
| 59 |
+
If `True`, then the observations from the worker processes are
|
| 60 |
+
communicated back through shared variables. This can improve the
|
| 61 |
+
efficiency if the observations are large (e.g. images).
|
| 62 |
+
copy : bool (default: `True`)
|
| 63 |
+
If `True`, then the `reset` and `step` methods return a copy of the
|
| 64 |
+
observations.
|
| 65 |
+
context : str, optional
|
| 66 |
+
Context for multiprocessing. If `None`, then the default context is used.
|
| 67 |
+
Only available in Python 3.
|
| 68 |
+
daemon : bool (default: `True`)
|
| 69 |
+
If `True`, then subprocesses have `daemon` flag turned on; that is, they
|
| 70 |
+
will quit if the head process quits. However, `daemon=True` prevents
|
| 71 |
+
subprocesses to spawn children, so for some environments you may want
|
| 72 |
+
to have it set to `False`
|
| 73 |
+
worker : function, optional
|
| 74 |
+
WARNING - advanced mode option! If set, then use that worker in a subprocess
|
| 75 |
+
instead of a default one. Can be useful to override some inner vector env
|
| 76 |
+
logic, for instance, how resets on done are handled. Provides high
|
| 77 |
+
degree of flexibility and a high chance to shoot yourself in the foot; thus,
|
| 78 |
+
if you are writing your own worker, it is recommended to start from the code
|
| 79 |
+
for `_worker` (or `_worker_shared_memory`) method below, and add changes
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
env_fns,
|
| 85 |
+
dummy_env_fn=None,
|
| 86 |
+
observation_space=None,
|
| 87 |
+
action_space=None,
|
| 88 |
+
shared_memory=True,
|
| 89 |
+
copy=True,
|
| 90 |
+
context=None,
|
| 91 |
+
daemon=True,
|
| 92 |
+
worker=None,
|
| 93 |
+
):
|
| 94 |
+
ctx = mp.get_context(context)
|
| 95 |
+
self.env_fns = env_fns
|
| 96 |
+
self.shared_memory = shared_memory
|
| 97 |
+
self.copy = copy
|
| 98 |
+
|
| 99 |
+
# Added dummy_env_fn to fix OpenGL error in Mujoco
|
| 100 |
+
# disable any OpenGL rendering in dummy_env_fn, since it
|
| 101 |
+
# will conflict with OpenGL context in the forked child process
|
| 102 |
+
if dummy_env_fn is None:
|
| 103 |
+
dummy_env_fn = env_fns[0]
|
| 104 |
+
dummy_env = dummy_env_fn()
|
| 105 |
+
self.metadata = dummy_env.metadata
|
| 106 |
+
|
| 107 |
+
if (observation_space is None) or (action_space is None):
|
| 108 |
+
observation_space = observation_space or dummy_env.observation_space
|
| 109 |
+
action_space = action_space or dummy_env.action_space
|
| 110 |
+
dummy_env.close()
|
| 111 |
+
del dummy_env
|
| 112 |
+
super(AsyncVectorEnv, self).__init__(
|
| 113 |
+
num_envs=len(env_fns),
|
| 114 |
+
observation_space=observation_space,
|
| 115 |
+
action_space=action_space,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if self.shared_memory:
|
| 119 |
+
try:
|
| 120 |
+
_obs_buffer = create_shared_memory(
|
| 121 |
+
self.single_observation_space, n=self.num_envs, ctx=ctx
|
| 122 |
+
)
|
| 123 |
+
self.observations = read_from_shared_memory(
|
| 124 |
+
_obs_buffer, self.single_observation_space, n=self.num_envs
|
| 125 |
+
)
|
| 126 |
+
except CustomSpaceError:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
"Using `shared_memory=True` in `AsyncVectorEnv` "
|
| 129 |
+
"is incompatible with non-standard Gym observation spaces "
|
| 130 |
+
"(i.e. custom spaces inheriting from `gym.Space`), and is "
|
| 131 |
+
"only compatible with default Gym spaces (e.g. `Box`, "
|
| 132 |
+
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
|
| 133 |
+
"if you use custom observation spaces."
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
_obs_buffer = None
|
| 137 |
+
self.observations = create_empty_array(
|
| 138 |
+
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self.parent_pipes, self.processes = [], []
|
| 142 |
+
self.error_queue = ctx.Queue()
|
| 143 |
+
target = _worker_shared_memory if self.shared_memory else _worker
|
| 144 |
+
target = worker or target
|
| 145 |
+
with clear_mpi_env_vars():
|
| 146 |
+
for idx, env_fn in enumerate(self.env_fns):
|
| 147 |
+
parent_pipe, child_pipe = ctx.Pipe()
|
| 148 |
+
process = ctx.Process(
|
| 149 |
+
target=target,
|
| 150 |
+
name="Worker<{0}>-{1}".format(type(self).__name__, idx),
|
| 151 |
+
args=(
|
| 152 |
+
idx,
|
| 153 |
+
CloudpickleWrapper(env_fn),
|
| 154 |
+
child_pipe,
|
| 155 |
+
parent_pipe,
|
| 156 |
+
_obs_buffer,
|
| 157 |
+
self.error_queue,
|
| 158 |
+
),
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.parent_pipes.append(parent_pipe)
|
| 162 |
+
self.processes.append(process)
|
| 163 |
+
|
| 164 |
+
process.daemon = daemon
|
| 165 |
+
process.start()
|
| 166 |
+
child_pipe.close()
|
| 167 |
+
|
| 168 |
+
self._state = AsyncState.DEFAULT
|
| 169 |
+
self._check_observation_spaces()
|
| 170 |
+
|
| 171 |
+
def seed(self, seeds=None):
|
| 172 |
+
self._assert_is_running()
|
| 173 |
+
if seeds is None:
|
| 174 |
+
seeds = [None for _ in range(self.num_envs)]
|
| 175 |
+
if isinstance(seeds, int):
|
| 176 |
+
seeds = [seeds + i for i in range(self.num_envs)]
|
| 177 |
+
assert len(seeds) == self.num_envs
|
| 178 |
+
|
| 179 |
+
if self._state != AsyncState.DEFAULT:
|
| 180 |
+
raise AlreadyPendingCallError(
|
| 181 |
+
"Calling `seed` while waiting "
|
| 182 |
+
"for a pending call to `{0}` to complete.".format(self._state.value),
|
| 183 |
+
self._state.value,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
for pipe, seed in zip(self.parent_pipes, seeds):
|
| 187 |
+
pipe.send(("seed", seed))
|
| 188 |
+
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
| 189 |
+
self._raise_if_errors(successes)
|
| 190 |
+
|
| 191 |
+
def reset_async(self):
|
| 192 |
+
self._assert_is_running()
|
| 193 |
+
if self._state != AsyncState.DEFAULT:
|
| 194 |
+
raise AlreadyPendingCallError(
|
| 195 |
+
"Calling `reset_async` while waiting "
|
| 196 |
+
"for a pending call to `{0}` to complete".format(self._state.value),
|
| 197 |
+
self._state.value,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
for pipe in self.parent_pipes:
|
| 201 |
+
pipe.send(("reset", None))
|
| 202 |
+
self._state = AsyncState.WAITING_RESET
|
| 203 |
+
|
| 204 |
+
def reset_wait(self, timeout=None):
|
| 205 |
+
"""
|
| 206 |
+
Parameters
|
| 207 |
+
----------
|
| 208 |
+
timeout : int or float, optional
|
| 209 |
+
Number of seconds before the call to `reset_wait` times out. If
|
| 210 |
+
`None`, the call to `reset_wait` never times out.
|
| 211 |
+
Returns
|
| 212 |
+
-------
|
| 213 |
+
observations : sample from `observation_space`
|
| 214 |
+
A batch of observations from the vectorized environment.
|
| 215 |
+
"""
|
| 216 |
+
self._assert_is_running()
|
| 217 |
+
if self._state != AsyncState.WAITING_RESET:
|
| 218 |
+
raise NoAsyncCallError(
|
| 219 |
+
"Calling `reset_wait` without any prior " "call to `reset_async`.",
|
| 220 |
+
AsyncState.WAITING_RESET.value,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if not self._poll(timeout):
|
| 224 |
+
self._state = AsyncState.DEFAULT
|
| 225 |
+
raise mp.TimeoutError(
|
| 226 |
+
"The call to `reset_wait` has timed out after "
|
| 227 |
+
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
| 231 |
+
self._raise_if_errors(successes)
|
| 232 |
+
self._state = AsyncState.DEFAULT
|
| 233 |
+
|
| 234 |
+
if not self.shared_memory:
|
| 235 |
+
self.observations = concatenate(
|
| 236 |
+
results, self.observations, self.single_observation_space
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return deepcopy(self.observations) if self.copy else self.observations
|
| 240 |
+
|
| 241 |
+
def step_async(self, actions):
|
| 242 |
+
"""
|
| 243 |
+
Parameters
|
| 244 |
+
----------
|
| 245 |
+
actions : iterable of samples from `action_space`
|
| 246 |
+
List of actions.
|
| 247 |
+
"""
|
| 248 |
+
self._assert_is_running()
|
| 249 |
+
if self._state != AsyncState.DEFAULT:
|
| 250 |
+
raise AlreadyPendingCallError(
|
| 251 |
+
"Calling `step_async` while waiting "
|
| 252 |
+
"for a pending call to `{0}` to complete.".format(self._state.value),
|
| 253 |
+
self._state.value,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
for pipe, action in zip(self.parent_pipes, actions):
|
| 257 |
+
pipe.send(("step", action))
|
| 258 |
+
self._state = AsyncState.WAITING_STEP
|
| 259 |
+
|
| 260 |
+
def step_wait(self, timeout=None):
|
| 261 |
+
"""
|
| 262 |
+
Parameters
|
| 263 |
+
----------
|
| 264 |
+
timeout : int or float, optional
|
| 265 |
+
Number of seconds before the call to `step_wait` times out. If
|
| 266 |
+
`None`, the call to `step_wait` never times out.
|
| 267 |
+
Returns
|
| 268 |
+
-------
|
| 269 |
+
observations : sample from `observation_space`
|
| 270 |
+
A batch of observations from the vectorized environment.
|
| 271 |
+
rewards : `np.ndarray` instance (dtype `np.float_`)
|
| 272 |
+
A vector of rewards from the vectorized environment.
|
| 273 |
+
dones : `np.ndarray` instance (dtype `np.bool_`)
|
| 274 |
+
A vector whose entries indicate whether the episode has ended.
|
| 275 |
+
infos : list of dict
|
| 276 |
+
A list of auxiliary diagnostic information.
|
| 277 |
+
"""
|
| 278 |
+
self._assert_is_running()
|
| 279 |
+
if self._state != AsyncState.WAITING_STEP:
|
| 280 |
+
raise NoAsyncCallError(
|
| 281 |
+
"Calling `step_wait` without any prior call " "to `step_async`.",
|
| 282 |
+
AsyncState.WAITING_STEP.value,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if not self._poll(timeout):
|
| 286 |
+
self._state = AsyncState.DEFAULT
|
| 287 |
+
raise mp.TimeoutError(
|
| 288 |
+
"The call to `step_wait` has timed out after "
|
| 289 |
+
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
| 293 |
+
self._raise_if_errors(successes)
|
| 294 |
+
self._state = AsyncState.DEFAULT
|
| 295 |
+
observations_list, rewards, dones, infos = zip(*results)
|
| 296 |
+
|
| 297 |
+
if not self.shared_memory:
|
| 298 |
+
self.observations = concatenate(
|
| 299 |
+
observations_list, self.observations, self.single_observation_space
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return (
|
| 303 |
+
deepcopy(self.observations) if self.copy else self.observations,
|
| 304 |
+
np.array(rewards),
|
| 305 |
+
np.array(dones, dtype=np.bool_),
|
| 306 |
+
infos,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def close_extras(self, timeout=None, terminate=False):
|
| 310 |
+
"""
|
| 311 |
+
Parameters
|
| 312 |
+
----------
|
| 313 |
+
timeout : int or float, optional
|
| 314 |
+
Number of seconds before the call to `close` times out. If `None`,
|
| 315 |
+
the call to `close` never times out. If the call to `close` times
|
| 316 |
+
out, then all processes are terminated.
|
| 317 |
+
terminate : bool (default: `False`)
|
| 318 |
+
If `True`, then the `close` operation is forced and all processes
|
| 319 |
+
are terminated.
|
| 320 |
+
"""
|
| 321 |
+
timeout = 0 if terminate else timeout
|
| 322 |
+
try:
|
| 323 |
+
if self._state != AsyncState.DEFAULT:
|
| 324 |
+
logger.warn(
|
| 325 |
+
"Calling `close` while waiting for a pending "
|
| 326 |
+
"call to `{0}` to complete.".format(self._state.value)
|
| 327 |
+
)
|
| 328 |
+
function = getattr(self, "{0}_wait".format(self._state.value))
|
| 329 |
+
function(timeout)
|
| 330 |
+
except mp.TimeoutError:
|
| 331 |
+
terminate = True
|
| 332 |
+
|
| 333 |
+
if terminate:
|
| 334 |
+
for process in self.processes:
|
| 335 |
+
if process.is_alive():
|
| 336 |
+
process.terminate()
|
| 337 |
+
else:
|
| 338 |
+
for pipe in self.parent_pipes:
|
| 339 |
+
if (pipe is not None) and (not pipe.closed):
|
| 340 |
+
pipe.send(("close", None))
|
| 341 |
+
for pipe in self.parent_pipes:
|
| 342 |
+
if (pipe is not None) and (not pipe.closed):
|
| 343 |
+
pipe.recv()
|
| 344 |
+
|
| 345 |
+
for pipe in self.parent_pipes:
|
| 346 |
+
if pipe is not None:
|
| 347 |
+
pipe.close()
|
| 348 |
+
for process in self.processes:
|
| 349 |
+
process.join()
|
| 350 |
+
|
| 351 |
+
def _poll(self, timeout=None):
|
| 352 |
+
self._assert_is_running()
|
| 353 |
+
if timeout is None:
|
| 354 |
+
return True
|
| 355 |
+
end_time = time.perf_counter() + timeout
|
| 356 |
+
delta = None
|
| 357 |
+
for pipe in self.parent_pipes:
|
| 358 |
+
delta = max(end_time - time.perf_counter(), 0)
|
| 359 |
+
if pipe is None:
|
| 360 |
+
return False
|
| 361 |
+
if pipe.closed or (not pipe.poll(delta)):
|
| 362 |
+
return False
|
| 363 |
+
return True
|
| 364 |
+
|
| 365 |
+
def _check_observation_spaces(self):
|
| 366 |
+
self._assert_is_running()
|
| 367 |
+
for pipe in self.parent_pipes:
|
| 368 |
+
pipe.send(("_check_observation_space", self.single_observation_space))
|
| 369 |
+
same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
| 370 |
+
self._raise_if_errors(successes)
|
| 371 |
+
if not all(same_spaces):
|
| 372 |
+
raise RuntimeError(
|
| 373 |
+
"Some environments have an observation space "
|
| 374 |
+
"different from `{0}`. In order to batch observations, the "
|
| 375 |
+
"observation spaces from all environments must be "
|
| 376 |
+
"equal.".format(self.single_observation_space)
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
def _assert_is_running(self):
|
| 380 |
+
if self.closed:
|
| 381 |
+
raise ClosedEnvironmentError(
|
| 382 |
+
"Trying to operate on `{0}`, after a "
|
| 383 |
+
"call to `close()`.".format(type(self).__name__)
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
def _raise_if_errors(self, successes):
|
| 387 |
+
if all(successes):
|
| 388 |
+
return
|
| 389 |
+
|
| 390 |
+
num_errors = self.num_envs - sum(successes)
|
| 391 |
+
assert num_errors > 0
|
| 392 |
+
for _ in range(num_errors):
|
| 393 |
+
index, exctype, value = self.error_queue.get()
|
| 394 |
+
logger.error(
|
| 395 |
+
"Received the following error from Worker-{0}: "
|
| 396 |
+
"{1}: {2}".format(index, exctype.__name__, value)
|
| 397 |
+
)
|
| 398 |
+
logger.error("Shutting down Worker-{0}.".format(index))
|
| 399 |
+
self.parent_pipes[index].close()
|
| 400 |
+
self.parent_pipes[index] = None
|
| 401 |
+
|
| 402 |
+
logger.error("Raising the last exception back to the main process.")
|
| 403 |
+
raise exctype(value)
|
| 404 |
+
|
| 405 |
+
def call_async(self, name: str, *args, **kwargs):
|
| 406 |
+
"""Calls the method with name asynchronously and apply args and kwargs to the method.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
name: Name of the method or property to call.
|
| 410 |
+
*args: Arguments to apply to the method call.
|
| 411 |
+
**kwargs: Keyword arguments to apply to the method call.
|
| 412 |
+
|
| 413 |
+
Raises:
|
| 414 |
+
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
| 415 |
+
AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
|
| 416 |
+
"""
|
| 417 |
+
self._assert_is_running()
|
| 418 |
+
if self._state != AsyncState.DEFAULT:
|
| 419 |
+
raise AlreadyPendingCallError(
|
| 420 |
+
"Calling `call_async` while waiting "
|
| 421 |
+
f"for a pending call to `{self._state.value}` to complete.",
|
| 422 |
+
self._state.value,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
for pipe in self.parent_pipes:
|
| 426 |
+
pipe.send(("_call", (name, args, kwargs)))
|
| 427 |
+
self._state = AsyncState.WAITING_CALL
|
| 428 |
+
|
| 429 |
+
def call_wait(self, timeout = None) -> list:
|
| 430 |
+
"""Calls all parent pipes and waits for the results.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
timeout: Number of seconds before the call to `step_wait` times out.
|
| 434 |
+
If `None` (default), the call to `step_wait` never times out.
|
| 435 |
+
|
| 436 |
+
Returns:
|
| 437 |
+
List of the results of the individual calls to the method or property for each environment.
|
| 438 |
+
|
| 439 |
+
Raises:
|
| 440 |
+
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
|
| 441 |
+
TimeoutError: The call to `call_wait` has timed out after timeout second(s).
|
| 442 |
+
"""
|
| 443 |
+
self._assert_is_running()
|
| 444 |
+
if self._state != AsyncState.WAITING_CALL:
|
| 445 |
+
raise NoAsyncCallError(
|
| 446 |
+
"Calling `call_wait` without any prior call to `call_async`.",
|
| 447 |
+
AsyncState.WAITING_CALL.value,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
if not self._poll(timeout):
|
| 451 |
+
self._state = AsyncState.DEFAULT
|
| 452 |
+
raise mp.TimeoutError(
|
| 453 |
+
f"The call to `call_wait` has timed out after {timeout} second(s)."
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
| 457 |
+
self._raise_if_errors(successes)
|
| 458 |
+
self._state = AsyncState.DEFAULT
|
| 459 |
+
|
| 460 |
+
return results
|
| 461 |
+
|
| 462 |
+
def call(self, name: str, *args, **kwargs):
|
| 463 |
+
"""Call a method, or get a property, from each parallel environment.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
name (str): Name of the method or property to call.
|
| 467 |
+
*args: Arguments to apply to the method call.
|
| 468 |
+
**kwargs: Keyword arguments to apply to the method call.
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
List of the results of the individual calls to the method or property for each environment.
|
| 472 |
+
"""
|
| 473 |
+
self.call_async(name, *args, **kwargs)
|
| 474 |
+
return self.call_wait()
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def call_each(self, name: str,
|
| 478 |
+
args_list: list=None,
|
| 479 |
+
kwargs_list: list=None,
|
| 480 |
+
timeout = None):
|
| 481 |
+
n_envs = len(self.parent_pipes)
|
| 482 |
+
if args_list is None:
|
| 483 |
+
args_list = [[]] * n_envs
|
| 484 |
+
assert len(args_list) == n_envs
|
| 485 |
+
|
| 486 |
+
if kwargs_list is None:
|
| 487 |
+
kwargs_list = [dict()] * n_envs
|
| 488 |
+
assert len(kwargs_list) == n_envs
|
| 489 |
+
|
| 490 |
+
# send
|
| 491 |
+
self._assert_is_running()
|
| 492 |
+
if self._state != AsyncState.DEFAULT:
|
| 493 |
+
raise AlreadyPendingCallError(
|
| 494 |
+
"Calling `call_async` while waiting "
|
| 495 |
+
f"for a pending call to `{self._state.value}` to complete.",
|
| 496 |
+
self._state.value,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
for i, pipe in enumerate(self.parent_pipes):
|
| 500 |
+
pipe.send(("_call", (name, args_list[i], kwargs_list[i])))
|
| 501 |
+
self._state = AsyncState.WAITING_CALL
|
| 502 |
+
|
| 503 |
+
# receive
|
| 504 |
+
self._assert_is_running()
|
| 505 |
+
if self._state != AsyncState.WAITING_CALL:
|
| 506 |
+
raise NoAsyncCallError(
|
| 507 |
+
"Calling `call_wait` without any prior call to `call_async`.",
|
| 508 |
+
AsyncState.WAITING_CALL.value,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
if not self._poll(timeout):
|
| 512 |
+
self._state = AsyncState.DEFAULT
|
| 513 |
+
raise mp.TimeoutError(
|
| 514 |
+
f"The call to `call_wait` has timed out after {timeout} second(s)."
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
| 518 |
+
self._raise_if_errors(successes)
|
| 519 |
+
self._state = AsyncState.DEFAULT
|
| 520 |
+
|
| 521 |
+
return results
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def set_attr(self, name: str, values):
|
| 525 |
+
"""Sets an attribute of the sub-environments.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
name: Name of the property to be set in each individual environment.
|
| 529 |
+
values: Values of the property to be set to. If ``values`` is a list or
|
| 530 |
+
tuple, then it corresponds to the values for each individual
|
| 531 |
+
environment, otherwise a single value is set for all environments.
|
| 532 |
+
|
| 533 |
+
Raises:
|
| 534 |
+
ValueError: Values must be a list or tuple with length equal to the number of environments.
|
| 535 |
+
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
|
| 536 |
+
"""
|
| 537 |
+
self._assert_is_running()
|
| 538 |
+
if not isinstance(values, (list, tuple)):
|
| 539 |
+
values = [values for _ in range(self.num_envs)]
|
| 540 |
+
if len(values) != self.num_envs:
|
| 541 |
+
raise ValueError(
|
| 542 |
+
"Values must be a list or tuple with length equal to the "
|
| 543 |
+
f"number of environments. Got `{len(values)}` values for "
|
| 544 |
+
f"{self.num_envs} environments."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if self._state != AsyncState.DEFAULT:
|
| 548 |
+
raise AlreadyPendingCallError(
|
| 549 |
+
"Calling `set_attr` while waiting "
|
| 550 |
+
f"for a pending call to `{self._state.value}` to complete.",
|
| 551 |
+
self._state.value,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
for pipe, value in zip(self.parent_pipes, values):
|
| 555 |
+
pipe.send(("_setattr", (name, value)))
|
| 556 |
+
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
| 557 |
+
self._raise_if_errors(successes)
|
| 558 |
+
|
| 559 |
+
def render(self, *args, **kwargs):
|
| 560 |
+
return self.call('render', *args, **kwargs)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
| 565 |
+
assert shared_memory is None
|
| 566 |
+
env = env_fn()
|
| 567 |
+
parent_pipe.close()
|
| 568 |
+
try:
|
| 569 |
+
while True:
|
| 570 |
+
command, data = pipe.recv()
|
| 571 |
+
if command == "reset":
|
| 572 |
+
observation = env.reset()
|
| 573 |
+
pipe.send((observation, True))
|
| 574 |
+
elif command == "step":
|
| 575 |
+
observation, reward, done, info = env.step(data)
|
| 576 |
+
# if done:
|
| 577 |
+
# observation = env.reset()
|
| 578 |
+
pipe.send(((observation, reward, done, info), True))
|
| 579 |
+
elif command == "seed":
|
| 580 |
+
env.seed(data)
|
| 581 |
+
pipe.send((None, True))
|
| 582 |
+
elif command == "close":
|
| 583 |
+
pipe.send((None, True))
|
| 584 |
+
break
|
| 585 |
+
elif command == "_call":
|
| 586 |
+
name, args, kwargs = data
|
| 587 |
+
if name in ["reset", "step", "seed", "close"]:
|
| 588 |
+
raise ValueError(
|
| 589 |
+
f"Trying to call function `{name}` with "
|
| 590 |
+
f"`_call`. Use `{name}` directly instead."
|
| 591 |
+
)
|
| 592 |
+
function = getattr(env, name)
|
| 593 |
+
if callable(function):
|
| 594 |
+
pipe.send((function(*args, **kwargs), True))
|
| 595 |
+
else:
|
| 596 |
+
pipe.send((function, True))
|
| 597 |
+
elif command == "_setattr":
|
| 598 |
+
name, value = data
|
| 599 |
+
setattr(env, name, value)
|
| 600 |
+
pipe.send((None, True))
|
| 601 |
+
|
| 602 |
+
elif command == "_check_observation_space":
|
| 603 |
+
pipe.send((data == env.observation_space, True))
|
| 604 |
+
else:
|
| 605 |
+
raise RuntimeError(
|
| 606 |
+
"Received unknown command `{0}`. Must "
|
| 607 |
+
"be one of {`reset`, `step`, `seed`, `close`, "
|
| 608 |
+
"`_check_observation_space`}.".format(command)
|
| 609 |
+
)
|
| 610 |
+
except (KeyboardInterrupt, Exception):
|
| 611 |
+
error_queue.put((index,) + sys.exc_info()[:2])
|
| 612 |
+
pipe.send((None, False))
|
| 613 |
+
finally:
|
| 614 |
+
env.close()
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
| 618 |
+
assert shared_memory is not None
|
| 619 |
+
env = env_fn()
|
| 620 |
+
observation_space = env.observation_space
|
| 621 |
+
parent_pipe.close()
|
| 622 |
+
try:
|
| 623 |
+
while True:
|
| 624 |
+
command, data = pipe.recv()
|
| 625 |
+
if command == "reset":
|
| 626 |
+
observation = env.reset()
|
| 627 |
+
write_to_shared_memory(
|
| 628 |
+
index, observation, shared_memory, observation_space
|
| 629 |
+
)
|
| 630 |
+
pipe.send((None, True))
|
| 631 |
+
elif command == "step":
|
| 632 |
+
observation, reward, done, info = env.step(data)
|
| 633 |
+
# if done:
|
| 634 |
+
# observation = env.reset()
|
| 635 |
+
write_to_shared_memory(
|
| 636 |
+
index, observation, shared_memory, observation_space
|
| 637 |
+
)
|
| 638 |
+
pipe.send(((None, reward, done, info), True))
|
| 639 |
+
elif command == "seed":
|
| 640 |
+
env.seed(data)
|
| 641 |
+
pipe.send((None, True))
|
| 642 |
+
elif command == "close":
|
| 643 |
+
pipe.send((None, True))
|
| 644 |
+
break
|
| 645 |
+
elif command == "_call":
|
| 646 |
+
name, args, kwargs = data
|
| 647 |
+
if name in ["reset", "step", "seed", "close"]:
|
| 648 |
+
raise ValueError(
|
| 649 |
+
f"Trying to call function `{name}` with "
|
| 650 |
+
f"`_call`. Use `{name}` directly instead."
|
| 651 |
+
)
|
| 652 |
+
function = getattr(env, name)
|
| 653 |
+
if callable(function):
|
| 654 |
+
pipe.send((function(*args, **kwargs), True))
|
| 655 |
+
else:
|
| 656 |
+
pipe.send((function, True))
|
| 657 |
+
elif command == "_setattr":
|
| 658 |
+
name, value = data
|
| 659 |
+
setattr(env, name, value)
|
| 660 |
+
pipe.send((None, True))
|
| 661 |
+
elif command == "_check_observation_space":
|
| 662 |
+
pipe.send((data == observation_space, True))
|
| 663 |
+
else:
|
| 664 |
+
raise RuntimeError(
|
| 665 |
+
"Received unknown command `{0}`. Must "
|
| 666 |
+
"be one of {`reset`, `step`, `seed`, `close`, "
|
| 667 |
+
"`_check_observation_space`}.".format(command)
|
| 668 |
+
)
|
| 669 |
+
except (KeyboardInterrupt, Exception):
|
| 670 |
+
error_queue.put((index,) + sys.exc_info()[:2])
|
| 671 |
+
pipe.send((None, False))
|
| 672 |
+
finally:
|
| 673 |
+
env.close()
|
equidiff/equi_diffpo/gym_util/multistep_wrapper.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gym
|
| 2 |
+
from gym import spaces
|
| 3 |
+
import numpy as np
|
| 4 |
+
from collections import defaultdict, deque
|
| 5 |
+
import dill
|
| 6 |
+
|
| 7 |
+
def stack_repeated(x, n):
|
| 8 |
+
return np.repeat(np.expand_dims(x,axis=0),n,axis=0)
|
| 9 |
+
|
| 10 |
+
def repeated_box(box_space, n):
|
| 11 |
+
return spaces.Box(
|
| 12 |
+
low=stack_repeated(box_space.low, n),
|
| 13 |
+
high=stack_repeated(box_space.high, n),
|
| 14 |
+
shape=(n,) + box_space.shape,
|
| 15 |
+
dtype=box_space.dtype
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def repeated_space(space, n):
|
| 19 |
+
if isinstance(space, spaces.Box):
|
| 20 |
+
return repeated_box(space, n)
|
| 21 |
+
elif isinstance(space, spaces.Dict):
|
| 22 |
+
result_space = spaces.Dict()
|
| 23 |
+
for key, value in space.items():
|
| 24 |
+
result_space[key] = repeated_space(value, n)
|
| 25 |
+
return result_space
|
| 26 |
+
else:
|
| 27 |
+
raise RuntimeError(f'Unsupported space type {type(space)}')
|
| 28 |
+
|
| 29 |
+
def take_last_n(x, n):
|
| 30 |
+
x = list(x)
|
| 31 |
+
n = min(len(x), n)
|
| 32 |
+
return np.array(x[-n:])
|
| 33 |
+
|
| 34 |
+
def dict_take_last_n(x, n):
|
| 35 |
+
result = dict()
|
| 36 |
+
for key, value in x.items():
|
| 37 |
+
result[key] = take_last_n(value, n)
|
| 38 |
+
return result
|
| 39 |
+
|
| 40 |
+
def aggregate(data, method='max'):
|
| 41 |
+
if method == 'max':
|
| 42 |
+
# equivalent to any
|
| 43 |
+
return np.max(data)
|
| 44 |
+
elif method == 'min':
|
| 45 |
+
# equivalent to all
|
| 46 |
+
return np.min(data)
|
| 47 |
+
elif method == 'mean':
|
| 48 |
+
return np.mean(data)
|
| 49 |
+
elif method == 'sum':
|
| 50 |
+
return np.sum(data)
|
| 51 |
+
else:
|
| 52 |
+
raise NotImplementedError()
|
| 53 |
+
|
| 54 |
+
def stack_last_n_obs(all_obs, n_steps):
|
| 55 |
+
assert(len(all_obs) > 0)
|
| 56 |
+
all_obs = list(all_obs)
|
| 57 |
+
result = np.zeros((n_steps,) + all_obs[-1].shape,
|
| 58 |
+
dtype=all_obs[-1].dtype)
|
| 59 |
+
start_idx = -min(n_steps, len(all_obs))
|
| 60 |
+
result[start_idx:] = np.array(all_obs[start_idx:])
|
| 61 |
+
if n_steps > len(all_obs):
|
| 62 |
+
# pad
|
| 63 |
+
result[:start_idx] = result[start_idx]
|
| 64 |
+
return result
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MultiStepWrapper(gym.Wrapper):
|
| 68 |
+
def __init__(self,
|
| 69 |
+
env,
|
| 70 |
+
n_obs_steps,
|
| 71 |
+
n_action_steps,
|
| 72 |
+
max_episode_steps=None,
|
| 73 |
+
reward_agg_method='max'
|
| 74 |
+
):
|
| 75 |
+
super().__init__(env)
|
| 76 |
+
self._action_space = repeated_space(env.action_space, n_action_steps)
|
| 77 |
+
self._observation_space = repeated_space(env.observation_space, n_obs_steps)
|
| 78 |
+
self.max_episode_steps = max_episode_steps
|
| 79 |
+
self.n_obs_steps = n_obs_steps
|
| 80 |
+
self.n_action_steps = n_action_steps
|
| 81 |
+
self.reward_agg_method = reward_agg_method
|
| 82 |
+
self.n_obs_steps = n_obs_steps
|
| 83 |
+
|
| 84 |
+
self.obs = deque(maxlen=n_obs_steps+1)
|
| 85 |
+
self.reward = list()
|
| 86 |
+
self.done = list()
|
| 87 |
+
self.info = defaultdict(lambda : deque(maxlen=n_obs_steps+1))
|
| 88 |
+
|
| 89 |
+
def reset(self):
|
| 90 |
+
"""Resets the environment using kwargs."""
|
| 91 |
+
obs = super().reset()
|
| 92 |
+
|
| 93 |
+
self.obs = deque([obs], maxlen=self.n_obs_steps+1)
|
| 94 |
+
self.reward = list()
|
| 95 |
+
self.done = list()
|
| 96 |
+
self.info = defaultdict(lambda : deque(maxlen=self.n_obs_steps+1))
|
| 97 |
+
|
| 98 |
+
obs = self._get_obs(self.n_obs_steps)
|
| 99 |
+
return obs
|
| 100 |
+
|
| 101 |
+
def step(self, action):
|
| 102 |
+
"""
|
| 103 |
+
actions: (n_action_steps,) + action_shape
|
| 104 |
+
"""
|
| 105 |
+
for act in action:
|
| 106 |
+
if len(self.done) > 0 and self.done[-1]:
|
| 107 |
+
# termination
|
| 108 |
+
break
|
| 109 |
+
observation, reward, done, info = super().step(act)
|
| 110 |
+
|
| 111 |
+
self.obs.append(observation)
|
| 112 |
+
self.reward.append(reward)
|
| 113 |
+
if (self.max_episode_steps is not None) \
|
| 114 |
+
and (len(self.reward) >= self.max_episode_steps):
|
| 115 |
+
# truncation
|
| 116 |
+
done = True
|
| 117 |
+
self.done.append(done)
|
| 118 |
+
self._add_info(info)
|
| 119 |
+
|
| 120 |
+
observation = self._get_obs(self.n_obs_steps)
|
| 121 |
+
reward = aggregate(self.reward, self.reward_agg_method)
|
| 122 |
+
done = aggregate(self.done, 'max')
|
| 123 |
+
info = dict_take_last_n(self.info, self.n_obs_steps)
|
| 124 |
+
return observation, reward, done, info
|
| 125 |
+
|
| 126 |
+
def _get_obs(self, n_steps=1):
|
| 127 |
+
"""
|
| 128 |
+
Output (n_steps,) + obs_shape
|
| 129 |
+
"""
|
| 130 |
+
assert(len(self.obs) > 0)
|
| 131 |
+
if isinstance(self.observation_space, spaces.Box):
|
| 132 |
+
return stack_last_n_obs(self.obs, n_steps)
|
| 133 |
+
elif isinstance(self.observation_space, spaces.Dict):
|
| 134 |
+
result = dict()
|
| 135 |
+
for key in self.observation_space.keys():
|
| 136 |
+
result[key] = stack_last_n_obs(
|
| 137 |
+
[obs[key] for obs in self.obs],
|
| 138 |
+
n_steps
|
| 139 |
+
)
|
| 140 |
+
return result
|
| 141 |
+
else:
|
| 142 |
+
raise RuntimeError('Unsupported space type')
|
| 143 |
+
|
| 144 |
+
def _add_info(self, info):
|
| 145 |
+
for key, value in info.items():
|
| 146 |
+
self.info[key].append(value)
|
| 147 |
+
|
| 148 |
+
def get_rewards(self):
|
| 149 |
+
return self.reward
|
| 150 |
+
|
| 151 |
+
def get_attr(self, name):
|
| 152 |
+
return getattr(self, name)
|
| 153 |
+
|
| 154 |
+
def run_dill_function(self, dill_fn):
|
| 155 |
+
fn = dill.loads(dill_fn)
|
| 156 |
+
return fn(self)
|
| 157 |
+
|
| 158 |
+
def get_infos(self):
|
| 159 |
+
result = dict()
|
| 160 |
+
for k, v in self.info.items():
|
| 161 |
+
result[k] = list(v)
|
| 162 |
+
return result
|