Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational
suhara commited on
Commit
ea1e856
·
verified ·
1 Parent(s): 91f6915

Upload modeling_nemotron_h.py

Browse files
Files changed (1) hide show
  1. modeling_nemotron_h.py +15 -10
modeling_nemotron_h.py CHANGED
@@ -24,21 +24,21 @@ import torch.utils.checkpoint
24
  from torch import nn
25
  from torch.nn import CrossEntropyLoss
26
 
27
- from transformers.activations import ACT2FN
28
- from transformers.cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
29
- from transformers.generation import GenerationMixin
30
- from transformers.modeling_attn_mask_utils import (
31
  AttentionMaskConverter,
32
  )
33
- from transformers.modeling_utils import PreTrainedModel
34
- from transformers.utils import (
35
  ModelOutput,
36
  add_code_sample_docstrings,
37
  add_start_docstrings,
38
  add_start_docstrings_to_model_forward,
39
  logging,
40
  )
41
- from transformers.utils.import_utils import (
42
  is_causal_conv1d_available,
43
  is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
@@ -70,7 +70,7 @@ else:
70
  causal_conv1d_update, causal_conv1d_fn = None, None
71
 
72
  if is_flash_attn_2_available():
73
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
74
 
75
  is_fast_path_available = all(
76
  (
@@ -844,8 +844,8 @@ class NemotronHAttention(nn.Module):
844
  self.attention_dropout = config.attention_dropout
845
  self.hidden_size = config.hidden_size
846
  self.num_heads = config.num_attention_heads
847
- if config.head_dim is not None:
848
- self.head_dim = config.head_dim
849
  else:
850
  self.head_dim = config.hidden_size // config.num_attention_heads
851
  self.num_key_value_heads = config.num_key_value_heads
@@ -1542,6 +1542,11 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1542
 
1543
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1544
  if inputs_embeds is not None and empty_past_kv:
 
 
 
 
 
1545
  model_inputs = {"inputs_embeds": inputs_embeds}
1546
  else:
1547
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
 
24
  from torch import nn
25
  from torch.nn import CrossEntropyLoss
26
 
27
+ from ...activations import ACT2FN
28
+ from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
29
+ from ...generation import GenerationMixin
30
+ from ...modeling_attn_mask_utils import (
31
  AttentionMaskConverter,
32
  )
33
+ from ...modeling_utils import PreTrainedModel
34
+ from ...utils import (
35
  ModelOutput,
36
  add_code_sample_docstrings,
37
  add_start_docstrings,
38
  add_start_docstrings_to_model_forward,
39
  logging,
40
  )
41
+ from ...utils.import_utils import (
42
  is_causal_conv1d_available,
43
  is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
 
70
  causal_conv1d_update, causal_conv1d_fn = None, None
71
 
72
  if is_flash_attn_2_available():
73
+ from ...modeling_flash_attention_utils import _flash_attention_forward
74
 
75
  is_fast_path_available = all(
76
  (
 
844
  self.attention_dropout = config.attention_dropout
845
  self.hidden_size = config.hidden_size
846
  self.num_heads = config.num_attention_heads
847
+ if config.attention_head_dim is not None:
848
+ self.head_dim = config.attention_head_dim
849
  else:
850
  self.head_dim = config.hidden_size // config.num_attention_heads
851
  self.num_key_value_heads = config.num_key_value_heads
 
1542
 
1543
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1544
  if inputs_embeds is not None and empty_past_kv:
1545
+ # TODO(pjin): workaround fix for properly extending inputs_embeds;
1546
+ # longer term, may be better handled elsewhere in .generate().
1547
+ if input_ids is not None and inputs_embeds.shape[1] < input_ids.shape[1]:
1548
+ new_token_embeds = self.get_input_embeddings()(input_ids[:,inputs_embeds.shape[1]:])
1549
+ inputs_embeds = torch.cat([inputs_embeds, new_token_embeds], dim=1)
1550
  model_inputs = {"inputs_embeds": inputs_embeds}
1551
  else:
1552
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases