|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .cache import InferenceParams, RecurrentInferenceParams
|
|
from .engine import HyenaInferenceEngine
|
|
from .layers import ParallelGatedMLP, RMSNorm, VocabParallelEmbedding
|
|
from .utils import column_split, print_rank_0
|
|
from deepspeed.moe.layer import MoE
|
|
class MoEParallelGatedMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.ffn_hidden_size = config.get("ffn_hidden_size", 4 * self.hidden_size)
|
|
self.num_experts = config.get("num_experts", 2)
|
|
self.top_k = config.get("top_k", 1)
|
|
self.moe = MoE(
|
|
hidden_size=self.hidden_size,
|
|
expert=nn.Sequential(
|
|
nn.Linear(self.hidden_size, self.ffn_hidden_size),
|
|
nn.GELU(),
|
|
nn.Linear(self.ffn_hidden_size, self.hidden_size),
|
|
),
|
|
num_experts=self.num_experts,
|
|
k=self.top_k,
|
|
)
|
|
|
|
def forward(self, x):
|
|
output, moe_loss = self.moe(x)
|
|
return output, moe_loss
|
|
|
|
try:
|
|
from flash_attn.modules.mha import MHA
|
|
except ImportError:
|
|
"flash_attn not installed"
|
|
|
|
try:
|
|
from .positional_embeddings import swap_mha_rope
|
|
except ImportError:
|
|
"could not import swap_mha_rope from positional_embeddings.py"
|
|
|
|
|
|
from .tokenizer import ByteTokenizer
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
def __init__(self, config, layer_idx) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
|
|
self.layer_idx = layer_idx
|
|
self.proj_groups = config.get("proj_groups", 1)
|
|
dtype = config.get("attn_block_dtype", torch.bfloat16)
|
|
mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
|
|
|
|
self.counter = 0
|
|
self.inner_mha_cls = MHA(
|
|
embed_dim=config.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_heads_kv=config.num_attention_heads // self.proj_groups,
|
|
rotary_emb_dim=config.hidden_size // config.num_attention_heads,
|
|
qkv_proj_bias=config.get("qkv_proj_bias", True),
|
|
rotary_emb_base=config.get("rotary_emb_base", 10000),
|
|
causal=True,
|
|
layer_idx=layer_idx,
|
|
out_proj_bias=config.get("mha_out_proj_bias", True),
|
|
use_flash_attn=self.config.use_flash_attn,
|
|
).to(dtype=dtype)
|
|
|
|
|
|
if config.get("use_interpolated_rotary_pos_emb", False):
|
|
swap_mha_rope(
|
|
mha=self.inner_mha_cls,
|
|
kwargs_new_rope={'scaling_factor': config.get("rotary_emb_scaling_factor", 1.)},
|
|
)
|
|
|
|
if self.config.get("smeared_gqa", False):
|
|
self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
|
|
self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
|
|
|
|
self.mlp = MoEParallelGatedMLP(config).to(dtype=mlp_dtype)
|
|
|
|
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
|
if (
|
|
type(padding_mask) == torch.Tensor
|
|
):
|
|
|
|
u = u * padding_mask[..., None]
|
|
u = (
|
|
self.inner_mha_cls(
|
|
self.pre_norm(u),
|
|
inference_params=inference_params,
|
|
)
|
|
+ u
|
|
)
|
|
if type(padding_mask) == torch.Tensor:
|
|
u = u * padding_mask[..., None]
|
|
u_norm = self.post_norm(u)
|
|
u_mlp_out, moe_loss = self.mlp(u_norm)
|
|
u = u_mlp_out + u
|
|
return u, moe_loss
|
|
|
|
|
|
class ParallelHyenaFilter(nn.Module):
|
|
def __init__(self, config, layer_idx) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.hyena_filter_groups = config.get("hyena_filter_groups", self.config.hidden_size)
|
|
|
|
self.use_flashfft = config.get("use_flashfft", False)
|
|
self.state_size = config.state_size
|
|
self.hidden_size = config.hidden_size
|
|
self.num_filters = config.num_filters
|
|
self.inference_mode = config.get("inference_mode", True)
|
|
self.counter = 0
|
|
self.column_split_hyena = config.get("column_split_hyena", True)
|
|
|
|
assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
|
|
|
|
self.D = nn.Parameter(torch.zeros(self.hidden_size))
|
|
|
|
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
|
|
|
|
|
|
self.short_filter_length = config.short_filter_length
|
|
self.short_filter_weight = nn.Parameter(torch.randn(3 * config.hidden_size, 1, config.short_filter_length))
|
|
self.short_filter_bias = (
|
|
nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
|
|
)
|
|
|
|
self.engine = HyenaInferenceEngine(layer_idx=layer_idx)
|
|
self.use_flash_depthwise = config.get("use_flash_depthwise", False)
|
|
self.data_dtype = None
|
|
|
|
if self.use_flash_depthwise:
|
|
self.fir_fn = FlashDepthwiseConv1d(
|
|
channels=3 * self.hidden_size,
|
|
kernel_size=self.short_filter_length,
|
|
padding=self.short_filter_length - 1,
|
|
weights=self.short_filter_weight,
|
|
bias=self.short_filter_bias,
|
|
device=None,
|
|
dtype=self.config.get("depthwise_dtype", torch.bfloat16),
|
|
)
|
|
else:
|
|
self.fir_fn = F.conv1d
|
|
|
|
self.fftconv_fn = None
|
|
self.long_fir_threshold = config.get("long_fir_threshold", None)
|
|
if self.long_fir_threshold is not None:
|
|
assert self.use_flashfft is False, "long_fir_threshold not compatible with fused flashfft"
|
|
|
|
self.num_systems = self.hidden_size // self.hyena_filter_groups
|
|
|
|
poles = torch.randn(self.num_systems, self.state_size, 1, 2)
|
|
|
|
|
|
poles[..., 0] = 1e-2 * torch.randn(self.num_systems, self.state_size, 1)
|
|
poles[..., 1] = 1e-3 * torch.randn(self.num_systems, self.state_size, 1)
|
|
|
|
self.poles = nn.Parameter(poles)
|
|
|
|
self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
|
|
self.h = None
|
|
|
|
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
|
if inference_params is not None and self.layer_idx in inference_params.fir_state_dict.keys():
|
|
return self.sequential_forward(u, inference_params)
|
|
|
|
else:
|
|
return self.parallel_forward(u, inference_params, padding_mask)
|
|
|
|
def parallel_forward(self, u, inference_params=None, padding_mask=None):
|
|
L = u.shape[1]
|
|
z_pre, fir_state = self.engine.parallel_fir(
|
|
self.fir_fn,
|
|
u,
|
|
self.short_filter_weight,
|
|
self.short_filter_bias,
|
|
L,
|
|
fir_length=self.short_filter_length,
|
|
inference_params=inference_params,
|
|
padding_mask=padding_mask,
|
|
)
|
|
if inference_params:
|
|
inference_params.fir_state_dict[self.layer_idx] = fir_state
|
|
|
|
if self.h is None:
|
|
h, filter_dtype, poles, residues = self.compute_filter(L, u.device)
|
|
else:
|
|
h = self.h
|
|
filter_dtype = self.h.dtype
|
|
|
|
if self.hyena_filter_groups > 1:
|
|
h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 1)
|
|
|
|
|
|
|
|
dims = (
|
|
self.hidden_size,
|
|
self.num_attention_heads,
|
|
self.hidden_size_per_attention_head,
|
|
self.state_size,
|
|
self.hyena_filter_groups,
|
|
)
|
|
y = self.engine.parallel_iir(
|
|
z_pre,
|
|
h,
|
|
self.D,
|
|
L,
|
|
t=self.t,
|
|
poles=self.poles,
|
|
residues=self.residues,
|
|
dims=dims,
|
|
inference_params=inference_params,
|
|
layer_idx=self.layer_idx,
|
|
prefill_style=self.config.get("prefill_style", "fft"),
|
|
use_flashfft=self.use_flashfft,
|
|
fftconv_fn=self.fftconv_fn,
|
|
column_split_hyena=self.column_split_hyena,
|
|
long_fir_threshold=self.long_fir_threshold,
|
|
padding_mask=padding_mask,
|
|
)
|
|
|
|
return y, inference_params
|
|
|
|
def sequential_forward(self, u, inference_params):
|
|
if self.data_dtype is None:
|
|
self.data_dtype = u.dtype
|
|
if len(u.shape) > 2:
|
|
u = u[:, -1]
|
|
|
|
fir_state, iir_state = (
|
|
inference_params.fir_state_dict[self.layer_idx],
|
|
inference_params.state_dict[self.layer_idx],
|
|
)
|
|
|
|
z_pre, fir_state = self.engine.step_fir(
|
|
u, fir_state, weight=self.short_filter_weight, bias=self.short_filter_bias
|
|
)
|
|
x2, x1, v = (
|
|
column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
|
|
if self.column_split_hyena
|
|
else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
|
|
)
|
|
|
|
y, iir_state = self.engine.step_iir(
|
|
x2,
|
|
x1,
|
|
v,
|
|
self.D,
|
|
self.residues,
|
|
self.poles,
|
|
iir_state,
|
|
iir_groups=self.hyena_filter_groups,
|
|
)
|
|
|
|
inference_params.fir_state_dict[self.layer_idx] = fir_state
|
|
inference_params.state_dict[self.layer_idx] = iir_state
|
|
y = y.to(dtype=self.data_dtype)
|
|
return y[:, None], inference_params
|
|
|
|
def update_time(self, L, device):
|
|
"""
|
|
Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
|
|
If L is greater than the length of the previous batch, then the time vector is
|
|
reinitialized. Otherwise, the time vector is truncated from cache.
|
|
"""
|
|
if not hasattr(self, "t"):
|
|
self.t = torch.arange(L, device=device)[None, None]
|
|
elif self.t.shape[-1] < L:
|
|
self.t = torch.arange(L, device=device)[None, None]
|
|
else:
|
|
self.t = self.t[..., :L]
|
|
|
|
def compute_filter(self, L, device):
|
|
self.update_time(L, device)
|
|
filter_dtype = torch.float32
|
|
residues, log_poles = (
|
|
torch.view_as_complex(self.residues.to(filter_dtype)),
|
|
torch.view_as_complex(self.poles.to(filter_dtype)).log(),
|
|
)
|
|
h = (residues * (log_poles * self.t).exp()).real.sum(1)[None]
|
|
return h, filter_dtype, log_poles, residues
|
|
|
|
|
|
class ParallelGatedConvBlock(nn.Module):
|
|
def __init__(self, config, layer_idx) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.low_mem_mode = config.get("low_mem_mode", False)
|
|
dtype = config.get("hyena_block_dtype", torch.float32)
|
|
mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
|
|
self.pre_norm, self.post_norm = RMSNorm(config).to(dtype=dtype), RMSNorm(config).to(dtype=dtype)
|
|
self.filter = ParallelHyenaFilter(config, layer_idx).to(dtype=dtype)
|
|
self.projections = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
|
self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size).to(dtype)
|
|
self.mlp = MoEParallelGatedMLP(config).to(dtype=mlp_dtype)
|
|
|
|
self.proj_norm_fn = self.proj_norm
|
|
self.res_mlp_norm_fn = self.res_mlp_norm
|
|
|
|
if self.config.get("compile", False):
|
|
self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
|
|
self.res_mlp_norm_fn = torch.compile(
|
|
self.res_mlp_norm, fullgraph=True, dynamic=False, mode="reduce-overhead"
|
|
)
|
|
|
|
def proj_norm(self, x):
|
|
return self.projections(self.pre_norm(x))
|
|
|
|
def res_mlp_norm(self, x):
|
|
x_norm = self.post_norm(x)
|
|
y, moe_loss = self.mlp(x_norm)
|
|
return y + x, moe_loss
|
|
|
|
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
|
z = self.proj_norm_fn(u)
|
|
|
|
if type(padding_mask) == torch.Tensor:
|
|
z = z * padding_mask[..., None]
|
|
|
|
z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
|
|
|
|
z_in = self.out_filter_dense(z) + u
|
|
|
|
if type(padding_mask) == torch.Tensor:
|
|
z_in = z_in * padding_mask[..., None]
|
|
|
|
y, moe_loss = self.res_mlp_norm_fn(z_in)
|
|
|
|
return y, inference_params, moe_loss
|
|
|
|
|
|
def get_block(config, layer_idx, flash_fft=None):
|
|
if layer_idx in config.attn_layer_idxs:
|
|
return AttentionBlock(config, layer_idx)
|
|
elif layer_idx in config.hyena_layer_idxs:
|
|
block = ParallelGatedConvBlock(config, layer_idx)
|
|
if config.get("use_flashfft", "False"):
|
|
block.filter.fftconv_fn = flash_fft
|
|
return block
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
class StripedHyena(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embedding_layer = VocabParallelEmbedding(config)
|
|
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
|
|
self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
|
|
|
|
if config.get("use_flashfft", "False"):
|
|
try:
|
|
from flashfftconv import FlashFFTConv
|
|
except:
|
|
raise ImportError
|
|
self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
|
|
else:
|
|
self.flash_fft = None
|
|
|
|
self.blocks = nn.ModuleList(
|
|
get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
|
|
)
|
|
|
|
def forward(self, x, inference_params_dict=None, padding_mask=None):
|
|
L = x.shape[1]
|
|
x = self.embedding_layer.embed(x)
|
|
if inference_params_dict is not None:
|
|
x, inference_params_dict_out, total_moe_loss = self.stateful_forward(
|
|
x,
|
|
inference_params_dict=inference_params_dict,
|
|
)
|
|
else:
|
|
x, inference_params_dict_out, total_moe_loss= self.stateless_forward(x, padding_mask=padding_mask)
|
|
|
|
x = self.norm(x)
|
|
if self.config.unembed==True:
|
|
x = self.unembed.unembed(x)
|
|
return x, inference_params_dict_out, total_moe_loss
|
|
|
|
def stateful_forward(self, x, inference_params_dict=None):
|
|
moe_losses = []
|
|
for block_idx, block in enumerate(self.blocks):
|
|
block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
|
|
inference_params = inference_params_dict[block_name]
|
|
x, block_moe_loss = block(x, inference_params=inference_params)
|
|
if block_moe_loss is not None:
|
|
moe_losses.append(block_moe_loss)
|
|
|
|
total_moe_loss = torch.stack(moe_losses).sum() if moe_losses else None
|
|
return x, inference_params_dict, total_moe_loss
|
|
|
|
def stateless_forward(self, x, padding_mask=None):
|
|
if type(padding_mask) == torch.Tensor:
|
|
x = x * padding_mask[..., None]
|
|
|
|
moe_losses = []
|
|
for _, block in enumerate(self.blocks):
|
|
x, block_moe_loss = block(x, inference_params=None, padding_mask=padding_mask)
|
|
if block_moe_loss is not None:
|
|
moe_losses.append(block_moe_loss)
|
|
|
|
total_moe_loss = torch.stack(moe_losses).sum() if moe_losses else None
|
|
return x, None, total_moe_loss
|
|
|
|
def initialize_inference_params(self):
|
|
print_rank_0("Initializing inference params...")
|
|
inference_params_dict = {
|
|
"mha": InferenceParams(
|
|
max_seqlen=self.config.get("max_seqlen", 8192),
|
|
max_batch_size=self.config.get("max_batch_size", 1),
|
|
seqlen_offset=0,
|
|
),
|
|
"hyena": RecurrentInferenceParams(
|
|
fir_filter_length=self.config.short_filter_length,
|
|
state_dim=self.config.state_size,
|
|
seqlen_offset=0,
|
|
),
|
|
}
|
|
return inference_params_dict
|
|
|
|
def precompute_filters(self, L, device):
|
|
for block_idx, block in enumerate(self.blocks):
|
|
if type(block) == ParallelGatedConvBlock:
|
|
if type(block.filter) == ParallelHyenaFilter:
|
|
L = block.filter.long_fir_threshold or L
|
|
print_rank_0(f"Precomputing filters, L={L}...")
|
|
|
|
filter_dtype = torch.float16 if L >= 2048 else torch.float32
|
|
|
|
block.filter._set_time(L, device)
|
|
residues, poles = (
|
|
torch.view_as_complex(block.filter.residues.to(torch.float16)),
|
|
torch.view_as_complex(block.filter.poles.to(torch.float16)),
|
|
)
|
|
|
|
block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
|
|
block.filter.h = block.filter.h.to(dtype=filter_dtype)
|
|
|
|
def load_poles_residues(self, path):
|
|
"Load different poles and residues for each layer."
|
|
for block_idx, block in enumerate(self.blocks):
|
|
if type(block) == ParallelGatedConvBlock:
|
|
if type(block.filter) == ParallelHyenaFilter:
|
|
print(f"Loading poles and residues for block {block_idx}")
|
|
poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
|
|
poles = torch.view_as_real(poles)
|
|
residues = torch.load(path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu")
|
|
residues = torch.view_as_real(residues)
|
|
poles = poles.permute(1, 0, 2).unsqueeze(-2)
|
|
residues = residues.permute(1, 0, 2).unsqueeze(-2)
|
|
|
|
block.filter.poles = nn.Parameter(poles)
|
|
block.filter.residues = nn.Parameter(residues)
|
|
|
|
def to_bfloat16_except_poles_residues(self):
|
|
"""Convert all parameters to bfloat16 except for the poles and residues.
|
|
|
|
Particularly important for longer prompts.
|
|
"""
|
|
for k, p in self.named_parameters():
|
|
if "poles" not in k and "residues" not in k:
|
|
p.data = p.data.to(torch.bfloat16)
|
|
|
|
def load_from_split_converted_state_dict(self, path):
|
|
|
|
print("Loading from split converted state dict")
|
|
|
|
embedding_weight = torch.load(path + "/layer_00.pt")["word_embeddings.weight"]
|
|
self.embedding_layer.weight = nn.Parameter(embedding_weight.to(self.embedding_layer.weight.dtype))
|
|
|
|
print("Loading embedding weight ok")
|
|
|
|
if self.config.get("final_norm", False) is not None:
|
|
idx = len(self.blocks) + 1
|
|
final_norm_scale = torch.load(path + f"/layer_{idx:02d}.pt")["norm.scale"]
|
|
self.norm.scale = nn.Parameter(final_norm_scale.to(self.norm.scale.dtype))
|
|
|
|
print("loading final norm ok")
|
|
|
|
if not self.config.get("tie_embeddings", True):
|
|
idx = len(self.blocks) + 2
|
|
embedding_weight = torch.load(path + f"/layer_{idx:02d}.pt")["word_embeddings.weight"]
|
|
self.unembed.weight = nn.Parameter(embedding_weight.to(self.unembed.weight.dtype))
|
|
|
|
print("loading unembed weight ok")
|
|
|
|
for block_idx, block in enumerate(self.blocks):
|
|
print("loading block {}...".format(block_idx))
|
|
|
|
|
|
strict = True
|
|
|
|
loaded_dict = torch.load(path + f"/layer_{block_idx + 1:02d}.pt")
|
|
block.load_state_dict(loaded_dict, strict=strict) |