norm function changes dtypes (#13)
Browse files- convert norm function output to original dtype (771117017599deb36494828a10d72790c69929f6)
- modeling_intern_vit.py +3 -2
modeling_intern_vit.py
CHANGED
|
@@ -287,9 +287,10 @@ class InternVisionEncoderLayer(nn.Module):
|
|
| 287 |
Args:
|
| 288 |
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 289 |
"""
|
| 290 |
-
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
|
| 291 |
|
| 292 |
-
hidden_states = hidden_states + self.
|
|
|
|
|
|
|
| 293 |
|
| 294 |
return hidden_states
|
| 295 |
|
|
|
|
| 287 |
Args:
|
| 288 |
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 289 |
"""
|
|
|
|
| 290 |
|
| 291 |
+
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
|
| 292 |
+
|
| 293 |
+
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
|
| 294 |
|
| 295 |
return hidden_states
|
| 296 |
|