ikuyamada commited on
Commit
0c98849
·
verified ·
1 Parent(s): 3f94a29

Add new SentenceTransformer model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ entity_linker/kb_id.trie filter=lfs diff=lfs merge=lfs -text
37
+ entity_linker/name.trie filter=lfs diff=lfs merge=lfs -text
38
+ entity_vocab.tsv filter=lfs diff=lfs merge=lfs -text
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": true,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
README.md CHANGED
@@ -1,199 +1,143 @@
1
  ---
2
- library_name: transformers
3
- tags: []
 
 
 
 
 
4
  ---
5
 
6
- # Model Card for Model ID
7
-
8
- <!-- Provide a quick summary of what the model is/does. -->
9
-
10
 
 
11
 
12
  ## Model Details
13
 
14
  ### Model Description
 
 
 
 
 
 
 
 
15
 
16
- <!-- Provide a longer summary of what this model is. -->
17
-
18
- This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
-
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
 
28
- ### Model Sources [optional]
 
 
29
 
30
- <!-- Provide the basic links for the model. -->
31
 
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
 
 
 
35
 
36
- ## Uses
37
 
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
 
40
- ### Direct Use
41
 
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
 
 
43
 
44
- [More Information Needed]
 
 
45
 
46
- ### Downstream Use [optional]
 
 
 
 
 
 
 
 
 
 
47
 
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
 
 
 
 
 
 
49
 
50
- [More Information Needed]
 
51
 
52
- ### Out-of-Scope Use
53
 
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
 
55
 
56
- [More Information Needed]
 
57
 
58
- ## Bias, Risks, and Limitations
59
 
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
 
62
- [More Information Needed]
 
63
 
64
- ### Recommendations
 
65
 
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
 
67
 
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
 
69
 
70
- ## How to Get Started with the Model
 
71
 
72
- Use the code below to get started with the model.
 
73
 
74
- [More Information Needed]
 
75
 
76
  ## Training Details
77
 
78
- ### Training Data
79
-
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
-
82
- [More Information Needed]
83
-
84
- ### Training Procedure
85
-
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
-
88
- #### Preprocessing [optional]
89
-
90
- [More Information Needed]
91
-
92
-
93
- #### Training Hyperparameters
94
-
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
-
97
- #### Speeds, Sizes, Times [optional]
98
-
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
-
101
- [More Information Needed]
102
-
103
- ## Evaluation
104
-
105
- <!-- This section describes the evaluation protocols and provides the results. -->
106
-
107
- ### Testing Data, Factors & Metrics
108
-
109
- #### Testing Data
110
-
111
- <!-- This should link to a Dataset Card if possible. -->
112
-
113
- [More Information Needed]
114
-
115
- #### Factors
116
-
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
-
119
- [More Information Needed]
120
-
121
- #### Metrics
122
-
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
-
125
- [More Information Needed]
126
-
127
- ### Results
128
-
129
- [More Information Needed]
130
-
131
- #### Summary
132
-
133
-
134
-
135
- ## Model Examination [optional]
136
-
137
- <!-- Relevant interpretability work for the model goes here -->
138
-
139
- [More Information Needed]
140
-
141
- ## Environmental Impact
142
-
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
-
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
-
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
-
153
- ## Technical Specifications [optional]
154
-
155
- ### Model Architecture and Objective
156
-
157
- [More Information Needed]
158
-
159
- ### Compute Infrastructure
160
-
161
- [More Information Needed]
162
-
163
- #### Hardware
164
-
165
- [More Information Needed]
166
-
167
- #### Software
168
-
169
- [More Information Needed]
170
-
171
- ## Citation [optional]
172
-
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
-
175
- **BibTeX:**
176
-
177
- [More Information Needed]
178
-
179
- **APA:**
180
-
181
- [More Information Needed]
182
-
183
- ## Glossary [optional]
184
 
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
 
187
- [More Information Needed]
188
 
189
- ## More Information [optional]
 
190
 
191
- [More Information Needed]
 
192
 
193
- ## Model Card Authors [optional]
 
