fmthoker commited on
Commit
401fa20
·
verified ·
1 Parent(s): baeec23

Upload 95 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. annotation_files/k400/test.csv +0 -0
  3. annotation_files/k400/train.csv +3 -0
  4. annotation_files/k400/val.csv +0 -0
  5. annotation_files/object_instances.txt +0 -0
  6. annotation_files/ssv2/test.csv +0 -0
  7. annotation_files/ssv2/train.csv +3 -0
  8. annotation_files/ssv2/val.csv +0 -0
  9. configs/beit-base-patch16-224-pt22k-ft22k.json +30 -0
  10. configs/config.py +132 -0
  11. configs/config_bert.json +22 -0
  12. configs/config_bert_large.json +25 -0
  13. configs/data.py +195 -0
  14. configs/model.py +27 -0
  15. configs/pretrain.py +101 -0
  16. configs/qa.py +20 -0
  17. configs/qa_anet.py +27 -0
  18. configs/qa_msrvtt.py +27 -0
  19. configs/ret_anet.py +27 -0
  20. configs/ret_coco.py +37 -0
  21. configs/ret_didemo.py +36 -0
  22. configs/ret_flickr.py +37 -0
  23. configs/ret_msrvtt.py +31 -0
  24. configs/ret_msrvtt_9k.py +7 -0
  25. configs/ret_msrvtt_mc.py +30 -0
  26. configs/ret_ssv2_label.py +24 -0
  27. configs/ret_ssv2_template.py +24 -0
  28. configs/tvqa.py +36 -0
  29. figs/smile.jpg +3 -0
  30. models_viclip/__init__.py +0 -0
  31. models_viclip/backbones/__init__.py +0 -0
  32. models_viclip/backbones/beit/__init__.py +0 -0
  33. models_viclip/backbones/beit/builder.py +85 -0
  34. models_viclip/backbones/beit/st_beit.py +1749 -0
  35. models_viclip/backbones/bert/.tokenization_bert.py.swp +0 -0
  36. models_viclip/backbones/bert/__init__.py +0 -0
  37. models_viclip/backbones/bert/__pycache__/__init__.cpython-310.pyc +0 -0
  38. models_viclip/backbones/bert/__pycache__/__init__.cpython-38.pyc +0 -0
  39. models_viclip/backbones/bert/__pycache__/tokenization_bert.cpython-310.pyc +0 -0
  40. models_viclip/backbones/bert/__pycache__/tokenization_bert.cpython-38.pyc +0 -0
  41. models_viclip/backbones/bert/builder.py +68 -0
  42. models_viclip/backbones/bert/tokenization_bert.py +546 -0
  43. models_viclip/backbones/bert/xbert.py +2157 -0
  44. models_viclip/backbones/blip_toremove/Qformer.py +1237 -0
  45. models_viclip/backbones/blip_toremove/__init__.py +0 -0
  46. models_viclip/backbones/blip_toremove/builder.py +44 -0
  47. models_viclip/backbones/blip_toremove/modeling_t5.py +2063 -0
  48. models_viclip/backbones/clip/__pycache__/clip_text.cpython-310.pyc +0 -0
  49. models_viclip/backbones/clip/__pycache__/clip_text.cpython-38.pyc +0 -0
  50. models_viclip/backbones/clip/__pycache__/clip_vision.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ annotation_files/k400/train.csv filter=lfs diff=lfs merge=lfs -text
37
+ annotation_files/ssv2/train.csv filter=lfs diff=lfs merge=lfs -text
38
+ figs/smile.jpg filter=lfs diff=lfs merge=lfs -text
annotation_files/k400/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
annotation_files/k400/train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:932b926a43e791b761badbf349dbd23107cfbcc069057d320212251cc60700c0
3
+ size 21512452
annotation_files/k400/val.csv ADDED
The diff for this file is too large to render. See raw diff
 
annotation_files/object_instances.txt ADDED
The diff for this file is too large to render. See raw diff
 
annotation_files/ssv2/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
annotation_files/ssv2/train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4833f9f9ddf54236dcdd66ecf59c7c0156fb1889a41832e9add5378d484ca787
3
+ size 15352045
annotation_files/ssv2/val.csv ADDED
The diff for this file is too large to render. See raw diff
 
configs/beit-base-patch16-224-pt22k-ft22k.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "note": "this file is a copy of the BEiT model config, not used directly",
3
+ "architectures": [
4
+ "BeitForImageClassification"
5
+ ],
6
+ "url": "https://huggingface.co/microsoft/beit-base-patch16-224-pt22k-ft22k/raw/main/config.json",
7
+ "attention_probs_dropout_prob": 0.0,
8
+ "drop_path_rate": 0.1,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.0,
11
+ "hidden_size": 768,
12
+ "image_size": 224,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-12,
16
+ "layer_scale_init_value": 0.1,
17
+ "model_type": "beit",
18
+ "num_attention_heads": 12,
19
+ "num_channels": 3,
20
+ "num_hidden_layers": 12,
21
+ "patch_size": 16,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.11.0.dev0",
24
+ "use_absolute_position_embeddings": false,
25
+ "use_mask_token": false,
26
+ "use_mean_pooling": true,
27
+ "use_relative_position_bias": true,
28
+ "use_shared_relative_position_bias": false,
29
+ "vocab_size": 8192
30
+ }
configs/config.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs.data import *
2
+ from configs.model import *
3
+
4
+ # ========================= data ==========================
5
+ train_corpus = "webvid_10m"
6
+ train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
7
+ test_file = dict(
8
+ test=[
9
+ "/ibex/project/c2134/LSMDC/annotations/LSMDC16_challenge_1000_publictest.json",
10
+ "/ibex/project/c2134/LSMDC/videos/",
11
+ "video",
12
+ ],
13
+ )
14
+ test_types = ["test"]
15
+ num_workers = 10
16
+
17
+ stop_key = None
18
+
19
+ # ========================= input ==========================
20
+ num_frames = 1
21
+ num_frames_test = 1
22
+ batch_size = 512
23
+ batch_size_test = 64
24
+ max_txt_l = 32
25
+
26
+ inputs = dict(
27
+ image_res=224,
28
+ video_input=dict(
29
+ num_frames="${num_frames}",
30
+ sample_type="rand",
31
+ num_frames_test="${num_frames_test}",
32
+ sample_type_test="middle",
33
+ random_aug=False,
34
+ ),
35
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
36
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
37
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
38
+ )
39
+
40
+ # ========================= model ==========================
41
+ text_enc = "bert_large"
42
+ model = dict(
43
+ model_cls="ViCLIP",
44
+ vision_encoder=dict(
45
+ # backbone
46
+ name="vit_b16",
47
+ pretrained='CLIP-ViT-B/16',
48
+ d_model=1024,
49
+ kernel_size=1,
50
+ center=True,
51
+ drop_path_rate=0.1,
52
+ masking_prob=0.9,
53
+ checkpoint_num=24,
54
+ ),
55
+ text_encoder=dict(
56
+ pretrained='CLIP-ViT-B/16', # This is for vindlu default tokenizer, this is never used
57
+ name="vit_b16",
58
+ d_model=512,
59
+ vocab_size=49408,
60
+ ),
61
+ requires_raw_text=True,
62
+ embed_dim=768,
63
+ temp=1 / 100.0,
64
+ temp_min=1 / 100.0,
65
+ freeze_text=True,
66
+ )
67
+
68
+ criterion = dict(
69
+ loss_weight=dict(
70
+ vtc=1.0,
71
+ # mlm=1.0,
72
+ # vtm=1.0,
73
+ # mvm=0.0,
74
+ # mac=1.0,
75
+ ), # 0: disabled.
76
+ )
77
+
78
+ optimizer = dict(
79
+ opt="adamW",
80
+ lr=4e-4,
81
+ opt_betas=[0.9, 0.98], # default
82
+ weight_decay=0.2,
83
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
84
+ # use a different lr for some modules, e.g., larger lr for new modules
85
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
86
+ )
87
+
88
+ scheduler = dict(sched="cosine", epochs=12, min_lr_multi=0.01, warmup_epochs=0.5)
89
+
90
+ evaluate = False
91
+ deep_fusion = False
92
+ evaluation = dict(
93
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
94
+ eval_x_only=False,
95
+ k_test=128,
96
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
97
+ )
98
+
99
+ fp16 = True
100
+ gradient_checkpointing = True
101
+
102
+ # ========================= wandb ==========================
103
+ wandb = dict(
104
+ enable=True,
105
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
106
+ project="vindlu_videoclip", # setup in your command line
107
+ )
108
+ dist_url = "env://"
109
+ device = "cuda"
110
+ mode = "pt"
111
+
112
+ # ========================= others ==========================
113
+ output_dir = None # output dir
114
+ resume = False # if True, load optimizer and scheduler states as well
115
+ debug = False
116
+ log_freq = 10
117
+ seed = 42
118
+
119
+ save_latest = True
120
+ auto_resume = True
121
+ pretrained_path = "" # path to pretrained model weights, for resume only?
122
+
123
+ deepspeed = dict(
124
+ enable=False,
125
+ stage=2,
126
+ )
127
+
128
+ wiseft = dict(
129
+ enable=False,
130
+ coef=0.5,
131
+ keys_to_exclude=["vision_encoder.temporal_positional_embedding"]
132
+ )
configs/config_bert.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "fusion_layer": 9,
20
+ "encoder_width": 768,
21
+ "cross_module": "ca"
22
+ }
configs/config_bert_large.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "gradient_checkpointing": false,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 1024,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 4096,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 16,
16
+ "num_hidden_layers": 24,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "type_vocab_size": 2,
20
+ "use_cache": true,
21
+ "vocab_size": 30522,
22
+ "fusion_layer": 19,
23
+ "encoder_width": 768,
24
+ "cross_module": "ca"
25
+ }
configs/data.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os as __os # add "__" if not want to be exported
2
+ from copy import deepcopy as __deepcopy
3
+
4
+ #data_dir = f'{VL_DATA_DIR}'
5
+ data_dir = '/ssdstore/fmthoker/videos_images/'
6
+ if data_dir is None:
7
+ raise ValueError("please set environment `VL_DATA_DIR` before continue")
8
+
9
+
10
+ #data_root = __os.path.join(data_dir, "videos_images")
11
+ #anno_root_pt = __os.path.join(data_dir, "anno_pretrain")
12
+ #anno_root_downstream = __os.path.join(data_dir, "anno_downstream")
13
+ data_root = data_dir
14
+ anno_root_pt = __os.path.join("/ssdstore/fmthoker/videos_images/", "anno_pretrain")
15
+ anno_root_downstream = __os.path.join("/ssdstore/fmthoker/videos_images/", "anno_downstream")
16
+
17
+ # ============== pretraining datasets=================
18
+ available_corpus = dict(
19
+ # pretraining datasets
20
+ cc3m=[
21
+ f"{anno_root_pt}/cc3m_train.json",
22
+ "{your_data_root}"
23
+ ],
24
+ cc12m=[
25
+ f"{anno_root_pt}/cc12m_train.json",
26
+ "{your_data_root}"
27
+ ],
28
+ sbu=[
29
+ f"{anno_root_pt}/sbu.json",
30
+ "{your_data_root}"
31
+ ],
32
+ vg=[
33
+ f"{anno_root_pt}/vg.json",
34
+ "{your_data_root}"
35
+ ],
36
+ coco=[
37
+ f"{anno_root_pt}/coco.json",
38
+ "{your_data_root}"
39
+ ],
40
+ imagenet1k=[
41
+ f"{anno_root_pt}/imagenet1k_train.json",
42
+ "{your_data_root}"
43
+ ],
44
+ webvid=[
45
+ f"{anno_root_pt}/webvid_train.json",
46
+ "{your_data_root}",
47
+ "video"
48
+ ],
49
+ webvid_10m=[
50
+ f"{anno_root_pt}/webvid_10m_train.json",
51
+ "{your_data_root}",
52
+ "video",
53
+ ],
54
+ kinetics400=[
55
+ f"{anno_root_pt}/kinetics400_train.json",
56
+ "{your_data_root}",
57
+ "video",
58
+ ],
59
+ kinetics710=[
60
+ f"{anno_root_pt}/kinetics710_train.json",
61
+ "{your_data_root}",
62
+ "video",
63
+ ],
64
+ kinetics710_raw=[
65
+ f"{anno_root_pt}/kinetics710_raw_train.json",
66
+ "{your_data_root}",
67
+ "only_video",
68
+ ],
69
+ internvid_10m_flt=[
70
+ #f"{anno_root_pt}/internvid_10m_flt.json",
71
+ #f"/ibex/project/c2134/InternVid-10M-FLT/internvid_10m_flt.json",
72
+ #"/ibex/project/c2134/InternVid-10M-FLT/vd-foundation___InternVid-10M-FLT/raw/InternVId-FLT_1/",
73
+ f"/ibex/project/c2134/InternVid-10M-FLT/vd-foundation___InternVid-10M-FLT/annotations/internvid_10m_flt.json",
74
+ f"/ibex/project/c2134/InternVid-10M-FLT/vd-foundation___InternVid-10M-FLT/videos/",
75
+ "video"
76
+ ],
77
+ internvid_300k_flt=[
78
+ f"/ibex/project/c2134/InternVid-10M-FLT/vd-foundation___InternVid-10M-FLT/annotations/internvid_300k_subset1.json",
79
+ f"/ibex/project/c2134/InternVid-10M-FLT/vd-foundation___InternVid-10M-FLT/videos/",
80
+ "video"
81
+ ],
82
+ mad_300k=[
83
+ f"/ibex/project/c2134/Fida/MAD/annotations/v2/MAD_train_viclip.json",
84
+ f"/ibex/project/c2134/Fida/MAD/data/folder_pre_shards",
85
+ "video"
86
+ ],
87
+ mad_100k=[
88
+ f"/ibex/project/c2134/Fida/MAD/annotations/v2/MAD_train_viclip_100k.json",
89
+ f"/ibex/project/c2134/Fida/MAD/data/folder_pre_shards",
90
+ "video"
91
+ ],
92
+ )
93
+
94
+ # composed datasets.
95
+ available_corpus["coco_vg"] = [available_corpus["coco"], available_corpus["vg"]]
96
+ available_corpus["in1k_k710"] = [
97
+ available_corpus["imagenet1k"],
98
+ available_corpus["kinetics710"],
99
+ ]
100
+ available_corpus["webvid_cc3m"] = [available_corpus["webvid"], available_corpus["cc3m"]]
101
+ available_corpus["webvid_cc3m_in1k_k710"] = [
102
+ available_corpus["webvid"],
103
+ available_corpus["cc3m"],
104
+ available_corpus["imagenet1k"],
105
+ available_corpus["kinetics710"],
106
+ ]
107
+ available_corpus["webvid_cc3m_k710raw"] = [
108
+ available_corpus["webvid"],
109
+ available_corpus["cc3m"],
110
+ available_corpus["kinetics710_raw"],
111
+ ]
112
+ available_corpus["webvid_14m"] = [
113
+ available_corpus["webvid"],
114
+ available_corpus["cc3m"],
115
+ available_corpus["coco"],
116
+ available_corpus["vg"],
117
+ available_corpus["sbu"],
118
+ available_corpus["cc12m"],
119
+ ]
120
+ available_corpus["webvid12m_14m"] = [
121
+ available_corpus["webvid"],
122
+ available_corpus["webvid_10m"],
123
+ available_corpus["cc3m"],
124
+ available_corpus["coco"],
125
+ available_corpus["vg"],
126
+ available_corpus["sbu"],
127
+ available_corpus["cc12m"],
128
+ ]
129
+ available_corpus["webvid10m_14m"] = [
130
+ available_corpus["webvid_10m"],
131
+ available_corpus["cc3m"],
132
+ available_corpus["coco"],
133
+ available_corpus["vg"],
134
+ available_corpus["sbu"],
135
+ available_corpus["cc12m"],
136
+ ]
137
+ available_corpus["simple_17m"] = [
138
+ available_corpus["webvid"],
139
+ available_corpus["cc3m"],
140
+ available_corpus["cc12m"],
141
+ ]
142
+ available_corpus["simple_25m"] = [
143
+ available_corpus["webvid_10m"],
144
+ available_corpus["cc3m"],
145
+ available_corpus["cc12m"],
146
+ ]
147
+ available_corpus["viclip_20m"] = [
148
+ available_corpus["internvid_10m_flt"],
149
+ available_corpus["webvid_10m"],
150
+ ]
151
+ available_corpus["viclip"] = [
152
+ available_corpus["internvid_10m_flt"],
153
+ ]
154
+ available_corpus["viclip_mad_300k"] = [
155
+ available_corpus["mad_300k"],
156
+ ]
157
+ available_corpus["viclip_mad_100k"] = [
158
+ available_corpus["mad_100k"],
159
+ ]
160
+ available_corpus["viclip_internvid_300k"] = [
161
+ available_corpus["internvid_300k_flt"],
162
+ ]
163
+
164
+ # ============== for validation =================
165
+ available_corpus["msrvtt_1k_test"] = [
166
+ f"{anno_root_downstream}/msrvtt_test1k.json",
167
+ f"{data_root}/msrvtt_2fps_224",
168
+ "video",
169
+ ]
170
+ available_corpus["k400_act_val"] = [
171
+ f"{anno_root_downstream}/kinetics400_validate.json",
172
+ "{your_data_root}",
173
+ "video",
174
+ ]
175
+ available_corpus["k600_act_val"] = [
176
+ f"{anno_root_downstream}/kinetics600_validate.json",
177
+ "{your_data_root}",
178
+ "video",
179
+ ]
180
+ available_corpus["k700_act_val"] = [
181
+ f"{anno_root_downstream}/kinetics700_validate.json",
182
+ "{your_data_root}",
183
+ "video",
184
+ ]
185
+ available_corpus["sthsthv1_act_val"] = [
186
+ f"{anno_root_downstream}/sthsthv1_validate_clean2.json",
187
+ "{your_data_root}",
188
+ "video",
189
+ ]
190
+ available_corpus["sthsthv2_act_val"] = [
191
+ f"{anno_root_downstream}/sthsthv2_validate_clean2.json",
192
+ "{your_data_root}",
193
+ "video",
194
+ ]
195
+
configs/model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VisionEncoders = dict()
2
+ VisionEncoders["beit"] = dict(
3
+ name="beit_base",
4
+ pretrained="microsoft/beit-base-patch16-224-pt22k-ft22k",
5
+ d_model=768,
6
+ )
7
+ VisionEncoders["beit_large"] = dict(
8
+ name="beit_large",
9
+ pretrained="microsoft/beit-large-patch16-224-pt22k-ft22k",
10
+ d_model=1024,
11
+ )
12
+
13
+ TextEncoders = dict()
14
+ TextEncoders["bert"] = dict(
15
+ name="bert_base",
16
+ pretrained="bert-base-uncased",
17
+ config="configs/config_bert.json",
18
+ d_model=768,
19
+ fusion_layer=9,
20
+ )
21
+ TextEncoders["bert_large"] = dict(
22
+ name="bert_large",
23
+ pretrained="bert-large-uncased",
24
+ config="configs/config_bert_large.json",
25
+ d_model=1024,
26
+ fusion_layer=19,
27
+ )
configs/pretrain.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .data import *
2
+ from .model import *
3
+
4
+ # ========================= data ==========================
5
+ train_corpus = "webvid_cc3m"
6
+ train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
7
+ test_file = dict(msrvtt_1k_test=available_corpus["msrvtt_1k_test"])
8
+ test_types = ["msrvtt_1k_test"]
9
+ num_workers = 6
10
+
11
+ stop_key = None
12
+
13
+ # ========================= input ==========================
14
+ num_frames = 4
15
+ num_frames_test = 4
16
+ batch_size = 64
17
+ max_txt_l = 32
18
+
19
+ inputs = dict(
20
+ image_res=224,
21
+ video_input=dict(
22
+ num_frames="${num_frames}",
23
+ sample_type="rand",
24
+ num_frames_test="${num_frames_test}",
25
+ sample_type_test="middle",
26
+ random_aug=False,
27
+ ),
28
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
29
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
30
+ batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
31
+ )
32
+
33
+ # ========================= model ==========================
34
+ vision_enc = "beit"
35
+ text_enc = "bert"
36
+ model = dict(
37
+ vision_encoder="${VisionEncoders[${vision_enc}]}",
38
+ text_encoder="${TextEncoders[${text_enc}]}",
39
+ temporal_modeling=dict(
40
+ num_frames="${num_frames}",
41
+ temporal_model_block="timesformer",
42
+ temporal_model_position="last",
43
+ temporal_model_config=dict(input_dim="${model.vision_encoder.d_model}"),
44
+ use_temporal_position_embedding=True,
45
+ ),
46
+ vit_add_ln=True,
47
+ multimodal=dict(enable=True),
48
+ embed_dim=256,
49
+ temp=0.07,
50
+ )
51
+
52
+ criterion = dict(
53
+ loss_weight=dict(vtc=1.0, mlm=1.0, vtm=1.0, mvm=0.0), # 0: disabled.
54
+ vtm_hard_neg=True,
55
+ mlm_masking_prob=0.5,
56
+ )
57
+
58
+ optimizer = dict(
59
+ opt="adamW",
60
+ lr=1e-4,
61
+ opt_betas=[0.9, 0.999], # default
62
+ weight_decay=0.02,
63
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
64
+ # use a different lr for some modules, e.g., larger lr for new modules
65
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
66
+ )
67
+
68
+ scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.01, warmup_epochs=1)
69
+
70
+ evaluate = False
71
+ deep_fusion = False
72
+ evaluation = dict(
73
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
74
+ eval_x_only=False,
75
+ k_test=128,
76
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
77
+ )
78
+
79
+ fp16 = True
80
+ gradient_checkpointing = True
81
+
82
+ # ========================= wandb ==========================
83
+ wandb = dict(
84
+ enable=True,
85
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
86
+ project="vindlu", # setup in your command line
87
+ )
88
+ dist_url = "env://"
89
+ device = "cuda"
90
+ mode = "pt"
91
+
92
+ # ========================= others ==========================
93
+ output_dir = None # output dir
94
+ resume = False # if True, load optimizer and scheduler states as well
95
+ debug = False
96
+ log_freq = 100
97
+ seed = 42
98
+
99
+ save_latest = True
100
+ auto_resume = True
101
+ pretrained_path = "" # path to pretrained model weights, for resume only?
configs/qa.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrain import *
2
+
3
+ del available_corpus
4
+
5
+ criterion["loss_weight"]["mlm"] = 0.0
6
+ scheduler["warmup_epochs"] = 0.5
7
+
8
+ max_txt_l = 32
9
+ batch_size = 32
10
+ num_frames = 12
11
+
12
+ optimizer["lr"] = 1e-5
13
+ log_freq = 100
14
+
15
+ # =========additional args for VQA ============
16
+ eos = "[SEP]"
17
+ max_q_len = 25
18
+ max_a_len = 5
19
+ # =========end ================================
20
+
configs/qa_anet.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .qa import *
2
+
3
+ train_file = [
4
+ [
5
+ f"{anno_root_downstream}/anet_qa_train.json",
6
+ f"{data_root}/activity_net_2fps_360",
7
+ "video",
8
+ ]
9
+ ]
10
+ test_file = dict(
11
+ val=[
12
+ f"{anno_root_downstream}/anet_qa_val.json",
13
+ f"{data_root}/activity_net_2fps_360",
14
+ "video",
15
+ ],
16
+ test=[
17
+ f"{anno_root_downstream}/anet_qa_test.json",
18
+ f"{data_root}/activity_net_2fps_360",
19
+ "video",
20
+ ]
21
+ )
22
+ dataset_name = "anet"
23
+
24
+ answer_list = f"{anno_root_downstream}/anet_qa_answer_list.json" # list of answer words
25
+
26
+ test_types = ["val"]
27
+ stop_key = "val" # used to choose the best ckpt. If None, save the last.
configs/qa_msrvtt.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .qa import *
2
+
3
+ train_file = [
4
+ [
5
+ f"{anno_root_downstream}/msrvtt_qa_train.json",
6
+ f"{data_root}/msrvtt_2fps_224",
7
+ "video",
8
+ ]
9
+ ]
10
+ test_file = dict(
11
+ val=[
12
+ f"{anno_root_downstream}/msrvtt_qa_val.json",
13
+ f"{data_root}/msrvtt_2fps_224",
14
+ "video",
15
+ ],
16
+ test=[
17
+ f"{anno_root_downstream}/msrvtt_qa_test.json",
18
+ f"{data_root}/msrvtt_2fps_224",
19
+ "video",
20
+ ],
21
+ )
22
+ dataset_name = "msrvtt"
23
+
24
+ answer_list = f"{anno_root_downstream}/msrvtt_qa_answer_list.json" # list of answer words
25
+
26
+ test_types = ["val"]
27
+ stop_key = "val" # used to choose the best ckpt. If None, save the last.
configs/ret_anet.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrain import *
2
+
3
+ del available_corpus
4
+
5
+ train_file = [
6
+ f"{anno_root_downstream}/anet_ret_train.json",
7
+ f"{data_root}/activity_net_2fps_360",
8
+ "video",
9
+ ]
10
+ test_file = dict(
11
+ test=[
12
+ f"{anno_root_downstream}/anet_ret_val_1.json",
13
+ f"{data_root}/activity_net_2fps_360",
14
+ "video",
15
+ ],
16
+ )
17
+
18
+ test_types = ["test"]
19
+ stop_key = "test/" # used to choose the best ckpt. If None, save the last.
20
+ is_paragraph_retrieval = True
21
+
22
+ max_txt_l = 64
23
+ batch_size = 32
24
+ num_frames = 12
25
+
26
+ optimizer["lr"] = 1e-5
27
+ log_freq = 100
configs/ret_coco.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrain import *
2
+
3
+ del available_corpus
4
+
5
+ train_file = [
6
+ f"{anno_root_downstream}/coco_train.json",
7
+ f"{data_root}/coco",
8
+ "video",
9
+ ]
10
+ test_file = dict(
11
+ val=[
12
+ f"{anno_root_downstream}/coco_val.json",
13
+ f"{data_root}/coco",
14
+ "video",
15
+ ],
16
+ test=[
17
+ f"{anno_root_downstream}/coco_test.json",
18
+ f"{data_root}/coco",
19
+ "video",
20
+ ],
21
+ )
22
+
23
+ test_types = ["val"]
24
+ stop_key = "val/" # used to choose the best ckpt. If None, save the last.
25
+ is_paragraph_retrieval = False
26
+
27
+ criterion["loss_weight"]["mlm"] = 0.0
28
+ scheduler["warmup_epochs"] = 0
29
+ optimizer["lr"] = 1e-5
30
+
31
+
32
+ max_txt_l = 22
33
+ batch_size = 128
34
+ num_frames = 1
35
+ num_frames_test = 1
36
+
37
+ log_freq = 100
configs/ret_didemo.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrain import *
2
+
3
+ del available_corpus
4
+
5
+ train_file = [
6
+ f"{anno_root_downstream}/didemo_ret_train.json",
7
+ f"{data_root}/didemo_2fps_360_trimed30",
8
+ "video",
9
+ ]
10
+ test_file = dict(
11
+ val=[
12
+ f"{anno_root_downstream}/didemo_ret_val.json",
13
+ f"{data_root}/didemo_2fps_360_trimed30",
14
+ "video",
15
+ ],
16
+ test=[
17
+ f"{anno_root_downstream}/didemo_ret_test.json",
18
+ f"{data_root}/didemo_2fps_360_trimed30",
19
+ "video",
20
+ ],
21
+ )
22
+
23
+ test_types = ["val"]
24
+ stop_key = "val/" # used to choose the best ckpt. If None, save the last.
25
+ is_paragraph_retrieval = True
26
+
27
+ criterion["loss_weight"]["mlm"] = 0.0
28
+ scheduler["warmup_epochs"] = 0
29
+ optimizer["lr"] = 1e-5
30
+
31
+
32
+ max_txt_l = 64
33
+ batch_size = 32
34
+ num_frames = 12
35
+
36
+ log_freq = 10
configs/ret_flickr.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrain import *
2
+
3
+ del available_corpus
4
+
5
+ train_file = [
6
+ f"{anno_root_downstream}/flickr30k_train.json",
7
+ f"{data_root}/f30k",
8
+ "video",
9
+ ]
10
+ test_file = dict(
11
+ val=[
12
+ f"{anno_root_downstream}/flickr30k_val.json",
13
+ f"{data_root}/f30k",
14
+ "video",
15
+ ],
16
+ test=[
17
+ f"{anno_root_downstream}/flickr30k_test.json",
18
+ f"{data_root}/f30k",
19
+ "video",
20
+ ],
21
+ )
22
+
23
+ test_types = ["val"]
24
+ stop_key = "val/" # used to choose the best ckpt. If None, save the last.
25
+ is_paragraph_retrieval = False
26
+
27
+ criterion["loss_weight"]["mlm"] = 0.0
28
+ scheduler["warmup_epochs"] = 0
29
+ optimizer["lr"] = 1e-5
30
+
31
+
32
+ max_txt_l = 32
33
+ batch_size = 128
34
+ num_frames = 1
35
+ num_frames_test = 1
36
+
37
+ log_freq = 100
configs/ret_msrvtt.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrain import *
2
+
3
+ del available_corpus
4
+
5
+ train_file = [
6
+ f"{anno_root_downstream}/msrvtt_ret_train7k.json",
7
+ f"{data_root}/msrvtt_2fps_224",
8
+ "video",
9
+ ]
10
+ test_file = dict(
11
+ test=[
12
+ f"{anno_root_downstream}/msrvtt_ret_test1k.json",
13
+ f"{data_root}/msrvtt_2fps_224",
14
+ "video",
15
+ ],
16
+ )
17
+
18
+ test_types = ["test"]
19
+ stop_key = None # used to choose the best ckpt. If None, save the last.
20
+ is_paragraph_retrieval = False
21
+
22
+ criterion["loss_weight"]["mlm"] = 0.0
23
+ scheduler["warmup_epochs"] = 0
24
+ scheduler["epochs"] = 5
25
+ optimizer["lr"] = 1e-5
26
+
27
+ max_txt_l = 32
28
+ batch_size = 32
29
+ num_frames = 12
30
+
31
+ log_freq = 100
configs/ret_msrvtt_9k.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .ret_msrvtt import *
2
+
3
+ train_file = [
4
+ f"{anno_root_downstream}/msrvtt_ret_train9k.json",
5
+ f"{data_root}/msrvtt_2fps_224",
6
+ "video",
7
+ ]
configs/ret_msrvtt_mc.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrain import *
2
+
3
+ del available_corpus
4
+
5
+ train_file = [
6
+ f"{anno_root_downstream}/msrvtt_ret_train7k.json",
7
+ f"{data_root}/msrvtt_2fps_224",
8
+ "video",
9
+ ]
10
+ test_file = dict(
11
+ mc_test=[
12
+ f"{anno_root_downstream}/msrvtt_mc_test.json",
13
+ f"{data_root}/msrvtt_2fps_224",
14
+ "video",
15
+ ]
16
+ )
17
+
18
+ test_types = ["mc_test"]
19
+ stop_key = None # used to choose the best ckpt. If None, save the last.
20
+ is_paragraph_retrieval = False
21
+
22
+ criterion["loss_weight"]["mlm"] = 0.0
23
+ scheduler["warmup_epochs"] = 0
24
+ optimizer["lr"] = 1e-5
25
+
26
+ max_txt_l = 32
27
+ batch_size = 32
28
+ num_frames = 12
29
+
30
+ log_freq = 100
configs/ret_ssv2_label.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ret_msrvtt import *
2
+
3
+ train_file = [
4
+ f"{anno_root_downstream}/ssv2_ret_label_train.json",
5
+ f"{data_root}/ssv2",
6
+ "video",
7
+ ]
8
+ test_file = dict(
9
+ val=[
10
+ f"{anno_root_downstream}/ssv2_ret_label_val_small.json",
11
+ f"{data_root}/ssv2",
12
+ "video",
13
+ ],
14
+ )
15
+
16
+ test_types = ["val"]
17
+ stop_key = None # used to choose the best ckpt. If None, save the last.
18
+
19
+ has_multi_vision_gt = True
20
+
21
+ scheduler["epochs"] = 10
22
+ optimizer["lr"] = 1e-4
23
+
24
+ max_txt_l = 25
configs/ret_ssv2_template.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ret_msrvtt import *
2
+
3
+ train_file = [
4
+ f"{anno_root_downstream}/ssv2_ret_template_train.json",
5
+ f"{data_root}/ssv2",
6
+ "video",
7
+ ]
8
+ test_file = dict(
9
+ val=[
10
+ f"{anno_root_downstream}/ssv2_ret_template_val_small.json",
11
+ f"{data_root}/ssv2",
12
+ "video",
13
+ ],
14
+ )
15
+
16
+ test_types = ["val"]
17
+ stop_key = None # used to choose the best ckpt. If None, save the last.
18
+
19
+ has_multi_vision_gt = True
20
+
21
+ scheduler["epochs"] = 10
22
+ optimizer["lr"] = 1e-4
23
+
24
+ max_txt_l = 22
configs/tvqa.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrain import *
2
+
3
+ del available_corpus
4
+
5
+ train_file = [
6
+ f"{anno_root_downstream}/tvqa_train_with_answer.json",
7
+ f"{data_root}/tvqa_trimmed_3fps",
8
+ "video",
9
+ ]
10
+ test_file = dict(
11
+ val=[
12
+ f"{anno_root_downstream}/tvqa_val_with_answer.json",
13
+ f"{data_root}/tvqa_trimmed_3fps",
14
+ "video",
15
+ ],
16
+ test=[
17
+ f"{anno_root_downstream}/tvqa_test_public_with_answer.json",
18
+ f"{data_root}/tvqa_trimmed_3fps",
19
+ "video",
20
+ ],
21
+ )
22
+
23
+ test_types = ["val"]
24
+ stop_key = "val" # used to choose the best ckpt. If None, save the last.
25
+ is_paragraph_retrieval = False
26
+
27
+ criterion["loss_weight"]["mlm"] = 0.0
28
+ optimizer["lr"] = 1e-5
29
+ scheduler["warmup_epochs"] = 0.5
30
+ scheduler["epochs"] = 10
31
+
32
+ max_txt_l = 150
33
+ batch_size = 32
34
+ num_frames = 12
35
+
36
+ log_freq = 100
figs/smile.jpg ADDED

Git LFS Details

  • SHA256: 695eabbf0c7395d4acc6dd439d3b0989443e1fae533c19e7f6b3e1bf831b8af8
  • Pointer size: 131 Bytes
  • Size of remote file: 338 kB
