dalgarak commited on
Commit
51adcf4
·
verified ·
1 Parent(s): 361cba6

Upload 8 files

Browse files

upload gbst-base-ds6x-newblock-1144k

config.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/jhshin/gbst-base-ds6x-newblock-972k-240517/",
3
+ "architectures": [
4
+ "GBSWT5ForConditionalGeneration"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_gbswt5.GBSWT5Config",
8
+ "AutoModel": "modeling_gbswt5.GBSWT5ForModel",
9
+ "AutoModelForSeq2SeqLM": "modeling_gbswt5.GBSWT5ForConditionalGeneration"
10
+ },
11
+ "d_ff": 3968,
12
+ "d_kv": 64,
13
+ "d_model": 1536,
14
+ "decoder_start_token_id": 0,
15
+ "dense_act_fn": "gelu_new",
16
+ "downsample_factor": 6,
17
+ "dropout_rate": 0.0,
18
+ "eos_token_id": 1,
19
+ "feed_forward_proj": "gated-gelu",
20
+ "gbst_batchnorm": false,
21
+ "gradient_checkpointing": false,
22
+ "initializer_factor": 0.05,
23
+ "is_encoder_decoder": true,
24
+ "is_gated_act": true,
25
+ "kv_heads": null,
26
+ "layer_norm_epsilon": 1e-06,
27
+ "max_subword_block_size": null,
28
+ "model_type": "gbswt5",
29
+ "num_decoder_layers": 6,
30
+ "num_heads": 12,
31
+ "num_layers": 18,
32
+ "output_past": true,
33
+ "pad_token_id": 0,
34
+ "relative_attention_max_distance": 128,
35
+ "relative_attention_num_buckets": 32,
36
+ "score_consensus_attn": true,
37
+ "subword_blocks": [
38
+ [
39
+ 1,
40
+ 0
41
+ ],
42
+ [
43
+ 2,
44
+ 0
45
+ ],
46
+ [
47
+ 3,
48
+ 0
49
+ ],
50
+ [
51
+ 4,
52
+ 0
53
+ ],
54
+ [
55
+ 5,
56
+ 0
57
+ ],
58
+ [
59
+ 6,
60
+ 0
61
+ ],
62
+ [
63
+ 7,
64
+ 0
65
+ ],
66
+ [
67
+ 8,
68
+ 0
69
+ ],
70
+ [
71
+ 9,
72
+ 0
73
+ ]
74
+ ],
75
+ "tie_word_embeddings": false,
76
+ "tokenizer_class": "ByT5Tokenizer",
77
+ "torch_dtype": "float32",
78
+ "transformers_version": "4.33.2",
79
+ "use_cache": true,
80
+ "vocab_size": 384,
81
+ "z_loss": 0.0001
82
+ }
configuration_gbswt5.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GBSWT5 model configuration.
3
+
4
+ Copyright (C) 2023~ ETRI LIRS. Jong-hun Shin.
5
+ """
6
+
7
+ from typing import Mapping
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.onnx import OnnxSeq2SeqConfigWithPast
10
+ from transformers.utils import logging
11
+
12
+
13
+ logger = logging.get_logger(__name__)
14
+ _BLOCKS = (
15
+ (1, 0), (2, 0), (3, 0), (4, 0),
16
+ (6, 0), (9, 0),
17
+ #(12, 0), (12, 3), (12, 6), (12, 9)
18
+ )
19
+
20
+
21
+ class GBSWT5Config(PretrainedConfig):
22
+ """ Based on models.t5. configuration_t5. T5Config in hf Transformers. """
23
+ model_type = "gbswt5"
24
+ keys_to_ignore_at_inference = ["past_key_values"]
25
+ attribute_map = {"hidden_size": "d_model",
26
+ "num_attention_heads": "num_heads",
27
+ "num_hidden_layers": "num_layers"}
28
+
29
+ def __init__(
30
+ self,
31
+ vocab_size=384,
32
+ d_model=512,
33
+ d_kv=64,
34
+ d_ff=2048,
35
+ num_layers=6,
36
+ num_decoder_layers=None,
37
+ num_heads=8,
38
+ relative_attention_num_buckets=32,
39
+ relative_attention_max_distance=128,
40
+ dropout_rate=0.1,
41
+ layer_norm_epsilon=1e-6,
42
+ initializer_factor=1.0,
43
+ feed_forward_proj="relu",
44
+ is_encoder_decoder=True,
45
+ use_cache=True,
46
+ pad_token_id=0,
47
+ eos_token_id=1,
48
+ max_subword_block_size=None, # GBSWT-related options here from
49
+ subword_blocks=_BLOCKS,
50
+ downsample_factor=1,
51
+ score_consensus_attn=True,
52
+ z_loss=1e-4,
53
+ gbst_batchnorm=False,
54
+ **kwargs,
55
+ ):
56
+ self.vocab_size = vocab_size
57
+ self.d_model = d_model
58
+ self.d_kv = d_kv
59
+ self.d_ff = d_ff
60
+ self.num_layers = num_layers
61
+ self.num_decoder_layers = (
62
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
63
+ ) # default = symmetry
64
+ self.num_heads = num_heads
65
+ self.relative_attention_num_buckets = relative_attention_num_buckets
66
+ self.relative_attention_max_distance = relative_attention_max_distance
67
+ self.dropout_rate = dropout_rate
68
+ self.layer_norm_epsilon = layer_norm_epsilon
69
+ self.initializer_factor = initializer_factor
70
+ self.feed_forward_proj = feed_forward_proj
71
+ self.use_cache = use_cache
72
+
73
+ act_info = self.feed_forward_proj.split("-")
74
+ self.dense_act_fn = act_info[-1]
75
+ self.is_gated_act = act_info[0] == "gated"
76
+
77
+ # GBSWT-related configurations
78
+ self.max_subword_block_size = max_subword_block_size
79
+ self.subword_blocks = subword_blocks
80
+ self.downsample_factor = downsample_factor
81
+ self.score_consensus_attn = score_consensus_attn
82
+ self.gbst_batchnorm = gbst_batchnorm
83
+
84
+ # z_loss for computational stability.
85
+ # see https://github.com/tensorflow/mesh/blob \
86
+ # /fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
87
+ # (1) logits이 0으로 부터 너무 멀어지게 드리프팅 되지 않도록 하여, bf16에서 발생하는
88
+ # round-off error를 방지하기 위함. (2) 로짓이 normalized log-probabilities가 되도록 제고한다.
89
+ self.z_loss = z_loss
90
+
91
+ if self.subword_blocks is not None and isinstance(self.subword_blocks, list):
92
+ for idx, elem in enumerate(self.subword_blocks):
93
+ self.subword_blocks[idx] = tuple(elem)
94
+ self.subword_blocks = tuple(self.subword_blocks)
95
+
96
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
97
+ raise ValueError(
98
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
99
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
100
+ "'gated-gelu' or 'relu'"
101
+ )
102
+
103
+ # for backwards compatibility
104
+ if feed_forward_proj == "gated-gelu":
105
+ self.dense_act_fn = "gelu_new"
106
+
107
+ super().__init__(
108
+ pad_token_id=pad_token_id,
109
+ eos_token_id=eos_token_id,
110
+ is_encoder_decoder=is_encoder_decoder,
111
+ **kwargs,
112
+ )
113
+
114
+
115
+ class GBSWT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
116
+ """ just copy of T5OnnxConfig. """
117
+ @property
118
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
119
+ common_inputs = {
120
+ "input_ids": {0: "batch", 1: "encoder_sequence"},
121
+ "attention_mask": {0: "batch", 1: "encoder_sequence"},
122
+ }
123
+ if self.use_past:
124
+ common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
125
+ common_inputs["decoder_input_ids"] = {0: "batch"}
126
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
127
+ else:
128
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
129
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
130
+
131
+ if self.use_past:
132
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
133
+
134
+ return common_inputs
135
+
136
+ @property
137
+ def default_onnx_opset(self) -> int:
138
+ return 13
139
+
gbst.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradient-based Subword Tokenization(GBST) Layer implementation.
3
+
4
+ based on lucidrains/charformer-pytorch implementation,
5
+ which distributed under MIT License.
6
+
7
+ original code location:
8
+ https://github.com/lucidrains/charformer-pytorch/charformer_pytorch.py
9
+
10
+ copyright (c) 2023~, ETRI LIRS. Jong-hun Shin.
11
+ """
12
+ import math
13
+ import functools
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from typing import Optional
18
+
19
+ from torch import einsum, nn, Tensor
20
+ from transformers.utils import logging
21
+ from einops.layers.torch import Rearrange
22
+ from einops import rearrange, repeat
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ # Block definition
28
+ _BLOCKS = (
29
+ (1, 0), (2, 0), (3, 0), (4, 0),
30
+ (6, 0), (9, 0),
31
+ #(12, 0), (12, 3), (12, 6), (12, 9)
32
+ )
33
+
34
+ @torch.jit.script
35
+ def pad_to_multiple(in_tensor:Tensor, multiple:int, seq_dim:int,
36
+ dim:int, value:Optional[float]):
37
+ seqlen = in_tensor.shape[seq_dim]
38
+ padded_len = math.ceil(seqlen / multiple) * multiple
39
+ if seqlen == padded_len:
40
+ return in_tensor
41
+ pad_offset = (0,) * (-1 - dim) * 2
42
+ if len(pad_offset) == 0:
43
+ return F.pad(in_tensor, (0, padded_len - seqlen), value=value)
44
+ # unpack 2 dims
45
+ d1, d2 = pad_offset
46
+ return F.pad(in_tensor, (d1, d2, 0, padded_len - seqlen), value=value)
47
+
48
+
49
+
50
+
51
+ class Depthwise1dConv(nn.Module):
52
+ def __init__(self, in_dim, out_dim, krnl_size, use_bn=False):
53
+ super().__init__()
54
+ self.use_bn = use_bn
55
+ self.convol = nn.Conv1d(in_dim, out_dim, krnl_size, groups=in_dim)
56
+ # EXPERIMENTAL: add BatchNorm Layer
57
+ if self.use_bn:
58
+ self.bn = nn.BatchNorm1d(out_dim, eps=1e-05,)
59
+ self.proj = nn.Conv1d(out_dim, out_dim, 1)
60
+
61
+ @torch.cuda.amp.autocast(enabled=False, dtype=torch.float32)
62
+ def forward(self, in_tensor):
63
+ in_tensor = self.convol(in_tensor)
64
+ if self.use_bn:
65
+ in_tensor = self.bn(in_tensor)
66
+ return self.proj(in_tensor)
67
+
68
+ def _init_weights(self, factor:float=0.05):
69
+ logger.debug(f"1dConv-Weight initialize called, before: {self.convol.weight.data}")
70
+ self.convol.weight.data.normal_(mean=0.0, std=factor * 1.0)
71
+ self.proj.weight.data.normal_(mean=0.0, std=factor * 1.0)
72
+ logger.debug(f"1dConv-Weight initialize called, after: {self.convol.weight.data}")
73
+
74
+
75
+ class Padding(nn.Module):
76
+ def __init__(self, padding, value=0):
77
+ super().__init__()
78
+ self.padding = padding
79
+ self.value = value
80
+
81
+ def forward(self, in_tensor):
82
+ return F.pad(in_tensor, self.padding, value=self.value)
83
+
84
+
85
+ class GBSWT(nn.Module):
86
+ """ Gradient-based Sub-Word Tokenizer implementation. """
87
+ def __init__(self, embed_tokens,
88
+ max_block_size=None,
89
+ blocks=_BLOCKS,
90
+ downsample_factor=1,
91
+ score_consensus_attn=True,
92
+ use_bn=False,):
93
+ super().__init__()
94
+ num_tokens, dim = embed_tokens.weight.shape
95
+
96
+ assert (max_block_size is not None) ^ (blocks is not None), \
97
+ 'max_block_size or blocks must be given.'
98
+ if blocks is None:
99
+ self.blocks = tuple(map(lambda elem: (elem, 0), range(1, max_block_size+1)))
100
+ else:
101
+ if not isinstance(blocks, tuple):
102
+ raise ValueError('blocks must be assigned as a tuple')
103
+ self.blocks = tuple(map(lambda elem: elem if isinstance(elem, tuple) else (elem, 0), blocks))
104
+ if not all([(offset < block_size) for block_size, offset in self.blocks]):
105
+ raise ValueError('Offset must be smaller than given block size.')
106
+ max_block_size = max(list(map(lambda x: x[0], self.blocks)))
107
+
108
+ assert downsample_factor <= max_block_size, \
109
+ 'downsample factor must be less than the max_block_size.'
110
+
111
+ self.downsample_factor = downsample_factor
112
+ self.score_consensus_attn = score_consensus_attn
113
+ self.use_bn = use_bn
114
+ logger.debug(f"GBSWT Subword Block Combinations: {self.blocks}")
115
+ logger.debug(f"GBSWT Downsampling factor: {self.downsample_factor}, use BatchNorm: {self.use_bn}")
116
+
117
+ def lcm(*num):
118
+ return int(functools.reduce(lambda x, y: int((x * y) / math.gcd(x, y)), num, 1))
119
+
120
+ self.block_pad_multiple = lcm(*[block_size for block_size, _ in self.blocks])
121
+ #print(f"block_pad_multiple: {self.block_pad_multiple}")
122
+
123
+ # layer definition
124
+ self.embeds = embed_tokens
125
+ self.positional_convol = nn.Sequential(
126
+ Padding((0, 0, 0, max_block_size-1)),
127
+ Rearrange('b s d -> b d s'),
128
+ Depthwise1dConv(dim, dim, krnl_size=max_block_size, use_bn=self.use_bn,),
129
+ Rearrange('b d s -> b s d'))
130
+ self.cand_scoring = nn.Sequential(
131
+ nn.Linear(dim, 1),
132
+ Rearrange('... () -> ...'))
133
+
134
+ def _init_weights(self, factor:float=0.05):
135
+ self.positional_convol[2]._init_weights(factor)
136
+ #print(f"GBSTW weight initialization called: before: {self.cand_scoring[0].weight.data}")
137
+ self.cand_scoring[0].weight.data.normal_(mean=0.0, std=factor * 1.0)
138
+ #print(f"GBSTW weight initialization called: after: {self.cand_scoring[0].weight.data}")
139
+
140
+ def get_blocks(self):
141
+ """ return GBST candidate blocking list. """
142
+ return self.blocks
143
+
144
+ @torch.cuda.amp.autocast()
145
+ def forward(self, in_tensor, attention_mask=None):
146
+ b, s = in_tensor.shape
147
+ #print(f"initial shape: b, s : {b}, {s}, in_tensor.shape: {in_tensor.shape}")
148
+ mask = attention_mask
149
+ #print(f"mask: {mask}")
150
+ block_multi, ds_factor = self.block_pad_multiple, self.downsample_factor
151
+
152
+ in_tensor = self.embeds(in_tensor)
153
+ in_tensor = self.positional_convol(in_tensor)
154
+ in_tensor = pad_to_multiple(in_tensor, block_multi,
155
+ seq_dim=1, dim=-2, value=0.0)
156
+ if mask is not None:
157
+ mask = pad_to_multiple(mask, block_multi,
158
+ seq_dim=1, dim=-1, value=False)
159
+
160
+ def _masked_mean(in_tensor:Tensor, mask:Tensor, dim:int=-1):
161
+ len_diff = len(in_tensor.shape) - len(mask.shape)
162
+ mask = torch.unsqueeze(mask, dim=-len_diff)
163
+ in_tensor.masked_fill_(~(mask.bool()), 0.)
164
+
165
+ total_elems = mask.sum(dim=dim)
166
+ mean = in_tensor.sum(dim=dim) / total_elems.clamp(min=1.)
167
+ mean.masked_fill_((total_elems == 0), 0.)
168
+ return mean.float()
169
+
170
+ block_reprs, block_masks = [], []
171
+
172
+ # 이제 입력 시퀀스를 cloning해서 후보를 세팅
173
+ for block_size, offset in self.blocks:
174
+ block_in = in_tensor.clone()
175
+ if mask is not None:
176
+ block_mask = mask.clone()
177
+ need_padding = offset > 0
178
+
179
+ if need_padding:
180
+ loff, roff = (block_size - offset), offset
181
+ #print(f"loff: {loff}, roff: {roff}")
182
+ block_in = F.pad(block_in, (0, 0, loff, roff), value=0.0)
183
+ if mask is not None:
184
+ block_mask = F.pad(block_mask, (0, 0, loff, roff), value=False)
185
+
186
+ blks = rearrange(block_in, 'b (s m) d -> b s m d', m=block_size)
187
+ if mask is not None:
188
+ mask_blks = rearrange(block_mask, 'b (s m) -> b s m', m=block_size)
189
+ blk_repr = _masked_mean(blks, mask_blks, dim=-2)
190
+ else:
191
+ blk_repr = blks.mean(dim=-2)
192
+
193
+ blk_repr = repeat(blk_repr, 'b s d -> b (s m) d', m=block_size)
194
+
195
+ if need_padding:
196
+ blk_repr = blk_repr[:, loff:-roff]
197
+
198
+ block_reprs.append(blk_repr)
199
+
200
+ if mask is not None:
201
+ mask_blks = torch.any(mask_blks, dim=-1)
202
+ mask_blks = repeat(mask_blks, 'b s -> b (s m)', m=block_size)
203
+ if need_padding:
204
+ mask_blks = mask_blks[:, loff:-roff]
205
+ block_masks.append(mask_blks)
206
+
207
+ # stack them all
208
+ block_reprs = torch.stack(block_reprs, dim=2,)
209
+ scores = self.cand_scoring(block_reprs)
210
+
211
+ if mask is not None:
212
+ block_masks = torch.stack(block_masks, dim=2)
213
+ max_neg_val = -torch.finfo(scores.dtype).max
214
+ scores = scores.masked_fill(~block_masks, max_neg_val)
215
+
216
+ scores = scores.softmax(dim=2)
217
+
218
+ # cheap consensus attention, as equation (5) in paper.
219
+ if self.score_consensus_attn:
220
+ score_sim = einsum('b i d, b j d -> b i j', scores, scores)
221
+
222
+ if mask is not None:
223
+ cross_mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
224
+ max_neg_val = -torch.finfo(score_sim.dtype).max
225
+ score_sim = score_sim.masked_fill((~(cross_mask.bool())), max_neg_val)
226
+
227
+ score_attn = score_sim.softmax(dim=-1)
228
+ scores = einsum('b i j, b j m -> b i m', score_attn, scores)
229
+
230
+ scores = rearrange(scores, 'b n m -> b n m ()')
231
+ in_tensor = (block_reprs * scores).sum(dim=2)
232
+
233
+ @torch.jit.script
234
+ def _reshape_input_tensor(in_tensor:Tensor, s:int, d:int):
235
+ # get divisible length to pad
236
+ m = int(math.ceil(s / d) * d)
237
+ #print(f"_reshape_input_tensor: {m}")
238
+ return in_tensor[:, :m]
239
+
240
+ in_tensor = _reshape_input_tensor(in_tensor, s, ds_factor)
241
+ if mask is not None:
242
+ mask = _reshape_input_tensor(mask, s, ds_factor)
243
+
244
+ # downsample with mean pooling
245
+ in_tensor = rearrange(in_tensor, 'b (n m) d -> b n m d', m=ds_factor)
246
+ if mask is not None:
247
+ mask = rearrange(mask, 'b (n m) -> b n m', m=ds_factor)
248
+ in_tensor = _masked_mean(in_tensor, mask, dim=2)
249
+ mask = torch.any(mask, dim=-1)
250
+ else:
251
+ in_tensor = in_tensor.mean(dim=-2)
252
+
253
+ # tuple을 반환하기 때문에, forward()에서 [0]을 취해 바꿔줘야 한다
254
+ return in_tensor, mask
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.33.2"
7
+ }
modeling_gbswt5.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ hf transformers-compatible GBST + T5 Model implementation.
3
+
4
+ several methods are copying from huggingface/transformers/models/t5/modeling_t5.py
5
+ as Implementation Standards for compatibility. (version 4.28.1)
6
+
7
+ hf transformers' modeling_t5.py file is distributed under Apache 2.0 License.
8
+
9
+ Copyright (C) 2023, ETRI LIRS, Jong-hun Shin.
10
+ """
11
+ import copy
12
+
13
+ from typing import Optional, Union, Tuple
14
+
15
+ import torch
16
+
17
+ from torch import nn
18
+ from transformers import add_start_docstrings
19
+ from transformers.utils import logging
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutput,
22
+ BaseModelOutputWithPastAndCrossAttentions,
23
+ Seq2SeqLMOutput,
24
+ Seq2SeqModelOutput,
25
+ )
26
+ from transformers.models.t5.modeling_t5 import (
27
+ T5LayerNorm, T5Block, T5Stack,
28
+ T5Model, T5PreTrainedModel, T5ForConditionalGeneration, T5EncoderModel,
29
+ T5DenseActDense, T5DenseGatedActDense, T5Attention,
30
+ T5_START_DOCSTRING
31
+ )
32
+
33
+ from .configuration_gbswt5 import GBSWT5Config
34
+ from .gbst import GBSWT
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class GBSWT5PreTrainedModel(T5PreTrainedModel):
41
+ config_class = GBSWT5Config
42
+ base_model_prefix = "GBSWT5"
43
+ is_parallelizable = True
44
+ supports_gradient_checkpointing = True
45
+ _no_split_modules = ["T5Block"]
46
+ _keep_in_fp32_modules = ["wo"]
47
+
48
+ def _init_weights(self, module):
49
+ """Initialize the weights. 대부분은 T5PreTrainedModel을 따른다. """
50
+ factor = self.config.initializer_factor # Used for testing weights initialization
51
+ if isinstance(module, T5LayerNorm):
52
+ module.weight.data.fill_(factor * 1.0)
53
+ elif isinstance(
54
+ module,
55
+ ( GBSWT5Model, GBSWT5ForConditionalGeneration, GBSWT5EncoderModel,),
56
+ ):
57
+ # Mesh TensorFlow embeddings initialization
58
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
59
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
60
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
61
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
62
+ if hasattr(module, "qa_outputs"):
63
+ module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
64
+ module.qa_outputs.bias.data.zero_()
65
+ elif isinstance(module, T5DenseActDense):
66
+ # Mesh TensorFlow FF initialization
67
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
68
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
69
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
70
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
71
+ module.wi.bias.data.zero_()
72
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
73
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
74
+ module.wo.bias.data.zero_()
75
+ elif isinstance(module, T5DenseGatedActDense):
76
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
77
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
78
+ module.wi_0.bias.data.zero_()
79
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
80
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
81
+ module.wi_1.bias.data.zero_()
82
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
83
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
84
+ module.wo.bias.data.zero_()
85
+ elif isinstance(module, T5Attention):
86
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
87
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
88
+ d_model = self.config.d_model
89
+ key_value_proj_dim = self.config.d_kv
90
+ n_heads = self.config.num_heads
91
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
92
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
93
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
94
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
95
+ if module.has_relative_attention_bias:
96
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
97
+ elif isinstance(module, GBSWT):
98
+ module._init_weights(factor)
99
+
100
+
101
+ class GBSWT5Stack(GBSWT5PreTrainedModel):
102
+ """ implement GBST-enabled T5Model, based on HF Transformers's T5Stack. """
103
+ def __init__(self, config: GBSWT5Config, embed_tokens :nn.Embedding=None):
104
+ # 초기화는 이전의 것을 따른다. 상속이 좀 애매해서, 사실 별도로 정의해야 하나 싶기도 하다.
105
+ super().__init__(config)
106
+
107
+ # override embed_tokens, apply GBWST
108
+ self.embed_tokens = GBSWT(embed_tokens=embed_tokens,
109
+ max_block_size=config.max_subword_block_size,
110
+ blocks=config.subword_blocks,
111
+ downsample_factor=config.downsample_factor,
112
+ score_consensus_attn=config.score_consensus_attn,
113
+ use_bn=config.gbst_batchnorm,)
114
+ self.is_decoder = config.is_decoder
115
+
116
+ self.block = nn.ModuleList(
117
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
118
+ )
119
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
120
+ self.dropout = nn.Dropout(config.dropout_rate)
121
+
122
+ # Initialize weights and apply final processing, same as T5 Stack.
123
+ self.post_init()
124
+ # for Model Parallel
125
+ self.model_parallel = False
126
+ self.device_map = False
127
+ self.gradient_checkpointing = False
128
+ self.downsample_factor = config.downsample_factor
129
+
130
+ def forward(self,
131
+ input_ids=None,
132
+ attention_mask=None,
133
+ encoder_hidden_states=None,
134
+ encoder_attention_mask=None,
135
+ inputs_embeds=None,
136
+ head_mask=None,
137
+ cross_attn_head_mask=None,
138
+ past_key_values=None,
139
+ use_cache=None,
140
+ output_attentions=None,
141
+ output_hidden_states=None,
142
+ return_dict=None,
143
+ ):
144
+ """ GBST 파트를 제외하면, T5Stack.forward() 구현을 그대로 복제하였다. """
145
+ # Model parallel
146
+ if self.model_parallel:
147
+ torch.cuda.set_device(self.first_device)
148
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
149
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
150
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
151
+ output_hidden_states = (
152
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
153
+ )
154
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
155
+
156
+ if input_ids is not None and inputs_embeds is not None:
157
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
158
+ raise ValueError(
159
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
160
+ )
161
+ elif input_ids is not None:
162
+ input_shape = input_ids.size()
163
+ input_ids = input_ids.view(-1, input_shape[-1])
164
+ elif inputs_embeds is not None:
165
+ input_shape = inputs_embeds.size()[:-1]
166
+ else:
167
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
168
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
169
+
170
+ if inputs_embeds is None:
171
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
172
+ #print(f"old: {input_shape}")
173
+ inputs_embeds, attention_mask = self.embed_tokens(input_ids, attention_mask)
174
+ # for downsample_factor > 1
175
+ input_shape = inputs_embeds.size()[:-1]
176
+ #print(f"new: {input_shape}")
177
+
178
+ batch_size, seq_length = input_shape
179
+ #print(f"bs: {batch_size}, sl: {seq_length}")
180
+
181
+ # required mask seq length can be calculated via length of past
182
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
183
+ #print(f"mask_seq_length: {mask_seq_length}")
184
+
185
+ if use_cache is True:
186
+ assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
187
+
188
+ if attention_mask is None:
189
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
190
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
191
+ encoder_seq_length = encoder_hidden_states.shape[1]
192
+ encoder_attention_mask = torch.ones(
193
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
194
+ )
195
+
196
+ # initialize past_key_values with `None` if past does not exist
197
+ if past_key_values is None:
198
+ past_key_values = [None] * len(self.block)
199
+
200
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
201
+ # ourselves in which case we just need to make it broadcastable to all heads.
202
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
203
+
204
+ # If a 2D or 3D attention mask is provided for the cross-attention
205
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
206
+ if self.is_decoder and encoder_hidden_states is not None:
207
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
208
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
209
+ if encoder_attention_mask is None:
210
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
211
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
212
+ else:
213
+ encoder_extended_attention_mask = None
214
+
215
+ if self.gradient_checkpointing and self.training:
216
+ if use_cache:
217
+ logger.warning_once(
218
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
219
+ )
220
+ use_cache = False
221
+
222
+ # Prepare head mask if needed
223
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
224
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
225
+ present_key_value_states = () if use_cache else None
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attentions = () if output_attentions else None
228
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
229
+ position_bias = None
230
+ encoder_decoder_position_bias = None
231
+
232
+ hidden_states = self.dropout(inputs_embeds)
233
+
234
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
235
+ layer_head_mask = head_mask[i]
236
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
237
+ # Model parallel
238
+ if self.model_parallel:
239
+ torch.cuda.set_device(hidden_states.device)
240
+ # Ensure that attention_mask is always on the same device as hidden_states
241
+ if attention_mask is not None:
242
+ attention_mask = attention_mask.to(hidden_states.device)
243
+ if position_bias is not None:
244
+ position_bias = position_bias.to(hidden_states.device)
245
+ if encoder_hidden_states is not None:
246
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
247
+ if encoder_extended_attention_mask is not None:
248
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
249
+ if encoder_decoder_position_bias is not None:
250
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
251
+ if layer_head_mask is not None:
252
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
253
+ if cross_attn_layer_head_mask is not None:
254
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
255
+ if output_hidden_states:
256
+ all_hidden_states = all_hidden_states + (hidden_states,)
257
+
258
+ if self.gradient_checkpointing and self.training:
259
+
260
+ def create_custom_forward(module):
261
+ def custom_forward(*inputs):
262
+ return tuple(module(*inputs, use_cache, output_attentions))
263
+
264
+ return custom_forward
265
+
266
+ layer_outputs = checkpoint(
267
+ create_custom_forward(layer_module),
268
+ hidden_states,
269
+ extended_attention_mask,
270
+ position_bias,
271
+ encoder_hidden_states,
272
+ encoder_extended_attention_mask,
273
+ encoder_decoder_position_bias,
274
+ layer_head_mask,
275
+ cross_attn_layer_head_mask,
276
+ None, # past_key_value is always None with gradient checkpointing
277
+ )
278
+ else:
279
+ layer_outputs = layer_module(
280
+ hidden_states,
281
+ attention_mask=extended_attention_mask,
282
+ position_bias=position_bias,
283
+ encoder_hidden_states=encoder_hidden_states,
284
+ encoder_attention_mask=encoder_extended_attention_mask,
285
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
286
+ layer_head_mask=layer_head_mask,
287
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
288
+ past_key_value=past_key_value,
289
+ use_cache=use_cache,
290
+ output_attentions=output_attentions,
291
+ )
292
+
293
+ # layer_outputs is a tuple with:
294
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
295
+ if use_cache is False:
296
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
297
+
298
+ hidden_states, present_key_value_state = layer_outputs[:2]
299
+
300
+ # We share the position biases between the layers - the first layer store them
301
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
302
+ # (cross-attention position bias), (cross-attention weights)
303
+ position_bias = layer_outputs[2]
304
+ if self.is_decoder and encoder_hidden_states is not None:
305
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
306
+ # append next layer key value states
307
+ if use_cache:
308
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
309
+
310
+ if output_attentions:
311
+ all_attentions = all_attentions + (layer_outputs[3],)
312
+ if self.is_decoder:
313
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
314
+
315
+ # Model Parallel: If it's the last layer for that device, put things on the next device
316
+ if self.model_parallel:
317
+ for k, v in self.device_map.items():
318
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
319
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
320
+
321
+ hidden_states = self.final_layer_norm(hidden_states)
322
+ hidden_states = self.dropout(hidden_states)
323
+
324
+ # Add last layer
325
+ if output_hidden_states:
326
+ all_hidden_states = all_hidden_states + (hidden_states,)
327
+
328
+ if not return_dict:
329
+ return tuple(
330
+ v
331
+ for v in [
332
+ hidden_states,
333
+ present_key_value_states,
334
+ all_hidden_states,
335
+ all_attentions,
336
+ all_cross_attentions,
337
+ ]
338
+ if v is not None
339
+ ), attention_mask
340
+
341
+ # must be return downsampled attention_mask
342
+ return BaseModelOutputWithPastAndCrossAttentions(
343
+ last_hidden_state=hidden_states,
344
+ past_key_values=present_key_value_states,
345
+ hidden_states=all_hidden_states,
346
+ attentions=all_attentions,
347
+ cross_attentions=all_cross_attentions,
348
+ ), attention_mask
349
+
350
+ def get_input_embeddings(self):
351
+ return self.embed_tokens.embeds
352
+
353
+ def set_input_embeddings(self, new_embeddings):
354
+ self.embed_tokens.embeds = new_embeddings
355
+
356
+
357
+ GBSWT5Stack.parallelize = T5Stack.parallelize
358
+ GBSWT5Stack.deparallelize = T5Stack.deparallelize
359
+
360
+
361
+ class GBSWT5Model(GBSWT5PreTrainedModel):
362
+ _keys_to_ignore_on_load_unexpected = [
363
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
364
+ ]
365
+ _tied_weights_keys = ["encoder.embed_tokens.embeds.weight", "decoder_embed_tokens.embeds.weight"]
366
+
367
+ def __init__(self, config: GBSWT5Config):
368
+ """ override T5Model """
369
+ # override some default missing parameters for pretrained ByT5 models (e.g. google/byt5-small)
370
+ if not hasattr(config, 'max_subword_block_size'):
371
+ config.max_subword_block_size = None
372
+ if not hasattr(config, 'subword_blocks'):
373
+ config.subword_blocks = ((1, 0), (2, 0), (3, 0), (6, 0), (9, 0),)
374
+ if not hasattr(config, 'downsample_factor'):
375
+ config.downsample_factor = 1
376
+ if not hasattr(config, 'score_consensus_attn'):
377
+ config.score_consensus_attn = True
378
+
379
+ super().__init__(config)
380
+
381
+ # naive T5와 같이 embedding은 공유함
382
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
383
+
384
+ encoder_cfg = copy.deepcopy(config)
385
+ encoder_cfg.is_decoder = False
386
+ encoder_cfg.use_cache = False
387
+ encoder_cfg.is_encoder_decoder = False
388
+ self.encoder = GBSWT5Stack(encoder_cfg, self.shared)
389
+
390
+ # Embedding base를 공유하기는 하지만, decoder에는 GBSWT를
391
+ # 적용하지 않아야 한다.
392
+ decoder_cfg = copy.deepcopy(config)
393
+ decoder_cfg.is_decoder = True
394
+ decoder_cfg.is_encoder_decoder = False
395
+ decoder_cfg.num_layers = config.num_decoder_layers
396
+ self.decoder = T5Stack(decoder_cfg, self.shared)
397
+
398
+ self.post_init()
399
+
400
+ self.model_parallel = False
401
+ self.device_map = None
402
+
403
+ def forward(self,
404
+ input_ids: Optional[torch.LongTensor] = None,
405
+ attention_mask: Optional[torch.FloatTensor] = None,
406
+ decoder_input_ids: Optional[torch.LongTensor] = None,
407
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
408
+ head_mask: Optional[torch.FloatTensor] = None,
409
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
410
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
411
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
412
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
413
+ inputs_embeds: Optional[torch.Tensor] = None,
414
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
415
+ use_cache: Optional[bool] = None,
416
+ output_attentions: Optional[bool] = None,
417
+ output_hidden_states: Optional[bool] = None,
418
+ return_dict: Optional[bool] = None,
419
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
420
+ """
421
+ 중요한 것은, downsampling이 된 경우 attention_mask가 변경되므로,
422
+ 이를 반영해주는 것이 필요하다. hf transformers 4.29.1에서 복제함
423
+ """
424
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
425
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
426
+
427
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
428
+ if head_mask is not None and decoder_head_mask is None:
429
+ if self.config.num_layers == self.config.num_decoder_layers:
430
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
431
+ decoder_head_mask = head_mask
432
+
433
+ # Encode if needed (training, first prediction pass)
434
+ if encoder_outputs is None:
435
+ encoder_outputs, attention_mask = self.encoder(
436
+ input_ids=input_ids,
437
+ attention_mask=attention_mask,
438
+ inputs_embeds=inputs_embeds,
439
+ head_mask=head_mask,
440
+ output_attentions=output_attentions,
441
+ output_hidden_states=output_hidden_states,
442
+ return_dict=return_dict,
443
+ )
444
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
445
+ # inference mode (e.g. .generate()) - must dewrap encoder output 'tuple'
446
+ encoder_outputs, attention_mask = encoder_outputs
447
+ encoder_outputs = BaseModelOutput(
448
+ last_hidden_state=encoder_outputs[0],
449
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
450
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
451
+ )
452
+
453
+ hidden_states = encoder_outputs[0]
454
+
455
+ # Set device for model parallelism
456
+ if self.model_parallel:
457
+ torch.cuda.set_device(self.decoder.first_device)
458
+ hidden_states = hidden_states.to(self.decoder.first_device)
459
+ if decoder_input_ids is not None:
460
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
461
+ if attention_mask is not None:
462
+ attention_mask = attention_mask.to(self.decoder.first_device)
463
+ if decoder_attention_mask is not None:
464
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
465
+
466
+ # Decode
467
+ decoder_outputs = self.decoder(
468
+ input_ids=decoder_input_ids,
469
+ attention_mask=decoder_attention_mask,
470
+ inputs_embeds=decoder_inputs_embeds,
471
+ past_key_values=past_key_values,
472
+ encoder_hidden_states=hidden_states,
473
+ encoder_attention_mask=attention_mask,
474
+ head_mask=decoder_head_mask,
475
+ cross_attn_head_mask=cross_attn_head_mask,
476
+ use_cache=use_cache,
477
+ output_attentions=output_attentions,
478
+ output_hidden_states=output_hidden_states,
479
+ return_dict=return_dict,
480
+ )
481
+
482
+ if not return_dict:
483
+ return decoder_outputs + encoder_outputs
484
+
485
+ return Seq2SeqModelOutput(
486
+ last_hidden_state=decoder_outputs.last_hidden_state,
487
+ past_key_values=decoder_outputs.past_key_values,
488
+ decoder_hidden_states=decoder_outputs.hidden_states,
489
+ decoder_attentions=decoder_outputs.attentions,
490
+ cross_attentions=decoder_outputs.cross_attentions,
491
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
492
+ encoder_hidden_states=encoder_outputs.hidden_states,
493
+ encoder_attentions=encoder_outputs.attentions,
494
+ )
495
+
496
+
497
+ GBSWT5Model.parallelize = T5Model.parallelize
498
+ GBSWT5Model.deparallelize = T5Model.deparallelize
499
+ GBSWT5Model.get_input_embeddings = T5Model.get_input_embeddings
500
+ GBSWT5Model.set_input_embeddings = T5Model.set_input_embeddings
501
+ GBSWT5Model.get_encoder = T5Model.get_encoder
502
+ GBSWT5Model._prune_heads = T5Model._prune_heads
503
+
504
+
505
+ @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
506
+ class GBSWT5ForConditionalGeneration(GBSWT5PreTrainedModel):
507
+ _keys_to_ignore_on_load_unexpected = [
508
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
509
+ ]
510
+ _tied_weights_keys = ["encoder.embed_tokens.embeds.weight",
511
+ "decoder_embed_tokens.embeds.weight",
512
+ "lm_head.weight"]
513
+
514
+ def __init__(self, config: GBSWT5Config):
515
+ # override some default missing parameters for pretrained ByT5 models (e.g. google/byt5-small)
516
+ if not hasattr(config, 'max_subword_block_size'):
517
+ config.max_subword_block_size = None
518
+ if not hasattr(config, 'subword_blocks'):
519
+ config.subword_blocks = ((1, 0), (2, 0), (3, 0), (6, 0), (9, 0),)
520
+ if not hasattr(config, 'downsample_factor'):
521
+ config.downsample_factor = 1
522
+ if not hasattr(config, 'score_consensus_attn'):
523
+ config.score_consensus_attn = True
524
+
525
+ # Grandparent의 init를 그대로 상속, 나머지는 T5ForConditionalGeneration을 따름
526
+ super().__init__(config)
527
+
528
+ self.model_dim = config.d_model
529
+
530
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
531
+
532
+ encoder_cfg = copy.deepcopy(config)
533
+ encoder_cfg.is_decoder = False
534
+ encoder_cfg.use_cache = False
535
+ encoder_cfg.is_encoder_decoder = False
536
+ self.encoder = GBSWT5Stack(encoder_cfg, self.shared)
537
+
538
+ # Embedding base를 공유하기는 하지만, decoder에는 GBSWT를
539
+ # 적용하지 않아야 한다.
540
+ decoder_cfg = copy.deepcopy(config)
541
+ decoder_cfg.is_decoder = True
542
+ decoder_cfg.is_encoder_decoder = False
543
+ decoder_cfg.num_layers = config.num_decoder_layers
544
+ self.decoder = T5Stack(decoder_cfg, self.shared)
545
+
546
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
547
+
548
+ # Initialize weights and apply final processing
549
+ self.post_init()
550
+
551
+ # Model parallel
552
+ self.model_parallel = False
553
+ self.device_map = None
554
+
555
+ def forward(self,
556
+ input_ids: Optional[torch.LongTensor] = None,
557
+ attention_mask: Optional[torch.FloatTensor] = None,
558
+ decoder_input_ids: Optional[torch.LongTensor] = None,
559
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
560
+ head_mask: Optional[torch.FloatTensor] = None,
561
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
562
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
563
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
564
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
565
+ inputs_embeds: Optional[torch.FloatTensor] = None,
566
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
567
+ labels: Optional[torch.LongTensor] = None,
568
+ use_cache: Optional[bool] = None,
569
+ output_attentions: Optional[bool] = None,
570
+ output_hidden_states: Optional[bool] = None,
571
+ return_dict: Optional[bool] = None,
572
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
573
+ """
574
+ 중요한 것은 encoder outputs에서 수정된 attention_mask를 다시 반영해야 하는 것임
575
+ downsampling이 들어간 경우, attention_mask가 변경되기 때문.
576
+ """
577
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
581
+ if head_mask is not None and decoder_head_mask is None:
582
+ if self.config.num_layers == self.config.num_decoder_layers:
583
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
584
+ decoder_head_mask = head_mask
585
+
586
+ # Encode if needed (training, first prediction pass)
587
+ if encoder_outputs is None:
588
+ # Convert encoder inputs in embeddings if needed
589
+ encoder_outputs, attention_mask = self.encoder(
590
+ input_ids=input_ids,
591
+ attention_mask=attention_mask,
592
+ inputs_embeds=inputs_embeds,
593
+ head_mask=head_mask,
594
+ output_attentions=output_attentions,
595
+ output_hidden_states=output_hidden_states,
596
+ return_dict=return_dict,
597
+ )
598
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
599
+ # inference mode (e.g. .generate()) - must dewrap encoder output 'tuple'
600
+ encoder_outputs, attention_mask = encoder_outputs
601
+ encoder_outputs = BaseModelOutput(
602
+ last_hidden_state=encoder_outputs[0],
603
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
604
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
605
+ )
606
+
607
+ hidden_states = encoder_outputs[0]
608
+
609
+ if self.model_parallel:
610
+ torch.cuda.set_device(self.decoder.first_device)
611
+
612
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
613
+ # get decoder inputs from shifting lm labels to the right
614
+ decoder_input_ids = self._shift_right(labels)
615
+
616
+ # Set device for model parallelism
617
+ if self.model_parallel:
618
+ torch.cuda.set_device(self.decoder.first_device)
619
+ hidden_states = hidden_states.to(self.decoder.first_device)
620
+ if decoder_input_ids is not None:
621
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
622
+ if attention_mask is not None:
623
+ attention_mask = attention_mask.to(self.decoder.first_device)
624
+ if decoder_attention_mask is not None:
625
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
626
+
627
+ # Decode
628
+ decoder_outputs = self.decoder(
629
+ input_ids=decoder_input_ids,
630
+ attention_mask=decoder_attention_mask,
631
+ inputs_embeds=decoder_inputs_embeds,
632
+ past_key_values=past_key_values,
633
+ encoder_hidden_states=hidden_states,
634
+ encoder_attention_mask=attention_mask,
635
+ head_mask=decoder_head_mask,
636
+ cross_attn_head_mask=cross_attn_head_mask,
637
+ use_cache=use_cache,
638
+ output_attentions=output_attentions,
639
+ output_hidden_states=output_hidden_states,
640
+ return_dict=return_dict,
641
+ )
642
+
643
+ sequence_output = decoder_outputs[0]
644
+
645
+ # Set device for model parallelism
646
+ if self.model_parallel:
647
+ torch.cuda.set_device(self.encoder.first_device)
648
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
649
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
650
+
651
+ if self.config.tie_word_embeddings:
652
+ # Rescale output before projecting on vocab
653
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
654
+ sequence_output = sequence_output * (self.model_dim**-0.5)
655
+
656
+ lm_logits = self.lm_head(sequence_output)
657
+
658
+ loss = None
659
+ if labels is not None:
660
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
661
+ # move labels to correct device to enable PP
662
+ labels = labels.to(lm_logits.device)
663
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
664
+ # add z_loss for computational stability in bf16 amp.
665
+ # see https://github.com/huggingface/transformers/pull/10956#issuecomment-820712267
666
+ if self.config.z_loss != 0.0:
667
+ log_z = lm_logits.view(-1).logsumexp(-1)
668
+ loss += self.config.z_loss * log_z.square()
669
+
670
+ if not return_dict:
671
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
672
+ return ((loss,) + output) if loss is not None else output
673
+
674
+ return Seq2SeqLMOutput(
675
+ loss=loss,
676
+ logits=lm_logits,
677
+ past_key_values=decoder_outputs.past_key_values,
678
+ decoder_hidden_states=decoder_outputs.hidden_states,
679
+ decoder_attentions=decoder_outputs.attentions,
680
+ cross_attentions=decoder_outputs.cross_attentions,
681
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
682
+ encoder_hidden_states=encoder_outputs.hidden_states,
683
+ encoder_attentions=encoder_outputs.attentions,
684
+ )
685
+
686
+
687
+ GBSWT5ForConditionalGeneration.parallelize = T5ForConditionalGeneration.parallelize
688
+ GBSWT5ForConditionalGeneration.deparallelize = T5ForConditionalGeneration.deparallelize
689
+ GBSWT5ForConditionalGeneration.get_input_embeddings = T5ForConditionalGeneration.get_input_embeddings
690
+ GBSWT5ForConditionalGeneration.set_input_embeddings = T5ForConditionalGeneration.set_input_embeddings
691
+ GBSWT5ForConditionalGeneration.get_output_embeddings = T5ForConditionalGeneration.get_output_embeddings
692
+ GBSWT5ForConditionalGeneration.set_output_embeddings = T5ForConditionalGeneration.set_output_embeddings
693
+ GBSWT5ForConditionalGeneration.get_encoder = T5ForConditionalGeneration.get_encoder
694
+ GBSWT5ForConditionalGeneration.prepare_inputs_for_generation = T5ForConditionalGeneration.prepare_inputs_for_generation
695
+ GBSWT5ForConditionalGeneration.prepare_decoder_input_ids_from_labels = T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
696
+ GBSWT5ForConditionalGeneration._reorder_cache = T5ForConditionalGeneration._reorder_cache
697
+ GBSWT5ForConditionalGeneration._prune_heads = T5Model._prune_heads
698
+
699
+
700
+ class GBSWT5EncoderModel(T5PreTrainedModel):
701
+ _tied_weights_keys = ["encoder.embed_tokens.embeds.weight"]
702
+
703
+ def __init__(self, config: GBSWT5Config):
704
+ # override some default missing parameters for pretrained ByT5 models (e.g. google/byt5-small)
705
+ if not hasattr(config, 'max_subword_block_size'):
706
+ config.max_subword_block_size = None
707
+ if not hasattr(config, 'subword_blocks'):
708
+ config.subword_blocks = ((1, 0), (2, 0), (3, 0), (6, 0), (9, 0),)
709
+ if not hasattr(config, 'downsample_factor'):
710
+ config.downsample_factor = 1
711
+ if not hasattr(config, 'score_consensus_attn'):
712
+ config.score_consensus_attn = True
713
+
714
+ super().__init__(config)
715
+
716
+ # naive T5와 같이 embedding은 공유함
717
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
718
+
719
+ encoder_cfg = copy.deepcopy(config)
720
+ encoder_cfg.is_decoder = False
721
+ encoder_cfg.use_cache = False
722
+ encoder_cfg.is_encoder_decoder = False
723
+ self.encoder = GBSWT5Stack(encoder_cfg, self.shared)
724
+
725
+ self.post_init()
726
+
727
+ self.model_parallel = False
728
+ self.device_map = None
729
+
730
+ def forward(self,
731
+ input_ids: Optional[torch.LongTensor] = None,
732
+ attention_mask: Optional[torch.FloatTensor] = None,
733
+ head_mask: Optional[torch.FloatTensor] = None,
734
+ inputs_embeds: Optional[torch.FloatTensor] = None,
735
+ output_attentions: Optional[bool] = None,
736
+ output_hidden_states: Optional[bool] = None,
737
+ return_dict: Optional[bool] = None,
738
+ return_resized_attention_mask: Optional[bool] = None,
739
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
740
+ r"""
741
+ downsampled 된 attention_mask를 함께 반환한다. 단, return_resized_attention_mask=True일 때만.
742
+ ```"""
743
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
744
+
745
+ encoder_outputs, attention_mask = self.encoder(
746
+ input_ids=input_ids,
747
+ attention_mask=attention_mask,
748
+ inputs_embeds=inputs_embeds,
749
+ head_mask=head_mask,
750
+ output_attentions=output_attentions,
751
+ output_hidden_states=output_hidden_states,
752
+ return_dict=return_dict,
753
+ )
754
+
755
+ if return_resized_attention_mask:
756
+ return encoder_outputs, attention_mask
757
+
758
+ return encoder_outputs
759
+
760
+
761
+ GBSWT5EncoderModel.parallelize = T5EncoderModel.parallelize
762
+ GBSWT5EncoderModel.deparallelize = T5EncoderModel.deparallelize
763
+ GBSWT5EncoderModel.get_input_embeddings = T5EncoderModel.get_input_embeddings
764
+ GBSWT5EncoderModel.set_input_embeddings = T5EncoderModel.set_input_embeddings
765
+ GBSWT5EncoderModel.get_encoder = T5EncoderModel.get_encoder
766
+ GBSWT5EncoderModel._prune_heads = T5EncoderModel._prune_heads
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c90684b6b4905ad58b09a39183a3651f2c969f1c584cf5733886c2239d7519a4
3
+ size 2336209429
special_tokens_map.json ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>",
103
+ "<extra_id_100>",
104
+ "<extra_id_101>",
105
+ "<extra_id_102>",
106
+ "<extra_id_103>",
107
+ "<extra_id_104>",
108
+ "<extra_id_105>",
109
+ "<extra_id_106>",
110
+ "<extra_id_107>",
111
+ "<extra_id_108>",
112
+ "<extra_id_109>",
113
+ "<extra_id_110>",
114
+ "<extra_id_111>",
115
+ "<extra_id_112>",
116
+ "<extra_id_113>",
117
+ "<extra_id_114>",
118
+ "<extra_id_115>",
119
+ "<extra_id_116>",
120
+ "<extra_id_117>",
121
+ "<extra_id_118>",
122
+ "<extra_id_119>",
123
+ "<extra_id_120>",
124
+ "<extra_id_121>",
125
+ "<extra_id_122>",
126
+ "<extra_id_123>",
127
+ "<extra_id_124>"
128
+ ],
129
+ "eos_token": {
130
+ "content": "</s>",
131
+ "lstrip": false,
132
+ "normalized": true,
133
+ "rstrip": false,
134
+ "single_word": false
135
+ },
136
+ "pad_token": {
137
+ "content": "<pad>",
138
+ "lstrip": false,
139
+ "normalized": true,
140
+ "rstrip": false,
141
+ "single_word": false
142
+ },
143
+ "unk_token": {
144
+ "content": "<unk>",
145
+ "lstrip": false,
146
+ "normalized": true,
147
+ "rstrip": false,
148
+ "single_word": false
149
+ }
150
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>",
103
+ "<extra_id_100>",
104
+ "<extra_id_101>",
105
+ "<extra_id_102>",
106
+ "<extra_id_103>",
107
+ "<extra_id_104>",
108
+ "<extra_id_105>",
109
+ "<extra_id_106>",
110
+ "<extra_id_107>",
111
+ "<extra_id_108>",
112
+ "<extra_id_109>",
113
+ "<extra_id_110>",
114
+ "<extra_id_111>",
115
+ "<extra_id_112>",
116
+ "<extra_id_113>",
117
+ "<extra_id_114>",
118
+ "<extra_id_115>",
119
+ "<extra_id_116>",
120
+ "<extra_id_117>",
121
+ "<extra_id_118>",
122
+ "<extra_id_119>",
123
+ "<extra_id_120>",
124
+ "<extra_id_121>",
125
+ "<extra_id_122>",
126
+ "<extra_id_123>",
127
+ "<extra_id_124>"
128
+ ],
129
+ "clean_up_tokenization_spaces": true,
130
+ "eos_token": {
131
+ "__type": "AddedToken",
132
+ "content": "</s>",
133
+ "lstrip": false,
134
+ "normalized": true,
135
+ "rstrip": false,
136
+ "single_word": false
137
+ },
138
+ "extra_ids": 125,
139
+ "model_max_length": 1000000000000000019884624838656,
140
+ "pad_token": {
141
+ "__type": "AddedToken",
142
+ "content": "<pad>",
143
+ "lstrip": false,
144
+ "normalized": true,
145
+ "rstrip": false,
146
+ "single_word": false
147
+ },
148
+ "tokenizer_class": "ByT5Tokenizer",
149
+ "unk_token": {
150
+ "__type": "AddedToken",
151
+ "content": "<unk>",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false
156
+ }
157
+ }