File size: 11,416 Bytes
e3f3842 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
from dataclasses import dataclass
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from segmentation_models_pytorch.base import SegmentationHead
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
from timm.layers.create_act import create_act_layer
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
from .convlstm import ConvLSTM
class ACTUConfig(PretrainedConfig):
model_type = "actu"
def __init__(
self,
# Base ACTU parameters
in_channels: int = 3,
kernel_size: tuple[int, int] = (3, 3),
padding="same",
stride=(1, 1),
backbone="resnet34",
bias=True,
batch_first=True,
bidirectional=False,
original_resolution=(256, 256),
act_layer="sigmoid",
n_classes=1,
# Variant control parameters
use_dem_input: bool = False,
use_climate_branch: bool = False,
# Climate branch parameters
climate_seq_len=5,
climate_input_dim=6,
lstm_hidden_dim=128,
num_lstm_layers=1,
**kwargs,
):
super().__init__(**kwargs)
self.in_channels = in_channels
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.backbone = backbone
self.bias = bias
self.batch_first = batch_first
self.bidirectional = bidirectional
self.original_resolution = original_resolution
self.act_layer = act_layer
self.n_classes = n_classes
# Parameters to control variants
self.use_dem_input = use_dem_input
self.use_climate_branch = use_climate_branch
self.climate_seq_len = climate_seq_len
self.climate_input_dim = climate_input_dim
self.lstm_hidden_dim = lstm_hidden_dim
self.num_lstm_layers = num_lstm_layers
# Adjust in_channels if DEM is used
if self.use_dem_input:
self.in_channels += 1
class ACTUForImageSegmentation(PreTrainedModel):
config_class = ACTUConfig
def __init__(self, config: ACTUConfig):
super().__init__(config)
self.config = config
self.encoder: nn.Module = timm.create_model(
config.backbone, features_only=True, in_chans=config.in_channels
)
with torch.no_grad():
dummy_input_channels = config.in_channels
dummy_input = torch.randn(
1, dummy_input_channels, *config.original_resolution, device=self.device
)
embs = self.encoder(dummy_input)
self.embs_shape = [e.shape for e in embs]
self.encoder_channels = [e[1] for e in self.embs_shape]
self.convlstm = nn.ModuleList(
[
ConvLSTM(
in_channels=shape[1],
hidden_channels=shape[1],
kernel_size=config.kernel_size,
padding=config.padding,
stride=config.stride,
bias=config.bias,
batch_first=config.batch_first,
bidirectional=config.bidirectional,
)
for shape in self.embs_shape
]
)
if self.config.use_climate_branch:
self.climate_branch = ClimateBranchLSTM(
output_shapes=[e[1:] for e in self.embs_shape],
lstm_hidden_dim=config.lstm_hidden_dim,
climate_seq_len=config.climate_seq_len,
climate_input_dim=config.climate_input_dim,
num_lstm_layers=config.num_lstm_layers,
)
self.fusers = nn.ModuleList(
GatedFusion(enc, enc) for enc in self.encoder_channels
)
self.decoder = UnetDecoder(
encoder_channels=[1] + self.encoder_channels,
decoder_channels=self.encoder_channels[::-1],
n_blocks=len(self.encoder_channels),
)
self.seg_head = nn.Sequential(
SegmentationHead(
in_channels=self.encoder_channels[0],
out_channels=config.n_classes,
),
create_act_layer(config.act_layer, inplace=True),
)
def forward(
self,
pixel_values: torch.Tensor,
climate: torch.Tensor = None,
dem: torch.Tensor = None,
labels: torch.Tensor = None,
**kwargs,
) -> SemanticSegmenterOutput:
b, t = pixel_values.shape[:2]
original_size = pixel_values.shape[-2:]
# Handle DEM input
if self.config.use_dem_input:
if dem is None:
raise ValueError(
"DEM tensor must be provided when use_dem_input is True."
)
dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t)
pixel_values = torch.cat([pixel_values, dem_repeated], dim=2)
# 1. Encode images per time step
encoded_sequence = self._encode_images(pixel_values)
# 2. Handle Climate Branch Fusion
if self.config.use_climate_branch:
if climate is None:
raise ValueError(
"Climate tensor must be provided when use_climate_branch is True."
)
climate_features = self.climate_branch(climate)
# Reshape for fusion
encoded_sequence_reshaped = [
rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence
]
climate_features_reshaped = [
rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features
]
# Fuse features
fused_features = [
fuser(img, clim)
for fuser, img, clim in zip(
self.fusers, encoded_sequence_reshaped, climate_features_reshaped
)
]
# Reshape back to sequence
encoded_sequence = [
rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features
]
# 3. Process sequence with ConvLSTM
temporal_features = self._encode_timeseries(encoded_sequence)
# 4. Decode to get the segmentation map
logits = self._decode(temporal_features, size=original_size)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels.float().unsqueeze(1))
return SemanticSegmenterOutput(
loss=loss,
logits=logits,
)
def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]:
B = x.size(0)
encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w"))
return [
rearrange(frames, "(b t) c h w -> b t c h w", b=B)
for frames in encoded_frames
]
def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]:
outs = []
for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))):
lstm_out, (_, _) = convlstm(encoded)
outs.append(lstm_out[:, -1, :, :, :])
return outs
def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
trend_map = self.decoder(*[None] + x[::-1])
trend_map = self.seg_head(trend_map)
trend_map = F.interpolate(
trend_map, size=size, mode="bilinear", align_corners=False
)
return trend_map
class ClimateBranchLSTM(nn.Module):
"""
Processes climate time series data using an LSTM.
Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5)
Output shape: (B, T, output_dim) -> e.g., (B, 5, 128)
"""
def __init__(
self,
output_shapes: list[tuple[int, int, int]],
climate_input_dim=5,
climate_seq_len=6,
lstm_hidden_dim=64,
num_lstm_layers=1,
):
super().__init__()
self.climate_seq_len = climate_seq_len
self.climate_input_dim = climate_input_dim
self.lstm_hidden_dim = lstm_hidden_dim
self.num_lstm_layers = num_lstm_layers
self.proj_dim = 128
self.output_shapes = output_shapes
self.lstm = nn.LSTM(
input_size=climate_input_dim,
hidden_size=lstm_hidden_dim,
num_layers=num_lstm_layers,
batch_first=True, # Crucial: expects input shape (batch, seq_len, features)
dropout=0.3 if num_lstm_layers > 1 else 0,
bidirectional=False,
)
# Linear layer to project LSTM output to the desired final dimension
self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim)
self.upsamples = nn.ModuleList(
_build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes
)
def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]:
# climate_data shape: (B, T, T_1, C_clim), e.g., (B, 5, 6, 5)
B_img, B_cli, T, C = climate_data.shape
# Reshape for LSTM: Treat each sequence independently
lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C")
# Pass through LSTM
_, (hidden, _) = self.lstm.forward(lstm_input)
# Get the last layer's hidden state
last_hidden = (
hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1]
)
if last_hidden.ndim == 3:
last_hidden = hidden.mean(dim=0)
# Pass the final hidden state through the fully connected layer(s) and upsample
climate_features = self.fc(last_hidden)
climate_features = rearrange(climate_features, "b c -> b c 1 1")
climate_features = [
rearrange(
u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli
)
for u in self.upsamples
]
return climate_features
class GatedFusion(nn.Module):
def __init__(self, img_channels, clim_channels):
super().__init__()
self.gate = nn.Sequential(
nn.Sequential(
nn.Conv2d(
img_channels + clim_channels, img_channels, kernel_size=3, padding=1
),
nn.ReLU(inplace=True),
nn.Conv2d(img_channels, img_channels, kernel_size=1),
nn.Sigmoid(), # Gate values between 0 and 1
)
)
def forward(self, img_feat, clim_feat):
gate = self.gate(torch.cat([img_feat, clim_feat], dim=1))
return gate * img_feat + (1 - gate) * clim_feat
def _build_upsampler(
in_channels: int, target_channels: int, target_h: int
) -> nn.Sequential:
layers = []
current_h = 1
# Expand to target channels early (e.g., 1x1 → 1x1 with target_channels)
layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()]
# Upsample spatially to target_h
while current_h < target_h:
next_h = min(current_h * 2, target_h)
layers += [
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1),
nn.GELU(),
]
current_h = next_h
return nn.Sequential(*layers)
|