194
 
195
- [More Information Needed]
 
196
 
 
197
  ## Model Card Contact
198
 
199
- [More Information Needed]
 
 
1
  ---
2
+ tags:
3
+ - sentence-transformers
4
+ - sentence-similarity
5
+ - feature-extraction
6
+ - dense
7
+ pipeline_tag: sentence-similarity
8
+ library_name: sentence-transformers
9
  ---
10
 
11
+ # SentenceTransformer
 
 
 
12
 
13
+ This is a [sentence-transformers](https://www.SBERT.net) model trained. It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
14
 
15
  ## Model Details
16
 
17
  ### Model Description
18
+ - **Model Type:** Sentence Transformer
19
+ <!-- - **Base model:** [Unknown](https://huggingface.co/unknown) -->
20
+ - **Maximum Sequence Length:** 512 tokens
21
+ - **Output Dimensionality:** 768 dimensions
22
+ - **Similarity Function:** Cosine Similarity
23
+ <!-- - **Training Dataset:** Unknown -->
24
+ <!-- - **Language:** Unknown -->
25
+ <!-- - **License:** Unknown -->
26
 
27
+ ### Model Sources
 
 
 
 
 
 
 
 
 
 
28
 
29
+ - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
30
+ - **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
31
+ - **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
32
 
33
+ ### Full Model Architecture
34
 
35
+ ```
36
+ SentenceTransformer(
37
+ (0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'KPRModelForBert'})
38
+ (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
39
+ )
40
+ ```
41
 
42
+ ## Usage
43
 
44
+ ### Direct Usage (Sentence Transformers)
45
 
46
+ First install the Sentence Transformers library:
47
 
48
+ ```bash
49
+ pip install -U sentence-transformers
50
+ ```
51
 
52
+ Then you can load this model and run inference.
53
+ ```python
54
+ from sentence_transformers import SentenceTransformer
55
 
56
+ # Download from the 🤗 Hub
57
+ model = SentenceTransformer("knowledgeable-ai/kpr-bge-base-en-v1.5")
58
+ # Run inference
59
+ sentences = [
60
+ 'The weather is lovely today.',
61
+ "It's so sunny outside!",
62
+ 'He drove to the stadium.',
63
+ ]
64
+ embeddings = model.encode(sentences)
65
+ print(embeddings.shape)
66
+ # [3, 768]
67
 
68
+ # Get the similarity scores for the embeddings
69
+ similarities = model.similarity(embeddings, embeddings)
70
+ print(similarities)
71
+ # tensor([[1.0000, 0.7985, 0.4422],
72
+ # [0.7985, 1.0000, 0.4318],
73
+ # [0.4422, 0.4318, 1.0000]])
74
+ ```
75
 
76
+ <!--
77
+ ### Direct Usage (Transformers)
78
 
79
+ <details><summary>Click to see the direct usage in Transformers</summary>
80
 
81
+ </details>
82
+ -->
83
 
84
+ <!--
85
+ ### Downstream Usage (Sentence Transformers)
86
 
87
+ You can finetune this model on your own dataset.
88
 
89
+ <details><summary>Click to expand</summary>
90
 
91
+ </details>
92
+ -->
93
 
94
+ <!--
95
+ ### Out-of-Scope Use
96
 
97
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
98
+ -->
99
 
100
+ <!--
101
+ ## Bias, Risks and Limitations
102
 
103
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
104
+ -->
105
 
106
+ <!--
107
+ ### Recommendations
108
 
109
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
110
+ -->
111
 
112
  ## Training Details
113
 
114
+ ### Framework Versions
115
+ - Python: 3.10.14
116
+ - Sentence Transformers: 5.2.0.dev0
117
+ - Transformers: 4.55.4
118
+ - PyTorch: 2.4.0+cu121
119
+ - Accelerate: 0.34.2
120
+ - Datasets: 2.16.1
121
+ - Tokenizers: 0.21.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ ## Citation
124
 
125
+ ### BibTeX
126
 
127
+ <!--
128
+ ## Glossary
129
 
130
+ *Clearly define terms in order to be accessible across audiences.*
131
+ -->
132
 
133
+ <!--
134
+ ## Model Card Authors
135
 
136
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
137
+ -->
138
 
139
+ <!--
140
  ## Model Card Contact
141
 
142
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
143
+ -->
config.json CHANGED
@@ -35,7 +35,7 @@
35
  "similarity_function": "cosine",
36
  "similarity_temperature": 0.02,
37
  "torch_dtype": "float32",
38
- "transformers_version": "4.55.2",
39
  "type_vocab_size": 2,
40
  "use_cache": true,
41
  "use_entity_position_embeddings": true,
 
35
  "similarity_function": "cosine",
36
  "similarity_temperature": 0.02,
37
  "torch_dtype": "float32",
38
+ "transformers_version": "4.55.4",
39
  "type_vocab_size": 2,
40
  "use_cache": true,
41
  "use_entity_position_embeddings": true,
config_sentence_transformers.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SentenceTransformer",
3
+ "__version__": {
4
+ "sentence_transformers": "5.2.0.dev0",
5
+ "transformers": "4.55.4",
6
+ "pytorch": "2.4.0+cu121"
7
+ },
8
+ "prompts": {
9
+ "query": "",
10
+ "document": ""
11
+ },
12
+ "default_prompt_name": null,
13
+ "similarity_fn_name": "cosine"
14
+ }
entity_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd83489d63bb45008620d90ba274331981546081491cfdd94be5afea9cb1cfea
3
+ size 11126965376
entity_linker/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"max_mention_length": 100, "case_sensitive": false, "min_link_prob": 0.05, "min_prior_prob": 0.3, "min_link_count": 1}
entity_linker/data.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e6bded234ea7be9250487f6e7ed26ccb3f04eb74443d78617d8328c3d3e41b
3
+ size 306472944
entity_linker/kb_id.trie ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21ac7411d8a0e9a5497c2a4695b2b115b166824d34df4d4d01492b600ce375d8
3
+ size 13211960
entity_linker/name.trie ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3897d996a84067d9dffaa1a826e114be2c8368c011534048facee4c18fa5b69
3
+ size 96855400
entity_linker/offsets.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ffb5444bbdcb9abcc63ef1a183cd6193c7346865a4b17c6ce7cbbed4c673ef2
3
+ size 70649476
entity_vocab.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f835c91847892b455fe15eb75cc008e4504be4c37a7e1c0718165d029e03fda5
3
+ size 131894551
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 512,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenization_kpr.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import json
5
+ import os
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import NamedTuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ import spacy
13
+ from marisa_trie import Trie
14
+ from transformers import BatchEncoding, BertTokenizer, PreTrainedTokenizerBase
15
+
16
+ NONE_ID = "<None>"
17
+
18
+
19
+ @dataclass
20
+ class Mention:
21
+ kb_id: str | None
22
+ text: str
23
+ start: int
24
+ end: int
25
+ link_count: int | None
26
+ total_link_count: int | None
27
+ doc_count: int | None
28
+
29
+ @property
30
+ def span(self) -> tuple[int, int]:
31
+ return self.start, self.end
32
+
33
+ @property
34
+ def link_prob(self) -> float | None:
35
+ if self.doc_count is None or self.total_link_count is None:
36
+ return None
37
+ elif self.doc_count > 0:
38
+ return min(1.0, self.total_link_count / self.doc_count)
39
+ else:
40
+ return 0.0
41
+
42
+ @property
43
+ def prior_prob(self) -> float | None:
44
+ if self.link_count is None or self.total_link_count is None:
45
+ return None
46
+ elif self.total_link_count > 0:
47
+ return min(1.0, self.link_count / self.total_link_count)
48
+ else:
49
+ return 0.0
50
+
51
+ def __repr__(self):
52
+ return f"<Mention {self.text} -> {self.kb_id}>"
53
+
54
+
55
+ def get_tokenizer(language: str) -> spacy.tokenizer.Tokenizer:
56
+ language_obj = spacy.blank(language)
57
+ return language_obj.tokenizer
58
+
59
+
60
+ class DictionaryEntityLinker:
61
+ def __init__(
62
+ self,
63
+ name_trie: Trie,
64
+ kb_id_trie: Trie,
65
+ data: np.ndarray,
66
+ offsets: np.ndarray,
67
+ max_mention_length: int,
68
+ case_sensitive: bool,
69
+ min_link_prob: float | None,
70
+ min_prior_prob: float | None,
71
+ min_link_count: int | None,
72
+ ):
73
+ self._name_trie = name_trie
74
+ self._kb_id_trie = kb_id_trie
75
+ self._data = data
76
+ self._offsets = offsets
77
+ self._max_mention_length = max_mention_length
78
+ self._case_sensitive = case_sensitive
79
+
80
+ self._min_link_prob = min_link_prob
81
+ self._min_prior_prob = min_prior_prob
82
+ self._min_link_count = min_link_count
83
+
84
+ self._tokenizer = get_tokenizer("en")
85
+
86
+ @staticmethod
87
+ def load(
88
+ data_dir: str,
89
+ min_link_prob: float | None = None,
90
+ min_prior_prob: float | None = None,
91
+ min_link_count: int | None = None,
92
+ ) -> "DictionaryEntityLinker":
93
+ data = np.load(os.path.join(data_dir, "data.npy"))
94
+ offsets = np.load(os.path.join(data_dir, "offsets.npy"))
95
+ name_trie = Trie()
96
+ name_trie.load(os.path.join(data_dir, "name.trie"))
97
+ kb_id_trie = Trie()
98
+ kb_id_trie.load(os.path.join(data_dir, "kb_id.trie"))
99
+
100
+ with open(os.path.join(data_dir, "config.json")) as config_file:
101
+ config = json.load(config_file)
102
+
103
+ if min_link_prob is None:
104
+ min_link_prob = config.get("min_link_prob", None)
105
+
106
+ if min_prior_prob is None:
107
+ min_prior_prob = config.get("min_prior_prob", None)
108
+
109
+ if min_link_count is None:
110
+ min_link_count = config.get("min_link_count", None)
111
+
112
+ return DictionaryEntityLinker(
113
+ name_trie=name_trie,
114
+ kb_id_trie=kb_id_trie,
115
+ data=data,
116
+ offsets=offsets,
117
+ max_mention_length=config["max_mention_length"],
118
+ case_sensitive=config["case_sensitive"],
119
+ min_link_prob=min_link_prob,
120
+ min_prior_prob=min_prior_prob,
121
+ min_link_count=min_link_count,
122
+ )
123
+
124
+ def detect_mentions(self, text: str) -> list[Mention]:
125
+ tokens = self._tokenizer(text)
126
+ end_offsets = frozenset(token.idx + len(token) for token in tokens)
127
+ if not self._case_sensitive:
128
+ text = text.lower()
129
+
130
+ ret = []
131
+ cur = 0
132
+ for token in tokens:
133
+ start = token.idx
134
+ if cur > start:
135
+ continue
136
+
137
+ for prefix in sorted(
138
+ self._name_trie.prefixes(text[start : start + self._max_mention_length]),
139
+ key=len,
140
+ reverse=True,
141
+ ):
142
+ end = start + len(prefix)
143
+ if end in end_offsets:
144
+ matched = False
145
+ mention_idx = self._name_trie[prefix]
146
+ data_start, data_end = self._offsets[mention_idx : mention_idx + 2]
147
+ for item in self._data[data_start:data_end]:
148
+ if item.size == 4:
149
+ kb_idx, link_count, total_link_count, doc_count = item
150
+ elif item.size == 1:
151
+ (kb_idx,) = item
152
+ link_count, total_link_count, doc_count = None, None, None
153
+ else:
154
+ raise ValueError("Unexpected data array format")
155
+
156
+ mention = Mention(
157
+ kb_id=self._kb_id_trie.restore_key(kb_idx),
158
+ text=prefix,
159
+ start=start,
160
+ end=end,
161
+ link_count=link_count,
162
+ total_link_count=total_link_count,
163
+ doc_count=doc_count,
164
+ )
165
+ if item.size == 1 or (
166
+ mention.link_prob >= self._min_link_prob
167
+ and mention.prior_prob >= self._min_prior_prob
168
+ and mention.link_count >= self._min_link_count
169
+ ):
170
+ ret.append(mention)
171
+
172
+ matched = True
173
+
174
+ if matched:
175
+ cur = end
176
+ break
177
+
178
+ return ret
179
+
180
+ def detect_mentions_batch(self, texts: list[str]) -> list[list[Mention]]:
181
+ return [self.detect_mentions(text) for text in texts]
182
+
183
+ def save(self, data_dir: str) -> None:
184
+ """
185
+ Save the entity linker data to the specified directory.
186
+
187
+ Args:
188
+ data_dir: Directory to save the entity linker data
189
+ """
190
+ os.makedirs(data_dir, exist_ok=True)
191
+
192
+ # Save numpy arrays
193
+ np.save(os.path.join(data_dir, "data.npy"), self._data)
194
+ np.save(os.path.join(data_dir, "offsets.npy"), self._offsets)
195
+
196
+ # Save tries
197
+ self._name_trie.save(os.path.join(data_dir, "name.trie"))
198
+ self._kb_id_trie.save(os.path.join(data_dir, "kb_id.trie"))
199
+
200
+ # Save configuration
201
+ with open(os.path.join(data_dir, "config.json"), "w") as config_file:
202
+ json.dump(
203
+ {
204
+ "max_mention_length": self._max_mention_length,
205
+ "case_sensitive": self._case_sensitive,
206
+ "min_link_prob": self._min_link_prob,
207
+ "min_prior_prob": self._min_prior_prob,
208
+ "min_link_count": self._min_link_count,
209
+ },
210
+ config_file,
211
+ )
212
+
213
+
214
+ def load_tsv_entity_vocab(file_path: str) -> dict[str, int]:
215
+ vocab = {}
216
+ with open(file_path, "r", encoding="utf-8") as file:
217
+ reader = csv.reader(file, delimiter="\t")
218
+ for row in reader:
219
+ vocab[row[0]] = int(row[1])
220
+ return vocab
221
+
222
+
223
+ def save_tsv_entity_vocab(file_path: str, entity_vocab: dict[str, int]) -> None:
224
+ """
225
+ Save entity vocabulary to a TSV file.
226
+
227
+ Args:
228
+ file_path: Path to save the entity vocabulary
229
+ entity_vocab: Entity vocabulary to save
230
+ """
231
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
232
+ with open(file_path, "w", encoding="utf-8") as f:
233
+ writer = csv.writer(f, delimiter="\t")
234
+ for entity_id, idx in entity_vocab.items():
235
+ writer.writerow([entity_id, idx])
236
+
237
+
238
+ class _Entity(NamedTuple):
239
+ entity_id: int
240
+ start: int
241
+ end: int
242
+
243
+ @property
244
+ def length(self) -> int:
245
+ return self.end - self.start
246
+
247
+
248
+ def preprocess_text(
249
+ text: str,
250
+ mentions: list[Mention] | None,
251
+ title: str | None,
252
+ title_mentions: list[Mention] | None,
253
+ tokenizer: PreTrainedTokenizerBase,
254
+ entity_vocab: dict[str, int],
255
+ ) -> dict[str, list[int]]:
256
+ tokens = []
257
+ entity_ids = []
258
+ entity_position_ids = []
259
+ if title is not None:
260
+ if title_mentions is None:
261
+ title_mentions = []
262
+
263
+ title_tokens, title_entities = _tokenize_text_with_mentions(title, title_mentions, tokenizer, entity_vocab)
264
+ tokens += title_tokens + [tokenizer.sep_token]
265
+ for entity in title_entities:
266
+ entity_ids.append(entity.entity_id)
267
+ entity_position_ids.append(list(range(entity.start, entity.end)))
268
+
269
+ if mentions is None:
270
+ mentions = []
271
+
272
+ entity_offset = len(tokens)
273
+ text_tokens, text_entities = _tokenize_text_with_mentions(text, mentions, tokenizer, entity_vocab)
274
+ tokens += text_tokens
275
+ for entity in text_entities:
276
+ entity_ids.append(entity.entity_id)
277
+ entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset)))
278
+
279
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
280
+
281
+ return {
282
+ "input_ids": input_ids,
283
+ "entity_ids": entity_ids,
284
+ "entity_position_ids": entity_position_ids,
285
+ }
286
+
287
+
288
+ def _tokenize_text_with_mentions(
289
+ text: str,
290
+ mentions: list[Mention],
291
+ tokenizer: PreTrainedTokenizerBase,
292
+ entity_vocab: dict[str, int],
293
+ ) -> tuple[list[str], list[_Entity]]:
294
+ """
295
+ Tokenize text while preserving mention boundaries and mapping entities.
296
+
297
+ Args:
298
+ text: Input text to tokenize
299
+ mentions: List of detected mentions in the text
300
+ tokenizer: Pre-trained tokenizer to use for tokenization
301
+ entity_vocab: Mapping from entity KB IDs to entity vocabulary indices
302
+
303
+ Returns:
304
+ Tuple containing:
305
+ - List of tokens from the tokenized text
306
+ - List of _Entity objects with entity IDs and token positions
307
+ """
308
+ target_mentions = [mention for mention in mentions if mention.kb_id is not None and mention.kb_id in entity_vocab]
309
+ split_char_positions = {mention.start for mention in target_mentions} | {mention.end for mention in target_mentions}
310
+
311
+ tokens: list[str] = []
312
+ cur = 0
313
+ char_to_token_mapping = {}
314
+ for char_position in sorted(split_char_positions):
315
+ target_text = text[cur:char_position]
316
+ tokens += tokenizer.tokenize(target_text)
317
+ char_to_token_mapping[char_position] = len(tokens)
318
+ cur = char_position
319
+ tokens += tokenizer.tokenize(text[cur:])
320
+
321
+ entities = [
322
+ _Entity(
323
+ entity_vocab[mention.kb_id],
324
+ char_to_token_mapping[mention.start],
325
+ char_to_token_mapping[mention.end],
326
+ )
327
+ for mention in target_mentions
328
+ ]
329
+ return tokens, entities
330
+
331
+
332
+ class KPRBertTokenizer(BertTokenizer):
333
+ vocab_files_names = {
334
+ **BertTokenizer.vocab_files_names, # Include the parent class files (vocab.txt)
335
+ "entity_linker_data_file": "entity_linker/data.npy",
336
+ "entity_linker_offsets_file": "entity_linker/offsets.npy",
337
+ "entity_linker_name_trie_file": "entity_linker/name.trie",
338
+ "entity_linker_kb_id_trie_file": "entity_linker/kb_id.trie",
339
+ "entity_linker_config_file": "entity_linker/config.json",
340
+ "entity_vocab_file": "entity_vocab.tsv",
341
+ "entity_embeddings_file": "entity_embeddings.npy",
342
+ }
343
+ model_input_names = [
344
+ "input_ids",
345
+ "token_type_ids",
346
+ "attention_mask",
347
+ "entity_ids",
348
+ "entity_position_ids",
349
+ ]
350
+
351
+ def __init__(
352
+ self,
353
+ vocab_file,
354
+ entity_linker_data_file: str,
355
+ entity_vocab_file: str,
356
+ entity_embeddings_file: str | None = None,
357
+ *args,
358
+ **kwargs,
359
+ ):
360
+ super().__init__(vocab_file=vocab_file, *args, **kwargs)
361
+ entity_linker_dir = str(Path(entity_linker_data_file).parent)
362
+ self.entity_linker = DictionaryEntityLinker.load(entity_linker_dir)
363
+ self.entity_to_id = load_tsv_entity_vocab(entity_vocab_file)
364
+ self.id_to_entity = {v: k for k, v in self.entity_to_id.items()}
365
+
366
+ self.entity_embeddings = None
367
+ if entity_embeddings_file:
368
+ # Use memory-mapped loading for large embeddings
369
+ self.entity_embeddings = np.load(entity_embeddings_file, mmap_mode="r")
370
+ if self.entity_embeddings.shape[0] != len(self.entity_to_id):
371
+ raise ValueError(
372
+ f"Entity embeddings shape {self.entity_embeddings.shape[0]} does not match "
373
+ f"the number of entities {len(self.entity_to_id)}. "
374
+ "Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
375
+ )
376
+
377
+ def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]:
378
+ mentions = self.entity_linker.detect_mentions(text)
379
+ model_inputs = preprocess_text(
380
+ text=text,
381
+ mentions=mentions,
382
+ title=None,
383
+ title_mentions=None,
384
+ tokenizer=self,
385
+ entity_vocab=self.entity_to_id,
386
+ )
387
+
388
+ # Prepare the inputs for the model
389
+ # This will add special tokens or truncate the input when specified in kwargs
390
+ # We exclude "return_tensors" from kwargs
391
+ # to avoid issues in passing the data to BatchEncoding outside this method
392
+ prepared_inputs = self.prepare_for_model(
393
+ model_inputs["input_ids"],
394
+ **{k: v for k, v in kwargs.items() if k != "return_tensors"},
395
+ )
396
+ model_inputs.update(prepared_inputs)
397
+
398
+ # Account for special tokens
399
+ if kwargs.get("add_special_tokens", True):
400
+ if prepared_inputs["input_ids"][0] != self.cls_token_id:
401
+ raise ValueError(
402
+ "We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
403
+ )
404
+ # Shift the entity position IDs by 1 to account for the [CLS] token
405
+ model_inputs["entity_position_ids"] = [
406
+ [pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"]
407
+ ]
408
+
409
+ # If there is no entities in the text, we output padding entity for the model
410
+ if not model_inputs["entity_ids"]:
411
+ model_inputs["entity_ids"] = [0] # The padding entity id is 0
412
+ model_inputs["entity_position_ids"] = [[0]]
413
+
414
+ # Count the number of special tokens at the end of the input
415
+ num_special_tokens_at_end = 0
416
+ input_ids = prepared_inputs["input_ids"]
417
+ if isinstance(input_ids, torch.Tensor):
418
+ input_ids = input_ids.tolist()
419
+ for input_id in input_ids[::-1]:
420
+ if int(input_id) not in {
421
+ self.sep_token_id,
422
+ self.pad_token_id,
423
+ self.cls_token_id,
424
+ }:
425
+ break
426
+ num_special_tokens_at_end += 1
427
+
428
+ # Remove entities that are not in truncated input
429
+ max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
430
+ entity_indices_to_keep = list()
431
+ for i, position_ids in enumerate(model_inputs["entity_position_ids"]):
432
+ if len(position_ids) > 0 and max(position_ids) < max_effective_pos:
433
+ entity_indices_to_keep.append(i)
434
+ model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
435
+ model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep]
436
+
437
+ if self.entity_embeddings is not None:
438
+ model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32)
439
+ return model_inputs
440
+
441
+ def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
442
+ for unsupported_arg in ["text_pair", "text_target", "text_pair_target"]:
443
+ if unsupported_arg in kwargs:
444
+ raise ValueError(
445
+ f"Argument '{unsupported_arg}' is not supported by {self.__class__.__name__}. "
446
+ "This tokenizer only supports single text inputs. "
447
+ )
448
+
449
+ if isinstance(text, str):
450
+ processed_inputs = self._preprocess_text(text, **kwargs)
451
+ return BatchEncoding(
452
+ processed_inputs,
453
+ tensor_type=kwargs.get("return_tensors", None),
454
+ prepend_batch_axis=True,
455
+ )
456
+
457
+ processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
458
+ collated_inputs = {
459
+ key: [item[key] for item in processed_inputs_list] for key in processed_inputs_list[0].keys()
460
+ }
461
+ if kwargs.get("padding"):
462
+ collated_inputs = self.pad(
463
+ collated_inputs,
464
+ padding=kwargs["padding"],
465
+ max_length=kwargs.get("max_length"),
466
+ pad_to_multiple_of=kwargs.get("pad_to_multiple_of"),
467
+ return_attention_mask=kwargs.get("return_attention_mask"),
468
+ verbose=kwargs.get("verbose", True),
469
+ )
470
+ # Pad entity ids
471
+ max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
472
+ for entity_ids in collated_inputs["entity_ids"]:
473
+ entity_ids += [0] * (max_num_entities - len(entity_ids))
474
+ # Pad entity position ids
475
+ flattened_entity_length = [
476
+ len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list
477
+ ]
478
+ max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0
479
+ for entity_position_ids_list in collated_inputs["entity_position_ids"]:
480
+ # pad entity_position_ids to max_entity_token_length
481
+ for entity_position_ids in entity_position_ids_list:
482
+ entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids))
483
+ # pad to max_num_entities
484
+ entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * (
485
+ max_num_entities - len(entity_position_ids_list)
486
+ )
487
+ # Pad entity embeddings
488
+ if "entity_embeds" in collated_inputs:
489
+ for i in range(len(collated_inputs["entity_embeds"])):
490
+ collated_inputs["entity_embeds"][i] = np.pad(
491
+ collated_inputs["entity_embeds"][i],
492
+ pad_width=(
493
+ (
494
+ 0,
495
+ max_num_entities - len(collated_inputs["entity_embeds"][i]),
496
+ ),
497
+ (0, 0),
498
+ ),
499
+ mode="constant",
500
+ constant_values=0,
501
+ )
502
+ return BatchEncoding(collated_inputs, tensor_type=kwargs.get("return_tensors", None))
503
+
504
+ def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
505
+ os.makedirs(save_directory, exist_ok=True)
506
+ saved_files = list(super().save_vocabulary(save_directory, filename_prefix))
507
+
508
+ # Save entity linker data
509
+ entity_linker_save_dir = str(
510
+ Path(save_directory) / Path(self.vocab_files_names["entity_linker_data_file"]).parent
511
+ )
512
+ self.entity_linker.save(entity_linker_save_dir)
513
+ for file_name in self.vocab_files_names.values():
514
+ if file_name.startswith("entity_linker/"):
515
+ saved_files.append(file_name)
516
+
517
+ # Save entity vocabulary
518
+ entity_vocab_path = str(Path(save_directory) / self.vocab_files_names["entity_vocab_file"])
519
+ save_tsv_entity_vocab(entity_vocab_path, self.entity_to_id)
520
+ saved_files.append(self.vocab_files_names["entity_vocab_file"])
521
+
522
+ if self.entity_embeddings is not None:
523
+ entity_embeddings_path = str(Path(save_directory) / self.vocab_files_names["entity_embeddings_file"])
524
+ np.save(entity_embeddings_path, self.entity_embeddings)
525
+ saved_files.append(self.vocab_files_names["entity_embeddings_file"])
526
+ return tuple(saved_files)
tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "auto_map": {
45
+ "AutoTokenizer": [
46
+ "tokenization_kpr.KPRBertTokenizer",
47
+ null
48
+ ]
49
+ },
50
+ "clean_up_tokenization_spaces": true,
51
+ "cls_token": "[CLS]",
52
+ "do_basic_tokenize": true,
53
+ "do_lower_case": true,
54
+ "extra_special_tokens": {},
55
+ "mask_token": "[MASK]",
56
+ "model_max_length": 512,
57
+ "never_split": null,
58
+ "pad_token": "[PAD]",
59
+ "sep_token": "[SEP]",
60
+ "strip_accents": null,
61
+ "tokenize_chinese_chars": true,
62
+ "tokenizer_class": "KPRBertTokenizer",
63
+ "unk_token": "[UNK]"
64
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff