Update to add SDPA support
Browse files- modular_isaac.py +131 -58
modular_isaac.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from collections import defaultdict
|
| 4 |
-
from typing import Any,
|
| 5 |
|
| 6 |
import math
|
| 7 |
import numpy as np
|
|
@@ -81,6 +81,91 @@ def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device)
|
|
| 81 |
return cu_seqlens, max_seqlen
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
class Siglip2VariableSequenceEmbeddings(nn.Module):
|
| 85 |
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
| 86 |
super().__init__()
|
|
@@ -172,58 +257,42 @@ class Siglip2VariableLengthAttention(nn.Module):
|
|
| 172 |
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 173 |
|
| 174 |
def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None):
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
# For variable-length attention, we need to reshape to (total_tokens, embed_dim)
|
| 178 |
if batch_size != 1:
|
| 179 |
-
raise ValueError("
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
window_size_left=-1,
|
| 209 |
-
window_size_right=-1,
|
| 210 |
-
alibi_slopes=None,
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
# 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim)
|
| 214 |
-
attn_output = attn_output.reshape(seq_len, self.embed_dim)
|
| 215 |
-
|
| 216 |
-
# 5. Convert back to original dtype if needed
|
| 217 |
-
if attn_output.dtype != orig_dtype:
|
| 218 |
-
attn_output = attn_output.to(orig_dtype)
|
| 219 |
-
|
| 220 |
-
# 6. Project output
|
| 221 |
-
attn_output = self.out_proj(attn_output) # (seq_len, embed_dim)
|
| 222 |
-
|
| 223 |
-
# 7. Add back batch dimension for compatibility
|
| 224 |
-
attn_output = attn_output.unsqueeze(0) # (1, seq_len, embed_dim)
|
| 225 |
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
| 228 |
|
| 229 |
class IsaacSiglip2EncoderLayer(nn.Module):
|
|
@@ -805,6 +874,7 @@ class IsaacConfig(Qwen3Config):
|
|
| 805 |
pixel_shuffle_scale: int = 1,
|
| 806 |
max_sequence_length: int = 16384,
|
| 807 |
vision_token: str = "<image>",
|
|
|
|
| 808 |
**kwargs,
|
| 809 |
):
|
| 810 |
super().__init__(**kwargs)
|
|
@@ -826,6 +896,7 @@ class IsaacConfig(Qwen3Config):
|
|
| 826 |
# Processing parameters
|
| 827 |
self.max_sequence_length = max_sequence_length
|
| 828 |
self.vision_token = vision_token
|
|
|
|
| 829 |
|
| 830 |
|
| 831 |
# ============================================================================
|
|
@@ -880,7 +951,6 @@ class IsaacProcessor(ProcessorMixin):
|
|
| 880 |
attributes = ["tokenizer"]
|
| 881 |
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 882 |
|
| 883 |
-
|
| 884 |
def __init__(
|
| 885 |
self,
|
| 886 |
tokenizer: Qwen2Tokenizer,
|
|
@@ -992,8 +1062,8 @@ class IsaacProcessor(ProcessorMixin):
|
|
| 992 |
|
| 993 |
def __call__(
|
| 994 |
self,
|
| 995 |
-
text:
|
| 996 |
-
images:
|
| 997 |
return_tensors: str | TensorType | None = TensorType.PYTORCH,
|
| 998 |
**kwargs,
|
| 999 |
) -> BatchFeature:
|
|
@@ -1135,6 +1205,12 @@ class IsaacModel(Qwen3Model):
|
|
| 1135 |
self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device)
|
| 1136 |
|
| 1137 |
vision_cfg = config.vision_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1138 |
if vision_cfg is None:
|
| 1139 |
raise ValueError("IsaacConfig should always have vision_config")
|
| 1140 |
|
|
@@ -1418,9 +1494,7 @@ class IsaacModel(Qwen3Model):
|
|
| 1418 |
causal_mask = attention_mask
|
| 1419 |
else:
|
| 1420 |
min_dtype = torch.finfo(dtype).min
|
| 1421 |
-
causal_mask = torch.full(
|
| 1422 |
-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 1423 |
-
)
|
| 1424 |
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 1425 |
if config.sliding_window is not None:
|
| 1426 |
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
|
@@ -1447,7 +1521,6 @@ class IsaacModel(Qwen3Model):
|
|
| 1447 |
return causal_mask
|
| 1448 |
|
| 1449 |
|
| 1450 |
-
|
| 1451 |
class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
|
| 1452 |
"""Isaac multimodal model for conditional generation."""
|
| 1453 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from collections import defaultdict
|
| 4 |
+
from typing import Any, TypedDict
|
| 5 |
|
| 6 |
import math
|
| 7 |
import numpy as np
|
|
|
|
| 81 |
return cu_seqlens, max_seqlen
|
| 82 |
|
| 83 |
|
| 84 |
+
def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int:
|
| 85 |
+
"""Helper to compute max sequence length from cumulative sequence lengths."""
|
| 86 |
+
if cu is None or len(cu) < 2:
|
| 87 |
+
return fallback
|
| 88 |
+
return int((cu[1:] - cu[:-1]).max().item())
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def flash_attention_document_mask_forward(
|
| 92 |
+
q_lhd: torch.Tensor, # (L, H, D)
|
| 93 |
+
k_lhd: torch.Tensor, # (L, H, D)
|
| 94 |
+
v_lhd: torch.Tensor, # (L, H, D)
|
| 95 |
+
attention_mask: torch.Tensor | None = None, # unused for FA path
|
| 96 |
+
dropout: float = 0.0,
|
| 97 |
+
scaling: float | None = None,
|
| 98 |
+
cum_seq_q: torch.Tensor | None = None,
|
| 99 |
+
cum_seq_k: torch.Tensor | None = None,
|
| 100 |
+
max_seqlen: int | None = None,
|
| 101 |
+
is_causal: bool = False,
|
| 102 |
+
**kwargs,
|
| 103 |
+
) -> tuple[torch.Tensor, None]:
|
| 104 |
+
"""FlashAttention that consumes (L, H, D) directly to avoid layout churn."""
|
| 105 |
+
L, H, D = q_lhd.shape
|
| 106 |
+
|
| 107 |
+
# Compute max block length once (honor caller when provided)
|
| 108 |
+
if max_seqlen is not None:
|
| 109 |
+
max_q = max_k = int(max_seqlen)
|
| 110 |
+
else:
|
| 111 |
+
max_q = _max_from_cu(cum_seq_q, L)
|
| 112 |
+
max_k = _max_from_cu(cum_seq_k, L)
|
| 113 |
+
|
| 114 |
+
# Ensure contiguity only if needed
|
| 115 |
+
if not q_lhd.is_contiguous():
|
| 116 |
+
q_lhd = q_lhd.contiguous()
|
| 117 |
+
if not k_lhd.is_contiguous():
|
| 118 |
+
k_lhd = k_lhd.contiguous()
|
| 119 |
+
if not v_lhd.is_contiguous():
|
| 120 |
+
v_lhd = v_lhd.contiguous()
|
| 121 |
+
|
| 122 |
+
out_lhd, *_ = torch.ops.aten._flash_attention_forward(
|
| 123 |
+
query=q_lhd, # (L, H, D)
|
| 124 |
+
key=k_lhd, # (L, H, D)
|
| 125 |
+
value=v_lhd, # (L, H, D)
|
| 126 |
+
cum_seq_q=cum_seq_q,
|
| 127 |
+
cum_seq_k=cum_seq_k,
|
| 128 |
+
max_q=max_q,
|
| 129 |
+
max_k=max_k,
|
| 130 |
+
dropout_p=dropout,
|
| 131 |
+
is_causal=is_causal,
|
| 132 |
+
return_debug_mask=False,
|
| 133 |
+
scale=scaling,
|
| 134 |
+
window_size_left=-1,
|
| 135 |
+
window_size_right=-1,
|
| 136 |
+
alibi_slopes=None,
|
| 137 |
+
)
|
| 138 |
+
return out_lhd, None # (L, H, D)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def sdpa_document_mask_forward(
|
| 142 |
+
q_lhd: torch.Tensor, # (L, H, D)
|
| 143 |
+
k_lhd: torch.Tensor, # (L, H, D)
|
| 144 |
+
v_lhd: torch.Tensor, # (L, H, D)
|
| 145 |
+
dropout: float,
|
| 146 |
+
scaling: float | None,
|
| 147 |
+
cu_seqlens: torch.Tensor | None,
|
| 148 |
+
) -> torch.Tensor:
|
| 149 |
+
"""SDPA with block-diagonal masking for variable-length sequences."""
|
| 150 |
+
L, H, D = q_lhd.shape
|
| 151 |
+
|
| 152 |
+
# Transpose to (1, H, L, D) format for SDPA
|
| 153 |
+
Q = q_lhd.permute(1, 0, 2).unsqueeze(0)
|
| 154 |
+
K = k_lhd.permute(1, 0, 2).unsqueeze(0)
|
| 155 |
+
V = v_lhd.permute(1, 0, 2).unsqueeze(0)
|
| 156 |
+
|
| 157 |
+
# Build block-diagonal mask for variable-length sequences
|
| 158 |
+
attn_mask = None
|
| 159 |
+
if cu_seqlens is not None:
|
| 160 |
+
seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long()
|
| 161 |
+
seg_ids = torch.repeat_interleave(torch.arange(len(seq_sizes), device=q_lhd.device), seq_sizes)
|
| 162 |
+
block_mask = seg_ids[:, None] != seg_ids[None, :] # Cross-document attention blocked
|
| 163 |
+
attn_mask = torch.where(block_mask, -torch.inf, 0.0).to(q_lhd.dtype).view(1, 1, L, L)
|
| 164 |
+
|
| 165 |
+
Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling)
|
| 166 |
+
return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
class Siglip2VariableSequenceEmbeddings(nn.Module):
|
| 170 |
def __init__(self, config: PixelShuffleSiglip2VisionConfig):
|
| 171 |
super().__init__()
|
|
|
|
| 257 |
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 258 |
|
| 259 |
def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None):
|
| 260 |
+
# Expect packed sequences with batch_size == 1
|
| 261 |
+
batch_size, L, _ = hidden_states.shape
|
|
|
|
| 262 |
if batch_size != 1:
|
| 263 |
+
raise ValueError("packed variable-length attention expects batch_size=1")
|
| 264 |
+
x = hidden_states[0] # (L, E)
|
| 265 |
+
|
| 266 |
+
H = self.num_heads
|
| 267 |
+
D = self.head_dim
|
| 268 |
+
p_drop = self.dropout if self.training else 0.0
|
| 269 |
+
|
| 270 |
+
# Project and reshape to (L, H, D)
|
| 271 |
+
q = self.q_proj(x).view(L, H, D)
|
| 272 |
+
k = self.k_proj(x).view(L, H, D)
|
| 273 |
+
v = self.v_proj(x).view(L, H, D)
|
| 274 |
+
|
| 275 |
+
attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3")
|
| 276 |
+
|
| 277 |
+
if attn_impl in ("flash_attention_2", "flash_attention_3"):
|
| 278 |
+
y_lhd, _ = flash_attention_document_mask_forward(
|
| 279 |
+
q,
|
| 280 |
+
k,
|
| 281 |
+
v,
|
| 282 |
+
attention_mask=None,
|
| 283 |
+
dropout=p_drop,
|
| 284 |
+
scaling=self.scale,
|
| 285 |
+
cum_seq_q=cu_seqlens,
|
| 286 |
+
cum_seq_k=cu_seqlens,
|
| 287 |
+
max_seqlen=max_seqlen,
|
| 288 |
+
is_causal=False,
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
y_lhd = sdpa_document_mask_forward(q, k, v, dropout=p_drop, scaling=self.scale, cu_seqlens=cu_seqlens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
# Merge heads and project
|
| 294 |
+
y = self.out_proj(y_lhd.reshape(L, self.embed_dim))
|
| 295 |
+
return y.unsqueeze(0), None # (1, L, E)
|
| 296 |
|
| 297 |
|
| 298 |
class IsaacSiglip2EncoderLayer(nn.Module):
|
|
|
|
| 874 |
pixel_shuffle_scale: int = 1,
|
| 875 |
max_sequence_length: int = 16384,
|
| 876 |
vision_token: str = "<image>",
|
| 877 |
+
vision_attn_implementation: str | None = None,
|
| 878 |
**kwargs,
|
| 879 |
):
|
| 880 |
super().__init__(**kwargs)
|
|
|
|
| 896 |
# Processing parameters
|
| 897 |
self.max_sequence_length = max_sequence_length
|
| 898 |
self.vision_token = vision_token
|
| 899 |
+
self.vision_attn_implementation = vision_attn_implementation
|
| 900 |
|
| 901 |
|
| 902 |
# ============================================================================
|
|
|
|
| 951 |
attributes = ["tokenizer"]
|
| 952 |
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 953 |
|
|
|
|
| 954 |
def __init__(
|
| 955 |
self,
|
| 956 |
tokenizer: Qwen2Tokenizer,
|
|
|
|
| 1062 |
|
| 1063 |
def __call__(
|
| 1064 |
self,
|
| 1065 |
+
text: str | list[str],
|
| 1066 |
+
images: PIL.Image.Image | list[PIL.Image.Image] | None = None,
|
| 1067 |
return_tensors: str | TensorType | None = TensorType.PYTORCH,
|
| 1068 |
**kwargs,
|
| 1069 |
) -> BatchFeature:
|
|
|
|
| 1205 |
self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device)
|
| 1206 |
|
| 1207 |
vision_cfg = config.vision_config
|
| 1208 |
+
# Use vision_attn_implementation if specified, otherwise fall back to general attn_implementation
|
| 1209 |
+
vision_cfg._attn_implementation = (
|
| 1210 |
+
config.vision_attn_implementation
|
| 1211 |
+
if config.vision_attn_implementation is not None
|
| 1212 |
+
else config._attn_implementation
|
| 1213 |
+
)
|
| 1214 |
if vision_cfg is None:
|
| 1215 |
raise ValueError("IsaacConfig should always have vision_config")
|
| 1216 |
|
|
|
|
| 1494 |
causal_mask = attention_mask
|
| 1495 |
else:
|
| 1496 |
min_dtype = torch.finfo(dtype).min
|
| 1497 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
|
|
|
|
|
| 1498 |
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 1499 |
if config.sliding_window is not None:
|
| 1500 |
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
|
|
|
| 1521 |
return causal_mask
|
| 1522 |
|
| 1523 |
|
|
|
|
| 1524 |
class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
|
| 1525 |
"""Isaac multimodal model for conditional generation."""
|
| 1526 |
|