from transformers import BertModel,BertConfig import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin time_gap=10000.0 class CSIBERT(nn.Module): def __init__(self,bertconfig, input_dim): super().__init__() self.bertconfig=bertconfig self.bert=BertModel(bertconfig) self.hidden_dim=bertconfig.hidden_size self.input_dim=input_dim self.len=bertconfig.max_position_embeddings self.Norm1 = nn.LayerNorm(self.input_dim) self.Norm2 = nn.LayerNorm(self.hidden_dim) self.Norm3 = nn.LayerNorm(self.hidden_dim) self.csi_emb=nn.Sequential( nn.Linear(input_dim, input_dim), nn.ReLU(), nn.Linear(input_dim, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, self.hidden_dim) ) self.time_emb=nn.Sequential( nn.Linear(input_dim, input_dim), nn.ReLU(), nn.Linear(input_dim, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, self.hidden_dim) ) self.fusion_emb=nn.Sequential( nn.Linear(self.hidden_dim*2, self.hidden_dim*2), nn.ReLU(), nn.Linear(self.hidden_dim*2, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, self.hidden_dim) ) self.arl = nn.Sequential( nn.Linear(self.len, self.len // 2), nn.ReLU(), nn.Linear(self.len // 2, self.len // 4), nn.ReLU(), nn.Linear(self.len // 4, 1) ) def forward(self,x,timestamp,attention_mask=None): x=x.to(torch.float32) x=self.attention(x) x=self.csi_emb(x) x_time=self.time_embedding(timestamp) x = x + x_time y=self.bert(inputs_embeds=x, attention_mask=attention_mask, output_hidden_states=False) y=y.last_hidden_state return y def time_embedding(self,timestamp,t=1): device=timestamp.device # timestamp = (timestamp - timestamp[:,0:1]) / time_gap # timestamp = (timestamp - timestamp[:, 0:1]) / (timestamp[:,-1:] - timestamp[:, 0:1]) timestamp = (timestamp - timestamp[:, 0:1]) / (timestamp[:,-1:] - timestamp[:, 0:1]) * self.len timestamp**=t d_model=self.input_dim dim=torch.tensor(list(range(d_model))).to(device) batch_size,length=timestamp.shape timestamp=timestamp.unsqueeze(2).repeat(1, 1, d_model) dim=dim.reshape([1,1,-1]).repeat(batch_size,length,1) sin_emb = torch.sin(timestamp/10000**(dim//2*2/d_model)) cos_emb = torch.cos(timestamp/10000**(dim//2*2/d_model)) mask=torch.zeros(d_model).to(device) mask[::2]=1 emb=sin_emb*mask+cos_emb*(1-mask) emb=self.time_emb(emb) # timestamp = torch.unsqueeze(timestamp, -1) # emb=self.time_emb(timestamp) return emb # def attention(self,x): # y = torch.transpose(x, -1, -2) # batch_size = y.shape[0] # queries = self.query(y).view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2) # keys = self.key(y).view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2) # values = self.value(y).view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2) # attention_weights = self.softmax(torch.matmul(queries, keys.transpose(-1, -2))/ (self.head_dim ** 0.5)) # # # attended_values = torch.matmul(attention_weights,values).transpose(1, 2) # # attended_values = attended_values.reshape(batch_size,self.input_dim,self.len) # # attended_values = self.norm(attended_values) # # y = attended_values.transpose(1, 2) # # attended_values = torch.matmul(attention_weights, values).transpose(-1, -2) # attended_values = attended_values.reshape(batch_size, self.len, self.input_dim) # y = self.norm(attended_values) # # return y+x def attention(self, x): y = torch.transpose(x, -1, -2) attn = self.arl(y) y = y * attn y = torch.transpose(y, -1, -2) return y class Token_Classifier(nn.Module): def __init__(self,bert,class_num=52): super().__init__() self.bert=bert self.classifier=nn.Sequential( nn.Linear(bert.hidden_dim, bert.hidden_dim//2), nn.ReLU(), nn.Linear(bert.hidden_dim//2, class_num) ) def forward(self,x,timestamp,attention_mask=None): x=self.bert(x,timestamp,attention_mask=attention_mask) x=self.classifier(x) return x class SelfAttention(nn.Module): def __init__(self, input_dim, da, r): super().__init__() self.ws1 = nn.Linear(input_dim, da, bias=False) self.ws2 = nn.Linear(da, r, bias=False) def forward(self, h): attn_mat = F.softmax(self.ws2(torch.tanh(self.ws1(h))), dim=1) attn_mat = attn_mat.permute(0, 2, 1) return attn_mat class Sequence_Classifier(nn.Module): def __init__(self, csibert, class_num, hs=128, da=128, r=4): super().__init__() self.bert = csibert self.attention = SelfAttention(hs, da, r) self.classifier = nn.Sequential( nn.Linear(hs * r, hs * r // 2), nn.ReLU(), nn.Linear(hs * r // 2, class_num) ) def forward(self, x, timestamp,attention_mask=None): x = self.bert(x, timestamp,attention_mask=attention_mask) attn_mat = self.attention(x) m = torch.bmm(attn_mat, x) flatten = m.view(m.size()[0], -1) res = self.classifier(flatten) return res class CSI_BERT2(nn.Module, PyTorchModelHubMixin ): def __init__(self, max_len=100, hs=128, layers=6, heads=8, intermediate_size=512, carrier_dim=52): super().__init__() self.config = BertConfig(max_position_embeddings=max_len, hidden_size=hs, num_hidden_layers=layers,num_attention_heads=heads, intermediate_size=intermediate_size) self.model = CSIBERT(self.config,carrier_dim) def forward(self, x, timestamp=None, attn_mask=None): return self.model(x,timestamp,attn_mask)