|
from typing import Dict |
|
import torch |
|
from equi_diffpo.model.common.normalizer import LinearNormalizer |
|
from equi_diffpo.policy.base_image_policy import BaseImagePolicy |
|
from equi_diffpo.common.pytorch_util import dict_apply |
|
|
|
from robomimic.algo import algo_factory |
|
from robomimic.algo.algo import PolicyAlgo |
|
import robomimic.utils.obs_utils as ObsUtils |
|
from equi_diffpo.common.robomimic_config_util import get_robomimic_config |
|
|
|
class RobomimicImagePolicy(BaseImagePolicy): |
|
def __init__(self, |
|
shape_meta: dict, |
|
algo_name='bc_rnn', |
|
obs_type='image', |
|
task_name='square', |
|
dataset_type='ph', |
|
crop_shape=(76,76) |
|
): |
|
super().__init__() |
|
|
|
|
|
action_shape = shape_meta['action']['shape'] |
|
assert len(action_shape) == 1 |
|
action_dim = action_shape[0] |
|
obs_shape_meta = shape_meta['obs'] |
|
obs_config = { |
|
'low_dim': [], |
|
'rgb': [], |
|
'depth': [], |
|
'scan': [] |
|
} |
|
obs_key_shapes = dict() |
|
for key, attr in obs_shape_meta.items(): |
|
shape = attr['shape'] |
|
obs_key_shapes[key] = list(shape) |
|
|
|
type = attr.get('type', 'low_dim') |
|
if type == 'rgb': |
|
obs_config['rgb'].append(key) |
|
elif type == 'low_dim': |
|
obs_config['low_dim'].append(key) |
|
else: |
|
raise RuntimeError(f"Unsupported obs type: {type}") |
|
|
|
|
|
config = get_robomimic_config( |
|
algo_name=algo_name, |
|
hdf5_type=obs_type, |
|
task_name=task_name, |
|
dataset_type=dataset_type) |
|
|
|
|
|
with config.unlocked(): |
|
|
|
config.observation.modalities.obs = obs_config |
|
|
|
if crop_shape is None: |
|
for key, modality in config.observation.encoder.items(): |
|
if modality.obs_randomizer_class == 'CropRandomizer': |
|
modality['obs_randomizer_class'] = None |
|
else: |
|
|
|
ch, cw = crop_shape |
|
for key, modality in config.observation.encoder.items(): |
|
if modality.obs_randomizer_class == 'CropRandomizer': |
|
modality.obs_randomizer_kwargs.crop_height = ch |
|
modality.obs_randomizer_kwargs.crop_width = cw |
|
|
|
|
|
ObsUtils.initialize_obs_utils_with_config(config) |
|
|
|
|
|
model: PolicyAlgo = algo_factory( |
|
algo_name=config.algo_name, |
|
config=config, |
|
obs_key_shapes=obs_key_shapes, |
|
ac_dim=action_dim, |
|
device='cpu', |
|
) |
|
|
|
self.model = model |
|
self.nets = model.nets |
|
self.normalizer = LinearNormalizer() |
|
self.config = config |
|
|
|
def to(self,*args,**kwargs): |
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) |
|
if device is not None: |
|
self.model.device = device |
|
super().to(*args,**kwargs) |
|
|
|
|
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
nobs_dict = self.normalizer(obs_dict) |
|
robomimic_obs_dict = dict_apply(nobs_dict, lambda x: x[:,0,...]) |
|
naction = self.model.get_action(robomimic_obs_dict) |
|
action = self.normalizer['action'].unnormalize(naction) |
|
|
|
result = { |
|
'action': action[:,None,:] |
|
} |
|
return result |
|
|
|
def reset(self): |
|
self.model.reset() |
|
|
|
|
|
def set_normalizer(self, normalizer: LinearNormalizer): |
|
self.normalizer.load_state_dict(normalizer.state_dict()) |
|
|
|
def train_on_batch(self, batch, epoch, validate=False): |
|
nobs = self.normalizer.normalize(batch['obs']) |
|
nactions = self.normalizer['action'].normalize(batch['action']) |
|
robomimic_batch = { |
|
'obs': nobs, |
|
'actions': nactions |
|
} |
|
input_batch = self.model.process_batch_for_training( |
|
robomimic_batch) |
|
info = self.model.train_on_batch( |
|
batch=input_batch, epoch=epoch, validate=validate) |
|
|
|
return info |
|
|
|
def on_epoch_end(self, epoch): |
|
self.model.on_epoch_end(epoch) |
|
|
|
def get_optimizer(self): |
|
return self.model.optimizers['policy'] |
|
|
|
|
|
def test(): |
|
import os |
|
from omegaconf import OmegaConf |
|
cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml') |
|
cfg = OmegaConf.load(cfg_path) |
|
shape_meta = cfg.shape_meta |
|
|
|
policy = RobomimicImagePolicy(shape_meta=shape_meta) |
|
|
|
|