Keras
legal
File size: 5,594 Bytes
5d58b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import math
from typing import Sequence, Union, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

torch.manual_seed(10086)
# typing, everything in Python is Object.
tensor_activation = Callable[[torch.Tensor], torch.Tensor]


class LSTM4VarLenSeq(nn.Module):
    def __init__(self, input_size, hidden_size,
                 num_layers=1, bias=True, bidirectional=False, init='orthogonal', take_last=True):
        """
        no dropout support
        batch_first support deprecated, the input and output tensors are
        provided as (batch, seq_len, feature).

        Args:
            input_size:
            hidden_size:
            num_layers:
            bias:
            bidirectional:
            init: ways to init the torch.nn.LSTM parameters,
                supports 'orthogonal' and 'uniform'
            take_last: 'True' if you only want the final hidden state
                otherwise 'False'
        """
        super(LSTM4VarLenSeq, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            bias=bias,
                            bidirectional=bidirectional)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.bidirectional = bidirectional
        self.init = init
        self.take_last = take_last
        self.batch_first = True  # Please don't modify this

        self.init_parameters()

    def init_parameters(self):
        """orthogonal init yields generally good results than uniform init"""
        if self.init == 'orthogonal':
            gain = 1  # use default value
            for nth in range(self.num_layers * self.bidirectional):
                # w_ih, (4 * hidden_size x input_size)
                nn.init.orthogonal_(self.lstm.all_weights[nth][0], gain=gain)
                # w_hh, (4 * hidden_size x hidden_size)
                nn.init.orthogonal_(self.lstm.all_weights[nth][1], gain=gain)
                # b_ih, (4 * hidden_size)
                nn.init.zeros_(self.lstm.all_weights[nth][2])
                # b_hh, (4 * hidden_size)
                nn.init.zeros_(self.lstm.all_weights[nth][3])
        elif self.init == 'uniform':
            k = math.sqrt(1 / self.hidden_size)
            for nth in range(self.num_layers * self.bidirectional):
                nn.init.uniform_(self.lstm.all_weights[nth][0], -k, k)
                nn.init.uniform_(self.lstm.all_weights[nth][1], -k, k)
                nn.init.zeros_(self.lstm.all_weights[nth][2])
                nn.init.zeros_(self.lstm.all_weights[nth][3])
        else:
            raise NotImplemented('Unsupported Initialization')

    def forward(self, x, x_len, hx=None):
        # 1. Sort x and its corresponding length
        sorted_x_len, sorted_x_idx = torch.sort(x_len, descending=True)
        sorted_x = x[sorted_x_idx]
        # 2. Ready to unsort after LSTM forward pass
        # Note that PyTorch 0.4 has no argsort, but PyTorch 1.0 does.
        _, unsort_x_idx = torch.sort(sorted_x_idx, descending=False)

        # 3. Pack the sorted version of x and x_len, as required by the API.
        x_emb = pack_padded_sequence(sorted_x, sorted_x_len,
                                     batch_first=self.batch_first)

        # 4. Forward lstm
        # output_packed.data.shape is (valid_seq, num_directions * hidden_dim).
        # See doc of torch.nn.LSTM for details.
        out_packed, (hn, cn) = self.lstm(x_emb)

        # 5. unsort h
        # (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
        hn = hn.permute(1, 0, 2)[unsort_x_idx]  # swap the first two dim
        hn = hn.permute(1, 0, 2)  # swap the first two again to recover
        if self.take_last:
            return hn.squeeze(0)
        else:
            # unpack: out
            # (batch, max_seq_len, num_directions * hidden_size)
            out, _ = pad_packed_sequence(out_packed,
                                         batch_first=self.batch_first)
            out = out[unsort_x_idx]
            # unpack: c
            # (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
            cn = cn.permute(1, 0, 2)[unsort_x_idx]  # swap the first two dim
            cn = cn.permute(1, 0, 2)  # swap the first two again to recover
            return out, (hn, cn)


if __name__ == '__main__':
    # Note that in the future we will import unittest
    # and port the following examples to test folder.

    # Unit test for LSTM variable length sequences
    # ================
    net = LSTM4VarLenSeq(200, 100,
                         num_layers=3, bias=True, bidirectional=True, init='orthogonal', take_last=False)

    inputs = torch.tensor([[1, 2, 3, 0],
                           [2, 3, 0, 0],
                           [2, 4, 3, 0],
                           [1, 4, 3, 0],
                           [1, 2, 3, 4]])
    embedding = nn.Embedding(num_embeddings=5, embedding_dim=200, padding_idx=0)
    lens = torch.LongTensor([3, 2, 3, 3, 4])

    input_embed = embedding(inputs)
    output, (h, c) = net(input_embed, lens)
    # 5, 4, 200, batch, seq length, hidden_size * 2 (only last layer)
    print(output.shape)
    # 6, 5, 100, num_layers * num_directions, batch, hidden_size
    print(h.shape)
    # 6, 5, 100, num_layers * num_directions, batch, hidden_size
    print(c.shape)