|
import torch |
|
from torch.nn import Linear, ReLU, SiLU, Sequential |
|
from torch_geometric.nn import MessagePassing |
|
from torch_scatter import scatter |
|
|
|
from models_cifm.mlp_and_gnn import MLPBiasFree |
|
|
|
|
|
class EGNNLayer(MessagePassing): |
|
"""E(n) Equivariant GNN Layer |
|
|
|
Paper: E(n) Equivariant Graph Neural Networks, Satorras et al. |
|
""" |
|
def __init__(self, emb_dim, num_mlp_layers, aggr="add"): |
|
""" |
|
Args: |
|
emb_dim: (int) - hidden dimension `d` |
|
activation: (str) - non-linearity within MLPs (swish/relu) |
|
norm: (str) - normalisation layer (layer/batch) |
|
aggr: (str) - aggregation function `\oplus` (sum/mean/max) |
|
""" |
|
|
|
super().__init__(aggr=aggr) |
|
|
|
self.emb_dim = emb_dim |
|
|
|
self.dist_embedding = Linear(1, emb_dim, bias=False) |
|
self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers) |
|
self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers) |
|
self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers) |
|
self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers) |
|
|
|
def forward(self, h, pos, edge_index): |
|
""" |
|
Args: |
|
h: (n, d) - initial node features |
|
pos: (n, 3) - initial node coordinates |
|
edge_index: (e, 2) - pairs of edges (i, j) |
|
Returns: |
|
out: [(n, d),(n,3)] - updated node features |
|
""" |
|
out = self.propagate(edge_index, h=h, pos=pos) |
|
return out |
|
|
|
def message(self, h_i, h_j, pos_i, pos_j): |
|
|
|
pos_diff = pos_i - pos_j |
|
dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) |
|
inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1) |
|
msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod) |
|
msg = self.mlp_msg(msg) |
|
|
|
pos_diff = pos_diff * self.mlp_pos(msg) |
|
|
|
return msg, pos_diff, inner_prod |
|
|
|
def aggregate(self, inputs, index): |
|
msgs, pos_diffs, inner_prod = inputs |
|
|
|
msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce="add") |
|
|
|
pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="add") |
|
|
|
counts = torch.ones_like(inner_prod) |
|
counts[inner_prod==0] = 0 |
|
counts = scatter(counts, index, dim=0, reduce="add") |
|
counts[counts==0] = 1 |
|
pos_aggr = pos_aggr / counts |
|
return msg_aggr, pos_aggr |
|
|
|
def update(self, aggr_out, h, pos): |
|
msg_aggr, pos_aggr = aggr_out |
|
upd_out = self.mlp_upd(msg_aggr) |
|
upd_pos = pos + pos_aggr |
|
return upd_out, upd_pos |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" |
|
|