File size: 5,662 Bytes
29c8230 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import esm
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, StratifiedShuffleSplit, StratifiedKFold
import collections
from torch.utils.data import DataLoader, TensorDataset
import os
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score, precision_score
import random
from sklearn.metrics import auc
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
#import esm
from tqdm import tqdm
import time
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, matthews_corrcoef, recall_score, f1_score, precision_score
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
from transformers import PretrainedConfig
from typing import List
from .configuration_TransHLA_II import TransHLA_II_Config
from transformers import PreTrainedModel
class TransHLA_II(nn.Module):
def __init__(self,config):
super(TransHLA_II, self).__init__()
max_len = config.max_len
n_layers = config.n_layers
n_head = config.n_head
d_model = config.d_model
d_ff = config.d_ff
cnn_padding_index = config.cnn_padding_index
cnn_num_channel = config.cnn_num_channel
region_embedding_size = config.region_embedding_size
cnn_kernel_size = config.cnn_kernel_size
cnn_padding_size = config.cnn_padding_size
cnn_stride = config.cnn_stride
pooling_size = config.pooling_size
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
self.region_cnn1 = nn.Conv1d(
d_model, cnn_num_channel, region_embedding_size)
self.region_cnn2 = nn.Conv1d(
max_len, cnn_num_channel, region_embedding_size)
self.padding1 = nn.ConstantPad1d((1, 1), 0)
self.padding2 = nn.ConstantPad1d((0, 1), 0)
self.relu = nn.ReLU()
self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
padding=cnn_padding_size, stride=cnn_stride)
self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
padding=cnn_padding_size, stride=cnn_stride)
self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size)
self.transformer_layers = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
self.transformer_encoder = nn.TransformerEncoder(
self.transformer_layers, num_layers=n_layers)
self.bn1 = nn.BatchNorm1d(d_model)
self.bn2 = nn.BatchNorm1d(cnn_num_channel)
self.bn3 = nn.BatchNorm1d(cnn_num_channel)
self.fc_task = nn.Sequential(
nn.Linear(d_model+2*cnn_num_channel, d_model // 4),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(d_model // 4, 64),
)
self.classifier = nn.Linear(64, 2)
def cnn_block1(self, x):
return self.cnn1(self.relu(x))
def cnn_block2(self, x):
x = self.padding2(x)
px = self.maxpooling(x)
x = self.relu(px)
x = self.cnn1(x)
x = self.relu(x)
x = self.cnn1(x)
x = px + x
return x
def structure_block1(self, x):
return self.cnn2(self.relu(x))
def structure_block2(self, x):
x = self.padding2(x)
px = self.maxpooling(x)
x = self.relu(px)
x = self.cnn2(x)
x = self.relu(x)
x = self.cnn2(x)
x = px + x
return x
def forward(self, x_in):
with torch.no_grad():
results = self.esm(x_in, repr_layers=[33], return_contacts=True)
emb = results["representations"][33]
structure_emb = results["contacts"]
output = self.transformer_encoder(emb)
representation = output[:, 0, :]
representation = self.bn1(representation)
cnn_emb = self.region_cnn1(emb.transpose(1, 2))
cnn_emb = self.padding1(cnn_emb)
conv = cnn_emb + self.cnn_block1(self.cnn_block1(cnn_emb))
while conv.size(-1) >= 2:
conv = self.cnn_block2(conv)
cnn_out = torch.squeeze(conv, dim=-1)
cnn_out = self.bn2(cnn_out)
structure_emb = self.region_cnn2(structure_emb.transpose(1, 2))
structure_emb = self.padding1(structure_emb)
structure_conv = structure_emb + \
self.structure_block1(self.structure_block1(structure_emb))
while structure_conv.size(-1) >= 2:
structure_conv = self.structure_block2(structure_conv)
structure_cnn_out = torch.squeeze(structure_conv, dim=-1)
structure_cnn_out = self.bn3(structure_cnn_out)
representation = torch.concat(
(representation,cnn_out,structure_cnn_out), dim=1)
reduction_feature = self.fc_task(representation)
reduction_feature = reduction_feature.view(
reduction_feature.size(0), -1)
logits_clsf = self.classifier(reduction_feature)
logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1)
return logits_clsf, reduction_feature
class TransHLA_II_Model(PreTrainedModel):
config_class = TransHLA_II_Config
def __init__(self, config):
super().__init__(config)
self.model = TransHLA_II(config)
def forward(self, tensor):
return self.model(tensor) |