import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GCNConv, GINConv, GATConv, SAGEConv, global_mean_pool class GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super(GCN, self).__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, hidden_channels) self.lin = nn.Linear(hidden_channels, out_channels) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return self.lin(x) class GIN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super(GIN, self).__init__() nn1 = nn.Sequential( nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, hidden_channels), ) nn2 = nn.Sequential( nn.Linear(hidden_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, hidden_channels), ) self.conv1 = GINConv(nn1) self.conv2 = GINConv(nn2) self.lin = nn.Linear(hidden_channels, out_channels) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return self.lin(x) class GAT(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, heads=4): super(GAT, self).__init__() self.conv1 = GATConv(in_channels, hidden_channels, heads=heads) self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1) self.lin = nn.Linear(hidden_channels, out_channels) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.elu(self.conv1(x, edge_index)) x = F.elu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return self.lin(x) class GraphSAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super(GraphSAGE, self).__init__() self.conv1 = SAGEConv(in_channels, hidden_channels) self.conv2 = SAGEConv(hidden_channels, hidden_channels) self.lin = nn.Linear(hidden_channels, out_channels) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return self.lin(x)