File size: 5,594 Bytes
5d58b52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import math
from typing import Sequence, Union, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
torch.manual_seed(10086)
# typing, everything in Python is Object.
tensor_activation = Callable[[torch.Tensor], torch.Tensor]
class LSTM4VarLenSeq(nn.Module):
def __init__(self, input_size, hidden_size,
num_layers=1, bias=True, bidirectional=False, init='orthogonal', take_last=True):
"""
no dropout support
batch_first support deprecated, the input and output tensors are
provided as (batch, seq_len, feature).
Args:
input_size:
hidden_size:
num_layers:
bias:
bidirectional:
init: ways to init the torch.nn.LSTM parameters,
supports 'orthogonal' and 'uniform'
take_last: 'True' if you only want the final hidden state
otherwise 'False'
"""
super(LSTM4VarLenSeq, self).__init__()
self.lstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
bidirectional=bidirectional)
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.bidirectional = bidirectional
self.init = init
self.take_last = take_last
self.batch_first = True # Please don't modify this
self.init_parameters()
def init_parameters(self):
"""orthogonal init yields generally good results than uniform init"""
if self.init == 'orthogonal':
gain = 1 # use default value
for nth in range(self.num_layers * self.bidirectional):
# w_ih, (4 * hidden_size x input_size)
nn.init.orthogonal_(self.lstm.all_weights[nth][0], gain=gain)
# w_hh, (4 * hidden_size x hidden_size)
nn.init.orthogonal_(self.lstm.all_weights[nth][1], gain=gain)
# b_ih, (4 * hidden_size)
nn.init.zeros_(self.lstm.all_weights[nth][2])
# b_hh, (4 * hidden_size)
nn.init.zeros_(self.lstm.all_weights[nth][3])
elif self.init == 'uniform':
k = math.sqrt(1 / self.hidden_size)
for nth in range(self.num_layers * self.bidirectional):
nn.init.uniform_(self.lstm.all_weights[nth][0], -k, k)
nn.init.uniform_(self.lstm.all_weights[nth][1], -k, k)
nn.init.zeros_(self.lstm.all_weights[nth][2])
nn.init.zeros_(self.lstm.all_weights[nth][3])
else:
raise NotImplemented('Unsupported Initialization')
def forward(self, x, x_len, hx=None):
# 1. Sort x and its corresponding length
sorted_x_len, sorted_x_idx = torch.sort(x_len, descending=True)
sorted_x = x[sorted_x_idx]
# 2. Ready to unsort after LSTM forward pass
# Note that PyTorch 0.4 has no argsort, but PyTorch 1.0 does.
_, unsort_x_idx = torch.sort(sorted_x_idx, descending=False)
# 3. Pack the sorted version of x and x_len, as required by the API.
x_emb = pack_padded_sequence(sorted_x, sorted_x_len,
batch_first=self.batch_first)
# 4. Forward lstm
# output_packed.data.shape is (valid_seq, num_directions * hidden_dim).
# See doc of torch.nn.LSTM for details.
out_packed, (hn, cn) = self.lstm(x_emb)
# 5. unsort h
# (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
hn = hn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim
hn = hn.permute(1, 0, 2) # swap the first two again to recover
if self.take_last:
return hn.squeeze(0)
else:
# unpack: out
# (batch, max_seq_len, num_directions * hidden_size)
out, _ = pad_packed_sequence(out_packed,
batch_first=self.batch_first)
out = out[unsort_x_idx]
# unpack: c
# (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
cn = cn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim
cn = cn.permute(1, 0, 2) # swap the first two again to recover
return out, (hn, cn)
if __name__ == '__main__':
# Note that in the future we will import unittest
# and port the following examples to test folder.
# Unit test for LSTM variable length sequences
# ================
net = LSTM4VarLenSeq(200, 100,
num_layers=3, bias=True, bidirectional=True, init='orthogonal', take_last=False)
inputs = torch.tensor([[1, 2, 3, 0],
[2, 3, 0, 0],
[2, 4, 3, 0],
[1, 4, 3, 0],
[1, 2, 3, 4]])
embedding = nn.Embedding(num_embeddings=5, embedding_dim=200, padding_idx=0)
lens = torch.LongTensor([3, 2, 3, 3, 4])
input_embed = embedding(inputs)
output, (h, c) = net(input_embed, lens)
# 5, 4, 200, batch, seq length, hidden_size * 2 (only last layer)
print(output.shape)
# 6, 5, 100, num_layers * num_directions, batch, hidden_size
print(h.shape)
# 6, 5, 100, num_layers * num_directions, batch, hidden_size
print(c.shape)
|