PACT-Net / models /polyatomic.py
rk-random's picture
Upload folder using huggingface_hub
9a67fbe verified
raw
history blame
4.52 kB
import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
PNAConv,
global_mean_pool,
global_max_pool,
global_add_pool,
)
from torch_geometric.utils import degree
class PolyatomicNet(nn.Module):
def __init__(
self,
node_feat_dim,
edge_feat_dim,
graph_feat_dim,
deg,
hidden_dim=128,
num_layers=5,
dropout=0.1,
):
super().__init__()
self.graph_feat_dim = graph_feat_dim
self.node_emb = nn.Linear(node_feat_dim, hidden_dim)
self.deg = deg
self.virtualnode_emb = nn.Embedding(1, hidden_dim)
self.vn_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# For graph-level feature projection
self.graph_proj = nn.Sequential(
nn.Linear(graph_feat_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# PNAConv requires degree preprocessing
self.deg_emb = nn.Embedding(20, hidden_dim) # cap degree buckets
aggregators = ["mean", "min", "max", "std"]
scalers = ["identity", "amplification", "attenuation"]
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for _ in range(num_layers):
conv = PNAConv(
in_channels=hidden_dim,
out_channels=hidden_dim,
aggregators=aggregators,
scalers=scalers,
edge_dim=edge_feat_dim,
towers=4,
pre_layers=1,
post_layers=1,
divide_input=True,
deg=deg,
)
self.convs.append(conv)
self.bns.append(nn.BatchNorm1d(hidden_dim))
self.dropout = nn.Dropout(dropout)
# Final readout
self.readout = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
)
def forward(self, data):
x, edge_index, edge_attr, batch = (
data.x,
data.edge_index,
data.edge_attr,
data.batch,
)
deg = degree(edge_index[0], x.size(0), dtype=torch.long).clamp(max=19)
h = self.node_emb(x) + self.deg_emb(deg)
vn = self.virtualnode_emb(
torch.zeros(batch.max().item() + 1, dtype=torch.long, device=x.device)
)
for conv, bn in zip(self.convs, self.bns):
h = h + vn[batch]
h = conv(h, edge_index, edge_attr)
h = bn(h)
h = F.relu(h)
h = self.dropout(h)
vn = vn + self.vn_mlp(global_mean_pool(h, batch))
mean_pool = global_mean_pool(h, batch)
max_pool = global_max_pool(h, batch)
# add_pool = global_add_pool(h, batch)
max_feat_dim = self.graph_feat_dim
if hasattr(data, "graph_feats") and isinstance(
data, torch_geometric.data.Batch # type: ignore
):
g_proj_list = []
for g in data.to_data_list():
g_feat = g.graph_feats.to(x.device)
if g_feat.size(0) < max_feat_dim:
padded = torch.zeros(max_feat_dim, device=g_feat.device)
padded[: g_feat.size(0)] = g_feat
g_feat = padded
elif g_feat.size(0) > max_feat_dim:
g_feat = g_feat[:max_feat_dim]
g_feat = torch.nan_to_num(g_feat, nan=0.0, posinf=1e5, neginf=-1e5)
g_proj_list.append(self.graph_proj(g_feat))
g_proj = torch.stack(g_proj_list, dim=0)
else:
g_feat = data.graph_feats.to(x.device)
if g_feat.size(0) < max_feat_dim:
padded = torch.zeros(max_feat_dim, device=g_feat.device)
padded[: g_feat.size(0)] = g_feat
g_feat = padded
elif g_feat.size(0) > max_feat_dim:
g_feat = g_feat[:max_feat_dim]
g_feat = torch.nan_to_num(g_feat, nan=0.0, posinf=1e5, neginf=-1e5)
g_proj = self.graph_proj(g_feat).unsqueeze(0)
final_input = torch.cat([mean_pool, max_pool, g_proj], dim=1)
return self.readout(final_input).view(-1)