qninhdt
commited on
Commit
·
8f603e4
1
Parent(s):
9acdf9b
cc
Browse files- configs/experiment/ch_64.yaml +11 -6
- swim/data/swim_data.py +8 -3
- swim/models/discriminator.py +16 -26
- swim/models/style_encoder.py +6 -2
- swim/models/swim_gan.py +81 -41
- swim/train.py +1 -1
- swim/utils/tensor_pool.py +25 -12
- train_swim.sh +1 -1
configs/experiment/ch_64.yaml
CHANGED
@@ -10,15 +10,20 @@ seed: 42
|
|
10 |
|
11 |
trainer:
|
12 |
max_epochs: 100
|
|
|
13 |
|
14 |
model:
|
15 |
-
channels:
|
16 |
-
z_c_channels:
|
17 |
-
updown_channel_mults: [1, 2, 4]
|
18 |
-
n_enc_resnet_blocks:
|
19 |
-
n_dec_resnet_blocks:
|
20 |
n_f_d_resnet_blocks: 4
|
21 |
-
learning_rate: 1e-
|
|
|
|
|
|
|
22 |
|
23 |
data:
|
24 |
batch_size: 4
|
|
|
|
10 |
|
11 |
trainer:
|
12 |
max_epochs: 100
|
13 |
+
# precision: 16-mixed
|
14 |
|
15 |
model:
|
16 |
+
channels: 128
|
17 |
+
z_c_channels: 512
|
18 |
+
updown_channel_mults: [1, 2, 2, 4]
|
19 |
+
n_enc_resnet_blocks: 8
|
20 |
+
n_dec_resnet_blocks: 8
|
21 |
n_f_d_resnet_blocks: 4
|
22 |
+
learning_rate: 1e-5
|
23 |
+
weight_decay: 1e-2
|
24 |
+
beta_1: 0.9
|
25 |
+
beta_2: 0.999
|
26 |
|
27 |
data:
|
28 |
batch_size: 4
|
29 |
+
|
swim/data/swim_data.py
CHANGED
@@ -26,9 +26,14 @@ class SwimDataset(Dataset):
|
|
26 |
if split == "train":
|
27 |
self.transform = T.Compose(
|
28 |
[
|
29 |
-
|
30 |
-
|
31 |
-
T.RandomResizedCrop(
|
|
|
|
|
|
|
|
|
|
|
32 |
T.RandomHorizontalFlip(),
|
33 |
T.ToTensor(),
|
34 |
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
|
26 |
if split == "train":
|
27 |
self.transform = T.Compose(
|
28 |
[
|
29 |
+
T.Resize(self.img_size),
|
30 |
+
T.RandomCrop(self.img_size),
|
31 |
+
# T.RandomResizedCrop(
|
32 |
+
# self.img_size,
|
33 |
+
# scale=(0.5, 1.0),
|
34 |
+
# ratio=(1.0, 1.0),
|
35 |
+
# interpolation=T.InterpolationMode.LANCZOS,
|
36 |
+
# ),
|
37 |
T.RandomHorizontalFlip(),
|
38 |
T.ToTensor(),
|
39 |
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
swim/models/discriminator.py
CHANGED
@@ -73,17 +73,7 @@ class FeatureDiscriminator(nn.Module):
|
|
73 |
for _ in range(n_resnet_blocks):
|
74 |
self.resnet_blocks.append(SNResnetBlock(channels, channels, d_style_emb))
|
75 |
|
76 |
-
self.conv_out = spectral_norm(
|
77 |
-
nn.Conv2d(channels, z_c_channels, 3, stride=1, padding=1)
|
78 |
-
)
|
79 |
-
|
80 |
-
self.mlp = nn.Sequential(
|
81 |
-
spectral_norm(nn.Linear(z_c_channels + d_style_emb, 256)),
|
82 |
-
nn.LeakyReLU(0.2),
|
83 |
-
spectral_norm(nn.Linear(256, 128)),
|
84 |
-
nn.LeakyReLU(0.2),
|
85 |
-
spectral_norm(nn.Linear(128, 1)),
|
86 |
-
)
|
87 |
|
88 |
def forward(
|
89 |
self, x: torch.Tensor, style_emb: torch.Tensor, for_G=False, for_real=False
|
@@ -96,25 +86,25 @@ class FeatureDiscriminator(nn.Module):
|
|
96 |
h = F.avg_pool2d(h, 2)
|
97 |
|
98 |
h = self.conv_out(h)
|
99 |
-
h = F.leaky_relu(h, 0.2)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
h = self.mlp(torch.concat([h, style_emb], dim=1))
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
|
110 |
-
|
|
|
|
|
111 |
|
112 |
# hinge loss
|
113 |
-
if for_G:
|
114 |
-
|
115 |
-
elif for_real:
|
116 |
-
|
117 |
-
else:
|
118 |
-
|
119 |
|
120 |
return loss
|
|
|
73 |
for _ in range(n_resnet_blocks):
|
74 |
self.resnet_blocks.append(SNResnetBlock(channels, channels, d_style_emb))
|
75 |
|
76 |
+
self.conv_out = spectral_norm(nn.Conv2d(channels, 1, 3, stride=1, padding=1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
def forward(
|
79 |
self, x: torch.Tensor, style_emb: torch.Tensor, for_G=False, for_real=False
|
|
|
86 |
h = F.avg_pool2d(h, 2)
|
87 |
|
88 |
h = self.conv_out(h)
|
|
|
89 |
|
90 |
+
if for_G:
|
91 |
+
for_real = True
|
|
|
92 |
|
93 |
+
if for_real:
|
94 |
+
target = torch.full_like(h, 1.0)
|
95 |
+
else:
|
96 |
+
target = torch.full_like(h, 0.0)
|
97 |
|
98 |
+
loss = F.binary_cross_entropy_with_logits(h, target, reduction="none").mean(
|
99 |
+
dim=[2, 3]
|
100 |
+
)
|
101 |
|
102 |
# hinge loss
|
103 |
+
# if for_G:
|
104 |
+
# loss = -h
|
105 |
+
# elif for_real:
|
106 |
+
# loss = F.relu(1.0 - h)
|
107 |
+
# else:
|
108 |
+
# loss = F.relu(1.0 + h)
|
109 |
|
110 |
return loss
|
swim/models/style_encoder.py
CHANGED
@@ -11,11 +11,15 @@ class StyleEncoder(nn.Module):
|
|
11 |
self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
|
12 |
self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
|
13 |
|
14 |
-
self.fc = nn.
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def forward(self, x):
|
17 |
# resize input to 224x224
|
18 |
-
|
19 |
|
20 |
x = self.resnet(x)
|
21 |
x = torch.flatten(x, 1)
|
|
|
11 |
self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
|
12 |
self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
|
13 |
|
14 |
+
self.fc = nn.Sequential(
|
15 |
+
nn.Linear(512, 256),
|
16 |
+
nn.SiLU(),
|
17 |
+
nn.Linear(256, d_style_emb),
|
18 |
+
)
|
19 |
|
20 |
def forward(self, x):
|
21 |
# resize input to 224x224
|
22 |
+
x = F.interpolate(x, size=(224, 224), mode="bilinear")
|
23 |
|
24 |
x = self.resnet(x)
|
25 |
x = torch.flatten(x, 1)
|
swim/models/swim_gan.py
CHANGED
@@ -9,7 +9,6 @@ from PIL import Image
|
|
9 |
|
10 |
from lightning import LightningModule
|
11 |
|
12 |
-
from diffusers import AutoencoderKL
|
13 |
from diffusers.utils import make_image_grid
|
14 |
|
15 |
from swim.utils.tensor_pool import GroupTensorPool
|
@@ -20,6 +19,7 @@ from .decoder import Decoder
|
|
20 |
from .discriminator import FeatureDiscriminator
|
21 |
|
22 |
import vision_aided_loss
|
|
|
23 |
|
24 |
|
25 |
class SwimGAN(LightningModule):
|
@@ -29,21 +29,23 @@ class SwimGAN(LightningModule):
|
|
29 |
channels: int = 128,
|
30 |
z_c_channels: int = 512,
|
31 |
updown_channel_mults: List[int] = [1, 2, 4],
|
32 |
-
n_enc_resnet_blocks: int =
|
33 |
n_dec_resnet_blocks: int = 6,
|
34 |
n_f_d_resnet_blocks: int = 2,
|
35 |
n_styles: int = 5,
|
36 |
d_style_emb: int = 128,
|
37 |
input_size: int = 512,
|
38 |
learning_rate: float = 1e-5,
|
39 |
-
weight_decay: float = 1e-
|
|
|
|
|
40 |
lambda_cls: float = 10.0,
|
41 |
-
lambda_rec: float =
|
42 |
-
lambda_cycle: float =
|
43 |
-
lambda_f_g: float =
|
44 |
lambda_i_g: float = 1.0,
|
45 |
lambda_c_const: float = 1.0,
|
46 |
-
lambda_s_const: float =
|
47 |
):
|
48 |
super().__init__()
|
49 |
|
@@ -53,6 +55,8 @@ class SwimGAN(LightningModule):
|
|
53 |
self.n_styles = n_styles
|
54 |
self.learning_rate = learning_rate
|
55 |
self.weight_decay = weight_decay
|
|
|
|
|
56 |
self.lambda_rec = lambda_rec
|
57 |
self.lambda_cycle = lambda_cycle
|
58 |
self.lambda_cls = lambda_cls
|
@@ -82,7 +86,10 @@ class SwimGAN(LightningModule):
|
|
82 |
|
83 |
# training only
|
84 |
self.i_discriminator = vision_aided_loss.Discriminator(
|
85 |
-
cv_type="clip",
|
|
|
|
|
|
|
86 |
)
|
87 |
|
88 |
self.f_discriminator = FeatureDiscriminator(
|
@@ -95,6 +102,9 @@ class SwimGAN(LightningModule):
|
|
95 |
self.style_pool = GroupTensorPool(n_styles, 32)
|
96 |
self.content_pool = GroupTensorPool(n_styles, 32)
|
97 |
|
|
|
|
|
|
|
98 |
self.cls_loss = nn.CrossEntropyLoss()
|
99 |
|
100 |
def on_fit_start(self):
|
@@ -108,7 +118,7 @@ class SwimGAN(LightningModule):
|
|
108 |
|
109 |
g_opt, i_d_opt, f_d_opt = self.optimizers()
|
110 |
|
111 |
-
# train
|
112 |
z_s = self.style_encoder(x)
|
113 |
z_c = self.content_encoder(x)
|
114 |
x_rec = self.decoder(z_c, z_s)
|
@@ -116,11 +126,28 @@ class SwimGAN(LightningModule):
|
|
116 |
style_logits = self.style_classifier(z_s)
|
117 |
cls_loss = self.cls_loss(style_logits, gt_style)
|
118 |
|
119 |
-
rec_loss = F.l1_loss(x, x_rec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
#
|
|
|
|
|
|
|
122 |
z_c_hat, _ = self.content_pool.query(z_c, gt_style)
|
123 |
-
z_s_hat,
|
|
|
|
|
|
|
|
|
124 |
|
125 |
x1 = self.decoder(z_c, z_s_hat)
|
126 |
x2 = self.decoder(z_c_hat, z_s)
|
@@ -128,29 +155,30 @@ class SwimGAN(LightningModule):
|
|
128 |
z_c_rec = self.content_encoder(x1)
|
129 |
z_s_rec = self.style_encoder(x2)
|
130 |
|
|
|
|
|
|
|
131 |
x_cycle = self.decoder(z_c_rec, z_s_rec)
|
132 |
|
133 |
-
|
134 |
-
s_const_loss = F.l1_loss(z_s, z_s_rec)
|
135 |
|
136 |
-
|
|
|
137 |
|
138 |
-
# adversarial loss
|
139 |
i_g_loss = (
|
140 |
-
self.i_discriminator(x1, for_G=True).mean()
|
141 |
-
+ self.i_discriminator(x2, for_G=True).mean()
|
142 |
-
|
143 |
-
|
144 |
-
c_g_loss = self.f_discriminator(z_c, z_s, for_G=True).mean()
|
145 |
|
146 |
g_loss = (
|
147 |
self.lambda_rec * rec_loss
|
|
|
|
|
148 |
+ self.lambda_cycle * cycle_loss
|
149 |
+ self.lambda_c_const * c_const_loss
|
150 |
+ self.lambda_s_const * s_const_loss
|
151 |
+ self.lambda_i_g * i_g_loss
|
152 |
-
+ self.lambda_f_g * c_g_loss
|
153 |
-
+ self.lambda_cls * cls_loss
|
154 |
)
|
155 |
|
156 |
g_opt.zero_grad()
|
@@ -158,14 +186,19 @@ class SwimGAN(LightningModule):
|
|
158 |
g_opt.step()
|
159 |
|
160 |
# train the image discriminator
|
|
|
|
|
|
|
|
|
161 |
i_d_loss = (
|
162 |
-
self.i_discriminator(x, for_real=True).mean()
|
163 |
+ (
|
164 |
-
self.i_discriminator(x1
|
165 |
-
+ self.i_discriminator(x2
|
|
|
166 |
)
|
167 |
-
/
|
168 |
-
) / 2
|
169 |
|
170 |
i_d_opt.zero_grad()
|
171 |
self.manual_backward(i_d_loss)
|
@@ -173,10 +206,10 @@ class SwimGAN(LightningModule):
|
|
173 |
|
174 |
# train the feature discriminator
|
175 |
f_d_loss = (
|
176 |
-
self.f_discriminator(z_c
|
177 |
+ (
|
178 |
-
self.f_discriminator(z_c
|
179 |
-
+ self.f_discriminator(z_c_hat, z_s
|
180 |
)
|
181 |
/ 2
|
182 |
) / 2
|
@@ -191,8 +224,8 @@ class SwimGAN(LightningModule):
|
|
191 |
"train/cycle_loss": cycle_loss,
|
192 |
"train/cls_loss": cls_loss,
|
193 |
"train/i_g_loss": i_g_loss,
|
194 |
-
"train/i_d_loss":
|
195 |
-
"train/f_g_loss":
|
196 |
"train/f_d_loss": f_d_loss,
|
197 |
"train/c_const_loss": c_const_loss,
|
198 |
"train/s_const_loss": s_const_loss,
|
@@ -257,31 +290,38 @@ class SwimGAN(LightningModule):
|
|
257 |
return x
|
258 |
|
259 |
def configure_optimizers(self):
|
260 |
-
|
261 |
[
|
262 |
{"params": self.content_encoder.parameters()},
|
263 |
-
{"params": self.style_encoder.
|
264 |
-
{
|
265 |
-
"params": self.style_encoder.resnet.parameters(),
|
266 |
-
"lr": self.learning_rate / 10,
|
267 |
-
},
|
268 |
{"params": self.style_classifier.parameters()},
|
269 |
{"params": self.decoder.parameters()},
|
270 |
],
|
271 |
lr=self.learning_rate,
|
272 |
weight_decay=self.weight_decay,
|
|
|
273 |
)
|
274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
i_d_opt = torch.optim.AdamW(
|
276 |
list(self.i_discriminator.parameters()),
|
277 |
lr=self.learning_rate,
|
278 |
weight_decay=self.weight_decay,
|
|
|
279 |
)
|
280 |
|
281 |
f_d_opt = torch.optim.AdamW(
|
282 |
list(self.f_discriminator.parameters()),
|
283 |
-
lr=self.learning_rate *
|
284 |
weight_decay=self.weight_decay,
|
|
|
285 |
)
|
286 |
|
287 |
-
return [
|
|
|
9 |
|
10 |
from lightning import LightningModule
|
11 |
|
|
|
12 |
from diffusers.utils import make_image_grid
|
13 |
|
14 |
from swim.utils.tensor_pool import GroupTensorPool
|
|
|
19 |
from .discriminator import FeatureDiscriminator
|
20 |
|
21 |
import vision_aided_loss
|
22 |
+
from lpips import LPIPS
|
23 |
|
24 |
|
25 |
class SwimGAN(LightningModule):
|
|
|
29 |
channels: int = 128,
|
30 |
z_c_channels: int = 512,
|
31 |
updown_channel_mults: List[int] = [1, 2, 4],
|
32 |
+
n_enc_resnet_blocks: int = 6,
|
33 |
n_dec_resnet_blocks: int = 6,
|
34 |
n_f_d_resnet_blocks: int = 2,
|
35 |
n_styles: int = 5,
|
36 |
d_style_emb: int = 128,
|
37 |
input_size: int = 512,
|
38 |
learning_rate: float = 1e-5,
|
39 |
+
weight_decay: float = 1e-2,
|
40 |
+
beta_1: float = 0.9,
|
41 |
+
beta_2: float = 0.999,
|
42 |
lambda_cls: float = 10.0,
|
43 |
+
lambda_rec: float = 10.0,
|
44 |
+
lambda_cycle: float = 10.0,
|
45 |
+
lambda_f_g: float = 1.0,
|
46 |
lambda_i_g: float = 1.0,
|
47 |
lambda_c_const: float = 1.0,
|
48 |
+
lambda_s_const: float = 1.0,
|
49 |
):
|
50 |
super().__init__()
|
51 |
|
|
|
55 |
self.n_styles = n_styles
|
56 |
self.learning_rate = learning_rate
|
57 |
self.weight_decay = weight_decay
|
58 |
+
self.beta_1 = beta_1
|
59 |
+
self.beta_2 = beta_2
|
60 |
self.lambda_rec = lambda_rec
|
61 |
self.lambda_cycle = lambda_cycle
|
62 |
self.lambda_cls = lambda_cls
|
|
|
86 |
|
87 |
# training only
|
88 |
self.i_discriminator = vision_aided_loss.Discriminator(
|
89 |
+
cv_type="clip",
|
90 |
+
num_classes=n_styles,
|
91 |
+
loss_type="multilevel_sigmoid_s",
|
92 |
+
device="cpu",
|
93 |
)
|
94 |
|
95 |
self.f_discriminator = FeatureDiscriminator(
|
|
|
102 |
self.style_pool = GroupTensorPool(n_styles, 32)
|
103 |
self.content_pool = GroupTensorPool(n_styles, 32)
|
104 |
|
105 |
+
# self.lpips = LPIPS(net="vgg")
|
106 |
+
# self.lpips.requires_grad_(False)
|
107 |
+
|
108 |
self.cls_loss = nn.CrossEntropyLoss()
|
109 |
|
110 |
def on_fit_start(self):
|
|
|
118 |
|
119 |
g_opt, i_d_opt, f_d_opt = self.optimizers()
|
120 |
|
121 |
+
# train g1
|
122 |
z_s = self.style_encoder(x)
|
123 |
z_c = self.content_encoder(x)
|
124 |
x_rec = self.decoder(z_c, z_s)
|
|
|
126 |
style_logits = self.style_classifier(z_s)
|
127 |
cls_loss = self.cls_loss(style_logits, gt_style)
|
128 |
|
129 |
+
rec_loss = F.l1_loss(x, x_rec) # + self.lpips(x, x_rec).mean()
|
130 |
+
|
131 |
+
f_g_loss = self.f_discriminator(z_c, z_s, for_G=True).mean()
|
132 |
+
|
133 |
+
# g1_loss = (
|
134 |
+
# self.lambda_rec * rec_loss
|
135 |
+
# + self.lambda_f_g * f_g_loss
|
136 |
+
# + self.lambda_cls * cls_loss
|
137 |
+
# )
|
138 |
+
|
139 |
+
# g1_opt.zero_grad()
|
140 |
|
141 |
+
# self.manual_backward(g1_loss)
|
142 |
+
# g1_opt.step()
|
143 |
+
|
144 |
+
# sample a random content and style feature
|
145 |
z_c_hat, _ = self.content_pool.query(z_c, gt_style)
|
146 |
+
z_s_hat, gt_style_hat = self.style_pool.query(z_s, gt_style)
|
147 |
+
|
148 |
+
# train g2
|
149 |
+
z_c = z_c.detach()
|
150 |
+
z_s = z_s.detach()
|
151 |
|
152 |
x1 = self.decoder(z_c, z_s_hat)
|
153 |
x2 = self.decoder(z_c_hat, z_s)
|
|
|
155 |
z_c_rec = self.content_encoder(x1)
|
156 |
z_s_rec = self.style_encoder(x2)
|
157 |
|
158 |
+
z_c_hat_rec = self.content_encoder(x2)
|
159 |
+
z_s_hat_rec = self.style_encoder(x1)
|
160 |
+
|
161 |
x_cycle = self.decoder(z_c_rec, z_s_rec)
|
162 |
|
163 |
+
cycle_loss = F.l1_loss(x, x_cycle) # + self.lpips(x, x_cycle).mean()
|
|
|
164 |
|
165 |
+
c_const_loss = F.l1_loss(z_c, z_c_rec) + F.l1_loss(z_c_hat, z_c_hat_rec)
|
166 |
+
s_const_loss = F.l1_loss(z_s, z_s_rec) + F.l1_loss(z_s_hat, z_s_hat_rec)
|
167 |
|
|
|
168 |
i_g_loss = (
|
169 |
+
self.i_discriminator(x1, gt_style_hat, for_G=True).mean()
|
170 |
+
+ self.i_discriminator(x2, gt_style, for_G=True).mean()
|
171 |
+
+ self.i_discriminator(x_cycle, gt_style, for_G=True).mean()
|
172 |
+
) / 3.0
|
|
|
173 |
|
174 |
g_loss = (
|
175 |
self.lambda_rec * rec_loss
|
176 |
+
+ self.lambda_f_g * f_g_loss
|
177 |
+
+ self.lambda_cls * cls_loss
|
178 |
+ self.lambda_cycle * cycle_loss
|
179 |
+ self.lambda_c_const * c_const_loss
|
180 |
+ self.lambda_s_const * s_const_loss
|
181 |
+ self.lambda_i_g * i_g_loss
|
|
|
|
|
182 |
)
|
183 |
|
184 |
g_opt.zero_grad()
|
|
|
186 |
g_opt.step()
|
187 |
|
188 |
# train the image discriminator
|
189 |
+
x1 = x1.detach()
|
190 |
+
x2 = x2.detach()
|
191 |
+
x_cycle = x_cycle.detach()
|
192 |
+
|
193 |
i_d_loss = (
|
194 |
+
self.i_discriminator(x, gt_style, for_real=True).mean()
|
195 |
+ (
|
196 |
+
self.i_discriminator(x1, gt_style_hat, for_real=False).mean()
|
197 |
+
+ self.i_discriminator(x2, gt_style, for_real=False).mean()
|
198 |
+
+ self.i_discriminator(x_cycle, gt_style, for_real=False).mean()
|
199 |
)
|
200 |
+
/ 3.0
|
201 |
+
) / 2.0
|
202 |
|
203 |
i_d_opt.zero_grad()
|
204 |
self.manual_backward(i_d_loss)
|
|
|
206 |
|
207 |
# train the feature discriminator
|
208 |
f_d_loss = (
|
209 |
+
self.f_discriminator(z_c, z_s, for_real=False).mean()
|
210 |
+ (
|
211 |
+
self.f_discriminator(z_c, z_s_hat, for_real=True).mean()
|
212 |
+
+ self.f_discriminator(z_c_hat, z_s, for_real=True).mean()
|
213 |
)
|
214 |
/ 2
|
215 |
) / 2
|
|
|
224 |
"train/cycle_loss": cycle_loss,
|
225 |
"train/cls_loss": cls_loss,
|
226 |
"train/i_g_loss": i_g_loss,
|
227 |
+
"train/i_d_loss": i_d_loss,
|
228 |
+
"train/f_g_loss": f_g_loss,
|
229 |
"train/f_d_loss": f_d_loss,
|
230 |
"train/c_const_loss": c_const_loss,
|
231 |
"train/s_const_loss": s_const_loss,
|
|
|
290 |
return x
|
291 |
|
292 |
def configure_optimizers(self):
|
293 |
+
g1_opt = torch.optim.AdamW(
|
294 |
[
|
295 |
{"params": self.content_encoder.parameters()},
|
296 |
+
{"params": self.style_encoder.parameters()},
|
|
|
|
|
|
|
|
|
297 |
{"params": self.style_classifier.parameters()},
|
298 |
{"params": self.decoder.parameters()},
|
299 |
],
|
300 |
lr=self.learning_rate,
|
301 |
weight_decay=self.weight_decay,
|
302 |
+
betas=(self.beta_1, self.beta_2),
|
303 |
)
|
304 |
|
305 |
+
# g2_opt = torch.optim.AdamW(
|
306 |
+
# [
|
307 |
+
# {"params": self.decoder.parameters()},
|
308 |
+
# ],
|
309 |
+
# lr=self.learning_rate,
|
310 |
+
# weight_decay=self.weight_decay,
|
311 |
+
# )
|
312 |
+
|
313 |
i_d_opt = torch.optim.AdamW(
|
314 |
list(self.i_discriminator.parameters()),
|
315 |
lr=self.learning_rate,
|
316 |
weight_decay=self.weight_decay,
|
317 |
+
betas=(self.beta_1, self.beta_2),
|
318 |
)
|
319 |
|
320 |
f_d_opt = torch.optim.AdamW(
|
321 |
list(self.f_discriminator.parameters()),
|
322 |
+
lr=self.learning_rate * 10,
|
323 |
weight_decay=self.weight_decay,
|
324 |
+
betas=(self.beta_1, self.beta_2),
|
325 |
)
|
326 |
|
327 |
+
return [g1_opt, i_d_opt, f_d_opt]
|
swim/train.py
CHANGED
@@ -88,7 +88,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
88 |
|
89 |
if cfg.get("train"):
|
90 |
if cfg.compile:
|
91 |
-
model.compile()
|
92 |
|
93 |
log.info("Starting training!")
|
94 |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
|
|
88 |
|
89 |
if cfg.get("train"):
|
90 |
if cfg.compile:
|
91 |
+
model = torch.compile(model)
|
92 |
|
93 |
log.info("Starting training!")
|
94 |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
swim/utils/tensor_pool.py
CHANGED
@@ -28,6 +28,10 @@ class TensorPool(object):
|
|
28 |
return torch.stack(return_tensors)
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
31 |
class GroupTensorPool(object):
|
32 |
|
33 |
def __init__(self, n_groups: int = 5, pool_size: int = 32):
|
@@ -40,25 +44,34 @@ class GroupTensorPool(object):
|
|
40 |
return_groups = []
|
41 |
|
42 |
tensors = tensors.detach().clone()
|
|
|
43 |
for tensor, curr_group in zip(tensors, labels):
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
group = choice([i for i in range(self.n_groups) if i != curr_group.item()])
|
47 |
pool = self.pools[group]
|
48 |
|
49 |
if len(pool) == 0:
|
|
|
50 |
return_tensors.append(tensor)
|
51 |
return_groups.append(curr_group)
|
52 |
-
pool.append((tensor, curr_group))
|
53 |
-
elif len(pool) < self.pool_size:
|
54 |
-
idx = randint(0, len(pool) - 1)
|
55 |
-
return_tensors.append(pool[idx][0])
|
56 |
-
return_groups.append(pool[idx][1])
|
57 |
-
pool.append((tensor, curr_group))
|
58 |
else:
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
return torch.stack(return_tensors), torch.stack(return_groups)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return torch.stack(return_tensors)
|
29 |
|
30 |
|
31 |
+
import torch
|
32 |
+
from random import choice, randint
|
33 |
+
|
34 |
+
|
35 |
class GroupTensorPool(object):
|
36 |
|
37 |
def __init__(self, n_groups: int = 5, pool_size: int = 32):
|
|
|
44 |
return_groups = []
|
45 |
|
46 |
tensors = tensors.detach().clone()
|
47 |
+
|
48 |
for tensor, curr_group in zip(tensors, labels):
|
49 |
+
self.save_new_tensor(tensor, curr_group.item())
|
50 |
+
|
51 |
+
for tensor, curr_group in zip(tensors, labels):
|
52 |
+
# Choose a random group except the current one
|
53 |
group = choice([i for i in range(self.n_groups) if i != curr_group.item()])
|
54 |
pool = self.pools[group]
|
55 |
|
56 |
if len(pool) == 0:
|
57 |
+
# If the selected group pool is empty, return the current tensor
|
58 |
return_tensors.append(tensor)
|
59 |
return_groups.append(curr_group)
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
else:
|
61 |
+
# Otherwise, select a random tensor from the pool
|
62 |
+
random_tensor = choice(pool)
|
63 |
+
return_tensors.append(random_tensor)
|
64 |
+
return_groups.append(torch.tensor(group))
|
65 |
|
66 |
return torch.stack(return_tensors), torch.stack(return_groups)
|
67 |
+
|
68 |
+
def save_new_tensor(self, tensor: torch.Tensor, group: int):
|
69 |
+
pool = self.pools[group]
|
70 |
+
|
71 |
+
if len(pool) < self.pool_size:
|
72 |
+
# If the pool is not full, append the tensor
|
73 |
+
pool.append(tensor)
|
74 |
+
else:
|
75 |
+
# Replace a random item in the pool with the new tensor
|
76 |
+
replace_idx = randint(0, len(pool) - 1)
|
77 |
+
pool[replace_idx] = tensor
|
train_swim.sh
CHANGED
@@ -3,7 +3,7 @@ python swim/train.py \
|
|
3 |
data.root_dir=/cm/shared/ninhnq3/datasets/swim_data \
|
4 |
logger=wandb \
|
5 |
logger.wandb.save_dir=/cm/shared/ninhnq3/workdirs/swim \
|
6 |
-
+trainer.val_check_interval=0.
|
7 |
+trainer.limit_val_batches=0.01 \
|
8 |
callbacks.model_checkpoint.dirpath=/cm/shared/ninhnq3/checkpoints/swim_final
|
9 |
|
|
|
3 |
data.root_dir=/cm/shared/ninhnq3/datasets/swim_data \
|
4 |
logger=wandb \
|
5 |
logger.wandb.save_dir=/cm/shared/ninhnq3/workdirs/swim \
|
6 |
+
+trainer.val_check_interval=0.05 \
|
7 |
+trainer.limit_val_batches=0.01 \
|
8 |
callbacks.model_checkpoint.dirpath=/cm/shared/ninhnq3/checkpoints/swim_final
|
9 |
|