moPPIt / flow_matching /utils /model_wrapper.py
AlienChen's picture
Upload 72 files
3527383 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC
from torch import nn, Tensor
class ModelWrapper(ABC, nn.Module):
"""
This class is used to wrap around another model, adding custom forward pass logic.
"""
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
r"""
This method defines how inputs should be passed through the wrapped model.
Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input,
along with any additional keyword arguments.
Optional things to do here:
- check that t is in the dimensions that the model is expecting.
- add a custom forward pass logic.
- call the wrapped model.
| given x, t
| returns the model output for input x at time t, with extra information `extra`.
Args:
x (Tensor): input data to the model (batch_size, ...).
t (Tensor): time (batch_size).
**extras: additional information forwarded to the model, e.g., text condition.
Returns:
Tensor: model output.
"""
return self.model(x=x, t=t, **extras)