RabiaSufian commited on
Commit
b32fda0
·
verified ·
1 Parent(s): 9ca1f75

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +32 -0
model.py CHANGED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class LSTMClassifier(nn.Module):
5
+ def __init__(self, input_size=1, hidden_size=32, num_layers=1,
6
+ bidirectional=True, dropout=0.0, num_classes=2):
7
+ super(LSTMClassifier, self).__init__()
8
+ self.hidden_size = hidden_size
9
+ self.num_layers = num_layers
10
+ self.bidirectional = bidirectional
11
+
12
+ self.lstm = nn.LSTM(
13
+ input_size=input_size,
14
+ hidden_size=hidden_size,
15
+ num_layers=num_layers,
16
+ batch_first=True,
17
+ dropout=dropout if num_layers > 1 else 0.0,
18
+ bidirectional=bidirectional
19
+ )
20
+
21
+ direction_factor = 2 if bidirectional else 1
22
+ self.fc = nn.Linear(hidden_size * direction_factor, num_classes)
23
+
24
+ def forward(self, x):
25
+ _, (hn, _) = self.lstm(x)
26
+ if self.bidirectional:
27
+ forward = hn[-2]
28
+ backward = hn[-1]
29
+ combined = torch.cat((forward, backward), dim=1)
30
+ else:
31
+ combined = hn[-1]
32
+ return self.fc(combined)