File size: 4,441 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
The probing model following Hear Benchmark

Authors:
  * Hear Team 2021
  * Leo 2022
"""


from typing import List

import torch

import s3prl.nn.pooling as pooling

__all__ = ["HearFullyConnectedPrediction"]


class HearFullyConnectedPrediction(torch.nn.Module):
    """
    The specific prediction head used in the Hear Benchmark.
    Modified from: https://github.com/hearbenchmark/hear-eval-kit/blob/855964977238e89dfc76394aa11c37010edb6f20/heareval/predictions/task_predictions.py#L142

    Args:
        input_size (int): input_size
        output_size (int): output_size
        hidden_size (int): hidden size across all layers. Default: 1024
        hidden_layers (int): number of hidden layers, all in :code:`hidden_size`. Default: 2
        norm_after_activation (bool): whether to norm after activation. Default: False
        dropout (float): dropout ratio. Default: 0.1
        initialization (str): initialization method name available in :obj:`torch.nn.init`
        hidden_norm (str): normalization method name available in :obj:`torch.nn`
        pooling_type (str): the pooling class name in :obj:`s3prl.nn.pooling`. Default: MeanPooling
        pooling_conf (dict): the arguments for initializing the pooling class.
            Default: empty dict
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_size: int = 1024,
        hidden_layers: int = 2,
        norm_after_activation: bool = False,
        dropout: float = 0.1,
        initialization: str = "xavier_uniform_",
        hidden_norm: str = "BatchNorm1d",
        pooling_type: str = None,
        pooling_conf: dict = None,
    ):
        super().__init__()
        self._input_size = input_size
        self._output_size = output_size
        initialization = getattr(torch.nn.init, initialization)
        hidden_norm = getattr(torch.nn, hidden_norm)

        curdim = input_size
        if pooling_type is not None:
            pooling_cls = getattr(pooling, pooling_type)
            self.pooling = pooling_cls(input_size, **(pooling_conf or {}))
            curdim = self.pooling.output_size

        hidden_modules: List[torch.nn.Module] = []
        last_activation = "linear"
        if hidden_layers:
            for i in range(hidden_layers):
                linear = torch.nn.Linear(curdim, hidden_size)
                initialization(
                    linear.weight,
                    gain=torch.nn.init.calculate_gain(last_activation),
                )
                hidden_modules.append(linear)
                if not norm_after_activation:
                    hidden_modules.append(hidden_norm(hidden_size))
                hidden_modules.append(torch.nn.Dropout(dropout))
                hidden_modules.append(torch.nn.ReLU())
                if norm_after_activation:
                    hidden_modules.append(hidden_norm(hidden_size))
                curdim = hidden_size
                last_activation = "relu"

            self.hidden = torch.nn.Sequential(*hidden_modules)
        else:
            self.hidden = torch.nn.Identity()  # type: ignore
        self.projection = torch.nn.Linear(curdim, output_size)

        initialization(
            self.projection.weight, gain=torch.nn.init.calculate_gain(last_activation)
        )

    @property
    def input_size(self) -> int:
        return self._input_size

    @property
    def output_size(self) -> int:
        return self._output_size

    def forward(self, x, x_len) -> torch.Tensor:
        """
        Args:
            x (torch.FloatTensor): (batch_size, seq_len, input_size)
            x_len (torch.LongTensor): (batch_size, )

        Returns:
            tuple:

            1. y (torch.FloatTensor)
            2. y_len (torch.LongTensor)

            if :code:`pooling_type` is None, :code:`y` is (batch_size, seq_len, output_size) and :code:`y_len` is (batch_size, )
            if not None, :code:`y` is (batch_size, output_size) and :code:`y_len` is (batch_size, ) in all 1s.
        """
        if hasattr(self, "pooling"):
            x = self.pooling(x, x_len)
            x_len = x.new_ones(len(x))

        shape = x.shape
        if len(shape) == 3:
            bs, ts, hidden_size = x.shape
            x = x.reshape(bs * ts, hidden_size)

        x = self.hidden(x)
        x = self.projection(x)

        if len(shape) == 3:
            x = x.reshape(bs, ts, -1)

        return x, x_len