if __name__ == "__main__": import sys import os import pathlib ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) sys.path.append(ROOT_DIR) os.chdir(ROOT_DIR) import os import hydra import torch from omegaconf import OmegaConf import pathlib from pathlib import Path from torch.utils.data import DataLoader import copy import random import wandb import tqdm import numpy as np import shutil import json from equi_diffpo.workspace.base_workspace import BaseWorkspace from equi_diffpo.policy.base_image_policy import BaseImagePolicy from equi_diffpo.dataset.base_dataset import BaseImageDataset from equi_diffpo.env_runner.base_image_runner import BaseImageRunner from equi_diffpo.common.checkpoint_util import TopKCheckpointManager from equi_diffpo.common.json_logger import JsonLogger from equi_diffpo.common.pytorch_util import dict_apply, optimizer_to from equi_diffpo.model.diffusion.ema_model import EMAModel from equi_diffpo.model.common.lr_scheduler import get_scheduler OmegaConf.register_new_resolver("eval", eval, replace=True) class TestEquiWorkspace(BaseWorkspace): include_keys = ['global_step', 'epoch'] def __init__(self, cfg: OmegaConf, output_dir=None): super().__init__(cfg, output_dir=output_dir) # set seed seed = cfg.training.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # configure model self.model: BaseImagePolicy = hydra.utils.instantiate(cfg.policy) self.ema_model: BaseImagePolicy = None if cfg.training.use_ema: self.ema_model = copy.deepcopy(self.model) # configure training state self.optimizer = self.model.get_optimizer(**cfg.optimizer) total_params = 0 for param_group in self.optimizer.param_groups: for param in param_group['params']: total_params += param.numel() assert total_params == sum(p.numel() for p in self.model.parameters() if p.requires_grad) # configure training state self.global_step = 0 self.epoch = 0 def run(self): cfg = copy.deepcopy(self.cfg) # resume training ckpt_path = Path(cfg.training.ckpt_path) if ckpt_path.is_file(): print(f"Resuming from checkpoint {ckpt_path}") self.load_checkpoint(path=ckpt_path) # configure dataset dataset: BaseImageDataset # print(cfg.task.dataset) dataset = hydra.utils.instantiate(cfg.task.dataset) assert isinstance(dataset, BaseImageDataset) normalizer = dataset.get_normalizer() self.model.set_normalizer(normalizer) if cfg.training.use_ema: self.ema_model.set_normalizer(normalizer) # configure env env_runner: BaseImageRunner env_runner = hydra.utils.instantiate( cfg.task.env_runner, output_dir=self.output_dir) assert isinstance(env_runner, BaseImageRunner) # configure logging # wandb_run = wandb.init( # dir=str(self.output_dir), # config=OmegaConf.to_container(cfg, resolve=True), # **cfg.logging # ) # wandb.config.update( # { # "output_dir": self.output_dir, # } # ) # device transfer device = torch.device(cfg.training.device) self.model.to(device) if self.ema_model is not None: self.ema_model.to(device) optimizer_to(self.optimizer, device) if cfg.training.debug: cfg.training.num_epochs = 2 cfg.training.max_train_steps = 3 cfg.training.max_val_steps = 3 cfg.training.rollout_every = 1 cfg.training.checkpoint_every = 1 cfg.training.val_every = 1 cfg.training.sample_every = 1 cfg.training.num_epochs=1 # training loop log_path = os.path.join(self.output_dir, 'logs.json.txt') with JsonLogger(log_path) as json_logger: step_log = dict() # ========= eval for this epoch ========== policy = self.model if cfg.training.use_ema: policy = self.ema_model policy.eval() # run rollout runner_log = env_runner.run(policy) # log all step_log.update(runner_log) # wandb_run.log(step_log, step=self.global_step) json_logger.log(cfg) json_logger.log(step_log) print(cfg.policy.n_action_steps) print(step_log['test/mean_score']) with open(cfg.log_txt_path, 'a') as f: log_data = { "run_name": cfg.run_name, "task": cfg.task_name, "diversity": cfg.diversity, "ckpt_path": cfg.ckpt_path, "demo": cfg.n_demo, "n_action_steps": cfg.policy.n_action_steps, "horizon": cfg.horizon, "test/mean_score": step_log['test/mean_score'] } f.write(json.dumps(log_data) + '\n') @hydra.main( version_base=None, config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), config_name=pathlib.Path(__file__).stem) def main(cfg): workspace = TestEquiWorkspace(cfg) workspace.run() if __name__ == "__main__": main()