|
if __name__ == "__main__": |
|
import sys |
|
import os |
|
import pathlib |
|
|
|
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) |
|
sys.path.append(ROOT_DIR) |
|
|
|
import multiprocessing |
|
import os |
|
import shutil |
|
import click |
|
import pathlib |
|
import h5py |
|
from tqdm import tqdm |
|
import collections |
|
import pickle |
|
from equi_diffpo.common.robomimic_util import RobomimicAbsoluteActionConverter |
|
|
|
def worker(x): |
|
path, idx, do_eval = x |
|
converter = RobomimicAbsoluteActionConverter(path) |
|
if do_eval: |
|
abs_actions, info = converter.convert_and_eval_idx(idx) |
|
else: |
|
abs_actions = converter.convert_idx(idx) |
|
info = dict() |
|
return abs_actions, info |
|
|
|
@click.command() |
|
@click.option('-i', '--input', required=True, help='input hdf5 path') |
|
@click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist') |
|
@click.option('-e', '--eval_dir', default=None, help='directory to output evaluation metrics') |
|
@click.option('-n', '--num_workers', default=None, type=int) |
|
def main(input, output, eval_dir, num_workers): |
|
|
|
input = pathlib.Path(input).expanduser() |
|
assert input.is_file() |
|
output = pathlib.Path(output).expanduser() |
|
assert output.parent.is_dir() |
|
assert not output.is_dir() |
|
|
|
do_eval = False |
|
if eval_dir is not None: |
|
eval_dir = pathlib.Path(eval_dir).expanduser() |
|
assert eval_dir.parent.exists() |
|
do_eval = True |
|
|
|
converter = RobomimicAbsoluteActionConverter(input) |
|
|
|
|
|
with multiprocessing.Pool(num_workers) as pool: |
|
results = pool.map(worker, [(input, i, do_eval) for i in range(len(converter))]) |
|
|
|
|
|
print('Copying hdf5') |
|
shutil.copy(str(input), str(output)) |
|
|
|
|
|
with h5py.File(output, 'r+') as out_file: |
|
for i in tqdm(range(len(converter)), desc="Writing to output"): |
|
abs_actions, info = results[i] |
|
demo = out_file[f'data/demo_{i}'] |
|
demo['actions'][:] = abs_actions |
|
|
|
|
|
if do_eval: |
|
eval_dir.mkdir(parents=False, exist_ok=True) |
|
|
|
print("Writing error_stats.pkl") |
|
infos = [info for _, info in results] |
|
pickle.dump(infos, eval_dir.joinpath('error_stats.pkl').open('wb')) |
|
|
|
print("Generating visualization") |
|
metrics = ['pos', 'rot'] |
|
metrics_dicts = dict() |
|
for m in metrics: |
|
metrics_dicts[m] = collections.defaultdict(list) |
|
|
|
for i in range(len(infos)): |
|
info = infos[i] |
|
for k, v in info.items(): |
|
for m in metrics: |
|
metrics_dicts[m][k].append(v[m]) |
|
|
|
from matplotlib import pyplot as plt |
|
plt.switch_backend('PDF') |
|
|
|
fig, ax = plt.subplots(1, len(metrics)) |
|
for i in range(len(metrics)): |
|
axis = ax[i] |
|
data = metrics_dicts[metrics[i]] |
|
for key, value in data.items(): |
|
axis.plot(value, label=key) |
|
axis.legend() |
|
axis.set_title(metrics[i]) |
|
fig.set_size_inches(10,4) |
|
fig.savefig(str(eval_dir.joinpath('error_stats.pdf'))) |
|
fig.savefig(str(eval_dir.joinpath('error_stats.png'))) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|