|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import esm |
|
import numpy as np |
|
import pandas as pd |
|
from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold |
|
import collections |
|
from torch.utils.data import DataLoader, TensorDataset |
|
import os |
|
from sklearn.metrics import roc_curve, roc_auc_score |
|
from sklearn.metrics import precision_recall_curve, average_precision_score |
|
from sklearn.metrics import matthews_corrcoef |
|
from sklearn.metrics import f1_score |
|
from sklearn.metrics import recall_score, precision_score |
|
import random |
|
from sklearn.metrics import auc |
|
from sklearn.decomposition import PCA |
|
import matplotlib.pyplot as plt |
|
|
|
from tqdm import tqdm |
|
import time |
|
import seaborn as sns |
|
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, matthews_corrcoef, recall_score, f1_score, precision_score |
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
from transformers import PretrainedConfig |
|
from typing import List |
|
from .configuration_TransHLA_II import TransHLA_II_Config |
|
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
|
|
class TransHLA_II(nn.Module): |
|
def __init__(self,config): |
|
super(TransHLA_II, self).__init__() |
|
|
|
max_len = config.max_len |
|
n_layers = config.n_layers |
|
n_head = config.n_head |
|
d_model = config.d_model |
|
d_ff = config.d_ff |
|
cnn_padding_index = config.cnn_padding_index |
|
cnn_num_channel = config.cnn_num_channel |
|
region_embedding_size = config.region_embedding_size |
|
cnn_kernel_size = config.cnn_kernel_size |
|
cnn_padding_size = config.cnn_padding_size |
|
cnn_stride = config.cnn_stride |
|
pooling_size = config.pooling_size |
|
|
|
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
self.region_cnn1 = nn.Conv1d( |
|
d_model, cnn_num_channel, region_embedding_size) |
|
self.region_cnn2 = nn.Conv1d( |
|
max_len, cnn_num_channel, region_embedding_size) |
|
self.padding1 = nn.ConstantPad1d((1, 1), 0) |
|
self.padding2 = nn.ConstantPad1d((0, 1), 0) |
|
self.relu = nn.ReLU() |
|
self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size, |
|
padding=cnn_padding_size, stride=cnn_stride) |
|
self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size, |
|
padding=cnn_padding_size, stride=cnn_stride) |
|
self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size) |
|
self.transformer_layers = nn.TransformerEncoderLayer( |
|
d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2) |
|
self.transformer_encoder = nn.TransformerEncoder( |
|
self.transformer_layers, num_layers=n_layers) |
|
self.bn1 = nn.BatchNorm1d(d_model) |
|
self.bn2 = nn.BatchNorm1d(cnn_num_channel) |
|
self.bn3 = nn.BatchNorm1d(cnn_num_channel) |
|
self.fc_task = nn.Sequential( |
|
nn.Linear(d_model+2*cnn_num_channel, d_model // 4), |
|
nn.Dropout(0.3), |
|
nn.ReLU(), |
|
nn.Linear(d_model // 4, 64), |
|
) |
|
self.classifier = nn.Linear(64, 2) |
|
|
|
def cnn_block1(self, x): |
|
return self.cnn1(self.relu(x)) |
|
|
|
def cnn_block2(self, x): |
|
x = self.padding2(x) |
|
px = self.maxpooling(x) |
|
x = self.relu(px) |
|
x = self.cnn1(x) |
|
x = self.relu(x) |
|
x = self.cnn1(x) |
|
x = px + x |
|
return x |
|
|
|
def structure_block1(self, x): |
|
return self.cnn2(self.relu(x)) |
|
|
|
def structure_block2(self, x): |
|
x = self.padding2(x) |
|
px = self.maxpooling(x) |
|
x = self.relu(px) |
|
x = self.cnn2(x) |
|
x = self.relu(x) |
|
x = self.cnn2(x) |
|
x = px + x |
|
return x |
|
|
|
def forward(self, x_in): |
|
with torch.no_grad(): |
|
results = self.esm(x_in, repr_layers=[33], return_contacts=True) |
|
emb = results["representations"][33] |
|
structure_emb = results["contacts"] |
|
output = self.transformer_encoder(emb) |
|
representation = output[:, 0, :] |
|
representation = self.bn1(representation) |
|
cnn_emb = self.region_cnn1(emb.transpose(1, 2)) |
|
cnn_emb = self.padding1(cnn_emb) |
|
conv = cnn_emb + self.cnn_block1(self.cnn_block1(cnn_emb)) |
|
while conv.size(-1) >= 2: |
|
conv = self.cnn_block2(conv) |
|
cnn_out = torch.squeeze(conv, dim=-1) |
|
cnn_out = self.bn2(cnn_out) |
|
|
|
structure_emb = self.region_cnn2(structure_emb.transpose(1, 2)) |
|
structure_emb = self.padding1(structure_emb) |
|
structure_conv = structure_emb + \ |
|
self.structure_block1(self.structure_block1(structure_emb)) |
|
while structure_conv.size(-1) >= 2: |
|
structure_conv = self.structure_block2(structure_conv) |
|
structure_cnn_out = torch.squeeze(structure_conv, dim=-1) |
|
structure_cnn_out = self.bn3(structure_cnn_out) |
|
representation = torch.concat( |
|
(representation,cnn_out,structure_cnn_out), dim=1) |
|
reduction_feature = self.fc_task(representation) |
|
reduction_feature = reduction_feature.view( |
|
reduction_feature.size(0), -1) |
|
logits_clsf = self.classifier(reduction_feature) |
|
logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1) |
|
return logits_clsf, reduction_feature |
|
|
|
|
|
class TransHLA_II_Model(PreTrainedModel): |
|
config_class = TransHLA_II_Config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = TransHLA_II(config) |
|
|
|
def forward(self, tensor): |
|
return self.model(tensor) |