|
from typing import Dict |
|
import torch |
|
import torch.nn as nn |
|
from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin |
|
from equi_diffpo.model.common.normalizer import LinearNormalizer |
|
|
|
class BaseLowdimPolicy(ModuleAttrMixin): |
|
|
|
|
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
""" |
|
obs_dict: |
|
obs: B,To,Do |
|
return: |
|
action: B,Ta,Da |
|
To = 3 |
|
Ta = 4 |
|
T = 6 |
|
|o|o|o| |
|
| | |a|a|a|a| |
|
|o|o| |
|
| |a|a|a|a|a| |
|
| | | | |a|a| |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
def reset(self): |
|
pass |
|
|
|
|
|
|
|
def set_normalizer(self, normalizer: LinearNormalizer): |
|
raise NotImplementedError() |
|
|
|
|