if __name__ == "__main__": import sys import os import pathlib ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) sys.path.append(ROOT_DIR) import multiprocessing import os import shutil import click import pathlib import h5py from tqdm import tqdm import collections import pickle from equi_diffpo.common.robomimic_util import RobomimicAbsoluteActionConverter def worker(x): path, idx, do_eval = x converter = RobomimicAbsoluteActionConverter(path) if do_eval: abs_actions, info = converter.convert_and_eval_idx(idx) else: abs_actions = converter.convert_idx(idx) info = dict() return abs_actions, info @click.command() @click.option('-i', '--input', required=True, help='input hdf5 path') @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist') @click.option('-e', '--eval_dir', default=None, help='directory to output evaluation metrics') @click.option('-n', '--num_workers', default=None, type=int) def main(input, output, eval_dir, num_workers): # process inputs input = pathlib.Path(input).expanduser() assert input.is_file() output = pathlib.Path(output).expanduser() assert output.parent.is_dir() assert not output.is_dir() do_eval = False if eval_dir is not None: eval_dir = pathlib.Path(eval_dir).expanduser() assert eval_dir.parent.exists() do_eval = True converter = RobomimicAbsoluteActionConverter(input) # run with multiprocessing.Pool(num_workers) as pool: results = pool.map(worker, [(input, i, do_eval) for i in range(len(converter))]) # save output print('Copying hdf5') shutil.copy(str(input), str(output)) # modify action with h5py.File(output, 'r+') as out_file: for i in tqdm(range(len(converter)), desc="Writing to output"): abs_actions, info = results[i] demo = out_file[f'data/demo_{i}'] demo['actions'][:] = abs_actions # save eval if do_eval: eval_dir.mkdir(parents=False, exist_ok=True) print("Writing error_stats.pkl") infos = [info for _, info in results] pickle.dump(infos, eval_dir.joinpath('error_stats.pkl').open('wb')) print("Generating visualization") metrics = ['pos', 'rot'] metrics_dicts = dict() for m in metrics: metrics_dicts[m] = collections.defaultdict(list) for i in range(len(infos)): info = infos[i] for k, v in info.items(): for m in metrics: metrics_dicts[m][k].append(v[m]) from matplotlib import pyplot as plt plt.switch_backend('PDF') fig, ax = plt.subplots(1, len(metrics)) for i in range(len(metrics)): axis = ax[i] data = metrics_dicts[metrics[i]] for key, value in data.items(): axis.plot(value, label=key) axis.legend() axis.set_title(metrics[i]) fig.set_size_inches(10,4) fig.savefig(str(eval_dir.joinpath('error_stats.pdf'))) fig.savefig(str(eval_dir.joinpath('error_stats.png'))) if __name__ == "__main__": main()