|
if __name__ == "__main__": |
|
import sys |
|
import os |
|
import pathlib |
|
|
|
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) |
|
sys.path.append(ROOT_DIR) |
|
os.chdir(ROOT_DIR) |
|
|
|
import os |
|
import hydra |
|
import torch |
|
from omegaconf import OmegaConf |
|
import pathlib |
|
from pathlib import Path |
|
from torch.utils.data import DataLoader |
|
import copy |
|
import random |
|
import wandb |
|
import tqdm |
|
import numpy as np |
|
import shutil |
|
import json |
|
from equi_diffpo.workspace.base_workspace import BaseWorkspace |
|
from equi_diffpo.policy.base_image_policy import BaseImagePolicy |
|
from equi_diffpo.dataset.base_dataset import BaseImageDataset |
|
from equi_diffpo.env_runner.base_image_runner import BaseImageRunner |
|
from equi_diffpo.common.checkpoint_util import TopKCheckpointManager |
|
from equi_diffpo.common.json_logger import JsonLogger |
|
from equi_diffpo.common.pytorch_util import dict_apply, optimizer_to |
|
from equi_diffpo.model.diffusion.ema_model import EMAModel |
|
from equi_diffpo.model.common.lr_scheduler import get_scheduler |
|
|
|
OmegaConf.register_new_resolver("eval", eval, replace=True) |
|
|
|
class TestEquiWorkspace(BaseWorkspace): |
|
include_keys = ['global_step', 'epoch'] |
|
|
|
def __init__(self, cfg: OmegaConf, output_dir=None): |
|
super().__init__(cfg, output_dir=output_dir) |
|
|
|
|
|
seed = cfg.training.seed |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
|
|
self.model: BaseImagePolicy = hydra.utils.instantiate(cfg.policy) |
|
|
|
self.ema_model: BaseImagePolicy = None |
|
if cfg.training.use_ema: |
|
self.ema_model = copy.deepcopy(self.model) |
|
|
|
|
|
self.optimizer = self.model.get_optimizer(**cfg.optimizer) |
|
total_params = 0 |
|
|
|
for param_group in self.optimizer.param_groups: |
|
for param in param_group['params']: |
|
total_params += param.numel() |
|
assert total_params == sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
|
|
|
self.global_step = 0 |
|
self.epoch = 0 |
|
|
|
def run(self): |
|
cfg = copy.deepcopy(self.cfg) |
|
|
|
|
|
ckpt_path = Path(cfg.training.ckpt_path) |
|
if ckpt_path.is_file(): |
|
print(f"Resuming from checkpoint {ckpt_path}") |
|
self.load_checkpoint(path=ckpt_path) |
|
|
|
|
|
dataset: BaseImageDataset |
|
|
|
dataset = hydra.utils.instantiate(cfg.task.dataset) |
|
assert isinstance(dataset, BaseImageDataset) |
|
normalizer = dataset.get_normalizer() |
|
|
|
self.model.set_normalizer(normalizer) |
|
if cfg.training.use_ema: |
|
self.ema_model.set_normalizer(normalizer) |
|
|
|
|
|
env_runner: BaseImageRunner |
|
env_runner = hydra.utils.instantiate( |
|
cfg.task.env_runner, |
|
output_dir=self.output_dir) |
|
assert isinstance(env_runner, BaseImageRunner) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device(cfg.training.device) |
|
self.model.to(device) |
|
if self.ema_model is not None: |
|
self.ema_model.to(device) |
|
optimizer_to(self.optimizer, device) |
|
|
|
if cfg.training.debug: |
|
cfg.training.num_epochs = 2 |
|
cfg.training.max_train_steps = 3 |
|
cfg.training.max_val_steps = 3 |
|
cfg.training.rollout_every = 1 |
|
cfg.training.checkpoint_every = 1 |
|
cfg.training.val_every = 1 |
|
cfg.training.sample_every = 1 |
|
|
|
cfg.training.num_epochs=1 |
|
|
|
log_path = os.path.join(self.output_dir, 'logs.json.txt') |
|
with JsonLogger(log_path) as json_logger: |
|
step_log = dict() |
|
|
|
policy = self.model |
|
if cfg.training.use_ema: |
|
policy = self.ema_model |
|
policy.eval() |
|
|
|
|
|
runner_log = env_runner.run(policy) |
|
|
|
step_log.update(runner_log) |
|
|
|
|
|
json_logger.log(cfg) |
|
json_logger.log(step_log) |
|
print(cfg.policy.n_action_steps) |
|
print(step_log['test/mean_score']) |
|
|
|
with open(cfg.log_txt_path, 'a') as f: |
|
log_data = { |
|
"run_name": cfg.run_name, |
|
"task": cfg.task_name, |
|
"diversity": cfg.diversity, |
|
"ckpt_path": cfg.ckpt_path, |
|
"demo": cfg.n_demo, |
|
"n_action_steps": cfg.policy.n_action_steps, |
|
"horizon": cfg.horizon, |
|
"test/mean_score": step_log['test/mean_score'] |
|
} |
|
f.write(json.dumps(log_data) + '\n') |
|
@hydra.main( |
|
version_base=None, |
|
config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), |
|
config_name=pathlib.Path(__file__).stem) |
|
def main(cfg): |
|
workspace = TestEquiWorkspace(cfg) |
|
workspace.run() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|