|
from __future__ import annotations |
|
|
|
import csv |
|
import json |
|
import os |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import NamedTuple |
|
|
|
import numpy as np |
|
import torch |
|
import spacy |
|
from marisa_trie import Trie |
|
from transformers import BatchEncoding, BertTokenizer, PreTrainedTokenizerBase |
|
|
|
NONE_ID = "<None>" |
|
|
|
|
|
@dataclass |
|
class Mention: |
|
kb_id: str | None |
|
text: str |
|
start: int |
|
end: int |
|
link_count: int | None |
|
total_link_count: int | None |
|
doc_count: int | None |
|
|
|
@property |
|
def span(self) -> tuple[int, int]: |
|
return self.start, self.end |
|
|
|
@property |
|
def link_prob(self) -> float | None: |
|
if self.doc_count is None or self.total_link_count is None: |
|
return None |
|
elif self.doc_count > 0: |
|
return min(1.0, self.total_link_count / self.doc_count) |
|
else: |
|
return 0.0 |
|
|
|
@property |
|
def prior_prob(self) -> float | None: |
|
if self.link_count is None or self.total_link_count is None: |
|
return None |
|
elif self.total_link_count > 0: |
|
return min(1.0, self.link_count / self.total_link_count) |
|
else: |
|
return 0.0 |
|
|
|
def __repr__(self): |
|
return f"<Mention {self.text} -> {self.kb_id}>" |
|
|
|
|
|
def get_tokenizer(language: str) -> spacy.tokenizer.Tokenizer: |
|
language_obj = spacy.blank(language) |
|
return language_obj.tokenizer |
|
|
|
|
|
class DictionaryEntityLinker: |
|
def __init__( |
|
self, |
|
name_trie: Trie, |
|
kb_id_trie: Trie, |
|
data: np.ndarray, |
|
offsets: np.ndarray, |
|
max_mention_length: int, |
|
case_sensitive: bool, |
|
min_link_prob: float | None, |
|
min_prior_prob: float | None, |
|
min_link_count: int | None, |
|
): |
|
self._name_trie = name_trie |
|
self._kb_id_trie = kb_id_trie |
|
self._data = data |
|
self._offsets = offsets |
|
self._max_mention_length = max_mention_length |
|
self._case_sensitive = case_sensitive |
|
|
|
self._min_link_prob = min_link_prob |
|
self._min_prior_prob = min_prior_prob |
|
self._min_link_count = min_link_count |
|
|
|
self._tokenizer = get_tokenizer("en") |
|
|
|
@staticmethod |
|
def load( |
|
data_dir: str, |
|
min_link_prob: float | None = None, |
|
min_prior_prob: float | None = None, |
|
min_link_count: int | None = None, |
|
) -> "DictionaryEntityLinker": |
|
data = np.load(os.path.join(data_dir, "data.npy")) |
|
offsets = np.load(os.path.join(data_dir, "offsets.npy")) |
|
name_trie = Trie() |
|
name_trie.load(os.path.join(data_dir, "name.trie")) |
|
kb_id_trie = Trie() |
|
kb_id_trie.load(os.path.join(data_dir, "kb_id.trie")) |
|
|
|
with open(os.path.join(data_dir, "config.json")) as config_file: |
|
config = json.load(config_file) |
|
|
|
if min_link_prob is None: |
|
min_link_prob = config.get("min_link_prob", None) |
|
|
|
if min_prior_prob is None: |
|
min_prior_prob = config.get("min_prior_prob", None) |
|
|
|
if min_link_count is None: |
|
min_link_count = config.get("min_link_count", None) |
|
|
|
return DictionaryEntityLinker( |
|
name_trie=name_trie, |
|
kb_id_trie=kb_id_trie, |
|
data=data, |
|
offsets=offsets, |
|
max_mention_length=config["max_mention_length"], |
|
case_sensitive=config["case_sensitive"], |
|
min_link_prob=min_link_prob, |
|
min_prior_prob=min_prior_prob, |
|
min_link_count=min_link_count, |
|
) |
|
|
|
def detect_mentions(self, text: str) -> list[Mention]: |
|
tokens = self._tokenizer(text) |
|
end_offsets = frozenset(token.idx + len(token) for token in tokens) |
|
if not self._case_sensitive: |
|
text = text.lower() |
|
|
|
ret = [] |
|
cur = 0 |
|
for token in tokens: |
|
start = token.idx |
|
if cur > start: |
|
continue |
|
|
|
for prefix in sorted( |
|
self._name_trie.prefixes(text[start : start + self._max_mention_length]), |
|
key=len, |
|
reverse=True, |
|
): |
|
end = start + len(prefix) |
|
if end in end_offsets: |
|
matched = False |
|
mention_idx = self._name_trie[prefix] |
|
data_start, data_end = self._offsets[mention_idx : mention_idx + 2] |
|
for item in self._data[data_start:data_end]: |
|
if item.size == 4: |
|
kb_idx, link_count, total_link_count, doc_count = item |
|
elif item.size == 1: |
|
(kb_idx,) = item |
|
link_count, total_link_count, doc_count = None, None, None |
|
else: |
|
raise ValueError("Unexpected data array format") |
|
|
|
mention = Mention( |
|
kb_id=self._kb_id_trie.restore_key(kb_idx), |
|
text=prefix, |
|
start=start, |
|
end=end, |
|
link_count=link_count, |
|
total_link_count=total_link_count, |
|
doc_count=doc_count, |
|
) |
|
if item.size == 1 or ( |
|
mention.link_prob >= self._min_link_prob |
|
and mention.prior_prob >= self._min_prior_prob |
|
and mention.link_count >= self._min_link_count |
|
): |
|
ret.append(mention) |
|
|
|
matched = True |
|
|
|
if matched: |
|
cur = end |
|
break |
|
|
|
return ret |
|
|
|
def detect_mentions_batch(self, texts: list[str]) -> list[list[Mention]]: |
|
return [self.detect_mentions(text) for text in texts] |
|
|
|
def save(self, data_dir: str) -> None: |
|
""" |
|
Save the entity linker data to the specified directory. |
|
|
|
Args: |
|
data_dir: Directory to save the entity linker data |
|
""" |
|
os.makedirs(data_dir, exist_ok=True) |
|
|
|
|
|
np.save(os.path.join(data_dir, "data.npy"), self._data) |
|
np.save(os.path.join(data_dir, "offsets.npy"), self._offsets) |
|
|
|
|
|
self._name_trie.save(os.path.join(data_dir, "name.trie")) |
|
self._kb_id_trie.save(os.path.join(data_dir, "kb_id.trie")) |
|
|
|
|
|
with open(os.path.join(data_dir, "config.json"), "w") as config_file: |
|
json.dump( |
|
{ |
|
"max_mention_length": self._max_mention_length, |
|
"case_sensitive": self._case_sensitive, |
|
"min_link_prob": self._min_link_prob, |
|
"min_prior_prob": self._min_prior_prob, |
|
"min_link_count": self._min_link_count, |
|
}, |
|
config_file, |
|
) |
|
|
|
|
|
def load_tsv_entity_vocab(file_path: str) -> dict[str, int]: |
|
vocab = {} |
|
with open(file_path, "r", encoding="utf-8") as file: |
|
reader = csv.reader(file, delimiter="\t") |
|
for row in reader: |
|
vocab[row[0]] = int(row[1]) |
|
return vocab |
|
|
|
|
|
def save_tsv_entity_vocab(file_path: str, entity_vocab: dict[str, int]) -> None: |
|
""" |
|
Save entity vocabulary to a TSV file. |
|
|
|
Args: |
|
file_path: Path to save the entity vocabulary |
|
entity_vocab: Entity vocabulary to save |
|
""" |
|
os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
with open(file_path, "w", encoding="utf-8") as f: |
|
writer = csv.writer(f, delimiter="\t") |
|
for entity_id, idx in entity_vocab.items(): |
|
writer.writerow([entity_id, idx]) |
|
|
|
|
|
class _Entity(NamedTuple): |
|
entity_id: int |
|
start: int |
|
end: int |
|
|
|
@property |
|
def length(self) -> int: |
|
return self.end - self.start |
|
|
|
|
|
def preprocess_text( |
|
text: str, |
|
mentions: list[Mention] | None, |
|
title: str | None, |
|
title_mentions: list[Mention] | None, |
|
tokenizer: PreTrainedTokenizerBase, |
|
entity_vocab: dict[str, int], |
|
) -> dict[str, list[int]]: |
|
tokens = [] |
|
entity_ids = [] |
|
entity_position_ids = [] |
|
if title is not None: |
|
if title_mentions is None: |
|
title_mentions = [] |
|
|
|
title_tokens, title_entities = _tokenize_text_with_mentions(title, title_mentions, tokenizer, entity_vocab) |
|
tokens += title_tokens + [tokenizer.sep_token] |
|
for entity in title_entities: |
|
entity_ids.append(entity.entity_id) |
|
entity_position_ids.append(list(range(entity.start, entity.end))) |
|
|
|
if mentions is None: |
|
mentions = [] |
|
|
|
entity_offset = len(tokens) |
|
text_tokens, text_entities = _tokenize_text_with_mentions(text, mentions, tokenizer, entity_vocab) |
|
tokens += text_tokens |
|
for entity in text_entities: |
|
entity_ids.append(entity.entity_id) |
|
entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset))) |
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"entity_ids": entity_ids, |
|
"entity_position_ids": entity_position_ids, |
|
} |
|
|
|
|
|
def _tokenize_text_with_mentions( |
|
text: str, |
|
mentions: list[Mention], |
|
tokenizer: PreTrainedTokenizerBase, |
|
entity_vocab: dict[str, int], |
|
) -> tuple[list[str], list[_Entity]]: |
|
""" |
|
Tokenize text while preserving mention boundaries and mapping entities. |
|
|
|
Args: |
|
text: Input text to tokenize |
|
mentions: List of detected mentions in the text |
|
tokenizer: Pre-trained tokenizer to use for tokenization |
|
entity_vocab: Mapping from entity KB IDs to entity vocabulary indices |
|
|
|
Returns: |
|
Tuple containing: |
|
- List of tokens from the tokenized text |
|
- List of _Entity objects with entity IDs and token positions |
|
""" |
|
target_mentions = [mention for mention in mentions if mention.kb_id is not None and mention.kb_id in entity_vocab] |
|
split_char_positions = {mention.start for mention in target_mentions} | {mention.end for mention in target_mentions} |
|
|
|
tokens: list[str] = [] |
|
cur = 0 |
|
char_to_token_mapping = {} |
|
for char_position in sorted(split_char_positions): |
|
target_text = text[cur:char_position] |
|
tokens += tokenizer.tokenize(target_text) |
|
char_to_token_mapping[char_position] = len(tokens) |
|
cur = char_position |
|
tokens += tokenizer.tokenize(text[cur:]) |
|
|
|
entities = [ |
|
_Entity( |
|
entity_vocab[mention.kb_id], |
|
char_to_token_mapping[mention.start], |
|
char_to_token_mapping[mention.end], |
|
) |
|
for mention in target_mentions |
|
] |
|
return tokens, entities |
|
|
|
|
|
class KPRBertTokenizer(BertTokenizer): |
|
vocab_files_names = { |
|
**BertTokenizer.vocab_files_names, |
|
"entity_linker_data_file": "entity_linker/data.npy", |
|
"entity_linker_offsets_file": "entity_linker/offsets.npy", |
|
"entity_linker_name_trie_file": "entity_linker/name.trie", |
|
"entity_linker_kb_id_trie_file": "entity_linker/kb_id.trie", |
|
"entity_linker_config_file": "entity_linker/config.json", |
|
"entity_vocab_file": "entity_vocab.tsv", |
|
"entity_embeddings_file": "entity_embeddings.npy", |
|
} |
|
model_input_names = [ |
|
"input_ids", |
|
"token_type_ids", |
|
"attention_mask", |
|
"entity_ids", |
|
"entity_position_ids", |
|
] |
|
|
|
def __init__( |
|
self, |
|
vocab_file, |
|
entity_linker_data_file: str, |
|
entity_vocab_file: str, |
|
entity_embeddings_file: str | None = None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__(vocab_file=vocab_file, *args, **kwargs) |
|
entity_linker_dir = str(Path(entity_linker_data_file).parent) |
|
self.entity_linker = DictionaryEntityLinker.load(entity_linker_dir) |
|
self.entity_to_id = load_tsv_entity_vocab(entity_vocab_file) |
|
self.id_to_entity = {v: k for k, v in self.entity_to_id.items()} |
|
|
|
self.entity_embeddings = None |
|
if entity_embeddings_file: |
|
|
|
self.entity_embeddings = np.load(entity_embeddings_file, mmap_mode="r") |
|
if self.entity_embeddings.shape[0] != len(self.entity_to_id): |
|
raise ValueError( |
|
f"Entity embeddings shape {self.entity_embeddings.shape[0]} does not match " |
|
f"the number of entities {len(self.entity_to_id)}. " |
|
"Make sure `embeddings.py` and `entity_vocab.tsv` are consistent." |
|
) |
|
|
|
def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]: |
|
mentions = self.entity_linker.detect_mentions(text) |
|
model_inputs = preprocess_text( |
|
text=text, |
|
mentions=mentions, |
|
title=None, |
|
title_mentions=None, |
|
tokenizer=self, |
|
entity_vocab=self.entity_to_id, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
prepared_inputs = self.prepare_for_model( |
|
model_inputs["input_ids"], |
|
**{k: v for k, v in kwargs.items() if k != "return_tensors"}, |
|
) |
|
model_inputs.update(prepared_inputs) |
|
|
|
|
|
if kwargs.get("add_special_tokens", True): |
|
if prepared_inputs["input_ids"][0] != self.cls_token_id: |
|
raise ValueError( |
|
"We assume that the input IDs start with the [CLS] token with add_special_tokens = True." |
|
) |
|
|
|
model_inputs["entity_position_ids"] = [ |
|
[pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"] |
|
] |
|
|
|
|
|
if not model_inputs["entity_ids"]: |
|
model_inputs["entity_ids"] = [0] |
|
model_inputs["entity_position_ids"] = [[0]] |
|
|
|
|
|
num_special_tokens_at_end = 0 |
|
input_ids = prepared_inputs["input_ids"] |
|
if isinstance(input_ids, torch.Tensor): |
|
input_ids = input_ids.tolist() |
|
for input_id in input_ids[::-1]: |
|
if int(input_id) not in { |
|
self.sep_token_id, |
|
self.pad_token_id, |
|
self.cls_token_id, |
|
}: |
|
break |
|
num_special_tokens_at_end += 1 |
|
|
|
|
|
max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end |
|
entity_indices_to_keep = list() |
|
for i, position_ids in enumerate(model_inputs["entity_position_ids"]): |
|
if len(position_ids) > 0 and max(position_ids) < max_effective_pos: |
|
entity_indices_to_keep.append(i) |
|
model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep] |
|
model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep] |
|
|
|
if self.entity_embeddings is not None: |
|
model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32) |
|
return model_inputs |
|
|
|
def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding: |
|
for unsupported_arg in ["text_pair", "text_target", "text_pair_target"]: |
|
if unsupported_arg in kwargs: |
|
raise ValueError( |
|
f"Argument '{unsupported_arg}' is not supported by {self.__class__.__name__}. " |
|
"This tokenizer only supports single text inputs. " |
|
) |
|
|
|
if isinstance(text, str): |
|
processed_inputs = self._preprocess_text(text, **kwargs) |
|
return BatchEncoding( |
|
processed_inputs, |
|
tensor_type=kwargs.get("return_tensors", None), |
|
prepend_batch_axis=True, |
|
) |
|
|
|
processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text] |
|
collated_inputs = { |
|
key: [item[key] for item in processed_inputs_list] for key in processed_inputs_list[0].keys() |
|
} |
|
if kwargs.get("padding"): |
|
collated_inputs = self.pad( |
|
collated_inputs, |
|
padding=kwargs["padding"], |
|
max_length=kwargs.get("max_length"), |
|
pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), |
|
return_attention_mask=kwargs.get("return_attention_mask"), |
|
verbose=kwargs.get("verbose", True), |
|
) |
|
|
|
max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"]) |
|
for entity_ids in collated_inputs["entity_ids"]: |
|
entity_ids += [0] * (max_num_entities - len(entity_ids)) |
|
|
|
flattened_entity_length = [ |
|
len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list |
|
] |
|
max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0 |
|
for entity_position_ids_list in collated_inputs["entity_position_ids"]: |
|
|
|
for entity_position_ids in entity_position_ids_list: |
|
entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids)) |
|
|
|
entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * ( |
|
max_num_entities - len(entity_position_ids_list) |
|
) |
|
|
|
if "entity_embeds" in collated_inputs: |
|
for i in range(len(collated_inputs["entity_embeds"])): |
|
collated_inputs["entity_embeds"][i] = np.pad( |
|
collated_inputs["entity_embeds"][i], |
|
pad_width=( |
|
( |
|
0, |
|
max_num_entities - len(collated_inputs["entity_embeds"][i]), |
|
), |
|
(0, 0), |
|
), |
|
mode="constant", |
|
constant_values=0, |
|
) |
|
return BatchEncoding(collated_inputs, tensor_type=kwargs.get("return_tensors", None)) |
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]: |
|
os.makedirs(save_directory, exist_ok=True) |
|
saved_files = list(super().save_vocabulary(save_directory, filename_prefix)) |
|
|
|
|
|
entity_linker_save_dir = str( |
|
Path(save_directory) / Path(self.vocab_files_names["entity_linker_data_file"]).parent |
|
) |
|
self.entity_linker.save(entity_linker_save_dir) |
|
for file_name in self.vocab_files_names.values(): |
|
if file_name.startswith("entity_linker/"): |
|
saved_files.append(file_name) |
|
|
|
|
|
entity_vocab_path = str(Path(save_directory) / self.vocab_files_names["entity_vocab_file"]) |
|
save_tsv_entity_vocab(entity_vocab_path, self.entity_to_id) |
|
saved_files.append(self.vocab_files_names["entity_vocab_file"]) |
|
|
|
if self.entity_embeddings is not None: |
|
entity_embeddings_path = str(Path(save_directory) / self.vocab_files_names["entity_embeddings_file"]) |
|
np.save(entity_embeddings_path, self.entity_embeddings) |
|
saved_files.append(self.vocab_files_names["entity_embeddings_file"]) |
|
return tuple(saved_files) |
|
|