|
|
import torch |
|
|
import torch_geometric |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch_geometric.nn import ( |
|
|
PNAConv, |
|
|
global_mean_pool, |
|
|
global_max_pool, |
|
|
global_add_pool, |
|
|
) |
|
|
from torch_geometric.utils import degree |
|
|
|
|
|
|
|
|
class PolyatomicNet(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
node_feat_dim, |
|
|
edge_feat_dim, |
|
|
graph_feat_dim, |
|
|
deg, |
|
|
hidden_dim=128, |
|
|
num_layers=5, |
|
|
dropout=0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.graph_feat_dim = graph_feat_dim |
|
|
self.node_emb = nn.Linear(node_feat_dim, hidden_dim) |
|
|
self.deg = deg |
|
|
self.virtualnode_emb = nn.Embedding(1, hidden_dim) |
|
|
self.vn_mlp = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
) |
|
|
|
|
|
|
|
|
self.graph_proj = nn.Sequential( |
|
|
nn.Linear(graph_feat_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
) |
|
|
|
|
|
|
|
|
self.deg_emb = nn.Embedding(20, hidden_dim) |
|
|
|
|
|
aggregators = ["mean", "min", "max", "std"] |
|
|
scalers = ["identity", "amplification", "attenuation"] |
|
|
|
|
|
self.convs = nn.ModuleList() |
|
|
self.bns = nn.ModuleList() |
|
|
|
|
|
for _ in range(num_layers): |
|
|
conv = PNAConv( |
|
|
in_channels=hidden_dim, |
|
|
out_channels=hidden_dim, |
|
|
aggregators=aggregators, |
|
|
scalers=scalers, |
|
|
edge_dim=edge_feat_dim, |
|
|
towers=4, |
|
|
pre_layers=1, |
|
|
post_layers=1, |
|
|
divide_input=True, |
|
|
deg=deg, |
|
|
) |
|
|
self.convs.append(conv) |
|
|
self.bns.append(nn.BatchNorm1d(hidden_dim)) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.readout = nn.Sequential( |
|
|
nn.Linear(hidden_dim * 3, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim // 2, 1), |
|
|
) |
|
|
|
|
|
def forward(self, data): |
|
|
x, edge_index, edge_attr, batch = ( |
|
|
data.x, |
|
|
data.edge_index, |
|
|
data.edge_attr, |
|
|
data.batch, |
|
|
) |
|
|
|
|
|
deg = degree(edge_index[0], x.size(0), dtype=torch.long).clamp(max=19) |
|
|
h = self.node_emb(x) + self.deg_emb(deg) |
|
|
|
|
|
vn = self.virtualnode_emb( |
|
|
torch.zeros(batch.max().item() + 1, dtype=torch.long, device=x.device) |
|
|
) |
|
|
|
|
|
for conv, bn in zip(self.convs, self.bns): |
|
|
h = h + vn[batch] |
|
|
h = conv(h, edge_index, edge_attr) |
|
|
h = bn(h) |
|
|
h = F.relu(h) |
|
|
h = self.dropout(h) |
|
|
vn = vn + self.vn_mlp(global_mean_pool(h, batch)) |
|
|
|
|
|
mean_pool = global_mean_pool(h, batch) |
|
|
max_pool = global_max_pool(h, batch) |
|
|
|
|
|
|
|
|
max_feat_dim = self.graph_feat_dim |
|
|
|
|
|
if hasattr(data, "graph_feats") and isinstance( |
|
|
data, torch_geometric.data.Batch |
|
|
): |
|
|
g_proj_list = [] |
|
|
for g in data.to_data_list(): |
|
|
g_feat = g.graph_feats.to(x.device) |
|
|
|
|
|
if g_feat.size(0) < max_feat_dim: |
|
|
padded = torch.zeros(max_feat_dim, device=g_feat.device) |
|
|
padded[: g_feat.size(0)] = g_feat |
|
|
g_feat = padded |
|
|
elif g_feat.size(0) > max_feat_dim: |
|
|
g_feat = g_feat[:max_feat_dim] |
|
|
g_feat = torch.nan_to_num(g_feat, nan=0.0, posinf=1e5, neginf=-1e5) |
|
|
g_proj_list.append(self.graph_proj(g_feat)) |
|
|
|
|
|
g_proj = torch.stack(g_proj_list, dim=0) |
|
|
|
|
|
else: |
|
|
g_feat = data.graph_feats.to(x.device) |
|
|
if g_feat.size(0) < max_feat_dim: |
|
|
padded = torch.zeros(max_feat_dim, device=g_feat.device) |
|
|
padded[: g_feat.size(0)] = g_feat |
|
|
g_feat = padded |
|
|
elif g_feat.size(0) > max_feat_dim: |
|
|
g_feat = g_feat[:max_feat_dim] |
|
|
g_feat = torch.nan_to_num(g_feat, nan=0.0, posinf=1e5, neginf=-1e5) |
|
|
g_proj = self.graph_proj(g_feat).unsqueeze(0) |
|
|
|
|
|
final_input = torch.cat([mean_pool, max_pool, g_proj], dim=1) |
|
|
return self.readout(final_input).view(-1) |
|
|
|