Add new SentenceTransformer model
Browse files- .gitattributes +3 -0
- 1_Pooling/config.json +10 -0
- README.md +99 -155
- config.json +1 -1
- config_sentence_transformers.json +14 -0
- entity_embeddings.npy +3 -0
- entity_linker/config.json +1 -0
- entity_linker/data.npy +3 -0
- entity_linker/kb_id.trie +3 -0
- entity_linker/name.trie +3 -0
- entity_linker/offsets.npy +3 -0
- entity_vocab.tsv +3 -0
- modules.json +14 -0
- sentence_bert_config.json +4 -0
- special_tokens_map.json +37 -0
- tokenization_kpr.py +526 -0
- tokenizer_config.json +64 -0
- vocab.txt +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
entity_linker/kb_id.trie filter=lfs diff=lfs merge=lfs -text
|
37 |
+
entity_linker/name.trie filter=lfs diff=lfs merge=lfs -text
|
38 |
+
entity_vocab.tsv filter=lfs diff=lfs merge=lfs -text
|
1_Pooling/config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 768,
|
3 |
+
"pooling_mode_cls_token": true,
|
4 |
+
"pooling_mode_mean_tokens": false,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
7 |
+
"pooling_mode_weightedmean_tokens": false,
|
8 |
+
"pooling_mode_lasttoken": false,
|
9 |
+
"include_prompt": true
|
10 |
+
}
|
README.md
CHANGED
@@ -1,199 +1,143 @@
|
|
1 |
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
4 |
---
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
-
|
10 |
|
|
|
11 |
|
12 |
## Model Details
|
13 |
|
14 |
### Model Description
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
19 |
-
|
20 |
-
- **Developed by:** [More Information Needed]
|
21 |
-
- **Funded by [optional]:** [More Information Needed]
|
22 |
-
- **Shared by [optional]:** [More Information Needed]
|
23 |
-
- **Model type:** [More Information Needed]
|
24 |
-
- **Language(s) (NLP):** [More Information Needed]
|
25 |
-
- **License:** [More Information Needed]
|
26 |
-
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
|
28 |
-
|
|
|
|
|
29 |
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
-
##
|
37 |
|
38 |
-
|
39 |
|
40 |
-
|
41 |
|
42 |
-
|
|
|
|
|
43 |
|
44 |
-
|
|
|
|
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
|
|
51 |
|
52 |
-
|
53 |
|
54 |
-
|
|
|
55 |
|
56 |
-
|
|
|
57 |
|
58 |
-
|
59 |
|
60 |
-
|
61 |
|
62 |
-
|
|
|
63 |
|
64 |
-
|
|
|
65 |
|
66 |
-
|
|
|
67 |
|
68 |
-
|
|
|
69 |
|
70 |
-
|
|
|
71 |
|
72 |
-
|
|
|
73 |
|
74 |
-
|
|
|
75 |
|
76 |
## Training Details
|
77 |
|
78 |
-
###
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
-
|
88 |
-
#### Preprocessing [optional]
|
89 |
-
|
90 |
-
[More Information Needed]
|
91 |
-
|
92 |
-
|
93 |
-
#### Training Hyperparameters
|
94 |
-
|
95 |
-
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
-
|
97 |
-
#### Speeds, Sizes, Times [optional]
|
98 |
-
|
99 |
-
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
-
|
101 |
-
[More Information Needed]
|
102 |
-
|
103 |
-
## Evaluation
|
104 |
-
|
105 |
-
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
-
|
107 |
-
### Testing Data, Factors & Metrics
|
108 |
-
|
109 |
-
#### Testing Data
|
110 |
-
|
111 |
-
<!-- This should link to a Dataset Card if possible. -->
|
112 |
-
|
113 |
-
[More Information Needed]
|
114 |
-
|
115 |
-
#### Factors
|
116 |
-
|
117 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
-
|
119 |
-
[More Information Needed]
|
120 |
-
|
121 |
-
#### Metrics
|
122 |
-
|
123 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
-
|
125 |
-
[More Information Needed]
|
126 |
-
|
127 |
-
### Results
|
128 |
-
|
129 |
-
[More Information Needed]
|
130 |
-
|
131 |
-
#### Summary
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
## Model Examination [optional]
|
136 |
-
|
137 |
-
<!-- Relevant interpretability work for the model goes here -->
|
138 |
-
|
139 |
-
[More Information Needed]
|
140 |
-
|
141 |
-
## Environmental Impact
|
142 |
-
|
143 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
-
|
145 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
-
|
147 |
-
- **Hardware Type:** [More Information Needed]
|
148 |
-
- **Hours used:** [More Information Needed]
|
149 |
-
- **Cloud Provider:** [More Information Needed]
|
150 |
-
- **Compute Region:** [More Information Needed]
|
151 |
-
- **Carbon Emitted:** [More Information Needed]
|
152 |
-
|
153 |
-
## Technical Specifications [optional]
|
154 |
-
|
155 |
-
### Model Architecture and Objective
|
156 |
-
|
157 |
-
[More Information Needed]
|
158 |
-
|
159 |
-
### Compute Infrastructure
|
160 |
-
|
161 |
-
[More Information Needed]
|
162 |
-
|
163 |
-
#### Hardware
|
164 |
-
|
165 |
-
[More Information Needed]
|
166 |
-
|
167 |
-
#### Software
|
168 |
-
|
169 |
-
[More Information Needed]
|
170 |
-
|
171 |
-
## Citation [optional]
|
172 |
-
|
173 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
-
|
175 |
-
**BibTeX:**
|
176 |
-
|
177 |
-
[More Information Needed]
|
178 |
-
|
179 |
-
**APA:**
|
180 |
-
|
181 |
-
[More Information Needed]
|
182 |
-
|
183 |
-
## Glossary [optional]
|
184 |
|
185 |
-
|
186 |
|
187 |
-
|
188 |
|
189 |
-
|
|
|
190 |
|
191 |
-
|
|
|
192 |
|
193 |
-
|
|
|
194 |
|
195 |
-
|
|
|
196 |
|
|
|
197 |
## Model Card Contact
|
198 |
|
199 |
-
|
|
|
|
1 |
---
|
2 |
+
tags:
|
3 |
+
- sentence-transformers
|
4 |
+
- sentence-similarity
|
5 |
+
- feature-extraction
|
6 |
+
- dense
|
7 |
+
pipeline_tag: sentence-similarity
|
8 |
+
library_name: sentence-transformers
|
9 |
---
|
10 |
|
11 |
+
# SentenceTransformer
|
|
|
|
|
|
|
12 |
|
13 |
+
This is a [sentence-transformers](https://www.SBERT.net) model trained. It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
|
14 |
|
15 |
## Model Details
|
16 |
|
17 |
### Model Description
|
18 |
+
- **Model Type:** Sentence Transformer
|
19 |
+
<!-- - **Base model:** [Unknown](https://huggingface.co/unknown) -->
|
20 |
+
- **Maximum Sequence Length:** 512 tokens
|
21 |
+
- **Output Dimensionality:** 768 dimensions
|
22 |
+
- **Similarity Function:** Cosine Similarity
|
23 |
+
<!-- - **Training Dataset:** Unknown -->
|
24 |
+
<!-- - **Language:** Unknown -->
|
25 |
+
<!-- - **License:** Unknown -->
|
26 |
|
27 |
+
### Model Sources
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
- **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
|
30 |
+
- **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
|
31 |
+
- **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
|
32 |
|
33 |
+
### Full Model Architecture
|
34 |
|
35 |
+
```
|
36 |
+
SentenceTransformer(
|
37 |
+
(0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'KPRModelForBert'})
|
38 |
+
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
|
39 |
+
)
|
40 |
+
```
|
41 |
|
42 |
+
## Usage
|
43 |
|
44 |
+
### Direct Usage (Sentence Transformers)
|
45 |
|
46 |
+
First install the Sentence Transformers library:
|
47 |
|
48 |
+
```bash
|
49 |
+
pip install -U sentence-transformers
|
50 |
+
```
|
51 |
|
52 |
+
Then you can load this model and run inference.
|
53 |
+
```python
|
54 |
+
from sentence_transformers import SentenceTransformer
|
55 |
|
56 |
+
# Download from the 🤗 Hub
|
57 |
+
model = SentenceTransformer("knowledgeable-ai/kpr-bge-base-en-v1.5")
|
58 |
+
# Run inference
|
59 |
+
sentences = [
|
60 |
+
'The weather is lovely today.',
|
61 |
+
"It's so sunny outside!",
|
62 |
+
'He drove to the stadium.',
|
63 |
+
]
|
64 |
+
embeddings = model.encode(sentences)
|
65 |
+
print(embeddings.shape)
|
66 |
+
# [3, 768]
|
67 |
|
68 |
+
# Get the similarity scores for the embeddings
|
69 |
+
similarities = model.similarity(embeddings, embeddings)
|
70 |
+
print(similarities)
|
71 |
+
# tensor([[1.0000, 0.7985, 0.4422],
|
72 |
+
# [0.7985, 1.0000, 0.4318],
|
73 |
+
# [0.4422, 0.4318, 1.0000]])
|
74 |
+
```
|
75 |
|
76 |
+
<!--
|
77 |
+
### Direct Usage (Transformers)
|
78 |
|
79 |
+
<details><summary>Click to see the direct usage in Transformers</summary>
|
80 |
|
81 |
+
</details>
|
82 |
+
-->
|
83 |
|
84 |
+
<!--
|
85 |
+
### Downstream Usage (Sentence Transformers)
|
86 |
|
87 |
+
You can finetune this model on your own dataset.
|
88 |
|
89 |
+
<details><summary>Click to expand</summary>
|
90 |
|
91 |
+
</details>
|
92 |
+
-->
|
93 |
|
94 |
+
<!--
|
95 |
+
### Out-of-Scope Use
|
96 |
|
97 |
+
*List how the model may foreseeably be misused and address what users ought not to do with the model.*
|
98 |
+
-->
|
99 |
|
100 |
+
<!--
|
101 |
+
## Bias, Risks and Limitations
|
102 |
|
103 |
+
*What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
|
104 |
+
-->
|
105 |
|
106 |
+
<!--
|
107 |
+
### Recommendations
|
108 |
|
109 |
+
*What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
|
110 |
+
-->
|
111 |
|
112 |
## Training Details
|
113 |
|
114 |
+
### Framework Versions
|
115 |
+
- Python: 3.10.14
|
116 |
+
- Sentence Transformers: 5.2.0.dev0
|
117 |
+
- Transformers: 4.55.4
|
118 |
+
- PyTorch: 2.4.0+cu121
|
119 |
+
- Accelerate: 0.34.2
|
120 |
+
- Datasets: 2.16.1
|
121 |
+
- Tokenizers: 0.21.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
## Citation
|
124 |
|
125 |
+
### BibTeX
|
126 |
|
127 |
+
<!--
|
128 |
+
## Glossary
|
129 |
|
130 |
+
*Clearly define terms in order to be accessible across audiences.*
|
131 |
+
-->
|
132 |
|
133 |
+
<!--
|
134 |
+
## Model Card Authors
|
135 |
|
136 |
+
*Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
|
137 |
+
-->
|
138 |
|
139 |
+
<!--
|
140 |
## Model Card Contact
|
141 |
|
142 |
+
*Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
|
143 |
+
-->
|
config.json
CHANGED
@@ -35,7 +35,7 @@
|
|
35 |
"similarity_function": "cosine",
|
36 |
"similarity_temperature": 0.02,
|
37 |
"torch_dtype": "float32",
|
38 |
-
"transformers_version": "4.55.
|
39 |
"type_vocab_size": 2,
|
40 |
"use_cache": true,
|
41 |
"use_entity_position_embeddings": true,
|
|
|
35 |
"similarity_function": "cosine",
|
36 |
"similarity_temperature": 0.02,
|
37 |
"torch_dtype": "float32",
|
38 |
+
"transformers_version": "4.55.4",
|
39 |
"type_vocab_size": 2,
|
40 |
"use_cache": true,
|
41 |
"use_entity_position_embeddings": true,
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "SentenceTransformer",
|
3 |
+
"__version__": {
|
4 |
+
"sentence_transformers": "5.2.0.dev0",
|
5 |
+
"transformers": "4.55.4",
|
6 |
+
"pytorch": "2.4.0+cu121"
|
7 |
+
},
|
8 |
+
"prompts": {
|
9 |
+
"query": "",
|
10 |
+
"document": ""
|
11 |
+
},
|
12 |
+
"default_prompt_name": null,
|
13 |
+
"similarity_fn_name": "cosine"
|
14 |
+
}
|
entity_embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd83489d63bb45008620d90ba274331981546081491cfdd94be5afea9cb1cfea
|
3 |
+
size 11126965376
|
entity_linker/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"max_mention_length": 100, "case_sensitive": false, "min_link_prob": 0.05, "min_prior_prob": 0.3, "min_link_count": 1}
|
entity_linker/data.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9e6bded234ea7be9250487f6e7ed26ccb3f04eb74443d78617d8328c3d3e41b
|
3 |
+
size 306472944
|
entity_linker/kb_id.trie
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21ac7411d8a0e9a5497c2a4695b2b115b166824d34df4d4d01492b600ce375d8
|
3 |
+
size 13211960
|
entity_linker/name.trie
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3897d996a84067d9dffaa1a826e114be2c8368c011534048facee4c18fa5b69
|
3 |
+
size 96855400
|
entity_linker/offsets.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ffb5444bbdcb9abcc63ef1a183cd6193c7346865a4b17c6ce7cbbed4c673ef2
|
3 |
+
size 70649476
|
entity_vocab.tsv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f835c91847892b455fe15eb75cc008e4504be4c37a7e1c0718165d029e03fda5
|
3 |
+
size 131894551
|
modules.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
}
|
14 |
+
]
|
sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 512,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": {
|
3 |
+
"content": "[CLS]",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"mask_token": {
|
10 |
+
"content": "[MASK]",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "[PAD]",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"sep_token": {
|
24 |
+
"content": "[SEP]",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"unk_token": {
|
31 |
+
"content": "[UNK]",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
}
|
37 |
+
}
|
tokenization_kpr.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import csv
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import NamedTuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import spacy
|
13 |
+
from marisa_trie import Trie
|
14 |
+
from transformers import BatchEncoding, BertTokenizer, PreTrainedTokenizerBase
|
15 |
+
|
16 |
+
NONE_ID = "<None>"
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class Mention:
|
21 |
+
kb_id: str | None
|
22 |
+
text: str
|
23 |
+
start: int
|
24 |
+
end: int
|
25 |
+
link_count: int | None
|
26 |
+
total_link_count: int | None
|
27 |
+
doc_count: int | None
|
28 |
+
|
29 |
+
@property
|
30 |
+
def span(self) -> tuple[int, int]:
|
31 |
+
return self.start, self.end
|
32 |
+
|
33 |
+
@property
|
34 |
+
def link_prob(self) -> float | None:
|
35 |
+
if self.doc_count is None or self.total_link_count is None:
|
36 |
+
return None
|
37 |
+
elif self.doc_count > 0:
|
38 |
+
return min(1.0, self.total_link_count / self.doc_count)
|
39 |
+
else:
|
40 |
+
return 0.0
|
41 |
+
|
42 |
+
@property
|
43 |
+
def prior_prob(self) -> float | None:
|
44 |
+
if self.link_count is None or self.total_link_count is None:
|
45 |
+
return None
|
46 |
+
elif self.total_link_count > 0:
|
47 |
+
return min(1.0, self.link_count / self.total_link_count)
|
48 |
+
else:
|
49 |
+
return 0.0
|
50 |
+
|
51 |
+
def __repr__(self):
|
52 |
+
return f"<Mention {self.text} -> {self.kb_id}>"
|
53 |
+
|
54 |
+
|
55 |
+
def get_tokenizer(language: str) -> spacy.tokenizer.Tokenizer:
|
56 |
+
language_obj = spacy.blank(language)
|
57 |
+
return language_obj.tokenizer
|
58 |
+
|
59 |
+
|
60 |
+
class DictionaryEntityLinker:
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
name_trie: Trie,
|
64 |
+
kb_id_trie: Trie,
|
65 |
+
data: np.ndarray,
|
66 |
+
offsets: np.ndarray,
|
67 |
+
max_mention_length: int,
|
68 |
+
case_sensitive: bool,
|
69 |
+
min_link_prob: float | None,
|
70 |
+
min_prior_prob: float | None,
|
71 |
+
min_link_count: int | None,
|
72 |
+
):
|
73 |
+
self._name_trie = name_trie
|
74 |
+
self._kb_id_trie = kb_id_trie
|
75 |
+
self._data = data
|
76 |
+
self._offsets = offsets
|
77 |
+
self._max_mention_length = max_mention_length
|
78 |
+
self._case_sensitive = case_sensitive
|
79 |
+
|
80 |
+
self._min_link_prob = min_link_prob
|
81 |
+
self._min_prior_prob = min_prior_prob
|
82 |
+
self._min_link_count = min_link_count
|
83 |
+
|
84 |
+
self._tokenizer = get_tokenizer("en")
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def load(
|
88 |
+
data_dir: str,
|
89 |
+
min_link_prob: float | None = None,
|
90 |
+
min_prior_prob: float | None = None,
|
91 |
+
min_link_count: int | None = None,
|
92 |
+
) -> "DictionaryEntityLinker":
|
93 |
+
data = np.load(os.path.join(data_dir, "data.npy"))
|
94 |
+
offsets = np.load(os.path.join(data_dir, "offsets.npy"))
|
95 |
+
name_trie = Trie()
|
96 |
+
name_trie.load(os.path.join(data_dir, "name.trie"))
|
97 |
+
kb_id_trie = Trie()
|
98 |
+
kb_id_trie.load(os.path.join(data_dir, "kb_id.trie"))
|
99 |
+
|
100 |
+
with open(os.path.join(data_dir, "config.json")) as config_file:
|
101 |
+
config = json.load(config_file)
|
102 |
+
|
103 |
+
if min_link_prob is None:
|
104 |
+
min_link_prob = config.get("min_link_prob", None)
|
105 |
+
|
106 |
+
if min_prior_prob is None:
|
107 |
+
min_prior_prob = config.get("min_prior_prob", None)
|
108 |
+
|
109 |
+
if min_link_count is None:
|
110 |
+
min_link_count = config.get("min_link_count", None)
|
111 |
+
|
112 |
+
return DictionaryEntityLinker(
|
113 |
+
name_trie=name_trie,
|
114 |
+
kb_id_trie=kb_id_trie,
|
115 |
+
data=data,
|
116 |
+
offsets=offsets,
|
117 |
+
max_mention_length=config["max_mention_length"],
|
118 |
+
case_sensitive=config["case_sensitive"],
|
119 |
+
min_link_prob=min_link_prob,
|
120 |
+
min_prior_prob=min_prior_prob,
|
121 |
+
min_link_count=min_link_count,
|
122 |
+
)
|
123 |
+
|
124 |
+
def detect_mentions(self, text: str) -> list[Mention]:
|
125 |
+
tokens = self._tokenizer(text)
|
126 |
+
end_offsets = frozenset(token.idx + len(token) for token in tokens)
|
127 |
+
if not self._case_sensitive:
|
128 |
+
text = text.lower()
|
129 |
+
|
130 |
+
ret = []
|
131 |
+
cur = 0
|
132 |
+
for token in tokens:
|
133 |
+
start = token.idx
|
134 |
+
if cur > start:
|
135 |
+
continue
|
136 |
+
|
137 |
+
for prefix in sorted(
|
138 |
+
self._name_trie.prefixes(text[start : start + self._max_mention_length]),
|
139 |
+
key=len,
|
140 |
+
reverse=True,
|
141 |
+
):
|
142 |
+
end = start + len(prefix)
|
143 |
+
if end in end_offsets:
|
144 |
+
matched = False
|
145 |
+
mention_idx = self._name_trie[prefix]
|
146 |
+
data_start, data_end = self._offsets[mention_idx : mention_idx + 2]
|
147 |
+
for item in self._data[data_start:data_end]:
|
148 |
+
if item.size == 4:
|
149 |
+
kb_idx, link_count, total_link_count, doc_count = item
|
150 |
+
elif item.size == 1:
|
151 |
+
(kb_idx,) = item
|
152 |
+
link_count, total_link_count, doc_count = None, None, None
|
153 |
+
else:
|
154 |
+
raise ValueError("Unexpected data array format")
|
155 |
+
|
156 |
+
mention = Mention(
|
157 |
+
kb_id=self._kb_id_trie.restore_key(kb_idx),
|
158 |
+
text=prefix,
|
159 |
+
start=start,
|
160 |
+
end=end,
|
161 |
+
link_count=link_count,
|
162 |
+
total_link_count=total_link_count,
|
163 |
+
doc_count=doc_count,
|
164 |
+
)
|
165 |
+
if item.size == 1 or (
|
166 |
+
mention.link_prob >= self._min_link_prob
|
167 |
+
and mention.prior_prob >= self._min_prior_prob
|
168 |
+
and mention.link_count >= self._min_link_count
|
169 |
+
):
|
170 |
+
ret.append(mention)
|
171 |
+
|
172 |
+
matched = True
|
173 |
+
|
174 |
+
if matched:
|
175 |
+
cur = end
|
176 |
+
break
|
177 |
+
|
178 |
+
return ret
|
179 |
+
|
180 |
+
def detect_mentions_batch(self, texts: list[str]) -> list[list[Mention]]:
|
181 |
+
return [self.detect_mentions(text) for text in texts]
|
182 |
+
|
183 |
+
def save(self, data_dir: str) -> None:
|
184 |
+
"""
|
185 |
+
Save the entity linker data to the specified directory.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
data_dir: Directory to save the entity linker data
|
189 |
+
"""
|
190 |
+
os.makedirs(data_dir, exist_ok=True)
|
191 |
+
|
192 |
+
# Save numpy arrays
|
193 |
+
np.save(os.path.join(data_dir, "data.npy"), self._data)
|
194 |
+
np.save(os.path.join(data_dir, "offsets.npy"), self._offsets)
|
195 |
+
|
196 |
+
# Save tries
|
197 |
+
self._name_trie.save(os.path.join(data_dir, "name.trie"))
|
198 |
+
self._kb_id_trie.save(os.path.join(data_dir, "kb_id.trie"))
|
199 |
+
|
200 |
+
# Save configuration
|
201 |
+
with open(os.path.join(data_dir, "config.json"), "w") as config_file:
|
202 |
+
json.dump(
|
203 |
+
{
|
204 |
+
"max_mention_length": self._max_mention_length,
|
205 |
+
"case_sensitive": self._case_sensitive,
|
206 |
+
"min_link_prob": self._min_link_prob,
|
207 |
+
"min_prior_prob": self._min_prior_prob,
|
208 |
+
"min_link_count": self._min_link_count,
|
209 |
+
},
|
210 |
+
config_file,
|
211 |
+
)
|
212 |
+
|
213 |
+
|
214 |
+
def load_tsv_entity_vocab(file_path: str) -> dict[str, int]:
|
215 |
+
vocab = {}
|
216 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
217 |
+
reader = csv.reader(file, delimiter="\t")
|
218 |
+
for row in reader:
|
219 |
+
vocab[row[0]] = int(row[1])
|
220 |
+
return vocab
|
221 |
+
|
222 |
+
|
223 |
+
def save_tsv_entity_vocab(file_path: str, entity_vocab: dict[str, int]) -> None:
|
224 |
+
"""
|
225 |
+
Save entity vocabulary to a TSV file.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
file_path: Path to save the entity vocabulary
|
229 |
+
entity_vocab: Entity vocabulary to save
|
230 |
+
"""
|
231 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
232 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
233 |
+
writer = csv.writer(f, delimiter="\t")
|
234 |
+
for entity_id, idx in entity_vocab.items():
|
235 |
+
writer.writerow([entity_id, idx])
|
236 |
+
|
237 |
+
|
238 |
+
class _Entity(NamedTuple):
|
239 |
+
entity_id: int
|
240 |
+
start: int
|
241 |
+
end: int
|
242 |
+
|
243 |
+
@property
|
244 |
+
def length(self) -> int:
|
245 |
+
return self.end - self.start
|
246 |
+
|
247 |
+
|
248 |
+
def preprocess_text(
|
249 |
+
text: str,
|
250 |
+
mentions: list[Mention] | None,
|
251 |
+
title: str | None,
|
252 |
+
title_mentions: list[Mention] | None,
|
253 |
+
tokenizer: PreTrainedTokenizerBase,
|
254 |
+
entity_vocab: dict[str, int],
|
255 |
+
) -> dict[str, list[int]]:
|
256 |
+
tokens = []
|
257 |
+
entity_ids = []
|
258 |
+
entity_position_ids = []
|
259 |
+
if title is not None:
|
260 |
+
if title_mentions is None:
|
261 |
+
title_mentions = []
|
262 |
+
|
263 |
+
title_tokens, title_entities = _tokenize_text_with_mentions(title, title_mentions, tokenizer, entity_vocab)
|
264 |
+
tokens += title_tokens + [tokenizer.sep_token]
|
265 |
+
for entity in title_entities:
|
266 |
+
entity_ids.append(entity.entity_id)
|
267 |
+
entity_position_ids.append(list(range(entity.start, entity.end)))
|
268 |
+
|
269 |
+
if mentions is None:
|
270 |
+
mentions = []
|
271 |
+
|
272 |
+
entity_offset = len(tokens)
|
273 |
+
text_tokens, text_entities = _tokenize_text_with_mentions(text, mentions, tokenizer, entity_vocab)
|
274 |
+
tokens += text_tokens
|
275 |
+
for entity in text_entities:
|
276 |
+
entity_ids.append(entity.entity_id)
|
277 |
+
entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset)))
|
278 |
+
|
279 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
280 |
+
|
281 |
+
return {
|
282 |
+
"input_ids": input_ids,
|
283 |
+
"entity_ids": entity_ids,
|
284 |
+
"entity_position_ids": entity_position_ids,
|
285 |
+
}
|
286 |
+
|
287 |
+
|
288 |
+
def _tokenize_text_with_mentions(
|
289 |
+
text: str,
|
290 |
+
mentions: list[Mention],
|
291 |
+
tokenizer: PreTrainedTokenizerBase,
|
292 |
+
entity_vocab: dict[str, int],
|
293 |
+
) -> tuple[list[str], list[_Entity]]:
|
294 |
+
"""
|
295 |
+
Tokenize text while preserving mention boundaries and mapping entities.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
text: Input text to tokenize
|
299 |
+
mentions: List of detected mentions in the text
|
300 |
+
tokenizer: Pre-trained tokenizer to use for tokenization
|
301 |
+
entity_vocab: Mapping from entity KB IDs to entity vocabulary indices
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
Tuple containing:
|
305 |
+
- List of tokens from the tokenized text
|
306 |
+
- List of _Entity objects with entity IDs and token positions
|
307 |
+
"""
|
308 |
+
target_mentions = [mention for mention in mentions if mention.kb_id is not None and mention.kb_id in entity_vocab]
|
309 |
+
split_char_positions = {mention.start for mention in target_mentions} | {mention.end for mention in target_mentions}
|
310 |
+
|
311 |
+
tokens: list[str] = []
|
312 |
+
cur = 0
|
313 |
+
char_to_token_mapping = {}
|
314 |
+
for char_position in sorted(split_char_positions):
|
315 |
+
target_text = text[cur:char_position]
|
316 |
+
tokens += tokenizer.tokenize(target_text)
|
317 |
+
char_to_token_mapping[char_position] = len(tokens)
|
318 |
+
cur = char_position
|
319 |
+
tokens += tokenizer.tokenize(text[cur:])
|
320 |
+
|
321 |
+
entities = [
|
322 |
+
_Entity(
|
323 |
+
entity_vocab[mention.kb_id],
|
324 |
+
char_to_token_mapping[mention.start],
|
325 |
+
char_to_token_mapping[mention.end],
|
326 |
+
)
|
327 |
+
for mention in target_mentions
|
328 |
+
]
|
329 |
+
return tokens, entities
|
330 |
+
|
331 |
+
|
332 |
+
class KPRBertTokenizer(BertTokenizer):
|
333 |
+
vocab_files_names = {
|
334 |
+
**BertTokenizer.vocab_files_names, # Include the parent class files (vocab.txt)
|
335 |
+
"entity_linker_data_file": "entity_linker/data.npy",
|
336 |
+
"entity_linker_offsets_file": "entity_linker/offsets.npy",
|
337 |
+
"entity_linker_name_trie_file": "entity_linker/name.trie",
|
338 |
+
"entity_linker_kb_id_trie_file": "entity_linker/kb_id.trie",
|
339 |
+
"entity_linker_config_file": "entity_linker/config.json",
|
340 |
+
"entity_vocab_file": "entity_vocab.tsv",
|
341 |
+
"entity_embeddings_file": "entity_embeddings.npy",
|
342 |
+
}
|
343 |
+
model_input_names = [
|
344 |
+
"input_ids",
|
345 |
+
"token_type_ids",
|
346 |
+
"attention_mask",
|
347 |
+
"entity_ids",
|
348 |
+
"entity_position_ids",
|
349 |
+
]
|
350 |
+
|
351 |
+
def __init__(
|
352 |
+
self,
|
353 |
+
vocab_file,
|
354 |
+
entity_linker_data_file: str,
|
355 |
+
entity_vocab_file: str,
|
356 |
+
entity_embeddings_file: str | None = None,
|
357 |
+
*args,
|
358 |
+
**kwargs,
|
359 |
+
):
|
360 |
+
super().__init__(vocab_file=vocab_file, *args, **kwargs)
|
361 |
+
entity_linker_dir = str(Path(entity_linker_data_file).parent)
|
362 |
+
self.entity_linker = DictionaryEntityLinker.load(entity_linker_dir)
|
363 |
+
self.entity_to_id = load_tsv_entity_vocab(entity_vocab_file)
|
364 |
+
self.id_to_entity = {v: k for k, v in self.entity_to_id.items()}
|
365 |
+
|
366 |
+
self.entity_embeddings = None
|
367 |
+
if entity_embeddings_file:
|
368 |
+
# Use memory-mapped loading for large embeddings
|
369 |
+
self.entity_embeddings = np.load(entity_embeddings_file, mmap_mode="r")
|
370 |
+
if self.entity_embeddings.shape[0] != len(self.entity_to_id):
|
371 |
+
raise ValueError(
|
372 |
+
f"Entity embeddings shape {self.entity_embeddings.shape[0]} does not match "
|
373 |
+
f"the number of entities {len(self.entity_to_id)}. "
|
374 |
+
"Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
|
375 |
+
)
|
376 |
+
|
377 |
+
def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]:
|
378 |
+
mentions = self.entity_linker.detect_mentions(text)
|
379 |
+
model_inputs = preprocess_text(
|
380 |
+
text=text,
|
381 |
+
mentions=mentions,
|
382 |
+
title=None,
|
383 |
+
title_mentions=None,
|
384 |
+
tokenizer=self,
|
385 |
+
entity_vocab=self.entity_to_id,
|
386 |
+
)
|
387 |
+
|
388 |
+
# Prepare the inputs for the model
|
389 |
+
# This will add special tokens or truncate the input when specified in kwargs
|
390 |
+
# We exclude "return_tensors" from kwargs
|
391 |
+
# to avoid issues in passing the data to BatchEncoding outside this method
|
392 |
+
prepared_inputs = self.prepare_for_model(
|
393 |
+
model_inputs["input_ids"],
|
394 |
+
**{k: v for k, v in kwargs.items() if k != "return_tensors"},
|
395 |
+
)
|
396 |
+
model_inputs.update(prepared_inputs)
|
397 |
+
|
398 |
+
# Account for special tokens
|
399 |
+
if kwargs.get("add_special_tokens", True):
|
400 |
+
if prepared_inputs["input_ids"][0] != self.cls_token_id:
|
401 |
+
raise ValueError(
|
402 |
+
"We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
|
403 |
+
)
|
404 |
+
# Shift the entity position IDs by 1 to account for the [CLS] token
|
405 |
+
model_inputs["entity_position_ids"] = [
|
406 |
+
[pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"]
|
407 |
+
]
|
408 |
+
|
409 |
+
# If there is no entities in the text, we output padding entity for the model
|
410 |
+
if not model_inputs["entity_ids"]:
|
411 |
+
model_inputs["entity_ids"] = [0] # The padding entity id is 0
|
412 |
+
model_inputs["entity_position_ids"] = [[0]]
|
413 |
+
|
414 |
+
# Count the number of special tokens at the end of the input
|
415 |
+
num_special_tokens_at_end = 0
|
416 |
+
input_ids = prepared_inputs["input_ids"]
|
417 |
+
if isinstance(input_ids, torch.Tensor):
|
418 |
+
input_ids = input_ids.tolist()
|
419 |
+
for input_id in input_ids[::-1]:
|
420 |
+
if int(input_id) not in {
|
421 |
+
self.sep_token_id,
|
422 |
+
self.pad_token_id,
|
423 |
+
self.cls_token_id,
|
424 |
+
}:
|
425 |
+
break
|
426 |
+
num_special_tokens_at_end += 1
|
427 |
+
|
428 |
+
# Remove entities that are not in truncated input
|
429 |
+
max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
|
430 |
+
entity_indices_to_keep = list()
|
431 |
+
for i, position_ids in enumerate(model_inputs["entity_position_ids"]):
|
432 |
+
if len(position_ids) > 0 and max(position_ids) < max_effective_pos:
|
433 |
+
entity_indices_to_keep.append(i)
|
434 |
+
model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
|
435 |
+
model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep]
|
436 |
+
|
437 |
+
if self.entity_embeddings is not None:
|
438 |
+
model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32)
|
439 |
+
return model_inputs
|
440 |
+
|
441 |
+
def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
|
442 |
+
for unsupported_arg in ["text_pair", "text_target", "text_pair_target"]:
|
443 |
+
if unsupported_arg in kwargs:
|
444 |
+
raise ValueError(
|
445 |
+
f"Argument '{unsupported_arg}' is not supported by {self.__class__.__name__}. "
|
446 |
+
"This tokenizer only supports single text inputs. "
|
447 |
+
)
|
448 |
+
|
449 |
+
if isinstance(text, str):
|
450 |
+
processed_inputs = self._preprocess_text(text, **kwargs)
|
451 |
+
return BatchEncoding(
|
452 |
+
processed_inputs,
|
453 |
+
tensor_type=kwargs.get("return_tensors", None),
|
454 |
+
prepend_batch_axis=True,
|
455 |
+
)
|
456 |
+
|
457 |
+
processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
|
458 |
+
collated_inputs = {
|
459 |
+
key: [item[key] for item in processed_inputs_list] for key in processed_inputs_list[0].keys()
|
460 |
+
}
|
461 |
+
if kwargs.get("padding"):
|
462 |
+
collated_inputs = self.pad(
|
463 |
+
collated_inputs,
|
464 |
+
padding=kwargs["padding"],
|
465 |
+
max_length=kwargs.get("max_length"),
|
466 |
+
pad_to_multiple_of=kwargs.get("pad_to_multiple_of"),
|
467 |
+
return_attention_mask=kwargs.get("return_attention_mask"),
|
468 |
+
verbose=kwargs.get("verbose", True),
|
469 |
+
)
|
470 |
+
# Pad entity ids
|
471 |
+
max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
|
472 |
+
for entity_ids in collated_inputs["entity_ids"]:
|
473 |
+
entity_ids += [0] * (max_num_entities - len(entity_ids))
|
474 |
+
# Pad entity position ids
|
475 |
+
flattened_entity_length = [
|
476 |
+
len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list
|
477 |
+
]
|
478 |
+
max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0
|
479 |
+
for entity_position_ids_list in collated_inputs["entity_position_ids"]:
|
480 |
+
# pad entity_position_ids to max_entity_token_length
|
481 |
+
for entity_position_ids in entity_position_ids_list:
|
482 |
+
entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids))
|
483 |
+
# pad to max_num_entities
|
484 |
+
entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * (
|
485 |
+
max_num_entities - len(entity_position_ids_list)
|
486 |
+
)
|
487 |
+
# Pad entity embeddings
|
488 |
+
if "entity_embeds" in collated_inputs:
|
489 |
+
for i in range(len(collated_inputs["entity_embeds"])):
|
490 |
+
collated_inputs["entity_embeds"][i] = np.pad(
|
491 |
+
collated_inputs["entity_embeds"][i],
|
492 |
+
pad_width=(
|
493 |
+
(
|
494 |
+
0,
|
495 |
+
max_num_entities - len(collated_inputs["entity_embeds"][i]),
|
496 |
+
),
|
497 |
+
(0, 0),
|
498 |
+
),
|
499 |
+
mode="constant",
|
500 |
+
constant_values=0,
|
501 |
+
)
|
502 |
+
return BatchEncoding(collated_inputs, tensor_type=kwargs.get("return_tensors", None))
|
503 |
+
|
504 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
|
505 |
+
os.makedirs(save_directory, exist_ok=True)
|
506 |
+
saved_files = list(super().save_vocabulary(save_directory, filename_prefix))
|
507 |
+
|
508 |
+
# Save entity linker data
|
509 |
+
entity_linker_save_dir = str(
|
510 |
+
Path(save_directory) / Path(self.vocab_files_names["entity_linker_data_file"]).parent
|
511 |
+
)
|
512 |
+
self.entity_linker.save(entity_linker_save_dir)
|
513 |
+
for file_name in self.vocab_files_names.values():
|
514 |
+
if file_name.startswith("entity_linker/"):
|
515 |
+
saved_files.append(file_name)
|
516 |
+
|
517 |
+
# Save entity vocabulary
|
518 |
+
entity_vocab_path = str(Path(save_directory) / self.vocab_files_names["entity_vocab_file"])
|
519 |
+
save_tsv_entity_vocab(entity_vocab_path, self.entity_to_id)
|
520 |
+
saved_files.append(self.vocab_files_names["entity_vocab_file"])
|
521 |
+
|
522 |
+
if self.entity_embeddings is not None:
|
523 |
+
entity_embeddings_path = str(Path(save_directory) / self.vocab_files_names["entity_embeddings_file"])
|
524 |
+
np.save(entity_embeddings_path, self.entity_embeddings)
|
525 |
+
saved_files.append(self.vocab_files_names["entity_embeddings_file"])
|
526 |
+
return tuple(saved_files)
|
tokenizer_config.json
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"auto_map": {
|
45 |
+
"AutoTokenizer": [
|
46 |
+
"tokenization_kpr.KPRBertTokenizer",
|
47 |
+
null
|
48 |
+
]
|
49 |
+
},
|
50 |
+
"clean_up_tokenization_spaces": true,
|
51 |
+
"cls_token": "[CLS]",
|
52 |
+
"do_basic_tokenize": true,
|
53 |
+
"do_lower_case": true,
|
54 |
+
"extra_special_tokens": {},
|
55 |
+
"mask_token": "[MASK]",
|
56 |
+
"model_max_length": 512,
|
57 |
+
"never_split": null,
|
58 |
+
"pad_token": "[PAD]",
|
59 |
+
"sep_token": "[SEP]",
|
60 |
+
"strip_accents": null,
|
61 |
+
"tokenize_chinese_chars": true,
|
62 |
+
"tokenizer_class": "KPRBertTokenizer",
|
63 |
+
"unk_token": "[UNK]"
|
64 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|