import math from typing import Sequence, Union, Callable import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence torch.manual_seed(10086) # typing, everything in Python is Object. tensor_activation = Callable[[torch.Tensor], torch.Tensor] class LSTM4VarLenSeq(nn.Module): def __init__(self, input_size, hidden_size, num_layers=1, bias=True, bidirectional=False, init='orthogonal', take_last=True): """ no dropout support batch_first support deprecated, the input and output tensors are provided as (batch, seq_len, feature). Args: input_size: hidden_size: num_layers: bias: bidirectional: init: ways to init the torch.nn.LSTM parameters, supports 'orthogonal' and 'uniform' take_last: 'True' if you only want the final hidden state otherwise 'False' """ super(LSTM4VarLenSeq, self).__init__() self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, bidirectional=bidirectional) self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.bidirectional = bidirectional self.init = init self.take_last = take_last self.batch_first = True # Please don't modify this self.init_parameters() def init_parameters(self): """orthogonal init yields generally good results than uniform init""" if self.init == 'orthogonal': gain = 1 # use default value for nth in range(self.num_layers * self.bidirectional): # w_ih, (4 * hidden_size x input_size) nn.init.orthogonal_(self.lstm.all_weights[nth][0], gain=gain) # w_hh, (4 * hidden_size x hidden_size) nn.init.orthogonal_(self.lstm.all_weights[nth][1], gain=gain) # b_ih, (4 * hidden_size) nn.init.zeros_(self.lstm.all_weights[nth][2]) # b_hh, (4 * hidden_size) nn.init.zeros_(self.lstm.all_weights[nth][3]) elif self.init == 'uniform': k = math.sqrt(1 / self.hidden_size) for nth in range(self.num_layers * self.bidirectional): nn.init.uniform_(self.lstm.all_weights[nth][0], -k, k) nn.init.uniform_(self.lstm.all_weights[nth][1], -k, k) nn.init.zeros_(self.lstm.all_weights[nth][2]) nn.init.zeros_(self.lstm.all_weights[nth][3]) else: raise NotImplemented('Unsupported Initialization') def forward(self, x, x_len, hx=None): # 1. Sort x and its corresponding length sorted_x_len, sorted_x_idx = torch.sort(x_len, descending=True) sorted_x = x[sorted_x_idx] # 2. Ready to unsort after LSTM forward pass # Note that PyTorch 0.4 has no argsort, but PyTorch 1.0 does. _, unsort_x_idx = torch.sort(sorted_x_idx, descending=False) # 3. Pack the sorted version of x and x_len, as required by the API. x_emb = pack_padded_sequence(sorted_x, sorted_x_len, batch_first=self.batch_first) # 4. Forward lstm # output_packed.data.shape is (valid_seq, num_directions * hidden_dim). # See doc of torch.nn.LSTM for details. out_packed, (hn, cn) = self.lstm(x_emb) # 5. unsort h # (num_layers * num_directions, batch, hidden_size) -> (batch, ...) hn = hn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim hn = hn.permute(1, 0, 2) # swap the first two again to recover if self.take_last: return hn.squeeze(0) else: # unpack: out # (batch, max_seq_len, num_directions * hidden_size) out, _ = pad_packed_sequence(out_packed, batch_first=self.batch_first) out = out[unsort_x_idx] # unpack: c # (num_layers * num_directions, batch, hidden_size) -> (batch, ...) cn = cn.permute(1, 0, 2)[unsort_x_idx] # swap the first two dim cn = cn.permute(1, 0, 2) # swap the first two again to recover return out, (hn, cn) if __name__ == '__main__': # Note that in the future we will import unittest # and port the following examples to test folder. # Unit test for LSTM variable length sequences # ================ net = LSTM4VarLenSeq(200, 100, num_layers=3, bias=True, bidirectional=True, init='orthogonal', take_last=False) inputs = torch.tensor([[1, 2, 3, 0], [2, 3, 0, 0], [2, 4, 3, 0], [1, 4, 3, 0], [1, 2, 3, 4]]) embedding = nn.Embedding(num_embeddings=5, embedding_dim=200, padding_idx=0) lens = torch.LongTensor([3, 2, 3, 3, 4]) input_embed = embedding(inputs) output, (h, c) = net(input_embed, lens) # 5, 4, 200, batch, seq length, hidden_size * 2 (only last layer) print(output.shape) # 6, 5, 100, num_layers * num_directions, batch, hidden_size print(h.shape) # 6, 5, 100, num_layers * num_directions, batch, hidden_size print(c.shape)