|
if __name__ == "__main__": |
|
import sys |
|
import os |
|
import pathlib |
|
|
|
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) |
|
sys.path.append(ROOT_DIR) |
|
os.chdir(ROOT_DIR) |
|
|
|
import os |
|
import hydra |
|
import torch |
|
from omegaconf import OmegaConf |
|
import pathlib |
|
from torch.utils.data import DataLoader |
|
import copy |
|
import random |
|
import wandb |
|
import tqdm |
|
import numpy as np |
|
import shutil |
|
from equi_diffpo.workspace.base_workspace import BaseWorkspace |
|
from equi_diffpo.policy.robomimic_image_policy import RobomimicImagePolicy |
|
from equi_diffpo.dataset.base_dataset import BaseImageDataset |
|
from equi_diffpo.env_runner.base_image_runner import BaseImageRunner |
|
from equi_diffpo.common.checkpoint_util import TopKCheckpointManager |
|
from equi_diffpo.common.json_logger import JsonLogger |
|
from equi_diffpo.common.pytorch_util import dict_apply, optimizer_to |
|
|
|
|
|
OmegaConf.register_new_resolver("eval", eval, replace=True) |
|
|
|
class TrainRobomimicImageWorkspace(BaseWorkspace): |
|
include_keys = ['global_step', 'epoch'] |
|
|
|
def __init__(self, cfg: OmegaConf, output_dir=None): |
|
super().__init__(cfg, output_dir=output_dir) |
|
|
|
|
|
seed = cfg.training.seed |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
|
|
self.model: RobomimicImagePolicy = hydra.utils.instantiate(cfg.policy) |
|
|
|
|
|
self.global_step = 0 |
|
self.epoch = 0 |
|
|
|
def run(self): |
|
cfg = copy.deepcopy(self.cfg) |
|
|
|
|
|
if cfg.training.resume: |
|
lastest_ckpt_path = self.get_checkpoint_path() |
|
if lastest_ckpt_path.is_file(): |
|
print(f"Resuming from checkpoint {lastest_ckpt_path}") |
|
self.load_checkpoint(path=lastest_ckpt_path) |
|
self.epoch += 1 |
|
self.global_step += 1 |
|
|
|
|
|
dataset: BaseImageDataset |
|
dataset = hydra.utils.instantiate(cfg.task.dataset) |
|
assert isinstance(dataset, BaseImageDataset) |
|
train_dataloader = DataLoader(dataset, **cfg.dataloader) |
|
normalizer = dataset.get_normalizer() |
|
|
|
|
|
val_dataset = dataset.get_validation_dataset() |
|
val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader) |
|
|
|
self.model.set_normalizer(normalizer) |
|
|
|
|
|
env_runner: BaseImageRunner |
|
env_runner = hydra.utils.instantiate( |
|
cfg.task.env_runner, |
|
output_dir=self.output_dir) |
|
assert isinstance(env_runner, BaseImageRunner) |
|
|
|
|
|
wandb_run = wandb.init( |
|
dir=str(self.output_dir), |
|
config=OmegaConf.to_container(cfg, resolve=True), |
|
**cfg.logging |
|
) |
|
wandb.config.update( |
|
{ |
|
"output_dir": self.output_dir, |
|
} |
|
) |
|
|
|
|
|
topk_manager = TopKCheckpointManager( |
|
save_dir=os.path.join(self.output_dir, 'checkpoints'), |
|
**cfg.checkpoint.topk |
|
) |
|
|
|
|
|
device = torch.device(cfg.training.device) |
|
self.model.to(device) |
|
|
|
|
|
train_sampling_batch = None |
|
|
|
if cfg.training.debug: |
|
cfg.training.num_epochs = 2 |
|
cfg.training.max_train_steps = 3 |
|
cfg.training.max_val_steps = 3 |
|
cfg.training.rollout_every = 1 |
|
cfg.training.checkpoint_every = 1 |
|
cfg.training.val_every = 1 |
|
cfg.training.sample_every = 1 |
|
|
|
|
|
log_path = os.path.join(self.output_dir, 'logs.json.txt') |
|
with JsonLogger(log_path) as json_logger: |
|
while self.epoch < cfg.training.num_epochs: |
|
step_log = dict() |
|
|
|
train_losses = list() |
|
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", |
|
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: |
|
for batch_idx, batch in enumerate(tepoch): |
|
|
|
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) |
|
if train_sampling_batch is None: |
|
train_sampling_batch = batch |
|
|
|
info = self.model.train_on_batch(batch, epoch=self.epoch) |
|
|
|
|
|
loss_cpu = info['losses']['action_loss'].item() |
|
tepoch.set_postfix(loss=loss_cpu, refresh=False) |
|
train_losses.append(loss_cpu) |
|
step_log = { |
|
'train_loss': loss_cpu, |
|
'global_step': self.global_step, |
|
'epoch': self.epoch |
|
} |
|
|
|
is_last_batch = (batch_idx == (len(train_dataloader)-1)) |
|
if not is_last_batch: |
|
|
|
wandb_run.log(step_log, step=self.global_step) |
|
json_logger.log(step_log) |
|
self.global_step += 1 |
|
|
|
if (cfg.training.max_train_steps is not None) \ |
|
and batch_idx >= (cfg.training.max_train_steps-1): |
|
break |
|
|
|
|
|
|
|
train_loss = np.mean(train_losses) |
|
step_log['train_loss'] = train_loss |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
if (self.epoch % cfg.training.rollout_every) == 0: |
|
runner_log = env_runner.run(self.model) |
|
|
|
step_log.update(runner_log) |
|
|
|
|
|
if (self.epoch % cfg.training.val_every) == 0: |
|
with torch.no_grad(): |
|
val_losses = list() |
|
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", |
|
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: |
|
for batch_idx, batch in enumerate(tepoch): |
|
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) |
|
info = self.model.train_on_batch(batch, epoch=self.epoch, validate=True) |
|
loss = info['losses']['action_loss'] |
|
val_losses.append(loss) |
|
if (cfg.training.max_val_steps is not None) \ |
|
and batch_idx >= (cfg.training.max_val_steps-1): |
|
break |
|
if len(val_losses) > 0: |
|
val_loss = torch.mean(torch.tensor(val_losses)).item() |
|
|
|
step_log['val_loss'] = val_loss |
|
|
|
|
|
if (self.epoch % cfg.training.sample_every) == 0: |
|
with torch.no_grad(): |
|
|
|
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) |
|
obs_dict = batch['obs'] |
|
gt_action = batch['action'] |
|
T = gt_action.shape[1] |
|
|
|
pred_actions = list() |
|
self.model.reset() |
|
for i in range(T): |
|
result = self.model.predict_action( |
|
dict_apply(obs_dict, lambda x: x[:,[i]]) |
|
) |
|
pred_actions.append(result['action']) |
|
pred_actions = torch.cat(pred_actions, dim=1) |
|
mse = torch.nn.functional.mse_loss(pred_actions, gt_action) |
|
step_log['train_action_mse_error'] = mse.item() |
|
del batch |
|
del obs_dict |
|
del gt_action |
|
del result |
|
del pred_actions |
|
del mse |
|
|
|
|
|
if (self.epoch % cfg.training.checkpoint_every) == 0: |
|
|
|
if cfg.checkpoint.save_last_ckpt: |
|
self.save_checkpoint() |
|
if cfg.checkpoint.save_last_snapshot: |
|
self.save_snapshot() |
|
|
|
|
|
metric_dict = dict() |
|
for key, value in step_log.items(): |
|
new_key = key.replace('/', '_') |
|
metric_dict[new_key] = value |
|
|
|
|
|
|
|
|
|
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) |
|
|
|
if topk_ckpt_path is not None: |
|
self.save_checkpoint(path=topk_ckpt_path) |
|
|
|
self.model.train() |
|
|
|
|
|
|
|
wandb_run.log(step_log, step=self.global_step) |
|
json_logger.log(step_log) |
|
self.global_step += 1 |
|
self.epoch += 1 |
|
|
|
|
|
@hydra.main( |
|
version_base=None, |
|
config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), |
|
config_name=pathlib.Path(__file__).stem) |
|
def main(cfg): |
|
workspace = TrainRobomimicImageWorkspace(cfg) |
|
workspace.run() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|