RS2002 commited on
Commit
cea9202
·
verified ·
1 Parent(s): 86422a8

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +170 -0
model.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel,BertConfig
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+
7
+ time_gap=10000.0
8
+
9
+ class CSIBERT(nn.Module):
10
+ def __init__(self,bertconfig, input_dim):
11
+ super().__init__()
12
+ self.bertconfig=bertconfig
13
+ self.bert=BertModel(bertconfig)
14
+ self.hidden_dim=bertconfig.hidden_size
15
+ self.input_dim=input_dim
16
+ self.len=bertconfig.max_position_embeddings
17
+
18
+ self.Norm1 = nn.LayerNorm(self.input_dim)
19
+ self.Norm2 = nn.LayerNorm(self.hidden_dim)
20
+ self.Norm3 = nn.LayerNorm(self.hidden_dim)
21
+
22
+ self.csi_emb=nn.Sequential(
23
+ nn.Linear(input_dim, input_dim),
24
+ nn.ReLU(),
25
+ nn.Linear(input_dim, self.hidden_dim),
26
+ nn.ReLU(),
27
+ nn.Linear(self.hidden_dim, self.hidden_dim)
28
+ )
29
+
30
+ self.time_emb=nn.Sequential(
31
+ nn.Linear(input_dim, input_dim),
32
+ nn.ReLU(),
33
+ nn.Linear(input_dim, self.hidden_dim),
34
+ nn.ReLU(),
35
+ nn.Linear(self.hidden_dim, self.hidden_dim)
36
+ )
37
+
38
+ self.fusion_emb=nn.Sequential(
39
+ nn.Linear(self.hidden_dim*2, self.hidden_dim*2),
40
+ nn.ReLU(),
41
+ nn.Linear(self.hidden_dim*2, self.hidden_dim),
42
+ nn.ReLU(),
43
+ nn.Linear(self.hidden_dim, self.hidden_dim)
44
+ )
45
+
46
+ self.arl = nn.Sequential(
47
+ nn.Linear(self.len, self.len // 2),
48
+ nn.ReLU(),
49
+ nn.Linear(self.len // 2, self.len // 4),
50
+ nn.ReLU(),
51
+ nn.Linear(self.len // 4, 1)
52
+ )
53
+
54
+ def forward(self,x,timestamp,attention_mask=None):
55
+ x=x.to(torch.float32)
56
+
57
+ x=self.attention(x)
58
+ x=self.csi_emb(x)
59
+ x_time=self.time_embedding(timestamp)
60
+ x = x + x_time
61
+ y=self.bert(inputs_embeds=x, attention_mask=attention_mask, output_hidden_states=False)
62
+ y=y.last_hidden_state
63
+ return y
64
+
65
+ def time_embedding(self,timestamp,t=1):
66
+ device=timestamp.device
67
+ # timestamp = (timestamp - timestamp[:,0:1]) / time_gap
68
+ # timestamp = (timestamp - timestamp[:, 0:1]) / (timestamp[:,-1:] - timestamp[:, 0:1])
69
+ timestamp = (timestamp - timestamp[:, 0:1]) / (timestamp[:,-1:] - timestamp[:, 0:1]) * self.len
70
+
71
+ timestamp**=t
72
+ d_model=self.input_dim
73
+ dim=torch.tensor(list(range(d_model))).to(device)
74
+ batch_size,length=timestamp.shape
75
+ timestamp=timestamp.unsqueeze(2).repeat(1, 1, d_model)
76
+ dim=dim.reshape([1,1,-1]).repeat(batch_size,length,1)
77
+ sin_emb = torch.sin(timestamp/10000**(dim//2*2/d_model))
78
+ cos_emb = torch.cos(timestamp/10000**(dim//2*2/d_model))
79
+ mask=torch.zeros(d_model).to(device)
80
+ mask[::2]=1
81
+ emb=sin_emb*mask+cos_emb*(1-mask)
82
+ emb=self.time_emb(emb)
83
+
84
+ # timestamp = torch.unsqueeze(timestamp, -1)
85
+ # emb=self.time_emb(timestamp)
86
+
87
+ return emb
88
+
89
+ # def attention(self,x):
90
+ # y = torch.transpose(x, -1, -2)
91
+ # batch_size = y.shape[0]
92
+ # queries = self.query(y).view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2)
93
+ # keys = self.key(y).view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2)
94
+ # values = self.value(y).view(batch_size, -1, self.head_num, self.head_dim).transpose(1, 2)
95
+ # attention_weights = self.softmax(torch.matmul(queries, keys.transpose(-1, -2))/ (self.head_dim ** 0.5))
96
+ #
97
+ # # attended_values = torch.matmul(attention_weights,values).transpose(1, 2)
98
+ # # attended_values = attended_values.reshape(batch_size,self.input_dim,self.len)
99
+ # # attended_values = self.norm(attended_values)
100
+ # # y = attended_values.transpose(1, 2)
101
+ #
102
+ # attended_values = torch.matmul(attention_weights, values).transpose(-1, -2)
103
+ # attended_values = attended_values.reshape(batch_size, self.len, self.input_dim)
104
+ # y = self.norm(attended_values)
105
+ #
106
+ # return y+x
107
+
108
+ def attention(self, x):
109
+ y = torch.transpose(x, -1, -2)
110
+ attn = self.arl(y)
111
+ y = y * attn
112
+ y = torch.transpose(y, -1, -2)
113
+ return y
114
+
115
+ class Token_Classifier(nn.Module):
116
+ def __init__(self,bert,class_num=52):
117
+ super().__init__()
118
+ self.bert=bert
119
+ self.classifier=nn.Sequential(
120
+ nn.Linear(bert.hidden_dim, bert.hidden_dim//2),
121
+ nn.ReLU(),
122
+ nn.Linear(bert.hidden_dim//2, class_num)
123
+ )
124
+
125
+ def forward(self,x,timestamp,attention_mask=None):
126
+ x=self.bert(x,timestamp,attention_mask=attention_mask)
127
+ x=self.classifier(x)
128
+ return x
129
+
130
+ class SelfAttention(nn.Module):
131
+ def __init__(self, input_dim, da, r):
132
+ super().__init__()
133
+ self.ws1 = nn.Linear(input_dim, da, bias=False)
134
+ self.ws2 = nn.Linear(da, r, bias=False)
135
+
136
+ def forward(self, h):
137
+ attn_mat = F.softmax(self.ws2(torch.tanh(self.ws1(h))), dim=1)
138
+ attn_mat = attn_mat.permute(0, 2, 1)
139
+ return attn_mat
140
+
141
+
142
+ class Sequence_Classifier(nn.Module):
143
+ def __init__(self, csibert, class_num, hs=128, da=128, r=4):
144
+ super().__init__()
145
+ self.bert = csibert
146
+ self.attention = SelfAttention(hs, da, r)
147
+ self.classifier = nn.Sequential(
148
+ nn.Linear(hs * r, hs * r // 2),
149
+ nn.ReLU(),
150
+ nn.Linear(hs * r // 2, class_num)
151
+ )
152
+
153
+ def forward(self, x, timestamp,attention_mask=None):
154
+ x = self.bert(x, timestamp,attention_mask=attention_mask)
155
+ attn_mat = self.attention(x)
156
+ m = torch.bmm(attn_mat, x)
157
+ flatten = m.view(m.size()[0], -1)
158
+ res = self.classifier(flatten)
159
+ return res
160
+
161
+ class CSI_BERT2(nn.Module,
162
+ PyTorchModelHubMixin
163
+ ):
164
+ def __init__(self, max_len=100, hs=128, layers=6, heads=8, intermediate_size=512, carrier_dim=52):
165
+ super().__init__()
166
+ self.config = BertConfig(max_position_embeddings=max_len, hidden_size=hs, num_hidden_layers=layers,num_attention_heads=heads, intermediate_size=intermediate_size)
167
+ self.model = CSIBERT(self.config,carrier_dim)
168
+
169
+ def forward(self, x, timestamp=None, attn_mask=None):
170
+ return self.model(x,timestamp,attn_mask)