from __future__ import annotations from collections import defaultdict from typing import Any, Union, TypedDict import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import PIL.Image from transformers import ( AutoTokenizer, BatchFeature, Cache, Qwen3Config, Qwen3ForCausalLM, Qwen3PreTrainedModel, ) from transformers.cache_utils import SlidingWindowCache, StaticCache from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3Model from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils import TensorType from transformers.modeling_attn_mask_utils import AttentionMaskConverter import re from transformers.models.siglip2.modeling_siglip2 import ( Siglip2MLP, ) from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig from perceptron.tensorstream import ( Event, Stream, TensorStream, TextType, VisionType, create_stream, group_streams, ) from perceptron.tensorstream.ops import ( compute_mrope_pos_tensor, modality_mask, reconstruct_tensor_stream_from_compact_dict, slice as ts_slice, tensor_stream_token_view, ) class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig): """Vision configuration for Isaac with Pixel Shuffle support. Extends Siglip2VisionConfig with additional fields for pixel shuffle. """ model_type = "pixel_shuffle_siglip2" base_config_key = "vision_config" def __init__( self, pixel_shuffle_scale_factor: int = 1, num_patches: int = 256, **kwargs, ): super().__init__(**kwargs) # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor self.num_patches = num_patches def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, int]: """Create cumulative sequence lengths for variable-length attention.""" cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) cu_seqlens[1:] = seq_sizes.cumsum(0) max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 return cu_seqlens, max_seqlen class Siglip2VariableSequenceEmbeddings(nn.Module): def __init__(self, config: PixelShuffleSiglip2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.patch_size = config.patch_size self.patch_embedding = nn.Linear( in_features=config.num_channels * self.patch_size * self.patch_size, out_features=self.embed_dim, ) self.num_patches = config.num_patches self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) def positional_embeddings( self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] ) -> torch.Tensor: # Prepare positional embeddings grid: (1, embed_dim, h, w) positional_embeddings = ( self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) .permute(2, 0, 1) .unsqueeze(0) ) _seq_patches, _seq_sizes, spatial_shapes = packed_seq_patches pos_embeds_list = [] mode = "bilinear" align_corners = False antialias = True for spatial_shape in spatial_shapes: height, width = spatial_shape # Guard to ensure height and width are positive for torch.compile if height > 0 and width > 0: resized_pos_embed = F.interpolate( positional_embeddings, size=(height, width), mode=mode, align_corners=align_corners, antialias=antialias, ) # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) else: # Fallback - should never happen in practice resized_pos_embed = positional_embeddings.reshape( self.embed_dim, self.position_embedding_size * self.position_embedding_size ).transpose(0, 1)[: height * width] pos_embeds_list.append(resized_pos_embed) # Concatenate all positional embeddings along the sequence dimension pos_embeds = torch.cat(pos_embeds_list, dim=0) return pos_embeds def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches # Apply patch embeddings target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) pos_embeds = self.positional_embeddings(packed_seq_patches) # Add positional embeddings to patch embeddings embeddings = patch_embeds + pos_embeds return embeddings class Siglip2VariableLengthAttention(nn.Module): """Custom attention that supports variable-length sequences with flash attention.""" def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): batch_size, seq_len, _ = hidden_states.size() # For variable-length attention, we need to reshape to (total_tokens, embed_dim) if batch_size != 1: raise ValueError("Variable-length attention expects batch_size=1 for packed sequences") hidden_states = hidden_states.squeeze(0) # Remove batch dimension: (seq_len, embed_dim) # Store original dtype orig_dtype = hidden_states.dtype # 1. Linear projections Q = self.q_proj(hidden_states) # (seq_len, embed_dim) K = self.k_proj(hidden_states) # (seq_len, embed_dim) V = self.v_proj(hidden_states) # (seq_len, embed_dim) # 2. Reshape for multi-head attention: (seq_len, n_heads, head_dim) Q = Q.view(-1, self.num_heads, self.embed_dim // self.num_heads) K = K.view(-1, self.num_heads, self.embed_dim // self.num_heads) V = V.view(-1, self.num_heads, self.embed_dim // self.num_heads) # 3. Apply variable-length attention using flash attention attn_output, _, _, _, _ = torch.ops.aten._flash_attention_forward( query=Q, key=K, value=V, cum_seq_q=cu_seqlens, cum_seq_k=cu_seqlens, max_q=max_seqlen, max_k=max_seqlen, dropout_p=self.dropout if self.training else 0.0, is_causal=False, return_debug_mask=False, scale=self.scale, window_size_left=-1, window_size_right=-1, alibi_slopes=None, ) # 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim) attn_output = attn_output.reshape(seq_len, self.embed_dim) # 5. Convert back to original dtype if needed if attn_output.dtype != orig_dtype: attn_output = attn_output.to(orig_dtype) # 6. Project output attn_output = self.out_proj(attn_output) # (seq_len, embed_dim) # 7. Add back batch dimension for compatibility attn_output = attn_output.unsqueeze(0) # (1, seq_len, embed_dim) return attn_output, None class IsaacSiglip2EncoderLayer(nn.Module): """Siglip2 encoder layer with variable-length attention.""" def __init__(self, config: PixelShuffleSiglip2VisionConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = Siglip2VariableLengthAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Siglip2MLP(config) # Use HF's Siglip2MLP self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor = None, max_seqlen: int = None, ) -> tuple[torch.FloatTensor]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return (hidden_states,) class IsaacEncoder(nn.Module): """Encoder using Isaac encoder layers with variable-length attention support.""" def __init__(self, config: PixelShuffleSiglip2VisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([IsaacSiglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) def forward( self, inputs_embeds, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, output_hidden_states: bool = False, ): all_hidden_states = () if output_hidden_states else None hidden_states = inputs_embeds for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = encoder_layer( hidden_states, cu_seqlens, max_seqlen, ) hidden_states = layer_outputs[0] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) return hidden_states, all_hidden_states, None def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, device: torch.device | None = None, ) -> torch.Tensor: """ Build a gather-index map that tells us, for every *output* token after pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. Args ---- seq_sizes : (num_images,) - #patches in each image (row-major order) token_grids : (num_images,2) - (height, width) for every image scale_factor : spatial down-scale factor (≥2) device : (optional) overrides `seq_sizes.device` Returns ------- gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. gather_idx[i, j] is the *flat* index into the *original* packed sequence for the j-th sub-patch that forms the i-th output token. """ if device is None: device = seq_sizes.device r = int(scale_factor) if r < 2: raise ValueError("`scale_factor` must be ≥ 2") # Safety: all spatial dims must be divisible by r # Cannot run under torch compile fullgraph mode hence if not torch.compiler.is_compiling(): if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()): raise AssertionError( f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}" ) gather_chunks: list[torch.Tensor] = [] tok_offset = 0 for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): # Build the (H, W) grid of flat indices for this image grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset grid = grid.view(h, w) # (H, W) # -------- identical ordering to your fixed-res routine -------- # Step 1: split width into blocks of r grid = grid.view(h, w // r, r) # (H, W/r, r) # Step 2: now split height into blocks of r grid = grid.view(h // r, r, w // r, r) # (H/r, r, W/r, r) # Step 3: final permutation to (H/r, W/r, r, r) grid = grid.permute(0, 2, 1, 3).contiguous() # (H/r, W/r, r, r) # Step 4: each (r, r) block forms one output token gather_chunks.append(grid.reshape(-1, r * r)) # (H*W / r², r²) tok_offset += seq_len # Concatenate over all images in the packed batch gather_idx = torch.cat(gather_chunks, dim=0) # (Σ_i HᵢWᵢ/r², r²) return gather_idx def pixel_shuffle_varlen( x: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, ) -> torch.Tensor: r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. Args: x (`torch.Tensor`): Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes produced by stacking image patches. token_grids (`torch.Tensor`): Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes corresponding to each image segment inside `x`. scale_factor (`int`, *optional*, defaults to 1): Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a single embedding channel-group. Returns: `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` if the singleton batch dimension was present. Raises: ValueError: If more than one batch item is provided. """ keep_batch_dim = x.dim() == 3 if keep_batch_dim: if x.size(0) != 1: raise AssertionError("Packed sequence is expected to have batch_size == 1") x_ = x.squeeze(0) # (seq, embed) else: x_ = x # (seq, embed) embed_dim = x_.size(-1) r = int(scale_factor) # Calculate seq_sizes from token_grids seq_sizes = torch.prod(token_grids, dim=-1) # Build index map and gather in one go gather_idx = create_pixel_shuffle_index_map( seq_sizes=seq_sizes, token_grids=token_grids, scale_factor=r, device=x_.device, ) # (new_seq, r²) # Gather → (new_seq, r², embed_dim) gathered = x_[gather_idx] # fancy indexing keeps gradient # Merge the r² group dimension into channels to finish the shuffle out = gathered.reshape(gathered.size(0), embed_dim * r * r) # Restore batch dimension if needed if keep_batch_dim: out = out.unsqueeze(0) return out class Siglip2SequenceVisionTransformer(nn.Module): def __init__(self, config: PixelShuffleSiglip2VisionConfig): super().__init__() self.config = config self.embeddings = Siglip2VariableSequenceEmbeddings(config) self.encoder = IsaacEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): seq_patches, token_grids = packed_seq_patches seq_sizes = torch.prod(token_grids, dim=-1) # Get embeddings from packed sequence hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) # Add a pseudo batch dimension for the encoder hidden_states = hidden_states.unsqueeze(0) # Generate cumulative sequence lengths for variable-length attention cu_seqlens, max_seqlen = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) # Pass through encoder with variable-length attention parameters hidden_states, _, _ = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) # Apply final layer normalization hidden_states = self.post_layernorm(hidden_states) if self.pixel_shuffle_scale_factor > 1: hidden_states = pixel_shuffle_varlen( x=hidden_states, token_grids=token_grids, scale_factor=self.pixel_shuffle_scale_factor, ) # Remove the pseudo batch dimension we added earlier hidden_states = hidden_states.squeeze(0) # Return the full sequence of embeddings return hidden_states # ============================================================================ # Configuration # ============================================================================ MAX_PIXELS = 60_000_000 # 60‑megapixel ceiling ≈ 8200 × 7300 px # Vision preprocessing constants VISION_MEAN = (0.5, 0.5, 0.5) VISION_STD = (0.5, 0.5, 0.5) VISION_SCALE = 1 / 255 def _make_writeable(arr: np.ndarray) -> np.ndarray: """Return *arr* itself if it is already writeable, otherwise try to flip the write flag in-place and finally fall back to `arr.copy()`. This guarantees the buffer handed to `torch.from_numpy()` is always writeable, silencing the PyTorch warning about undefined behaviour. """ if arr.flags.writeable: return arr # First, try the cheap path — in‑place flag toggle (works for mmap'd arrays # and some shared memory buffers): try: arr.setflags(write=True) return arr # success: no data copy except ValueError: # Buffer is inherently read‑only (e.g. backed by PyAV / PIL): make copy return arr.copy() def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None: if image.width * image.height > MAX_PIXELS: raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`") img = image if image.mode == "RGB" else image.convert("RGB") arr = np.asarray(img) arr = _make_writeable(arr) return torch.from_numpy(arr) def get_image_size_for_max_num_patches( image_height: int, image_width: int, patch_size: int, max_num_patches: int, min_num_patches: int | None = None, eps: float = 1e-5, pixel_shuffle_scale: int = 1, ) -> tuple[int, int]: r"""Compute a target resolution whose patch grid satisfies patching parametrization. Args: image_height (`int`): Height in pixels of the source image prior to any resizing. image_width (`int`): Width in pixels of the source image prior to any resizing. patch_size (`int`): Size of the square patch used by the vision encoder. max_num_patches (`int`): Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. min_num_patches (`int`, *optional*): Lower bound on the number of patches. When provided the image will be scaled up if necessary. eps (`float`, *optional*, defaults to 1e-5): Convergence tolerance for the internal binary search to determing the target dimensions. pixel_shuffle_scale (`int`, *optional*, defaults to 1): Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. Returns: `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` and respect both the maximum and optional minimum patch-count constraints. """ def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale): scaled_size = scale * original_size divisor = patch_size * pixel_shuffle_scale scaled_size = math.ceil(scaled_size / divisor) * divisor scaled_size = max(divisor, scaled_size) return int(scaled_size) # Ensure divisibility divisor = patch_size * pixel_shuffle_scale adjusted_height = math.ceil(image_height / divisor) * divisor adjusted_height = max(divisor, adjusted_height) adjusted_width = math.ceil(image_width / divisor) * divisor adjusted_width = max(divisor, adjusted_width) num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) if min_num_patches is not None and num_patches < min_num_patches: # Scale up scale_min, scale_max = 1.0, 100.0 while (scale_max - scale_min) >= eps: scale = (scale_min + scale_max) / 2 target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) num_patches = (target_height / patch_size) * (target_width / patch_size) if num_patches >= min_num_patches: scale_max = scale else: scale_min = scale scale = scale_max target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) return target_height, target_width elif num_patches <= max_num_patches: return adjusted_height, adjusted_width else: # Scale down scale_min, scale_max = eps / 10, 1.0 while (scale_max - scale_min) >= eps: scale = (scale_min + scale_max) / 2 target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) num_patches = (target_height / patch_size) * (target_width / patch_size) if num_patches <= max_num_patches: scale_min = scale else: scale_max = scale scale = scale_min target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) return target_height, target_width _MEAN_TENSOR = torch.tensor(VISION_MEAN, dtype=torch.float32).view(1, 1, 1, -1) _STD_TENSOR = torch.tensor(VISION_STD, dtype=torch.float32).view(1, 1, 1, -1) def prepare_image_tensor( image: torch.Tensor, scale: float = VISION_SCALE, ) -> torch.Tensor: r"""Standardize RGB images prior to patch extraction via rescaling and whitening. Args: image (`torch.Tensor`): Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating point if needed. scale (`float`, *optional*, defaults to `VISION_SCALE`): Scalar multiplier applied before normalization. Returns: `torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`. """ if not torch.is_floating_point(image): image = image.float() rescaled = image * scale # Use precomputed tensors and move to the correct device if needed mean_tensor = _MEAN_TENSOR.to(image.device) std_tensor = _STD_TENSOR.to(image.device) normalized = (rescaled - mean_tensor) / std_tensor return normalized def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: r"""Convert normalized images into flattened ViT-style patches. Args: image (`torch.Tensor`): Tensor of shape `(num_images, height, width, channels)`. patch_size (`int`): Edge length of the square patches Returns: `torch.Tensor`: Patch tensor where each position stores the flattened pixels belonging to that patch. Raises: ValueError: If `height` or `width` is not divisible by `patch_size`. """ num_images, height, width, channels = image.shape if height % patch_size or width % patch_size: raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) patches = patches.permute(0, 1, 3, 2, 4, 5) patches = patches.reshape(num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size) return patches def process_vision_for_patches( images: torch.Tensor, patch_size: int, max_num_patches: int, min_num_patches: int | None = None, pixel_shuffle_scale: int = 1, ) -> tuple[torch.Tensor, list[int]]: r"""Resize, normalize, and patchify RGB images for the vision encoder. Args: images (`torch.Tensor`): Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a batch. Channels are expected to be RGB. patch_size (`int`): Edge length of square patches; implictly controls resize grid granularity. max_num_patches (`int`): Maximum number of patches allowed after resizing. min_num_patches (`int`, *optional*): Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound. pixel_shuffle_scale (`int`, *optional*, defaults to 1): pixel shuffle scale factor; influences the target grid that the function produces. Returns: `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape `(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual` encodes effective `(images, height, width)` dimensions after optional pixel shuffling. """ # Add batch dim if single image if images.dim() == 3: images = images.unsqueeze(0) # Permute to channel first for resize images = images.permute(0, 3, 1, 2) # Get target dimensions _, _, orig_height, orig_width = images.shape target_height, target_width = get_image_size_for_max_num_patches( orig_height, orig_width, patch_size, max_num_patches, min_num_patches=min_num_patches, pixel_shuffle_scale=pixel_shuffle_scale, ) # Resize images = F.interpolate( images, size=(target_height, target_width), mode="bilinear", align_corners=False, ) # Back to channel last images = images.permute(0, 2, 3, 1) # Normalize images = prepare_image_tensor(images) # Patchify patches = patchify_vision(images, patch_size=patch_size) # Calculate dimensions for the patches n_images, h_patches, w_patches, _ = patches.shape dims_virtual = ( [1, h_patches, w_patches] if pixel_shuffle_scale == 1 else [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] ) return patches, dims_virtual def precompute_inv_freq(theta: float, dim: int) -> torch.Tensor: """ Returns shape (dim//2,). """ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) return inv_freq # type: ignore[return-value] def precompute_cos_sin_3d( position_ids: torch.Tensor, # shape (3, B, T) inv_freq: torch.Tensor, # shape (dim//2,) mrope_half_section: list[int], # sum to dim//2 ) -> tuple[torch.Tensor, torch.Tensor]: r"""Generate 3D rotary embeddings for multi-axis positions. Args: position_ids (`torch.Tensor`): Tensor of shape `(3, batch_size, seq_len)` containing positional indices for the x/y/t axes. inv_freq (`torch.Tensor`): Precomputed inverse frequency vector used to derive rotary phases. mrope_half_section (`list[int]`): Sizes the axis-specific frequency blocks. Returns: `tuple[torch.Tensor, torch.Tensor]`: Cosine and sine tensors, each of shape `(batch_size, seq_len, dim)`, ready to be passed into rotary attention layers. """ B = position_ids.shape[1] T = position_ids.shape[2] dim_half = inv_freq.shape[0] device = position_ids.device # Initialize with full dimension (not half) to match LLaMA cos_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) sin_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) offset = 0 for d in range(3): block_size = mrope_half_section[d] freq_slice = inv_freq[offset : offset + block_size] # shape => (block_size,) # shape => (B, T, block_size) phase = position_ids[d].unsqueeze(-1).float() * freq_slice cos_part = phase.cos() sin_part = phase.sin() # Duplicate values for both halves of the dimension cos_3d[:, :, offset : offset + block_size] = cos_part cos_3d[:, :, dim_half + offset : dim_half + offset + block_size] = cos_part sin_3d[:, :, offset : offset + block_size] = sin_part sin_3d[:, :, dim_half + offset : dim_half + offset + block_size] = sin_part offset += block_size return cos_3d, sin_3d class RopeScaling(TypedDict, total=False): rope_type: str factor: float mrope_section: list[int] mrope_interleaved: bool low_freq_factor: float high_freq_factor: float original_max_position_embeddings: int class IsaacConfig(Qwen3Config): """Configuration class for Isaac multimodal model.""" model_type = "isaac" sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig} def __init__( self, vision_config=None, vision_patch_size: int = 16, vision_max_num_patches: int = 256, vision_min_num_patches: int | None = None, pixel_shuffle_scale: int = 1, max_sequence_length: int = 16384, vision_token: str = "", **kwargs, ): super().__init__(**kwargs) # Handle vision config - either dict or PixelShuffleSiglip2VisionConfig instance if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() else: self.vision_config = vision_config # EventStreamProcessor parameters (for backward compatibility) self.video_patch_size = vision_patch_size self.vision_max_num_patches = vision_max_num_patches self.vision_min_num_patches = vision_min_num_patches self.pixel_shuffle_scale = pixel_shuffle_scale # Processing parameters self.max_sequence_length = max_sequence_length self.vision_token = vision_token # ============================================================================ # Processor Components # ============================================================================ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. Args: tokenizer (`AutoTokenizer`): Tokenizer used to convert text into model vocabulary ids. text (`str`): Plain-text fragment to encode. time (`float`, *optional*, defaults to 0.0): Timeline coordinate associated with the event. Both start and end times use the same value because text segments are instantaneous in the scheduler. Returns: `Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching metadata so that downstream processors can compute modality-specific embeddings. """ tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) # Calculate dimensions for the event num_tokens = len(tokens) dims_virtual = [num_tokens, 1] # [sequence_length, 1] dims_real = dims_virtual.copy() # Ensure tokens has the right shape for tensor_stream_token_view # It expects a 2D tensor where sum(dim=-1) gives the token IDs if tokens.dim() == 1: tokens = tokens.unsqueeze(-1) return Event( data=tokens, type=TextType.text, time=(time, time), dims_virtual=dims_virtual, dims_real=dims_real, idx_range=(0, num_tokens), ) # ============================================================================ # Processor # ============================================================================ class IsaacProcessor(ProcessorMixin): attributes = ["tokenizer"] tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, tokenizer: Qwen2Tokenizer, config: IsaacConfig | dict, ): super().__init__(tokenizer) self.tokenizer = tokenizer if isinstance(config, dict): config = IsaacConfig(**config) self.config = config # Use vision token from config self.vision_token = config.vision_token # Processing parameters self.max_sequence_length = config.max_sequence_length # Vision processing parameters self.patch_size = config.video_patch_size self.max_num_patches = config.vision_max_num_patches self.min_num_patches = config.vision_min_num_patches self.pixel_shuffle_scale = config.pixel_shuffle_scale def apply_chat_template( self, messages: list[dict[str, Any]], tokenize: bool = False, add_generation_prompt: bool = False, **kwargs, ) -> Any: return self.tokenizer.apply_chat_template( messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs ) def build_event_stream_simple( self, text: str, images: list[PIL.Image.Image] | None = None, ) -> Stream: events = [] # Process text and images # Find all occurrences of vision token pattern = re.escape(self.vision_token) parts = re.split(f"({pattern})", text) # Keep the delimiter in the result image_idx = 0 for current_time, part in enumerate(parts): if part == self.vision_token: # Replace vision token with image event if image_idx < len(images): # Create vision event from PIL image image_tensor = extract_image_pil(images[image_idx]) if image_tensor is not None: # Create a vision event with the image tensor vision_event = Event( data=image_tensor.unsqueeze(0), # HWC format from extract_image_pil type=VisionType.image, # I-frame time=(current_time, current_time), ) events.append(vision_event) image_idx += 1 elif part: # Non-empty text part # tokens = self.text_processor.tokenize(part, add_special_tokens=False) text_event = create_text_event(self.tokenizer, part, time=current_time) events.append(text_event) # Process vision events if any if any(event.type == VisionType.image for event in events): # Separate text and vision events for processing text_events = [event for event in events if event.type == TextType.text] vision_events = [event for event in events if event.type == VisionType.image] # Process vision events using functional approach processed_vision_events = [] for vision_event in vision_events: # Process the vision data patches, dims_virtual = process_vision_for_patches( vision_event.data.squeeze(0), # Remove the extra dimension patch_size=self.patch_size, max_num_patches=self.max_num_patches, min_num_patches=self.min_num_patches, pixel_shuffle_scale=self.pixel_shuffle_scale, ) # Update event with processed data vision_event.data = patches.unsqueeze(1) # Add back frame dimension vision_event.dims_virtual = dims_virtual vision_event.dims_real = ( dims_virtual if self.pixel_shuffle_scale == 1 else [ dims_virtual[0], dims_virtual[1] * self.pixel_shuffle_scale, dims_virtual[2] * self.pixel_shuffle_scale, ] ) vision_event.idx_range = (0, math.prod(dims_virtual)) # Flatten the patches vision_event.data = vision_event.data.reshape(-1, vision_event.data.shape[-1]) processed_vision_events.append(vision_event) events = text_events + processed_vision_events # Create stream without scheduling (events already in order) return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) def __call__( self, text: Union[str, list[str]], images: Union[PIL.Image.Image, list[PIL.Image.Image], None] = None, return_tensors: str | TensorType | None = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: """ Process text and images into TensorStream format. Args: text: Input text or list of texts with vision tokens images: PIL image or list of images (optional) return_tensors: Format for output tensors Returns: BatchFeature with input_ids and tensor_stream """ # Normalize inputs to lists if isinstance(text, str): texts = [text] else: texts = text if images is not None: if isinstance(images, PIL.Image.Image): images_list = [images] else: images_list = images else: images_list = None if len(texts) != 1: raise ValueError("IsaacProcessor currently supports batch_size=1") if images_list is not None: # Count vision tokens in text to validate image count vision_token_count = texts[0].count(self.vision_token) if vision_token_count != len(images_list): raise ValueError( f"Number of {self.vision_token} tokens in text ({vision_token_count}) " f"must match number of images ({len(images_list)})" ) # Build event stream stream = self.build_event_stream_simple( text=texts[0], images=images_list, ) # Create TensorStream tensor_stream = TensorStream([stream]) # Slice to max length if needed _, T = tensor_stream.shape if T > self.max_sequence_length: tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) # Get token view tokens = tensor_stream_token_view(tensor_stream) if return_tensors in (TensorType.PYTORCH, "pt"): input_ids = torch.as_tensor(tokens, dtype=torch.long) else: input_ids = tokens data = { "input_ids": input_ids, "tensor_stream": tensor_stream, } return BatchFeature(data=data) # ============================================================================ # Model # ============================================================================ def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: r"""Create 3D positional indices for token input. Args: input_ids (`torch.Tensor`): Tensor of shape `(batch_size, seq_len)` containing token ids. Returns: `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. """ batch_size, seq_length = input_ids.shape position_ids = torch.arange(seq_length, device=input_ids.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE return position_ids class IsaacRotaryEmbedding(nn.Module): def __init__(self, config: IsaacConfig, device=None): super().__init__() # Extract dimensions from config self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = config.head_dim # Get rope_scaling config - use direct access when available rope_scaling = getattr(config, "rope_scaling", None) or {} # Read RopeScaling parameters self.rope_type = rope_scaling.get("rope_type", "default") self.mrope_section = [ self.head_dim // 4, # 2x more for temporal dim self.head_dim // 8, self.head_dim // 8, ] rope_base = getattr(config, "rope_theta", 10000.0) inv_freq = precompute_inv_freq(rope_base, self.head_dim) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): # Ensure non-spatial tokens have 1D rotation equivalence not_spatial = ~(modality_tensor == VisionType.image.value) # shape is [N, 1] data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) # now broadcast it from [N, 1] -> [N, D] so it matches pos[not_spatial] exactly data_1d = data_1d.expand(-1, position_ids.shape[-1]) # expand along the last dim position_ids = position_ids.clone() # Clone to avoid warning about in-place operations on expanded tensors position_ids[not_spatial] = data_1d position_ids = position_ids.permute(2, 0, 1) # pos dim first -> (3, B, L) cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) return cos, sin class IsaacModel(Qwen3Model): def __init__(self, config: IsaacConfig): super().__init__(config) text_cfg = getattr(config, "get_text_config", lambda: config)() self.layers = torch.nn.ModuleList( [Qwen3DecoderLayer(text_cfg, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) vision_cfg = config.vision_config if vision_cfg is None: raise ValueError("IsaacConfig should always have vision_config") hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) self.vision_embedding = nn.Sequential( Siglip2SequenceVisionTransformer(vision_cfg), nn.Linear( hidden_dim, 4 * hidden_dim, bias=False, ), nn.SiLU(), nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), ) # Dispatch table for TensorStream balanced embedding (text + vision) self.embed_fns = { TextType: self.embed_text_tokens, VisionType: self.embed_vision, } def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed text tokens, squeezing singleton dimensions.""" # Text events are shaped as (..., 1); squeeze the singleton index dim h = self.embed_tokens(token_ids) if h.dim() >= 2 and h.size(-2) == 1: h = h[..., 0, :] return h def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """Embed vision tokens using the vision encoder.""" # vision tokens is (seq_patches, token_grids) return self.vision_embedding(vision_tokens) def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: """ Embed each modality stream independently, preserving the original TensorStream structure. """ flat_stream = tensor_stream.flat_stream() per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} # Collect per-event grids for vision tokens (H, W like dims sans time) token_grids = defaultdict(list) for stream in tensor_stream.streams: for event in stream: token_grids[event.type].append(event.dims(virtual=False)) embedded_compact = {} for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): if stream_type.modality == VisionType: # Build a (N_events, 2) grid tensor with spatial dims only grids = token_grids.get(stream_type, []) if len(grids) == 0: input_tensor = modality_payload_tensor else: token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] input_tensor = (modality_payload_tensor, token_grids_tensor) embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) else: embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) # Reconstruct a TensorStream with embedded payloads and compact embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) h = embedded_ts.compact() # (B, T, D) return h def forward( self, input_ids: torch.LongTensor | None = None, tensor_stream: TensorStream | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, modality_tensor: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple | BaseModelOutputWithPast: """ Forward pass with MRoPE position embeddings. Computes position embeddings once and passes them through all layers. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Get inputs if tensor_stream is not None and inputs_embeds is not None: raise ValueError("You cannot specify both tensor_stream and inputs_embeds") elif tensor_stream is not None: # Embed TensorStream directly inputs_embeds = self.embed_stream(tensor_stream) # Create modality tensor if not provided if modality_tensor is None: modality_tensor = modality_mask(tensor_stream) elif input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: inputs_embeds = self.embed_tokens(input_ids) # Create text modality tensor if not provided if modality_tensor is None: batch_size, seq_length = input_ids.shape modality_tensor = torch.full( (batch_size, seq_length), TextType.text.value, device=input_ids.device, dtype=torch.long ) elif inputs_embeds is None: raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") # Create default position_ids if not provided if position_ids is None: if tensor_stream is not None: position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) else: position_ids = compute_position_ids_input_ids(input_ids) # Compute MRoPE position embeddings if we have custom rotary_emb cos, sin = self.rotary_emb(position_ids, modality_tensor) cos = cos.to(inputs_embeds.dtype) sin = sin.to(inputs_embeds.dtype) # Prepare attention mask if attention_mask is not None: attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, False ) # Initialize hidden states hidden_states = inputs_embeds for decoder_layer in self.layers: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), **kwargs, ) hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs # Final layer norm hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen3Config, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. config (`Qwen3Config`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): """Isaac multimodal model for conditional generation.""" config_class = IsaacConfig def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) self.model = IsaacModel(config) # Use our custom model self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. self.rope_deltas = None self.config = config def get_rope_index( self, input_ids: torch.Tensor | None, tensor_stream: TensorStream | None, attention_mask: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute MRoPE position ids from a TensorStream (or 1D fallback). Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. rope_deltas is (B,1) used to advance positions in decode. """ # tensor_stream present: compute 3D coords if tensor_stream is None and input_ids is None: raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") if tensor_stream is not None: pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) else: pos_3d = compute_position_ids_input_ids(input_ids) B, L, _ = pos_3d.shape # Max position per batch across the 3 planes and sequence dimension: (B,) m_per_batch = pos_3d.amax(dim=(1, 2)) # Sequence lengths per batch: (B,) if attention_mask is None: seq_lens = torch.full_like(m_per_batch, L) else: seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) return pos_3d, rope_deltas def forward( self, input_ids: torch.LongTensor | None = None, tensor_stream: TensorStream | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple | CausalLMOutputWithPast: """ Forward pass for conditional generation supporting both standard inputs and TensorStream. Uses our embed_stream approach for multimodal inputs. """ # Don't compute embeddings here - let the model handle it if tensor_stream is not None: input_ids = None if input_ids is None and inputs_embeds is None and tensor_stream is None: raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") # Build position ids (MRoPE) if needed and tensor_stream is available # During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far # cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream. if position_ids is None and tensor_stream is not None: position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) elif position_ids is None and input_ids is not None: # For text inputs build position ids and modality tensor position_ids = compute_position_ids_input_ids(input_ids) if cache_position is not None and self.rope_deltas is not None: # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue # rotating in lockstep across generation steps. rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) else: rope_delta = 0 if cache_position is not None and not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` batch_size = input_ids.shape[0] rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) position_ids = position_ids.add(rope_delta) if tensor_stream is not None: modality_tensor = modality_mask(tensor_stream) else: batch_size, seq_len = input_ids.shape modality_tensor = torch.empty(batch_size, seq_len, device=position_ids.device).fill_(TextType.text.value) outputs = self.model( input_ids=input_ids, tensor_stream=tensor_stream, attention_mask=attention_mask, position_ids=position_ids, modality_tensor=modality_tensor, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=None, ) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: list[torch.FloatTensor] | None = None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, tensor_stream: TensorStream | None = None, cache_position: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool = True, **kwargs, ) -> dict[str, Any]: """ Prepare inputs for generation, handling TensorStream inputs properly. """ # Call parent preparation model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, **kwargs, ) # Handle TensorStream for first forward pass only if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): model_inputs["tensor_stream"] = tensor_stream # Let forward rebuild position_ids using cached deltas during decode model_inputs["position_ids"] = None # Drop tensor_stream after step 0 if cache_position is not None and cache_position[0] != 0: model_inputs["tensor_stream"] = None return model_inputs def can_generate(self) -> bool: return True __all__ = [ "IsaacConfig", "IsaacModel", "IsaacForConditionalGeneration", "IsaacProcessor", ]