kpr-bge-base-en-v1.5 / tokenization_kpr.py
ikuyamada's picture
Add new SentenceTransformer model
0c98849 verified
from __future__ import annotations
import csv
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple
import numpy as np
import torch
import spacy
from marisa_trie import Trie
from transformers import BatchEncoding, BertTokenizer, PreTrainedTokenizerBase
NONE_ID = "<None>"
@dataclass
class Mention:
kb_id: str | None
text: str
start: int
end: int
link_count: int | None
total_link_count: int | None
doc_count: int | None
@property
def span(self) -> tuple[int, int]:
return self.start, self.end
@property
def link_prob(self) -> float | None:
if self.doc_count is None or self.total_link_count is None:
return None
elif self.doc_count > 0:
return min(1.0, self.total_link_count / self.doc_count)
else:
return 0.0
@property
def prior_prob(self) -> float | None:
if self.link_count is None or self.total_link_count is None:
return None
elif self.total_link_count > 0:
return min(1.0, self.link_count / self.total_link_count)
else:
return 0.0
def __repr__(self):
return f"<Mention {self.text} -> {self.kb_id}>"
def get_tokenizer(language: str) -> spacy.tokenizer.Tokenizer:
language_obj = spacy.blank(language)
return language_obj.tokenizer
class DictionaryEntityLinker:
def __init__(
self,
name_trie: Trie,
kb_id_trie: Trie,
data: np.ndarray,
offsets: np.ndarray,
max_mention_length: int,
case_sensitive: bool,
min_link_prob: float | None,
min_prior_prob: float | None,
min_link_count: int | None,
):
self._name_trie = name_trie
self._kb_id_trie = kb_id_trie
self._data = data
self._offsets = offsets
self._max_mention_length = max_mention_length
self._case_sensitive = case_sensitive
self._min_link_prob = min_link_prob
self._min_prior_prob = min_prior_prob
self._min_link_count = min_link_count
self._tokenizer = get_tokenizer("en")
@staticmethod
def load(
data_dir: str,
min_link_prob: float | None = None,
min_prior_prob: float | None = None,
min_link_count: int | None = None,
) -> "DictionaryEntityLinker":
data = np.load(os.path.join(data_dir, "data.npy"))
offsets = np.load(os.path.join(data_dir, "offsets.npy"))
name_trie = Trie()
name_trie.load(os.path.join(data_dir, "name.trie"))
kb_id_trie = Trie()
kb_id_trie.load(os.path.join(data_dir, "kb_id.trie"))
with open(os.path.join(data_dir, "config.json")) as config_file:
config = json.load(config_file)
if min_link_prob is None:
min_link_prob = config.get("min_link_prob", None)
if min_prior_prob is None:
min_prior_prob = config.get("min_prior_prob", None)
if min_link_count is None:
min_link_count = config.get("min_link_count", None)
return DictionaryEntityLinker(
name_trie=name_trie,
kb_id_trie=kb_id_trie,
data=data,
offsets=offsets,
max_mention_length=config["max_mention_length"],
case_sensitive=config["case_sensitive"],
min_link_prob=min_link_prob,
min_prior_prob=min_prior_prob,
min_link_count=min_link_count,
)
def detect_mentions(self, text: str) -> list[Mention]:
tokens = self._tokenizer(text)
end_offsets = frozenset(token.idx + len(token) for token in tokens)
if not self._case_sensitive:
text = text.lower()
ret = []
cur = 0
for token in tokens:
start = token.idx
if cur > start:
continue
for prefix in sorted(
self._name_trie.prefixes(text[start : start + self._max_mention_length]),
key=len,
reverse=True,
):
end = start + len(prefix)
if end in end_offsets:
matched = False
mention_idx = self._name_trie[prefix]
data_start, data_end = self._offsets[mention_idx : mention_idx + 2]
for item in self._data[data_start:data_end]:
if item.size == 4:
kb_idx, link_count, total_link_count, doc_count = item
elif item.size == 1:
(kb_idx,) = item
link_count, total_link_count, doc_count = None, None, None
else:
raise ValueError("Unexpected data array format")
mention = Mention(
kb_id=self._kb_id_trie.restore_key(kb_idx),
text=prefix,
start=start,
end=end,
link_count=link_count,
total_link_count=total_link_count,
doc_count=doc_count,
)
if item.size == 1 or (
mention.link_prob >= self._min_link_prob
and mention.prior_prob >= self._min_prior_prob
and mention.link_count >= self._min_link_count
):
ret.append(mention)
matched = True
if matched:
cur = end
break
return ret
def detect_mentions_batch(self, texts: list[str]) -> list[list[Mention]]:
return [self.detect_mentions(text) for text in texts]
def save(self, data_dir: str) -> None:
"""
Save the entity linker data to the specified directory.
Args:
data_dir: Directory to save the entity linker data
"""
os.makedirs(data_dir, exist_ok=True)
# Save numpy arrays
np.save(os.path.join(data_dir, "data.npy"), self._data)
np.save(os.path.join(data_dir, "offsets.npy"), self._offsets)
# Save tries
self._name_trie.save(os.path.join(data_dir, "name.trie"))
self._kb_id_trie.save(os.path.join(data_dir, "kb_id.trie"))
# Save configuration
with open(os.path.join(data_dir, "config.json"), "w") as config_file:
json.dump(
{
"max_mention_length": self._max_mention_length,
"case_sensitive": self._case_sensitive,
"min_link_prob": self._min_link_prob,
"min_prior_prob": self._min_prior_prob,
"min_link_count": self._min_link_count,
},
config_file,
)
def load_tsv_entity_vocab(file_path: str) -> dict[str, int]:
vocab = {}
with open(file_path, "r", encoding="utf-8") as file:
reader = csv.reader(file, delimiter="\t")
for row in reader:
vocab[row[0]] = int(row[1])
return vocab
def save_tsv_entity_vocab(file_path: str, entity_vocab: dict[str, int]) -> None:
"""
Save entity vocabulary to a TSV file.
Args:
file_path: Path to save the entity vocabulary
entity_vocab: Entity vocabulary to save
"""
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f, delimiter="\t")
for entity_id, idx in entity_vocab.items():
writer.writerow([entity_id, idx])
class _Entity(NamedTuple):
entity_id: int
start: int
end: int
@property
def length(self) -> int:
return self.end - self.start
def preprocess_text(
text: str,
mentions: list[Mention] | None,
title: str | None,
title_mentions: list[Mention] | None,
tokenizer: PreTrainedTokenizerBase,
entity_vocab: dict[str, int],
) -> dict[str, list[int]]:
tokens = []
entity_ids = []
entity_position_ids = []
if title is not None:
if title_mentions is None:
title_mentions = []
title_tokens, title_entities = _tokenize_text_with_mentions(title, title_mentions, tokenizer, entity_vocab)
tokens += title_tokens + [tokenizer.sep_token]
for entity in title_entities:
entity_ids.append(entity.entity_id)
entity_position_ids.append(list(range(entity.start, entity.end)))
if mentions is None:
mentions = []
entity_offset = len(tokens)
text_tokens, text_entities = _tokenize_text_with_mentions(text, mentions, tokenizer, entity_vocab)
tokens += text_tokens
for entity in text_entities:
entity_ids.append(entity.entity_id)
entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset)))
input_ids = tokenizer.convert_tokens_to_ids(tokens)
return {
"input_ids": input_ids,
"entity_ids": entity_ids,
"entity_position_ids": entity_position_ids,
}
def _tokenize_text_with_mentions(
text: str,
mentions: list[Mention],
tokenizer: PreTrainedTokenizerBase,
entity_vocab: dict[str, int],
) -> tuple[list[str], list[_Entity]]:
"""
Tokenize text while preserving mention boundaries and mapping entities.
Args:
text: Input text to tokenize
mentions: List of detected mentions in the text
tokenizer: Pre-trained tokenizer to use for tokenization
entity_vocab: Mapping from entity KB IDs to entity vocabulary indices
Returns:
Tuple containing:
- List of tokens from the tokenized text
- List of _Entity objects with entity IDs and token positions
"""
target_mentions = [mention for mention in mentions if mention.kb_id is not None and mention.kb_id in entity_vocab]
split_char_positions = {mention.start for mention in target_mentions} | {mention.end for mention in target_mentions}
tokens: list[str] = []
cur = 0
char_to_token_mapping = {}
for char_position in sorted(split_char_positions):
target_text = text[cur:char_position]
tokens += tokenizer.tokenize(target_text)
char_to_token_mapping[char_position] = len(tokens)
cur = char_position
tokens += tokenizer.tokenize(text[cur:])
entities = [
_Entity(
entity_vocab[mention.kb_id],
char_to_token_mapping[mention.start],
char_to_token_mapping[mention.end],
)
for mention in target_mentions
]
return tokens, entities
class KPRBertTokenizer(BertTokenizer):
vocab_files_names = {
**BertTokenizer.vocab_files_names, # Include the parent class files (vocab.txt)
"entity_linker_data_file": "entity_linker/data.npy",
"entity_linker_offsets_file": "entity_linker/offsets.npy",
"entity_linker_name_trie_file": "entity_linker/name.trie",
"entity_linker_kb_id_trie_file": "entity_linker/kb_id.trie",
"entity_linker_config_file": "entity_linker/config.json",
"entity_vocab_file": "entity_vocab.tsv",
"entity_embeddings_file": "entity_embeddings.npy",
}
model_input_names = [
"input_ids",
"token_type_ids",
"attention_mask",
"entity_ids",
"entity_position_ids",
]
def __init__(
self,
vocab_file,
entity_linker_data_file: str,
entity_vocab_file: str,
entity_embeddings_file: str | None = None,
*args,
**kwargs,
):
super().__init__(vocab_file=vocab_file, *args, **kwargs)
entity_linker_dir = str(Path(entity_linker_data_file).parent)
self.entity_linker = DictionaryEntityLinker.load(entity_linker_dir)
self.entity_to_id = load_tsv_entity_vocab(entity_vocab_file)
self.id_to_entity = {v: k for k, v in self.entity_to_id.items()}
self.entity_embeddings = None
if entity_embeddings_file:
# Use memory-mapped loading for large embeddings
self.entity_embeddings = np.load(entity_embeddings_file, mmap_mode="r")
if self.entity_embeddings.shape[0] != len(self.entity_to_id):
raise ValueError(
f"Entity embeddings shape {self.entity_embeddings.shape[0]} does not match "
f"the number of entities {len(self.entity_to_id)}. "
"Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
)
def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]:
mentions = self.entity_linker.detect_mentions(text)
model_inputs = preprocess_text(
text=text,
mentions=mentions,
title=None,
title_mentions=None,
tokenizer=self,
entity_vocab=self.entity_to_id,
)
# Prepare the inputs for the model
# This will add special tokens or truncate the input when specified in kwargs
# We exclude "return_tensors" from kwargs
# to avoid issues in passing the data to BatchEncoding outside this method
prepared_inputs = self.prepare_for_model(
model_inputs["input_ids"],
**{k: v for k, v in kwargs.items() if k != "return_tensors"},
)
model_inputs.update(prepared_inputs)
# Account for special tokens
if kwargs.get("add_special_tokens", True):
if prepared_inputs["input_ids"][0] != self.cls_token_id:
raise ValueError(
"We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
)
# Shift the entity position IDs by 1 to account for the [CLS] token
model_inputs["entity_position_ids"] = [
[pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"]
]
# If there is no entities in the text, we output padding entity for the model
if not model_inputs["entity_ids"]:
model_inputs["entity_ids"] = [0] # The padding entity id is 0
model_inputs["entity_position_ids"] = [[0]]
# Count the number of special tokens at the end of the input
num_special_tokens_at_end = 0
input_ids = prepared_inputs["input_ids"]
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.tolist()
for input_id in input_ids[::-1]:
if int(input_id) not in {
self.sep_token_id,
self.pad_token_id,
self.cls_token_id,
}:
break
num_special_tokens_at_end += 1
# Remove entities that are not in truncated input
max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
entity_indices_to_keep = list()
for i, position_ids in enumerate(model_inputs["entity_position_ids"]):
if len(position_ids) > 0 and max(position_ids) < max_effective_pos:
entity_indices_to_keep.append(i)
model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep]
if self.entity_embeddings is not None:
model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32)
return model_inputs
def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
for unsupported_arg in ["text_pair", "text_target", "text_pair_target"]:
if unsupported_arg in kwargs:
raise ValueError(
f"Argument '{unsupported_arg}' is not supported by {self.__class__.__name__}. "
"This tokenizer only supports single text inputs. "
)
if isinstance(text, str):
processed_inputs = self._preprocess_text(text, **kwargs)
return BatchEncoding(
processed_inputs,
tensor_type=kwargs.get("return_tensors", None),
prepend_batch_axis=True,
)
processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
collated_inputs = {
key: [item[key] for item in processed_inputs_list] for key in processed_inputs_list[0].keys()
}
if kwargs.get("padding"):
collated_inputs = self.pad(
collated_inputs,
padding=kwargs["padding"],
max_length=kwargs.get("max_length"),
pad_to_multiple_of=kwargs.get("pad_to_multiple_of"),
return_attention_mask=kwargs.get("return_attention_mask"),
verbose=kwargs.get("verbose", True),
)
# Pad entity ids
max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
for entity_ids in collated_inputs["entity_ids"]:
entity_ids += [0] * (max_num_entities - len(entity_ids))
# Pad entity position ids
flattened_entity_length = [
len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list
]
max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0
for entity_position_ids_list in collated_inputs["entity_position_ids"]:
# pad entity_position_ids to max_entity_token_length
for entity_position_ids in entity_position_ids_list:
entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids))
# pad to max_num_entities
entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * (
max_num_entities - len(entity_position_ids_list)
)
# Pad entity embeddings
if "entity_embeds" in collated_inputs:
for i in range(len(collated_inputs["entity_embeds"])):
collated_inputs["entity_embeds"][i] = np.pad(
collated_inputs["entity_embeds"][i],
pad_width=(
(
0,
max_num_entities - len(collated_inputs["entity_embeds"][i]),
),
(0, 0),
),
mode="constant",
constant_values=0,
)
return BatchEncoding(collated_inputs, tensor_type=kwargs.get("return_tensors", None))
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
os.makedirs(save_directory, exist_ok=True)
saved_files = list(super().save_vocabulary(save_directory, filename_prefix))
# Save entity linker data
entity_linker_save_dir = str(
Path(save_directory) / Path(self.vocab_files_names["entity_linker_data_file"]).parent
)
self.entity_linker.save(entity_linker_save_dir)
for file_name in self.vocab_files_names.values():
if file_name.startswith("entity_linker/"):
saved_files.append(file_name)
# Save entity vocabulary
entity_vocab_path = str(Path(save_directory) / self.vocab_files_names["entity_vocab_file"])
save_tsv_entity_vocab(entity_vocab_path, self.entity_to_id)
saved_files.append(self.vocab_files_names["entity_vocab_file"])
if self.entity_embeddings is not None:
entity_embeddings_path = str(Path(save_directory) / self.vocab_files_names["entity_embeddings_file"])
np.save(entity_embeddings_path, self.entity_embeddings)
saved_files.append(self.vocab_files_names["entity_embeddings_file"])
return tuple(saved_files)