ikuyamada commited on
Commit
9355881
·
verified ·
1 Parent(s): 5720d2b

Upload tokenizer

Browse files
Files changed (1) hide show
  1. tokenization_kpr.py +49 -32
tokenization_kpr.py CHANGED
@@ -255,8 +255,7 @@ def preprocess_text(
255
  ) -> dict[str, list[int]]:
256
  tokens = []
257
  entity_ids = []
258
- entity_start_positions = []
259
- entity_lengths = []
260
  if title is not None:
261
  if title_mentions is None:
262
  title_mentions = []
@@ -265,8 +264,7 @@ def preprocess_text(
265
  tokens += title_tokens + [tokenizer.sep_token]
266
  for entity in title_entities:
267
  entity_ids.append(entity.entity_id)
268
- entity_start_positions.append(entity.start)
269
- entity_lengths.append(entity.end - entity.start)
270
 
271
  if mentions is None:
272
  mentions = []
@@ -276,16 +274,14 @@ def preprocess_text(
276
  tokens += text_tokens
277
  for entity in text_entities:
278
  entity_ids.append(entity.entity_id)
279
- entity_start_positions.append(entity.start + entity_offset)
280
- entity_lengths.append(entity.end - entity.start)
281
 
282
  input_ids = tokenizer.convert_tokens_to_ids(tokens)
283
 
284
  return {
285
  "input_ids": input_ids,
286
  "entity_ids": entity_ids,
287
- "entity_start_positions": entity_start_positions,
288
- "entity_lengths": entity_lengths,
289
  }
290
 
291
 
@@ -349,8 +345,7 @@ class KPRBertTokenizer(BertTokenizer):
349
  "token_type_ids",
350
  "attention_mask",
351
  "entity_ids",
352
- "entity_start_positions",
353
- "entity_lengths",
354
  ]
355
 
356
  def __init__(
@@ -379,7 +374,7 @@ class KPRBertTokenizer(BertTokenizer):
379
  "Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
380
  )
381
 
382
- def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int]]:
383
  mentions = self.entity_linker.detect_mentions(text)
384
  model_inputs = preprocess_text(
385
  text=text,
@@ -395,18 +390,26 @@ class KPRBertTokenizer(BertTokenizer):
395
  # We exclude "return_tensors" from kwargs
396
  # to avoid issues in passing the data to BatchEncoding outside this method
397
  prepared_inputs = self.prepare_for_model(
398
- model_inputs["input_ids"], **{k: v for k, v in kwargs.items() if k != "return_tensors"}
 
399
  )
400
  model_inputs.update(prepared_inputs)
401
 
402
  # Account for special tokens
403
- if kwargs.get("add_special_tokens"):
404
  if prepared_inputs["input_ids"][0] != self.cls_token_id:
405
  raise ValueError(
406
  "We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
407
  )
408
- # Shift the entity start positions by 1 to account for the [CLS] token
409
- model_inputs["entity_start_positions"] = [pos + 1 for pos in model_inputs["entity_start_positions"]]
 
 
 
 
 
 
 
410
 
411
  # Count the number of special tokens at the end of the input
412
  num_special_tokens_at_end = 0
@@ -414,26 +417,25 @@ class KPRBertTokenizer(BertTokenizer):
414
  if isinstance(input_ids, torch.Tensor):
415
  input_ids = input_ids.tolist()
416
  for input_id in input_ids[::-1]:
417
- if int(input_id) not in {self.sep_token_id, self.pad_token_id, self.cls_token_id}:
 
 
 
 
418
  break
419
  num_special_tokens_at_end += 1
420
 
421
  # Remove entities that are not in truncated input
422
  max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
423
  entity_indices_to_keep = list()
424
- for i, (start_pos, length) in enumerate(
425
- zip(model_inputs["entity_start_positions"], model_inputs["entity_lengths"])
426
- ):
427
- if (start_pos + length) <= max_effective_pos:
428
  entity_indices_to_keep.append(i)
429
  model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
430
- model_inputs["entity_start_positions"] = [
431
- model_inputs["entity_start_positions"][i] for i in entity_indices_to_keep
432
- ]
433
- model_inputs["entity_lengths"] = [model_inputs["entity_lengths"][i] for i in entity_indices_to_keep]
434
 
435
  if self.entity_embeddings is not None:
436
- model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]]
437
  return model_inputs
438
 
439
  def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
@@ -447,7 +449,9 @@ class KPRBertTokenizer(BertTokenizer):
447
  if isinstance(text, str):
448
  processed_inputs = self._preprocess_text(text, **kwargs)
449
  return BatchEncoding(
450
- processed_inputs, tensor_type=kwargs.get("return_tensors", None), prepend_batch_axis=True
 
 
451
  )
452
 
453
  processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
@@ -463,20 +467,33 @@ class KPRBertTokenizer(BertTokenizer):
463
  return_attention_mask=kwargs.get("return_attention_mask"),
464
  verbose=kwargs.get("verbose", True),
465
  )
466
- # Collate entity_ids, entity_start_positions, and entity_lengths
467
  max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
468
  for entity_ids in collated_inputs["entity_ids"]:
469
  entity_ids += [0] * (max_num_entities - len(entity_ids))
470
- for entity_start_positions in collated_inputs["entity_start_positions"]:
471
- entity_start_positions += [-1] * (max_num_entities - len(entity_start_positions))
472
- for entity_lengths in collated_inputs["entity_lengths"]:
473
- entity_lengths += [0] * (max_num_entities - len(entity_lengths))
 
 
 
 
 
 
 
 
 
 
474
  if "entity_embeds" in collated_inputs:
