DarthReca commited on
Commit
dd7cdee
·
verified ·
1 Parent(s): 4970bed

Upload 4 files

Browse files
Files changed (4) hide show
  1. NOTICE.md +30 -0
  2. README.md +68 -3
  3. convlstm.py +209 -0
  4. modeling_actu.py +332 -0
NOTICE.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Third-Party Software
2
+
3
+ This project incorporates components from the following third-party software:
4
+
5
+ ### ConvLSTM (MIT License)
6
+
7
+ The code ConvLSTM is used in this project. The original license is as follows:
8
+
9
+ ---
10
+ MIT License
11
+
12
+ Copyright (c) 2022 Seyong Kim
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ The above copyright notice and this permission notice shall be included in all
22
+ copies or substantial portions of the Software.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,68 @@
1
- ---
2
- license: openrail
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail
3
+ datasets:
4
+ - DarthReca/hydro-chronos
5
+ pipeline_tag: image-segmentation
6
+ tags:
7
+ - climate
8
+ - geospatial
9
+ - remote-sensing
10
+ - spatiotemporal
11
+ - multi-modal
12
+ - earth-observation
13
+ - time-series
14
+ - hydrology
15
+ library_name: transformers
16
+ ---
17
+
18
+ # ACTU for Magnitude Regression
19
+
20
+ <!-- Provide a quick summary of what the model is/does. -->
21
+ This is ACTU for pixelwise regression of MNDWI.
22
+
23
+ ## Model Details
24
+
25
+ <!-- Provide a longer summary of what this model is. -->
26
+ This architecture is a temporal UNet (with ConvLSTMs), featuring an LSTM branch to process climate timeseries and a gating mechanism.
27
+ It is designed to receive a timeseries of Sentinel-2 images, DEM, and timeseries of climate variables and output a single real mask of future MNDWI.
28
+
29
+ - **Developed by:** Daniele Rege Cambrin
30
+ - **Model type:** ACTU
31
+ - **License:** OpenRAIL
32
+ - **Repository:** [Github](https://github.com/DarthReca/hydro-chronos)
33
+ - **Paper:** [Arxiv](https://arxiv.org/abs/2506.14362)
34
+
35
+
36
+ ## How to Get Started with the Model
37
+ The model is integrated into Transformers, so you can easily load it with the following code:
38
+
39
+ ```python
40
+ AutoModel.from_pretrained("DarthReca/actu-magnitude-regression", trust_remote_code=True, revision=<model_type>)
41
+ ```
42
+
43
+ Load the model with the desired configuration with the *revision* parameter (the branches of this repo). These configurations are available:
44
+
45
+ | Revision | Backbone | DEM | Climate |
46
+ |-------------|-----------------|:---:|:-------:|
47
+ | main | ConvNeXtV2 Base | No | No |
48
+ | dem-climate | ConvNeXtV2 Base | Yes | Yes |
49
+
50
+ ## Training Details
51
+ The model is pre-trained on Landsat-5 images and fine-tuned on Sentinel-2 of HydroChronos.
52
+
53
+ ## Citation
54
+
55
+ ```bibtex
56
+ @misc{cambrin2025hydrochronosforecastingdecadessurface,
57
+ title={HydroChronos: Forecasting Decades of Surface Water Change},
58
+ author={Daniele Rege Cambrin and Eleonora Poeta and Eliana Pastor and Isaac Corley and Tania Cerquitelli and Elena Baralis and Paolo Garza},
59
+ year={2025},
60
+ eprint={2506.14362},
61
+ archivePrefix={arXiv},
62
+ primaryClass={cs.CV},
63
+ url={https://arxiv.org/abs/2506.14362},
64
+ }
65
+ ```
66
+
67
+ ## Licensing
68
+ The project uses third-party software. For detailed information on the licensing of each component, please see the [**NOTICE.md**](NOTICE.md) file.
convlstm.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Seyong Kim
2
+
3
+ from typing import Any, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import Tensor, nn, sigmoid, tanh
7
+
8
+
9
+ class ConvGate(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channels: int,
13
+ hidden_channels: int,
14
+ kernel_size: Union[Tuple[int, int], int],
15
+ padding: Union[Tuple[int, int], int],
16
+ stride: Union[Tuple[int, int], int],
17
+ bias: bool,
18
+ ):
19
+ super(ConvGate, self).__init__()
20
+ self.conv_x = nn.Conv2d(
21
+ in_channels=in_channels,
22
+ out_channels=hidden_channels * 4,
23
+ kernel_size=kernel_size,
24
+ padding=padding,
25
+ stride=stride,
26
+ bias=bias,
27
+ )
28
+ self.conv_h = nn.Conv2d(
29
+ in_channels=hidden_channels,
30
+ out_channels=hidden_channels * 4,
31
+ kernel_size=kernel_size,
32
+ padding=padding,
33
+ stride=stride,
34
+ bias=bias,
35
+ )
36
+ self.bn2d = nn.BatchNorm2d(hidden_channels * 4)
37
+
38
+ def forward(self, x, hidden_state):
39
+ gated = self.conv_x(x) + self.conv_h(hidden_state)
40
+ return self.bn2d(gated)
41
+
42
+
43
+ class ConvLSTMCell(nn.Module):
44
+ def __init__(
45
+ self, in_channels, hidden_channels, kernel_size, padding, stride, bias
46
+ ):
47
+ super().__init__()
48
+ # To check the model structure with tools such as torchinfo, need to wrap
49
+ # the custom module with nn.ModuleList
50
+ self.gates = nn.ModuleList(
51
+ [ConvGate(in_channels, hidden_channels, kernel_size, padding, stride, bias)]
52
+ )
53
+
54
+ def forward(
55
+ self, x: Tensor, hidden_state: Tensor, cell_state: Tensor
56
+ ) -> Tuple[Tensor, Tensor]:
57
+ gated = self.gates[0](x, hidden_state)
58
+ i_gated, f_gated, c_gated, o_gated = gated.chunk(4, dim=1)
59
+
60
+ i_gated = sigmoid(i_gated)
61
+ f_gated = sigmoid(f_gated)
62
+ o_gated = sigmoid(o_gated)
63
+
64
+ cell_state = f_gated.mul(cell_state) + i_gated.mul(tanh(c_gated))
65
+ hidden_state = o_gated.mul(tanh(cell_state))
66
+
67
+ return hidden_state, cell_state
68
+
69
+
70
+ class ConvLSTM(nn.Module):
71
+ """ConvLSTM module"""
72
+
73
+ def __init__(
74
+ self,
75
+ in_channels,
76
+ hidden_channels,
77
+ kernel_size,
78
+ padding,
79
+ stride,
80
+ bias,
81
+ batch_first,
82
+ bidirectional,
83
+ ):
84
+ super().__init__()
85
+ self.in_channels = in_channels
86
+ self.hidden_channels = hidden_channels
87
+ self.bidirectional = bidirectional
88
+ self.batch_first = batch_first
89
+
90
+ # To check the model structure with tools such as torchinfo, need to wrap
91
+ # the custom module with nn.ModuleList
92
+ self.conv_lstm_cells = nn.ModuleList(
93
+ [
94
+ ConvLSTMCell(
95
+ in_channels, hidden_channels, kernel_size, padding, stride, bias
96
+ )
97
+ ]
98
+ )
99
+
100
+ if self.bidirectional:
101
+ self.conv_lstm_cells.append(
102
+ ConvLSTMCell(
103
+ in_channels, hidden_channels, kernel_size, padding, stride, bias
104
+ )
105
+ )
106
+
107
+ self.batch_size = None
108
+ self.seq_len = None
109
+ self.height = None
110
+ self.width = None
111
+
112
+ def forward(
113
+ self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] = None
114
+ ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
115
+ # size of x: B, T, C, H, W or T, B, C, H, W
116
+ x = self._check_shape(x)
117
+ hidden_state, cell_state, backward_hidden_state, backward_cell_state = (
118
+ self.init_state(x, state)
119
+ )
120
+
121
+ output, hidden_state, cell_state = self._forward(
122
+ self.conv_lstm_cells[0], x, hidden_state, cell_state
123
+ )
124
+
125
+ if self.bidirectional:
126
+ x = torch.flip(x, [1])
127
+ backward_output, backward_hidden_state, backward_cell_state = self._forward(
128
+ self.conv_lstm_cells[1], x, backward_hidden_state, backward_cell_state
129
+ )
130
+
131
+ output = torch.cat([output, backward_output], dim=-3)
132
+ hidden_state = torch.cat([hidden_state, backward_hidden_state], dim=-1)
133
+ cell_state = torch.cat([cell_state, backward_cell_state], dim=-1)
134
+ return output, (hidden_state, cell_state)
135
+
136
+ def _forward(self, lstm_cell, x, hidden_state, cell_state):
137
+ outputs = []
138
+ for time_step in range(self.seq_len):
139
+ x_t = x[:, time_step, :, :, :]
140
+ hidden_state, cell_state = lstm_cell(x_t, hidden_state, cell_state)
141
+ outputs.append(hidden_state.detach())
142
+ output = torch.stack(outputs, dim=1)
143
+ return output, hidden_state, cell_state
144
+
145
+ def _check_shape(self, x: Tensor) -> Tensor:
146
+ if self.batch_first:
147
+ batch_size, self.seq_len = x.shape[0], x.shape[1]
148
+ else:
149
+ batch_size, self.seq_len = x.shape[1], x.shape[0]
150
+ x = x.permute(1, 0, 2, 3)
151
+ x = torch.swapaxes(x, 0, 1)
152
+
153
+ self.height = x.shape[-2]
154
+ self.width = x.shape[-1]
155
+
156
+ dim = len(x.shape)
157
+
158
+ if dim == 4:
159
+ x = x.unsqueeze(dim=1) # increase dimension
160
+ x = x.view(batch_size, self.seq_len, -1, self.height, self.width)
161
+ x = x.contiguous() # Reassign memory location
162
+ elif dim <= 3:
163
+ raise ValueError(
164
+ f"Got {len(x.shape)} dimensional tensor. Input shape unmatched"
165
+ )
166
+
167
+ return x
168
+
169
+ def init_state(
170
+ self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]]
171
+ ) -> Tuple[Union[Tensor, Any], Union[Tensor, Any], Optional[Any], Optional[Any]]:
172
+ # If state doesn't enter as input, initialize state to zeros
173
+ backward_hidden_state, backward_cell_state = None, None
174
+
175
+ if state is None:
176
+ self.batch_size = x.shape[0]
177
+ hidden_state, cell_state = self._init_state(x.dtype, x.device)
178
+
179
+ if self.bidirectional:
180
+ backward_hidden_state, backward_cell_state = self._init_state(
181
+ x.dtype, x.device
182
+ )
183
+ else:
184
+ if self.bidirectional:
185
+ hidden_state, hidden_state_back = state[0].chunk(2, dim=-1)
186
+ cell_state, cell_state_back = state[1].chunk(2, dim=-1)
187
+ else:
188
+ hidden_state, cell_state = state
189
+
190
+ return hidden_state, cell_state, backward_hidden_state, backward_cell_state
191
+
192
+ def _init_state(self, dtype, device):
193
+ self.register_buffer(
194
+ "hidden_state",
195
+ torch.zeros(
196
+ (1, self.hidden_channels, self.height, self.width),
197
+ dtype=dtype,
198
+ device=device,
199
+ ),
200
+ )
201
+ self.register_buffer(
202
+ "cell_state",
203
+ torch.zeros(
204
+ (1, self.hidden_channels, self.height, self.width),
205
+ dtype=dtype,
206
+ device=device,
207
+ ),
208
+ )
209
+ return self.hidden_state, self.cell_state
modeling_actu.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ import timm
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from segmentation_models_pytorch.base import SegmentationHead
10
+ from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
11
+ from timm.layers.create_act import create_act_layer
12
+ from transformers import PretrainedConfig, PreTrainedModel
13
+ from transformers.modeling_outputs import SemanticSegmenterOutput
14
+
15
+ from .convlstm import ConvLSTM
16
+
17
+
18
+ class ACTUConfig(PretrainedConfig):
19
+ model_type = "actu"
20
+
21
+ def __init__(
22
+ self,
23
+ # Base ACTU parameters
24
+ in_channels: int = 3,
25
+ kernel_size: tuple[int, int] = (3, 3),
26
+ padding="same",
27
+ stride=(1, 1),
28
+ backbone="resnet34",
29
+ bias=True,
30
+ batch_first=True,
31
+ bidirectional=False,
32
+ original_resolution=(256, 256),
33
+ act_layer="sigmoid",
34
+ n_classes=1,
35
+ # Variant control parameters
36
+ use_dem_input: bool = False,
37
+ use_climate_branch: bool = False,
38
+ # Climate branch parameters
39
+ climate_seq_len=5,
40
+ climate_input_dim=6,
41
+ lstm_hidden_dim=128,
42
+ num_lstm_layers=1,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(**kwargs)
46
+ self.in_channels = in_channels
47
+ self.kernel_size = kernel_size
48
+ self.padding = padding
49
+ self.stride = stride
50
+ self.backbone = backbone
51
+ self.bias = bias
52
+ self.batch_first = batch_first
53
+ self.bidirectional = bidirectional
54
+ self.original_resolution = original_resolution
55
+ self.act_layer = act_layer
56
+ self.n_classes = n_classes
57
+
58
+ # Parameters to control variants
59
+ self.use_dem_input = use_dem_input
60
+ self.use_climate_branch = use_climate_branch
61
+ self.climate_seq_len = climate_seq_len
62
+ self.climate_input_dim = climate_input_dim
63
+ self.lstm_hidden_dim = lstm_hidden_dim
64
+ self.num_lstm_layers = num_lstm_layers
65
+
66
+ # Adjust in_channels if DEM is used
67
+ if self.use_dem_input:
68
+ self.in_channels += 1
69
+
70
+
71
+ class ACTUForImageSegmentation(PreTrainedModel):
72
+ config_class = ACTUConfig
73
+
74
+ def __init__(self, config: ACTUConfig):
75
+ super().__init__(config)
76
+ self.config = config
77
+
78
+ self.encoder: nn.Module = timm.create_model(
79
+ config.backbone, features_only=True, in_chans=config.in_channels
80
+ )
81
+
82
+ with torch.no_grad():
83
+ dummy_input_channels = config.in_channels
84
+ dummy_input = torch.randn(
85
+ 1, dummy_input_channels, *config.original_resolution, device=self.device
86
+ )
87
+ embs = self.encoder(dummy_input)
88
+ self.embs_shape = [e.shape for e in embs]
89
+ self.encoder_channels = [e[1] for e in self.embs_shape]
90
+
91
+ self.convlstm = nn.ModuleList(
92
+ [
93
+ ConvLSTM(
94
+ in_channels=shape[1],
95
+ hidden_channels=shape[1],
96
+ kernel_size=config.kernel_size,
97
+ padding=config.padding,
98
+ stride=config.stride,
99
+ bias=config.bias,
100
+ batch_first=config.batch_first,
101
+ bidirectional=config.bidirectional,
102
+ )
103
+ for shape in self.embs_shape
104
+ ]
105
+ )
106
+
107
+ if self.config.use_climate_branch:
108
+ self.climate_branch = ClimateBranchLSTM(
109
+ output_shapes=[e[1:] for e in self.embs_shape],
110
+ lstm_hidden_dim=config.lstm_hidden_dim,
111
+ climate_seq_len=config.climate_seq_len,
112
+ climate_input_dim=config.climate_input_dim,
113
+ num_lstm_layers=config.num_lstm_layers,
114
+ )
115
+ self.fusers = nn.ModuleList(
116
+ GatedFusion(enc, enc) for enc in self.encoder_channels
117
+ )
118
+
119
+ self.decoder = UnetDecoder(
120
+ encoder_channels=[1] + self.encoder_channels,
121
+ decoder_channels=self.encoder_channels[::-1],
122
+ n_blocks=len(self.encoder_channels),
123
+ )
124
+
125
+ self.seg_head = nn.Sequential(
126
+ SegmentationHead(
127
+ in_channels=self.encoder_channels[0],
128
+ out_channels=config.n_classes,
129
+ ),
130
+ create_act_layer(config.act_layer, inplace=True),
131
+ )
132
+
133
+ def forward(
134
+ self,
135
+ pixel_values: torch.Tensor,
136
+ climate: torch.Tensor = None,
137
+ dem: torch.Tensor = None,
138
+ labels: torch.Tensor = None,
139
+ **kwargs,
140
+ ) -> SemanticSegmenterOutput:
141
+ b, t = pixel_values.shape[:2]
142
+ original_size = pixel_values.shape[-2:]
143
+
144
+ # Handle DEM input
145
+ if self.config.use_dem_input:
146
+ if dem is None:
147
+ raise ValueError(
148
+ "DEM tensor must be provided when use_dem_input is True."
149
+ )
150
+ dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t)
151
+ pixel_values = torch.cat([pixel_values, dem_repeated], dim=2)
152
+
153
+ # 1. Encode images per time step
154
+ encoded_sequence = self._encode_images(pixel_values)
155
+
156
+ # 2. Handle Climate Branch Fusion
157
+ if self.config.use_climate_branch:
158
+ if climate is None:
159
+ raise ValueError(
160
+ "Climate tensor must be provided when use_climate_branch is True."
161
+ )
162
+
163
+ climate_features = self.climate_branch(climate)
164
+
165
+ # Reshape for fusion
166
+ encoded_sequence_reshaped = [
167
+ rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence
168
+ ]
169
+ climate_features_reshaped = [
170
+ rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features
171
+ ]
172
+
173
+ # Fuse features
174
+ fused_features = [
175
+ fuser(img, clim)
176
+ for fuser, img, clim in zip(
177
+ self.fusers, encoded_sequence_reshaped, climate_features_reshaped
178
+ )
179
+ ]
180
+
181
+ # Reshape back to sequence
182
+ encoded_sequence = [
183
+ rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features
184
+ ]
185
+
186
+ # 3. Process sequence with ConvLSTM
187
+ temporal_features = self._encode_timeseries(encoded_sequence)
188
+
189
+ # 4. Decode to get the segmentation map
190
+ logits = self._decode(temporal_features, size=original_size)
191
+
192
+ loss = None
193
+ if labels is not None:
194
+ loss_fct = nn.CrossEntropyLoss()
195
+ loss = loss_fct(logits, labels.float().unsqueeze(1))
196
+
197
+ return SemanticSegmenterOutput(
198
+ loss=loss,
199
+ logits=logits,
200
+ )
201
+
202
+ def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]:
203
+ B = x.size(0)
204
+ encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w"))
205
+ return [
206
+ rearrange(frames, "(b t) c h w -> b t c h w", b=B)
207
+ for frames in encoded_frames
208
+ ]
209
+
210
+ def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]:
211
+ outs = []
212
+ for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))):
213
+ lstm_out, (_, _) = convlstm(encoded)
214
+ outs.append(lstm_out[:, -1, :, :, :])
215
+ return outs
216
+
217
+ def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
218
+ trend_map = self.decoder(*[None] + x[::-1])
219
+ trend_map = self.seg_head(trend_map)
220
+ trend_map = F.interpolate(
221
+ trend_map, size=size, mode="bilinear", align_corners=False
222
+ )
223
+ return trend_map
224
+
225
+
226
+ class ClimateBranchLSTM(nn.Module):
227
+ """
228
+ Processes climate time series data using an LSTM.
229
+ Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5)
230
+ Output shape: (B, T, output_dim) -> e.g., (B, 5, 128)
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ output_shapes: list[tuple[int, int, int]],
236
+ climate_input_dim=5,
237
+ climate_seq_len=6,
238
+ lstm_hidden_dim=64,
239
+ num_lstm_layers=1,
240
+ ):
241
+ super().__init__()
242
+ self.climate_seq_len = climate_seq_len
243
+ self.climate_input_dim = climate_input_dim
244
+ self.lstm_hidden_dim = lstm_hidden_dim
245
+ self.num_lstm_layers = num_lstm_layers
246
+ self.proj_dim = 128
247
+ self.output_shapes = output_shapes
248
+
249
+ self.lstm = nn.LSTM(
250
+ input_size=climate_input_dim,
251
+ hidden_size=lstm_hidden_dim,
252
+ num_layers=num_lstm_layers,
253
+ batch_first=True, # Crucial: expects input shape (batch, seq_len, features)
254
+ dropout=0.3 if num_lstm_layers > 1 else 0,
255
+ bidirectional=False,
256
+ )
257
+
258
+ # Linear layer to project LSTM output to the desired final dimension
259
+ self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim)
260
+
261
+ self.upsamples = nn.ModuleList(
262
+ _build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes
263
+ )
264
+
265
+ def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]:
266
+ # climate_data shape: (B, T, T_1, C_clim), e.g., (B, 5, 6, 5)
267
+ B_img, B_cli, T, C = climate_data.shape
268
+
269
+ # Reshape for LSTM: Treat each sequence independently
270
+ lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C")
271
+
272
+ # Pass through LSTM
273
+ _, (hidden, _) = self.lstm.forward(lstm_input)
274
+ # Get the last layer's hidden state
275
+ last_hidden = (
276
+ hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1]
277
+ )
278
+ if last_hidden.ndim == 3:
279
+ last_hidden = hidden.mean(dim=0)
280
+
281
+ # Pass the final hidden state through the fully connected layer(s) and upsample
282
+ climate_features = self.fc(last_hidden)
283
+ climate_features = rearrange(climate_features, "b c -> b c 1 1")
284
+ climate_features = [
285
+ rearrange(
286
+ u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli
287
+ )
288
+ for u in self.upsamples
289
+ ]
290
+
291
+ return climate_features
292
+
293
+
294
+ class GatedFusion(nn.Module):
295
+ def __init__(self, img_channels, clim_channels):
296
+ super().__init__()
297
+ self.gate = nn.Sequential(
298
+ nn.Sequential(
299
+ nn.Conv2d(
300
+ img_channels + clim_channels, img_channels, kernel_size=3, padding=1
301
+ ),
302
+ nn.ReLU(inplace=True),
303
+ nn.Conv2d(img_channels, img_channels, kernel_size=1),
304
+ nn.Sigmoid(), # Gate values between 0 and 1
305
+ )
306
+ )
307
+
308
+ def forward(self, img_feat, clim_feat):
309
+ gate = self.gate(torch.cat([img_feat, clim_feat], dim=1))
310
+ return gate * img_feat + (1 - gate) * clim_feat
311
+
312
+
313
+ def _build_upsampler(
314
+ in_channels: int, target_channels: int, target_h: int
315
+ ) -> nn.Sequential:
316
+ layers = []
317
+ current_h = 1
318
+
319
+ # Expand to target channels early (e.g., 1x1 → 1x1 with target_channels)
320
+ layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()]
321
+
322
+ # Upsample spatially to target_h
323
+ while current_h < target_h:
324
+ next_h = min(current_h * 2, target_h)
325
+ layers += [
326
+ nn.Upsample(scale_factor=2, mode="nearest"),
327
+ nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1),
328
+ nn.GELU(),
329
+ ]
330
+ current_h = next_h
331
+
332
+ return nn.Sequential(*layers)