Upload modeling_nemotron_h.py
Browse files- modeling_nemotron_h.py +6 -1
modeling_nemotron_h.py
CHANGED
@@ -42,7 +42,7 @@ from transformers.utils.import_utils import (
|
|
42 |
is_causal_conv1d_available,
|
43 |
is_flash_attn_2_available,
|
44 |
is_flash_attn_greater_or_equal_2_10,
|
45 |
-
is_mamba_2_ssm_available,
|
46 |
)
|
47 |
from .configuration_nemotron_h import NemotronHConfig
|
48 |
|
@@ -1542,6 +1542,11 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
1542 |
|
1543 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1544 |
if inputs_embeds is not None and empty_past_kv:
|
|
|
|
|
|
|
|
|
|
|
1545 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
1546 |
else:
|
1547 |
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
|
42 |
is_causal_conv1d_available,
|
43 |
is_flash_attn_2_available,
|
44 |
is_flash_attn_greater_or_equal_2_10,
|
45 |
+
is_mamba_2_ssm_available,
|
46 |
)
|
47 |
from .configuration_nemotron_h import NemotronHConfig
|
48 |
|
|
|
1542 |
|
1543 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1544 |
if inputs_embeds is not None and empty_past_kv:
|
1545 |
+
# TODO(pjin): workaround fix for properly extending inputs_embeds;
|
1546 |
+
# longer term, may be better handled elsewhere in .generate().
|
1547 |
+
if input_ids is not None and inputs_embeds.shape[1] < input_ids.shape[1]:
|
1548 |
+
new_token_embeds = self.get_input_embeddings()(input_ids[:,inputs_embeds.shape[1]:])
|
1549 |
+
inputs_embeds = torch.cat([inputs_embeds, new_token_embeds], dim=1)
|
1550 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
1551 |
else:
|
1552 |
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|