|
from transformers import BertModel,BertConfig
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
|
|
time_gap=10000.0
|
|
|
|
class CSIBERT(nn.Module):
|
|
def __init__(self,bertconfig, input_dim):
|
|
super().__init__()
|
|
self.bertconfig=bertconfig
|
|
self.bert=BertModel(bertconfig)
|
|
self.hidden_dim=bertconfig.hidden_size
|
|
self.input_dim=input_dim
|
|
self.len=bertconfig.max_position_embeddings
|
|
|
|
self.Norm1 = nn.LayerNorm(self.input_dim)
|
|
self.Norm2 = nn.LayerNorm(self.hidden_dim)
|
|
self.Norm3 = nn.LayerNorm(self.hidden_dim)
|
|
|
|
self.csi_emb=nn.Sequential(
|
|
nn.Linear(input_dim, input_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(input_dim, self.hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(self.hidden_dim, self.hidden_dim)
|
|
)
|
|
|
|
self.time_emb=nn.Sequential(
|
|
nn.Linear(input_dim, input_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(input_dim, self.hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(self.hidden_dim, self.hidden_dim)
|
|
)
|
|
|
|
self.fusion_emb=nn.Sequential(
|
|
nn.Linear(self.hidden_dim*2, self.hidden_dim*2),
|
|
nn.ReLU(),
|
|
nn.Linear(self.hidden_dim*2, self.hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(self.hidden_dim, self.hidden_dim)
|
|
)
|
|
|
|
self.arl = nn.Sequential(
|
|
nn.Linear(self.len, self.len // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(self.len // 2, self.len // 4),
|
|
nn.ReLU(),
|
|
nn.Linear(self.len // 4, 1)
|
|
)
|
|
|
|
def forward(self,x,timestamp,attention_mask=None):
|
|
x=x.to(torch.float32)
|
|
|
|
x=self.attention(x)
|
|
x=self.csi_emb(x)
|
|
x_time=self.time_embedding(timestamp)
|
|
x = x + x_time
|
|
y=self.bert(inputs_embeds=x, attention_mask=attention_mask, output_hidden_states=False)
|
|
y=y.last_hidden_state
|
|
return y
|
|
|
|
def time_embedding(self,timestamp,t=1):
|
|
device=timestamp.device
|
|
|
|
|
|
timestamp = (timestamp - timestamp[:, 0:1]) / (timestamp[:,-1:] - timestamp[:, 0:1]) * self.len
|
|
|
|
timestamp**=t
|
|
d_model=self.input_dim
|
|
dim=torch.tensor(list(range(d_model))).to(device)
|
|
batch_size,length=timestamp.shape
|
|
timestamp=timestamp.unsqueeze(2).repeat(1, 1, d_model)
|
|
dim=dim.reshape([1,1,-1]).repeat(batch_size,length,1)
|
|
sin_emb = torch.sin(timestamp/10000**(dim//2*2/d_model))
|
|
cos_emb = torch.cos(timestamp/10000**(dim//2*2/d_model))
|
|
mask=torch.zeros(d_model).to(device)
|
|
mask[::2]=1
|
|
emb=sin_emb*mask+cos_emb*(1-mask)
|
|
emb=self.time_emb(emb)
|
|
|
|
|
|
|
|
|
|
return emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attention(self, x):
|
|
y = torch.transpose(x, -1, -2)
|
|
attn = self.arl(y)
|
|
y = y * attn
|
|
y = torch.transpose(y, -1, -2)
|
|
return y
|
|
|
|
class Token_Classifier(nn.Module):
|
|
def __init__(self,bert,class_num=52):
|
|
super().__init__()
|
|
self.bert=bert
|
|
self.classifier=nn.Sequential(
|
|
nn.Linear(bert.hidden_dim, bert.hidden_dim//2),
|
|
nn.ReLU(),
|
|
nn.Linear(bert.hidden_dim//2, class_num)
|
|
)
|
|
|
|
def forward(self,x,timestamp,attention_mask=None):
|
|
x=self.bert(x,timestamp,attention_mask=attention_mask)
|
|
x=self.classifier(x)
|
|
return x
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, input_dim, da, r):
|
|
super().__init__()
|
|
self.ws1 = nn.Linear(input_dim, da, bias=False)
|
|
self.ws2 = nn.Linear(da, r, bias=False)
|
|
|
|
def forward(self, h):
|
|
attn_mat = F.softmax(self.ws2(torch.tanh(self.ws1(h))), dim=1)
|
|
attn_mat = attn_mat.permute(0, 2, 1)
|
|
return attn_mat
|
|
|
|
|
|
class Sequence_Classifier(nn.Module):
|
|
def __init__(self, csibert, class_num, hs=128, da=128, r=4):
|
|
super().__init__()
|
|
self.bert = csibert
|
|
self.attention = SelfAttention(hs, da, r)
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(hs * r, hs * r // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hs * r // 2, class_num)
|
|
)
|
|
|
|
def forward(self, x, timestamp,attention_mask=None):
|
|
x = self.bert(x, timestamp,attention_mask=attention_mask)
|
|
attn_mat = self.attention(x)
|
|
m = torch.bmm(attn_mat, x)
|
|
flatten = m.view(m.size()[0], -1)
|
|
res = self.classifier(flatten)
|
|
return res
|
|
|
|
class CSI_BERT2(nn.Module,
|
|
PyTorchModelHubMixin
|
|
):
|
|
def __init__(self, max_len=100, hs=128, layers=6, heads=8, intermediate_size=512, carrier_dim=52):
|
|
super().__init__()
|
|
self.config = BertConfig(max_position_embeddings=max_len, hidden_size=hs, num_hidden_layers=layers,num_attention_heads=heads, intermediate_size=intermediate_size)
|
|
self.model = CSIBERT(self.config,carrier_dim)
|
|
|
|
def forward(self, x, timestamp=None, attn_mask=None):
|
|
return self.model(x,timestamp,attn_mask) |