models_viclip/__init__.py ADDED
File without changes
models_viclip/backbones/__init__.py ADDED
File without changes
models_viclip/backbones/beit/__init__.py ADDED
File without changes
models_viclip/backbones/beit/builder.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from models.utils import (interpolate_pos_relative_bias_beit,
4
+ load_temp_embed_with_mismatch)
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def interpolate_pos_embed_beit(state_dict, new_model):
10
+ """interpolate the positional embeddings.
11
+ The spatial pe is relative and temporal pe is absolute.
12
+ additional temporal pe is padded with 0.
13
+
14
+ Args:
15
+ state_dict (dict): The state_dict.
16
+ new_model (nn.Module): The created model.
17
+
18
+ Returns: dict. The state_dict with updated positional embeddings.
19
+
20
+ """
21
+ state_dict = interpolate_pos_relative_bias_beit(
22
+ state_dict_old=state_dict,
23
+ state_dict_new=new_model.state_dict(),
24
+ patch_shape_new=new_model.vision_encoder.embeddings.patch_embeddings.patch_shape,
25
+ )
26
+ # absolute temporal pos bias
27
+ temporal_pe_key = "vision_encoder.embeddings.temporal_position_embeddings"
28
+ if temporal_pe_key in state_dict:
29
+ logger.info(f"interpolate temporal positional embeddings: {temporal_pe_key}")
30
+ state_dict[temporal_pe_key] = load_temp_embed_with_mismatch(
31
+ temp_embed_old=state_dict[temporal_pe_key],
32
+ temp_embed_new=new_model.state_dict()[temporal_pe_key],
33
+ )
34
+ return state_dict
35
+
36
+
37
+ def build_beit(model_config, image_res, checkpoint):
38
+ """build beit with configuration.
39
+
40
+ Args:
41
+ config (dict): The configs for beit.
42
+ image_res (int): The image resolution.
43
+ checkpoint (bool): Whether to enable gradient checkpointing.
44
+
45
+ Returns: nn.Module
46
+
47
+ """
48
+ from .st_beit import BeitConfig as config_cls
49
+ from .st_beit import BeitModel as model_cls
50
+
51
+ logger.info(
52
+ f"Loading vit pre-trained weights from huggingface {model_config.vision_encoder.pretrained}."
53
+ )
54
+ # BEiT uses average pooled tokens instead of [CLS] used by other models
55
+ aux_kwargs = {"add_pooling_layer": True}
56
+ tmp_model = model_cls.from_pretrained(model_config.vision_encoder.pretrained, **aux_kwargs)
57
+ state_dict = tmp_model.state_dict()
58
+ del tmp_model
59
+
60
+ logger.info(f"Init new model with new image size {image_res}, and load weights.")
61
+
62
+ other_cfg = model_config.temporal_modeling
63
+ vit_config = config_cls.from_pretrained(
64
+ model_config.vision_encoder.pretrained, image_size=image_res, **other_cfg
65
+ )
66
+ model = model_cls(config=vit_config, **aux_kwargs)
67
+
68
+ if checkpoint:
69
+ model.gradient_checkpointing_enable()
70
+
71
+ # interpolate relative pos bias
72
+ state_dict = interpolate_pos_relative_bias_beit(
73
+ state_dict_old=state_dict,
74
+ state_dict_new=model.state_dict(),
75
+ patch_shape_new=model.embeddings.patch_embeddings.patch_shape,
76
+ )
77
+
78
+ # del prompt_bias_table
79
+ for k in list(state_dict.keys()):
80
+ if "prompt_bias_table" in k:
81
+ del state_dict[k]
82
+
83
+ msg = model.load_state_dict(state_dict, strict=False)
84
+ logger.info(msg)
85
+ return model
models_viclip/backbones/beit/st_beit.py ADDED
@@ -0,0 +1,1749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch BEiT model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import einops
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from torch.nn import functional as F
28
+ from transformers.activations import ACT2FN
29
+ from transformers.configuration_utils import PretrainedConfig
30
+ from transformers.modeling_outputs import (BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ MaskedLMOutput,
34
+ SemanticSegmenterOutput)
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.pytorch_utils import (find_pruneable_heads_and_indices,
37
+ prune_linear_layer)
38
+ from transformers.utils import (add_code_sample_docstrings,
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward, logging,
41
+ replace_return_docstrings)
42
+
43
+ from models.utils import interpolate_temporal_pos_embed
44
+
45
+ from ...modules.temporal_model import (X_CLIP, STAdapter, TemporalAttention,
46
+ WindowTemporalAttention)
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ # General docstring
51
+ _CONFIG_FOR_DOC = "BeitConfig"
52
+ _FEAT_EXTRACTOR_FOR_DOC = "BeitFeatureExtractor"
53
+
54
+ # Base docstring
55
+ _CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224-pt22k"
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "microsoft/beit-base-patch16-224"
60
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
61
+
62
+ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "microsoft/beit-base-patch16-224",
64
+ # See all BEiT models at https://huggingface.co/models?filter=beit
65
+ ]
66
+
67
+
68
+ class BeitConfig(PretrainedConfig):
69
+ r"""
70
+ This is the configuration class to store the configuration of a [`BeitModel`]. It is used to instantiate an BEiT
71
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
72
+ defaults will yield a similar configuration to that of the BEiT
73
+ [microsoft/beit-base-patch16-224-pt22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k) architecture.
74
+
75
+ Args:
76
+ vocab_size (`int`, *optional*, defaults to 8092):
77
+ Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during
78
+ pre-training.
79
+ hidden_size (`int`, *optional*, defaults to 768):
80
+ Dimensionality of the encoder layers and the pooler layer.
81
+ num_hidden_layers (`int`, *optional*, defaults to 12):
82
+ Number of hidden layers in the Transformer encoder.
83
+ num_attention_heads (`int`, *optional*, defaults to 12):
84
+ Number of attention heads for each attention layer in the Transformer encoder.
85
+ intermediate_size (`int`, *optional*, defaults to 3072):
86
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
87
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
88
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
89
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
90
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
91
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
92
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
93
+ The dropout ratio for the attention probabilities.
94
+ initializer_range (`float`, *optional*, defaults to 0.02):
95
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
96
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
97
+ The epsilon used by the layer normalization layers.
98
+ image_size (`int`, *optional*, defaults to 224):
99
+ The size (resolution) of each image.
100
+ patch_size (`int`, *optional*, defaults to 16):
101
+ The size (resolution) of each patch.
102
+ num_channels (`int`, *optional*, defaults to 3):
103
+ The number of input channels.
104
+ use_mask_token (`bool`, *optional*, defaults to `False`):
105
+ Whether to use a mask token for masked image modeling.
106
+ use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
107
+ Whether to use BERT-style absolute position embeddings.
108
+ use_relative_position_bias (`bool`, *optional*, defaults to `False`):
109
+ Whether to use T5-style relative position embeddings in the self-attention layers.
110
+ use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
111
+ Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
112
+ layer_scale_init_value (`float`, *optional*, defaults to 0.1):
113
+ Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
114
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
115
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
116
+ use_mean_pooling (`bool`, *optional*, defaults to `True`):
117
+ Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
118
+ CLS token, before applying the classification head.
119
+ out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
120
+ Indices of the feature maps to use for semantic segmentation.
121
+ pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
122
+ Pooling scales used in Pooling Pyramid Module applied on the last feature map.
123
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
124
+ Whether to use an auxiliary head during training.
125
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
126
+ Weight of the cross-entropy loss of the auxiliary head.
127
+ auxiliary_channels (`int`, *optional*, defaults to 256):
128
+ Number of channels to use in the auxiliary head.
129
+ auxiliary_num_convs (`int`, *optional*, defaults to 1):
130
+ Number of convolutional layers to use in the auxiliary head.
131
+ auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
132
+ Whether to concatenate the output of the auxiliary head with the input before the classification layer.
133
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
134
+ The index that is ignored by the loss function of the semantic segmentation model.
135
+
136
+ Example:
137
+
138
+ ```python
139
+ >>> from transformers import BeitModel, BeitConfig
140
+
141
+ >>> # Initializing a BEiT beit-base-patch16-224-pt22k style configuration
142
+ >>> configuration = BeitConfig()
143
+
144
+ >>> # Initializing a model from the beit-base-patch16-224-pt22k style configuration
145
+ >>> model = BeitModel(configuration)
146
+
147
+ >>> # Accessing the model configuration
148
+ >>> configuration = model.config
149
+ ```"""
150
+ model_type = "beit"
151
+
152
+ def __init__(
153
+ self,
154
+ vocab_size=8192,
155
+ hidden_size=768,
156
+ num_hidden_layers=12,
157
+ num_attention_heads=12,
158
+ intermediate_size=3072,
159
+ hidden_act="gelu",
160
+ hidden_dropout_prob=0.0,
161
+ attention_probs_dropout_prob=0.0,
162
+ initializer_range=0.02,
163
+ layer_norm_eps=1e-12,
164
+ is_encoder_decoder=False,
165
+ image_size=224,
166
+ num_frames=1,
167
+ patch_size=16,
168
+ num_channels=3,
169
+ use_mask_token=False,
170
+ use_absolute_position_embeddings=False,
171
+ use_relative_position_bias=False,
172
+ use_shared_relative_position_bias=False,
173
+ layer_scale_init_value=0.1,
174
+ drop_path_rate=0.1,
175
+ use_mean_pooling=True,
176
+ out_indices=[3, 5, 7, 11],
177
+ pool_scales=[1, 2, 3, 6],
178
+ use_auxiliary_head=True,
179
+ auxiliary_loss_weight=0.4,
180
+ auxiliary_channels=256,
181
+ auxiliary_num_convs=1,
182
+ auxiliary_concat_input=False,
183
+ semantic_loss_ignore_index=255,
184
+ temporal_model_block="none",
185
+ temporal_model_position="last",
186
+ temporal_model_init_value=0.0,
187
+ temporal_model_config={},
188
+ use_temporal_position_embedding=False,
189
+ add_k_prompts=0,
190
+ **kwargs,
191
+ ):
192
+ super().__init__(**kwargs)
193
+
194
+ self.vocab_size = vocab_size
195
+ self.hidden_size = hidden_size
196
+ self.num_hidden_layers = num_hidden_layers
197
+ self.num_attention_heads = num_attention_heads
198
+ self.intermediate_size = intermediate_size
199
+ self.hidden_act = hidden_act
200
+ self.hidden_dropout_prob = hidden_dropout_prob
201
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
202
+ self.initializer_range = initializer_range
203
+ self.layer_norm_eps = layer_norm_eps
204
+
205
+ self.image_size = image_size
206
+ self.patch_size = patch_size
207
+ self.num_channels = num_channels
208
+ self.use_mask_token = use_mask_token
209
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
210
+ self.use_relative_position_bias = use_relative_position_bias
211
+ self.use_shared_relative_position_bias = use_shared_relative_position_bias
212
+ self.layer_scale_init_value = layer_scale_init_value
213
+ self.drop_path_rate = drop_path_rate
214
+ self.use_mean_pooling = use_mean_pooling
215
+ # decode head attributes (semantic segmentation)
216
+ self.out_indices = out_indices
217
+ self.pool_scales = pool_scales
218
+ # auxiliary head attributes (semantic segmentation)
219
+ self.use_auxiliary_head = use_auxiliary_head
220
+ self.auxiliary_loss_weight = auxiliary_loss_weight
221
+ self.auxiliary_channels = auxiliary_channels
222
+ self.auxiliary_num_convs = auxiliary_num_convs
223
+ self.auxiliary_concat_input = auxiliary_concat_input
224
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
225
+
226
+ self.temporal_model_block = temporal_model_block
227
+ self.temporal_model_config = temporal_model_config
228
+ self.temporal_model_position = temporal_model_position
229
+ self.temporal_model_init_value = temporal_model_init_value
230
+ self.use_temporal_position_embedding = use_temporal_position_embedding
231
+ self.add_k_prompts = add_k_prompts
232
+ self.num_frames = num_frames
233
+
234
+
235
+ @dataclass
236
+ class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
237
+ """
238
+ Class for outputs of [`BeitModel`].
239
+
240
+ Args:
241
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
242
+ Sequence of hidden-states at the output of the last layer of the model.
243
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
244
+ Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
245
+ *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
246
+ will be returned.
247
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
248
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
249
+ shape `(batch_size, sequence_length, hidden_size)`.
250
+
251
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
252
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
253
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
254
+ sequence_length)`.
255
+
256
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
257
+ heads.
258
+ """
259
+
260
+
261
+ def drop_path(
262
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
263
+ ) -> torch.Tensor:
264
+ """
265
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
266
+
267
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
268
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
269
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
270
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
271
+ argument.
272
+ """
273
+ if drop_prob == 0.0 or not training:
274
+ return input
275
+ keep_prob = 1 - drop_prob
276
+ shape = (input.shape[0],) + (1,) * (
277
+ input.ndim - 1
278
+ ) # work with diff dim tensors, not just 2D ConvNets
279
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
280
+ random_tensor.floor_() # binarize
281
+ output = input.div(keep_prob) * random_tensor
282
+ return output
283
+
284
+
285
+ class BeitDropPath(nn.Module):
286
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
287
+
288
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
289
+ super().__init__()
290
+ self.drop_prob = drop_prob
291
+
292
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
293
+ return drop_path(x, self.drop_prob, self.training)
294
+
295
+ def extra_repr(self) -> str:
296
+ return "p={}".format(self.drop_prob)
297
+
298
+
299
+ # Based on timm implementation, which can be found here:
300
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
301
+ class BeitEmbeddings(nn.Module):
302
+ """
303
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
304
+
305
+ """
306
+
307
+ def __init__(self, config: BeitConfig) -> None:
308
+ super().__init__()
309
+
310
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
311
+ if config.use_mask_token:
312
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
313
+ else:
314
+ self.mask_token = None
315
+ self.patch_embeddings = BeitPatchEmbeddings(config)
316
+ num_patches = self.patch_embeddings.num_patches
317
+ if config.use_absolute_position_embeddings:
318
+ self.position_embeddings = nn.Parameter(
319
+ torch.zeros(1, num_patches + 1, config.hidden_size)
320
+ )
321
+ else:
322
+ self.position_embeddings = None
323
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
324
+
325
+ if config.use_temporal_position_embedding:
326
+ self.temporal_position_embeddings = nn.parameter.Parameter(
327
+ torch.zeros(1, config.num_frames, 1, config.hidden_size)
328
+ )
329
+ else:
330
+ self.temporal_position_embeddings = None
331
+
332
+ if config.add_k_prompts > 0:
333
+ self.prompt_tokens = nn.parameter.Parameter(
334
+ torch.zeros(1, config.add_k_prompts, config.hidden_size)
335
+ )
336
+ else:
337
+ self.prompt_tokens = None
338
+
339
+ def forward(
340
+ self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None
341
+ ) -> torch.Tensor:
342
+ """
343
+ Args:
344
+ pixel_values (torch.Tensor): The input image patches. Shape: [B, T, C, H, W].
345
+
346
+
347
+ """
348
+ t = pixel_values.shape[1]
349
+ pixel_values = einops.rearrange(pixel_values, "b t c h w -> (b t) c h w")
350
+
351
+ embeddings = self.patch_embeddings(pixel_values)
352
+ batch_size, seq_len, _ = embeddings.size() # [(b t) l c]
353
+
354
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
355
+ if bool_masked_pos is not None:
356
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
357
+ # replace the masked visual tokens by mask_tokens
358
+ w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
359
+ embeddings = embeddings * (1 - w) + mask_tokens * w
360
+
361
+ if self.prompt_tokens is not None:
362
+ prompt_tokens = self.prompt_tokens.expand(batch_size, -1, -1)
363
+ embeddings = torch.cat((cls_tokens, embeddings, prompt_tokens), dim=1)
364
+ else:
365
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1) # [B*T, L, C]
366
+ if self.position_embeddings is not None:
367
+ embeddings = embeddings + self.position_embeddings
368
+
369
+ embeddings = einops.rearrange(embeddings, "(b t) l c -> b t l c", t=t)
370
+ if self.temporal_position_embeddings is not None:
371
+ if t <= self.temporal_position_embeddings.shape[1]:
372
+ embeddings = embeddings + self.temporal_position_embeddings[:, :t]
373
+ else:
374
+ tpe = interpolate_temporal_pos_embed(self.temporal_position_embeddings, t)
375
+ embeddings = embeddings + tpe
376
+
377
+ embeddings = self.dropout(embeddings)
378
+
379
+ return embeddings
380
+
381
+
382
+ class BeitPatchEmbeddings(nn.Module):
383
+ """
384
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
385
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
386
+ Transformer.
387
+ """
388
+
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ image_size, patch_size = config.image_size, config.patch_size
392
+ num_channels, hidden_size = config.num_channels, config.hidden_size
393
+
394
+ image_size = (
395
+ image_size
396
+ if isinstance(image_size, collections.abc.Iterable)
397
+ else (image_size, image_size)
398
+ )
399
+ patch_size = (
400
+ patch_size
401
+ if isinstance(patch_size, collections.abc.Iterable)
402
+ else (patch_size, patch_size)
403
+ )
404
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
405
+ patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
406
+ self.image_size = image_size
407
+ self.patch_size = patch_size
408
+ self.num_channels = num_channels
409
+ self.num_patches = num_patches
410
+ self.patch_shape = patch_shape
411
+
412
+ self.projection = nn.Conv2d(
413
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
414
+ )
415
+
416
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
417
+ batch_size, num_channels, height, width = pixel_values.shape
418
+ if num_channels != self.num_channels:
419
+ raise ValueError(
420
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
421
+ )
422
+ if height != self.image_size[0] or width != self.image_size[1]:
423
+ raise ValueError(
424
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
425
+ )
426
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
427
+
428
+ return embeddings
429
+
430
+
431
+ class BeitSelfAttention(nn.Module):
432
+ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
433
+ super().__init__()
434
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
435
+ config, "embedding_size"
436
+ ):
437
+ raise ValueError(
438
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
439
+ f"heads {config.num_attention_heads}."
440
+ )
441
+
442
+ self.num_attention_heads = config.num_attention_heads
443
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
444
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
445
+
446
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
447
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
448
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
449
+
450
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
451
+
452
+ if window_size:
453
+ self.relative_position_bias = BeitRelativePositionBias(
454
+ config, window_size=window_size
455
+ )
456
+ else:
457
+ self.relative_position_bias = None
458
+
459
+ def transpose_for_scores(self, x):
460
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
461
+ x = x.view(*new_x_shape)
462
+ return x.permute(0, 2, 1, 3)
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: torch.Tensor,
467
+ head_mask: Optional[torch.Tensor] = None,
468
+ output_attentions: bool = False,
469
+ relative_position_bias: Optional["BeitRelativePositionBias"] = None,
470
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
471
+ mixed_query_layer = self.query(hidden_states)
472
+
473
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
474
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
475
+ query_layer = self.transpose_for_scores(mixed_query_layer)
476
+
477
+ # Take the dot product between "query" and "key" to get the raw attention scores.
478
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
479
+
480
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
481
+
482
+ # Add relative position bias if present.
483
+ if self.relative_position_bias is not None:
484
+ attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
485
+
486
+ # Add shared relative position bias if provided.
487
+ if relative_position_bias is not None:
488
+ attention_scores = attention_scores + relative_position_bias
489
+
490
+ # Normalize the attention scores to probabilities.
491
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
492
+
493
+ # This is actually dropping out entire tokens to attend to, which might
494
+ # seem a bit unusual, but is taken from the original Transformer paper.
495
+ attention_probs = self.dropout(attention_probs)
496
+
497
+ # Mask heads if we want to
498
+ if head_mask is not None:
499
+ attention_probs = attention_probs * head_mask
500
+
501
+ context_layer = torch.matmul(attention_probs, value_layer)
502
+
503
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
504
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
505
+ context_layer = context_layer.view(*new_context_layer_shape)
506
+
507
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
508
+
509
+ return outputs
510
+
511
+
512
+ class BeitSelfOutput(nn.Module):
513
+ """
514
+ The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
515
+ layernorm applied before each block.
516
+ """
517
+
518
+ def __init__(self, config: BeitConfig) -> None:
519
+ super().__init__()
520
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
521
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
522
+
523
+ def forward(
524
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None
525
+ ) -> torch.Tensor:
526
+ hidden_states = self.dense(hidden_states)
527
+ hidden_states = self.dropout(hidden_states)
528
+
529
+ return hidden_states
530
+
531
+
532
+ class BeitAttention(nn.Module):
533
+ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
534
+ super().__init__()
535
+ self.attention = BeitSelfAttention(config, window_size=window_size)
536
+ self.output = BeitSelfOutput(config)
537
+ self.pruned_heads = set()
538
+
539
+ def prune_heads(self, heads):
540
+ if len(heads) == 0:
541
+ return
542
+ heads, index = find_pruneable_heads_and_indices(
543
+ heads,
544
+ self.attention.num_attention_heads,
545
+ self.attention.attention_head_size,
546
+ self.pruned_heads,
547
+ )
548
+
549
+ # Prune linear layers
550
+ self.attention.query = prune_linear_layer(self.attention.query, index)
551
+ self.attention.key = prune_linear_layer(self.attention.key, index)
552
+ self.attention.value = prune_linear_layer(self.attention.value, index)
553
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
554
+
555
+ # Update hyper params and store pruned heads
556
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
557
+ self.attention.all_head_size = (
558
+ self.attention.attention_head_size * self.attention.num_attention_heads
559
+ )
560
+ self.pruned_heads = self.pruned_heads.union(heads)
561
+
562
+ def forward(
563
+ self,
564
+ hidden_states: torch.Tensor,
565
+ head_mask: Optional[torch.Tensor] = None,
566
+ output_attentions: bool = False,
567
+ relative_position_bias: Optional["BeitRelativePositionBias"] = None,
568
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
569
+ self_outputs = self.attention(
570
+ hidden_states, head_mask, output_attentions, relative_position_bias
571
+ )
572
+
573
+ attention_output = self.output(self_outputs[0], hidden_states)
574
+
575
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
576
+ return outputs
577
+
578
+
579
+ class BeitIntermediate(nn.Module):
580
+ def __init__(self, config: BeitConfig) -> None:
581
+ super().__init__()
582
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
583
+ if isinstance(config.hidden_act, str):
584
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
585
+ else:
586
+ self.intermediate_act_fn = config.hidden_act
587
+
588
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
589
+ hidden_states = self.dense(hidden_states)
590
+ hidden_states = self.intermediate_act_fn(hidden_states)
591
+
592
+ return hidden_states
593
+
594
+
595
+ class BeitOutput(nn.Module):
596
+ def __init__(self, config: BeitConfig) -> None:
597
+ super().__init__()
598
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
599
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
600
+
601
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
602
+ hidden_states = self.dense(hidden_states)
603
+ hidden_states = self.dropout(hidden_states)
604
+
605
+ return hidden_states
606
+
607
+
608
+ class TemporalAttentionBeit(nn.Module):
609
+
610
+ """temporal attention using BeitAttention"""
611
+
612
+ def __init__(self, config: BeitConfig):
613
+ """TODO: to be defined."""
614
+ super().__init__()
615
+
616
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
617
+ self.attention = BeitAttention(config, window_size=None)
618
+ self.scale = nn.Parameter(
619
+ config.temporal_model_init_value * torch.ones((config.hidden_size)),
620
+ requires_grad=True,
621
+ )
622
+ self.drop_path = BeitDropPath(config.drop_path_rate)
623
+
624
+ def forward(self, hidden_states: torch.Tensor):
625
+ """forward function
626
+
627
+ Args:
628
+ hidden_states (torch.Tensor): The input. Shape: [b,t,l,c]
629
+
630
+ Returns: TODO
631
+
632
+ """
633
+ b = hidden_states.shape[0]
634
+ output = einops.rearrange(hidden_states, "b t l c -> (b l) t c")
635
+ output = self.layernorm_before(output)
636
+ output = self.attention(output)
637
+ output = einops.rearrange(output[0], "(b l) t c -> b t l c", b=b)
638
+ return hidden_states + self.drop_path(output[0]) * self.scale
639
+
640
+
641
+ class BeitLayer(nn.Module):
642
+ """This corresponds to the Block class in the timm implementation."""
643
+
644
+ def __init__(
645
+ self,
646
+ config: BeitConfig,
647
+ window_size: Optional[tuple] = None,
648
+ drop_path_rate: float = 0.0,
649
+ ) -> None:
650
+ super().__init__()
651
+ self.config = config
652
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
653
+ self.seq_len_dim = 1
654
+ self.attention = BeitAttention(config, window_size=window_size)
655
+ self.intermediate = BeitIntermediate(config)
656
+ self.output = BeitOutput(config)
657
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
658
+ self.drop_path = (
659
+ BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
660
+ )
661
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
662
+
663
+ self.temporal_model_position = config.temporal_model_position
664
+
665
+ init_values = config.layer_scale_init_value
666
+ if init_values > 0:
667
+ self.lambda_1 = nn.Parameter(
668
+ init_values * torch.ones((config.hidden_size)), requires_grad=True
669
+ )
670
+ self.lambda_2 = nn.Parameter(
671
+ init_values * torch.ones((config.hidden_size)), requires_grad=True
672
+ )
673
+ else:
674
+ self.lambda_1, self.lambda_2 = None, None
675
+
676
+ if config.temporal_model_block == "st_adapter":
677
+ self.temp_model = STAdapter(**config.temporal_model_config)
678
+ elif config.temporal_model_block == "timesformer":
679
+ self.temp_model = TemporalAttention(**config.temporal_model_config)
680
+ elif config.temporal_model_block == "ta_beit":
681
+ self.temp_model = TemporalAttentionBeit(config)
682
+ elif config.temporal_model_block == "window_attention":
683
+ self.temp_model = WindowTemporalAttention(**config.temporal_model_config)
684
+ elif config.temporal_model_block == "xclip":
685
+ self.temp_model = X_CLIP(**config.temporal_model_config)
686
+ elif config.temporal_model_block == "none":
687
+ self.temp_model = None
688
+ else:
689
+ raise ValueError(f"not accepted temporal model: {config.temporal_model_block}")
690
+
691
+ self.temporal_model_block = config.temporal_model_block
692
+
693
+ def forward(
694
+ self,
695
+ hidden_states: torch.Tensor,
696
+ head_mask: Optional[torch.Tensor] = None,
697
+ output_attentions: bool = False,
698
+ relative_position_bias: Optional["BeitRelativePositionBias"] = None,
699
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
700
+
701
+ b, t, l, c = hidden_states.shape
702
+
703
+ if self.temporal_model_block == "xclip":
704
+ assert (
705
+ self.temporal_model_position == "first" and self.config.add_k_prompts == 1
706
+ ), "xclip must be put before the attention and add_k_prompts must be 1."
707
+
708
+ if self.temp_model is not None and self.temporal_model_position == "first":
709
+ hidden_states = self.temp_model(hidden_states)
710
+
711
+ hidden_states = einops.rearrange(hidden_states, "b t l c -> (b t) l c")
712
+
713
+ self_attention_outputs = self.attention(
714
+ self.layernorm_before(
715
+ hidden_states
716
+ ), # in BEiT, layernorm is applied before self-attention
717
+ head_mask,
718
+ output_attentions=output_attentions,
719
+ relative_position_bias=relative_position_bias,
720
+ )
721
+ attention_output = self_attention_outputs[0]
722
+
723
+ # add self attentions if we output attention weights
724
+ outputs = self_attention_outputs[1:]
725
+
726
+ # apply lambda_1 if present
727
+ if self.lambda_1 is not None:
728
+ attention_output = self.lambda_1 * attention_output
729
+
730
+ # first residual connection
731
+ hidden_states = self.drop_path(attention_output) + hidden_states
732
+
733
+ # in BEiT, layernorm is also applied after self-attention
734
+ layer_output = self.layernorm_after(hidden_states)
735
+
736
+ layer_output = self.intermediate(layer_output)
737
+ layer_output = self.output(layer_output)
738
+
739
+ if self.lambda_2 is not None:
740
+ layer_output = self.lambda_2 * layer_output
741
+
742
+ # second residual connection
743
+ layer_output = self.drop_path(layer_output) + hidden_states
744
+
745
+ layer_output = einops.rearrange(layer_output, "(b t) l c -> b t l c", b=b)
746
+
747
+ # apply temporal modeling block
748
+ if self.temp_model is not None and self.temporal_model_position == "last":
749
+ layer_output = self.temp_model(layer_output)
750
+
751
+ outputs = (layer_output,) + outputs
752
+
753
+ return outputs
754
+
755
+
756
+ class BeitRelativePositionBias(nn.Module):
757
+ def __init__(self, config: BeitConfig, window_size: tuple) -> None:
758
+ super().__init__()
759
+ self.window_size = window_size
760
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
761
+ self.relative_position_bias_table = nn.Parameter(
762
+ torch.zeros(self.num_relative_distance, config.num_attention_heads)
763
+ ) # 2*Wh-1 * 2*Ww-1, nH
764
+ # cls to token & token 2 cls & cls to cls
765
+
766
+ # get pair-wise relative position index for each token inside the window
767
+ coords_h = torch.arange(window_size[0])
768
+ coords_w = torch.arange(window_size[1])
769
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
770
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
771
+ relative_coords = (
772
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
773
+ ) # 2, Wh*Ww, Wh*Ww
774
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
775
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
776
+ relative_coords[:, :, 1] += window_size[1] - 1
777
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
778
+ relative_position_index = torch.zeros(
779
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
780
+ )
781
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
782
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
783
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
784
+ relative_position_index[0, 0] = self.num_relative_distance - 1
785
+
786
+ self.register_buffer("relative_position_index", relative_position_index)
787
+
788
+ # add bias for prompts
789
+ k = config.add_k_prompts
790
+ self.k = k
791
+ if k > 0:
792
+ self.prompt_bias_table = nn.parameter.Parameter(
793
+ torch.zeros((2 + k) * k, config.num_attention_heads)
794
+ ) # k prompt-to-token, k token-to-prompt, k*k prompt-to-promt
795
+ else:
796
+ self.prompt_bias_table = None
797
+
798
+ def forward(self) -> torch.Tensor:
799
+ relative_position_bias = self.relative_position_bias_table[
800
+ self.relative_position_index.view(-1)
801
+ ].view(
802
+ self.window_size[0] * self.window_size[1] + 1,
803
+ self.window_size[0] * self.window_size[1] + 1,
804
+ -1,
805
+ ) # Wh*Ww,Wh*Ww,nH
806
+
807
+ k = self.k
808
+ if k > 0:
809
+ l = self.window_size[0] * self.window_size[1] + 1
810
+ bias = torch.zeros(l + k, l + k, relative_position_bias.shape[-1]).to(
811
+ relative_position_bias.device
812
+ )
813
+ bias[:l, :l] = relative_position_bias
814
+ bias[l:, :l] = self.prompt_bias_table[:k].view(k, 1, -1) # prompt to token
815
+ bias[:l, l:] = self.prompt_bias_table[k : 2 * k].view(1, k, -1) # token to prompt
816
+ bias[l:, l:] = self.prompt_bias_table[2 * k, :].view(k, k, -1) # prompt to prompt
817
+
818
+ # bias[k:, k:] = relative_position_bias
819
+ # bias[:k, k:] = self.prompt_bias_table[:k].view(k, 1, -1)
820
+ # bias[k:, :k] = self.prompt_bias_table[k : 2 * k].view(1, k, -1)
821
+ # bias[:k, :k] = self.prompt_bias_table[2 * k :].view(k, k, -1)
822
+ else:
823
+ bias = relative_position_bias
824
+
825
+ return bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
826
+
827
+
828
+ class BeitEncoder(nn.Module):
829
+ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
830
+ super().__init__()
831
+ self.config = config
832
+ if config.use_shared_relative_position_bias:
833
+ self.relative_position_bias = BeitRelativePositionBias(
834
+ config, window_size=window_size
835
+ )
836
+ else:
837
+ self.relative_position_bias = None
838
+
839
+ # stochastic depth decay rule
840
+ dpr = [
841
+ x.item()
842
+ for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
843
+ ]
844
+ self.layer = nn.ModuleList(
845
+ [
846
+ BeitLayer(
847
+ config,
848
+ window_size=window_size if config.use_relative_position_bias else None,
849
+ drop_path_rate=dpr[i],
850
+ )
851
+ for i in range(config.num_hidden_layers)
852
+ ]
853
+ )
854
+ self.gradient_checkpointing = False
855
+
856
+ def forward(
857
+ self,
858
+ hidden_states: torch.Tensor,
859
+ head_mask: Optional[torch.Tensor] = None,
860
+ output_attentions: bool = False,
861
+ output_hidden_states: bool = False,
862
+ return_dict: bool = True,
863
+ ) -> Union[tuple, BaseModelOutput]:
864
+ all_hidden_states = () if output_hidden_states else None
865
+ all_self_attentions = () if output_attentions else None
866
+
867
+ for i, layer_module in enumerate(self.layer):
868
+ if output_hidden_states:
869
+ # all_hidden_states = all_hidden_states + (
870
+ # einops.rearrange(hidden_states, "b t l c -> (b t) l c"),
871
+ # )
872
+ all_hidden_states = all_hidden_states + (hidden_states,)
873
+
874
+ layer_head_mask = head_mask[i] if head_mask is not None else None
875
+
876
+ if self.gradient_checkpointing and self.training:
877
+
878
+ def create_custom_forward(module):
879
+ def custom_forward(*inputs):
880
+ return module(*inputs, output_attentions)
881
+
882
+ return custom_forward
883
+
884
+ layer_outputs = torch.utils.checkpoint.checkpoint(
885
+ create_custom_forward(layer_module),
886
+ hidden_states,
887
+ layer_head_mask,
888
+ use_reentrant=False,
889
+ )
890
+ else:
891
+ relative_position_bias = (
892
+ self.relative_position_bias()
893
+ if self.relative_position_bias is not None
894
+ else None
895
+ )
896
+ layer_outputs = layer_module(
897
+ hidden_states, layer_head_mask, output_attentions, relative_position_bias
898
+ )
899
+
900
+ hidden_states = layer_outputs[0]
901
+
902
+ if output_attentions:
903
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
904
+
905
+ # hidden_states = einops.rearrange(hidden_states, "b t l c -> (b t) l c")
906
+
907
+ if output_hidden_states:
908
+ all_hidden_states = all_hidden_states + (hidden_states,)
909
+
910
+ if not return_dict:
911
+ return tuple(
912
+ v
913
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
914
+ if v is not None
915
+ )
916
+ return BaseModelOutput(
917
+ last_hidden_state=hidden_states,
918
+ hidden_states=all_hidden_states,
919
+ attentions=all_self_attentions,
920
+ )
921
+
922
+
923
+ class BeitPreTrainedModel(PreTrainedModel):
924
+ """
925
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
926
+ models.
927
+ """
928
+
929
+ config_class = BeitConfig
930
+ base_model_prefix = "beit"
931
+ main_input_name = "pixel_values"
932
+ supports_gradient_checkpointing = True
933
+
934
+ def _init_weights(self, module):
935
+ """Initialize the weights"""
936
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
937
+ # Slightly different from the TF version which uses truncated_normal for initialization
938
+ # cf https://github.com/pytorch/pytorch/pull/5617
939
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
940
+ if module.bias is not None:
941
+ module.bias.data.zero_()
942
+ elif isinstance(module, nn.Embedding):
943
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
944
+ if module.padding_idx is not None:
945
+ module.weight.data[module.padding_idx].zero_()
946
+ elif isinstance(module, nn.LayerNorm):
947
+ module.bias.data.zero_()
948
+ module.weight.data.fill_(1.0)
949
+
950
+ def _set_gradient_checkpointing(self, module, value=False):
951
+ if isinstance(module, BeitEncoder):
952
+ module.gradient_checkpointing = value
953
+
954
+
955
+ BEIT_START_DOCSTRING = r"""
956
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
957
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
958
+ behavior.
959
+
960
+ Parameters:
961
+ config ([`BeitConfig`]): Model configuration class with all the parameters of the model.
962
+ Initializing with a config file does not load the weights associated with the model, only the
963
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
964
+ """
965
+
966
+ BEIT_INPUTS_DOCSTRING = r"""
967
+ Args:
968
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
969
+ Pixel values. Pixel values can be obtained using [`BeitFeatureExtractor`]. See
970
+ [`BeitFeatureExtractor.__call__`] for details.
971
+
972
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
973
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
974
+
975
+ - 1 indicates the head is **not masked**,
976
+ - 0 indicates the head is **masked**.
977
+
978
+ output_attentions (`bool`, *optional*):
979
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
980
+ tensors for more detail.
981
+ output_hidden_states (`bool`, *optional*):
982
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
983
+ more detail.
984
+ return_dict (`bool`, *optional*):
985
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
986
+ """
987
+
988
+
989
+ @add_start_docstrings(
990
+ "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.",
991
+ BEIT_START_DOCSTRING,
992
+ )
993
+ class BeitModel(BeitPreTrainedModel):
994
+ def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None:
995
+ super().__init__(config)
996
+ self.config = config
997
+
998
+ self.embeddings = BeitEmbeddings(config)
999
+ self.encoder = BeitEncoder(
1000
+ config, window_size=self.embeddings.patch_embeddings.patch_shape
1001
+ )
1002
+
1003
+ self.layernorm = (
1004
+ nn.Identity()
1005
+ if config.use_mean_pooling
1006
+ else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1007
+ )
1008
+ self.pooler = BeitPooler(config) if add_pooling_layer else None
1009
+
1010
+ # Initialize weights and apply final processing
1011
+ self.post_init()
1012
+
1013
+ def get_input_embeddings(self):
1014
+ return self.embeddings.patch_embeddings
1015
+
1016
+ def _prune_heads(self, heads_to_prune):
1017
+ """
1018
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1019
+ class PreTrainedModel
1020
+ """
1021
+ for layer, heads in heads_to_prune.items():
1022
+ self.encoder.layer[layer].attention.prune_heads(heads)
1023
+
1024
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1025
+ @add_code_sample_docstrings(
1026
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
1027
+ checkpoint=_CHECKPOINT_FOR_DOC,
1028
+ output_type=BeitModelOutputWithPooling,
1029
+ config_class=_CONFIG_FOR_DOC,
1030
+ modality="vision",
1031
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1032
+ )
1033
+ def forward(
1034
+ self,
1035
+ pixel_values: Optional[torch.Tensor] = None,
1036
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
1037
+ head_mask: Optional[torch.Tensor] = None,
1038
+ output_attentions: Optional[bool] = None,
1039
+ output_hidden_states: Optional[bool] = None,
1040
+ return_dict: Optional[bool] = None,
1041
+ ) -> Union[tuple, BeitModelOutputWithPooling]:
1042
+ output_attentions = (
1043
+ output_attentions
1044
+ if output_attentions is not None
1045
+ else self.config.output_attentions
1046
+ )
1047
+ output_hidden_states = (
1048
+ output_hidden_states
1049
+ if output_hidden_states is not None
1050
+ else self.config.output_hidden_states
1051
+ )
1052
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1053
+
1054
+ if pixel_values is None:
1055
+ raise ValueError("You have to specify pixel_values")
1056
+
1057
+ # Prepare head mask if needed
1058
+ # 1.0 in head_mask indicate we keep the head
1059
+ # attention_probs has shape bsz x n_heads x N x N
1060
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1061
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1062
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1063
+
1064
+ # pixel_values: [bsz, nframes, c, h, w]
1065
+ assert pixel_values.ndim == 5, logger.error(
1066
+ f"input shape to st_beit: {pixel_values.shape}"
1067
+ )
1068
+
1069
+ embedding_output = self.embeddings(
1070
+ pixel_values, bool_masked_pos
1071
+ ) # [bs, nframes, L, c]
1072
+
1073
+ encoder_outputs = self.encoder(
1074
+ embedding_output,
1075
+ head_mask=head_mask,
1076
+ output_attentions=output_attentions,
1077
+ output_hidden_states=output_hidden_states,
1078
+ return_dict=return_dict,
1079
+ )
1080
+ sequence_output = encoder_outputs[0]
1081
+ sequence_output = self.layernorm(sequence_output)
1082
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1083
+
1084
+ # logger.info(f"sequence_output: {sequence_output.shape}. pooled_output: {pooled_output.shape}")
1085
+
1086
+ if not return_dict:
1087
+ head_outputs = (
1088
+ (sequence_output, pooled_output)
1089
+ if pooled_output is not None
1090
+ else (sequence_output,)
1091
+ )
1092
+ return head_outputs + encoder_outputs[1:]
1093
+
1094
+ return BeitModelOutputWithPooling(
1095
+ last_hidden_state=sequence_output,
1096
+ pooler_output=pooled_output,
1097
+ hidden_states=encoder_outputs.hidden_states,
1098
+ attentions=encoder_outputs.attentions,
1099
+ )
1100
+
1101
+
1102
+ class BeitPooler(nn.Module):
1103
+ def __init__(self, config: BeitConfig) -> None:
1104
+ super().__init__()
1105
+ self.num_prompts = config.add_k_prompts
1106
+ self.layernorm = (
1107
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1108
+ if config.use_mean_pooling
1109
+ else None
1110
+ )
1111
+
1112
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1113
+ """
1114
+ Args:
1115
+ hidden_states (torch.Tensor): Shape: [B,T,L,C]
1116
+ """
1117
+ if self.layernorm is not None:
1118
+ # Mean pool the final hidden states of the patch tokens
1119
+ # patch_tokens = hidden_states[:, 1 + self.num_prompts :, :]
1120
+ if self.num_prompts > 0:
1121
+ patch_tokens = hidden_states[:, :, 1 : -self.num_prompts, :]
1122
+ else:
1123
+ patch_tokens = hidden_states[:, :, 1:, :]
1124
+ pooled_output = self.layernorm(patch_tokens.mean(2))
1125
+ else:
1126
+ # Pool by simply taking the final hidden state of the [CLS] token
1127
+ pooled_output = hidden_states[:, :, 0]
1128
+
1129
+ return pooled_output
1130
+
1131
+
1132
+ @add_start_docstrings(
1133
+ """Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting
1134
+ visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT
1135
+ predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you
1136
+ will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.""",
1137
+ BEIT_START_DOCSTRING,
1138
+ )
1139
+ class BeitForMaskedImageModeling(BeitPreTrainedModel):
1140
+ def __init__(self, config: BeitConfig) -> None:
1141
+ super().__init__(config)
1142
+
1143
+ self.num_labels = config.num_labels
1144
+ self.beit = BeitModel(config, add_pooling_layer=False)
1145
+
1146
+ # Classifier head
1147
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1148
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
1149
+
1150
+ # Initialize weights and apply final processing
1151
+ self.post_init()
1152
+
1153
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1154
+ @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
1155
+ def forward(
1156
+ self,
1157
+ pixel_values: Optional[torch.Tensor] = None,
1158
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
1159
+ head_mask: Optional[torch.Tensor] = None,
1160
+ labels: Optional[torch.Tensor] = None,
1161
+ output_attentions: Optional[bool] = None,
1162
+ output_hidden_states: Optional[bool] = None,
1163
+ return_dict: Optional[bool] = None,
1164
+ ) -> Union[tuple, MaskedLMOutput]:
1165
+ r"""
1166
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
1167
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
1168
+
1169
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1170
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1171
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1172
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1173
+
1174
+ Returns:
1175
+
1176
+ Examples:
1177
+
1178
+ ```python
1179
+ >>> from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling
1180
+ >>> import torch
1181
+ >>> from PIL import Image
1182
+ >>> import requests
1183
+
1184
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1185
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1186
+
1187
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
1188
+ >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
1189
+
1190
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
1191
+ >>> pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
1192
+ >>> # create random boolean mask of shape (batch_size, num_patches)
1193
+ >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
1194
+
1195
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
1196
+ >>> loss, logits = outputs.loss, outputs.logits
1197
+ >>> list(logits.shape)
1198
+ [1, 196, 8192]
1199
+ ```"""
1200
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1201
+
1202
+ outputs = self.beit(
1203
+ pixel_values,
1204
+ bool_masked_pos=bool_masked_pos,
1205
+ head_mask=head_mask,
1206
+ output_attentions=output_attentions,
1207
+ output_hidden_states=output_hidden_states,
1208
+ return_dict=return_dict,
1209
+ )
1210
+
1211
+ sequence_output = outputs[0]
1212
+ sequence_output = self.layernorm(sequence_output)
1213
+ prediction_scores = self.lm_head(sequence_output[:, 1:])
1214
+
1215
+ masked_lm_loss = None
1216
+ if labels is not None:
1217
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1218
+ masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels)
1219
+
1220
+ if not return_dict:
1221
+ output = (prediction_scores,) + outputs[1:]
1222
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1223
+
1224
+ return MaskedLMOutput(
1225
+ loss=masked_lm_loss,
1226
+ logits=prediction_scores,
1227
+ hidden_states=outputs.hidden_states,
1228
+ attentions=outputs.attentions,
1229
+ )
1230
+
1231
+
1232
+ @add_start_docstrings(
1233
+ """
1234
+ Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
1235
+ hidden states of the patch tokens) e.g. for ImageNet.
1236
+ """,
1237
+ BEIT_START_DOCSTRING,
1238
+ )
1239
+ class BeitForImageClassification(BeitPreTrainedModel):
1240
+ def __init__(self, config: BeitConfig) -> None:
1241
+ super().__init__(config)
1242
+
1243
+ self.num_labels = config.num_labels
1244
+ self.beit = BeitModel(config, add_pooling_layer=True)
1245
+
1246
+ # Classifier head
1247
+ self.classifier = (
1248
+ nn.Linear(config.hidden_size, config.num_labels)
1249
+ if config.num_labels > 0
1250
+ else nn.Identity()
1251
+ )
1252
+
1253
+ # Initialize weights and apply final processing
1254
+ self.post_init()
1255
+
1256
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1257
+ @add_code_sample_docstrings(
1258
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
1259
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1260
+ output_type=ImageClassifierOutput,
1261
+ config_class=_CONFIG_FOR_DOC,
1262
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1263
+ )
1264
+ def forward(
1265
+ self,
1266
+ pixel_values: Optional[torch.Tensor] = None,
1267
+ head_mask: Optional[torch.Tensor] = None,
1268
+ labels: Optional[torch.Tensor] = None,
1269
+ output_attentions: Optional[bool] = None,
1270
+ output_hidden_states: Optional[bool] = None,
1271
+ return_dict: Optional[bool] = None,
1272
+ ) -> Union[tuple, ImageClassifierOutput]:
1273
+ r"""
1274
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1275
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1276
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1277
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1278
+ """
1279
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1280
+ outputs = self.beit(
1281
+ pixel_values,
1282
+ head_mask=head_mask,
1283
+ output_attentions=output_attentions,
1284
+ output_hidden_states=output_hidden_states,
1285
+ return_dict=return_dict,
1286
+ )
1287
+
1288
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
1289
+
1290
+ logits = self.classifier(pooled_output)
1291
+
1292
+ loss = None
1293
+ if labels is not None:
1294
+ if self.config.problem_type is None:
1295
+ if self.num_labels == 1:
1296
+ self.config.problem_type = "regression"
1297
+ elif self.num_labels > 1 and (
1298
+ labels.dtype == torch.long or labels.dtype == torch.int
1299
+ ):
1300
+ self.config.problem_type = "single_label_classification"
1301
+ else:
1302
+ self.config.problem_type = "multi_label_classification"
1303
+
1304
+ if self.config.problem_type == "regression":
1305
+ loss_fct = MSELoss()
1306
+ if self.num_labels == 1:
1307
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1308
+ else:
1309
+ loss = loss_fct(logits, labels)
1310
+ elif self.config.problem_type == "single_label_classification":
1311
+ loss_fct = CrossEntropyLoss()
1312
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1313
+ elif self.config.problem_type == "multi_label_classification":
1314
+ loss_fct = BCEWithLogitsLoss()
1315
+ loss = loss_fct(logits, labels)
1316
+ if not return_dict:
1317
+ output = (logits,) + outputs[2:]
1318
+ return ((loss,) + output) if loss is not None else output
1319
+
1320
+ return ImageClassifierOutput(
1321
+ loss=loss,
1322
+ logits=logits,
1323
+ hidden_states=outputs.hidden_states,
1324
+ attentions=outputs.attentions,
1325
+ )
1326
+
1327
+
1328
+ class BeitConvModule(nn.Module):
1329
+ """
1330
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
1331
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
1332
+
1333
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1334
+ """
1335
+
1336
+ def __init__(
1337
+ self,
1338
+ in_channels: int,
1339
+ out_channels: int,
1340
+ kernel_size: Union[int, Tuple[int, int]],
1341
+ padding: Union[int, Tuple[int, int], str] = 0,
1342
+ bias: bool = False,
1343
+ dilation: Union[int, Tuple[int, int]] = 1,
1344
+ ) -> None:
1345
+ super().__init__()
1346
+ self.conv = nn.Conv2d(
1347
+ in_channels=in_channels,
1348
+ out_channels=out_channels,
1349
+ kernel_size=kernel_size,
1350
+ padding=padding,
1351
+ bias=bias,
1352
+ dilation=dilation,
1353
+ )
1354
+ self.bn = nn.BatchNorm2d(out_channels)
1355
+ self.activation = nn.ReLU()
1356
+
1357
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
1358
+ output = self.conv(input)
1359
+ output = self.bn(output)
1360
+ output = self.activation(output)
1361
+
1362
+ return output
1363
+
1364
+
1365
+ class BeitPyramidPoolingBlock(nn.Module):
1366
+ def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
1367
+ super().__init__()
1368
+ self.layers = [
1369
+ nn.AdaptiveAvgPool2d(pool_scale),
1370
+ BeitConvModule(in_channels, channels, kernel_size=1),
1371
+ ]
1372
+ for i, layer in enumerate(self.layers):
1373
+ self.add_module(str(i), layer)
1374
+
1375
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
1376
+ hidden_state = input
1377
+ for layer in self.layers:
1378
+ hidden_state = layer(hidden_state)
1379
+ return hidden_state
1380
+
1381
+
1382
+ class BeitPyramidPoolingModule(nn.Module):
1383
+ """
1384
+ Pyramid Pooling Module (PPM) used in PSPNet.
1385
+
1386
+ Args:
1387
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
1388
+ Module.
1389
+ in_channels (int): Input channels.
1390
+ channels (int): Channels after modules, before conv_seg.
1391
+ align_corners (bool): align_corners argument of F.interpolate.
1392
+
1393
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1394
+ """
1395
+
1396
+ def __init__(
1397
+ self,
1398
+ pool_scales: Tuple[int, ...],
1399
+ in_channels: int,
1400
+ channels: int,
1401
+ align_corners: bool,
1402
+ ) -> None:
1403
+ super().__init__()
1404
+ self.pool_scales = pool_scales
1405
+ self.align_corners = align_corners
1406
+ self.in_channels = in_channels
1407
+ self.channels = channels
1408
+ self.blocks = []
1409
+ for i, pool_scale in enumerate(pool_scales):
1410
+ block = BeitPyramidPoolingBlock(
1411
+ pool_scale=pool_scale, in_channels=in_channels, channels=channels
1412
+ )
1413
+ self.blocks.append(block)
1414
+ self.add_module(str(i), block)
1415
+
1416
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
1417
+ ppm_outs = []
1418
+ for ppm in self.blocks:
1419
+ ppm_out = ppm(x)
1420
+ upsampled_ppm_out = nn.functional.interpolate(
1421
+ ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
1422
+ )
1423
+ ppm_outs.append(upsampled_ppm_out)
1424
+ return ppm_outs
1425
+
1426
+
1427
+ class BeitUperHead(nn.Module):
1428
+ """
1429
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
1430
+ [UPerNet](https://arxiv.org/abs/1807.10221).
1431
+
1432
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1433
+ """
1434
+
1435
+ def __init__(self, config: BeitConfig) -> None:
1436
+ super().__init__()
1437
+
1438
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
1439
+ self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
1440
+ self.channels = config.hidden_size
1441
+ self.align_corners = False
1442
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
1443
+
1444
+ # PSP Module
1445
+ self.psp_modules = BeitPyramidPoolingModule(
1446
+ self.pool_scales,
1447
+ self.in_channels[-1],
1448
+ self.channels,
1449
+ align_corners=self.align_corners,
1450
+ )
1451
+ self.bottleneck = BeitConvModule(
1452
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
1453
+ self.channels,
1454
+ kernel_size=3,
1455
+ padding=1,
1456
+ )
1457
+ # FPN Module
1458
+ self.lateral_convs = nn.ModuleList()
1459
+ self.fpn_convs = nn.ModuleList()
1460
+ for in_channels in self.in_channels[:-1]: # skip the top layer
1461
+ l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
1462
+ fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
1463
+ self.lateral_convs.append(l_conv)
1464
+ self.fpn_convs.append(fpn_conv)
1465
+
1466
+ self.fpn_bottleneck = BeitConvModule(
1467
+ len(self.in_channels) * self.channels,
1468
+ self.channels,
1469
+ kernel_size=3,
1470
+ padding=1,
1471
+ )
1472
+
1473
+ def psp_forward(self, inputs):
1474
+ x = inputs[-1]
1475
+ psp_outs = [x]
1476
+ psp_outs.extend(self.psp_modules(x))
1477
+ psp_outs = torch.cat(psp_outs, dim=1)
1478
+ output = self.bottleneck(psp_outs)
1479
+
1480
+ return output
1481
+
1482
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
1483
+ # build laterals
1484
+ laterals = [
1485
+ lateral_conv(encoder_hidden_states[i])
1486
+ for i, lateral_conv in enumerate(self.lateral_convs)
1487
+ ]
1488
+
1489
+ laterals.append(self.psp_forward(encoder_hidden_states))
1490
+
1491
+ # build top-down path
1492
+ used_backbone_levels = len(laterals)
1493
+ for i in range(used_backbone_levels - 1, 0, -1):
1494
+ prev_shape = laterals[i - 1].shape[2:]
1495
+ laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
1496
+ laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
1497
+ )
1498
+
1499
+ # build outputs
1500
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
1501
+ # append psp feature
1502
+ fpn_outs.append(laterals[-1])
1503
+
1504
+ for i in range(used_backbone_levels - 1, 0, -1):
1505
+ fpn_outs[i] = nn.functional.interpolate(
1506
+ fpn_outs[i],
1507
+ size=fpn_outs[0].shape[2:],
1508
+ mode="bilinear",
1509
+ align_corners=self.align_corners,
1510
+ )
1511
+ fpn_outs = torch.cat(fpn_outs, dim=1)
1512
+ output = self.fpn_bottleneck(fpn_outs)
1513
+ output = self.classifier(output)
1514
+
1515
+ return output
1516
+
1517
+
1518
+ class BeitFCNHead(nn.Module):
1519
+ """
1520
+ Fully Convolution Networks for Semantic Segmentation. This head is implemented of
1521
+ [FCNNet](https://arxiv.org/abs/1411.4038>).
1522
+
1523
+ Args:
1524
+ config (BeitConfig): Configuration.
1525
+ in_channels
1526
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
1527
+ dilation (int): The dilation rate for convs in the head. Default: 1.
1528
+
1529
+
1530
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1531
+ """
1532
+
1533
+ def __init__(
1534
+ self,
1535
+ config: BeitConfig,
1536
+ in_index: int = 2,
1537
+ kernel_size: int = 3,
1538
+ dilation: Union[int, Tuple[int, int]] = 1,
1539
+ ) -> None:
1540
+ super().__init__()
1541
+ self.in_channels = config.hidden_size
1542
+ self.channels = config.auxiliary_channels
1543
+ self.num_convs = config.auxiliary_num_convs
1544
+ self.concat_input = config.auxiliary_concat_input
1545
+ self.in_index = in_index
1546
+
1547
+ conv_padding = (kernel_size // 2) * dilation
1548
+ convs = []
1549
+ convs.append(
1550
+ BeitConvModule(
1551
+ self.in_channels,
1552
+ self.channels,
1553
+ kernel_size=kernel_size,
1554
+ padding=conv_padding,
1555
+ dilation=dilation,
1556
+ )
1557
+ )
1558
+ for i in range(self.num_convs - 1):
1559
+ convs.append(
1560
+ BeitConvModule(
1561
+ self.channels,
1562
+ self.channels,
1563
+ kernel_size=kernel_size,
1564
+ padding=conv_padding,
1565
+ dilation=dilation,
1566
+ )
1567
+ )
1568
+ if self.num_convs == 0:
1569
+ self.convs = nn.Identity()
1570
+ else:
1571
+ self.convs = nn.Sequential(*convs)
1572
+ if self.concat_input:
1573
+ self.conv_cat = BeitConvModule(
1574
+ self.in_channels + self.channels,
1575
+ self.channels,
1576
+ kernel_size=kernel_size,
1577
+ padding=kernel_size // 2,
1578
+ )
1579
+
1580
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
1581
+
1582
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
1583
+ # just take the relevant feature maps
1584
+ hidden_states = encoder_hidden_states[self.in_index]
1585
+ output = self.convs(hidden_states)
1586
+ if self.concat_input:
1587
+ output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
1588
+ output = self.classifier(output)
1589
+ return output
1590
+
1591
+
1592
+ @add_start_docstrings(
1593
+ """
1594
+ Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
1595
+ """,
1596
+ BEIT_START_DOCSTRING,
1597
+ )
1598
+ class BeitForSemanticSegmentation(BeitPreTrainedModel):
1599
+ def __init__(self, config: BeitConfig) -> None:
1600
+ super().__init__(config)
1601
+
1602
+ self.num_labels = config.num_labels
1603
+ self.beit = BeitModel(config, add_pooling_layer=False)
1604
+
1605
+ # FPNs
1606
+ self.fpn1 = nn.Sequential(
1607
+ nn.ConvTranspose2d(
1608
+ config.hidden_size, config.hidden_size, kernel_size=2, stride=2
1609
+ ),
1610
+ nn.BatchNorm2d(config.hidden_size),
1611
+ nn.GELU(),
1612
+ nn.ConvTranspose2d(
1613
+ config.hidden_size, config.hidden_size, kernel_size=2, stride=2
1614
+ ),
1615
+ )
1616
+ self.fpn2 = nn.Sequential(
1617
+ nn.ConvTranspose2d(
1618
+ config.hidden_size, config.hidden_size, kernel_size=2, stride=2
1619
+ ),
1620
+ )
1621
+ self.fpn3 = nn.Identity()
1622
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
1623
+
1624
+ # Semantic segmentation head(s)
1625
+ self.decode_head = BeitUperHead(config)
1626
+ self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
1627
+
1628
+ # Initialize weights and apply final processing
1629
+ self.post_init()
1630
+
1631
+ def compute_loss(self, logits, auxiliary_logits, labels):
1632
+ # upsample logits to the images' original size
1633
+ upsampled_logits = nn.functional.interpolate(
1634
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
1635
+ )
1636
+ if auxiliary_logits is not None:
1637
+ upsampled_auxiliary_logits = nn.functional.interpolate(
1638
+ auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
1639
+ )
1640
+ # compute weighted loss
1641
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
1642
+ main_loss = loss_fct(upsampled_logits, labels)
1643
+ auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
1644
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
1645
+
1646
+ return loss
1647
+
1648
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1649
+ @replace_return_docstrings(
1650
+ output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC
1651
+ )
1652
+ def forward(
1653
+ self,
1654
+ pixel_values: Optional[torch.Tensor] = None,
1655
+ head_mask: Optional[torch.Tensor] = None,
1656
+ labels: Optional[torch.Tensor] = None,
1657
+ output_attentions: Optional[bool] = None,
1658
+ output_hidden_states: Optional[bool] = None,
1659
+ return_dict: Optional[bool] = None,
1660
+ ) -> Union[tuple, SemanticSegmenterOutput]:
1661
+ r"""
1662
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
1663
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
1664
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
1665
+
1666
+ Returns:
1667
+
1668
+ Examples:
1669
+
1670
+ ```python
1671
+ >>> from transformers import AutoFeatureExtractor, BeitForSemanticSegmentation
1672
+ >>> from PIL import Image
1673
+ >>> import requests
1674
+
1675
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1676
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1677
+
1678
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
1679
+ >>> model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
1680
+
1681
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
1682
+ >>> outputs = model(**inputs)
1683
+ >>> # logits are of shape (batch_size, num_labels, height, width)
1684
+ >>> logits = outputs.logits
1685
+ ```"""
1686
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1687
+ output_hidden_states = (
1688
+ output_hidden_states
1689
+ if output_hidden_states is not None
1690
+ else self.config.output_hidden_states
1691
+ )
1692
+
1693
+ outputs = self.beit(
1694
+ pixel_values,
1695
+ head_mask=head_mask,
1696
+ output_attentions=output_attentions,
1697
+ output_hidden_states=True, # we need the intermediate hidden states
1698
+ return_dict=return_dict,
1699
+ )
1700
+
1701
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
1702
+
1703
+ # only keep certain features, and reshape
1704
+ # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
1705
+ features = [
1706
+ feature
1707
+ for idx, feature in enumerate(encoder_hidden_states)
1708
+ if idx + 1 in self.config.out_indices
1709
+ ]
1710
+ batch_size = pixel_values.shape[0]
1711
+ patch_resolution = self.config.image_size // self.config.patch_size
1712
+ features = [
1713
+ x[:, 1:, :]
1714
+ .permute(0, 2, 1)
1715
+ .reshape(batch_size, -1, patch_resolution, patch_resolution)
1716
+ for x in features
1717
+ ]
1718
+
1719
+ # apply FPNs
1720
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
1721
+ for i in range(len(features)):
1722
+ features[i] = ops[i](features[i])
1723
+
1724
+ logits = self.decode_head(features)
1725
+
1726
+ auxiliary_logits = None
1727
+ if self.auxiliary_head is not None:
1728
+ auxiliary_logits = self.auxiliary_head(features)
1729
+
1730
+ loss = None
1731
+ if labels is not None:
1732
+ if self.config.num_labels == 1:
1733
+ raise ValueError("The number of labels should be greater than one")
1734
+ else:
1735
+ loss = self.compute_loss(logits, auxiliary_logits, labels)
1736
+
1737
+ if not return_dict:
1738
+ if output_hidden_states:
1739
+ output = (logits,) + outputs[1:]
1740
+ else:
1741
+ output = (logits,) + outputs[2:]
1742
+ return ((loss,) + output) if loss is not None else output
1743
+
1744
+ return SemanticSegmenterOutput(
1745
+ loss=loss,
1746
+ logits=logits,
1747
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1748
+ attentions=outputs.attentions,
1749
+ )
models_viclip/backbones/bert/.tokenization_bert.py.swp ADDED
Binary file (36.9 kB). View file
 
models_viclip/backbones/bert/__init__.py ADDED
File without changes
models_viclip/backbones/bert/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
models_viclip/backbones/bert/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (175 Bytes). View file
 
models_viclip/backbones/bert/__pycache__/tokenization_bert.cpython-310.pyc ADDED
Binary file (20 kB). View file
 
models_viclip/backbones/bert/__pycache__/tokenization_bert.cpython-38.pyc ADDED
Binary file (19.6 kB). View file
 
models_viclip/backbones/bert/builder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
2
+
3
+ import logging
4
+ logger = logging.getLogger(__name__)
5
+
6
+ def build_bert(model_config, pretrain, checkpoint):
7
+ """build text encoder.
8
+
9
+ Args:
10
+ model_config (dict): model config.
11
+ pretrain (bool): Whether to do pretrain or finetuning.
12
+ checkpoint (bool): whether to do gradient_checkpointing.
13
+
14
+ Returns: TODO
15
+
16
+ """
17
+ bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
18
+ bert_config.encoder_width = model_config.vision_encoder.d_model
19
+ bert_config.gradient_checkpointing = checkpoint
20
+ bert_config.fusion_layer = model_config.text_encoder.fusion_layer
21
+
22
+ if not model_config.multimodal.enable:
23
+ bert_config.fusion_layer = bert_config.num_hidden_layers
24
+
25
+ if pretrain:
26
+ text_encoder, loading_info = BertForMaskedLM.from_pretrained(
27
+ model_config.text_encoder.pretrained,
28
+ config=bert_config,
29
+ output_loading_info=True,
30
+ )
31
+ else:
32
+ text_encoder, loading_info = BertModel.from_pretrained(
33
+ model_config.text_encoder.pretrained,
34
+ config=bert_config,
35
+ add_pooling_layer=False,
36
+ output_loading_info=True,
37
+ )
38
+
39
+ return text_encoder
40
+
41
+
42
+ def build_bert_decoder(model_config, checkpoint):
43
+ """build text decoder the same as the multimodal encoder.
44
+
45
+ Args:
46
+ model_config (dict): model config.
47
+ pretrain (bool): Whether to do pretrain or finetuning.
48
+ checkpoint (bool): whether to do gradient_checkpointing.
49
+
50
+ Returns: TODO
51
+
52
+ """
53
+ bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
54
+ bert_config.encoder_width = model_config.vision_encoder.d_model
55
+ bert_config.gradient_checkpointing = checkpoint
56
+
57
+ bert_config.fusion_layer = 0
58
+ bert_config.num_hidden_layers = (
59
+ bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
60
+ )
61
+
62
+ text_decoder, loading_info = BertLMHeadModel.from_pretrained(
63
+ model_config.text_encoder.pretrained,
64
+ config=bert_config,
65
+ output_loading_info=True,
66
+ )
67
+
68
+ return text_decoder
models_viclip/backbones/bert/tokenization_bert.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Bert."""
16
+
17
+
18
+ import collections
19
+ import os
20
+ import unicodedata
21
+ from typing import List, Optional, Tuple
22
+
23
+ from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
24
+ from transformers.utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ "vocab_file": {
33
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
34
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
35
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
36
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
37
+ "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
38
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
39
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
40
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
41
+ "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
42
+ "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
43
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
44
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
45
+ "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
46
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
47
+ "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
48
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
49
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
50
+ "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
51
+ }
52
+ }
53
+
54
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55
+ "bert-base-uncased": 512,
56
+ "bert-large-uncased": 512,
57
+ "bert-base-cased": 512,
58
+ "bert-large-cased": 512,
59
+ "bert-base-multilingual-uncased": 512,
60
+ "bert-base-multilingual-cased": 512,
61
+ "bert-base-chinese": 512,
62
+ "bert-base-german-cased": 512,
63
+ "bert-large-uncased-whole-word-masking": 512,
64
+ "bert-large-cased-whole-word-masking": 512,
65
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67
+ "bert-base-cased-finetuned-mrpc": 512,
68
+ "bert-base-german-dbmdz-cased": 512,
69
+ "bert-base-german-dbmdz-uncased": 512,
70
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
71
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72
+ "wietsedv/bert-base-dutch-cased": 512,
73
+ }
74
+
75
+ PRETRAINED_INIT_CONFIGURATION = {
76
+ "bert-base-uncased": {"do_lower_case": True},
77
+ "bert-large-uncased": {"do_lower_case": True},
78
+ "bert-base-cased": {"do_lower_case": False},
79
+ "bert-large-cased": {"do_lower_case": False},
80
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
81
+ "bert-base-multilingual-cased": {"do_lower_case": False},
82
+ "bert-base-chinese": {"do_lower_case": False},
83
+ "bert-base-german-cased": {"do_lower_case": False},
84
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94
+ }
95
+
96
+
97
+ def load_vocab(vocab_file):
98
+ """Loads a vocabulary file into a dictionary."""
99
+ vocab = collections.OrderedDict()
100
+ with open(vocab_file, "r", encoding="utf-8") as reader:
101
+ tokens = reader.readlines()
102
+ for index, token in enumerate(tokens):
103
+ token = token.rstrip("\n")
104
+ vocab[token] = index
105
+ return vocab
106
+
107
+
108
+ def whitespace_tokenize(text):
109
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
110
+ text = text.strip()
111
+ if not text:
112
+ return []
113
+ tokens = text.split()
114
+ return tokens
115
+
116
+
117
+ class BertTokenizer(PreTrainedTokenizer):
118
+ r"""
119
+ Construct a BERT tokenizer. Based on WordPiece.
120
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
121
+ Users should refer to this superclass for more information regarding those methods.
122
+ Args:
123
+ vocab_file (:obj:`str`):
124
+ File containing the vocabulary.
125
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
126
+ Whether or not to lowercase the input when tokenizing.
127
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
128
+ Whether or not to do basic tokenization before WordPiece.
129
+ never_split (:obj:`Iterable`, `optional`):
130
+ Collection of tokens which will never be split during tokenization. Only has an effect when
131
+ :obj:`do_basic_tokenize=True`
132
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
133
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
134
+ token instead.
135
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
136
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
137
+ sequence classification or for a text and a question for question answering. It is also used as the last
138
+ token of a sequence built with special tokens.
139
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
140
+ The token used for padding, for example when batching sequences of different lengths.
141
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
142
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
143
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
144
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
145
+ The token used for masking values. This is the token used when training this model with masked language
146
+ modeling. This is the token which the model will try to predict.
147
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
148
+ Whether or not to tokenize Chinese characters.
149
+ This should likely be deactivated for Japanese (see this `issue
150
+ <https://github.com/huggingface/transformers/issues/328>`__).
151
+ strip_accents: (:obj:`bool`, `optional`):
152
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
153
+ value for :obj:`lowercase` (as in the original BERT).
154
+ """
155
+
156
+ vocab_files_names = VOCAB_FILES_NAMES
157
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_file,
164
+ do_lower_case=True,
165
+ do_basic_tokenize=True,
166
+ never_split=None,
167
+ unk_token="[UNK]",
168
+ sep_token="[SEP]",
169
+ pad_token="[PAD]",
170
+ cls_token="[CLS]",
171
+ mask_token="[MASK]",
172
+ tokenize_chinese_chars=True,
173
+ strip_accents=None,
174
+ **kwargs
175
+ ):
176
+ super().__init__(
177
+ do_lower_case=do_lower_case,
178
+ do_basic_tokenize=do_basic_tokenize,
179
+ never_split=never_split,
180
+ unk_token=unk_token,
181
+ sep_token=sep_token,
182
+ pad_token=pad_token,
183
+ cls_token=cls_token,
184
+ mask_token=mask_token,
185
+ tokenize_chinese_chars=tokenize_chinese_chars,
186
+ strip_accents=strip_accents,
187
+ **kwargs,
188
+ )
189
+
190
+ if not os.path.isfile(vocab_file):
191
+ raise ValueError(
192
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
193
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
194
+ vocab_file)
195
+ )
196
+ self.vocab = load_vocab(vocab_file)
197
+ self.ids_to_tokens = collections.OrderedDict(
198
+ [(ids, tok) for tok, ids in self.vocab.items()])
199
+ self.do_basic_tokenize = do_basic_tokenize
200
+ if do_basic_tokenize:
201
+ self.basic_tokenizer = BasicTokenizer(
202
+ do_lower_case=do_lower_case,
203
+ never_split=never_split,
204
+ tokenize_chinese_chars=tokenize_chinese_chars,
205
+ strip_accents=strip_accents,
206
+ )
207
+ self.wordpiece_tokenizer = WordpieceTokenizer(
208
+ vocab=self.vocab, unk_token=self.unk_token)
209
+
210
+ @property
211
+ def do_lower_case(self):
212
+ return self.basic_tokenizer.do_lower_case
213
+
214
+ @property
215
+ def vocab_size(self):
216
+ return len(self.vocab)
217
+
218
+ def get_vocab(self):
219
+ return dict(self.vocab, **self.added_tokens_encoder)
220
+
221
+ def _tokenize(self, text):
222
+ split_tokens = []
223
+ if self.do_basic_tokenize:
224
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
225
+
226
+ # If the token is part of the never_split set
227
+ if token in self.basic_tokenizer.never_split:
228
+ split_tokens.append(token)
229
+ else:
230
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
231
+ else:
232
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
233
+ return split_tokens
234
+
235
+ def _convert_token_to_id(self, token):
236
+ """ Converts a token (str) in an id using the vocab. """
237
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
238
+
239
+ def _convert_id_to_token(self, index):
240
+ """Converts an index (integer) in a token (str) using the vocab."""
241
+ return self.ids_to_tokens.get(index, self.unk_token)
242
+
243
+ def convert_tokens_to_string(self, tokens):
244
+ """ Converts a sequence of tokens (string) in a single string. """
245
+ out_string = " ".join(tokens).replace(" ##", "").strip()
246
+ return out_string
247
+
248
+ def build_inputs_with_special_tokens(
249
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
250
+ ) -> List[int]:
251
+ """
252
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
253
+ adding special tokens. A BERT sequence has the following format:
254
+ - single sequence: ``[CLS] X ``
255
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
256
+ Args:
257
+ token_ids_0 (:obj:`List[int]`):
258
+ List of IDs to which the special tokens will be added.
259
+ token_ids_1 (:obj:`List[int]`, `optional`):
260
+ Optional second list of IDs for sequence pairs.
261
+ Returns:
262
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
263
+ """
264
+ if token_ids_1 is None:
265
+ return [self.cls_token_id] + token_ids_0
266
+ cls = [self.cls_token_id]
267
+ sep = [self.sep_token_id]
268
+ return cls + token_ids_0 + sep + token_ids_1 + sep
269
+
270
+ def get_special_tokens_mask(
271
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
272
+ ) -> List[int]:
273
+ """
274
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
275
+ special tokens using the tokenizer ``prepare_for_model`` method.
276
+ Args:
277
+ token_ids_0 (:obj:`List[int]`):
278
+ List of IDs.
279
+ token_ids_1 (:obj:`List[int]`, `optional`):
280
+ Optional second list of IDs for sequence pairs.
281
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
282
+ Whether or not the token list is already formatted with special tokens for the model.
283
+ Returns:
284
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
285
+ """
286
+
287
+ if already_has_special_tokens:
288
+ if token_ids_1 is not None:
289
+ raise ValueError(
290
+ "You should not supply a second sequence if the provided sequence of "
291
+ "ids is already formatted with special tokens for the model."
292
+ )
293
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
294
+
295
+ if token_ids_1 is not None:
296
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
297
+ return [1] + ([0] * len(token_ids_0)) + [1]
298
+
299
+ def create_token_type_ids_from_sequences(
300
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
301
+ ) -> List[int]:
302
+ """
303
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
304
+ pair mask has the following format:
305
+ ::
306
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
307
+ | first sequence | second sequence |
308
+ If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
309
+ Args:
310
+ token_ids_0 (:obj:`List[int]`):
311
+ List of IDs.
312
+ token_ids_1 (:obj:`List[int]`, `optional`):
313
+ Optional second list of IDs for sequence pairs.
314
+ Returns:
315
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
316
+ sequence(s).
317
+ """
318
+ sep = [self.sep_token_id]
319
+ cls = [self.cls_token_id]
320
+ if token_ids_1 is None:
321
+ return len(cls + token_ids_0 + sep) * [0]
322
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
323
+
324
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
325
+ index = 0
326
+ if os.path.isdir(save_directory):
327
+ vocab_file = os.path.join(
328
+ save_directory, (filename_prefix + "-" if filename_prefix else "") +
329
+ VOCAB_FILES_NAMES["vocab_file"]
330
+ )
331
+ else:
332
+ vocab_file = (filename_prefix +
333
+ "-" if filename_prefix else "") + save_directory
334
+ with open(vocab_file, "w", encoding="utf-8") as writer:
335
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
336
+ if index != token_index:
337
+ logger.warning(
338
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
339
+ " Please check that the vocabulary is not corrupted!".format(
340
+ vocab_file)
341
+ )
342
+ index = token_index
343
+ writer.write(token + "\n")
344
+ index += 1
345
+ return (vocab_file,)
346
+
347
+
348
+ class BasicTokenizer(object):
349
+ """
350
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
351
+ Args:
352
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
353
+ Whether or not to lowercase the input when tokenizing.
354
+ never_split (:obj:`Iterable`, `optional`):
355
+ Collection of tokens which will never be split during tokenization. Only has an effect when
356
+ :obj:`do_basic_tokenize=True`
357
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
358
+ Whether or not to tokenize Chinese characters.
359
+ This should likely be deactivated for Japanese (see this `issue
360
+ <https://github.com/huggingface/transformers/issues/328>`__).
361
+ strip_accents: (:obj:`bool`, `optional`):
362
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
363
+ value for :obj:`lowercase` (as in the original BERT).
364
+ """
365
+
366
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
367
+ if never_split is None:
368
+ never_split = []
369
+ self.do_lower_case = do_lower_case
370
+ self.never_split = set(never_split)
371
+ self.tokenize_chinese_chars = tokenize_chinese_chars
372
+ self.strip_accents = strip_accents
373
+
374
+ def tokenize(self, text, never_split=None):
375
+ """
376
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
377
+ WordPieceTokenizer.
378
+ Args:
379
+ **never_split**: (`optional`) list of str
380
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
381
+ :func:`PreTrainedTokenizer.tokenize`) List of token not to split.
382
+ """
383
+ # union() returns a new set by concatenating the two sets.
384
+ never_split = self.never_split.union(
385
+ set(never_split)) if never_split else self.never_split
386
+ text = self._clean_text(text)
387
+
388
+ # This was added on November 1st, 2018 for the multilingual and Chinese
389
+ # models. This is also applied to the English models now, but it doesn't
390
+ # matter since the English models were not trained on any Chinese data
391
+ # and generally don't have any Chinese data in them (there are Chinese
392
+ # characters in the vocabulary because Wikipedia does have some Chinese
393
+ # words in the English Wikipedia.).
394
+ if self.tokenize_chinese_chars:
395
+ text = self._tokenize_chinese_chars(text)
396
+ orig_tokens = whitespace_tokenize(text)
397
+ split_tokens = []
398
+ for token in orig_tokens:
399
+ if token not in never_split:
400
+ if self.do_lower_case:
401
+ token = token.lower()
402
+ if self.strip_accents is not False:
403
+ token = self._run_strip_accents(token)
404
+ elif self.strip_accents:
405
+ token = self._run_strip_accents(token)
406
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
407
+
408
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
409
+ return output_tokens
410
+
411
+ def _run_strip_accents(self, text):
412
+ """Strips accents from a piece of text."""
413
+ text = unicodedata.normalize("NFD", text)
414
+ output = []
415
+ for char in text:
416
+ cat = unicodedata.category(char)
417
+ if cat == "Mn":
418
+ continue
419
+ output.append(char)
420
+ return "".join(output)
421
+
422
+ def _run_split_on_punc(self, text, never_split=None):
423
+ """Splits punctuation on a piece of text."""
424
+ if never_split is not None and text in never_split:
425
+ return [text]
426
+ chars = list(text)
427
+ i = 0
428
+ start_new_word = True
429
+ output = []
430
+ while i < len(chars):
431
+ char = chars[i]
432
+ if _is_punctuation(char):
433
+ output.append([char])
434
+ start_new_word = True
435
+ else:
436
+ if start_new_word:
437
+ output.append([])
438
+ start_new_word = False
439
+ output[-1].append(char)
440
+ i += 1
441
+
442
+ return ["".join(x) for x in output]
443
+
444
+ def _tokenize_chinese_chars(self, text):
445
+ """Adds whitespace around any CJK character."""
446
+ output = []
447
+ for char in text:
448
+ cp = ord(char)
449
+ if self._is_chinese_char(cp):
450
+ output.append(" ")
451
+ output.append(char)
452
+ output.append(" ")
453
+ else:
454
+ output.append(char)
455
+ return "".join(output)
456
+
457
+ def _is_chinese_char(self, cp):
458
+ """Checks whether CP is the codepoint of a CJK character."""
459
+ # This defines a "chinese character" as anything in the CJK Unicode block:
460
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
461
+ #
462
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
463
+ # despite its name. The modern Korean Hangul alphabet is a different block,
464
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
465
+ # space-separated words, so they are not treated specially and handled
466
+ # like the all of the other languages.
467
+ if (
468
+ (cp >= 0x4E00 and cp <= 0x9FFF)
469
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
470
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
471
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
472
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
473
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
474
+ or (cp >= 0xF900 and cp <= 0xFAFF)
475
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
476
+ ): #
477
+ return True
478
+
479
+ return False
480
+
481
+ def _clean_text(self, text):
482
+ """Performs invalid character removal and whitespace cleanup on text."""
483
+ output = []
484
+ for char in text:
485
+ cp = ord(char)
486
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
487
+ continue
488
+ if _is_whitespace(char):
489
+ output.append(" ")
490
+ else:
491
+ output.append(char)
492
+ return "".join(output)
493
+
494
+
495
+ class WordpieceTokenizer(object):
496
+ """Runs WordPiece tokenization."""
497
+
498
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
499
+ self.vocab = vocab
500
+ self.unk_token = unk_token
501
+ self.max_input_chars_per_word = max_input_chars_per_word
502
+
503
+ def tokenize(self, text):
504
+ """
505
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
506
+ tokenization using the given vocabulary.
507
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
508
+ Args:
509
+ text: A single token or whitespace separated tokens. This should have
510
+ already been passed through `BasicTokenizer`.
511
+ Returns:
512
+ A list of wordpiece tokens.
513
+ """
514
+
515
+ output_tokens = []
516
+ for token in whitespace_tokenize(text):
517
+ chars = list(token)
518
+ if len(chars) > self.max_input_chars_per_word:
519
+ output_tokens.append(self.unk_token)
520
+ continue
521
+
522
+ is_bad = False
523
+ start = 0
524
+ sub_tokens = []
525
+ while start < len(chars):
526
+ end = len(chars)
527
+ cur_substr = None
528
+ while start < end:
529
+ substr = "".join(chars[start:end])
530
+ if start > 0:
531
+ substr = "##" + substr
532
+ if substr in self.vocab:
533
+ cur_substr = substr
534
+ break
535
+ end -= 1
536
+ if cur_substr is None:
537
+ is_bad = True
538
+ break
539
+ sub_tokens.append(cur_substr)
540
+ start = end
541
+
542
+ if is_bad:
543
+ output_tokens.append(self.unk_token)
544
+ else:
545
+ output_tokens.extend(sub_tokens)
546
+ return output_tokens
models_viclip/backbones/bert/xbert.py ADDED
@@ -0,0 +1,2157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ import transformers
28
+ from torch import Tensor, device, dtype, nn
29
+ from torch.nn import CrossEntropyLoss, MSELoss
30
+ from transformers.activations import ACT2FN
31
+ # from transformers.models.bert.configuration_bert import BertConfig
32
+ from transformers.configuration_utils import PretrainedConfig
33
+ from transformers.file_utils import (ModelOutput, add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ replace_return_docstrings)
36
+ from transformers.modeling_outputs import (
37
+ BaseModelOutputWithPastAndCrossAttentions,
38
+ BaseModelOutputWithPoolingAndCrossAttentions,
39
+ CausalLMOutputWithCrossAttentions, MaskedLMOutput,
40
+ MultipleChoiceModelOutput, NextSentencePredictorOutput,
41
+ QuestionAnsweringModelOutput, SequenceClassifierOutput,
42
+ TokenClassifierOutput)
43
+ from transformers.modeling_utils import (PreTrainedModel,
44
+ apply_chunking_to_forward,
45
+ find_pruneable_heads_and_indices,
46
+ prune_linear_layer)
47
+ from transformers.utils import logging
48
+
49
+ transformers.logging.set_verbosity_error()
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _CONFIG_FOR_DOC = "BertConfig"
54
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
55
+
56
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "bert-base-uncased",
58
+ "bert-large-uncased",
59
+ "bert-base-cased",
60
+ "bert-large-cased",
61
+ "bert-base-multilingual-uncased",
62
+ "bert-base-multilingual-cased",
63
+ "bert-base-chinese",
64
+ "bert-base-german-cased",
65
+ "bert-large-uncased-whole-word-masking",
66
+ "bert-large-cased-whole-word-masking",
67
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
68
+ "bert-large-cased-whole-word-masking-finetuned-squad",
69
+ "bert-base-cased-finetuned-mrpc",
70
+ "bert-base-german-dbmdz-cased",
71
+ "bert-base-german-dbmdz-uncased",
72
+ "cl-tohoku/bert-base-japanese",
73
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
74
+ "cl-tohoku/bert-base-japanese-char",
75
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
76
+ "TurkuNLP/bert-base-finnish-cased-v1",
77
+ "TurkuNLP/bert-base-finnish-uncased-v1",
78
+ "wietsedv/bert-base-dutch-cased",
79
+ # See all BERT models at https://huggingface.co/models?filter=bert
80
+ ]
81
+
82
+
83
+ class BertConfig(PretrainedConfig):
84
+ r"""
85
+ This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
86
+ instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
87
+ configuration with the defaults will yield a similar configuration to that of the BERT
88
+ [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
89
+
90
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
91
+ documentation from [`PretrainedConfig`] for more information.
92
+
93
+
94
+ Args:
95
+ vocab_size (`int`, *optional*, defaults to 30522):
96
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
97
+ `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
98
+ hidden_size (`int`, *optional*, defaults to 768):
99
+ Dimensionality of the encoder layers and the pooler layer.
100
+ num_hidden_layers (`int`, *optional*, defaults to 12):
101
+ Number of hidden layers in the Transformer encoder.
102
+ num_attention_heads (`int`, *optional*, defaults to 12):
103
+ Number of attention heads for each attention layer in the Transformer encoder.
104
+ intermediate_size (`int`, *optional*, defaults to 3072):
105
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
106
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
107
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
108
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
109
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
110
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
111
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
112
+ The dropout ratio for the attention probabilities.
113
+ max_position_embeddings (`int`, *optional*, defaults to 512):
114
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
115
+ just in case (e.g., 512 or 1024 or 2048).
116
+ type_vocab_size (`int`, *optional*, defaults to 2):
117
+ The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
118
+ initializer_range (`float`, *optional*, defaults to 0.02):
119
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
120
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
121
+ The epsilon used by the layer normalization layers.
122
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
123
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
124
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
125
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
126
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
127
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
128
+ use_cache (`bool`, *optional*, defaults to `True`):
129
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
130
+ relevant if `config.is_decoder=True`.
131
+ classifier_dropout (`float`, *optional*):
132
+ The dropout ratio for the classification head.
133
+
134
+ Examples:
135
+
136
+ ```python
137
+ >>> from transformers import BertModel, BertConfig
138
+
139
+ >>> # Initializing a BERT bert-base-uncased style configuration
140
+ >>> configuration = BertConfig()
141
+
142
+ >>> # Initializing a model from the bert-base-uncased style configuration
143
+ >>> model = BertModel(configuration)
144
+
145
+ >>> # Accessing the model configuration
146
+ >>> configuration = model.config
147
+ ```"""
148
+ model_type = "bert"
149
+
150
+ def __init__(
151
+ self,
152
+ vocab_size=30522,
153
+ hidden_size=768,
154
+ num_hidden_layers=12,
155
+ num_attention_heads=12,
156
+ intermediate_size=3072,
157
+ hidden_act="gelu",
158
+ hidden_dropout_prob=0.1,
159
+ attention_probs_dropout_prob=0.1,
160
+ max_position_embeddings=512,
161
+ type_vocab_size=2,
162
+ initializer_range=0.02,
163
+ layer_norm_eps=1e-12,
164
+ pad_token_id=0,
165
+ position_embedding_type="absolute",
166
+ use_cache=True,
167
+ classifier_dropout=None,
168
+ cross_module="ca",
169
+ **kwargs,
170
+ ):
171
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
172
+
173
+ self.vocab_size = vocab_size
174
+ self.hidden_size = hidden_size
175
+ self.num_hidden_layers = num_hidden_layers
176
+ self.num_attention_heads = num_attention_heads
177
+ self.hidden_act = hidden_act
178
+ self.intermediate_size = intermediate_size
179
+ self.hidden_dropout_prob = hidden_dropout_prob
180
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
181
+ self.max_position_embeddings = max_position_embeddings
182
+ self.type_vocab_size = type_vocab_size
183
+ self.initializer_range = initializer_range
184
+ self.layer_norm_eps = layer_norm_eps
185
+ self.position_embedding_type = position_embedding_type
186
+ self.use_cache = use_cache
187
+ self.classifier_dropout = classifier_dropout
188
+ self.cross_module = cross_module
189
+
190
+
191
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
192
+ """Load tf checkpoints in a pytorch model."""
193
+ try:
194
+ import re
195
+
196
+ import numpy as np
197
+ import tensorflow as tf
198
+ except ImportError:
199
+ logger.error(
200
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
201
+ "https://www.tensorflow.org/install/ for installation instructions."
202
+ )
203
+ raise
204
+ tf_path = os.path.abspath(tf_checkpoint_path)
205
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
206
+ # Load weights from TF model
207
+ init_vars = tf.train.list_variables(tf_path)
208
+ names = []
209
+ arrays = []
210
+ for name, shape in init_vars:
211
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
212
+ array = tf.train.load_variable(tf_path, name)
213
+ names.append(name)
214
+ arrays.append(array)
215
+
216
+ for name, array in zip(names, arrays):
217
+ name = name.split("/")
218
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
219
+ # which are not required for using pretrained model
220
+ if any(
221
+ n
222
+ in [
223
+ "adam_v",
224
+ "adam_m",
225
+ "AdamWeightDecayOptimizer",
226
+ "AdamWeightDecayOptimizer_1",
227
+ "global_step",
228
+ ]
229
+ for n in name
230
+ ):
231
+ logger.info("Skipping {}".format("/".join(name)))
232
+ continue
233
+ pointer = model
234
+ for m_name in name:
235
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
236
+ scope_names = re.split(r"_(\d+)", m_name)
237
+ else:
238
+ scope_names = [m_name]
239
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
240
+ pointer = getattr(pointer, "weight")
241
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
242
+ pointer = getattr(pointer, "bias")
243
+ elif scope_names[0] == "output_weights":
244
+ pointer = getattr(pointer, "weight")
245
+ elif scope_names[0] == "squad":
246
+ pointer = getattr(pointer, "classifier")
247
+ else:
248
+ try:
249
+ pointer = getattr(pointer, scope_names[0])
250
+ except AttributeError:
251
+ logger.info("Skipping {}".format("/".join(name)))
252
+ continue
253
+ if len(scope_names) >= 2:
254
+ num = int(scope_names[1])
255
+ pointer = pointer[num]
256
+ if m_name[-11:] == "_embeddings":
257
+ pointer = getattr(pointer, "weight")
258
+ elif m_name == "kernel":
259
+ array = np.transpose(array)
260
+ try:
261
+ assert (
262
+ pointer.shape == array.shape
263
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
264
+ except AssertionError as e:
265
+ e.args += (pointer.shape, array.shape)
266
+ raise
267
+ logger.info("Initialize PyTorch weight {}".format(name))
268
+ pointer.data = torch.from_numpy(array)
269
+ return model
270
+
271
+
272
+ class BertEmbeddings(nn.Module):
273
+ """Construct the embeddings from word, position and token_type embeddings."""
274
+
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.word_embeddings = nn.Embedding(
278
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
279
+ )
280
+ self.position_embeddings = nn.Embedding(
281
+ config.max_position_embeddings, config.hidden_size
282
+ )
283
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
284
+
285
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
286
+ # any TensorFlow checkpoint file
287
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
288
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
289
+
290
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
291
+ self.register_buffer(
292
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
293
+ )
294
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
295
+
296
+ self.config = config
297
+
298
+ def forward(
299
+ self,
300
+ input_ids=None,
301
+ token_type_ids=None,
302
+ position_ids=None,
303
+ inputs_embeds=None,
304
+ past_key_values_length=0,
305
+ ):
306
+ if input_ids is not None:
307
+ input_shape = input_ids.size()
308
+ else:
309
+ input_shape = inputs_embeds.size()[:-1]
310
+
311
+ seq_length = input_shape[1]
312
+
313
+ if position_ids is None:
314
+ position_ids = self.position_ids[
315
+ :, past_key_values_length : seq_length + past_key_values_length
316
+ ]
317
+
318
+ if token_type_ids is None:
319
+ token_type_ids = torch.zeros(
320
+ input_shape, dtype=torch.long, device=self.position_ids.device
321
+ )
322
+
323
+ if inputs_embeds is None:
324
+ inputs_embeds = self.word_embeddings(input_ids)
325
+
326
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
327
+
328
+ embeddings = inputs_embeds + token_type_embeddings
329
+ if self.position_embedding_type == "absolute":
330
+ position_embeddings = self.position_embeddings(position_ids)
331
+ embeddings += position_embeddings
332
+ embeddings = self.LayerNorm(embeddings)
333
+ embeddings = self.dropout(embeddings)
334
+ return embeddings
335
+
336
+
337
+ class BertSelfAttention(nn.Module):
338
+ def __init__(self, config, is_cross_attention):
339
+ super().__init__()
340
+ self.config = config
341
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
342
+ config, "embedding_size"
343
+ ):
344
+ raise ValueError(
345
+ "The hidden size (%d) is not a multiple of the number of attention "
346
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
347
+ )
348
+
349
+ self.num_attention_heads = config.num_attention_heads
350
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
351
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
352
+
353
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
354
+ if is_cross_attention:
355
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
356
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
357
+ else:
358
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
359
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
360
+
361
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
362
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
363
+ if (
364
+ self.position_embedding_type == "relative_key"
365
+ or self.position_embedding_type == "relative_key_query"
366
+ ):
367
+ self.max_position_embeddings = config.max_position_embeddings
368
+ self.distance_embedding = nn.Embedding(
369
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
370
+ )
371
+ self.save_attention = False
372
+
373
+ def save_attn_gradients(self, attn_gradients):
374
+ self.attn_gradients = attn_gradients
375
+
376
+ def get_attn_gradients(self):
377
+ return self.attn_gradients
378
+
379
+ def save_attention_map(self, attention_map):
380
+ self.attention_map = attention_map
381
+
382
+ def get_attention_map(self):
383
+ return self.attention_map
384
+
385
+ def transpose_for_scores(self, x):
386
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
387
+ x = x.view(*new_x_shape)
388
+ return x.permute(0, 2, 1, 3)
389
+
390
+ def forward(
391
+ self,
392
+ hidden_states,
393
+ attention_mask=None,
394
+ head_mask=None,
395
+ encoder_hidden_states=None,
396
+ encoder_attention_mask=None,
397
+ past_key_value=None,
398
+ output_attentions=False,
399
+ ):
400
+ mixed_query_layer = self.query(hidden_states)
401
+
402
+ # If this is instantiated as a cross-attention module, the keys
403
+ # and values come from an encoder; the attention mask needs to be
404
+ # such that the encoder's padding tokens are not attended to.
405
+ is_cross_attention = encoder_hidden_states is not None
406
+
407
+ if is_cross_attention:
408
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
409
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
410
+ attention_mask = encoder_attention_mask
411
+ elif past_key_value is not None:
412
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
413
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
414
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
415
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
416
+ else:
417
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
418
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
419
+
420
+ query_layer = self.transpose_for_scores(mixed_query_layer)
421
+
422
+ past_key_value = (key_layer, value_layer)
423
+
424
+ # Take the dot product between "query" and "key" to get the raw attention scores.
425
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
426
+
427
+ if (
428
+ self.position_embedding_type == "relative_key"
429
+ or self.position_embedding_type == "relative_key_query"
430
+ ):
431
+ seq_length = hidden_states.size()[1]
432
+ position_ids_l = torch.arange(
433
+ seq_length, dtype=torch.long, device=hidden_states.device
434
+ ).view(-1, 1)
435
+ position_ids_r = torch.arange(
436
+ seq_length, dtype=torch.long, device=hidden_states.device
437
+ ).view(1, -1)
438
+ distance = position_ids_l - position_ids_r
439
+ positional_embedding = self.distance_embedding(
440
+ distance + self.max_position_embeddings - 1
441
+ )
442
+ positional_embedding = positional_embedding.to(
443
+ dtype=query_layer.dtype
444
+ ) # fp16 compatibility
445
+
446
+ if self.position_embedding_type == "relative_key":
447
+ relative_position_scores = torch.einsum(
448
+ "bhld,lrd->bhlr", query_layer, positional_embedding
449
+ )
450
+ attention_scores = attention_scores + relative_position_scores
451
+ elif self.position_embedding_type == "relative_key_query":
452
+ relative_position_scores_query = torch.einsum(
453
+ "bhld,lrd->bhlr", query_layer, positional_embedding
454
+ )
455
+ relative_position_scores_key = torch.einsum(
456
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
457
+ )
458
+ attention_scores = (
459
+ attention_scores
460
+ + relative_position_scores_query
461
+ + relative_position_scores_key
462
+ )
463
+
464
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
465
+ if attention_mask is not None:
466
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
467
+ attention_scores = attention_scores + attention_mask
468
+
469
+ # Normalize the attention scores to probabilities.
470
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
471
+
472
+ if is_cross_attention and self.save_attention:
473
+ self.save_attention_map(attention_probs)
474
+ attention_probs.register_hook(self.save_attn_gradients)
475
+
476
+ # This is actually dropping out entire tokens to attend to, which might
477
+ # seem a bit unusual, but is taken from the original Transformer paper.
478
+ attention_probs_dropped = self.dropout(attention_probs)
479
+
480
+ # Mask heads if we want to
481
+ if head_mask is not None:
482
+ attention_probs_dropped = attention_probs_dropped * head_mask
483
+
484
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
485
+
486
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
487
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
488
+ context_layer = context_layer.view(*new_context_layer_shape)
489
+
490
+ # added `attention_scores` to return tuple
491
+ outputs = (
492
+ (context_layer, attention_probs, attention_scores)
493
+ if output_attentions
494
+ else (context_layer,)
495
+ )
496
+
497
+ outputs = outputs + (past_key_value,)
498
+ return outputs
499
+
500
+
501
+ class BertSelfOutput(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
506
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
507
+
508
+ def forward(self, hidden_states, input_tensor):
509
+ hidden_states = self.dense(hidden_states)
510
+ hidden_states = self.dropout(hidden_states)
511
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
512
+ return hidden_states
513
+
514
+
515
+ class BertAttention(nn.Module):
516
+ def __init__(self, config, is_cross_attention=False):
517
+ super().__init__()
518
+
519
+ self.self = BertSelfAttention(config, is_cross_attention)
520
+
521
+ self.output = BertSelfOutput(config)
522
+ self.pruned_heads = set()
523
+
524
+ def prune_heads(self, heads):
525
+ if len(heads) == 0:
526
+ return
527
+ heads, index = find_pruneable_heads_and_indices(
528
+ heads,
529
+ self.self.num_attention_heads,
530
+ self.self.attention_head_size,
531
+ self.pruned_heads,
532
+ )
533
+
534
+ # Prune linear layers
535
+ self.self.query = prune_linear_layer(self.self.query, index)
536
+ self.self.key = prune_linear_layer(self.self.key, index)
537
+ self.self.value = prune_linear_layer(self.self.value, index)
538
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
539
+
540
+ # Update hyper params and store pruned heads
541
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
542
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
543
+ self.pruned_heads = self.pruned_heads.union(heads)
544
+
545
+ def forward(
546
+ self,
547
+ hidden_states,
548
+ attention_mask=None,
549
+ head_mask=None,
550
+ encoder_hidden_states=None,
551
+ encoder_attention_mask=None,
552
+ past_key_value=None,
553
+ output_attentions=False,
554
+ ):
555
+ self_outputs = self.self(
556
+ hidden_states,
557
+ attention_mask,
558
+ head_mask,
559
+ encoder_hidden_states,
560
+ encoder_attention_mask,
561
+ past_key_value,
562
+ output_attentions,
563
+ )
564
+ attention_output = self.output(self_outputs[0], hidden_states)
565
+ # add attentions if we output them
566
+ outputs = (attention_output,) + self_outputs[1:]
567
+ return outputs # (context_layer, attention_probs, attention_scores, past_key_value,)
568
+
569
+
570
+ class BertIntermediate(nn.Module):
571
+ def __init__(self, config):
572
+ super().__init__()
573
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
574
+ if isinstance(config.hidden_act, str):
575
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
576
+ else:
577
+ self.intermediate_act_fn = config.hidden_act
578
+
579
+ def forward(self, hidden_states):
580
+ hidden_states = self.dense(hidden_states)
581
+ hidden_states = self.intermediate_act_fn(hidden_states)
582
+ return hidden_states
583
+
584
+
585
+ class BertOutput(nn.Module):
586
+ def __init__(self, config):
587
+ super().__init__()
588
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
589
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
590
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
591
+
592
+ def forward(self, hidden_states, input_tensor):
593
+ hidden_states = self.dense(hidden_states)
594
+ hidden_states = self.dropout(hidden_states)
595
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
596
+ return hidden_states
597
+
598
+
599
+ class BertLayer(nn.Module):
600
+ def __init__(self, config, layer_num):
601
+ super().__init__()
602
+ self.config = config
603
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
604
+ self.seq_len_dim = 1
605
+ self.attention = BertAttention(config)
606
+
607
+ self.has_cross_attention = layer_num >= config.fusion_layer
608
+ if self.has_cross_attention:
609
+ self.crossattention = BertAttention(config, is_cross_attention=True)
610
+ self.intermediate = BertIntermediate(config)
611
+ self.output = BertOutput(config)
612
+
613
+ def forward(
614
+ self,
615
+ hidden_states,
616
+ attention_mask=None,
617
+ head_mask=None,
618
+ encoder_hidden_states=None,
619
+ encoder_attention_mask=None,
620
+ past_key_value=None,
621
+ output_attentions=False,
622
+ ):
623
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
624
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
625
+ self_attention_outputs = self.attention(
626
+ hidden_states,
627
+ attention_mask,
628
+ head_mask,
629
+ output_attentions=output_attentions,
630
+ past_key_value=self_attn_past_key_value,
631
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
632
+ attention_output = self_attention_outputs[0]
633
+
634
+ outputs = self_attention_outputs[1:-1]
635
+ present_key_value = self_attention_outputs[-1]
636
+
637
+ if self.has_cross_attention:
638
+ assert (
639
+ encoder_hidden_states is not None
640
+ ), "encoder_hidden_states must be given for cross-attention layers"
641
+
642
+ if type(encoder_hidden_states) == list:
643
+ cross_attention_outputs = self.crossattention(
644
+ attention_output,
645
+ attention_mask,
646
+ head_mask,
647
+ encoder_hidden_states[
648
+ (self.layer_num - self.config.fusion_layer)
649
+ % len(encoder_hidden_states)
650
+ ],
651
+ encoder_attention_mask[
652
+ (self.layer_num - self.config.fusion_layer)
653
+ % len(encoder_hidden_states)
654
+ ],
655
+ output_attentions=output_attentions,
656
+ )
657
+ attention_output = cross_attention_outputs[0]
658
+ outputs = outputs + cross_attention_outputs[1:-1]
659
+
660
+ else:
661
+ cross_attention_outputs = self.crossattention(
662
+ attention_output,
663
+ attention_mask,
664
+ head_mask,
665
+ encoder_hidden_states,
666
+ encoder_attention_mask,
667
+ output_attentions=output_attentions,
668
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
669
+ attention_output = cross_attention_outputs[0]
670
+ # add cross attentions if we output attention weights
671
+ outputs = outputs + cross_attention_outputs[1:-1]
672
+ layer_output = apply_chunking_to_forward(
673
+ self.feed_forward_chunk,
674
+ self.chunk_size_feed_forward,
675
+ self.seq_len_dim,
676
+ attention_output,
677
+ )
678
+ outputs = (layer_output,) + outputs
679
+
680
+ outputs = outputs + (present_key_value,)
681
+
682
+ return outputs
683
+
684
+ def feed_forward_chunk(self, attention_output):
685
+ intermediate_output = self.intermediate(attention_output)
686
+ layer_output = self.output(intermediate_output, attention_output)
687
+ return layer_output
688
+
689
+
690
+ class BertEncoder(nn.Module):
691
+ def __init__(self, config):
692
+ super().__init__()
693
+ self.config = config
694
+ self.layer = nn.ModuleList(
695
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
696
+ )
697
+ logger.info(f"build bert with cross_module: {config.cross_module}")
698
+
699
+ def forward(
700
+ self,
701
+ hidden_states,
702
+ attention_mask=None,
703
+ head_mask=None,
704
+ encoder_hidden_states=None,
705
+ encoder_attention_mask=None,
706
+ past_key_values=None,
707
+ use_cache=None,
708
+ output_attentions=False,
709
+ output_hidden_states=False,
710
+ return_dict=True,
711
+ mode="multi_modal",
712
+ normalize_attention=True,
713
+ ):
714
+ all_hidden_states = () if output_hidden_states else None
715
+ all_self_attentions = () if output_attentions else None
716
+ # all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
717
+ all_cross_attentions = () if output_attentions else None
718
+
719
+ next_decoder_cache = () if use_cache else None
720
+
721
+ if (
722
+ mode == "text" or mode == "temporal"
723
+ ): # temporal is added and used for temporal att module.
724
+ start_layer = 0
725
+ output_layer = self.config.fusion_layer
726
+
727
+ elif mode == "fusion":
728
+ start_layer = self.config.fusion_layer
729
+ output_layer = self.config.num_hidden_layers
730
+
731
+ elif mode == "multi_modal":
732
+ start_layer = 0
733
+ output_layer = self.config.num_hidden_layers
734
+
735
+ for i in range(start_layer, output_layer):
736
+ layer_module = self.layer[i]
737
+ if output_hidden_states:
738
+ all_hidden_states = all_hidden_states + (hidden_states,)
739
+
740
+ layer_head_mask = head_mask[i] if head_mask is not None else None
741
+ past_key_value = past_key_values[i] if past_key_values is not None else None
742
+
743
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
744
+
745
+ if use_cache:
746
+ logger.warn(
747
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
748
+ "`use_cache=False`..."
749
+ )
750
+ use_cache = False
751
+
752
+ def create_custom_forward(module):
753
+ def custom_forward(*inputs):
754
+ return module(*inputs, past_key_value, output_attentions)
755
+
756
+ return custom_forward
757
+
758
+ layer_outputs = torch.utils.checkpoint.checkpoint(
759
+ create_custom_forward(layer_module),
760
+ hidden_states,
761
+ attention_mask,
762
+ layer_head_mask,
763
+ encoder_hidden_states,
764
+ encoder_attention_mask,
765
+ use_reentrant=False,
766
+ )
767
+ else:
768
+ layer_outputs = layer_module(
769
+ hidden_states,
770
+ attention_mask,
771
+ layer_head_mask,
772
+ encoder_hidden_states,
773
+ encoder_attention_mask,
774
+ past_key_value,
775
+ output_attentions,
776
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
777
+ hidden_states = layer_outputs[0]
778
+ if use_cache:
779
+ next_decoder_cache += (layer_outputs[-1],)
780
+ if output_attentions:
781
+ # whether to output normalized attention,
782
+ # note for unnormalized attention, there is a mask added
783
+ offset = int(normalize_attention)
784
+ # all_self_attentions = all_self_attentions + (layer_outputs[1], )
785
+ all_self_attentions = all_self_attentions + (layer_outputs[2 - offset],)
786
+ if hasattr(layer_module, "crossattention"):
787
+ # all_cross_attentions = all_cross_attentions + (layer_outputs[3], )
788
+ all_cross_attentions = all_cross_attentions + (layer_outputs[4 - offset],)
789
+
790
+ if output_hidden_states:
791
+ all_hidden_states = all_hidden_states + (hidden_states,)
792
+
793
+ if not return_dict:
794
+ return tuple(
795
+ v
796
+ for v in [
797
+ hidden_states,
798
+ next_decoder_cache,
799
+ all_hidden_states,
800
+ all_self_attentions,
801
+ all_cross_attentions,
802
+ ]
803
+ if v is not None
804
+ )
805
+ return BaseModelOutputWithPastAndCrossAttentions(
806
+ last_hidden_state=hidden_states,
807
+ past_key_values=next_decoder_cache,
808
+ hidden_states=all_hidden_states,
809
+ attentions=all_self_attentions,
810
+ cross_attentions=all_cross_attentions,
811
+ )
812
+
813
+
814
+ class BertPooler(nn.Module):
815
+ def __init__(self, config):
816
+ super().__init__()
817
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
818
+ self.activation = nn.Tanh()
819
+
820
+ def forward(self, hidden_states):
821
+ # We "pool" the model by simply taking the hidden state corresponding
822
+ # to the first token.
823
+ first_token_tensor = hidden_states[:, 0]
824
+ pooled_output = self.dense(first_token_tensor)
825
+ pooled_output = self.activation(pooled_output)
826
+ return pooled_output
827
+
828
+
829
+ class BertPredictionHeadTransform(nn.Module):
830
+ def __init__(self, config):
831
+ super().__init__()
832
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
833
+ if isinstance(config.hidden_act, str):
834
+ self.transform_act_fn = ACT2FN[config.hidden_act]
835
+ else:
836
+ self.transform_act_fn = config.hidden_act
837
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
838
+
839
+ def forward(self, hidden_states):
840
+ hidden_states = self.dense(hidden_states)
841
+ hidden_states = self.transform_act_fn(hidden_states)
842
+ hidden_states = self.LayerNorm(hidden_states)
843
+ return hidden_states
844
+
845
+
846
+ class BertLMPredictionHead(nn.Module):
847
+ def __init__(self, config):
848
+ super().__init__()
849
+ self.transform = BertPredictionHeadTransform(config)
850
+
851
+ # The output weights are the same as the input embeddings, but there is
852
+ # an output-only bias for each token.
853
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
854
+
855
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
856
+
857
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
858
+ self.decoder.bias = self.bias
859
+
860
+ def forward(self, hidden_states):
861
+ hidden_states = self.transform(hidden_states)
862
+ hidden_states = self.decoder(hidden_states)
863
+ return hidden_states
864
+
865
+
866
+ class BertOnlyMLMHead(nn.Module):
867
+ def __init__(self, config):
868
+ super().__init__()
869
+ self.predictions = BertLMPredictionHead(config)
870
+
871
+ def forward(self, sequence_output):
872
+ prediction_scores = self.predictions(sequence_output)
873
+ return prediction_scores
874
+
875
+
876
+ class BertOnlyNSPHead(nn.Module):
877
+ def __init__(self, config):
878
+ super().__init__()
879
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
880
+
881
+ def forward(self, pooled_output):
882
+ seq_relationship_score = self.seq_relationship(pooled_output)
883
+ return seq_relationship_score
884
+
885
+
886
+ class BertPreTrainingHeads(nn.Module):
887
+ def __init__(self, config):
888
+ super().__init__()
889
+ self.predictions = BertLMPredictionHead(config)
890
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
891
+
892
+ def forward(self, sequence_output, pooled_output):
893
+ prediction_scores = self.predictions(sequence_output)
894
+ seq_relationship_score = self.seq_relationship(pooled_output)
895
+ return prediction_scores, seq_relationship_score
896
+
897
+
898
+ class BertPreTrainedModel(PreTrainedModel):
899
+ """
900
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
901
+ models.
902
+ """
903
+
904
+ config_class = BertConfig
905
+ load_tf_weights = load_tf_weights_in_bert
906
+ base_model_prefix = "bert"
907
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
908
+
909
+ def _init_weights(self, module):
910
+ """Initialize the weights"""
911
+ if isinstance(module, (nn.Linear, nn.Embedding)):
912
+ # Slightly different from the TF version which uses truncated_normal for initialization
913
+ # cf https://github.com/pytorch/pytorch/pull/5617
914
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
915
+ elif isinstance(module, nn.LayerNorm):
916
+ module.bias.data.zero_()
917
+ module.weight.data.fill_(1.0)
918
+ if isinstance(module, nn.Linear) and module.bias is not None:
919
+ module.bias.data.zero_()
920
+
921
+
922
+ @dataclass
923
+ class BertForPreTrainingOutput(ModelOutput):
924
+ """
925
+ Output type of :class:`~transformers.BertForPreTraining`.
926
+ Args:
927
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
928
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
929
+ (classification) loss.
930
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
931
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
932
+ seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
933
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
934
+ before SoftMax).
935
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
936
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
937
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
938
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
939
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
940
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
941
+ sequence_length, sequence_length)`.
942
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
943
+ heads.
944
+ """
945
+
946
+ loss: Optional[torch.FloatTensor] = None
947
+ prediction_logits: torch.FloatTensor = None
948
+ seq_relationship_logits: torch.FloatTensor = None
949
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
950
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
951
+
952
+
953
+ BERT_START_DOCSTRING = r"""
954
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
955
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
956
+ pruning heads etc.)
957
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
958
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
959
+ general usage and behavior.
960
+ Parameters:
961
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
962
+ Initializing with a config file does not load the weights associated with the model, only the
963
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
964
+ weights.
965
+ """
966
+
967
+ BERT_INPUTS_DOCSTRING = r"""
968
+ Args:
969
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
970
+ Indices of input sequence tokens in the vocabulary.
971
+ Indices can be obtained using :class:`~transformers.BertTokenizer`. See
972
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
973
+ details.
974
+ `What are input IDs? <../glossary.html#input-ids>`__
975
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
976
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
977
+ - 1 for tokens that are **not masked**,
978
+ - 0 for tokens that are **masked**.
979
+ `What are attention masks? <../glossary.html#attention-mask>`__
980
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
981
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
982
+ 1]``:
983
+ - 0 corresponds to a `sentence A` token,
984
+ - 1 corresponds to a `sentence B` token.
985
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
986
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
987
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
988
+ config.max_position_embeddings - 1]``.
989
+ `What are position IDs? <../glossary.html#position-ids>`_
990
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
991
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
992
+ - 1 indicates the head is **not masked**,
993
+ - 0 indicates the head is **masked**.
994
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
995
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
996
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
997
+ vectors than the model's internal embedding lookup matrix.
998
+ output_attentions (:obj:`bool`, `optional`):
999
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
1000
+ tensors for more detail.
1001
+ output_hidden_states (:obj:`bool`, `optional`):
1002
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
1003
+ more detail.
1004
+ return_dict (:obj:`bool`, `optional`):
1005
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
1006
+ """
1007
+
1008
+
1009
+ @add_start_docstrings(
1010
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
1011
+ BERT_START_DOCSTRING,
1012
+ )
1013
+ class BertModel(BertPreTrainedModel):
1014
+ """
1015
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
1016
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
1017
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
1018
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
1019
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
1020
+ input to the forward pass.
1021
+ """
1022
+
1023
+ def __init__(self, config, add_pooling_layer=True):
1024
+ super().__init__(config)
1025
+ self.config = config
1026
+
1027
+ self.embeddings = BertEmbeddings(config)
1028
+
1029
+ self.encoder = BertEncoder(config)
1030
+
1031
+ self.pooler = BertPooler(config) if add_pooling_layer else None
1032
+
1033
+ self.init_weights()
1034
+
1035
+ def get_input_embeddings(self):
1036
+ return self.embeddings.word_embeddings
1037
+
1038
+ def set_input_embeddings(self, value):
1039
+ self.embeddings.word_embeddings = value
1040
+
1041
+ def _prune_heads(self, heads_to_prune):
1042
+ """
1043
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1044
+ class PreTrainedModel
1045
+ """
1046
+ for layer, heads in heads_to_prune.items():
1047
+ self.encoder.layer[layer].attention.prune_heads(heads)
1048
+
1049
+ def get_extended_attention_mask(
1050
+ self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool
1051
+ ) -> Tensor:
1052
+ """
1053
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
1054
+
1055
+ Arguments:
1056
+ attention_mask (:obj:`torch.Tensor`):
1057
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
1058
+ input_shape (:obj:`Tuple[int]`):
1059
+ The shape of the input to the model.
1060
+ device: (:obj:`torch.device`):
1061
+ The device of the input to the model.
1062
+
1063
+ Returns:
1064
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
1065
+ """
1066
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1067
+ # ourselves in which case we just need to make it broadcastable to all heads.
1068
+ if attention_mask.dim() == 3:
1069
+ extended_attention_mask = attention_mask[:, None, :, :]
1070
+ elif attention_mask.dim() == 2:
1071
+ # Provided a padding mask of dimensions [batch_size, seq_length]
1072
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
1073
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
1074
+ if is_decoder:
1075
+ batch_size, seq_length = input_shape
1076
+ seq_ids = torch.arange(seq_length, device=device)
1077
+ causal_mask = (
1078
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
1079
+ <= seq_ids[None, :, None]
1080
+ )
1081
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
1082
+ # causal and attention masks must have same type with pytorch version < 1.3
1083
+ causal_mask = causal_mask.to(attention_mask.dtype)
1084
+
1085
+ if causal_mask.shape[1] < attention_mask.shape[1]:
1086
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
1087
+ causal_mask = torch.cat(
1088
+ [
1089
+ torch.ones(
1090
+ (batch_size, seq_length, prefix_seq_len),
1091
+ device=device,
1092
+ dtype=causal_mask.dtype,
1093
+ ),
1094
+ causal_mask,
1095
+ ],
1096
+ axis=-1,
1097
+ )
1098
+
1099
+ extended_attention_mask = (
1100
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
1101
+ )
1102
+ else:
1103
+ extended_attention_mask = attention_mask[:, None, None, :]
1104
+ else:
1105
+ raise ValueError(
1106
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
1107
+ input_shape, attention_mask.shape
1108
+ )
1109
+ )
1110
+
1111
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1112
+ # masked positions, this operation will create a tensor which is 0.0 for
1113
+ # positions we want to attend and -10000.0 for masked positions.
1114
+ # Since we are adding it to the raw scores before the softmax, this is
1115
+ # effectively the same as removing these entirely.
1116
+ extended_attention_mask = extended_attention_mask.to(
1117
+ dtype=self.dtype
1118
+ ) # fp16 compatibility
1119
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1120
+ return extended_attention_mask
1121
+
1122
+ def forward(
1123
+ self,
1124
+ input_ids=None,
1125
+ attention_mask=None,
1126
+ token_type_ids=None,
1127
+ position_ids=None,
1128
+ head_mask=None,
1129
+ inputs_embeds=None,
1130
+ encoder_embeds=None,
1131
+ encoder_hidden_states=None,
1132
+ encoder_attention_mask=None,
1133
+ past_key_values=None,
1134
+ use_cache=None,
1135
+ output_attentions=None,
1136
+ output_hidden_states=None,
1137
+ return_dict=None,
1138
+ is_decoder=False,
1139
+ mode="multi_modal",
1140
+ normalize_attention=True,
1141
+ ):
1142
+ r"""
1143
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1144
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1145
+ the model is configured as a decoder.
1146
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1147
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1148
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1149
+ - 1 for tokens that are **not masked**,
1150
+ - 0 for tokens that are **masked**.
1151
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1152
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1153
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1154
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1155
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1156
+ use_cache (:obj:`bool`, `optional`):
1157
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1158
+ decoding (see :obj:`past_key_values`).
1159
+ """
1160
+ output_attentions = (
1161
+ output_attentions
1162
+ if output_attentions is not None
1163
+ else self.config.output_attentions
1164
+ )
1165
+ output_hidden_states = (
1166
+ output_hidden_states
1167
+ if output_hidden_states is not None
1168
+ else self.config.output_hidden_states
1169
+ )
1170
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1171
+
1172
+ if is_decoder:
1173
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1174
+ else:
1175
+ use_cache = False
1176
+
1177
+ if input_ids is not None and inputs_embeds is not None:
1178
+ raise ValueError(
1179
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1180
+ )
1181
+ elif input_ids is not None:
1182
+ input_shape = input_ids.size()
1183
+ batch_size, seq_length = input_shape
1184
+ device = input_ids.device
1185
+ elif inputs_embeds is not None:
1186
+ input_shape = inputs_embeds.size()[:-1]
1187
+ batch_size, seq_length = input_shape
1188
+ device = inputs_embeds.device
1189
+ elif encoder_embeds is not None:
1190
+ input_shape = encoder_embeds.size()[:-1]
1191
+ batch_size, seq_length = input_shape
1192
+ device = encoder_embeds.device
1193
+ else:
1194
+ raise ValueError(
1195
+ "You have to specify either input_ids or inputs_embeds or encoder_embeds"
1196
+ )
1197
+
1198
+ # past_key_values_length
1199
+ past_key_values_length = (
1200
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1201
+ )
1202
+
1203
+ if attention_mask is None:
1204
+ attention_mask = torch.ones(
1205
+ ((batch_size, seq_length + past_key_values_length)), device=device
1206
+ )
1207
+ if token_type_ids is None:
1208
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1209
+
1210
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1211
+ # ourselves in which case we just need to make it broadcastable to all heads.
1212
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
1213
+ attention_mask, input_shape, device, is_decoder
1214
+ )
1215
+
1216
+ # If a 2D or 3D attention mask is provided for the cross-attention
1217
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1218
+ if encoder_hidden_states is not None:
1219
+ if type(encoder_hidden_states) == list:
1220
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
1221
+ 0
1222
+ ].size()
1223
+ else:
1224
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1225
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1226
+
1227
+ if type(encoder_attention_mask) == list:
1228
+ encoder_extended_attention_mask = [
1229
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
1230
+ ]
1231
+ elif encoder_attention_mask is None:
1232
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1233
+ encoder_extended_attention_mask = self.invert_attention_mask(
1234
+ encoder_attention_mask
1235
+ )
1236
+ else:
1237
+ encoder_extended_attention_mask = self.invert_attention_mask(
1238
+ encoder_attention_mask
1239
+ )
1240
+ else:
1241
+ encoder_extended_attention_mask = None
1242
+
1243
+ # Prepare head mask if needed
1244
+ # 1.0 in head_mask indicate we keep the head
1245
+ # attention_probs has shape bsz x n_heads x N x N
1246
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1247
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1248
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1249
+
1250
+ if encoder_embeds is None:
1251
+ embedding_output = self.embeddings(
1252
+ input_ids=input_ids,
1253
+ position_ids=position_ids,
1254
+ token_type_ids=token_type_ids,
1255
+ inputs_embeds=inputs_embeds,
1256
+ past_key_values_length=past_key_values_length,
1257
+ )
1258
+ else:
1259
+ embedding_output = encoder_embeds
1260
+
1261
+ encoder_outputs = self.encoder(
1262
+ embedding_output,
1263
+ attention_mask=extended_attention_mask,
1264
+ head_mask=head_mask,
1265
+ encoder_hidden_states=encoder_hidden_states,
1266
+ encoder_attention_mask=encoder_extended_attention_mask,
1267
+ past_key_values=past_key_values,
1268
+ use_cache=use_cache,
1269
+ output_attentions=output_attentions,
1270
+ output_hidden_states=output_hidden_states,
1271
+ return_dict=return_dict,
1272
+ mode=mode,
1273
+ normalize_attention=normalize_attention,
1274
+ )
1275
+ sequence_output = encoder_outputs[0]
1276
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1277
+
1278
+ if not return_dict:
1279
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1280
+
1281
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1282
+ last_hidden_state=sequence_output,
1283
+ pooler_output=pooled_output,
1284
+ past_key_values=encoder_outputs.past_key_values,
1285
+ hidden_states=encoder_outputs.hidden_states,
1286
+ attentions=encoder_outputs.attentions,
1287
+ cross_attentions=encoder_outputs.cross_attentions,
1288
+ )
1289
+
1290
+
1291
+ @add_start_docstrings(
1292
+ """
1293
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1294
+ sentence prediction (classification)` head.
1295
+ """,
1296
+ BERT_START_DOCSTRING,
1297
+ )
1298
+ class BertForPreTraining(BertPreTrainedModel):
1299
+ def __init__(self, config):
1300
+ super().__init__(config)
1301
+
1302
+ self.bert = BertModel(config)
1303
+ self.cls = BertPreTrainingHeads(config)
1304
+
1305
+ self.init_weights()
1306
+
1307
+ def get_output_embeddings(self):
1308
+ return self.cls.predictions.decoder
1309
+
1310
+ def set_output_embeddings(self, new_embeddings):
1311
+ self.cls.predictions.decoder = new_embeddings
1312
+
1313
+ @add_start_docstrings_to_model_forward(
1314
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1315
+ )
1316
+ @replace_return_docstrings(
1317
+ output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
1318
+ )
1319
+ def forward(
1320
+ self,
1321
+ input_ids=None,
1322
+ attention_mask=None,
1323
+ token_type_ids=None,
1324
+ position_ids=None,
1325
+ head_mask=None,
1326
+ inputs_embeds=None,
1327
+ labels=None,
1328
+ next_sentence_label=None,
1329
+ output_attentions=None,
1330
+ output_hidden_states=None,
1331
+ return_dict=None,
1332
+ ):
1333
+ r"""
1334
+ labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
1335
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1336
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1337
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1338
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
1339
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1340
+ (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
1341
+ - 0 indicates sequence B is a continuation of sequence A,
1342
+ - 1 indicates sequence B is a random sequence.
1343
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1344
+ Used to hide legacy arguments that have been deprecated.
1345
+ Returns:
1346
+ Example::
1347
+ >>> from transformers import BertTokenizer, BertForPreTraining
1348
+ >>> import torch
1349
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1350
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
1351
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1352
+ >>> outputs = model(**inputs)
1353
+ >>> prediction_logits = outputs.prediction_logits
1354
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1355
+ """
1356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1357
+
1358
+ outputs = self.bert(
1359
+ input_ids,
1360
+ attention_mask=attention_mask,
1361
+ token_type_ids=token_type_ids,
1362
+ position_ids=position_ids,
1363
+ head_mask=head_mask,
1364
+ inputs_embeds=inputs_embeds,
1365
+ output_attentions=output_attentions,
1366
+ output_hidden_states=output_hidden_states,
1367
+ return_dict=return_dict,
1368
+ )
1369
+
1370
+ sequence_output, pooled_output = outputs[:2]
1371
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1372
+
1373
+ total_loss = None
1374
+ if labels is not None and next_sentence_label is not None:
1375
+ loss_fct = CrossEntropyLoss()
1376
+ masked_lm_loss = loss_fct(
1377
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1378
+ )
1379
+ next_sentence_loss = loss_fct(
1380
+ seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)
1381
+ )
1382
+ total_loss = masked_lm_loss + next_sentence_loss
1383
+
1384
+ if not return_dict:
1385
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1386
+ return ((total_loss,) + output) if total_loss is not None else output
1387
+
1388
+ return BertForPreTrainingOutput(
1389
+ loss=total_loss,
1390
+ prediction_logits=prediction_scores,
1391
+ seq_relationship_logits=seq_relationship_score,
1392
+ hidden_states=outputs.hidden_states,
1393
+ attentions=outputs.attentions,
1394
+ )
1395
+
1396
+
1397
+ @add_start_docstrings(
1398
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """,
1399
+ BERT_START_DOCSTRING,
1400
+ )
1401
+ class BertLMHeadModel(BertPreTrainedModel):
1402
+
1403
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1404
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1405
+
1406
+ def __init__(self, config):
1407
+ super().__init__(config)
1408
+
1409
+ self.bert = BertModel(config, add_pooling_layer=False)
1410
+ self.cls = BertOnlyMLMHead(config)
1411
+
1412
+ self.init_weights()
1413
+
1414
+ def get_output_embeddings(self):
1415
+ return self.cls.predictions.decoder
1416
+
1417
+ def set_output_embeddings(self, new_embeddings):
1418
+ self.cls.predictions.decoder = new_embeddings
1419
+
1420
+ @add_start_docstrings_to_model_forward(
1421
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1422
+ )
1423
+ @replace_return_docstrings(
1424
+ output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC
1425
+ )
1426
+ def forward(
1427
+ self,
1428
+ input_ids=None,
1429
+ attention_mask=None,
1430
+ token_type_ids=None,
1431
+ position_ids=None,
1432
+ head_mask=None,
1433
+ inputs_embeds=None,
1434
+ encoder_hidden_states=None,
1435
+ encoder_attention_mask=None,
1436
+ labels=None,
1437
+ past_key_values=None,
1438
+ use_cache=None,
1439
+ output_attentions=None,
1440
+ output_hidden_states=None,
1441
+ return_dict=None,
1442
+ is_decoder=True,
1443
+ reduction="mean",
1444
+ mode="multi_modal",
1445
+ normalize_attention=True,
1446
+ soft_labels=None,
1447
+ alpha=0,
1448
+ return_logits=False,
1449
+ ):
1450
+ r"""
1451
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1452
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1453
+ the model is configured as a decoder.
1454
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1455
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1456
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1457
+ - 1 for tokens that are **not masked**,
1458
+ - 0 for tokens that are **masked**.
1459
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1460
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1461
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1462
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1463
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1464
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1465
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1466
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1467
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1468
+ use_cache (:obj:`bool`, `optional`):
1469
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1470
+ decoding (see :obj:`past_key_values`).
1471
+ Returns:
1472
+ Example::
1473
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1474
+ >>> import torch
1475
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1476
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1477
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1478
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1479
+ >>> outputs = model(**inputs)
1480
+ >>> prediction_logits = outputs.logits
1481
+ """
1482
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1483
+ if labels is not None:
1484
+ use_cache = False
1485
+
1486
+ outputs = self.bert(
1487
+ input_ids,
1488
+ attention_mask=attention_mask,
1489
+ token_type_ids=token_type_ids,
1490
+ position_ids=position_ids,
1491
+ head_mask=head_mask,
1492
+ inputs_embeds=inputs_embeds,
1493
+ encoder_hidden_states=encoder_hidden_states,
1494
+ encoder_attention_mask=encoder_attention_mask,
1495
+ past_key_values=past_key_values,
1496
+ use_cache=use_cache,
1497
+ output_attentions=output_attentions,
1498
+ output_hidden_states=output_hidden_states,
1499
+ return_dict=return_dict,
1500
+ is_decoder=is_decoder,
1501
+ mode=mode,
1502
+ normalize_attention=normalize_attention,
1503
+ )
1504
+
1505
+ sequence_output = outputs[0]
1506
+ prediction_scores = self.cls(sequence_output)
1507
+
1508
+ if return_logits:
1509
+ return prediction_scores[:, :-1, :].contiguous()
1510
+
1511
+ lm_loss = None
1512
+ if labels is not None:
1513
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1514
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1515
+ labels = labels[:, 1:].contiguous()
1516
+ loss_fct = CrossEntropyLoss(reduction=reduction)
1517
+ lm_loss = loss_fct(
1518
+ shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1519
+ )
1520
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1521
+
1522
+ if soft_labels is not None:
1523
+ loss_distill = -torch.sum(
1524
+ F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, dim=-1
1525
+ )
1526
+ loss_distill = (loss_distill * (labels != -100)).sum(1)
1527
+ lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill
1528
+
1529
+ if not return_dict:
1530
+ output = (prediction_scores,) + outputs[2:]
1531
+ return ((lm_loss,) + output) if lm_loss is not None else output
1532
+
1533
+ return CausalLMOutputWithCrossAttentions(
1534
+ loss=lm_loss,
1535
+ logits=prediction_scores,
1536
+ past_key_values=outputs.past_key_values,
1537
+ hidden_states=outputs.hidden_states,
1538
+ attentions=outputs.attentions,
1539
+ cross_attentions=outputs.cross_attentions,
1540
+ )
1541
+
1542
+ def prepare_inputs_for_generation(
1543
+ self, input_ids, past=None, attention_mask=None, **model_kwargs
1544
+ ):
1545
+ input_shape = input_ids.shape
1546
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1547
+ if attention_mask is None:
1548
+ attention_mask = input_ids.new_ones(input_shape)
1549
+
1550
+ # cut decoder_input_ids if past is used
1551
+ if past is not None:
1552
+ input_ids = input_ids[:, -1:]
1553
+
1554
+ return {
1555
+ "input_ids": input_ids,
1556
+ "attention_mask": attention_mask,
1557
+ "past_key_values": past,
1558
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1559
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1560
+ "is_decoder": True,
1561
+ }
1562
+
1563
+ def _reorder_cache(self, past, beam_idx):
1564
+ reordered_past = ()
1565
+ for layer_past in past:
1566
+ reordered_past += (
1567
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
1568
+ )
1569
+ return reordered_past
1570
+
1571
+
1572
+ @dataclass
1573
+ class MaskedLMOutputWithDistill(MaskedLMOutput):
1574
+ loss_aux: Optional[torch.FloatTensor] = None
1575
+ loss_distill: Optional[torch.FloatTensor] = None
1576
+
1577
+
1578
+ @add_start_docstrings(
1579
+ """Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING
1580
+ )
1581
+ class BertForMaskedLM(BertPreTrainedModel):
1582
+
1583
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1584
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1585
+
1586
+ def __init__(self, config):
1587
+ super().__init__(config)
1588
+
1589
+ self.bert = BertModel(config, add_pooling_layer=False)
1590
+ self.cls = BertOnlyMLMHead(config)
1591
+
1592
+ self.init_weights()
1593
+
1594
+ def tie_aux_decoder_weights(self, module, aux_modules):
1595
+ """Tie decoder weights of all `aux_modules` to `module`, (not bias)"""
1596
+ for m in aux_modules:
1597
+ m.predictions.decoder.weight = module.predictions.decoder.weight
1598
+
1599
+ def get_output_embeddings(self):
1600
+ return self.cls.predictions.decoder
1601
+
1602
+ def set_output_embeddings(self, new_embeddings):
1603
+ self.cls.predictions.decoder = new_embeddings
1604
+
1605
+ def forward(
1606
+ self,
1607
+ input_ids=None,
1608
+ attention_mask=None,
1609
+ token_type_ids=None,
1610
+ position_ids=None,
1611
+ head_mask=None,
1612
+ inputs_embeds=None,
1613
+ encoder_embeds=None,
1614
+ encoder_hidden_states=None,
1615
+ encoder_attention_mask=None,
1616
+ labels=None,
1617
+ output_attentions=None,
1618
+ output_hidden_states=None,
1619
+ return_dict=None,
1620
+ is_decoder=False,
1621
+ mode="multi_modal",
1622
+ normalize_attention=True,
1623
+ soft_labels=None,
1624
+ alpha=0,
1625
+ return_logits=False,
1626
+ ):
1627
+ r"""
1628
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1629
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1630
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1631
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1632
+ """
1633
+
1634
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1635
+
1636
+ outputs = self.bert(
1637
+ input_ids,
1638
+ attention_mask=attention_mask,
1639
+ token_type_ids=token_type_ids,
1640
+ position_ids=position_ids,
1641
+ head_mask=head_mask,
1642
+ inputs_embeds=inputs_embeds,
1643
+ encoder_embeds=encoder_embeds,
1644
+ encoder_hidden_states=encoder_hidden_states,
1645
+ encoder_attention_mask=encoder_attention_mask,
1646
+ output_attentions=output_attentions,
1647
+ output_hidden_states=output_hidden_states,
1648
+ return_dict=return_dict,
1649
+ is_decoder=is_decoder,
1650
+ mode=mode,
1651
+ normalize_attention=normalize_attention,
1652
+ )
1653
+
1654
+ sequence_output = outputs[0]
1655
+ prediction_scores = self.cls(sequence_output)
1656
+
1657
+ if return_logits:
1658
+ return prediction_scores
1659
+
1660
+ masked_lm_loss = None
1661
+ masked_lm_loss_aux = 0.0
1662
+ if labels is not None:
1663
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1664
+ masked_lm_loss = loss_fct(
1665
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1666
+ )
1667
+
1668
+ if soft_labels is not None:
1669
+ loss_distill = -torch.sum(
1670
+ F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1
1671
+ )
1672
+ loss_distill = loss_distill[labels != -100].mean()
1673
+ masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill
1674
+
1675
+ if not return_dict:
1676
+ output = (prediction_scores,) + outputs[2:]
1677
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1678
+
1679
+ # changed from MaskedLMOutput to MaskedLMOutputWithDistill
1680
+ return MaskedLMOutputWithDistill(
1681
+ loss=masked_lm_loss,
1682
+ loss_aux=masked_lm_loss_aux,
1683
+ logits=prediction_scores,
1684
+ hidden_states=outputs.hidden_states,
1685
+ attentions=outputs.attentions,
1686
+ )
1687
+
1688
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1689
+ input_shape = input_ids.shape
1690
+ effective_batch_size = input_shape[0]
1691
+
1692
+ # add a dummy token
1693
+ assert (
1694
+ self.config.pad_token_id is not None
1695
+ ), "The PAD token should be defined for generation"
1696
+ attention_mask = torch.cat(
1697
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
1698
+ )
1699
+ dummy_token = torch.full(
1700
+ (effective_batch_size, 1),
1701
+ self.config.pad_token_id,
1702
+ dtype=torch.long,
1703
+ device=input_ids.device,
1704
+ )
1705
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1706
+
1707
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1708
+
1709
+
1710
+ @add_start_docstrings(
1711
+ """Bert Model with a `next sentence prediction (classification)` head on top. """,
1712
+ BERT_START_DOCSTRING,
1713
+ )
1714
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1715
+ def __init__(self, config):
1716
+ super().__init__(config)
1717
+
1718
+ self.bert = BertModel(config)
1719
+ self.cls = BertOnlyNSPHead(config)
1720
+
1721
+ self.init_weights()
1722
+
1723
+ @add_start_docstrings_to_model_forward(
1724
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1725
+ )
1726
+ @replace_return_docstrings(
1727
+ output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
1728
+ )
1729
+ def forward(
1730
+ self,
1731
+ input_ids=None,
1732
+ attention_mask=None,
1733
+ token_type_ids=None,
1734
+ position_ids=None,
1735
+ head_mask=None,
1736
+ inputs_embeds=None,
1737
+ labels=None,
1738
+ output_attentions=None,
1739
+ output_hidden_states=None,
1740
+ return_dict=None,
1741
+ **kwargs,
1742
+ ):
1743
+ r"""
1744
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1745
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1746
+ (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
1747
+ - 0 indicates sequence B is a continuation of sequence A,
1748
+ - 1 indicates sequence B is a random sequence.
1749
+ Returns:
1750
+ Example::
1751
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1752
+ >>> import torch
1753
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1754
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1755
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1756
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1757
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1758
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1759
+ >>> logits = outputs.logits
1760
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1761
+ """
1762
+
1763
+ if "next_sentence_label" in kwargs:
1764
+ warnings.warn(
1765
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
1766
+ FutureWarning,
1767
+ )
1768
+ labels = kwargs.pop("next_sentence_label")
1769
+
1770
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1771
+
1772
+ outputs = self.bert(
1773
+ input_ids,
1774
+ attention_mask=attention_mask,
1775
+ token_type_ids=token_type_ids,
1776
+ position_ids=position_ids,
1777
+ head_mask=head_mask,
1778
+ inputs_embeds=inputs_embeds,
1779
+ output_attentions=output_attentions,
1780
+ output_hidden_states=output_hidden_states,
1781
+ return_dict=return_dict,
1782
+ )
1783
+
1784
+ pooled_output = outputs[1]
1785
+
1786
+ seq_relationship_scores = self.cls(pooled_output)
1787
+
1788
+ next_sentence_loss = None
1789
+ if labels is not None:
1790
+ loss_fct = CrossEntropyLoss()
1791
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1792
+
1793
+ if not return_dict:
1794
+ output = (seq_relationship_scores,) + outputs[2:]
1795
+ return (
1796
+ ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1797
+ )
1798
+
1799
+ return NextSentencePredictorOutput(
1800
+ loss=next_sentence_loss,
1801
+ logits=seq_relationship_scores,
1802
+ hidden_states=outputs.hidden_states,
1803
+ attentions=outputs.attentions,
1804
+ )
1805
+
1806
+
1807
+ @add_start_docstrings(
1808
+ """
1809
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1810
+ output) e.g. for GLUE tasks.
1811
+ """,
1812
+ BERT_START_DOCSTRING,
1813
+ )
1814
+ class BertForSequenceClassification(BertPreTrainedModel):
1815
+ def __init__(self, config):
1816
+ super().__init__(config)
1817
+ self.num_labels = config.num_labels
1818
+
1819
+ self.bert = BertModel(config)
1820
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1821
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1822
+
1823
+ self.init_weights()
1824
+
1825
+ def forward(
1826
+ self,
1827
+ input_ids=None,
1828
+ attention_mask=None,
1829
+ token_type_ids=None,
1830
+ position_ids=None,
1831
+ head_mask=None,
1832
+ inputs_embeds=None,
1833
+ labels=None,
1834
+ output_attentions=None,
1835
+ output_hidden_states=None,
1836
+ return_dict=None,
1837
+ ):
1838
+ r"""
1839
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1840
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1841
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1842
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1843
+ """
1844
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1845
+
1846
+ outputs = self.bert(
1847
+ input_ids,
1848
+ attention_mask=attention_mask,
1849
+ token_type_ids=token_type_ids,
1850
+ position_ids=position_ids,
1851
+ head_mask=head_mask,
1852
+ inputs_embeds=inputs_embeds,
1853
+ output_attentions=output_attentions,
1854
+ output_hidden_states=output_hidden_states,
1855
+ return_dict=return_dict,
1856
+ )
1857
+
1858
+ pooled_output = outputs[1]
1859
+
1860
+ pooled_output = self.dropout(pooled_output)
1861
+ logits = self.classifier(pooled_output)
1862
+
1863
+ loss = None
1864
+ if labels is not None:
1865
+ if self.num_labels == 1:
1866
+ # We are doing regression
1867
+ loss_fct = MSELoss()
1868
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1869
+ else:
1870
+ loss_fct = CrossEntropyLoss()
1871
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1872
+
1873
+ if not return_dict:
1874
+ output = (logits,) + outputs[2:]
1875
+ return ((loss,) + output) if loss is not None else output
1876
+
1877
+ return SequenceClassifierOutput(
1878
+ loss=loss,
1879
+ logits=logits,
1880
+ hidden_states=outputs.hidden_states,
1881
+ attentions=outputs.attentions,
1882
+ )
1883
+
1884
+
1885
+ @add_start_docstrings(
1886
+ """
1887
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1888
+ softmax) e.g. for RocStories/SWAG tasks.
1889
+ """,
1890
+ BERT_START_DOCSTRING,
1891
+ )
1892
+ class BertForMultipleChoice(BertPreTrainedModel):
1893
+ def __init__(self, config):
1894
+ super().__init__(config)
1895
+
1896
+ self.bert = BertModel(config)
1897
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1898
+ self.classifier = nn.Linear(config.hidden_size, 1)
1899
+
1900
+ self.init_weights()
1901
+
1902
+ def forward(
1903
+ self,
1904
+ input_ids=None,
1905
+ attention_mask=None,
1906
+ token_type_ids=None,
1907
+ position_ids=None,
1908
+ head_mask=None,
1909
+ inputs_embeds=None,
1910
+ labels=None,
1911
+ output_attentions=None,
1912
+ output_hidden_states=None,
1913
+ return_dict=None,
1914
+ ):
1915
+ r"""
1916
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1917
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1918
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
1919
+ :obj:`input_ids` above)
1920
+ """
1921
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1922
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1923
+
1924
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1925
+ attention_mask = (
1926
+ attention_mask.view(-1, attention_mask.size(-1))
1927
+ if attention_mask is not None
1928
+ else None
1929
+ )
1930
+ token_type_ids = (
1931
+ token_type_ids.view(-1, token_type_ids.size(-1))
1932
+ if token_type_ids is not None
1933
+ else None
1934
+ )
1935
+ position_ids = (
1936
+ position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1937
+ )
1938
+ inputs_embeds = (
1939
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1940
+ if inputs_embeds is not None
1941
+ else None
1942
+ )
1943
+
1944
+ outputs = self.bert(
1945
+ input_ids,
1946
+ attention_mask=attention_mask,
1947
+ token_type_ids=token_type_ids,
1948
+ position_ids=position_ids,
1949
+ head_mask=head_mask,
1950
+ inputs_embeds=inputs_embeds,
1951
+ output_attentions=output_attentions,
1952
+ output_hidden_states=output_hidden_states,
1953
+ return_dict=return_dict,
1954
+ )
1955
+
1956
+ pooled_output = outputs[1]
1957
+
1958
+ pooled_output = self.dropout(pooled_output)
1959
+ logits = self.classifier(pooled_output)
1960
+ reshaped_logits = logits.view(-1, num_choices)
1961
+
1962
+ loss = None
1963
+ if labels is not None:
1964
+ loss_fct = CrossEntropyLoss()
1965
+ loss = loss_fct(reshaped_logits, labels)
1966
+
1967
+ if not return_dict:
1968
+ output = (reshaped_logits,) + outputs[2:]
1969
+ return ((loss,) + output) if loss is not None else output
1970
+
1971
+ return MultipleChoiceModelOutput(
1972
+ loss=loss,
1973
+ logits=reshaped_logits,
1974
+ hidden_states=outputs.hidden_states,
1975
+ attentions=outputs.attentions,
1976
+ )
1977
+
1978
+
1979
+ @add_start_docstrings(
1980
+ """
1981
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1982
+ Named-Entity-Recognition (NER) tasks.
1983
+ """,
1984
+ BERT_START_DOCSTRING,
1985
+ )
1986
+ class BertForTokenClassification(BertPreTrainedModel):
1987
+
1988
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1989
+
1990
+ def __init__(self, config):
1991
+ super().__init__(config)
1992
+ self.num_labels = config.num_labels
1993
+
1994
+ self.bert = BertModel(config, add_pooling_layer=False)
1995
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1996
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1997
+
1998
+ self.init_weights()
1999
+
2000
+ def forward(
2001
+ self,
2002
+ input_ids=None,
2003
+ attention_mask=None,
2004
+ token_type_ids=None,
2005
+ position_ids=None,
2006
+ head_mask=None,
2007
+ inputs_embeds=None,
2008
+ labels=None,
2009
+ output_attentions=None,
2010
+ output_hidden_states=None,
2011
+ return_dict=None,
2012
+ ):
2013
+ r"""
2014
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
2015
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
2016
+ 1]``.
2017
+ """
2018
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2019
+
2020
+ outputs = self.bert(
2021
+ input_ids,
2022
+ attention_mask=attention_mask,
2023
+ token_type_ids=token_type_ids,
2024
+ position_ids=position_ids,
2025
+ head_mask=head_mask,
2026
+ inputs_embeds=inputs_embeds,
2027
+ output_attentions=output_attentions,
2028
+ output_hidden_states=output_hidden_states,
2029
+ return_dict=return_dict,
2030
+ )
2031
+
2032
+ sequence_output = outputs[0]
2033
+
2034
+ sequence_output = self.dropout(sequence_output)
2035
+ logits = self.classifier(sequence_output)
2036
+
2037
+ loss = None
2038
+ if labels is not None:
2039
+ loss_fct = CrossEntropyLoss()
2040
+ # Only keep active parts of the loss
2041
+ if attention_mask is not None:
2042
+ active_loss = attention_mask.view(-1) == 1
2043
+ active_logits = logits.view(-1, self.num_labels)
2044
+ active_labels = torch.where(
2045
+ active_loss,
2046
+ labels.view(-1),
2047
+ torch.tensor(loss_fct.ignore_index).type_as(labels),
2048
+ )
2049
+ loss = loss_fct(active_logits, active_labels)
2050
+ else:
2051
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
2052
+
2053
+ if not return_dict:
2054
+ output = (logits,) + outputs[2:]
2055
+ return ((loss,) + output) if loss is not None else output
2056
+
2057
+ return TokenClassifierOutput(
2058
+ loss=loss,
2059
+ logits=logits,
2060
+ hidden_states=outputs.hidden_states,
2061
+ attentions=outputs.attentions,
2062
+ )
2063
+
2064
+
2065
+ @add_start_docstrings(
2066
+ """
2067
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
2068
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
2069
+ """,
2070
+ BERT_START_DOCSTRING,
2071
+ )
2072
+ class BertForQuestionAnswering(BertPreTrainedModel):
2073
+
2074
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
2075
+
2076
+ def __init__(self, config):
2077
+ super().__init__(config)
2078
+ self.num_labels = config.num_labels
2079
+
2080
+ self.bert = BertModel(config, add_pooling_layer=False)
2081
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
2082
+
2083
+ self.init_weights()
2084
+
2085
+ def forward(
2086
+ self,
2087
+ input_ids=None,
2088
+ attention_mask=None,
2089
+ token_type_ids=None,
2090
+ position_ids=None,
2091
+ head_mask=None,
2092
+ inputs_embeds=None,
2093
+ start_positions=None,
2094
+ end_positions=None,
2095
+ output_attentions=None,
2096
+ output_hidden_states=None,
2097
+ return_dict=None,
2098
+ ):
2099
+ r"""
2100
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
2101
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
2102
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
2103
+ sequence are not taken into account for computing the loss.
2104
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
2105
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
2106
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
2107
+ sequence are not taken into account for computing the loss.
2108
+ """
2109
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2110
+
2111
+ outputs = self.bert(
2112
+ input_ids,
2113
+ attention_mask=attention_mask,
2114
+ token_type_ids=token_type_ids,
2115
+ position_ids=position_ids,
2116
+ head_mask=head_mask,
2117
+ inputs_embeds=inputs_embeds,
2118
+ output_attentions=output_attentions,
2119
+ output_hidden_states=output_hidden_states,
2120
+ return_dict=return_dict,
2121
+ )
2122
+
2123
+ sequence_output = outputs[0]
2124
+
2125
+ logits = self.qa_outputs(sequence_output)
2126
+ start_logits, end_logits = logits.split(1, dim=-1)
2127
+ start_logits = start_logits.squeeze(-1)
2128
+ end_logits = end_logits.squeeze(-1)
2129
+
2130
+ total_loss = None
2131
+ if start_positions is not None and end_positions is not None:
2132
+ # If we are on multi-GPU, split add a dimension
2133
+ if len(start_positions.size()) > 1:
2134
+ start_positions = start_positions.squeeze(-1)
2135
+ if len(end_positions.size()) > 1:
2136
+ end_positions = end_positions.squeeze(-1)
2137
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
2138
+ ignored_index = start_logits.size(1)
2139
+ start_positions.clamp_(0, ignored_index)
2140
+ end_positions.clamp_(0, ignored_index)
2141
+
2142
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
2143
+ start_loss = loss_fct(start_logits, start_positions)
2144
+ end_loss = loss_fct(end_logits, end_positions)
2145
+ total_loss = (start_loss + end_loss) / 2
2146
+
2147
+ if not return_dict:
2148
+ output = (start_logits, end_logits) + outputs[2:]
2149
+ return ((total_loss,) + output) if total_loss is not None else output
2150
+
2151
+ return QuestionAnsweringModelOutput(
2152
+ loss=total_loss,
2153
+ start_logits=start_logits,
2154
+ end_logits=end_logits,
2155
+ hidden_states=outputs.hidden_states,
2156
+ attentions=outputs.attentions,
2157
+ )
models_viclip/backbones/blip_toremove/Qformer.py ADDED
@@ -0,0 +1,1237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from timm.models.layers import drop_path
25
+ from transformers.activations import ACT2FN
26
+ from transformers.file_utils import (
27
+ ModelOutput,
28
+ )
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ MaskedLMOutput,
34
+ MultipleChoiceModelOutput,
35
+ NextSentencePredictorOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import (
41
+ PreTrainedModel,
42
+ apply_chunking_to_forward,
43
+ find_pruneable_heads_and_indices,
44
+ prune_linear_layer,
45
+ )
46
+ from transformers.utils import logging
47
+ from transformers.models.bert.configuration_bert import BertConfig
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(
58
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
59
+ )
60
+ self.position_embeddings = nn.Embedding(
61
+ config.max_position_embeddings, config.hidden_size
62
+ )
63
+
64
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
65
+ # any TensorFlow checkpoint file
66
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
67
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
68
+
69
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
70
+ self.register_buffer(
71
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
72
+ )
73
+ self.position_embedding_type = getattr(
74
+ config, "position_embedding_type", "absolute"
75
+ )
76
+
77
+ self.config = config
78
+
79
+ def forward(
80
+ self,
81
+ input_ids=None,
82
+ position_ids=None,
83
+ query_embeds=None,
84
+ past_key_values_length=0,
85
+ ):
86
+ if input_ids is not None:
87
+ seq_length = input_ids.size()[1]
88
+ else:
89
+ seq_length = 0
90
+
91
+ if position_ids is None:
92
+ position_ids = self.position_ids[
93
+ :, past_key_values_length : seq_length + past_key_values_length
94
+ ].clone()
95
+
96
+ if input_ids is not None:
97
+ embeddings = self.word_embeddings(input_ids)
98
+ if self.position_embedding_type == "absolute":
99
+ position_embeddings = self.position_embeddings(position_ids)
100
+ embeddings = embeddings + position_embeddings
101
+
102
+ if query_embeds is not None:
103
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
104
+ else:
105
+ embeddings = query_embeds
106
+
107
+ embeddings = self.LayerNorm(embeddings)
108
+ embeddings = self.dropout(embeddings)
109
+ return embeddings
110
+
111
+
112
+ class BertSelfAttention(nn.Module):
113
+ def __init__(self, config, is_cross_attention):
114
+ super().__init__()
115
+ self.config = config
116
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
117
+ config, "embedding_size"
118
+ ):
119
+ raise ValueError(
120
+ "The hidden size (%d) is not a multiple of the number of attention "
121
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
122
+ )
123
+
124
+ self.num_attention_heads = config.num_attention_heads
125
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
126
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
127
+
128
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
129
+ if is_cross_attention:
130
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
131
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
132
+ else:
133
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
134
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
135
+
136
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
137
+ self.position_embedding_type = getattr(
138
+ config, "position_embedding_type", "absolute"
139
+ )
140
+ if (
141
+ self.position_embedding_type == "relative_key"
142
+ or self.position_embedding_type == "relative_key_query"
143
+ ):
144
+ self.max_position_embeddings = config.max_position_embeddings
145
+ self.distance_embedding = nn.Embedding(
146
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
147
+ )
148
+ self.save_attention = False
149
+
150
+ def save_attn_gradients(self, attn_gradients):
151
+ self.attn_gradients = attn_gradients
152
+
153
+ def get_attn_gradients(self):
154
+ return self.attn_gradients
155
+
156
+ def save_attention_map(self, attention_map):
157
+ self.attention_map = attention_map
158
+
159
+ def get_attention_map(self):
160
+ return self.attention_map
161
+
162
+ def transpose_for_scores(self, x):
163
+ new_x_shape = x.size()[:-1] + (
164
+ self.num_attention_heads,
165
+ self.attention_head_size,
166
+ )
167
+ x = x.view(*new_x_shape)
168
+ return x.permute(0, 2, 1, 3)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states,
173
+ attention_mask=None,
174
+ head_mask=None,
175
+ encoder_hidden_states=None,
176
+ encoder_attention_mask=None,
177
+ past_key_value=None,
178
+ output_attentions=False,
179
+ ):
180
+
181
+ # If this is instantiated as a cross-attention module, the keys
182
+ # and values come from an encoder; the attention mask needs to be
183
+ # such that the encoder's padding tokens are not attended to.
184
+ is_cross_attention = encoder_hidden_states is not None
185
+
186
+ if is_cross_attention:
187
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
188
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
189
+ attention_mask = encoder_attention_mask
190
+ elif past_key_value is not None:
191
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
192
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
193
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
194
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
195
+ else:
196
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
198
+
199
+ mixed_query_layer = self.query(hidden_states)
200
+
201
+ query_layer = self.transpose_for_scores(mixed_query_layer)
202
+
203
+ past_key_value = (key_layer, value_layer)
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+ if (
209
+ self.position_embedding_type == "relative_key"
210
+ or self.position_embedding_type == "relative_key_query"
211
+ ):
212
+ seq_length = hidden_states.size()[1]
213
+ position_ids_l = torch.arange(
214
+ seq_length, dtype=torch.long, device=hidden_states.device
215
+ ).view(-1, 1)
216
+ position_ids_r = torch.arange(
217
+ seq_length, dtype=torch.long, device=hidden_states.device
218
+ ).view(1, -1)
219
+ distance = position_ids_l - position_ids_r
220
+ positional_embedding = self.distance_embedding(
221
+ distance + self.max_position_embeddings - 1
222
+ )
223
+ positional_embedding = positional_embedding.to(
224
+ dtype=query_layer.dtype
225
+ ) # fp16 compatibility
226
+
227
+ if self.position_embedding_type == "relative_key":
228
+ relative_position_scores = torch.einsum(
229
+ "bhld,lrd->bhlr", query_layer, positional_embedding
230
+ )
231
+ attention_scores = attention_scores + relative_position_scores
232
+ elif self.position_embedding_type == "relative_key_query":
233
+ relative_position_scores_query = torch.einsum(
234
+ "bhld,lrd->bhlr", query_layer, positional_embedding
235
+ )
236
+ relative_position_scores_key = torch.einsum(
237
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
238
+ )
239
+ attention_scores = (
240
+ attention_scores
241
+ + relative_position_scores_query
242
+ + relative_position_scores_key
243
+ )
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
252
+
253
+ if is_cross_attention and self.save_attention:
254
+ self.save_attention_map(attention_probs)
255
+ attention_probs.register_hook(self.save_attn_gradients)
256
+
257
+ # This is actually dropping out entire tokens to attend to, which might
258
+ # seem a bit unusual, but is taken from the original Transformer paper.
259
+ attention_probs_dropped = self.dropout(attention_probs)
260
+
261
+ # Mask heads if we want to
262
+ if head_mask is not None:
263
+ attention_probs_dropped = attention_probs_dropped * head_mask
264
+
265
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
266
+
267
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
268
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
269
+ context_layer = context_layer.view(*new_context_layer_shape)
270
+
271
+ outputs = (
272
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
273
+ )
274
+
275
+ outputs = outputs + (past_key_value,)
276
+ return outputs
277
+
278
+
279
+ class DropPath(nn.Module):
280
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
281
+ """
282
+ def __init__(self, drop_prob=None):
283
+ super(DropPath, self).__init__()
284
+ self.drop_prob = drop_prob
285
+
286
+ def forward(self, x):
287
+ return drop_path(x, self.drop_prob, self.training)
288
+
289
+ def extra_repr(self) -> str:
290
+ return 'p={}'.format(self.drop_prob)
291
+
292
+
293
+ class BertSelfOutput(nn.Module):
294
+ def __init__(self, config, drop_path_prob=0.):
295
+ super().__init__()
296
+ self.drop_path = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
297
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
298
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
299
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
300
+
301
+ def forward(self, hidden_states, input_tensor):
302
+ hidden_states = self.dense(hidden_states)
303
+ hidden_states = self.dropout(hidden_states)
304
+ hidden_states = self.drop_path(hidden_states)
305
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
306
+ return hidden_states
307
+
308
+
309
+ class BertAttention(nn.Module):
310
+ def __init__(self, config, is_cross_attention=False, drop_path_prob=0.,):
311
+ super().__init__()
312
+ self.self = BertSelfAttention(config, is_cross_attention)
313
+ self.output = BertSelfOutput(config, drop_path_prob=drop_path_prob)
314
+ self.pruned_heads = set()
315
+
316
+ def prune_heads(self, heads):
317
+ if len(heads) == 0:
318
+ return
319
+ heads, index = find_pruneable_heads_and_indices(
320
+ heads,
321
+ self.self.num_attention_heads,
322
+ self.self.attention_head_size,
323
+ self.pruned_heads,
324
+ )
325
+
326
+ # Prune linear layers
327
+ self.self.query = prune_linear_layer(self.self.query, index)
328
+ self.self.key = prune_linear_layer(self.self.key, index)
329
+ self.self.value = prune_linear_layer(self.self.value, index)
330
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
331
+
332
+ # Update hyper params and store pruned heads
333
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
334
+ self.self.all_head_size = (
335
+ self.self.attention_head_size * self.self.num_attention_heads
336
+ )
337
+ self.pruned_heads = self.pruned_heads.union(heads)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states,
342
+ attention_mask=None,
343
+ head_mask=None,
344
+ encoder_hidden_states=None,
345
+ encoder_attention_mask=None,
346
+ past_key_value=None,
347
+ output_attentions=False,
348
+ ):
349
+ self_outputs = self.self(
350
+ hidden_states,
351
+ attention_mask,
352
+ head_mask,
353
+ encoder_hidden_states,
354
+ encoder_attention_mask,
355
+ past_key_value,
356
+ output_attentions,
357
+ )
358
+ attention_output = self.output(self_outputs[0], hidden_states)
359
+
360
+ outputs = (attention_output,) + self_outputs[
361
+ 1:
362
+ ] # add attentions if we output them
363
+ return outputs
364
+
365
+
366
+ class BertIntermediate(nn.Module):
367
+ def __init__(self, config):
368
+ super().__init__()
369
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
370
+ if isinstance(config.hidden_act, str):
371
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
372
+ else:
373
+ self.intermediate_act_fn = config.hidden_act
374
+
375
+ def forward(self, hidden_states):
376
+ hidden_states = self.dense(hidden_states)
377
+ hidden_states = self.intermediate_act_fn(hidden_states)
378
+ return hidden_states
379
+
380
+
381
+ class BertOutput(nn.Module):
382
+ def __init__(self, config, drop_path=0.):
383
+ super().__init__()
384
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
385
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
386
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
387
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
388
+
389
+ def forward(self, hidden_states, input_tensor):
390
+ hidden_states = self.dense(hidden_states)
391
+ hidden_states = self.dropout(hidden_states)
392
+ hidden_states = self.drop_path(hidden_states)
393
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
394
+ return hidden_states
395
+
396
+
397
+ class BertLayer(nn.Module):
398
+ def __init__(self, config, layer_num):
399
+ super().__init__()
400
+ self.config = config
401
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
402
+ self.seq_len_dim = 1
403
+ drop_path_prob = config.drop_path_list[layer_num]
404
+ self.attention = BertAttention(config, drop_path_prob=drop_path_prob)
405
+ self.layer_num = layer_num
406
+ if (
407
+ self.config.add_cross_attention
408
+ and layer_num % self.config.cross_attention_freq == 0
409
+ ):
410
+ self.crossattention = BertAttention(
411
+ config, is_cross_attention=self.config.add_cross_attention,
412
+ drop_path_prob=drop_path_prob
413
+ )
414
+ self.has_cross_attention = True
415
+ else:
416
+ self.has_cross_attention = False
417
+ self.intermediate = BertIntermediate(config)
418
+ self.output = BertOutput(config, drop_path=drop_path_prob)
419
+
420
+ self.intermediate_query = BertIntermediate(config)
421
+ self.output_query = BertOutput(config, drop_path=drop_path_prob)
422
+
423
+ def forward(
424
+ self,
425
+ hidden_states,
426
+ attention_mask=None,
427
+ head_mask=None,
428
+ encoder_hidden_states=None,
429
+ encoder_attention_mask=None,
430
+ past_key_value=None,
431
+ output_attentions=False,
432
+ query_length=0,
433
+ ):
434
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
435
+ self_attn_past_key_value = (
436
+ past_key_value[:2] if past_key_value is not None else None
437
+ )
438
+ self_attention_outputs = self.attention(
439
+ hidden_states,
440
+ attention_mask,
441
+ head_mask,
442
+ output_attentions=output_attentions,
443
+ past_key_value=self_attn_past_key_value,
444
+ )
445
+ attention_output = self_attention_outputs[0]
446
+ outputs = self_attention_outputs[1:-1]
447
+
448
+ present_key_value = self_attention_outputs[-1]
449
+
450
+ if query_length > 0:
451
+ query_attention_output = attention_output[:, :query_length, :]
452
+
453
+ if self.has_cross_attention:
454
+ assert (
455
+ encoder_hidden_states is not None
456
+ ), "encoder_hidden_states must be given for cross-attention layers"
457
+ cross_attention_outputs = self.crossattention(
458
+ query_attention_output,
459
+ attention_mask,
460
+ head_mask,
461
+ encoder_hidden_states,
462
+ encoder_attention_mask,
463
+ output_attentions=output_attentions,
464
+ )
465
+ query_attention_output = cross_attention_outputs[0]
466
+ outputs = (
467
+ outputs + cross_attention_outputs[1:-1]
468
+ ) # add cross attentions if we output attention weights
469
+
470
+ layer_output = apply_chunking_to_forward(
471
+ self.feed_forward_chunk_query,
472
+ self.chunk_size_feed_forward,
473
+ self.seq_len_dim,
474
+ query_attention_output,
475
+ )
476
+ if attention_output.shape[1] > query_length:
477
+ layer_output_text = apply_chunking_to_forward(
478
+ self.feed_forward_chunk,
479
+ self.chunk_size_feed_forward,
480
+ self.seq_len_dim,
481
+ attention_output[:, query_length:, :],
482
+ )
483
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
484
+ else:
485
+ layer_output = apply_chunking_to_forward(
486
+ self.feed_forward_chunk,
487
+ self.chunk_size_feed_forward,
488
+ self.seq_len_dim,
489
+ attention_output,
490
+ )
491
+ outputs = (layer_output,) + outputs
492
+
493
+ outputs = outputs + (present_key_value,)
494
+
495
+ return outputs
496
+
497
+ def feed_forward_chunk(self, attention_output):
498
+ intermediate_output = self.intermediate(attention_output)
499
+ layer_output = self.output(intermediate_output, attention_output)
500
+ return layer_output
501
+
502
+ def feed_forward_chunk_query(self, attention_output):
503
+ intermediate_output = self.intermediate_query(attention_output)
504
+ layer_output = self.output_query(intermediate_output, attention_output)
505
+ return layer_output
506
+
507
+
508
+ class BertEncoder(nn.Module):
509
+ def __init__(self, config):
510
+ super().__init__()
511
+ self.config = config
512
+ self.layer = nn.ModuleList(
513
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
514
+ )
515
+
516
+ def forward(
517
+ self,
518
+ hidden_states,
519
+ attention_mask=None,
520
+ head_mask=None,
521
+ encoder_hidden_states=None,
522
+ encoder_attention_mask=None,
523
+ past_key_values=None,
524
+ use_cache=None,
525
+ output_attentions=False,
526
+ output_hidden_states=False,
527
+ return_dict=True,
528
+ query_length=0,
529
+ ):
530
+ all_hidden_states = () if output_hidden_states else None
531
+ all_self_attentions = () if output_attentions else None
532
+ all_cross_attentions = (
533
+ () if output_attentions and self.config.add_cross_attention else None
534
+ )
535
+
536
+ next_decoder_cache = () if use_cache else None
537
+
538
+ for i in range(self.config.num_hidden_layers):
539
+ layer_module = self.layer[i]
540
+ if output_hidden_states:
541
+ all_hidden_states = all_hidden_states + (hidden_states,)
542
+
543
+ layer_head_mask = head_mask[i] if head_mask is not None else None
544
+ past_key_value = past_key_values[i] if past_key_values is not None else None
545
+
546
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
547
+
548
+ if use_cache:
549
+ logger.warn(
550
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
551
+ )
552
+ use_cache = False
553
+
554
+ def create_custom_forward(module):
555
+ def custom_forward(*inputs):
556
+ return module(
557
+ *inputs, past_key_value, output_attentions, query_length
558
+ )
559
+
560
+ return custom_forward
561
+
562
+ layer_outputs = torch.utils.checkpoint.checkpoint(
563
+ create_custom_forward(layer_module),
564
+ hidden_states,
565
+ attention_mask,
566
+ layer_head_mask,
567
+ encoder_hidden_states,
568
+ encoder_attention_mask,
569
+ )
570
+ else:
571
+ layer_outputs = layer_module(
572
+ hidden_states,
573
+ attention_mask,
574
+ layer_head_mask,
575
+ encoder_hidden_states,
576
+ encoder_attention_mask,
577
+ past_key_value,
578
+ output_attentions,
579
+ query_length,
580
+ )
581
+
582
+ hidden_states = layer_outputs[0]
583
+ if use_cache:
584
+ next_decoder_cache += (layer_outputs[-1],)
585
+ if output_attentions:
586
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
587
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
588
+
589
+ if output_hidden_states:
590
+ all_hidden_states = all_hidden_states + (hidden_states,)
591
+
592
+ if not return_dict:
593
+ return tuple(
594
+ v
595
+ for v in [
596
+ hidden_states,
597
+ next_decoder_cache,
598
+ all_hidden_states,
599
+ all_self_attentions,
600
+ all_cross_attentions,
601
+ ]
602
+ if v is not None
603
+ )
604
+ return BaseModelOutputWithPastAndCrossAttentions(
605
+ last_hidden_state=hidden_states,
606
+ past_key_values=next_decoder_cache,
607
+ hidden_states=all_hidden_states,
608
+ attentions=all_self_attentions,
609
+ cross_attentions=all_cross_attentions,
610
+ )
611
+
612
+
613
+ class BertPooler(nn.Module):
614
+ def __init__(self, config):
615
+ super().__init__()
616
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
617
+ self.activation = nn.Tanh()
618
+
619
+ def forward(self, hidden_states):
620
+ # We "pool" the model by simply taking the hidden state corresponding
621
+ # to the first token.
622
+ first_token_tensor = hidden_states[:, 0]
623
+ pooled_output = self.dense(first_token_tensor)
624
+ pooled_output = self.activation(pooled_output)
625
+ return pooled_output
626
+
627
+
628
+ class BertPredictionHeadTransform(nn.Module):
629
+ def __init__(self, config):
630
+ super().__init__()
631
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
632
+ if isinstance(config.hidden_act, str):
633
+ self.transform_act_fn = ACT2FN[config.hidden_act]
634
+ else:
635
+ self.transform_act_fn = config.hidden_act
636
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.dense(hidden_states)
640
+ hidden_states = self.transform_act_fn(hidden_states)
641
+ hidden_states = self.LayerNorm(hidden_states)
642
+ return hidden_states
643
+
644
+
645
+ class BertLMPredictionHead(nn.Module):
646
+ def __init__(self, config):
647
+ super().__init__()
648
+ self.transform = BertPredictionHeadTransform(config)
649
+
650
+ # The output weights are the same as the input embeddings, but there is
651
+ # an output-only bias for each token.
652
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
653
+
654
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
655
+
656
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
657
+ self.decoder.bias = self.bias
658
+
659
+ def forward(self, hidden_states):
660
+ hidden_states = self.transform(hidden_states)
661
+ hidden_states = self.decoder(hidden_states)
662
+ return hidden_states
663
+
664
+
665
+ class BertOnlyMLMHead(nn.Module):
666
+ def __init__(self, config):
667
+ super().__init__()
668
+ self.predictions = BertLMPredictionHead(config)
669
+
670
+ def forward(self, sequence_output):
671
+ prediction_scores = self.predictions(sequence_output)
672
+ return prediction_scores
673
+
674
+
675
+ class BertPreTrainedModel(PreTrainedModel):
676
+ """
677
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
678
+ models.
679
+ """
680
+
681
+ config_class = BertConfig
682
+ base_model_prefix = "bert"
683
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
684
+
685
+ def _init_weights(self, module):
686
+ """Initialize the weights"""
687
+ if isinstance(module, (nn.Linear, nn.Embedding)):
688
+ # Slightly different from the TF version which uses truncated_normal for initialization
689
+ # cf https://github.com/pytorch/pytorch/pull/5617
690
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
691
+ elif isinstance(module, nn.LayerNorm):
692
+ module.bias.data.zero_()
693
+ module.weight.data.fill_(1.0)
694
+ if isinstance(module, nn.Linear) and module.bias is not None:
695
+ module.bias.data.zero_()
696
+
697
+
698
+ class BertModel(BertPreTrainedModel):
699
+ """
700
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
701
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
702
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
703
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
704
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
705
+ input to the forward pass.
706
+ """
707
+
708
+ def __init__(self, config, add_pooling_layer=False):
709
+ super().__init__(config)
710
+ self.config = config
711
+
712
+ self.embeddings = BertEmbeddings(config)
713
+
714
+ self.encoder = BertEncoder(config)
715
+
716
+ self.pooler = BertPooler(config) if add_pooling_layer else None
717
+
718
+ self.init_weights()
719
+
720
+ def get_input_embeddings(self):
721
+ return self.embeddings.word_embeddings
722
+
723
+ def set_input_embeddings(self, value):
724
+ self.embeddings.word_embeddings = value
725
+
726
+ def _prune_heads(self, heads_to_prune):
727
+ """
728
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
729
+ class PreTrainedModel
730
+ """
731
+ for layer, heads in heads_to_prune.items():
732
+ self.encoder.layer[layer].attention.prune_heads(heads)
733
+
734
+ def get_extended_attention_mask(
735
+ self,
736
+ attention_mask: Tensor,
737
+ input_shape: Tuple[int],
738
+ device: device,
739
+ is_decoder: bool,
740
+ has_query: bool = False,
741
+ ) -> Tensor:
742
+ """
743
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
744
+
745
+ Arguments:
746
+ attention_mask (:obj:`torch.Tensor`):
747
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
748
+ input_shape (:obj:`Tuple[int]`):
749
+ The shape of the input to the model.
750
+ device: (:obj:`torch.device`):
751
+ The device of the input to the model.
752
+
753
+ Returns:
754
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
755
+ """
756
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
757
+ # ourselves in which case we just need to make it broadcastable to all heads.
758
+ if attention_mask.dim() == 3:
759
+ extended_attention_mask = attention_mask[:, None, :, :]
760
+ elif attention_mask.dim() == 2:
761
+ # Provided a padding mask of dimensions [batch_size, seq_length]
762
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
763
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
764
+ if is_decoder:
765
+ batch_size, seq_length = input_shape
766
+
767
+ seq_ids = torch.arange(seq_length, device=device)
768
+ causal_mask = (
769
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
770
+ <= seq_ids[None, :, None]
771
+ )
772
+
773
+ # add a prefix ones mask to the causal mask
774
+ # causal and attention masks must have same type with pytorch version < 1.3
775
+ causal_mask = causal_mask.to(attention_mask.dtype)
776
+
777
+ if causal_mask.shape[1] < attention_mask.shape[1]:
778
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
779
+ if has_query: # UniLM style attention mask
780
+ causal_mask = torch.cat(
781
+ [
782
+ torch.zeros(
783
+ (batch_size, prefix_seq_len, seq_length),
784
+ device=device,
785
+ dtype=causal_mask.dtype,
786
+ ),
787
+ causal_mask,
788
+ ],
789
+ axis=1,
790
+ )
791
+ causal_mask = torch.cat(
792
+ [
793
+ torch.ones(
794
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
795
+ device=device,
796
+ dtype=causal_mask.dtype,
797
+ ),
798
+ causal_mask,
799
+ ],
800
+ axis=-1,
801
+ )
802
+ extended_attention_mask = (
803
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
804
+ )
805
+ else:
806
+ extended_attention_mask = attention_mask[:, None, None, :]
807
+ else:
808
+ raise ValueError(
809
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
810
+ input_shape, attention_mask.shape
811
+ )
812
+ )
813
+
814
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
815
+ # masked positions, this operation will create a tensor which is 0.0 for
816
+ # positions we want to attend and -10000.0 for masked positions.
817
+ # Since we are adding it to the raw scores before the softmax, this is
818
+ # effectively the same as removing these entirely.
819
+ extended_attention_mask = extended_attention_mask.to(
820
+ dtype=self.dtype
821
+ ) # fp16 compatibility
822
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
823
+ return extended_attention_mask
824
+
825
+ def forward(
826
+ self,
827
+ input_ids=None,
828
+ attention_mask=None,
829
+ position_ids=None,
830
+ head_mask=None,
831
+ query_embeds=None,
832
+ encoder_hidden_states=None,
833
+ encoder_attention_mask=None,
834
+ past_key_values=None,
835
+ use_cache=None,
836
+ output_attentions=None,
837
+ output_hidden_states=None,
838
+ return_dict=None,
839
+ is_decoder=False,
840
+ ):
841
+ r"""
842
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
843
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
844
+ the model is configured as a decoder.
845
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
846
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
847
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
848
+ - 1 for tokens that are **not masked**,
849
+ - 0 for tokens that are **masked**.
850
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
851
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
852
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
853
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
854
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
855
+ use_cache (:obj:`bool`, `optional`):
856
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
857
+ decoding (see :obj:`past_key_values`).
858
+ """
859
+ output_attentions = (
860
+ output_attentions
861
+ if output_attentions is not None
862
+ else self.config.output_attentions
863
+ )
864
+ output_hidden_states = (
865
+ output_hidden_states
866
+ if output_hidden_states is not None
867
+ else self.config.output_hidden_states
868
+ )
869
+ return_dict = (
870
+ return_dict if return_dict is not None else self.config.use_return_dict
871
+ )
872
+
873
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
874
+
875
+ if input_ids is None:
876
+ assert (
877
+ query_embeds is not None
878
+ ), "You have to specify query_embeds when input_ids is None"
879
+
880
+ # past_key_values_length
881
+ past_key_values_length = (
882
+ past_key_values[0][0].shape[2] - self.config.query_length
883
+ if past_key_values is not None
884
+ else 0
885
+ )
886
+
887
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
888
+
889
+ embedding_output = self.embeddings(
890
+ input_ids=input_ids,
891
+ position_ids=position_ids,
892
+ query_embeds=query_embeds,
893
+ past_key_values_length=past_key_values_length,
894
+ )
895
+
896
+ input_shape = embedding_output.size()[:-1]
897
+ batch_size, seq_length = input_shape
898
+ device = embedding_output.device
899
+
900
+ if attention_mask is None:
901
+ attention_mask = torch.ones(
902
+ ((batch_size, seq_length + past_key_values_length)), device=device
903
+ )
904
+
905
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
906
+ # ourselves in which case we just need to make it broadcastable to all heads.
907
+ if is_decoder:
908
+ extended_attention_mask = self.get_extended_attention_mask(
909
+ attention_mask,
910
+ input_ids.shape,
911
+ device,
912
+ is_decoder,
913
+ has_query=(query_embeds is not None),
914
+ )
915
+ else:
916
+ extended_attention_mask = self.get_extended_attention_mask(
917
+ attention_mask, input_shape, device, is_decoder
918
+ )
919
+
920
+ # If a 2D or 3D attention mask is provided for the cross-attention
921
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
922
+ if encoder_hidden_states is not None:
923
+ if type(encoder_hidden_states) == list:
924
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
925
+ 0
926
+ ].size()
927
+ else:
928
+ (
929
+ encoder_batch_size,
930
+ encoder_sequence_length,
931
+ _,
932
+ ) = encoder_hidden_states.size()
933
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
934
+
935
+ if type(encoder_attention_mask) == list:
936
+ encoder_extended_attention_mask = [
937
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
938
+ ]
939
+ elif encoder_attention_mask is None:
940
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
941
+ encoder_extended_attention_mask = self.invert_attention_mask(
942
+ encoder_attention_mask
943
+ )
944
+ else:
945
+ encoder_extended_attention_mask = self.invert_attention_mask(
946
+ encoder_attention_mask
947
+ )
948
+ else:
949
+ encoder_extended_attention_mask = None
950
+
951
+ # Prepare head mask if needed
952
+ # 1.0 in head_mask indicate we keep the head
953
+ # attention_probs has shape bsz x n_heads x N x N
954
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
955
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
956
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
957
+
958
+ encoder_outputs = self.encoder(
959
+ embedding_output,
960
+ attention_mask=extended_attention_mask,
961
+ head_mask=head_mask,
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ encoder_attention_mask=encoder_extended_attention_mask,
964
+ past_key_values=past_key_values,
965
+ use_cache=use_cache,
966
+ output_attentions=output_attentions,
967
+ output_hidden_states=output_hidden_states,
968
+ return_dict=return_dict,
969
+ query_length=query_length,
970
+ )
971
+ sequence_output = encoder_outputs[0]
972
+ pooled_output = (
973
+ self.pooler(sequence_output) if self.pooler is not None else None
974
+ )
975
+
976
+ if not return_dict:
977
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
978
+
979
+ return BaseModelOutputWithPoolingAndCrossAttentions(
980
+ last_hidden_state=sequence_output,
981
+ pooler_output=pooled_output,
982
+ past_key_values=encoder_outputs.past_key_values,
983
+ hidden_states=encoder_outputs.hidden_states,
984
+ attentions=encoder_outputs.attentions,
985
+ cross_attentions=encoder_outputs.cross_attentions,
986
+ )
987
+
988
+
989
+ class BertLMHeadModel(BertPreTrainedModel):
990
+
991
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
992
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
993
+
994
+ def __init__(self, config):
995
+ super().__init__(config)
996
+
997
+ self.bert = BertModel(config, add_pooling_layer=False)
998
+ self.cls = BertOnlyMLMHead(config)
999
+
1000
+ self.init_weights()
1001
+
1002
+ def get_output_embeddings(self):
1003
+ return self.cls.predictions.decoder
1004
+
1005
+ def set_output_embeddings(self, new_embeddings):
1006
+ self.cls.predictions.decoder = new_embeddings
1007
+
1008
+ def forward(
1009
+ self,
1010
+ input_ids=None,
1011
+ attention_mask=None,
1012
+ position_ids=None,
1013
+ head_mask=None,
1014
+ query_embeds=None,
1015
+ encoder_hidden_states=None,
1016
+ encoder_attention_mask=None,
1017
+ labels=None,
1018
+ past_key_values=None,
1019
+ use_cache=True,
1020
+ output_attentions=None,
1021
+ output_hidden_states=None,
1022
+ return_dict=None,
1023
+ return_logits=False,
1024
+ is_decoder=True,
1025
+ reduction="mean",
1026
+ ):
1027
+ r"""
1028
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1029
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1030
+ the model is configured as a decoder.
1031
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1032
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1033
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1034
+ - 1 for tokens that are **not masked**,
1035
+ - 0 for tokens that are **masked**.
1036
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1037
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1038
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1039
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1040
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1041
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1042
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1043
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1044
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1045
+ use_cache (:obj:`bool`, `optional`):
1046
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1047
+ decoding (see :obj:`past_key_values`).
1048
+ Returns:
1049
+ Example::
1050
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1051
+ >>> import torch
1052
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1053
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1054
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1055
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1056
+ >>> outputs = model(**inputs)
1057
+ >>> prediction_logits = outputs.logits
1058
+ """
1059
+ return_dict = (
1060
+ return_dict if return_dict is not None else self.config.use_return_dict
1061
+ )
1062
+ if labels is not None:
1063
+ use_cache = False
1064
+ if past_key_values is not None:
1065
+ query_embeds = None
1066
+
1067
+ outputs = self.bert(
1068
+ input_ids,
1069
+ attention_mask=attention_mask,
1070
+ position_ids=position_ids,
1071
+ head_mask=head_mask,
1072
+ query_embeds=query_embeds,
1073
+ encoder_hidden_states=encoder_hidden_states,
1074
+ encoder_attention_mask=encoder_attention_mask,
1075
+ past_key_values=past_key_values,
1076
+ use_cache=use_cache,
1077
+ output_attentions=output_attentions,
1078
+ output_hidden_states=output_hidden_states,
1079
+ return_dict=return_dict,
1080
+ is_decoder=is_decoder,
1081
+ )
1082
+
1083
+ sequence_output = outputs[0]
1084
+ if query_embeds is not None:
1085
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1086
+
1087
+ prediction_scores = self.cls(sequence_output)
1088
+
1089
+ if return_logits:
1090
+ return prediction_scores[:, :-1, :].contiguous()
1091
+
1092
+ lm_loss = None
1093
+ if labels is not None:
1094
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1095
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1096
+ labels = labels[:, 1:].contiguous()
1097
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1098
+ lm_loss = loss_fct(
1099
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1100
+ labels.view(-1),
1101
+ )
1102
+ if reduction == "none":
1103
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1104
+
1105
+ if not return_dict:
1106
+ output = (prediction_scores,) + outputs[2:]
1107
+ return ((lm_loss,) + output) if lm_loss is not None else output
1108
+
1109
+ return CausalLMOutputWithCrossAttentions(
1110
+ loss=lm_loss,
1111
+ logits=prediction_scores,
1112
+ past_key_values=outputs.past_key_values,
1113
+ hidden_states=outputs.hidden_states,
1114
+ attentions=outputs.attentions,
1115
+ cross_attentions=outputs.cross_attentions,
1116
+ )
1117
+
1118
+ def prepare_inputs_for_generation(
1119
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1120
+ ):
1121
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1122
+ if attention_mask is None:
1123
+ attention_mask = input_ids.new_ones(input_ids.shape)
1124
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1125
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1126
+
1127
+ # cut decoder_input_ids if past is used
1128
+ if past is not None:
1129
+ input_ids = input_ids[:, -1:]
1130
+
1131
+ return {
1132
+ "input_ids": input_ids,
1133
+ "query_embeds": query_embeds,
1134
+ "attention_mask": attention_mask,
1135
+ "past_key_values": past,
1136
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1137
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1138
+ "is_decoder": True,
1139
+ }
1140
+
1141
+ def _reorder_cache(self, past, beam_idx):
1142
+ reordered_past = ()
1143
+ for layer_past in past:
1144
+ reordered_past += (
1145
+ tuple(
1146
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1147
+ ),
1148
+ )
1149
+ return reordered_past
1150
+
1151
+
1152
+ class BertForMaskedLM(BertPreTrainedModel):
1153
+
1154
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1155
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1156
+
1157
+ def __init__(self, config):
1158
+ super().__init__(config)
1159
+
1160
+ self.bert = BertModel(config, add_pooling_layer=False)
1161
+ self.cls = BertOnlyMLMHead(config)
1162
+
1163
+ self.init_weights()
1164
+
1165
+ def get_output_embeddings(self):
1166
+ return self.cls.predictions.decoder
1167
+
1168
+ def set_output_embeddings(self, new_embeddings):
1169
+ self.cls.predictions.decoder = new_embeddings
1170
+
1171
+ def forward(
1172
+ self,
1173
+ input_ids=None,
1174
+ attention_mask=None,
1175
+ position_ids=None,
1176
+ head_mask=None,
1177
+ query_embeds=None,
1178
+ encoder_hidden_states=None,
1179
+ encoder_attention_mask=None,
1180
+ labels=None,
1181
+ output_attentions=None,
1182
+ output_hidden_states=None,
1183
+ return_dict=None,
1184
+ return_logits=False,
1185
+ is_decoder=False,
1186
+ ):
1187
+ r"""
1188
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1189
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1190
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1191
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1192
+ """
1193
+
1194
+ return_dict = (
1195
+ return_dict if return_dict is not None else self.config.use_return_dict
1196
+ )
1197
+
1198
+ outputs = self.bert(
1199
+ input_ids,
1200
+ attention_mask=attention_mask,
1201
+ position_ids=position_ids,
1202
+ head_mask=head_mask,
1203
+ query_embeds=query_embeds,
1204
+ encoder_hidden_states=encoder_hidden_states,
1205
+ encoder_attention_mask=encoder_attention_mask,
1206
+ output_attentions=output_attentions,
1207
+ output_hidden_states=output_hidden_states,
1208
+ return_dict=return_dict,
1209
+ is_decoder=is_decoder,
1210
+ )
1211
+
1212
+ if query_embeds is not None:
1213
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1214
+ prediction_scores = self.cls(sequence_output)
1215
+
1216
+ if return_logits:
1217
+ return prediction_scores
1218
+
1219
+ masked_lm_loss = None
1220
+ if labels is not None:
1221
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1222
+ masked_lm_loss = loss_fct(
1223
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1224
+ )
1225
+
1226
+ if not return_dict:
1227
+ output = (prediction_scores,) + outputs[2:]
1228
+ return (
1229
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1230
+ )
1231
+
1232
+ return MaskedLMOutput(
1233
+ loss=masked_lm_loss,
1234
+ logits=prediction_scores,
1235
+ hidden_states=outputs.hidden_states,
1236
+ attentions=outputs.attentions,
1237
+ )
models_viclip/backbones/blip_toremove/__init__.py ADDED
File without changes
models_viclip/backbones/blip_toremove/builder.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import logging
4
+
5
+
6
+ from .Qformer import BertConfig, BertLMHeadModel
7
+ from .modeling_t5 import T5Config, T5ForConditionalGeneration
8
+ from models.utils import load_temp_embed_with_mismatch
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def build_qformer(num_query_token, vision_width,
14
+ qformer_hidden_dropout_prob=0.1,
15
+ qformer_attention_probs_dropout_prob=0.1,
16
+ drop_path_rate=0.,
17
+ ):
18
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
19
+ encoder_config.encoder_width = vision_width
20
+ # insert cross-attention layer every other block
21
+ encoder_config.add_cross_attention = True
22
+ encoder_config.cross_attention_freq = 2
23
+ encoder_config.query_length = num_query_token
24
+ encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
25
+ encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
26
+ encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_config.num_hidden_layers)]
27
+ logger.info(f"Drop_path:{encoder_config.drop_path_list}")
28
+ logger.info(encoder_config)
29
+ Qformer = BertLMHeadModel.from_pretrained(
30
+ "bert-base-uncased", config=encoder_config
31
+ )
32
+ query_tokens = nn.Parameter(
33
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
34
+ )
35
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
36
+ return Qformer, query_tokens
37
+
38
+ def interpolate_pos_embed_blip(state_dict, new_model):
39
+ if "vision_temp_embed" in state_dict:
40
+ vision_temp_embed_new = new_model.state_dict()["vision_temp_embed"]
41
+ state_dict["vision_temp_embed"] = load_temp_embed_with_mismatch(
42
+ state_dict["vision_temp_embed"], vision_temp_embed_new, add_zero=False
43
+ )
44
+ return state_dict
models_viclip/backbones/blip_toremove/modeling_t5.py ADDED
@@ -0,0 +1,2063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch T5 model."""
16
+
17
+
18
+ import copy
19
+ import math
20
+ import os
21
+ import warnings
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+ from torch.utils.checkpoint import checkpoint
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ Seq2SeqLMOutput,
34
+ Seq2SeqModelOutput,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.pytorch_utils import (
38
+ ALL_LAYERNORM_LAYERS,
39
+ find_pruneable_heads_and_indices,
40
+ prune_linear_layer,
41
+ )
42
+ from transformers.utils import (
43
+ DUMMY_INPUTS,
44
+ DUMMY_MASK,
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ is_torch_fx_proxy,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
52
+ from transformers.models.t5.configuration_t5 import T5Config
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CONFIG_FOR_DOC = "T5Config"
58
+ _TOKENIZER_FOR_DOC = "T5Tokenizer"
59
+ _CHECKPOINT_FOR_DOC = "t5-small"
60
+
61
+ ####################################################
62
+ # This dict contains ids and associated url
63
+ # for the pretrained weights provided with the models
64
+ ####################################################
65
+ T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
66
+ "t5-small",
67
+ "t5-base",
68
+ "t5-large",
69
+ "t5-3b",
70
+ "t5-11b",
71
+ # See all T5 models at https://huggingface.co/models?filter=t5
72
+ ]
73
+
74
+
75
+ ####################################################
76
+ # This is a conversion method from TF 1.0 to PyTorch
77
+ # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
78
+ ####################################################
79
+ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
80
+ """Load tf checkpoints in a pytorch model."""
81
+ try:
82
+ import re
83
+
84
+ import numpy as np
85
+ import tensorflow as tf
86
+ except ImportError:
87
+ logger.error(
88
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
89
+ "https://www.tensorflow.org/install/ for installation instructions."
90
+ )
91
+ raise
92
+ tf_path = os.path.abspath(tf_checkpoint_path)
93
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
94
+ # Load weights from TF model
95
+ init_vars = tf.train.list_variables(tf_path)
96
+ names = []
97
+ tf_weights = {}
98
+ for name, shape in init_vars:
99
+ logger.info(f"Loading TF weight {name} with shape {shape}")
100
+ array = tf.train.load_variable(tf_path, name)
101
+ names.append(name)
102
+ tf_weights[name] = array
103
+
104
+ for txt_name in names:
105
+ name = txt_name.split("/")
106
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
107
+ # which are not required for using pretrained model
108
+ if any(
109
+ n
110
+ in [
111
+ "adam_v",
112
+ "adam_m",
113
+ "AdamWeightDecayOptimizer",
114
+ "AdamWeightDecayOptimizer_1",
115
+ "global_step",
116
+ ]
117
+ for n in name
118
+ ):
119
+ logger.info(f"Skipping {'/'.join(name)}")
120
+ tf_weights.pop(txt_name, None)
121
+ continue
122
+ if "_slot_" in name[-1]:
123
+ logger.info(f"Skipping {'/'.join(name)}")
124
+ tf_weights.pop(txt_name, None)
125
+ continue
126
+ pointer = model
127
+ array = tf_weights[txt_name]
128
+
129
+ for m_name in name:
130
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
131
+ scope_names = re.split(r"_(\d+)", m_name)
132
+ else:
133
+ scope_names = [m_name]
134
+ if scope_names[0] in ["kernel", "scale", "embedding"]:
135
+ pointer = getattr(pointer, "weight")
136
+ elif scope_names[0] == "self_attention":
137
+ pointer = getattr(pointer, "layer")
138
+ pointer = pointer[0]
139
+ elif scope_names[0] == "enc_dec_attention":
140
+ pointer = getattr(pointer, "layer")
141
+ pointer = pointer[1]
142
+ elif scope_names[0] == "dense_relu_dense":
143
+ pointer = getattr(pointer, "layer")
144
+ pointer = pointer[2]
145
+ elif scope_names[0] == "rms_norm":
146
+ if hasattr(pointer, "layer_norm"):
147
+ pointer = getattr(pointer, "layer_norm")
148
+ elif hasattr(pointer, "final_layer_norm"):
149
+ pointer = getattr(pointer, "final_layer_norm")
150
+ elif scope_names[0] == "scale":
151
+ pointer = getattr(pointer, "weight")
152
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
153
+ pointer = getattr(pointer, "bias")
154
+ elif scope_names[0] == "squad":
155
+ pointer = getattr(pointer, "classifier")
156
+ elif scope_names[0] == "decoder" and name[1] == "logits":
157
+ continue
158
+ elif scope_names[0] == "logits":
159
+ pointer = getattr(pointer, "lm_head")
160
+ elif (
161
+ scope_names[0] == "wi"
162
+ and len(scope_names) > 1
163
+ and scope_names[1].isdigit()
164
+ ):
165
+ pointer = getattr(pointer, f"wi_{scope_names[1]}")
166
+ continue
167
+ else:
168
+ try:
169
+ pointer = getattr(pointer, scope_names[0])
170
+ except AttributeError:
171
+ logger.info(f"Skipping {'/'.join(name)}")
172
+ continue
173
+ if len(scope_names) >= 2:
174
+ num = int(scope_names[1])
175
+ pointer = pointer[num]
176
+ if scope_names[0] not in ["kernel", "scale", "embedding"]:
177
+ pointer = getattr(pointer, "weight")
178
+ if scope_names[0] != "embedding":
179
+ logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
180
+ array = np.transpose(array)
181
+ try:
182
+ assert (
183
+ pointer.shape == array.shape
184
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
185
+ except AssertionError as e:
186
+ e.args += (pointer.shape, array.shape)
187
+ raise
188
+ logger.info(f"Initialize PyTorch weight {name}")
189
+ pointer.data = torch.from_numpy(array.astype(np.float32))
190
+ tf_weights.pop(txt_name, None)
191
+
192
+ logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
193
+ return model
194
+
195
+
196
+ ####################################################
197
+ # PyTorch Models are constructed by sub-classing
198
+ # - torch.nn.Module for the layers and
199
+ # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
200
+ ####################################################
201
+ PARALLELIZE_DOCSTRING = r"""
202
+ This is an experimental feature and is a subject to change at a moment's notice.
203
+
204
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
205
+ it will evenly distribute blocks across all devices.
206
+
207
+ Args:
208
+ device_map (`Dict[int, list]`, optional, defaults to None):
209
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
210
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
211
+ have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
212
+ following number of attention modules:
213
+
214
+ - t5-small: 6
215
+ - t5-base: 12
216
+ - t5-large: 24
217
+ - t5-3b: 24
218
+ - t5-11b: 24
219
+
220
+ Example:
221
+
222
+ ```python
223
+ # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
224
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
225
+ device_map = {
226
+ 0: [0, 1, 2],
227
+ 1: [3, 4, 5, 6, 7, 8, 9],
228
+ 2: [10, 11, 12, 13, 14, 15, 16],
229
+ 3: [17, 18, 19, 20, 21, 22, 23],
230
+ }
231
+ model.parallelize(device_map)
232
+ ```
233
+ """
234
+ DEPARALLELIZE_DOCSTRING = r"""
235
+ Moves the model to cpu from a model parallel state.
236
+
237
+ Example:
238
+
239
+ ```python
240
+ # On a 4 GPU machine with t5-3b:
241
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
242
+ device_map = {
243
+ 0: [0, 1, 2],
244
+ 1: [3, 4, 5, 6, 7, 8, 9],
245
+ 2: [10, 11, 12, 13, 14, 15, 16],
246
+ 3: [17, 18, 19, 20, 21, 22, 23],
247
+ }
248
+ model.parallelize(device_map) # Splits the model across several devices
249
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
250
+ ```
251
+ """
252
+
253
+
254
+ class T5LayerNorm(nn.Module):
255
+ def __init__(self, hidden_size, eps=1e-6):
256
+ """
257
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
258
+ """
259
+ super().__init__()
260
+ self.weight = nn.Parameter(torch.ones(hidden_size))
261
+ self.variance_epsilon = eps
262
+
263
+ def forward(self, hidden_states):
264
+
265
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
266
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
267
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
268
+ # half-precision inputs is done in fp32
269
+
270
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
271
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
272
+
273
+ # convert into half-precision if necessary
274
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
275
+ hidden_states = hidden_states.to(self.weight.dtype)
276
+
277
+ return self.weight * hidden_states
278
+
279
+
280
+ try:
281
+ from apex.normalization import FusedRMSNorm
282
+
283
+ T5LayerNorm = FusedRMSNorm # noqa
284
+
285
+ logger.info(
286
+ "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm"
287
+ )
288
+ except ImportError:
289
+ # using the normal T5LayerNorm
290
+ pass
291
+ except Exception:
292
+ logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
293
+ pass
294
+
295
+ ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
296
+
297
+
298
+ class T5DenseActDense(nn.Module):
299
+ def __init__(self, config: T5Config):
300
+ super().__init__()
301
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
302
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
303
+ self.dropout = nn.Dropout(config.dropout_rate)
304
+ self.act = ACT2FN[config.dense_act_fn]
305
+
306
+ def forward(self, hidden_states):
307
+ hidden_states = self.wi(hidden_states)
308
+ hidden_states = self.act(hidden_states)
309
+ hidden_states = self.dropout(hidden_states)
310
+ hidden_states = self.wo(hidden_states)
311
+ return hidden_states
312
+
313
+
314
+ class T5DenseGatedActDense(nn.Module):
315
+ def __init__(self, config: T5Config):
316
+ super().__init__()
317
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
318
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
319
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
320
+ self.dropout = nn.Dropout(config.dropout_rate)
321
+ self.act = ACT2FN[config.dense_act_fn]
322
+
323
+ def forward(self, hidden_states):
324
+ hidden_gelu = self.act(self.wi_0(hidden_states))
325
+ hidden_linear = self.wi_1(hidden_states)
326
+ hidden_states = hidden_gelu * hidden_linear
327
+ hidden_states = self.dropout(hidden_states)
328
+ hidden_states = self.wo(hidden_states)
329
+ return hidden_states
330
+
331
+
332
+ class T5LayerFF(nn.Module):
333
+ def __init__(self, config: T5Config):
334
+ super().__init__()
335
+ if config.is_gated_act:
336
+ self.DenseReluDense = T5DenseGatedActDense(config)
337
+ else:
338
+ self.DenseReluDense = T5DenseActDense(config)
339
+
340
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
341
+ self.dropout = nn.Dropout(config.dropout_rate)
342
+
343
+ def forward(self, hidden_states):
344
+ forwarded_states = self.layer_norm(hidden_states)
345
+ forwarded_states = self.DenseReluDense(forwarded_states)
346
+ hidden_states = hidden_states + self.dropout(forwarded_states)
347
+ return hidden_states
348
+
349
+
350
+ class T5Attention(nn.Module):
351
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
352
+ super().__init__()
353
+ self.is_decoder = config.is_decoder
354
+ self.has_relative_attention_bias = has_relative_attention_bias
355
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
356
+ self.relative_attention_max_distance = config.relative_attention_max_distance
357
+ self.d_model = config.d_model
358
+ self.key_value_proj_dim = config.d_kv
359
+ self.n_heads = config.num_heads
360
+ self.dropout = config.dropout_rate
361
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
362
+
363
+ # Mesh TensorFlow initialization to avoid scaling before softmax
364
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
365
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
366
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
367
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
368
+
369
+ if self.has_relative_attention_bias:
370
+ self.relative_attention_bias = nn.Embedding(
371
+ self.relative_attention_num_buckets, self.n_heads
372
+ )
373
+ self.pruned_heads = set()
374
+ self.gradient_checkpointing = False
375
+
376
+ def prune_heads(self, heads):
377
+ if len(heads) == 0:
378
+ return
379
+ heads, index = find_pruneable_heads_and_indices(
380
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
381
+ )
382
+ # Prune linear layers
383
+ self.q = prune_linear_layer(self.q, index)
384
+ self.k = prune_linear_layer(self.k, index)
385
+ self.v = prune_linear_layer(self.v, index)
386
+ self.o = prune_linear_layer(self.o, index, dim=1)
387
+ # Update hyper params
388
+ self.n_heads = self.n_heads - len(heads)
389
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
390
+ self.pruned_heads = self.pruned_heads.union(heads)
391
+
392
+ @staticmethod
393
+ def _relative_position_bucket(
394
+ relative_position, bidirectional=True, num_buckets=32, max_distance=128
395
+ ):
396
+ """
397
+ Adapted from Mesh Tensorflow:
398
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
399
+
400
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
401
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
402
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
403
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
404
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
405
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
406
+
407
+ Args:
408
+ relative_position: an int32 Tensor
409
+ bidirectional: a boolean - whether the attention is bidirectional
410
+ num_buckets: an integer
411
+ max_distance: an integer
412
+
413
+ Returns:
414
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
415
+ """
416
+ relative_buckets = 0
417
+ if bidirectional:
418
+ num_buckets //= 2
419
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
420
+ relative_position = torch.abs(relative_position)
421
+ else:
422
+ relative_position = -torch.min(
423
+ relative_position, torch.zeros_like(relative_position)
424
+ )
425
+ # now relative_position is in the range [0, inf)
426
+
427
+ # half of the buckets are for exact increments in positions
428
+ max_exact = num_buckets // 2
429
+ is_small = relative_position < max_exact
430
+
431
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
432
+ relative_position_if_large = max_exact + (
433
+ torch.log(relative_position.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_position_if_large = torch.min(
438
+ relative_position_if_large,
439
+ torch.full_like(relative_position_if_large, num_buckets - 1),
440
+ )
441
+
442
+ relative_buckets += torch.where(
443
+ is_small, relative_position, relative_position_if_large
444
+ )
445
+ return relative_buckets
446
+
447
+ def compute_bias(self, query_length, key_length, device=None):
448
+ """Compute binned relative position bias"""
449
+ if device is None:
450
+ device = self.relative_attention_bias.weight.device
451
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[
452
+ :, None
453
+ ]
454
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
455
+ None, :
456
+ ]
457
+ relative_position = (
458
+ memory_position - context_position
459
+ ) # shape (query_length, key_length)
460
+ relative_position_bucket = self._relative_position_bucket(
461
+ relative_position, # shape (query_length, key_length)
462
+ bidirectional=(not self.is_decoder),
463
+ num_buckets=self.relative_attention_num_buckets,
464
+ max_distance=self.relative_attention_max_distance,
465
+ )
466
+ values = self.relative_attention_bias(
467
+ relative_position_bucket
468
+ ) # shape (query_length, key_length, num_heads)
469
+ values = values.permute([2, 0, 1]).unsqueeze(
470
+ 0
471
+ ) # shape (1, num_heads, query_length, key_length)
472
+ return values
473
+
474
+ def forward(
475
+ self,
476
+ hidden_states,
477
+ mask=None,
478
+ key_value_states=None,
479
+ position_bias=None,
480
+ past_key_value=None,
481
+ layer_head_mask=None,
482
+ query_length=None,
483
+ use_cache=False,
484
+ output_attentions=False,
485
+ ):
486
+ """
487
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
488
+ """
489
+ # Input is (batch_size, seq_length, dim)
490
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
491
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
492
+ batch_size, seq_length = hidden_states.shape[:2]
493
+
494
+ real_seq_length = seq_length
495
+
496
+ if past_key_value is not None:
497
+ assert (
498
+ len(past_key_value) == 2
499
+ ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
500
+ real_seq_length += (
501
+ past_key_value[0].shape[2] if query_length is None else query_length
502
+ )
503
+
504
+ key_length = (
505
+ real_seq_length if key_value_states is None else key_value_states.shape[1]
506
+ )
507
+
508
+ def shape(states):
509
+ """projection"""
510
+ return states.view(
511
+ batch_size, -1, self.n_heads, self.key_value_proj_dim
512
+ ).transpose(1, 2)
513
+
514
+ def unshape(states):
515
+ """reshape"""
516
+ return (
517
+ states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
518
+ )
519
+
520
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
521
+ """projects hidden states correctly to key/query states"""
522
+ if key_value_states is None:
523
+ # self-attn
524
+ # (batch_size, n_heads, seq_length, dim_per_head)
525
+ hidden_states = shape(proj_layer(hidden_states))
526
+ elif past_key_value is None:
527
+ # cross-attn
528
+ # (batch_size, n_heads, seq_length, dim_per_head)
529
+ hidden_states = shape(proj_layer(key_value_states))
530
+
531
+ if past_key_value is not None:
532
+ if key_value_states is None:
533
+ # self-attn
534
+ # (batch_size, n_heads, key_length, dim_per_head)
535
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
536
+ else:
537
+ # cross-attn
538
+ hidden_states = past_key_value
539
+ return hidden_states
540
+
541
+ # get query states
542
+ query_states = shape(
543
+ self.q(hidden_states)
544
+ ) # (batch_size, n_heads, seq_length, dim_per_head)
545
+
546
+ # get key/value states
547
+ key_states = project(
548
+ hidden_states,
549
+ self.k,
550
+ key_value_states,
551
+ past_key_value[0] if past_key_value is not None else None,
552
+ )
553
+ value_states = project(
554
+ hidden_states,
555
+ self.v,
556
+ key_value_states,
557
+ past_key_value[1] if past_key_value is not None else None,
558
+ )
559
+
560
+ # compute scores
561
+ scores = torch.matmul(
562
+ query_states, key_states.transpose(3, 2)
563
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
564
+
565
+ if position_bias is None:
566
+ if not self.has_relative_attention_bias:
567
+ position_bias = torch.zeros(
568
+ (1, self.n_heads, real_seq_length, key_length),
569
+ device=scores.device,
570
+ dtype=scores.dtype,
571
+ )
572
+ if self.gradient_checkpointing and self.training:
573
+ position_bias.requires_grad = True
574
+ else:
575
+ position_bias = self.compute_bias(
576
+ real_seq_length, key_length, device=scores.device
577
+ )
578
+
579
+ # if key and values are already calculated
580
+ # we want only the last query position bias
581
+ if past_key_value is not None:
582
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
583
+
584
+ if mask is not None:
585
+ position_bias = (
586
+ position_bias + mask
587
+ ) # (batch_size, n_heads, seq_length, key_length)
588
+
589
+ if self.pruned_heads:
590
+ mask = torch.ones(position_bias.shape[1])
591
+ mask[list(self.pruned_heads)] = 0
592
+ position_bias_masked = position_bias[:, mask.bool()]
593
+ else:
594
+ position_bias_masked = position_bias
595
+
596
+ scores += position_bias_masked
597
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
598
+ scores
599
+ ) # (batch_size, n_heads, seq_length, key_length)
600
+ attn_weights = nn.functional.dropout(
601
+ attn_weights, p=self.dropout, training=self.training
602
+ ) # (batch_size, n_heads, seq_length, key_length)
603
+
604
+ # Mask heads if we want to
605
+ if layer_head_mask is not None:
606
+ attn_weights = attn_weights * layer_head_mask
607
+
608
+ attn_output = unshape(
609
+ torch.matmul(attn_weights, value_states)
610
+ ) # (batch_size, seq_length, dim)
611
+ attn_output = self.o(attn_output)
612
+
613
+ present_key_value_state = (
614
+ (key_states, value_states) if (self.is_decoder and use_cache) else None
615
+ )
616
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
617
+
618
+ if output_attentions:
619
+ outputs = outputs + (attn_weights,)
620
+ return outputs
621
+
622
+
623
+ class T5LayerSelfAttention(nn.Module):
624
+ def __init__(self, config, has_relative_attention_bias=False):
625
+ super().__init__()
626
+ self.SelfAttention = T5Attention(
627
+ config, has_relative_attention_bias=has_relative_attention_bias
628
+ )
629
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
630
+ self.dropout = nn.Dropout(config.dropout_rate)
631
+
632
+ def forward(
633
+ self,
634
+ hidden_states,
635
+ attention_mask=None,
636
+ position_bias=None,
637
+ layer_head_mask=None,
638
+ past_key_value=None,
639
+ use_cache=False,
640
+ output_attentions=False,
641
+ ):
642
+ normed_hidden_states = self.layer_norm(hidden_states)
643
+ attention_output = self.SelfAttention(
644
+ normed_hidden_states,
645
+ mask=attention_mask,
646
+ position_bias=position_bias,
647
+ layer_head_mask=layer_head_mask,
648
+ past_key_value=past_key_value,
649
+ use_cache=use_cache,
650
+ output_attentions=output_attentions,
651
+ )
652
+ hidden_states = hidden_states + self.dropout(attention_output[0])
653
+ outputs = (hidden_states,) + attention_output[
654
+ 1:
655
+ ] # add attentions if we output them
656
+ return outputs
657
+
658
+
659
+ class T5LayerCrossAttention(nn.Module):
660
+ def __init__(self, config):
661
+ super().__init__()
662
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
663
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
664
+ self.dropout = nn.Dropout(config.dropout_rate)
665
+
666
+ def forward(
667
+ self,
668
+ hidden_states,
669
+ key_value_states,
670
+ attention_mask=None,
671
+ position_bias=None,
672
+ layer_head_mask=None,
673
+ past_key_value=None,
674
+ use_cache=False,
675
+ query_length=None,
676
+ output_attentions=False,
677
+ ):
678
+ normed_hidden_states = self.layer_norm(hidden_states)
679
+ attention_output = self.EncDecAttention(
680
+ normed_hidden_states,
681
+ mask=attention_mask,
682
+ key_value_states=key_value_states,
683
+ position_bias=position_bias,
684
+ layer_head_mask=layer_head_mask,
685
+ past_key_value=past_key_value,
686
+ use_cache=use_cache,
687
+ query_length=query_length,
688
+ output_attentions=output_attentions,
689
+ )
690
+ layer_output = hidden_states + self.dropout(attention_output[0])
691
+ outputs = (layer_output,) + attention_output[
692
+ 1:
693
+ ] # add attentions if we output them
694
+ return outputs
695
+
696
+
697
+ class T5Block(nn.Module):
698
+ def __init__(self, config, has_relative_attention_bias=False):
699
+ super().__init__()
700
+ self.is_decoder = config.is_decoder
701
+ self.layer = nn.ModuleList()
702
+ self.layer.append(
703
+ T5LayerSelfAttention(
704
+ config, has_relative_attention_bias=has_relative_attention_bias
705
+ )
706
+ )
707
+ if self.is_decoder:
708
+ self.layer.append(T5LayerCrossAttention(config))
709
+
710
+ self.layer.append(T5LayerFF(config))
711
+
712
+ def forward(
713
+ self,
714
+ hidden_states,
715
+ attention_mask=None,
716
+ position_bias=None,
717
+ encoder_hidden_states=None,
718
+ encoder_attention_mask=None,
719
+ encoder_decoder_position_bias=None,
720
+ layer_head_mask=None,
721
+ cross_attn_layer_head_mask=None,
722
+ past_key_value=None,
723
+ use_cache=False,
724
+ output_attentions=False,
725
+ return_dict=True,
726
+ ):
727
+
728
+ if past_key_value is not None:
729
+ if not self.is_decoder:
730
+ logger.warning(
731
+ "`past_key_values` is passed to the encoder. Please make sure this is intended."
732
+ )
733
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
734
+
735
+ if len(past_key_value) != expected_num_past_key_values:
736
+ raise ValueError(
737
+ f"There should be {expected_num_past_key_values} past states. "
738
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
739
+ f"Got {len(past_key_value)} past key / value states"
740
+ )
741
+
742
+ self_attn_past_key_value = past_key_value[:2]
743
+ cross_attn_past_key_value = past_key_value[2:]
744
+ else:
745
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
746
+
747
+ self_attention_outputs = self.layer[0](
748
+ hidden_states,
749
+ attention_mask=attention_mask,
750
+ position_bias=position_bias,
751
+ layer_head_mask=layer_head_mask,
752
+ past_key_value=self_attn_past_key_value,
753
+ use_cache=use_cache,
754
+ output_attentions=output_attentions,
755
+ )
756
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
757
+ attention_outputs = self_attention_outputs[
758
+ 2:
759
+ ] # Keep self-attention outputs and relative position weights
760
+
761
+ # clamp inf values to enable fp16 training
762
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
763
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
764
+ hidden_states = torch.clamp(
765
+ hidden_states, min=-clamp_value, max=clamp_value
766
+ )
767
+
768
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
769
+ if do_cross_attention:
770
+ # the actual query length is unknown for cross attention
771
+ # if using past key value states. Need to inject it here
772
+ if present_key_value_state is not None:
773
+ query_length = present_key_value_state[0].shape[2]
774
+ else:
775
+ query_length = None
776
+
777
+ cross_attention_outputs = self.layer[1](
778
+ hidden_states,
779
+ key_value_states=encoder_hidden_states,
780
+ attention_mask=encoder_attention_mask,
781
+ position_bias=encoder_decoder_position_bias,
782
+ layer_head_mask=cross_attn_layer_head_mask,
783
+ past_key_value=cross_attn_past_key_value,
784
+ query_length=query_length,
785
+ use_cache=use_cache,
786
+ output_attentions=output_attentions,
787
+ )
788
+ hidden_states = cross_attention_outputs[0]
789
+
790
+ # clamp inf values to enable fp16 training
791
+ if (
792
+ hidden_states.dtype == torch.float16
793
+ and torch.isinf(hidden_states).any()
794
+ ):
795
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
796
+ hidden_states = torch.clamp(
797
+ hidden_states, min=-clamp_value, max=clamp_value
798
+ )
799
+
800
+ # Combine self attn and cross attn key value states
801
+ if present_key_value_state is not None:
802
+ present_key_value_state = (
803
+ present_key_value_state + cross_attention_outputs[1]
804
+ )
805
+
806
+ # Keep cross-attention outputs and relative position weights
807
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
808
+
809
+ # Apply Feed Forward layer
810
+ hidden_states = self.layer[-1](hidden_states)
811
+
812
+ # clamp inf values to enable fp16 training
813
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
814
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
815
+ hidden_states = torch.clamp(
816
+ hidden_states, min=-clamp_value, max=clamp_value
817
+ )
818
+
819
+ outputs = (hidden_states,)
820
+
821
+ if use_cache:
822
+ outputs = outputs + (present_key_value_state,) + attention_outputs
823
+ else:
824
+ outputs = outputs + attention_outputs
825
+
826
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
827
+
828
+
829
+ class T5PreTrainedModel(PreTrainedModel):
830
+ """
831
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
832
+ models.
833
+ """
834
+
835
+ config_class = T5Config
836
+ load_tf_weights = load_tf_weights_in_t5
837
+ base_model_prefix = "transformer"
838
+ is_parallelizable = True
839
+ supports_gradient_checkpointing = True
840
+ _no_split_modules = ["T5Block"]
841
+
842
+ @property
843
+ def dummy_inputs(self):
844
+ input_ids = torch.tensor(DUMMY_INPUTS)
845
+ input_mask = torch.tensor(DUMMY_MASK)
846
+ dummy_inputs = {
847
+ "decoder_input_ids": input_ids,
848
+ "input_ids": input_ids,
849
+ "decoder_attention_mask": input_mask,
850
+ }
851
+ return dummy_inputs
852
+
853
+ def _init_weights(self, module):
854
+ """Initialize the weights"""
855
+ factor = (
856
+ self.config.initializer_factor
857
+ ) # Used for testing weights initialization
858
+ if isinstance(module, T5LayerNorm):
859
+ module.weight.data.fill_(factor * 1.0)
860
+ elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
861
+ # Mesh TensorFlow embeddings initialization
862
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
863
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
864
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
865
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
866
+ elif isinstance(module, T5DenseActDense):
867
+ # Mesh TensorFlow FF initialization
868
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
869
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
870
+ module.wi.weight.data.normal_(
871
+ mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
872
+ )
873
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
874
+ module.wi.bias.data.zero_()
875
+ module.wo.weight.data.normal_(
876
+ mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
877
+ )
878
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
879
+ module.wo.bias.data.zero_()
880
+ elif isinstance(module, T5DenseGatedActDense):
881
+ module.wi_0.weight.data.normal_(
882
+ mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
883
+ )
884
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
885
+ module.wi_0.bias.data.zero_()
886
+ module.wi_1.weight.data.normal_(
887
+ mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
888
+ )
889
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
890
+ module.wi_1.bias.data.zero_()
891
+ module.wo.weight.data.normal_(
892
+ mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
893
+ )
894
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
895
+ module.wo.bias.data.zero_()
896
+ elif isinstance(module, T5Attention):
897
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
898
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
899
+ d_model = self.config.d_model
900
+ key_value_proj_dim = self.config.d_kv
901
+ n_heads = self.config.num_heads
902
+ module.q.weight.data.normal_(
903
+ mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)
904
+ )
905
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
906
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
907
+ module.o.weight.data.normal_(
908
+ mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)
909
+ )
910
+ if module.has_relative_attention_bias:
911
+ module.relative_attention_bias.weight.data.normal_(
912
+ mean=0.0, std=factor * ((d_model) ** -0.5)
913
+ )
914
+
915
+ def _set_gradient_checkpointing(self, module, value=False):
916
+ if isinstance(module, (T5Attention, T5Stack)):
917
+ module.gradient_checkpointing = value
918
+
919
+ def _shift_right(self, input_ids):
920
+ decoder_start_token_id = self.config.decoder_start_token_id
921
+ pad_token_id = self.config.pad_token_id
922
+
923
+ assert decoder_start_token_id is not None, (
924
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
925
+ " See T5 docs for more information"
926
+ )
927
+
928
+ # shift inputs to the right
929
+ if is_torch_fx_proxy(input_ids):
930
+ # Item assignment is not supported natively for proxies.
931
+ shifted_input_ids = torch.full(
932
+ input_ids.shape[:-1] + (1,), decoder_start_token_id
933
+ )
934
+ shifted_input_ids = torch.cat(
935
+ [shifted_input_ids, input_ids[..., :-1]], dim=-1
936
+ )
937
+ else:
938
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
939
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
940
+ shifted_input_ids[..., 0] = decoder_start_token_id
941
+
942
+ assert (
943
+ pad_token_id is not None
944
+ ), "self.model.config.pad_token_id has to be defined."
945
+ # replace possible -100 values in labels by `pad_token_id`
946
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
947
+
948
+ return shifted_input_ids
949
+
950
+
951
+ class T5Stack(T5PreTrainedModel):
952
+ def __init__(self, config, embed_tokens=None):
953
+ super().__init__(config)
954
+
955
+ self.embed_tokens = embed_tokens
956
+ self.is_decoder = config.is_decoder
957
+
958
+ self.block = nn.ModuleList(
959
+ [
960
+ T5Block(config, has_relative_attention_bias=bool(i == 0))
961
+ for i in range(config.num_layers)
962
+ ]
963
+ )
964
+ self.final_layer_norm = T5LayerNorm(
965
+ config.d_model, eps=config.layer_norm_epsilon
966
+ )
967
+ self.dropout = nn.Dropout(config.dropout_rate)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+ # Model parallel
972
+ self.model_parallel = False
973
+ self.device_map = None
974
+ self.gradient_checkpointing = False
975
+
976
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
977
+ def parallelize(self, device_map=None):
978
+ # Check validity of device_map
979
+ self.device_map = (
980
+ get_device_map(len(self.block), range(torch.cuda.device_count()))
981
+ if device_map is None
982
+ else device_map
983
+ )
984
+ assert_device_map(self.device_map, len(self.block))
985
+ self.model_parallel = True
986
+ self.first_device = (
987
+ "cpu"
988
+ if "cpu" in self.device_map.keys()
989
+ else "cuda:" + str(min(self.device_map.keys()))
990
+ )
991
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
992
+ # Load onto devices
993
+ for k, v in self.device_map.items():
994
+ for layer in v:
995
+ cuda_device = "cuda:" + str(k)
996
+ self.block[layer] = self.block[layer].to(cuda_device)
997
+
998
+ # Set embed_tokens to first layer
999
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
1000
+ # Set final layer norm to last device
1001
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
1002
+
1003
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1004
+ def deparallelize(self):
1005
+ self.model_parallel = False
1006
+ self.device_map = None
1007
+ self.first_device = "cpu"
1008
+ self.last_device = "cpu"
1009
+ for i in range(len(self.block)):
1010
+ self.block[i] = self.block[i].to("cpu")
1011
+ self.embed_tokens = self.embed_tokens.to("cpu")
1012
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
1013
+ torch.cuda.empty_cache()
1014
+
1015
+ def get_input_embeddings(self):
1016
+ return self.embed_tokens
1017
+
1018
+ def set_input_embeddings(self, new_embeddings):
1019
+ self.embed_tokens = new_embeddings
1020
+
1021
+ def forward(
1022
+ self,
1023
+ input_ids=None,
1024
+ attention_mask=None,
1025
+ encoder_hidden_states=None,
1026
+ encoder_attention_mask=None,
1027
+ inputs_embeds=None,
1028
+ head_mask=None,
1029
+ cross_attn_head_mask=None,
1030
+ past_key_values=None,
1031
+ use_cache=None,
1032
+ output_attentions=None,
1033
+ output_hidden_states=None,
1034
+ return_dict=None,
1035
+ ):
1036
+ # Model parallel
1037
+ if self.model_parallel:
1038
+ torch.cuda.set_device(self.first_device)
1039
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
1040
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1041
+ output_attentions = (
1042
+ output_attentions
1043
+ if output_attentions is not None
1044
+ else self.config.output_attentions
1045
+ )
1046
+ output_hidden_states = (
1047
+ output_hidden_states
1048
+ if output_hidden_states is not None
1049
+ else self.config.output_hidden_states
1050
+ )
1051
+ return_dict = (
1052
+ return_dict if return_dict is not None else self.config.use_return_dict
1053
+ )
1054
+
1055
+ if input_ids is not None and inputs_embeds is not None:
1056
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
1057
+ raise ValueError(
1058
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
1059
+ )
1060
+ elif input_ids is not None:
1061
+ input_shape = input_ids.size()
1062
+ input_ids = input_ids.view(-1, input_shape[-1])
1063
+ elif inputs_embeds is not None:
1064
+ input_shape = inputs_embeds.size()[:-1]
1065
+ else:
1066
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
1067
+ raise ValueError(
1068
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
1069
+ )
1070
+
1071
+ if inputs_embeds is None:
1072
+ assert (
1073
+ self.embed_tokens is not None
1074
+ ), "You have to initialize the model with valid token embeddings"
1075
+ inputs_embeds = self.embed_tokens(input_ids)
1076
+
1077
+ batch_size, seq_length = input_shape
1078
+
1079
+ # required mask seq length can be calculated via length of past
1080
+ mask_seq_length = (
1081
+ past_key_values[0][0].shape[2] + seq_length
1082
+ if past_key_values is not None
1083
+ else seq_length
1084
+ )
1085
+
1086
+ if use_cache is True:
1087
+ assert (
1088
+ self.is_decoder
1089
+ ), f"`use_cache` can only be set to `True` if {self} is used as a decoder"
1090
+
1091
+ if attention_mask is None:
1092
+ attention_mask = torch.ones(
1093
+ batch_size, mask_seq_length, device=inputs_embeds.device
1094
+ )
1095
+ if (
1096
+ self.is_decoder
1097
+ and encoder_attention_mask is None
1098
+ and encoder_hidden_states is not None
1099
+ ):
1100
+ encoder_seq_length = encoder_hidden_states.shape[1]
1101
+ encoder_attention_mask = torch.ones(
1102
+ batch_size,
1103
+ encoder_seq_length,
1104
+ device=inputs_embeds.device,
1105
+ dtype=torch.long,
1106
+ )
1107
+
1108
+ # initialize past_key_values with `None` if past does not exist
1109
+ if past_key_values is None:
1110
+ past_key_values = [None] * len(self.block)
1111
+
1112
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1113
+ # ourselves in which case we just need to make it broadcastable to all heads.
1114
+ extended_attention_mask = self.get_extended_attention_mask(
1115
+ attention_mask, input_shape
1116
+ )
1117
+
1118
+ # If a 2D or 3D attention mask is provided for the cross-attention
1119
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1120
+ if self.is_decoder and encoder_hidden_states is not None:
1121
+ (
1122
+ encoder_batch_size,
1123
+ encoder_sequence_length,
1124
+ _,
1125
+ ) = encoder_hidden_states.size()
1126
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1127
+ if encoder_attention_mask is None:
1128
+ encoder_attention_mask = torch.ones(
1129
+ encoder_hidden_shape, device=inputs_embeds.device
1130
+ )
1131
+ encoder_extended_attention_mask = self.invert_attention_mask(
1132
+ encoder_attention_mask
1133
+ )
1134
+ else:
1135
+ encoder_extended_attention_mask = None
1136
+
1137
+ # Prepare head mask if needed
1138
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
1139
+ cross_attn_head_mask = self.get_head_mask(
1140
+ cross_attn_head_mask, self.config.num_layers
1141
+ )
1142
+ present_key_value_states = () if use_cache else None
1143
+ all_hidden_states = () if output_hidden_states else None
1144
+ all_attentions = () if output_attentions else None
1145
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
1146
+ position_bias = None
1147
+ encoder_decoder_position_bias = None
1148
+
1149
+ hidden_states = self.dropout(inputs_embeds)
1150
+
1151
+ for i, (layer_module, past_key_value) in enumerate(
1152
+ zip(self.block, past_key_values)
1153
+ ):
1154
+ layer_head_mask = head_mask[i]
1155
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
1156
+ # Model parallel
1157
+ if self.model_parallel:
1158
+ torch.cuda.set_device(hidden_states.device)
1159
+ # Ensure that attention_mask is always on the same device as hidden_states
1160
+ if attention_mask is not None:
1161
+ attention_mask = attention_mask.to(hidden_states.device)
1162
+ if position_bias is not None:
1163
+ position_bias = position_bias.to(hidden_states.device)
1164
+ if encoder_hidden_states is not None:
1165
+ encoder_hidden_states = encoder_hidden_states.to(
1166
+ hidden_states.device
1167
+ )
1168
+ if encoder_extended_attention_mask is not None:
1169
+ encoder_extended_attention_mask = (
1170
+ encoder_extended_attention_mask.to(hidden_states.device)
1171
+ )
1172
+ if encoder_decoder_position_bias is not None:
1173
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(
1174
+ hidden_states.device
1175
+ )
1176
+ if layer_head_mask is not None:
1177
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
1178
+ if cross_attn_layer_head_mask is not None:
1179
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
1180
+ hidden_states.device
1181
+ )
1182
+ if output_hidden_states:
1183
+ all_hidden_states = all_hidden_states + (hidden_states,)
1184
+
1185
+ if self.gradient_checkpointing and self.training:
1186
+ if use_cache:
1187
+ logger.warning(
1188
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1189
+ )
1190
+ use_cache = False
1191
+
1192
+ def create_custom_forward(module):
1193
+ def custom_forward(*inputs):
1194
+ return tuple(module(*inputs, use_cache, output_attentions))
1195
+
1196
+ return custom_forward
1197
+
1198
+ layer_outputs = checkpoint(
1199
+ create_custom_forward(layer_module),
1200
+ hidden_states,
1201
+ extended_attention_mask,
1202
+ position_bias,
1203
+ encoder_hidden_states,
1204
+ encoder_extended_attention_mask,
1205
+ encoder_decoder_position_bias,
1206
+ layer_head_mask,
1207
+ cross_attn_layer_head_mask,
1208
+ None, # past_key_value is always None with gradient checkpointing
1209
+ )
1210
+ else:
1211
+ layer_outputs = layer_module(
1212
+ hidden_states,
1213
+ attention_mask=extended_attention_mask,
1214
+ position_bias=position_bias,
1215
+ encoder_hidden_states=encoder_hidden_states,
1216
+ encoder_attention_mask=encoder_extended_attention_mask,
1217
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1218
+ layer_head_mask=layer_head_mask,
1219
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1220
+ past_key_value=past_key_value,
1221
+ use_cache=use_cache,
1222
+ output_attentions=output_attentions,
1223
+ )
1224
+
1225
+ # layer_outputs is a tuple with:
1226
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1227
+ if use_cache is False:
1228
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1229
+
1230
+ hidden_states, present_key_value_state = layer_outputs[:2]
1231
+
1232
+ # We share the position biases between the layers - the first layer store them
1233
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1234
+ # (cross-attention position bias), (cross-attention weights)
1235
+ position_bias = layer_outputs[2]
1236
+ if self.is_decoder and encoder_hidden_states is not None:
1237
+ encoder_decoder_position_bias = layer_outputs[
1238
+ 4 if output_attentions else 3
1239
+ ]
1240
+ # append next layer key value states
1241
+ if use_cache:
1242
+ present_key_value_states = present_key_value_states + (
1243
+ present_key_value_state,
1244
+ )
1245
+
1246
+ if output_attentions:
1247
+ all_attentions = all_attentions + (layer_outputs[3],)
1248
+ if self.is_decoder:
1249
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1250
+
1251
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1252
+ if self.model_parallel:
1253
+ for k, v in self.device_map.items():
1254
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1255
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1256
+
1257
+ hidden_states = self.final_layer_norm(hidden_states)
1258
+ hidden_states = self.dropout(hidden_states)
1259
+
1260
+ # Add last layer
1261
+ if output_hidden_states:
1262
+ all_hidden_states = all_hidden_states + (hidden_states,)
1263
+
1264
+ if not return_dict:
1265
+ return tuple(
1266
+ v
1267
+ for v in [
1268
+ hidden_states,
1269
+ present_key_value_states,
1270
+ all_hidden_states,
1271
+ all_attentions,
1272
+ all_cross_attentions,
1273
+ ]
1274
+ if v is not None
1275
+ )
1276
+ return BaseModelOutputWithPastAndCrossAttentions(
1277
+ last_hidden_state=hidden_states,
1278
+ past_key_values=present_key_value_states,
1279
+ hidden_states=all_hidden_states,
1280
+ attentions=all_attentions,
1281
+ cross_attentions=all_cross_attentions,
1282
+ )
1283
+
1284
+
1285
+ T5_START_DOCSTRING = r"""
1286
+
1287
+ The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
1288
+ Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
1289
+ Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
1290
+ text-to-text denoising generative setting.
1291
+
1292
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1293
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1294
+ etc.)
1295
+
1296
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1297
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1298
+ and behavior.
1299
+
1300
+ Parameters:
1301
+ config ([`T5Config`]): Model configuration class with all the parameters of the model.
1302
+ Initializing with a config file does not load the weights associated with the model, only the
1303
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1304
+ """
1305
+
1306
+ T5_INPUTS_DOCSTRING = r"""
1307
+ Args:
1308
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1309
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1310
+ should be able to pad the inputs on both the right and the left.
1311
+
1312
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1313
+ [`PreTrainedTokenizer.__call__`] for detail.
1314
+
1315
+ [What are input IDs?](../glossary#input-ids)
1316
+
1317
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1318
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1319
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1320
+
1321
+ - 1 for tokens that are **not masked**,
1322
+ - 0 for tokens that are **masked**.
1323
+
1324
+ [What are attention masks?](../glossary#attention-mask)
1325
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1326
+ Indices of decoder input sequence tokens in the vocabulary.
1327
+
1328
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1329
+ [`PreTrainedTokenizer.__call__`] for details.
1330
+
1331
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
1332
+
1333
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
1334
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
1335
+
1336
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
1337
+ Training](./t5#training).
1338
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1339
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
1340
+ be used by default.
1341
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1342
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
1343
+ 1]`:
1344
+
1345
+ - 1 indicates the head is **not masked**,
1346
+ - 0 indicates the head is **masked**.
1347
+
1348
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1349
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1350
+ 1]`:
1351
+
1352
+ - 1 indicates the head is **not masked**,
1353
+ - 0 indicates the head is **masked**.
1354
+
1355
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1356
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1357
+ `[0, 1]`:
1358
+
1359
+ - 1 indicates the head is **not masked**,
1360
+ - 0 indicates the head is **masked**.
1361
+
1362
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1363
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
1364
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
1365
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1366
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1367
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1368
+
1369
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1370
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1371
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1372
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1373
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1374
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1375
+ model's internal embedding lookup matrix.
1376
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
1377
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1378
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
1379
+ input (see `past_key_values`). This is useful if you want more control over how to convert
1380
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1381
+
1382
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1383
+ of `inputs_embeds`.
1384
+
1385
+ use_cache (`bool`, *optional*):
1386
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1387
+ `past_key_values`).
1388
+
1389
+ output_attentions (`bool`, *optional*):
1390
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1391
+ tensors for more detail.
1392
+ output_hidden_states (`bool`, *optional*):
1393
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1394
+ more detail.
1395
+ return_dict (`bool`, *optional*):
1396
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1397
+ """
1398
+
1399
+ T5_ENCODER_INPUTS_DOCSTRING = r"""
1400
+ Args:
1401
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1402
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1403
+ should be able to pad the inputs on both the right and the left.
1404
+
1405
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1406
+ [`PreTrainedTokenizer.__call__`] for detail.
1407
+
1408
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1409
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1410
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1411
+
1412
+ - 1 for tokens that are **not masked**,
1413
+ - 0 for tokens that are **masked**.
1414
+
1415
+ [What are attention masks?](../glossary#attention-mask)
1416
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1417
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1418
+
1419
+ - 1 indicates the head is **not masked**,
1420
+ - 0 indicates the head is **masked**.
1421
+
1422
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1423
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1424
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1425
+ model's internal embedding lookup matrix.
1426
+ output_attentions (`bool`, *optional*):
1427
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1428
+ tensors for more detail.
1429
+ output_hidden_states (`bool`, *optional*):
1430
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1431
+ more detail.
1432
+ return_dict (`bool`, *optional*):
1433
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1434
+ """
1435
+
1436
+ # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1437
+ __HEAD_MASK_WARNING_MSG = """
1438
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
1439
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
1440
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
1441
+ num_heads)`.
1442
+ """
1443
+
1444
+
1445
+ @add_start_docstrings(
1446
+ "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1447
+ T5_START_DOCSTRING,
1448
+ )
1449
+ class T5Model(T5PreTrainedModel):
1450
+ _keys_to_ignore_on_load_missing = [
1451
+ r"encoder.embed_tokens.weight",
1452
+ r"decoder.embed_tokens.weight",
1453
+ ]
1454
+ _keys_to_ignore_on_load_unexpected = [
1455
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1456
+ ]
1457
+
1458
+ def __init__(self, config: T5Config):
1459
+ super().__init__(config)
1460
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1461
+
1462
+ encoder_config = copy.deepcopy(config)
1463
+ encoder_config.is_decoder = False
1464
+ encoder_config.use_cache = False
1465
+ encoder_config.is_encoder_decoder = False
1466
+ self.encoder = T5Stack(encoder_config, self.shared)
1467
+
1468
+ decoder_config = copy.deepcopy(config)
1469
+ decoder_config.is_decoder = True
1470
+ decoder_config.is_encoder_decoder = False
1471
+ decoder_config.num_layers = config.num_decoder_layers
1472
+ self.decoder = T5Stack(decoder_config, self.shared)
1473
+
1474
+ # Initialize weights and apply final processing
1475
+ self.post_init()
1476
+
1477
+ # Model parallel
1478
+ self.model_parallel = False
1479
+ self.device_map = None
1480
+
1481
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1482
+ def parallelize(self, device_map=None):
1483
+ self.device_map = (
1484
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1485
+ if device_map is None
1486
+ else device_map
1487
+ )
1488
+ assert_device_map(self.device_map, len(self.encoder.block))
1489
+ self.encoder.parallelize(self.device_map)
1490
+ self.decoder.parallelize(self.device_map)
1491
+ self.model_parallel = True
1492
+
1493
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1494
+ def deparallelize(self):
1495
+ self.encoder.deparallelize()
1496
+ self.decoder.deparallelize()
1497
+ self.encoder = self.encoder.to("cpu")
1498
+ self.decoder = self.decoder.to("cpu")
1499
+ self.model_parallel = False
1500
+ self.device_map = None
1501
+ torch.cuda.empty_cache()
1502
+
1503
+ def get_input_embeddings(self):
1504
+ return self.shared
1505
+
1506
+ def set_input_embeddings(self, new_embeddings):
1507
+ self.shared = new_embeddings
1508
+ self.encoder.set_input_embeddings(new_embeddings)
1509
+ self.decoder.set_input_embeddings(new_embeddings)
1510
+
1511
+ def get_encoder(self):
1512
+ return self.encoder
1513
+
1514
+ def get_decoder(self):
1515
+ return self.decoder
1516
+
1517
+ def _prune_heads(self, heads_to_prune):
1518
+ """
1519
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1520
+ class PreTrainedModel
1521
+ """
1522
+ for layer, heads in heads_to_prune.items():
1523
+ self.encoder.layer[layer].attention.prune_heads(heads)
1524
+
1525
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1526
+ @replace_return_docstrings(
1527
+ output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC
1528
+ )
1529
+ def forward(
1530
+ self,
1531
+ input_ids: Optional[torch.LongTensor] = None,
1532
+ attention_mask: Optional[torch.FloatTensor] = None,
1533
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1534
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1535
+ head_mask: Optional[torch.FloatTensor] = None,
1536
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1537
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1538
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1539
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1540
+ inputs_embeds: Optional[torch.Tensor] = None,
1541
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1542
+ use_cache: Optional[bool] = None,
1543
+ output_attentions: Optional[bool] = None,
1544
+ output_hidden_states: Optional[bool] = None,
1545
+ return_dict: Optional[bool] = None,
1546
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1547
+ r"""
1548
+ Returns:
1549
+
1550
+ Example:
1551
+
1552
+ ```python
1553
+ >>> from transformers import T5Tokenizer, T5Model
1554
+
1555
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1556
+ >>> model = T5Model.from_pretrained("t5-small")
1557
+
1558
+ >>> input_ids = tokenizer(
1559
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1560
+ ... ).input_ids # Batch size 1
1561
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1562
+
1563
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
1564
+ >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
1565
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1566
+
1567
+ >>> # forward pass
1568
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1569
+ >>> last_hidden_states = outputs.last_hidden_state
1570
+ ```"""
1571
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1572
+ return_dict = (
1573
+ return_dict if return_dict is not None else self.config.use_return_dict
1574
+ )
1575
+
1576
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1577
+ if head_mask is not None and decoder_head_mask is None:
1578
+ if self.config.num_layers == self.config.num_decoder_layers:
1579
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
1580
+ decoder_head_mask = head_mask
1581
+
1582
+ # Encode if needed (training, first prediction pass)
1583
+ if encoder_outputs is None:
1584
+ encoder_outputs = self.encoder(
1585
+ input_ids=input_ids,
1586
+ attention_mask=attention_mask,
1587
+ inputs_embeds=inputs_embeds,
1588
+ head_mask=head_mask,
1589
+ output_attentions=output_attentions,
1590
+ output_hidden_states=output_hidden_states,
1591
+ return_dict=return_dict,
1592
+ )
1593
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1594
+ encoder_outputs = BaseModelOutput(
1595
+ last_hidden_state=encoder_outputs[0],
1596
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1597
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1598
+ )
1599
+
1600
+ hidden_states = encoder_outputs[0]
1601
+
1602
+ # Set device for model parallelism
1603
+ if self.model_parallel:
1604
+ torch.cuda.set_device(self.decoder.first_device)
1605
+ hidden_states = hidden_states.to(self.decoder.first_device)
1606
+ if decoder_input_ids is not None:
1607
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1608
+ if attention_mask is not None:
1609
+ attention_mask = attention_mask.to(self.decoder.first_device)
1610
+ if decoder_attention_mask is not None:
1611
+ decoder_attention_mask = decoder_attention_mask.to(
1612
+ self.decoder.first_device
1613
+ )
1614
+
1615
+ # Decode
1616
+ decoder_outputs = self.decoder(
1617
+ input_ids=decoder_input_ids,
1618
+ attention_mask=decoder_attention_mask,
1619
+ inputs_embeds=decoder_inputs_embeds,
1620
+ past_key_values=past_key_values,
1621
+ encoder_hidden_states=hidden_states,
1622
+ encoder_attention_mask=attention_mask,
1623
+ head_mask=decoder_head_mask,
1624
+ cross_attn_head_mask=cross_attn_head_mask,
1625
+ use_cache=use_cache,
1626
+ output_attentions=output_attentions,
1627
+ output_hidden_states=output_hidden_states,
1628
+ return_dict=return_dict,
1629
+ )
1630
+
1631
+ if not return_dict:
1632
+ return decoder_outputs + encoder_outputs
1633
+
1634
+ return Seq2SeqModelOutput(
1635
+ last_hidden_state=decoder_outputs.last_hidden_state,
1636
+ past_key_values=decoder_outputs.past_key_values,
1637
+ decoder_hidden_states=decoder_outputs.hidden_states,
1638
+ decoder_attentions=decoder_outputs.attentions,
1639
+ cross_attentions=decoder_outputs.cross_attentions,
1640
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1641
+ encoder_hidden_states=encoder_outputs.hidden_states,
1642
+ encoder_attentions=encoder_outputs.attentions,
1643
+ )
1644
+
1645
+
1646
+ @add_start_docstrings(
1647
+ """T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING
1648
+ )
1649
+ class T5ForConditionalGeneration(T5PreTrainedModel):
1650
+ _keys_to_ignore_on_load_missing = [
1651
+ r"encoder.embed_tokens.weight",
1652
+ r"decoder.embed_tokens.weight",
1653
+ r"lm_head.weight",
1654
+ ]
1655
+ _keys_to_ignore_on_load_unexpected = [
1656
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1657
+ ]
1658
+
1659
+ def __init__(self, config: T5Config):
1660
+ super().__init__(config)
1661
+ self.model_dim = config.d_model
1662
+
1663
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1664
+
1665
+ encoder_config = copy.deepcopy(config)
1666
+ encoder_config.is_decoder = False
1667
+ encoder_config.use_cache = False
1668
+ encoder_config.is_encoder_decoder = False
1669
+ self.encoder = T5Stack(encoder_config, self.shared)
1670
+
1671
+ decoder_config = copy.deepcopy(config)
1672
+ decoder_config.is_decoder = True
1673
+ decoder_config.is_encoder_decoder = False
1674
+ decoder_config.num_layers = config.num_decoder_layers
1675
+ self.decoder = T5Stack(decoder_config, self.shared)
1676
+
1677
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1678
+
1679
+ # Initialize weights and apply final processing
1680
+ self.post_init()
1681
+
1682
+ # Model parallel
1683
+ self.model_parallel = False
1684
+ self.device_map = None
1685
+
1686
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1687
+ def parallelize(self, device_map=None):
1688
+ self.device_map = (
1689
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1690
+ if device_map is None
1691
+ else device_map
1692
+ )
1693
+ assert_device_map(self.device_map, len(self.encoder.block))
1694
+ self.encoder.parallelize(self.device_map)
1695
+ self.decoder.parallelize(self.device_map)
1696
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1697
+ self.model_parallel = True
1698
+
1699
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1700
+ def deparallelize(self):
1701
+ self.encoder.deparallelize()
1702
+ self.decoder.deparallelize()
1703
+ self.encoder = self.encoder.to("cpu")
1704
+ self.decoder = self.decoder.to("cpu")
1705
+ self.lm_head = self.lm_head.to("cpu")
1706
+ self.model_parallel = False
1707
+ self.device_map = None
1708
+ torch.cuda.empty_cache()
1709
+
1710
+ def get_input_embeddings(self):
1711
+ return self.shared
1712
+
1713
+ def set_input_embeddings(self, new_embeddings):
1714
+ self.shared = new_embeddings
1715
+ self.encoder.set_input_embeddings(new_embeddings)
1716
+ self.decoder.set_input_embeddings(new_embeddings)
1717
+
1718
+ def set_output_embeddings(self, new_embeddings):
1719
+ self.lm_head = new_embeddings
1720
+
1721
+ def get_output_embeddings(self):
1722
+ return self.lm_head
1723
+
1724
+ def get_encoder(self):
1725
+ return self.encoder
1726
+
1727
+ def get_decoder(self):
1728
+ return self.decoder
1729
+
1730
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1731
+ @replace_return_docstrings(
1732
+ output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
1733
+ )
1734
+ def forward(
1735
+ self,
1736
+ input_ids: Optional[torch.LongTensor] = None,
1737
+ attention_mask: Optional[torch.FloatTensor] = None,
1738
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1739
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1740
+ head_mask: Optional[torch.FloatTensor] = None,
1741
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1742
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1743
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1744
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1745
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1746
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1747
+ labels: Optional[torch.LongTensor] = None,
1748
+ use_cache: Optional[bool] = None,
1749
+ output_attentions: Optional[bool] = None,
1750
+ output_hidden_states: Optional[bool] = None,
1751
+ return_dict: Optional[bool] = None,
1752
+ reduction: Optional[str] = "mean",
1753
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1754
+ r"""
1755
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1756
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1757
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1758
+ labels in `[0, ..., config.vocab_size]`
1759
+
1760
+ Returns:
1761
+
1762
+ Examples:
1763
+
1764
+ ```python
1765
+ >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
1766
+
1767
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1768
+ >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1769
+
1770
+ >>> # training
1771
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1772
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1773
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1774
+ >>> loss = outputs.loss
1775
+ >>> logits = outputs.logits
1776
+
1777
+ >>> # inference
1778
+ >>> input_ids = tokenizer(
1779
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1780
+ ... ).input_ids # Batch size 1
1781
+ >>> outputs = model.generate(input_ids)
1782
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1783
+ >>> # studies have shown that owning a dog is good for you.
1784
+ ```"""
1785
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1786
+ return_dict = (
1787
+ return_dict if return_dict is not None else self.config.use_return_dict
1788
+ )
1789
+
1790
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1791
+ if head_mask is not None and decoder_head_mask is None:
1792
+ if self.config.num_layers == self.config.num_decoder_layers:
1793
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
1794
+ decoder_head_mask = head_mask
1795
+
1796
+ # Encode if needed (training, first prediction pass)
1797
+ if encoder_outputs is None:
1798
+ # Convert encoder inputs in embeddings if needed
1799
+ encoder_outputs = self.encoder(
1800
+ input_ids=input_ids,
1801
+ attention_mask=attention_mask,
1802
+ inputs_embeds=inputs_embeds,
1803
+ head_mask=head_mask,
1804
+ output_attentions=output_attentions,
1805
+ output_hidden_states=output_hidden_states,
1806
+ return_dict=return_dict,
1807
+ )
1808
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1809
+ encoder_outputs = BaseModelOutput(
1810
+ last_hidden_state=encoder_outputs[0],
1811
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1812
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1813
+ )
1814
+
1815
+ hidden_states = encoder_outputs[0]
1816
+
1817
+ if self.model_parallel:
1818
+ torch.cuda.set_device(self.decoder.first_device)
1819
+
1820
+ if (
1821
+ labels is not None
1822
+ and decoder_input_ids is None
1823
+ and decoder_inputs_embeds is None
1824
+ ):
1825
+ # get decoder inputs from shifting lm labels to the right
1826
+ decoder_input_ids = self._shift_right(labels)
1827
+
1828
+ # Set device for model parallelism
1829
+ if self.model_parallel:
1830
+ torch.cuda.set_device(self.decoder.first_device)
1831
+ hidden_states = hidden_states.to(self.decoder.first_device)
1832
+ if decoder_input_ids is not None:
1833
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1834
+ if attention_mask is not None:
1835
+ attention_mask = attention_mask.to(self.decoder.first_device)
1836
+ if decoder_attention_mask is not None:
1837
+ decoder_attention_mask = decoder_attention_mask.to(
1838
+ self.decoder.first_device
1839
+ )
1840
+
1841
+ # Decode
1842
+ decoder_outputs = self.decoder(
1843
+ input_ids=decoder_input_ids,
1844
+ attention_mask=decoder_attention_mask,
1845
+ inputs_embeds=decoder_inputs_embeds,
1846
+ past_key_values=past_key_values,
1847
+ encoder_hidden_states=hidden_states,
1848
+ encoder_attention_mask=attention_mask,
1849
+ head_mask=decoder_head_mask,
1850
+ cross_attn_head_mask=cross_attn_head_mask,
1851
+ use_cache=use_cache,
1852
+ output_attentions=output_attentions,
1853
+ output_hidden_states=output_hidden_states,
1854
+ return_dict=return_dict,
1855
+ )
1856
+
1857
+ sequence_output = decoder_outputs[0]
1858
+
1859
+ # Set device for model parallelism
1860
+ if self.model_parallel:
1861
+ torch.cuda.set_device(self.encoder.first_device)
1862
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1863
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1864
+
1865
+ if self.config.tie_word_embeddings:
1866
+ # Rescale output before projecting on vocab
1867
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1868
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1869
+
1870
+ lm_logits = self.lm_head(sequence_output)
1871
+
1872
+ loss = None
1873
+ if labels is not None:
1874
+ loss_fct = CrossEntropyLoss(ignore_index=-100, reduction=reduction)
1875
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1876
+ if reduction == "none":
1877
+ loss = loss.view(lm_logits.size(0), -1).sum(1)
1878
+
1879
+ if not return_dict:
1880
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1881
+ return ((loss,) + output) if loss is not None else output
1882
+
1883
+ return Seq2SeqLMOutput(
1884
+ loss=loss,
1885
+ logits=lm_logits,
1886
+ past_key_values=decoder_outputs.past_key_values,
1887
+ decoder_hidden_states=decoder_outputs.hidden_states,
1888
+ decoder_attentions=decoder_outputs.attentions,
1889
+ cross_attentions=decoder_outputs.cross_attentions,
1890
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1891
+ encoder_hidden_states=encoder_outputs.hidden_states,
1892
+ encoder_attentions=encoder_outputs.attentions,
1893
+ )
1894
+
1895
+ def prepare_inputs_for_generation(
1896
+ self,
1897
+ input_ids,
1898
+ past=None,
1899
+ attention_mask=None,
1900
+ head_mask=None,
1901
+ decoder_head_mask=None,
1902
+ cross_attn_head_mask=None,
1903
+ use_cache=None,
1904
+ encoder_outputs=None,
1905
+ **kwargs,
1906
+ ):
1907
+
1908
+ # cut decoder_input_ids if past is used
1909
+ if past is not None:
1910
+ input_ids = input_ids[:, -1:]
1911
+
1912
+ return {
1913
+ "decoder_input_ids": input_ids,
1914
+ "past_key_values": past,
1915
+ "encoder_outputs": encoder_outputs,
1916
+ "attention_mask": attention_mask,
1917
+ "head_mask": head_mask,
1918
+ "decoder_head_mask": decoder_head_mask,
1919
+ "cross_attn_head_mask": cross_attn_head_mask,
1920
+ "use_cache": use_cache,
1921
+ }
1922
+
1923
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1924
+ return self._shift_right(labels)
1925
+
1926
+ def _reorder_cache(self, past, beam_idx):
1927
+ # if decoder past is not included in output
1928
+ # speedy decoding is disabled and no need to reorder
1929
+ if past is None:
1930
+ logger.warning(
1931
+ "You might want to consider setting `use_cache=True` to speed up decoding"
1932
+ )
1933
+ return past
1934
+
1935
+ reordered_decoder_past = ()
1936
+ for layer_past_states in past:
1937
+ # get the correct batch idx from layer past batch dim
1938
+ # batch dim of `past` is at 2nd position
1939
+ reordered_layer_past_states = ()
1940
+ for layer_past_state in layer_past_states:
1941
+ # need to set correct `past` for each of the four key / value states
1942
+ reordered_layer_past_states = reordered_layer_past_states + (
1943
+ layer_past_state.index_select(
1944
+ 0, beam_idx.to(layer_past_state.device)
1945
+ ),
1946
+ )
1947
+
1948
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1949
+ assert len(reordered_layer_past_states) == len(layer_past_states)
1950
+
1951
+ reordered_decoder_past = reordered_decoder_past + (
1952
+ reordered_layer_past_states,
1953
+ )
1954
+ return reordered_decoder_past
1955
+
1956
+
1957
+ @add_start_docstrings(
1958
+ "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1959
+ T5_START_DOCSTRING,
1960
+ )
1961
+ class T5EncoderModel(T5PreTrainedModel):
1962
+ authorized_missing_keys = [
1963
+ r"encoder.embed_tokens.weight",
1964
+ ]
1965
+
1966
+ def __init__(self, config: T5Config):
1967
+ super().__init__(config)
1968
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1969
+
1970
+ encoder_config = copy.deepcopy(config)
1971
+ encoder_config.use_cache = False
1972
+ encoder_config.is_encoder_decoder = False
1973
+ self.encoder = T5Stack(encoder_config, self.shared)
1974
+
1975
+ # Initialize weights and apply final processing
1976
+ self.post_init()
1977
+
1978
+ # Model parallel
1979
+ self.model_parallel = False
1980
+ self.device_map = None
1981
+
1982
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1983
+ def parallelize(self, device_map=None):
1984
+ self.device_map = (
1985
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1986
+ if device_map is None
1987
+ else device_map
1988
+ )
1989
+ assert_device_map(self.device_map, len(self.encoder.block))
1990
+ self.encoder.parallelize(self.device_map)
1991
+ self.model_parallel = True
1992
+
1993
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1994
+ def deparallelize(self):
1995
+ self.encoder.deparallelize()
1996
+ self.encoder = self.encoder.to("cpu")
1997
+ self.model_parallel = False
1998
+ self.device_map = None
1999
+ torch.cuda.empty_cache()
2000
+
2001
+ def get_input_embeddings(self):
2002
+ return self.shared
2003
+
2004
+ def set_input_embeddings(self, new_embeddings):
2005
+ self.shared = new_embeddings
2006
+ self.encoder.set_input_embeddings(new_embeddings)
2007
+
2008
+ def get_encoder(self):
2009
+ return self.encoder
2010
+
2011
+ def _prune_heads(self, heads_to_prune):
2012
+ """
2013
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
2014
+ class PreTrainedModel
2015
+ """
2016
+ for layer, heads in heads_to_prune.items():
2017
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
2018
+
2019
+ @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
2020
+ @replace_return_docstrings(
2021
+ output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC
2022
+ )
2023
+ def forward(
2024
+ self,
2025
+ input_ids: Optional[torch.LongTensor] = None,
2026
+ attention_mask: Optional[torch.FloatTensor] = None,
2027
+ head_mask: Optional[torch.FloatTensor] = None,
2028
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2029
+ output_attentions: Optional[bool] = None,
2030
+ output_hidden_states: Optional[bool] = None,
2031
+ return_dict: Optional[bool] = None,
2032
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
2033
+ r"""
2034
+ Returns:
2035
+
2036
+ Example:
2037
+
2038
+ ```python
2039
+ >>> from transformers import T5Tokenizer, T5EncoderModel
2040
+
2041
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
2042
+ >>> model = T5EncoderModel.from_pretrained("t5-small")
2043
+ >>> input_ids = tokenizer(
2044
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
2045
+ ... ).input_ids # Batch size 1
2046
+ >>> outputs = model(input_ids=input_ids)
2047
+ >>> last_hidden_states = outputs.last_hidden_state
2048
+ ```"""
2049
+ return_dict = (
2050
+ return_dict if return_dict is not None else self.config.use_return_dict
2051
+ )
2052
+
2053
+ encoder_outputs = self.encoder(
2054
+ input_ids=input_ids,
2055
+ attention_mask=attention_mask,
2056
+ inputs_embeds=inputs_embeds,
2057
+ head_mask=head_mask,
2058
+ output_attentions=output_attentions,
2059
+ output_hidden_states=output_hidden_states,
2060
+ return_dict=return_dict,
2061
+ )
2062
+
2063
+ return encoder_outputs
models_viclip/backbones/clip/__pycache__/clip_text.cpython-310.pyc ADDED
Binary file (9 kB). View file
 
models_viclip/backbones/clip/__pycache__/clip_text.cpython-38.pyc ADDED
Binary file (9.28 kB). View file
 
models_viclip/backbones/clip/__pycache__/clip_vision.cpython-310.pyc ADDED
Binary file (9.75 kB). View file