iioSnail commited on
Commit
3a346ec
·
verified ·
1 Parent(s): 5300beb

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ config/ms_yahei.ttf filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "iioSnail/NamBert-for-csc",
3
+ "architectures": [
4
+ "NamBertForCSC"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModel": "csc_model.NamBertForCSC"
9
+ },
10
+ "classifier_dropout": null,
11
+ "directionality": "bidi",
12
+ "gradient_checkpointing": false,
13
+ "hidden_act": "gelu",
14
+ "hidden_dropout_prob": 0.1,
15
+ "hidden_size": 768,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 3072,
18
+ "layer_norm_eps": 1e-12,
19
+ "max_position_embeddings": 512,
20
+ "model_type": "bert",
21
+ "num_attention_heads": 12,
22
+ "num_hidden_layers": 12,
23
+ "pad_token_id": 0,
24
+ "pooler_fc_size": 768,
25
+ "pooler_num_attention_heads": 12,
26
+ "pooler_num_fc_layers": 3,
27
+ "pooler_size_per_head": 128,
28
+ "pooler_type": "first_token_transform",
29
+ "position_embedding_type": "absolute",
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.33.2",
32
+ "type_vocab_size": 2,
33
+ "use_cache": true,
34
+ "vocab_size": 21128
35
+ }
config/ms_yahei.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3c0e7bbcec69ee4765a53831c7be310acaca1ec1b408974ca4f4c73c1aa400c
3
+ size 15044440
csc_model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+ from torch import nn
9
+ from torchvision.models import resnet18
10
+ from transformers import BertPreTrainedModel, BertForMaskedLM
11
+ from transformers.activations import GELUActivation
12
+ from transformers.modeling_outputs import MaskedLMOutput
13
+
14
+ cache_path = Path(os.path.abspath(__file__)).parent
15
+
16
+
17
+ def download_file(filename: str, path: Path):
18
+ if os.path.exists(cache_path / filename):
19
+ return
20
+
21
+ if os.path.exists(path / filename):
22
+ shutil.copyfile(path / filename, cache_path / filename)
23
+ return
24
+
25
+ hf_hub_download(
26
+ "iioSnail/NamBert-for-csc",
27
+ filename,
28
+ local_dir=cache_path,
29
+ )
30
+ time.sleep(0.2)
31
+
32
+
33
+ class NamBertForCSC(BertPreTrainedModel):
34
+
35
+ def __init__(self, config):
36
+ super(NamBertForCSC, self).__init__(config)
37
+
38
+ self.bert = BertForMaskedLM(config).bert
39
+ self.token_forget_gate = nn.Linear(768, 768, bias=False)
40
+
41
+ self.pinyin_feature_size = 6
42
+ self.pinyin_embeddings = PinyinManualEmbeddings()
43
+
44
+ self.glyph_feature_size = 56
45
+ self.glyph_embeddings = GlyphDenseEmbedding()
46
+
47
+ self.cls = BertOnlyMLMHead(768 + self.pinyin_feature_size + self.glyph_feature_size, config.vocab_size,
48
+ layer_num=1)
49
+
50
+ def forward(self, input_ids, attention_mask, token_type_ids, pinyin_ids, images, **kwargs):
51
+ batch_size = input_ids.size(0)
52
+
53
+ del kwargs['offset_mapping']
54
+
55
+ bert_outputs = self.bert(input_ids=input_ids,
56
+ attention_mask=attention_mask,
57
+ token_type_ids=token_type_ids,
58
+ **kwargs)
59
+
60
+ token_embeddings = self.bert.embeddings(input_ids)
61
+ token_embeddings = token_embeddings * self.token_forget_gate(token_embeddings).sigmoid()
62
+ bert_outputs.last_hidden_state += token_embeddings
63
+
64
+ pinyin_embeddings = self.pinyin_embeddings(pinyin_ids)
65
+ pinyin_embeddings = pinyin_embeddings.view(batch_size, -1, self.pinyin_feature_size)
66
+
67
+ glyph_embeddings = self.glyph_embeddings(images)
68
+ glyph_embeddings = glyph_embeddings.view(batch_size, -1, self.glyph_feature_size)
69
+
70
+ hidden_states = torch.concat([bert_outputs.last_hidden_state,
71
+ pinyin_embeddings,
72
+ glyph_embeddings], dim=-1)
73
+
74
+ logits = self.cls(hidden_states)
75
+
76
+ return MaskedLMOutput(
77
+ logits=logits,
78
+ hidden_states=bert_outputs.hidden_states,
79
+ attentions=bert_outputs.attentions,
80
+ )
81
+
82
+
83
+ class PinyinManualEmbeddings(nn.Module):
84
+
85
+ def __init__(self):
86
+ super(PinyinManualEmbeddings, self).__init__()
87
+ self.pinyin_feature_size = 6
88
+ self.embedding_layer = nn.Linear(6, 6, bias=True)
89
+
90
+ def forward(self, inputs):
91
+ fill = self.pinyin_feature_size - inputs.size(1)
92
+ if fill > 0:
93
+ inputs = torch.concat([inputs, torch.zeros((len(inputs), fill), device=inputs.device)], dim=1).long()
94
+ inputs = self.embedding_layer(inputs.float())
95
+ return inputs
96
+
97
+
98
+ class GlyphDenseEmbedding(nn.Module):
99
+
100
+ def __init__(self, font_size=32):
101
+ super(GlyphDenseEmbedding, self).__init__()
102
+ self.font_size = font_size
103
+ self.embeddings = nn.Sequential(
104
+ nn.Linear(1024, 512),
105
+ nn.ReLU(),
106
+ nn.Dropout(0.15),
107
+ nn.Linear(512, 256),
108
+ nn.ReLU(),
109
+ nn.Dropout(0.15),
110
+ nn.Linear(256, 56),
111
+ nn.Tanh()
112
+ )
113
+
114
+ def forward(self, images):
115
+ batch_size = len(images)
116
+ images = images.view(batch_size, -1) / 255.
117
+ return self.embeddings(images)
118
+
119
+
120
+ class BertOnlyMLMHead(nn.Module):
121
+ def __init__(self, hidden_size, vocab_size, activation='gelu', layer_num=1):
122
+ super().__init__()
123
+ self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
124
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
125
+ self.decoder.bias = self.bias
126
+
127
+ self.activation = None
128
+ if activation == 'gelu':
129
+ self.activation = GELUActivation()
130
+ elif activation == 'tanh':
131
+ self.activation = nn.Tanh()
132
+ elif activation == 'sigmoid':
133
+ self.activation = nn.Sigmoid()
134
+ else:
135
+ raise Exception("Please add activation function here.")
136
+
137
+ self.heads = []
138
+
139
+ for i in range(layer_num):
140
+ self.heads.append(nn.Sequential(
141
+ nn.Linear(hidden_size, hidden_size),
142
+ self.activation,
143
+ nn.LayerNorm(hidden_size, eps=1e-12, elementwise_affine=True),
144
+ ))
145
+
146
+ self.predictions = nn.Sequential(
147
+ *self.heads,
148
+ self.decoder
149
+ )
150
+
151
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
152
+ return self.predictions(sequence_output)
csc_tokenizer.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import time
4
+ from pathlib import Path
5
+ from typing import List, Union, Optional
6
+
7
+ import numpy as np
8
+ import pypinyin
9
+ import torch
10
+ from PIL import ImageFont
11
+ from huggingface_hub import hf_hub_download
12
+ from torch import NoneType
13
+ from torch.nn.utils.rnn import pad_sequence
14
+ from transformers.tokenization_utils_base import TruncationStrategy
15
+ from transformers.utils import PaddingStrategy
16
+ from transformers.utils.generic import TensorType
17
+
18
+ try:
19
+ from tokenizers import BertWordPieceTokenizer
20
+ except:
21
+ from tokenizers.implementations import BertWordPieceTokenizer
22
+
23
+ from transformers import BertTokenizerFast, BatchEncoding
24
+
25
+ cache_path = Path(os.path.abspath(__file__)).parent
26
+
27
+
28
+ def download_file(filename: str, path: Path):
29
+ if os.path.exists(cache_path / filename):
30
+ return
31
+
32
+ if os.path.exists(path / filename):
33
+ shutil.copyfile(path / filename, cache_path / filename)
34
+ return
35
+
36
+ hf_hub_download(
37
+ "iioSnail/NamBert-for-csc",
38
+ filename,
39
+ local_dir=cache_path
40
+ )
41
+ time.sleep(0.2)
42
+
43
+
44
+ class NamBertTokenizer(BertTokenizerFast):
45
+
46
+ def __init__(self, **kwargs):
47
+ super(NamBertTokenizer, self).__init__(**kwargs)
48
+
49
+ self.path = Path(kwargs['name_or_path'])
50
+ vocab_file = cache_path / 'vocab.txt'
51
+ config_path = cache_path / 'config'
52
+ if not os.path.exists(config_path):
53
+ os.makedirs(config_path)
54
+
55
+ self.max_length = 20480
56
+ self.font_size = 32
57
+
58
+ download_file('vocab.txt', self.path)
59
+ download_file('config/ms_yahei.ttf', self.path)
60
+
61
+ self.tokenizer = BertWordPieceTokenizer(str(vocab_file))
62
+
63
+ font = ImageFont.truetype(str(cache_path / 'config' / "ms_yahei.ttf"), size=self.font_size)
64
+ vocab = self.tokenizer.get_vocab().items()
65
+ self.input_helper = InputHelper(font, vocab)
66
+
67
+ def __call__(self,
68
+ text: Union[str, List[str], List[List[str]]] = None,
69
+ text_pair: Union[str, List[str], List[List[str]], NoneType] = None,
70
+ text_target: Union[str, List[str], List[List[str]]] = None,
71
+ text_pair_target: Union[str, List[str], List[List[str]], NoneType] = None,
72
+ add_special_tokens: bool = True,
73
+ padding: Union[bool, str, PaddingStrategy] = False,
74
+ truncation: Union[bool, str, TruncationStrategy] = None,
75
+ max_length: Optional[int] = None,
76
+ stride: int = 0,
77
+ is_split_into_words: bool = False,
78
+ pad_to_multiple_of: Optional[int] = None,
79
+ return_tensors: Union[str, TensorType, NoneType] = None,
80
+ return_token_type_ids: Optional[bool] = None,
81
+ return_attention_mask: Optional[bool] = None,
82
+ return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False,
83
+ return_offsets_mapping: bool = False,
84
+ return_length: bool = False,
85
+ verbose: bool = True, **kwargs) -> BatchEncoding:
86
+ encoding = super(NamBertTokenizer, self).__call__(
87
+ text=text,
88
+ text_pair=text_pair,
89
+ text_target=text_target,
90
+ text_pair_target=text_pair_target,
91
+ add_special_tokens=add_special_tokens,
92
+ padding=padding,
93
+ truncation=truncation,
94
+ max_length=max_length,
95
+ stride=stride,
96
+ is_split_into_words=is_split_into_words,
97
+ pad_to_multiple_of=pad_to_multiple_of,
98
+ return_tensors='pt',
99
+ return_token_type_ids=return_token_type_ids,
100
+ return_attention_mask=return_attention_mask,
101
+ return_overflowing_tokens=return_overflowing_tokens,
102
+ return_offsets_mapping=True,
103
+ return_length=return_length,
104
+ verbose=verbose,
105
+ )
106
+
107
+ input_ids = encoding.input_ids
108
+ encoding['pinyin_ids'] = self.input_helper.convert_tokens_to_pinyin_embeddings(input_ids.view(-1))
109
+ encoding['images'] = self.input_helper.convert_tokens_to_images(input_ids.view(-1))
110
+
111
+ return encoding
112
+
113
+ def restore_ids(self, target_ids, input_ids):
114
+ for i in range(len(target_ids)):
115
+ for j in range(len(target_ids[i])):
116
+ if target_ids[i][j] == 1 or target_ids[i][j] == 0:
117
+ target_ids[i][j] = input_ids[i][j]
118
+
119
+ return target_ids
120
+
121
+
122
+ class InputHelper:
123
+
124
+ def __init__(self, font, vocab):
125
+ self.font = font
126
+ self.vocab = vocab
127
+
128
+ self.pinyin_embedding_cache = None
129
+ self._init_pinyin_embedding_cache()
130
+
131
+ self.token_images_cache = None
132
+ self._init_token_images_cache()
133
+
134
+ def _init_pinyin_embedding_cache(self):
135
+ self.pinyin_embedding_cache = {}
136
+ for token, id in self.vocab:
137
+ self.pinyin_embedding_cache[id] = convert_char_to_pinyin(token)
138
+
139
+ def _init_token_images_cache(self):
140
+ self.token_images_cache = {}
141
+ for token, id in self.vocab:
142
+ self.token_images_cache[id] = convert_char_to_image(self.font, token, 32)
143
+
144
+ def convert_tokens_to_pinyin_embeddings(self, input_ids):
145
+ input_pinyins = []
146
+ for i, input_id in enumerate(input_ids):
147
+ input_pinyins.append(self.pinyin_embedding_cache.get(input_id.item(), torch.LongTensor([0])))
148
+
149
+ return pad_sequence(input_pinyins, batch_first=True)
150
+
151
+ def convert_tokens_to_images(self, input_ids):
152
+ images = []
153
+ for i, input_id in enumerate(input_ids):
154
+ images.append(self.token_images_cache.get(input_id.item(), torch.zeros(32, 32)))
155
+ return torch.stack(images)
156
+
157
+
158
+ def convert_char_to_pinyin(character, size=-1, tone=False):
159
+ if not is_chinese(character):
160
+ return torch.LongTensor([0] * max(size, 1))
161
+
162
+ if tone:
163
+ pinyin = pypinyin.pinyin(character, style=pypinyin.TONE3)[0][0]
164
+ else:
165
+ pinyin = pypinyin.pinyin(character, style=pypinyin.NORMAL)[0][0]
166
+
167
+ if not tone:
168
+ embeddings = torch.tensor([ord(letter) - 96 for letter in pinyin])
169
+ else:
170
+ embeddings = []
171
+ for letter in pinyin:
172
+ if letter.isnumeric():
173
+ embeddings.append(int(letter) + 27)
174
+ else:
175
+ embeddings.append(ord(letter) - 96)
176
+ embeddings = torch.tensor(embeddings)
177
+
178
+ if size > len(embeddings):
179
+ padding = torch.zeros(size - len(embeddings))
180
+ embeddings = torch.concat([embeddings, padding])
181
+
182
+ return embeddings
183
+
184
+
185
+ def convert_char_to_image(font, character, font_size=32):
186
+ image = font.getmask(character)
187
+ image = np.asarray(image).astype(np.float32).reshape(image.size[::-1])
188
+
189
+ image = image[:font_size, :font_size]
190
+
191
+ if image.size != (font_size, font_size):
192
+ back_image = np.zeros((font_size, font_size)).astype(np.float32)
193
+ offset0 = (font_size - image.shape[0]) // 2
194
+ offset1 = (font_size - image.shape[1]) // 2
195
+ back_image[offset0:offset0 + image.shape[0], offset1:offset1 + image.shape[1]] = image
196
+ image = back_image
197
+
198
+ return torch.tensor(image)
199
+
200
+
201
+ def is_chinese(uchar):
202
+ return '\u4e00' <= uchar <= '\u9fa5'
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0e5d81862b4673dab61f223c6e139f32e6a03625002803aae0e81b81de737dc
3
+ size 484812057
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "csc_tokenizer.NamBertTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "clean_up_tokenization_spaces": true,
9
+ "cls_token": "[CLS]",
10
+ "do_lower_case": true,
11
+ "mask_token": "[MASK]",
12
+ "model_max_length": 1000000000000000019884624838656,
13
+ "pad_token": "[PAD]",
14
+ "sep_token": "[SEP]",
15
+ "strip_accents": null,
16
+ "tokenize_chinese_chars": true,
17
+ "tokenizer_class": "NamBertTokenizer",
18
+ "unk_token": "[UNK]"
19
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff