Upload tokenizer
Browse files- tokenization_kpr.py +49 -32
tokenization_kpr.py
CHANGED
@@ -255,8 +255,7 @@ def preprocess_text(
|
|
255 |
) -> dict[str, list[int]]:
|
256 |
tokens = []
|
257 |
entity_ids = []
|
258 |
-
|
259 |
-
entity_lengths = []
|
260 |
if title is not None:
|
261 |
if title_mentions is None:
|
262 |
title_mentions = []
|
@@ -265,8 +264,7 @@ def preprocess_text(
|
|
265 |
tokens += title_tokens + [tokenizer.sep_token]
|
266 |
for entity in title_entities:
|
267 |
entity_ids.append(entity.entity_id)
|
268 |
-
|
269 |
-
entity_lengths.append(entity.end - entity.start)
|
270 |
|
271 |
if mentions is None:
|
272 |
mentions = []
|
@@ -276,16 +274,14 @@ def preprocess_text(
|
|
276 |
tokens += text_tokens
|
277 |
for entity in text_entities:
|
278 |
entity_ids.append(entity.entity_id)
|
279 |
-
|
280 |
-
entity_lengths.append(entity.end - entity.start)
|
281 |
|
282 |
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
283 |
|
284 |
return {
|
285 |
"input_ids": input_ids,
|
286 |
"entity_ids": entity_ids,
|
287 |
-
"
|
288 |
-
"entity_lengths": entity_lengths,
|
289 |
}
|
290 |
|
291 |
|
@@ -349,8 +345,7 @@ class KPRBertTokenizer(BertTokenizer):
|
|
349 |
"token_type_ids",
|
350 |
"attention_mask",
|
351 |
"entity_ids",
|
352 |
-
"
|
353 |
-
"entity_lengths",
|
354 |
]
|
355 |
|
356 |
def __init__(
|
@@ -379,7 +374,7 @@ class KPRBertTokenizer(BertTokenizer):
|
|
379 |
"Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
|
380 |
)
|
381 |
|
382 |
-
def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int]]:
|
383 |
mentions = self.entity_linker.detect_mentions(text)
|
384 |
model_inputs = preprocess_text(
|
385 |
text=text,
|
@@ -395,18 +390,26 @@ class KPRBertTokenizer(BertTokenizer):
|
|
395 |
# We exclude "return_tensors" from kwargs
|
396 |
# to avoid issues in passing the data to BatchEncoding outside this method
|
397 |
prepared_inputs = self.prepare_for_model(
|
398 |
-
model_inputs["input_ids"],
|
|
|
399 |
)
|
400 |
model_inputs.update(prepared_inputs)
|
401 |
|
402 |
# Account for special tokens
|
403 |
-
if kwargs.get("add_special_tokens"):
|
404 |
if prepared_inputs["input_ids"][0] != self.cls_token_id:
|
405 |
raise ValueError(
|
406 |
"We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
|
407 |
)
|
408 |
-
# Shift the entity
|
409 |
-
model_inputs["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
# Count the number of special tokens at the end of the input
|
412 |
num_special_tokens_at_end = 0
|
@@ -414,26 +417,25 @@ class KPRBertTokenizer(BertTokenizer):
|
|
414 |
if isinstance(input_ids, torch.Tensor):
|
415 |
input_ids = input_ids.tolist()
|
416 |
for input_id in input_ids[::-1]:
|
417 |
-
if int(input_id) not in {
|
|
|
|
|
|
|
|
|
418 |
break
|
419 |
num_special_tokens_at_end += 1
|
420 |
|
421 |
# Remove entities that are not in truncated input
|
422 |
max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
|
423 |
entity_indices_to_keep = list()
|
424 |
-
for i,
|
425 |
-
|
426 |
-
):
|
427 |
-
if (start_pos + length) <= max_effective_pos:
|
428 |
entity_indices_to_keep.append(i)
|
429 |
model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
|
430 |
-
model_inputs["
|
431 |
-
model_inputs["entity_start_positions"][i] for i in entity_indices_to_keep
|
432 |
-
]
|
433 |
-
model_inputs["entity_lengths"] = [model_inputs["entity_lengths"][i] for i in entity_indices_to_keep]
|
434 |
|
435 |
if self.entity_embeddings is not None:
|
436 |
-
model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]]
|
437 |
return model_inputs
|
438 |
|
439 |
def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
|
@@ -447,7 +449,9 @@ class KPRBertTokenizer(BertTokenizer):
|
|
447 |
if isinstance(text, str):
|
448 |
processed_inputs = self._preprocess_text(text, **kwargs)
|
449 |
return BatchEncoding(
|
450 |
-
processed_inputs,
|
|
|
|
|
451 |
)
|
452 |
|
453 |
processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
|
@@ -463,20 +467,33 @@ class KPRBertTokenizer(BertTokenizer):
|
|
463 |
return_attention_mask=kwargs.get("return_attention_mask"),
|
464 |
verbose=kwargs.get("verbose", True),
|
465 |
)
|
466 |
-
#
|
467 |
max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
|
468 |
for entity_ids in collated_inputs["entity_ids"]:
|
469 |
entity_ids += [0] * (max_num_entities - len(entity_ids))
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
if "entity_embeds" in collated_inputs:
|
475 |
for i in range(len(collated_inputs["entity_embeds"])):
|
476 |
collated_inputs["entity_embeds"][i] = np.pad(
|
477 |
collated_inputs["entity_embeds"][i],
|
478 |
pad_width=(
|
479 |
-
(
|
|
|
|
|
|
|
480 |
(0, 0),
|
481 |
),
|
482 |
mode="constant",
|
|
|
255 |
) -> dict[str, list[int]]:
|
256 |
tokens = []
|
257 |
entity_ids = []
|
258 |
+
entity_position_ids = []
|
|
|
259 |
if title is not None:
|
260 |
if title_mentions is None:
|
261 |
title_mentions = []
|
|
|
264 |
tokens += title_tokens + [tokenizer.sep_token]
|
265 |
for entity in title_entities:
|
266 |
entity_ids.append(entity.entity_id)
|
267 |
+
entity_position_ids.append(list(range(entity.start, entity.end)))
|
|
|
268 |
|
269 |
if mentions is None:
|
270 |
mentions = []
|
|
|
274 |
tokens += text_tokens
|
275 |
for entity in text_entities:
|
276 |
entity_ids.append(entity.entity_id)
|
277 |
+
entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset)))
|
|
|
278 |
|
279 |
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
280 |
|
281 |
return {
|
282 |
"input_ids": input_ids,
|
283 |
"entity_ids": entity_ids,
|
284 |
+
"entity_position_ids": entity_position_ids,
|
|
|
285 |
}
|
286 |
|
287 |
|
|
|
345 |
"token_type_ids",
|
346 |
"attention_mask",
|
347 |
"entity_ids",
|
348 |
+
"entity_position_ids",
|
|
|
349 |
]
|
350 |
|
351 |
def __init__(
|
|
|
374 |
"Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
|
375 |
)
|
376 |
|
377 |
+
def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]:
|
378 |
mentions = self.entity_linker.detect_mentions(text)
|
379 |
model_inputs = preprocess_text(
|
380 |
text=text,
|
|
|
390 |
# We exclude "return_tensors" from kwargs
|
391 |
# to avoid issues in passing the data to BatchEncoding outside this method
|
392 |
prepared_inputs = self.prepare_for_model(
|
393 |
+
model_inputs["input_ids"],
|
394 |
+
**{k: v for k, v in kwargs.items() if k != "return_tensors"},
|
395 |
)
|
396 |
model_inputs.update(prepared_inputs)
|
397 |
|
398 |
# Account for special tokens
|
399 |
+
if kwargs.get("add_special_tokens", True):
|
400 |
if prepared_inputs["input_ids"][0] != self.cls_token_id:
|
401 |
raise ValueError(
|
402 |
"We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
|
403 |
)
|
404 |
+
# Shift the entity position IDs by 1 to account for the [CLS] token
|
405 |
+
model_inputs["entity_position_ids"] = [
|
406 |
+
[pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"]
|
407 |
+
]
|
408 |
+
|
409 |
+
# If there is no entities in the text, we output padding entity for the model
|
410 |
+
if not model_inputs["entity_ids"]:
|
411 |
+
model_inputs["entity_ids"] = [0] # The padding entity id is 0
|
412 |
+
model_inputs["entity_position_ids"] = [[0]]
|
413 |
|
414 |
# Count the number of special tokens at the end of the input
|
415 |
num_special_tokens_at_end = 0
|
|
|
417 |
if isinstance(input_ids, torch.Tensor):
|
418 |
input_ids = input_ids.tolist()
|
419 |
for input_id in input_ids[::-1]:
|
420 |
+
if int(input_id) not in {
|
421 |
+
self.sep_token_id,
|
422 |
+
self.pad_token_id,
|
423 |
+
self.cls_token_id,
|
424 |
+
}:
|
425 |
break
|
426 |
num_special_tokens_at_end += 1
|
427 |
|
428 |
# Remove entities that are not in truncated input
|
429 |
max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
|
430 |
entity_indices_to_keep = list()
|
431 |
+
for i, position_ids in enumerate(model_inputs["entity_position_ids"]):
|
432 |
+
if len(position_ids) > 0 and max(position_ids) < max_effective_pos:
|
|
|
|
|
433 |
entity_indices_to_keep.append(i)
|
434 |
model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
|
435 |
+
model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep]
|
|
|
|
|
|
|
436 |
|
437 |
if self.entity_embeddings is not None:
|
438 |
+
model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32)
|
439 |
return model_inputs
|
440 |
|
441 |
def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
|
|
|
449 |
if isinstance(text, str):
|
450 |
processed_inputs = self._preprocess_text(text, **kwargs)
|
451 |
return BatchEncoding(
|
452 |
+
processed_inputs,
|
453 |
+
tensor_type=kwargs.get("return_tensors", None),
|
454 |
+
prepend_batch_axis=True,
|
455 |
)
|
456 |
|
457 |
processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
|
|
|
467 |
return_attention_mask=kwargs.get("return_attention_mask"),
|
468 |
verbose=kwargs.get("verbose", True),
|
469 |
)
|
470 |
+
# Pad entity ids
|
471 |
max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
|
472 |
for entity_ids in collated_inputs["entity_ids"]:
|
473 |
entity_ids += [0] * (max_num_entities - len(entity_ids))
|
474 |
+
# Pad entity position ids
|
475 |
+
flattened_entity_length = [
|
476 |
+
len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list
|
477 |
+
]
|
478 |
+
max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0
|
479 |
+
for entity_position_ids_list in collated_inputs["entity_position_ids"]:
|
480 |
+
# pad entity_position_ids to max_entity_token_length
|
481 |
+
for entity_position_ids in entity_position_ids_list:
|
482 |
+
entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids))
|
483 |
+
# pad to max_num_entities
|
484 |
+
entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * (
|
485 |
+
max_num_entities - len(entity_position_ids_list)
|
486 |
+
)
|
487 |
+
# Pad entity embeddings
|
488 |
if "entity_embeds" in collated_inputs:
|
489 |
for i in range(len(collated_inputs["entity_embeds"])):
|
490 |
collated_inputs["entity_embeds"][i] = np.pad(
|
491 |
collated_inputs["entity_embeds"][i],
|
492 |
pad_width=(
|
493 |
+
(
|
494 |
+
0,
|
495 |
+
max_num_entities - len(collated_inputs["entity_embeds"][i]),
|
496 |
+
),
|
497 |
(0, 0),
|
498 |
),
|
499 |
mode="constant",
|