update modeling_yi.py
Browse files- modeling_yi.py +11 -13
modeling_yi.py
CHANGED
|
@@ -6,7 +6,6 @@ import torch.utils.checkpoint
|
|
| 6 |
from einops import repeat
|
| 7 |
from torch import nn
|
| 8 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 9 |
-
|
| 10 |
from transformers.activations import ACT2FN
|
| 11 |
from transformers.modeling_outputs import (
|
| 12 |
BaseModelOutputWithPast,
|
|
@@ -18,17 +17,17 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
| 18 |
from transformers.utils import (
|
| 19 |
add_start_docstrings,
|
| 20 |
add_start_docstrings_to_model_forward,
|
| 21 |
-
is_flash_attn_available,
|
| 22 |
logging,
|
| 23 |
replace_return_docstrings,
|
| 24 |
)
|
| 25 |
|
| 26 |
from .configuration_yi import YiConfig
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
from flash_attn import flash_attn_func
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
logger = logging.get_logger(__name__)
|
| 34 |
|
|
@@ -224,7 +223,6 @@ class YiAttention(nn.Module):
|
|
| 224 |
use_cache: bool = False,
|
| 225 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 226 |
bsz, q_len, _ = hidden_states.size()
|
| 227 |
-
flash_attn_available = is_flash_attn_available()
|
| 228 |
|
| 229 |
query_states = self.q_proj(hidden_states).view(
|
| 230 |
bsz, q_len, self.num_heads, self.head_dim
|
|
@@ -237,7 +235,7 @@ class YiAttention(nn.Module):
|
|
| 237 |
bsz, q_len, self.num_key_value_heads, self.head_dim
|
| 238 |
)
|
| 239 |
|
| 240 |
-
if not
|
| 241 |
if self.num_key_value_groups > 1:
|
| 242 |
key_states = repeat(
|
| 243 |
key_states, f"b n h d -> b n (h {self.num_key_value_groups}) d"
|
|
@@ -251,13 +249,13 @@ class YiAttention(nn.Module):
|
|
| 251 |
key_states = key_states.transpose(1, 2)
|
| 252 |
value_states = value_states.transpose(1, 2)
|
| 253 |
|
| 254 |
-
seq_dim = 1 if
|
| 255 |
kv_seq_len = key_states.shape[seq_dim]
|
| 256 |
if past_key_value is not None:
|
| 257 |
kv_seq_len += past_key_value[0].shape[seq_dim]
|
| 258 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 259 |
query_states, key_states = apply_rotary_pos_emb(
|
| 260 |
-
query_states, key_states, cos, sin, position_ids,
|
| 261 |
)
|
| 262 |
|
| 263 |
if past_key_value is not None:
|
|
@@ -267,7 +265,7 @@ class YiAttention(nn.Module):
|
|
| 267 |
|
| 268 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 269 |
|
| 270 |
-
if
|
| 271 |
attn_output = flash_attn_func(
|
| 272 |
query_states, key_states, value_states, dropout_p=0.0, causal=True
|
| 273 |
)
|
|
@@ -308,7 +306,7 @@ class YiAttention(nn.Module):
|
|
| 308 |
f" {attn_output.size()}"
|
| 309 |
)
|
| 310 |
|
| 311 |
-
if not
|
| 312 |
attn_output = attn_output.transpose(1, 2)
|
| 313 |
|
| 314 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
@@ -541,7 +539,7 @@ class YiModel(YiPreTrainedModel):
|
|
| 541 |
def _prepare_decoder_attention_mask(
|
| 542 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
| 543 |
):
|
| 544 |
-
input_shape = input_ids.shape
|
| 545 |
# create causal mask
|
| 546 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 547 |
combined_attention_mask = None
|
|
@@ -631,7 +629,7 @@ class YiModel(YiPreTrainedModel):
|
|
| 631 |
if inputs_embeds is None:
|
| 632 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 633 |
|
| 634 |
-
if not is_flash_attn_available
|
| 635 |
# embed positions
|
| 636 |
if attention_mask is None:
|
| 637 |
attention_mask = torch.ones(
|
|
|
|
| 6 |
from einops import repeat
|
| 7 |
from torch import nn
|
| 8 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
| 9 |
from transformers.activations import ACT2FN
|
| 10 |
from transformers.modeling_outputs import (
|
| 11 |
BaseModelOutputWithPast,
|
|
|
|
| 17 |
from transformers.utils import (
|
| 18 |
add_start_docstrings,
|
| 19 |
add_start_docstrings_to_model_forward,
|
|
|
|
| 20 |
logging,
|
| 21 |
replace_return_docstrings,
|
| 22 |
)
|
| 23 |
|
| 24 |
from .configuration_yi import YiConfig
|
| 25 |
|
| 26 |
+
is_flash_attn_available = True
|
| 27 |
+
try:
|
| 28 |
from flash_attn import flash_attn_func
|
| 29 |
+
except Exception:
|
| 30 |
+
is_flash_attn_available = False
|
| 31 |
|
| 32 |
logger = logging.get_logger(__name__)
|
| 33 |
|
|
|
|
| 223 |
use_cache: bool = False,
|
| 224 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 225 |
bsz, q_len, _ = hidden_states.size()
|
|
|
|
| 226 |
|
| 227 |
query_states = self.q_proj(hidden_states).view(
|
| 228 |
bsz, q_len, self.num_heads, self.head_dim
|
|
|
|
| 235 |
bsz, q_len, self.num_key_value_heads, self.head_dim
|
| 236 |
)
|
| 237 |
|
| 238 |
+
if not is_flash_attn_available:
|
| 239 |
if self.num_key_value_groups > 1:
|
| 240 |
key_states = repeat(
|
| 241 |
key_states, f"b n h d -> b n (h {self.num_key_value_groups}) d"
|
|
|
|
| 249 |
key_states = key_states.transpose(1, 2)
|
| 250 |
value_states = value_states.transpose(1, 2)
|
| 251 |
|
| 252 |
+
seq_dim = 1 if is_flash_attn_available else 2
|
| 253 |
kv_seq_len = key_states.shape[seq_dim]
|
| 254 |
if past_key_value is not None:
|
| 255 |
kv_seq_len += past_key_value[0].shape[seq_dim]
|
| 256 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 257 |
query_states, key_states = apply_rotary_pos_emb(
|
| 258 |
+
query_states, key_states, cos, sin, position_ids, is_flash_attn_available
|
| 259 |
)
|
| 260 |
|
| 261 |
if past_key_value is not None:
|
|
|
|
| 265 |
|
| 266 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 267 |
|
| 268 |
+
if is_flash_attn_available:
|
| 269 |
attn_output = flash_attn_func(
|
| 270 |
query_states, key_states, value_states, dropout_p=0.0, causal=True
|
| 271 |
)
|
|
|
|
| 306 |
f" {attn_output.size()}"
|
| 307 |
)
|
| 308 |
|
| 309 |
+
if not is_flash_attn_available:
|
| 310 |
attn_output = attn_output.transpose(1, 2)
|
| 311 |
|
| 312 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
| 539 |
def _prepare_decoder_attention_mask(
|
| 540 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
| 541 |
):
|
| 542 |
+
input_shape = input_ids.shape if input_ids else inputs_embeds.shape[:-1]
|
| 543 |
# create causal mask
|
| 544 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 545 |
combined_attention_mask = None
|
|
|
|
| 629 |
if inputs_embeds is None:
|
| 630 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 631 |
|
| 632 |
+
if not is_flash_attn_available:
|
| 633 |
# embed positions
|
| 634 |
if attention_mask is None:
|
| 635 |
attention_mask = torch.ones(
|