Keras
legal
kevin110211's picture
Upload 51 files
5d58b52
import math
from typing import Sequence, Union, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
torch.manual_seed(10086)
# typing, everything in Python is Object.
tensor_activation = Callable[[torch.Tensor], torch.Tensor]
class LSTM4VarLenSeq(nn.Module):
def __init__(self, input_size, hidden_size,
num_layers=1, bias=True, bidirectional=False, init='orthogonal', take_last=True):
"""
no dropout support
batch_first support deprecated, the input and output tensors are
provided as (batch, seq_len, feature).
Args:
input_size:
hidden_size:
num_layers:
bias:
bidirectional:
init: ways to init the torch.nn.LSTM parameters,
supports 'orthogonal' and 'uniform'
take_last: 'True' if you only want the final hidden state
otherwise 'False'
"""
super(LSTM4VarLenSeq, self).__init__()
self.lstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
bidirectional=bidirectional)
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.bidirectional = bidirectional
self.init = init
self.take_last = take_last
self.batch_first = True # Please don't modify this
self.init_parameters()
def init_parameters(self):
"""orthogonal init yields generally good results than uniform init"""
if self.init == 'orthogonal':
gain = 1 # use default value
for nth in range(self.num_layers * self.bidirectional):
# w_ih, (4 * hidden_size x input_size)
nn.init.orthogonal_(self.lstm.all_weights[nth][0], gain=gain)
# w_hh, (4 * hidden_size x hidden_size)
nn.init.orthogonal_(self.lstm.all_weights[nth][1], gain=gain)
# b_ih, (4 * hidden_size)
nn.init.zeros_(self.lstm.all_weights[nth][2])
# b_hh, (4 * hidden_size)
nn.init.zeros_(self.lstm.all_weights[nth][3])
elif self.init == 'uniform':
k = math.sqrt(1 / self.hidden_size)
for nth in range(self.num_layers * self.bidirectional):
nn.init.uniform_(self.lstm.all_weights[nth][0], -k, k)
nn.init.uniform_(self.lstm.all_weights[nth][1], -k, k)
nn.init.zeros_(self.lstm.all_weights[nth][2])
nn.init.zeros_(self.lstm.all_weights[nth][3])
else:
raise NotImplemented('Unsupported Initialization')
def forward(self, x, x_len, hx=None):
# 1. Sort x and its corresponding length
sorted_x_len, sorted_x_idx = torch.sort(x_len, descending=True)
sorted_x = x[sorted_x_idx]
# 2. Ready to unsort after LSTM forward pass
# Note that PyTorch 0.4 has no argsort, but PyTorch 1.0 does.
_, unsort_x_idx = torch.sort(sorted_x_idx, descending=False)
# 3. Pack the sorted version of x and x_len, as required by the API.
x_emb = pack_padded_sequence(sorted_x, sorted_x_len,
batch_first=self.batch_first)
# 4. Forward lstm
# output_packed.data.shape is (valid_seq, num_directions * hidden_dim).
# See doc of torch.nn.LSTM for details.
out_packed, (hn, cn) = self.lstm(x_emb)
# 5. unsort h
# (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
hn = hn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim
hn = hn.permute(1, 0, 2) # swap the first two again to recover
if self.take_last:
return hn.squeeze(0)
else:
# unpack: out
# (batch, max_seq_len, num_directions * hidden_size)
out, _ = pad_packed_sequence(out_packed,
batch_first=self.batch_first)
out = out[unsort_x_idx]
# unpack: c
# (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
cn = cn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim
cn = cn.permute(1, 0, 2) # swap the first two again to recover
return out, (hn, cn)
if __name__ == '__main__':
# Note that in the future we will import unittest
# and port the following examples to test folder.
# Unit test for LSTM variable length sequences
# ================
net = LSTM4VarLenSeq(200, 100,
num_layers=3, bias=True, bidirectional=True, init='orthogonal', take_last=False)
inputs = torch.tensor([[1, 2, 3, 0],
[2, 3, 0, 0],
[2, 4, 3, 0],
[1, 4, 3, 0],
[1, 2, 3, 4]])
embedding = nn.Embedding(num_embeddings=5, embedding_dim=200, padding_idx=0)
lens = torch.LongTensor([3, 2, 3, 3, 4])
input_embed = embedding(inputs)
output, (h, c) = net(input_embed, lens)
# 5, 4, 200, batch, seq length, hidden_size * 2 (only last layer)
print(output.shape)
# 6, 5, 100, num_layers * num_directions, batch, hidden_size
print(h.shape)
# 6, 5, 100, num_layers * num_directions, batch, hidden_size
print(c.shape)