|
|
|
|
|
|
|
import torch
|
|
import copy
|
|
from einops import rearrange
|
|
from flash_attn.layers.rotary import RotaryEmbedding
|
|
from flash_attn.modules.mha import MHA
|
|
|
|
|
|
|
|
class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
scaling_factor: float=1.,
|
|
base=10000.0,
|
|
interleaved=False,
|
|
scale_base=None,
|
|
pos_idx_in_fp32=True,
|
|
device=None,
|
|
):
|
|
super().__init__(
|
|
dim=dim,
|
|
base=base,
|
|
interleaved=interleaved,
|
|
scale_base=scale_base,
|
|
pos_idx_in_fp32=pos_idx_in_fp32,
|
|
device=device
|
|
)
|
|
self._linear_scaling_factor = scaling_factor
|
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
|
|
|
|
|
|
if (
|
|
seqlen > self._seq_len_cached
|
|
or self._cos_cached is None
|
|
or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
or (self.training and self._cos_cached.is_inference())
|
|
):
|
|
self._seq_len_cached = seqlen
|
|
|
|
|
|
|
|
if self.pos_idx_in_fp32:
|
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
|
|
|
t = t / self._linear_scaling_factor
|
|
|
|
|
|
|
|
|
|
if self.inv_freq.dtype != torch.float32:
|
|
inv_freq = self._compute_inv_freq(device=device)
|
|
else:
|
|
inv_freq = self.inv_freq
|
|
else:
|
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
|
|
t = t / self._linear_scaling_factor
|
|
inv_freq = self.inv_freq
|
|
|
|
|
|
freqs = torch.outer(t, inv_freq)
|
|
if self.scale is None:
|
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
else:
|
|
power = (
|
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
|
- seqlen // 2
|
|
) / self.scale_base
|
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
|
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
|
|
|
|
|
def swap_mha_rope(
|
|
mha,
|
|
new_rope: torch.nn.Module=LinearlyScaledRotaryEmbedding,
|
|
kwargs_new_rope: dict=None
|
|
):
|
|
|
|
dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
|
|
device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
|
|
|
|
kwargs_old_rope = dict(
|
|
dim = mha.rotary_emb.dim,
|
|
base = mha.rotary_emb.base,
|
|
interleaved = mha.rotary_emb.interleaved,
|
|
scale_base = mha.rotary_emb.scale_base,
|
|
pos_idx_in_fp32 = mha.rotary_emb.pos_idx_in_fp32,
|
|
device = mha.rotary_emb.inv_freq.device
|
|
)
|
|
|
|
del mha.rotary_emb
|
|
|
|
kwargs_new_rope = kwargs_new_rope or {'scaling_factor': 1.0}
|
|
scaled_rope = new_rope(
|
|
**kwargs_new_rope,
|
|
**kwargs_old_rope
|
|
).to(dtype)
|
|
|
|
mha.rotary_emb = scaled_rope
|
|
|
|
assert isinstance(mha.rotary_emb, new_rope)
|
|
return mha |