Remove In-place operations

#40
by yyyyifan - opened
Files changed (1) hide show
  1. modeling_molmo.py +3 -3
modeling_molmo.py CHANGED
@@ -1163,7 +1163,7 @@ class MultiHeadAttentionPool(nn.Module):
1163
  if self.dropout:
1164
  attn_output = self.residual_dropout(attn_output)
1165
  if self.mean_residual:
1166
- attn_output += inputs_kv.mean(dim=1, keepdim=True)
1167
 
1168
  return attn_output
1169
 
@@ -1879,7 +1879,7 @@ class Molmo(nn.Module):
1879
  # For hf demo/endpoint
1880
  image_features = image_features.to(x.device)
1881
 
1882
- x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
1883
 
1884
  if not self.config.rope:
1885
  # Get positional embeddings.
@@ -2145,7 +2145,7 @@ class MolmoForCausalLM(PreTrainedModel):
2145
  z_loss = z_loss.view(input_ids.shape[0], -1)
2146
  z_loss = z_loss * loss_masks
2147
  z_loss = z_loss.sum() / batch_size_in_tokens
2148
- loss += z_loss
2149
  else:
2150
  # Shift so that tokens < n predict n
2151
  shift_logits = logits[..., :-1, :].contiguous()
 
1163
  if self.dropout:
1164
  attn_output = self.residual_dropout(attn_output)
1165
  if self.mean_residual:
1166
+ attn_output = attn_output + inputs_kv.mean(dim=1, keepdim=True)
1167
 
1168
  return attn_output
1169
 
 
1879
  # For hf demo/endpoint
1880
  image_features = image_features.to(x.device)
1881
 
1882
+ x[batch_idx[valid], image_input_idx[valid]] = x[batch_idx[valid], image_input_idx[valid]] + image_features[valid]
1883
 
1884
  if not self.config.rope:
1885
  # Get positional embeddings.
 
2145
  z_loss = z_loss.view(input_ids.shape[0], -1)
2146
  z_loss = z_loss * loss_masks
2147
  z_loss = z_loss.sum() / batch_size_in_tokens
2148
+ loss = loss + z_loss
2149
  else:
2150
  # Shift so that tokens < n predict n
2151
  shift_logits = logits[..., :-1, :].contiguous()