Text Generation
Transformers
Safetensors
PyTorch
nvidia
suhara commited on
Commit
a8fe94a
·
verified ·
1 Parent(s): f75c274

Upload modeling_nemotron_h.py

Browse files
Files changed (1) hide show
  1. modeling_nemotron_h.py +6 -1
modeling_nemotron_h.py CHANGED
@@ -42,7 +42,7 @@ from transformers.utils.import_utils import (
42
  is_causal_conv1d_available,
43
  is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
- is_mamba_2_ssm_available,
46
  )
47
  from .configuration_nemotron_h import NemotronHConfig
48
 
@@ -1542,6 +1542,11 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1542
 
1543
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1544
  if inputs_embeds is not None and empty_past_kv:
 
 
 
 
 
1545
  model_inputs = {"inputs_embeds": inputs_embeds}
1546
  else:
1547
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
 
42
  is_causal_conv1d_available,
43
  is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
+ is_mamba_2_ssm_available,
46
  )
47
  from .configuration_nemotron_h import NemotronHConfig
48
 
 
1542
 
1543
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1544
  if inputs_embeds is not None and empty_past_kv:
1545
+ # TODO(pjin): workaround fix for properly extending inputs_embeds;
1546
+ # longer term, may be better handled elsewhere in .generate().
1547
+ if input_ids is not None and inputs_embeds.shape[1] < input_ids.shape[1]:
1548
+ new_token_embeds = self.get_input_embeddings()(input_ids[:,inputs_embeds.shape[1]:])
1549
+ inputs_embeds = torch.cat([inputs_embeds, new_token_embeds], dim=1)
1550
  model_inputs = {"inputs_embeds": inputs_embeds}
1551
  else:
1552
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases