|
import math |
|
from typing import Sequence, Union, Callable |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
|
|
|
torch.manual_seed(10086) |
|
|
|
tensor_activation = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
|
|
class LSTM4VarLenSeq(nn.Module): |
|
def __init__(self, input_size, hidden_size, |
|
num_layers=1, bias=True, bidirectional=False, init='orthogonal', take_last=True): |
|
""" |
|
no dropout support |
|
batch_first support deprecated, the input and output tensors are |
|
provided as (batch, seq_len, feature). |
|
|
|
Args: |
|
input_size: |
|
hidden_size: |
|
num_layers: |
|
bias: |
|
bidirectional: |
|
init: ways to init the torch.nn.LSTM parameters, |
|
supports 'orthogonal' and 'uniform' |
|
take_last: 'True' if you only want the final hidden state |
|
otherwise 'False' |
|
""" |
|
super(LSTM4VarLenSeq, self).__init__() |
|
self.lstm = nn.LSTM(input_size=input_size, |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
bias=bias, |
|
bidirectional=bidirectional) |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.num_layers = num_layers |
|
self.bias = bias |
|
self.bidirectional = bidirectional |
|
self.init = init |
|
self.take_last = take_last |
|
self.batch_first = True |
|
|
|
self.init_parameters() |
|
|
|
def init_parameters(self): |
|
"""orthogonal init yields generally good results than uniform init""" |
|
if self.init == 'orthogonal': |
|
gain = 1 |
|
for nth in range(self.num_layers * self.bidirectional): |
|
|
|
nn.init.orthogonal_(self.lstm.all_weights[nth][0], gain=gain) |
|
|
|
nn.init.orthogonal_(self.lstm.all_weights[nth][1], gain=gain) |
|
|
|
nn.init.zeros_(self.lstm.all_weights[nth][2]) |
|
|
|
nn.init.zeros_(self.lstm.all_weights[nth][3]) |
|
elif self.init == 'uniform': |
|
k = math.sqrt(1 / self.hidden_size) |
|
for nth in range(self.num_layers * self.bidirectional): |
|
nn.init.uniform_(self.lstm.all_weights[nth][0], -k, k) |
|
nn.init.uniform_(self.lstm.all_weights[nth][1], -k, k) |
|
nn.init.zeros_(self.lstm.all_weights[nth][2]) |
|
nn.init.zeros_(self.lstm.all_weights[nth][3]) |
|
else: |
|
raise NotImplemented('Unsupported Initialization') |
|
|
|
def forward(self, x, x_len, hx=None): |
|
|
|
sorted_x_len, sorted_x_idx = torch.sort(x_len, descending=True) |
|
sorted_x = x[sorted_x_idx] |
|
|
|
|
|
_, unsort_x_idx = torch.sort(sorted_x_idx, descending=False) |
|
|
|
|
|
x_emb = pack_padded_sequence(sorted_x, sorted_x_len, |
|
batch_first=self.batch_first) |
|
|
|
|
|
|
|
|
|
out_packed, (hn, cn) = self.lstm(x_emb) |
|
|
|
|
|
|
|
hn = hn.permute(1, 0, 2)[unsort_x_idx] |
|
hn = hn.permute(1, 0, 2) |
|
if self.take_last: |
|
return hn.squeeze(0) |
|
else: |
|
|
|
|
|
out, _ = pad_packed_sequence(out_packed, |
|
batch_first=self.batch_first) |
|
out = out[unsort_x_idx] |
|
|
|
|
|
cn = cn.permute(1, 0, 2)[unsort_x_idx] |
|
cn = cn.permute(1, 0, 2) |
|
return out, (hn, cn) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
net = LSTM4VarLenSeq(200, 100, |
|
num_layers=3, bias=True, bidirectional=True, init='orthogonal', take_last=False) |
|
|
|
inputs = torch.tensor([[1, 2, 3, 0], |
|
[2, 3, 0, 0], |
|
[2, 4, 3, 0], |
|
[1, 4, 3, 0], |
|
[1, 2, 3, 4]]) |
|
embedding = nn.Embedding(num_embeddings=5, embedding_dim=200, padding_idx=0) |
|
lens = torch.LongTensor([3, 2, 3, 3, 4]) |
|
|
|
input_embed = embedding(inputs) |
|
output, (h, c) = net(input_embed, lens) |
|
|
|
print(output.shape) |
|
|
|
print(h.shape) |
|
|
|
print(c.shape) |
|
|