from dataclasses import dataclass import numpy as np import timm import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from segmentation_models_pytorch.base import SegmentationHead from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder from timm.layers.create_act import create_act_layer from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import SemanticSegmenterOutput from .convlstm import ConvLSTM class ACTUConfig(PretrainedConfig): model_type = "actu" def __init__( self, # Base ACTU parameters in_channels: int = 3, kernel_size: tuple[int, int] = (3, 3), padding="same", stride=(1, 1), backbone="resnet34", bias=True, batch_first=True, bidirectional=False, original_resolution=(256, 256), act_layer="sigmoid", n_classes=1, # Variant control parameters use_dem_input: bool = False, use_climate_branch: bool = False, # Climate branch parameters climate_seq_len=5, climate_input_dim=6, lstm_hidden_dim=128, num_lstm_layers=1, **kwargs, ): super().__init__(**kwargs) self.in_channels = in_channels self.kernel_size = kernel_size self.padding = padding self.stride = stride self.backbone = backbone self.bias = bias self.batch_first = batch_first self.bidirectional = bidirectional self.original_resolution = original_resolution self.act_layer = act_layer self.n_classes = n_classes # Parameters to control variants self.use_dem_input = use_dem_input self.use_climate_branch = use_climate_branch self.climate_seq_len = climate_seq_len self.climate_input_dim = climate_input_dim self.lstm_hidden_dim = lstm_hidden_dim self.num_lstm_layers = num_lstm_layers # Adjust in_channels if DEM is used if self.use_dem_input: self.in_channels += 1 class ACTUForImageSegmentation(PreTrainedModel): config_class = ACTUConfig def __init__(self, config: ACTUConfig): super().__init__(config) self.config = config self.encoder: nn.Module = timm.create_model( config.backbone, features_only=True, in_chans=config.in_channels ) with torch.no_grad(): dummy_input_channels = config.in_channels dummy_input = torch.randn( 1, dummy_input_channels, *config.original_resolution, device=self.device ) embs = self.encoder(dummy_input) self.embs_shape = [e.shape for e in embs] self.encoder_channels = [e[1] for e in self.embs_shape] self.convlstm = nn.ModuleList( [ ConvLSTM( in_channels=shape[1], hidden_channels=shape[1], kernel_size=config.kernel_size, padding=config.padding, stride=config.stride, bias=config.bias, batch_first=config.batch_first, bidirectional=config.bidirectional, ) for shape in self.embs_shape ] ) if self.config.use_climate_branch: self.climate_branch = ClimateBranchLSTM( output_shapes=[e[1:] for e in self.embs_shape], lstm_hidden_dim=config.lstm_hidden_dim, climate_seq_len=config.climate_seq_len, climate_input_dim=config.climate_input_dim, num_lstm_layers=config.num_lstm_layers, ) self.fusers = nn.ModuleList( GatedFusion(enc, enc) for enc in self.encoder_channels ) self.decoder = UnetDecoder( encoder_channels=[1] + self.encoder_channels, decoder_channels=self.encoder_channels[::-1], n_blocks=len(self.encoder_channels), ) self.seg_head = nn.Sequential( SegmentationHead( in_channels=self.encoder_channels[0], out_channels=config.n_classes, ), create_act_layer(config.act_layer, inplace=True), ) def forward( self, pixel_values: torch.Tensor, climate: torch.Tensor = None, dem: torch.Tensor = None, labels: torch.Tensor = None, **kwargs, ) -> SemanticSegmenterOutput: b, t = pixel_values.shape[:2] original_size = pixel_values.shape[-2:] # Handle DEM input if self.config.use_dem_input: if dem is None: raise ValueError( "DEM tensor must be provided when use_dem_input is True." ) dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t) pixel_values = torch.cat([pixel_values, dem_repeated], dim=2) # 1. Encode images per time step encoded_sequence = self._encode_images(pixel_values) # 2. Handle Climate Branch Fusion if self.config.use_climate_branch: if climate is None: raise ValueError( "Climate tensor must be provided when use_climate_branch is True." ) climate_features = self.climate_branch(climate) # Reshape for fusion encoded_sequence_reshaped = [ rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence ] climate_features_reshaped = [ rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features ] # Fuse features fused_features = [ fuser(img, clim) for fuser, img, clim in zip( self.fusers, encoded_sequence_reshaped, climate_features_reshaped ) ] # Reshape back to sequence encoded_sequence = [ rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features ] # 3. Process sequence with ConvLSTM temporal_features = self._encode_timeseries(encoded_sequence) # 4. Decode to get the segmentation map logits = self._decode(temporal_features, size=original_size) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels.float().unsqueeze(1)) return SemanticSegmenterOutput( loss=loss, logits=logits, ) def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]: B = x.size(0) encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w")) return [ rearrange(frames, "(b t) c h w -> b t c h w", b=B) for frames in encoded_frames ] def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]: outs = [] for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))): lstm_out, (_, _) = convlstm(encoded) outs.append(lstm_out[:, -1, :, :, :]) return outs def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: trend_map = self.decoder(*[None] + x[::-1]) trend_map = self.seg_head(trend_map) trend_map = F.interpolate( trend_map, size=size, mode="bilinear", align_corners=False ) return trend_map class ClimateBranchLSTM(nn.Module): """ Processes climate time series data using an LSTM. Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5) Output shape: (B, T, output_dim) -> e.g., (B, 5, 128) """ def __init__( self, output_shapes: list[tuple[int, int, int]], climate_input_dim=5, climate_seq_len=6, lstm_hidden_dim=64, num_lstm_layers=1, ): super().__init__() self.climate_seq_len = climate_seq_len self.climate_input_dim = climate_input_dim self.lstm_hidden_dim = lstm_hidden_dim self.num_lstm_layers = num_lstm_layers self.proj_dim = 128 self.output_shapes = output_shapes self.lstm = nn.LSTM( input_size=climate_input_dim, hidden_size=lstm_hidden_dim, num_layers=num_lstm_layers, batch_first=True, # Crucial: expects input shape (batch, seq_len, features) dropout=0.3 if num_lstm_layers > 1 else 0, bidirectional=False, ) # Linear layer to project LSTM output to the desired final dimension self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim) self.upsamples = nn.ModuleList( _build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes ) def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]: # climate_data shape: (B, T, T_1, C_clim), e.g., (B, 5, 6, 5) B_img, B_cli, T, C = climate_data.shape # Reshape for LSTM: Treat each sequence independently lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C") # Pass through LSTM _, (hidden, _) = self.lstm.forward(lstm_input) # Get the last layer's hidden state last_hidden = ( hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1] ) if last_hidden.ndim == 3: last_hidden = hidden.mean(dim=0) # Pass the final hidden state through the fully connected layer(s) and upsample climate_features = self.fc(last_hidden) climate_features = rearrange(climate_features, "b c -> b c 1 1") climate_features = [ rearrange( u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli ) for u in self.upsamples ] return climate_features class GatedFusion(nn.Module): def __init__(self, img_channels, clim_channels): super().__init__() self.gate = nn.Sequential( nn.Sequential( nn.Conv2d( img_channels + clim_channels, img_channels, kernel_size=3, padding=1 ), nn.ReLU(inplace=True), nn.Conv2d(img_channels, img_channels, kernel_size=1), nn.Sigmoid(), # Gate values between 0 and 1 ) ) def forward(self, img_feat, clim_feat): gate = self.gate(torch.cat([img_feat, clim_feat], dim=1)) return gate * img_feat + (1 - gate) * clim_feat def _build_upsampler( in_channels: int, target_channels: int, target_h: int ) -> nn.Sequential: layers = [] current_h = 1 # Expand to target channels early (e.g., 1x1 → 1x1 with target_channels) layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()] # Upsample spatially to target_h while current_h < target_h: next_h = min(current_h * 2, target_h) layers += [ nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1), nn.GELU(), ] current_h = next_h return nn.Sequential(*layers)