475
  for i in range(len(collated_inputs["entity_embeds"])):
476
  collated_inputs["entity_embeds"][i] = np.pad(
477
  collated_inputs["entity_embeds"][i],
478
  pad_width=(
479
- (0, max_num_entities - len(collated_inputs["entity_embeds"][i])),
 
 
 
480
  (0, 0),
481
  ),
482
  mode="constant",
 
255
  ) -> dict[str, list[int]]:
256
  tokens = []
257
  entity_ids = []
258
+ entity_position_ids = []
 
259
  if title is not None:
260
  if title_mentions is None:
261
  title_mentions = []
 
264
  tokens += title_tokens + [tokenizer.sep_token]
265
  for entity in title_entities:
266
  entity_ids.append(entity.entity_id)
267
+ entity_position_ids.append(list(range(entity.start, entity.end)))
 
268
 
269
  if mentions is None:
270
  mentions = []
 
274
  tokens += text_tokens
275
  for entity in text_entities:
276
  entity_ids.append(entity.entity_id)
277
+ entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset)))
 
278
 
279
  input_ids = tokenizer.convert_tokens_to_ids(tokens)
280
 
281
  return {
282
  "input_ids": input_ids,
283
  "entity_ids": entity_ids,
284
+ "entity_position_ids": entity_position_ids,
 
285
  }
286
 
287
 
 
345
  "token_type_ids",
346
  "attention_mask",
347
  "entity_ids",
348
+ "entity_position_ids",
 
349
  ]
350
 
351
  def __init__(
 
374
  "Make sure `embeddings.py` and `entity_vocab.tsv` are consistent."
375
  )
376
 
377
+ def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]:
378
  mentions = self.entity_linker.detect_mentions(text)
379
  model_inputs = preprocess_text(
380
  text=text,
 
390
  # We exclude "return_tensors" from kwargs
391
  # to avoid issues in passing the data to BatchEncoding outside this method
392
  prepared_inputs = self.prepare_for_model(
393
+ model_inputs["input_ids"],
394
+ **{k: v for k, v in kwargs.items() if k != "return_tensors"},
395
  )
396
  model_inputs.update(prepared_inputs)
397
 
398
  # Account for special tokens
399
+ if kwargs.get("add_special_tokens", True):
400
  if prepared_inputs["input_ids"][0] != self.cls_token_id:
401
  raise ValueError(
402
  "We assume that the input IDs start with the [CLS] token with add_special_tokens = True."
403
  )
404
+ # Shift the entity position IDs by 1 to account for the [CLS] token
405
+ model_inputs["entity_position_ids"] = [
406
+ [pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"]
407
+ ]
408
+
409
+ # If there is no entities in the text, we output padding entity for the model
410
+ if not model_inputs["entity_ids"]:
411
+ model_inputs["entity_ids"] = [0] # The padding entity id is 0
412
+ model_inputs["entity_position_ids"] = [[0]]
413
 
414
  # Count the number of special tokens at the end of the input
415
  num_special_tokens_at_end = 0
 
417
  if isinstance(input_ids, torch.Tensor):
418
  input_ids = input_ids.tolist()
419
  for input_id in input_ids[::-1]:
420
+ if int(input_id) not in {
421
+ self.sep_token_id,
422
+ self.pad_token_id,
423
+ self.cls_token_id,
424
+ }:
425
  break
426
  num_special_tokens_at_end += 1
427
 
428
  # Remove entities that are not in truncated input
429
  max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end
430
  entity_indices_to_keep = list()
431
+ for i, position_ids in enumerate(model_inputs["entity_position_ids"]):
432
+ if len(position_ids) > 0 and max(position_ids) < max_effective_pos:
 
 
433
  entity_indices_to_keep.append(i)
434
  model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep]
435
+ model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep]
 
 
 
436
 
437
  if self.entity_embeddings is not None:
438
+ model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32)
439
  return model_inputs
440
 
441
  def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding:
 
449
  if isinstance(text, str):
450
  processed_inputs = self._preprocess_text(text, **kwargs)
451
  return BatchEncoding(
452
+ processed_inputs,
453
+ tensor_type=kwargs.get("return_tensors", None),
454
+ prepend_batch_axis=True,
455
  )
456
 
457
  processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text]
 
467
  return_attention_mask=kwargs.get("return_attention_mask"),
468
  verbose=kwargs.get("verbose", True),
469
  )
470
+ # Pad entity ids
471
  max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"])
472
  for entity_ids in collated_inputs["entity_ids"]:
473
  entity_ids += [0] * (max_num_entities - len(entity_ids))
474
+ # Pad entity position ids
475
+ flattened_entity_length = [
476
+ len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list
477
+ ]
478
+ max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0
479
+ for entity_position_ids_list in collated_inputs["entity_position_ids"]:
480
+ # pad entity_position_ids to max_entity_token_length
481
+ for entity_position_ids in entity_position_ids_list:
482
+ entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids))
483
+ # pad to max_num_entities
484
+ entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * (
485
+ max_num_entities - len(entity_position_ids_list)
486
+ )
487
+ # Pad entity embeddings
488
  if "entity_embeds" in collated_inputs:
489
  for i in range(len(collated_inputs["entity_embeds"])):
490
  collated_inputs["entity_embeds"][i] = np.pad(
491
  collated_inputs["entity_embeds"][i],
492
  pad_width=(
493
+ (
494
+ 0,
495
+ max_num_entities - len(collated_inputs["entity_embeds"][i]),
496
+ ),
497
  (0, 0),
498
  ),
499
  mode="constant",