Upload 4 files
Browse files- NOTICE.md +30 -0
- README.md +68 -3
- convlstm.py +209 -0
- modeling_actu.py +332 -0
NOTICE.md
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Third-Party Software
|
2 |
+
|
3 |
+
This project incorporates components from the following third-party software:
|
4 |
+
|
5 |
+
### ConvLSTM (MIT License)
|
6 |
+
|
7 |
+
The code ConvLSTM is used in this project. The original license is as follows:
|
8 |
+
|
9 |
+
---
|
10 |
+
MIT License
|
11 |
+
|
12 |
+
Copyright (c) 2022 Seyong Kim
|
13 |
+
|
14 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
15 |
+
of this software and associated documentation files (the "Software"), to deal
|
16 |
+
in the Software without restriction, including without limitation the rights
|
17 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
18 |
+
copies of the Software, and to permit persons to whom the Software is
|
19 |
+
furnished to do so, subject to the following conditions:
|
20 |
+
|
21 |
+
The above copyright notice and this permission notice shall be included in all
|
22 |
+
copies or substantial portions of the Software.
|
23 |
+
|
24 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
25 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
26 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
27 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
28 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
29 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
30 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,68 @@
|
|
1 |
-
---
|
2 |
-
license: openrail
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: openrail
|
3 |
+
datasets:
|
4 |
+
- DarthReca/hydro-chronos
|
5 |
+
pipeline_tag: image-segmentation
|
6 |
+
tags:
|
7 |
+
- climate
|
8 |
+
- geospatial
|
9 |
+
- remote-sensing
|
10 |
+
- spatiotemporal
|
11 |
+
- multi-modal
|
12 |
+
- earth-observation
|
13 |
+
- time-series
|
14 |
+
- hydrology
|
15 |
+
library_name: transformers
|
16 |
+
---
|
17 |
+
|
18 |
+
# ACTU for Magnitude Regression
|
19 |
+
|
20 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
21 |
+
This is ACTU for pixelwise regression of MNDWI.
|
22 |
+
|
23 |
+
## Model Details
|
24 |
+
|
25 |
+
<!-- Provide a longer summary of what this model is. -->
|
26 |
+
This architecture is a temporal UNet (with ConvLSTMs), featuring an LSTM branch to process climate timeseries and a gating mechanism.
|
27 |
+
It is designed to receive a timeseries of Sentinel-2 images, DEM, and timeseries of climate variables and output a single real mask of future MNDWI.
|
28 |
+
|
29 |
+
- **Developed by:** Daniele Rege Cambrin
|
30 |
+
- **Model type:** ACTU
|
31 |
+
- **License:** OpenRAIL
|
32 |
+
- **Repository:** [Github](https://github.com/DarthReca/hydro-chronos)
|
33 |
+
- **Paper:** [Arxiv](https://arxiv.org/abs/2506.14362)
|
34 |
+
|
35 |
+
|
36 |
+
## How to Get Started with the Model
|
37 |
+
The model is integrated into Transformers, so you can easily load it with the following code:
|
38 |
+
|
39 |
+
```python
|
40 |
+
AutoModel.from_pretrained("DarthReca/actu-magnitude-regression", trust_remote_code=True, revision=<model_type>)
|
41 |
+
```
|
42 |
+
|
43 |
+
Load the model with the desired configuration with the *revision* parameter (the branches of this repo). These configurations are available:
|
44 |
+
|
45 |
+
| Revision | Backbone | DEM | Climate |
|
46 |
+
|-------------|-----------------|:---:|:-------:|
|
47 |
+
| main | ConvNeXtV2 Base | No | No |
|
48 |
+
| dem-climate | ConvNeXtV2 Base | Yes | Yes |
|
49 |
+
|
50 |
+
## Training Details
|
51 |
+
The model is pre-trained on Landsat-5 images and fine-tuned on Sentinel-2 of HydroChronos.
|
52 |
+
|
53 |
+
## Citation
|
54 |
+
|
55 |
+
```bibtex
|
56 |
+
@misc{cambrin2025hydrochronosforecastingdecadessurface,
|
57 |
+
title={HydroChronos: Forecasting Decades of Surface Water Change},
|
58 |
+
author={Daniele Rege Cambrin and Eleonora Poeta and Eliana Pastor and Isaac Corley and Tania Cerquitelli and Elena Baralis and Paolo Garza},
|
59 |
+
year={2025},
|
60 |
+
eprint={2506.14362},
|
61 |
+
archivePrefix={arXiv},
|
62 |
+
primaryClass={cs.CV},
|
63 |
+
url={https://arxiv.org/abs/2506.14362},
|
64 |
+
}
|
65 |
+
```
|
66 |
+
|
67 |
+
## Licensing
|
68 |
+
The project uses third-party software. For detailed information on the licensing of each component, please see the [**NOTICE.md**](NOTICE.md) file.
|
convlstm.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Seyong Kim
|
2 |
+
|
3 |
+
from typing import Any, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor, nn, sigmoid, tanh
|
7 |
+
|
8 |
+
|
9 |
+
class ConvGate(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
in_channels: int,
|
13 |
+
hidden_channels: int,
|
14 |
+
kernel_size: Union[Tuple[int, int], int],
|
15 |
+
padding: Union[Tuple[int, int], int],
|
16 |
+
stride: Union[Tuple[int, int], int],
|
17 |
+
bias: bool,
|
18 |
+
):
|
19 |
+
super(ConvGate, self).__init__()
|
20 |
+
self.conv_x = nn.Conv2d(
|
21 |
+
in_channels=in_channels,
|
22 |
+
out_channels=hidden_channels * 4,
|
23 |
+
kernel_size=kernel_size,
|
24 |
+
padding=padding,
|
25 |
+
stride=stride,
|
26 |
+
bias=bias,
|
27 |
+
)
|
28 |
+
self.conv_h = nn.Conv2d(
|
29 |
+
in_channels=hidden_channels,
|
30 |
+
out_channels=hidden_channels * 4,
|
31 |
+
kernel_size=kernel_size,
|
32 |
+
padding=padding,
|
33 |
+
stride=stride,
|
34 |
+
bias=bias,
|
35 |
+
)
|
36 |
+
self.bn2d = nn.BatchNorm2d(hidden_channels * 4)
|
37 |
+
|
38 |
+
def forward(self, x, hidden_state):
|
39 |
+
gated = self.conv_x(x) + self.conv_h(hidden_state)
|
40 |
+
return self.bn2d(gated)
|
41 |
+
|
42 |
+
|
43 |
+
class ConvLSTMCell(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self, in_channels, hidden_channels, kernel_size, padding, stride, bias
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
# To check the model structure with tools such as torchinfo, need to wrap
|
49 |
+
# the custom module with nn.ModuleList
|
50 |
+
self.gates = nn.ModuleList(
|
51 |
+
[ConvGate(in_channels, hidden_channels, kernel_size, padding, stride, bias)]
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(
|
55 |
+
self, x: Tensor, hidden_state: Tensor, cell_state: Tensor
|
56 |
+
) -> Tuple[Tensor, Tensor]:
|
57 |
+
gated = self.gates[0](x, hidden_state)
|
58 |
+
i_gated, f_gated, c_gated, o_gated = gated.chunk(4, dim=1)
|
59 |
+
|
60 |
+
i_gated = sigmoid(i_gated)
|
61 |
+
f_gated = sigmoid(f_gated)
|
62 |
+
o_gated = sigmoid(o_gated)
|
63 |
+
|
64 |
+
cell_state = f_gated.mul(cell_state) + i_gated.mul(tanh(c_gated))
|
65 |
+
hidden_state = o_gated.mul(tanh(cell_state))
|
66 |
+
|
67 |
+
return hidden_state, cell_state
|
68 |
+
|
69 |
+
|
70 |
+
class ConvLSTM(nn.Module):
|
71 |
+
"""ConvLSTM module"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
in_channels,
|
76 |
+
hidden_channels,
|
77 |
+
kernel_size,
|
78 |
+
padding,
|
79 |
+
stride,
|
80 |
+
bias,
|
81 |
+
batch_first,
|
82 |
+
bidirectional,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
self.in_channels = in_channels
|
86 |
+
self.hidden_channels = hidden_channels
|
87 |
+
self.bidirectional = bidirectional
|
88 |
+
self.batch_first = batch_first
|
89 |
+
|
90 |
+
# To check the model structure with tools such as torchinfo, need to wrap
|
91 |
+
# the custom module with nn.ModuleList
|
92 |
+
self.conv_lstm_cells = nn.ModuleList(
|
93 |
+
[
|
94 |
+
ConvLSTMCell(
|
95 |
+
in_channels, hidden_channels, kernel_size, padding, stride, bias
|
96 |
+
)
|
97 |
+
]
|
98 |
+
)
|
99 |
+
|
100 |
+
if self.bidirectional:
|
101 |
+
self.conv_lstm_cells.append(
|
102 |
+
ConvLSTMCell(
|
103 |
+
in_channels, hidden_channels, kernel_size, padding, stride, bias
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
self.batch_size = None
|
108 |
+
self.seq_len = None
|
109 |
+
self.height = None
|
110 |
+
self.width = None
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] = None
|
114 |
+
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
115 |
+
# size of x: B, T, C, H, W or T, B, C, H, W
|
116 |
+
x = self._check_shape(x)
|
117 |
+
hidden_state, cell_state, backward_hidden_state, backward_cell_state = (
|
118 |
+
self.init_state(x, state)
|
119 |
+
)
|
120 |
+
|
121 |
+
output, hidden_state, cell_state = self._forward(
|
122 |
+
self.conv_lstm_cells[0], x, hidden_state, cell_state
|
123 |
+
)
|
124 |
+
|
125 |
+
if self.bidirectional:
|
126 |
+
x = torch.flip(x, [1])
|
127 |
+
backward_output, backward_hidden_state, backward_cell_state = self._forward(
|
128 |
+
self.conv_lstm_cells[1], x, backward_hidden_state, backward_cell_state
|
129 |
+
)
|
130 |
+
|
131 |
+
output = torch.cat([output, backward_output], dim=-3)
|
132 |
+
hidden_state = torch.cat([hidden_state, backward_hidden_state], dim=-1)
|
133 |
+
cell_state = torch.cat([cell_state, backward_cell_state], dim=-1)
|
134 |
+
return output, (hidden_state, cell_state)
|
135 |
+
|
136 |
+
def _forward(self, lstm_cell, x, hidden_state, cell_state):
|
137 |
+
outputs = []
|
138 |
+
for time_step in range(self.seq_len):
|
139 |
+
x_t = x[:, time_step, :, :, :]
|
140 |
+
hidden_state, cell_state = lstm_cell(x_t, hidden_state, cell_state)
|
141 |
+
outputs.append(hidden_state.detach())
|
142 |
+
output = torch.stack(outputs, dim=1)
|
143 |
+
return output, hidden_state, cell_state
|
144 |
+
|
145 |
+
def _check_shape(self, x: Tensor) -> Tensor:
|
146 |
+
if self.batch_first:
|
147 |
+
batch_size, self.seq_len = x.shape[0], x.shape[1]
|
148 |
+
else:
|
149 |
+
batch_size, self.seq_len = x.shape[1], x.shape[0]
|
150 |
+
x = x.permute(1, 0, 2, 3)
|
151 |
+
x = torch.swapaxes(x, 0, 1)
|
152 |
+
|
153 |
+
self.height = x.shape[-2]
|
154 |
+
self.width = x.shape[-1]
|
155 |
+
|
156 |
+
dim = len(x.shape)
|
157 |
+
|
158 |
+
if dim == 4:
|
159 |
+
x = x.unsqueeze(dim=1) # increase dimension
|
160 |
+
x = x.view(batch_size, self.seq_len, -1, self.height, self.width)
|
161 |
+
x = x.contiguous() # Reassign memory location
|
162 |
+
elif dim <= 3:
|
163 |
+
raise ValueError(
|
164 |
+
f"Got {len(x.shape)} dimensional tensor. Input shape unmatched"
|
165 |
+
)
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
def init_state(
|
170 |
+
self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]]
|
171 |
+
) -> Tuple[Union[Tensor, Any], Union[Tensor, Any], Optional[Any], Optional[Any]]:
|
172 |
+
# If state doesn't enter as input, initialize state to zeros
|
173 |
+
backward_hidden_state, backward_cell_state = None, None
|
174 |
+
|
175 |
+
if state is None:
|
176 |
+
self.batch_size = x.shape[0]
|
177 |
+
hidden_state, cell_state = self._init_state(x.dtype, x.device)
|
178 |
+
|
179 |
+
if self.bidirectional:
|
180 |
+
backward_hidden_state, backward_cell_state = self._init_state(
|
181 |
+
x.dtype, x.device
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
if self.bidirectional:
|
185 |
+
hidden_state, hidden_state_back = state[0].chunk(2, dim=-1)
|
186 |
+
cell_state, cell_state_back = state[1].chunk(2, dim=-1)
|
187 |
+
else:
|
188 |
+
hidden_state, cell_state = state
|
189 |
+
|
190 |
+
return hidden_state, cell_state, backward_hidden_state, backward_cell_state
|
191 |
+
|
192 |
+
def _init_state(self, dtype, device):
|
193 |
+
self.register_buffer(
|
194 |
+
"hidden_state",
|
195 |
+
torch.zeros(
|
196 |
+
(1, self.hidden_channels, self.height, self.width),
|
197 |
+
dtype=dtype,
|
198 |
+
device=device,
|
199 |
+
),
|
200 |
+
)
|
201 |
+
self.register_buffer(
|
202 |
+
"cell_state",
|
203 |
+
torch.zeros(
|
204 |
+
(1, self.hidden_channels, self.height, self.width),
|
205 |
+
dtype=dtype,
|
206 |
+
device=device,
|
207 |
+
),
|
208 |
+
)
|
209 |
+
return self.hidden_state, self.cell_state
|
modeling_actu.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import timm
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
from segmentation_models_pytorch.base import SegmentationHead
|
10 |
+
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
|
11 |
+
from timm.layers.create_act import create_act_layer
|
12 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
13 |
+
from transformers.modeling_outputs import SemanticSegmenterOutput
|
14 |
+
|
15 |
+
from .convlstm import ConvLSTM
|
16 |
+
|
17 |
+
|
18 |
+
class ACTUConfig(PretrainedConfig):
|
19 |
+
model_type = "actu"
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
# Base ACTU parameters
|
24 |
+
in_channels: int = 3,
|
25 |
+
kernel_size: tuple[int, int] = (3, 3),
|
26 |
+
padding="same",
|
27 |
+
stride=(1, 1),
|
28 |
+
backbone="resnet34",
|
29 |
+
bias=True,
|
30 |
+
batch_first=True,
|
31 |
+
bidirectional=False,
|
32 |
+
original_resolution=(256, 256),
|
33 |
+
act_layer="sigmoid",
|
34 |
+
n_classes=1,
|
35 |
+
# Variant control parameters
|
36 |
+
use_dem_input: bool = False,
|
37 |
+
use_climate_branch: bool = False,
|
38 |
+
# Climate branch parameters
|
39 |
+
climate_seq_len=5,
|
40 |
+
climate_input_dim=6,
|
41 |
+
lstm_hidden_dim=128,
|
42 |
+
num_lstm_layers=1,
|
43 |
+
**kwargs,
|
44 |
+
):
|
45 |
+
super().__init__(**kwargs)
|
46 |
+
self.in_channels = in_channels
|
47 |
+
self.kernel_size = kernel_size
|
48 |
+
self.padding = padding
|
49 |
+
self.stride = stride
|
50 |
+
self.backbone = backbone
|
51 |
+
self.bias = bias
|
52 |
+
self.batch_first = batch_first
|
53 |
+
self.bidirectional = bidirectional
|
54 |
+
self.original_resolution = original_resolution
|
55 |
+
self.act_layer = act_layer
|
56 |
+
self.n_classes = n_classes
|
57 |
+
|
58 |
+
# Parameters to control variants
|
59 |
+
self.use_dem_input = use_dem_input
|
60 |
+
self.use_climate_branch = use_climate_branch
|
61 |
+
self.climate_seq_len = climate_seq_len
|
62 |
+
self.climate_input_dim = climate_input_dim
|
63 |
+
self.lstm_hidden_dim = lstm_hidden_dim
|
64 |
+
self.num_lstm_layers = num_lstm_layers
|
65 |
+
|
66 |
+
# Adjust in_channels if DEM is used
|
67 |
+
if self.use_dem_input:
|
68 |
+
self.in_channels += 1
|
69 |
+
|
70 |
+
|
71 |
+
class ACTUForImageSegmentation(PreTrainedModel):
|
72 |
+
config_class = ACTUConfig
|
73 |
+
|
74 |
+
def __init__(self, config: ACTUConfig):
|
75 |
+
super().__init__(config)
|
76 |
+
self.config = config
|
77 |
+
|
78 |
+
self.encoder: nn.Module = timm.create_model(
|
79 |
+
config.backbone, features_only=True, in_chans=config.in_channels
|
80 |
+
)
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
dummy_input_channels = config.in_channels
|
84 |
+
dummy_input = torch.randn(
|
85 |
+
1, dummy_input_channels, *config.original_resolution, device=self.device
|
86 |
+
)
|
87 |
+
embs = self.encoder(dummy_input)
|
88 |
+
self.embs_shape = [e.shape for e in embs]
|
89 |
+
self.encoder_channels = [e[1] for e in self.embs_shape]
|
90 |
+
|
91 |
+
self.convlstm = nn.ModuleList(
|
92 |
+
[
|
93 |
+
ConvLSTM(
|
94 |
+
in_channels=shape[1],
|
95 |
+
hidden_channels=shape[1],
|
96 |
+
kernel_size=config.kernel_size,
|
97 |
+
padding=config.padding,
|
98 |
+
stride=config.stride,
|
99 |
+
bias=config.bias,
|
100 |
+
batch_first=config.batch_first,
|
101 |
+
bidirectional=config.bidirectional,
|
102 |
+
)
|
103 |
+
for shape in self.embs_shape
|
104 |
+
]
|
105 |
+
)
|
106 |
+
|
107 |
+
if self.config.use_climate_branch:
|
108 |
+
self.climate_branch = ClimateBranchLSTM(
|
109 |
+
output_shapes=[e[1:] for e in self.embs_shape],
|
110 |
+
lstm_hidden_dim=config.lstm_hidden_dim,
|
111 |
+
climate_seq_len=config.climate_seq_len,
|
112 |
+
climate_input_dim=config.climate_input_dim,
|
113 |
+
num_lstm_layers=config.num_lstm_layers,
|
114 |
+
)
|
115 |
+
self.fusers = nn.ModuleList(
|
116 |
+
GatedFusion(enc, enc) for enc in self.encoder_channels
|
117 |
+
)
|
118 |
+
|
119 |
+
self.decoder = UnetDecoder(
|
120 |
+
encoder_channels=[1] + self.encoder_channels,
|
121 |
+
decoder_channels=self.encoder_channels[::-1],
|
122 |
+
n_blocks=len(self.encoder_channels),
|
123 |
+
)
|
124 |
+
|
125 |
+
self.seg_head = nn.Sequential(
|
126 |
+
SegmentationHead(
|
127 |
+
in_channels=self.encoder_channels[0],
|
128 |
+
out_channels=config.n_classes,
|
129 |
+
),
|
130 |
+
create_act_layer(config.act_layer, inplace=True),
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
pixel_values: torch.Tensor,
|
136 |
+
climate: torch.Tensor = None,
|
137 |
+
dem: torch.Tensor = None,
|
138 |
+
labels: torch.Tensor = None,
|
139 |
+
**kwargs,
|
140 |
+
) -> SemanticSegmenterOutput:
|
141 |
+
b, t = pixel_values.shape[:2]
|
142 |
+
original_size = pixel_values.shape[-2:]
|
143 |
+
|
144 |
+
# Handle DEM input
|
145 |
+
if self.config.use_dem_input:
|
146 |
+
if dem is None:
|
147 |
+
raise ValueError(
|
148 |
+
"DEM tensor must be provided when use_dem_input is True."
|
149 |
+
)
|
150 |
+
dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t)
|
151 |
+
pixel_values = torch.cat([pixel_values, dem_repeated], dim=2)
|
152 |
+
|
153 |
+
# 1. Encode images per time step
|
154 |
+
encoded_sequence = self._encode_images(pixel_values)
|
155 |
+
|
156 |
+
# 2. Handle Climate Branch Fusion
|
157 |
+
if self.config.use_climate_branch:
|
158 |
+
if climate is None:
|
159 |
+
raise ValueError(
|
160 |
+
"Climate tensor must be provided when use_climate_branch is True."
|
161 |
+
)
|
162 |
+
|
163 |
+
climate_features = self.climate_branch(climate)
|
164 |
+
|
165 |
+
# Reshape for fusion
|
166 |
+
encoded_sequence_reshaped = [
|
167 |
+
rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence
|
168 |
+
]
|
169 |
+
climate_features_reshaped = [
|
170 |
+
rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features
|
171 |
+
]
|
172 |
+
|
173 |
+
# Fuse features
|
174 |
+
fused_features = [
|
175 |
+
fuser(img, clim)
|
176 |
+
for fuser, img, clim in zip(
|
177 |
+
self.fusers, encoded_sequence_reshaped, climate_features_reshaped
|
178 |
+
)
|
179 |
+
]
|
180 |
+
|
181 |
+
# Reshape back to sequence
|
182 |
+
encoded_sequence = [
|
183 |
+
rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features
|
184 |
+
]
|
185 |
+
|
186 |
+
# 3. Process sequence with ConvLSTM
|
187 |
+
temporal_features = self._encode_timeseries(encoded_sequence)
|
188 |
+
|
189 |
+
# 4. Decode to get the segmentation map
|
190 |
+
logits = self._decode(temporal_features, size=original_size)
|
191 |
+
|
192 |
+
loss = None
|
193 |
+
if labels is not None:
|
194 |
+
loss_fct = nn.CrossEntropyLoss()
|
195 |
+
loss = loss_fct(logits, labels.float().unsqueeze(1))
|
196 |
+
|
197 |
+
return SemanticSegmenterOutput(
|
198 |
+
loss=loss,
|
199 |
+
logits=logits,
|
200 |
+
)
|
201 |
+
|
202 |
+
def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]:
|
203 |
+
B = x.size(0)
|
204 |
+
encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w"))
|
205 |
+
return [
|
206 |
+
rearrange(frames, "(b t) c h w -> b t c h w", b=B)
|
207 |
+
for frames in encoded_frames
|
208 |
+
]
|
209 |
+
|
210 |
+
def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]:
|
211 |
+
outs = []
|
212 |
+
for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))):
|
213 |
+
lstm_out, (_, _) = convlstm(encoded)
|
214 |
+
outs.append(lstm_out[:, -1, :, :, :])
|
215 |
+
return outs
|
216 |
+
|
217 |
+
def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
218 |
+
trend_map = self.decoder(*[None] + x[::-1])
|
219 |
+
trend_map = self.seg_head(trend_map)
|
220 |
+
trend_map = F.interpolate(
|
221 |
+
trend_map, size=size, mode="bilinear", align_corners=False
|
222 |
+
)
|
223 |
+
return trend_map
|
224 |
+
|
225 |
+
|
226 |
+
class ClimateBranchLSTM(nn.Module):
|
227 |
+
"""
|
228 |
+
Processes climate time series data using an LSTM.
|
229 |
+
Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5)
|
230 |
+
Output shape: (B, T, output_dim) -> e.g., (B, 5, 128)
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
output_shapes: list[tuple[int, int, int]],
|
236 |
+
climate_input_dim=5,
|
237 |
+
climate_seq_len=6,
|
238 |
+
lstm_hidden_dim=64,
|
239 |
+
num_lstm_layers=1,
|
240 |
+
):
|
241 |
+
super().__init__()
|
242 |
+
self.climate_seq_len = climate_seq_len
|
243 |
+
self.climate_input_dim = climate_input_dim
|
244 |
+
self.lstm_hidden_dim = lstm_hidden_dim
|
245 |
+
self.num_lstm_layers = num_lstm_layers
|
246 |
+
self.proj_dim = 128
|
247 |
+
self.output_shapes = output_shapes
|
248 |
+
|
249 |
+
self.lstm = nn.LSTM(
|
250 |
+
input_size=climate_input_dim,
|
251 |
+
hidden_size=lstm_hidden_dim,
|
252 |
+
num_layers=num_lstm_layers,
|
253 |
+
batch_first=True, # Crucial: expects input shape (batch, seq_len, features)
|
254 |
+
dropout=0.3 if num_lstm_layers > 1 else 0,
|
255 |
+
bidirectional=False,
|
256 |
+
)
|
257 |
+
|
258 |
+
# Linear layer to project LSTM output to the desired final dimension
|
259 |
+
self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim)
|
260 |
+
|
261 |
+
self.upsamples = nn.ModuleList(
|
262 |
+
_build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes
|
263 |
+
)
|
264 |
+
|
265 |
+
def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]:
|
266 |
+
# climate_data shape: (B, T, T_1, C_clim), e.g., (B, 5, 6, 5)
|
267 |
+
B_img, B_cli, T, C = climate_data.shape
|
268 |
+
|
269 |
+
# Reshape for LSTM: Treat each sequence independently
|
270 |
+
lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C")
|
271 |
+
|
272 |
+
# Pass through LSTM
|
273 |
+
_, (hidden, _) = self.lstm.forward(lstm_input)
|
274 |
+
# Get the last layer's hidden state
|
275 |
+
last_hidden = (
|
276 |
+
hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1]
|
277 |
+
)
|
278 |
+
if last_hidden.ndim == 3:
|
279 |
+
last_hidden = hidden.mean(dim=0)
|
280 |
+
|
281 |
+
# Pass the final hidden state through the fully connected layer(s) and upsample
|
282 |
+
climate_features = self.fc(last_hidden)
|
283 |
+
climate_features = rearrange(climate_features, "b c -> b c 1 1")
|
284 |
+
climate_features = [
|
285 |
+
rearrange(
|
286 |
+
u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli
|
287 |
+
)
|
288 |
+
for u in self.upsamples
|
289 |
+
]
|
290 |
+
|
291 |
+
return climate_features
|
292 |
+
|
293 |
+
|
294 |
+
class GatedFusion(nn.Module):
|
295 |
+
def __init__(self, img_channels, clim_channels):
|
296 |
+
super().__init__()
|
297 |
+
self.gate = nn.Sequential(
|
298 |
+
nn.Sequential(
|
299 |
+
nn.Conv2d(
|
300 |
+
img_channels + clim_channels, img_channels, kernel_size=3, padding=1
|
301 |
+
),
|
302 |
+
nn.ReLU(inplace=True),
|
303 |
+
nn.Conv2d(img_channels, img_channels, kernel_size=1),
|
304 |
+
nn.Sigmoid(), # Gate values between 0 and 1
|
305 |
+
)
|
306 |
+
)
|
307 |
+
|
308 |
+
def forward(self, img_feat, clim_feat):
|
309 |
+
gate = self.gate(torch.cat([img_feat, clim_feat], dim=1))
|
310 |
+
return gate * img_feat + (1 - gate) * clim_feat
|
311 |
+
|
312 |
+
|
313 |
+
def _build_upsampler(
|
314 |
+
in_channels: int, target_channels: int, target_h: int
|
315 |
+
) -> nn.Sequential:
|
316 |
+
layers = []
|
317 |
+
current_h = 1
|
318 |
+
|
319 |
+
# Expand to target channels early (e.g., 1x1 → 1x1 with target_channels)
|
320 |
+
layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()]
|
321 |
+
|
322 |
+
# Upsample spatially to target_h
|
323 |
+
while current_h < target_h:
|
324 |
+
next_h = min(current_h * 2, target_h)
|
325 |
+
layers += [
|
326 |
+
nn.Upsample(scale_factor=2, mode="nearest"),
|
327 |
+
nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1),
|
328 |
+
nn.GELU(),
|
329 |
+
]
|
330 |
+
current_h = next_h
|
331 |
+
|
332 |
+
return nn.Sequential(*layers)
|