from __future__ import annotations import csv import json import os from dataclasses import dataclass from pathlib import Path from typing import NamedTuple import numpy as np import torch import spacy from marisa_trie import Trie from transformers import BatchEncoding, BertTokenizer, PreTrainedTokenizerBase NONE_ID = "" @dataclass class Mention: kb_id: str | None text: str start: int end: int link_count: int | None total_link_count: int | None doc_count: int | None @property def span(self) -> tuple[int, int]: return self.start, self.end @property def link_prob(self) -> float | None: if self.doc_count is None or self.total_link_count is None: return None elif self.doc_count > 0: return min(1.0, self.total_link_count / self.doc_count) else: return 0.0 @property def prior_prob(self) -> float | None: if self.link_count is None or self.total_link_count is None: return None elif self.total_link_count > 0: return min(1.0, self.link_count / self.total_link_count) else: return 0.0 def __repr__(self): return f" {self.kb_id}>" def get_tokenizer(language: str) -> spacy.tokenizer.Tokenizer: language_obj = spacy.blank(language) return language_obj.tokenizer class DictionaryEntityLinker: def __init__( self, name_trie: Trie, kb_id_trie: Trie, data: np.ndarray, offsets: np.ndarray, max_mention_length: int, case_sensitive: bool, min_link_prob: float | None, min_prior_prob: float | None, min_link_count: int | None, ): self._name_trie = name_trie self._kb_id_trie = kb_id_trie self._data = data self._offsets = offsets self._max_mention_length = max_mention_length self._case_sensitive = case_sensitive self._min_link_prob = min_link_prob self._min_prior_prob = min_prior_prob self._min_link_count = min_link_count self._tokenizer = get_tokenizer("en") @staticmethod def load( data_dir: str, min_link_prob: float | None = None, min_prior_prob: float | None = None, min_link_count: int | None = None, ) -> "DictionaryEntityLinker": data = np.load(os.path.join(data_dir, "data.npy")) offsets = np.load(os.path.join(data_dir, "offsets.npy")) name_trie = Trie() name_trie.load(os.path.join(data_dir, "name.trie")) kb_id_trie = Trie() kb_id_trie.load(os.path.join(data_dir, "kb_id.trie")) with open(os.path.join(data_dir, "config.json")) as config_file: config = json.load(config_file) if min_link_prob is None: min_link_prob = config.get("min_link_prob", None) if min_prior_prob is None: min_prior_prob = config.get("min_prior_prob", None) if min_link_count is None: min_link_count = config.get("min_link_count", None) return DictionaryEntityLinker( name_trie=name_trie, kb_id_trie=kb_id_trie, data=data, offsets=offsets, max_mention_length=config["max_mention_length"], case_sensitive=config["case_sensitive"], min_link_prob=min_link_prob, min_prior_prob=min_prior_prob, min_link_count=min_link_count, ) def detect_mentions(self, text: str) -> list[Mention]: tokens = self._tokenizer(text) end_offsets = frozenset(token.idx + len(token) for token in tokens) if not self._case_sensitive: text = text.lower() ret = [] cur = 0 for token in tokens: start = token.idx if cur > start: continue for prefix in sorted( self._name_trie.prefixes(text[start : start + self._max_mention_length]), key=len, reverse=True, ): end = start + len(prefix) if end in end_offsets: matched = False mention_idx = self._name_trie[prefix] data_start, data_end = self._offsets[mention_idx : mention_idx + 2] for item in self._data[data_start:data_end]: if item.size == 4: kb_idx, link_count, total_link_count, doc_count = item elif item.size == 1: (kb_idx,) = item link_count, total_link_count, doc_count = None, None, None else: raise ValueError("Unexpected data array format") mention = Mention( kb_id=self._kb_id_trie.restore_key(kb_idx), text=prefix, start=start, end=end, link_count=link_count, total_link_count=total_link_count, doc_count=doc_count, ) if item.size == 1 or ( mention.link_prob >= self._min_link_prob and mention.prior_prob >= self._min_prior_prob and mention.link_count >= self._min_link_count ): ret.append(mention) matched = True if matched: cur = end break return ret def detect_mentions_batch(self, texts: list[str]) -> list[list[Mention]]: return [self.detect_mentions(text) for text in texts] def save(self, data_dir: str) -> None: """ Save the entity linker data to the specified directory. Args: data_dir: Directory to save the entity linker data """ os.makedirs(data_dir, exist_ok=True) # Save numpy arrays np.save(os.path.join(data_dir, "data.npy"), self._data) np.save(os.path.join(data_dir, "offsets.npy"), self._offsets) # Save tries self._name_trie.save(os.path.join(data_dir, "name.trie")) self._kb_id_trie.save(os.path.join(data_dir, "kb_id.trie")) # Save configuration with open(os.path.join(data_dir, "config.json"), "w") as config_file: json.dump( { "max_mention_length": self._max_mention_length, "case_sensitive": self._case_sensitive, "min_link_prob": self._min_link_prob, "min_prior_prob": self._min_prior_prob, "min_link_count": self._min_link_count, }, config_file, ) def load_tsv_entity_vocab(file_path: str) -> dict[str, int]: vocab = {} with open(file_path, "r", encoding="utf-8") as file: reader = csv.reader(file, delimiter="\t") for row in reader: vocab[row[0]] = int(row[1]) return vocab def save_tsv_entity_vocab(file_path: str, entity_vocab: dict[str, int]) -> None: """ Save entity vocabulary to a TSV file. Args: file_path: Path to save the entity vocabulary entity_vocab: Entity vocabulary to save """ os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "w", encoding="utf-8") as f: writer = csv.writer(f, delimiter="\t") for entity_id, idx in entity_vocab.items(): writer.writerow([entity_id, idx]) class _Entity(NamedTuple): entity_id: int start: int end: int @property def length(self) -> int: return self.end - self.start def preprocess_text( text: str, mentions: list[Mention] | None, title: str | None, title_mentions: list[Mention] | None, tokenizer: PreTrainedTokenizerBase, entity_vocab: dict[str, int], ) -> dict[str, list[int]]: tokens = [] entity_ids = [] entity_position_ids = [] if title is not None: if title_mentions is None: title_mentions = [] title_tokens, title_entities = _tokenize_text_with_mentions(title, title_mentions, tokenizer, entity_vocab) tokens += title_tokens + [tokenizer.sep_token] for entity in title_entities: entity_ids.append(entity.entity_id) entity_position_ids.append(list(range(entity.start, entity.end))) if mentions is None: mentions = [] entity_offset = len(tokens) text_tokens, text_entities = _tokenize_text_with_mentions(text, mentions, tokenizer, entity_vocab) tokens += text_tokens for entity in text_entities: entity_ids.append(entity.entity_id) entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset))) input_ids = tokenizer.convert_tokens_to_ids(tokens) return { "input_ids": input_ids, "entity_ids": entity_ids, "entity_position_ids": entity_position_ids, } def _tokenize_text_with_mentions( text: str, mentions: list[Mention], tokenizer: PreTrainedTokenizerBase, entity_vocab: dict[str, int], ) -> tuple[list[str], list[_Entity]]: """ Tokenize text while preserving mention boundaries and mapping entities. Args: text: Input text to tokenize mentions: List of detected mentions in the text tokenizer: Pre-trained tokenizer to use for tokenization entity_vocab: Mapping from entity KB IDs to entity vocabulary indices Returns: Tuple containing: - List of tokens from the tokenized text - List of _Entity objects with entity IDs and token positions """ target_mentions = [mention for mention in mentions if mention.kb_id is not None and mention.kb_id in entity_vocab] split_char_positions = {mention.start for mention in target_mentions} | {mention.end for mention in target_mentions} tokens: list[str] = [] cur = 0 char_to_token_mapping = {} for char_position in sorted(split_char_positions): target_text = text[cur:char_position] tokens += tokenizer.tokenize(target_text) char_to_token_mapping[char_position] = len(tokens) cur = char_position tokens += tokenizer.tokenize(text[cur:]) entities = [ _Entity( entity_vocab[mention.kb_id], char_to_token_mapping[mention.start], char_to_token_mapping[mention.end], ) for mention in target_mentions ] return tokens, entities class KPRBertTokenizer(BertTokenizer): vocab_files_names = { **BertTokenizer.vocab_files_names, # Include the parent class files (vocab.txt) "entity_linker_data_file": "entity_linker/data.npy", "entity_linker_offsets_file": "entity_linker/offsets.npy", "entity_linker_name_trie_file": "entity_linker/name.trie", "entity_linker_kb_id_trie_file": "entity_linker/kb_id.trie", "entity_linker_config_file": "entity_linker/config.json", "entity_vocab_file": "entity_vocab.tsv", "entity_embeddings_file": "entity_embeddings.npy", } model_input_names = [ "input_ids", "token_type_ids", "attention_mask", "entity_ids", "entity_position_ids", ] def __init__( self, vocab_file, entity_linker_data_file: str, entity_vocab_file: str, entity_embeddings_file: str | None = None, *args, **kwargs, ): super().__init__(vocab_file=vocab_file, *args, **kwargs) entity_linker_dir = str(Path(entity_linker_data_file).parent) self.entity_linker = DictionaryEntityLinker.load(entity_linker_dir) self.entity_to_id = load_tsv_entity_vocab(entity_vocab_file) self.id_to_entity = {v: k for k, v in self.entity_to_id.items()} self.entity_embeddings = None if entity_embeddings_file: # Use memory-mapped loading for large embeddings self.entity_embeddings = np.load(entity_embeddings_file, mmap_mode="r") if self.entity_embeddings.shape[0] != len(self.entity_to_id): raise ValueError( f"Entity embeddings shape {self.entity_embeddings.shape[0]} does not match " f"the number of entities {len(self.entity_to_id)}. " "Make sure `embeddings.py` and `entity_vocab.tsv` are consistent." ) def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]: mentions = self.entity_linker.detect_mentions(text) model_inputs = preprocess_text( text=text, mentions=mentions, title=None, title_mentions=None, tokenizer=self, entity_vocab=self.entity_to_id, ) # Prepare the inputs for the model # This will add special tokens or truncate the input when specified in kwargs # We exclude "return_tensors" from kwargs # to avoid issues in passing the data to BatchEncoding outside this method prepared_inputs = self.prepare_for_model( model_inputs["input_ids"], **{k: v for k, v in kwargs.items() if k != "return_tensors"}, ) model_inputs.update(prepared_inputs) # Account for special tokens if kwargs.get("add_special_tokens", True): if prepared_inputs["input_ids"][0] != self.cls_token_id: raise ValueError( "We assume that the input IDs start with the [CLS] token with add_special_tokens = True." ) # Shift the entity position IDs by 1 to account for the [CLS] token model_inputs["entity_position_ids"] = [ [pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"] ] # If there is no entities in the text, we output padding entity for the model if not model_inputs["entity_ids"]: model_inputs["entity_ids"] = [0] # The padding entity id is 0 model_inputs["entity_position_ids"] = [[0]] # Count the number of special tokens at the end of the input num_special_tokens_at_end = 0 input_ids = prepared_inputs["input_ids"] if isinstance(input_ids, torch.Tensor): input_ids = input_ids.tolist() for input_id in input_ids[::-1]: if int(input_id) not in { self.sep_token_id, self.pad_token_id, self.cls_token_id, }: break num_special_tokens_at_end += 1 # Remove entities that are not in truncated input max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end entity_indices_to_keep = list() for i, position_ids in enumerate(model_inputs["entity_position_ids"]): if len(position_ids) > 0 and max(position_ids) < max_effective_pos: entity_indices_to_keep.append(i) model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep] model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep] if self.entity_embeddings is not None: model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32) return model_inputs def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding: for unsupported_arg in ["text_pair", "text_target", "text_pair_target"]: if unsupported_arg in kwargs: raise ValueError( f"Argument '{unsupported_arg}' is not supported by {self.__class__.__name__}. " "This tokenizer only supports single text inputs. " ) if isinstance(text, str): processed_inputs = self._preprocess_text(text, **kwargs) return BatchEncoding( processed_inputs, tensor_type=kwargs.get("return_tensors", None), prepend_batch_axis=True, ) processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text] collated_inputs = { key: [item[key] for item in processed_inputs_list] for key in processed_inputs_list[0].keys() } if kwargs.get("padding"): collated_inputs = self.pad( collated_inputs, padding=kwargs["padding"], max_length=kwargs.get("max_length"), pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), return_attention_mask=kwargs.get("return_attention_mask"), verbose=kwargs.get("verbose", True), ) # Pad entity ids max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"]) for entity_ids in collated_inputs["entity_ids"]: entity_ids += [0] * (max_num_entities - len(entity_ids)) # Pad entity position ids flattened_entity_length = [ len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list ] max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0 for entity_position_ids_list in collated_inputs["entity_position_ids"]: # pad entity_position_ids to max_entity_token_length for entity_position_ids in entity_position_ids_list: entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids)) # pad to max_num_entities entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * ( max_num_entities - len(entity_position_ids_list) ) # Pad entity embeddings if "entity_embeds" in collated_inputs: for i in range(len(collated_inputs["entity_embeds"])): collated_inputs["entity_embeds"][i] = np.pad( collated_inputs["entity_embeds"][i], pad_width=( ( 0, max_num_entities - len(collated_inputs["entity_embeds"][i]), ), (0, 0), ), mode="constant", constant_values=0, ) return BatchEncoding(collated_inputs, tensor_type=kwargs.get("return_tensors", None)) def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]: os.makedirs(save_directory, exist_ok=True) saved_files = list(super().save_vocabulary(save_directory, filename_prefix)) # Save entity linker data entity_linker_save_dir = str( Path(save_directory) / Path(self.vocab_files_names["entity_linker_data_file"]).parent ) self.entity_linker.save(entity_linker_save_dir) for file_name in self.vocab_files_names.values(): if file_name.startswith("entity_linker/"): saved_files.append(file_name) # Save entity vocabulary entity_vocab_path = str(Path(save_directory) / self.vocab_files_names["entity_vocab_file"]) save_tsv_entity_vocab(entity_vocab_path, self.entity_to_id) saved_files.append(self.vocab_files_names["entity_vocab_file"]) if self.entity_embeddings is not None: entity_embeddings_path = str(Path(save_directory) / self.vocab_files_names["entity_embeddings_file"]) np.save(entity_embeddings_path, self.entity_embeddings) saved_files.append(self.vocab_files_names["entity_embeddings_file"]) return tuple(saved_files)