qninhdt commited on
Commit
8f603e4
·
1 Parent(s): 9acdf9b
configs/experiment/ch_64.yaml CHANGED
@@ -10,15 +10,20 @@ seed: 42
10
 
11
  trainer:
12
  max_epochs: 100
 
13
 
14
  model:
15
- channels: 64
16
- z_c_channels: 256
17
- updown_channel_mults: [1, 2, 4]
18
- n_enc_resnet_blocks: 4
19
- n_dec_resnet_blocks: 6
20
  n_f_d_resnet_blocks: 4
21
- learning_rate: 1e-4
 
 
 
22
 
23
  data:
24
  batch_size: 4
 
 
10
 
11
  trainer:
12
  max_epochs: 100
13
+ # precision: 16-mixed
14
 
15
  model:
16
+ channels: 128
17
+ z_c_channels: 512
18
+ updown_channel_mults: [1, 2, 2, 4]
19
+ n_enc_resnet_blocks: 8
20
+ n_dec_resnet_blocks: 8
21
  n_f_d_resnet_blocks: 4
22
+ learning_rate: 1e-5
23
+ weight_decay: 1e-2
24
+ beta_1: 0.9
25
+ beta_2: 0.999
26
 
27
  data:
28
  batch_size: 4
29
+
swim/data/swim_data.py CHANGED
@@ -26,9 +26,14 @@ class SwimDataset(Dataset):
26
  if split == "train":
27
  self.transform = T.Compose(
28
  [
29
- # T.Resize(self.img_size),
30
- # T.RandomCrop(self.img_size),
31
- T.RandomResizedCrop(self.img_size, scale=(0.5, 1.0)),
 
 
 
 
 
32
  T.RandomHorizontalFlip(),
33
  T.ToTensor(),
34
  T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
 
26
  if split == "train":
27
  self.transform = T.Compose(
28
  [
29
+ T.Resize(self.img_size),
30
+ T.RandomCrop(self.img_size),
31
+ # T.RandomResizedCrop(
32
+ # self.img_size,
33
+ # scale=(0.5, 1.0),
34
+ # ratio=(1.0, 1.0),
35
+ # interpolation=T.InterpolationMode.LANCZOS,
36
+ # ),
37
  T.RandomHorizontalFlip(),
38
  T.ToTensor(),
39
  T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
swim/models/discriminator.py CHANGED
@@ -73,17 +73,7 @@ class FeatureDiscriminator(nn.Module):
73
  for _ in range(n_resnet_blocks):
74
  self.resnet_blocks.append(SNResnetBlock(channels, channels, d_style_emb))
75
 
76
- self.conv_out = spectral_norm(
77
- nn.Conv2d(channels, z_c_channels, 3, stride=1, padding=1)
78
- )
79
-
80
- self.mlp = nn.Sequential(
81
- spectral_norm(nn.Linear(z_c_channels + d_style_emb, 256)),
82
- nn.LeakyReLU(0.2),
83
- spectral_norm(nn.Linear(256, 128)),
84
- nn.LeakyReLU(0.2),
85
- spectral_norm(nn.Linear(128, 1)),
86
- )
87
 
88
  def forward(
89
  self, x: torch.Tensor, style_emb: torch.Tensor, for_G=False, for_real=False
@@ -96,25 +86,25 @@ class FeatureDiscriminator(nn.Module):
96
  h = F.avg_pool2d(h, 2)
97
 
98
  h = self.conv_out(h)
99
- h = F.leaky_relu(h, 0.2)
100
 
101
- h = F.adaptive_avg_pool2d(h, 1)
102
- h = torch.flatten(h, 1)
103
- h = self.mlp(torch.concat([h, style_emb], dim=1))
104
 
105
- # if for_real:
106
- # target = torch.full_like(h, 1.0)
107
- # else:
108
- # target = torch.zeros_like(h)
109
 
110
- # loss = F.binary_cross_entropy_with_logits(h, target, reduction="none")
 
 
111
 
112
  # hinge loss
113
- if for_G:
114
- loss = -h
115
- elif for_real:
116
- loss = F.relu(1.0 - h)
117
- else:
118
- loss = F.relu(1.0 + h)
119
 
120
  return loss
 
73
  for _ in range(n_resnet_blocks):
74
  self.resnet_blocks.append(SNResnetBlock(channels, channels, d_style_emb))
75
 
76
+ self.conv_out = spectral_norm(nn.Conv2d(channels, 1, 3, stride=1, padding=1))
 
 
 
 
 
 
 
 
 
 
77
 
78
  def forward(
79
  self, x: torch.Tensor, style_emb: torch.Tensor, for_G=False, for_real=False
 
86
  h = F.avg_pool2d(h, 2)
87
 
88
  h = self.conv_out(h)
 
89
 
90
+ if for_G:
91
+ for_real = True
 
92
 
93
+ if for_real:
94
+ target = torch.full_like(h, 1.0)
95
+ else:
96
+ target = torch.full_like(h, 0.0)
97
 
98
+ loss = F.binary_cross_entropy_with_logits(h, target, reduction="none").mean(
99
+ dim=[2, 3]
100
+ )
101
 
102
  # hinge loss
103
+ # if for_G:
104
+ # loss = -h
105
+ # elif for_real:
106
+ # loss = F.relu(1.0 - h)
107
+ # else:
108
+ # loss = F.relu(1.0 + h)
109
 
110
  return loss
swim/models/style_encoder.py CHANGED
@@ -11,11 +11,15 @@ class StyleEncoder(nn.Module):
11
  self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
12
  self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
13
 
14
- self.fc = nn.Linear(512, d_style_emb)
 
 
 
 
15
 
16
  def forward(self, x):
17
  # resize input to 224x224
18
- # x = F.interpolate(x, size=(224, 224), mode="bilinear")
19
 
20
  x = self.resnet(x)
21
  x = torch.flatten(x, 1)
 
11
  self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
12
  self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
13
 
14
+ self.fc = nn.Sequential(
15
+ nn.Linear(512, 256),
16
+ nn.SiLU(),
17
+ nn.Linear(256, d_style_emb),
18
+ )
19
 
20
  def forward(self, x):
21
  # resize input to 224x224
22
+ x = F.interpolate(x, size=(224, 224), mode="bilinear")
23
 
24
  x = self.resnet(x)
25
  x = torch.flatten(x, 1)
swim/models/swim_gan.py CHANGED
@@ -9,7 +9,6 @@ from PIL import Image
9
 
10
  from lightning import LightningModule
11
 
12
- from diffusers import AutoencoderKL
13
  from diffusers.utils import make_image_grid
14
 
15
  from swim.utils.tensor_pool import GroupTensorPool
@@ -20,6 +19,7 @@ from .decoder import Decoder
20
  from .discriminator import FeatureDiscriminator
21
 
22
  import vision_aided_loss
 
23
 
24
 
25
  class SwimGAN(LightningModule):
@@ -29,21 +29,23 @@ class SwimGAN(LightningModule):
29
  channels: int = 128,
30
  z_c_channels: int = 512,
31
  updown_channel_mults: List[int] = [1, 2, 4],
32
- n_enc_resnet_blocks: int = 4,
33
  n_dec_resnet_blocks: int = 6,
34
  n_f_d_resnet_blocks: int = 2,
35
  n_styles: int = 5,
36
  d_style_emb: int = 128,
37
  input_size: int = 512,
38
  learning_rate: float = 1e-5,
39
- weight_decay: float = 1e-4,
 
 
40
  lambda_cls: float = 10.0,
41
- lambda_rec: float = 1.0,
42
- lambda_cycle: float = 1.0,
43
- lambda_f_g: float = 10.0,
44
  lambda_i_g: float = 1.0,
45
  lambda_c_const: float = 1.0,
46
- lambda_s_const: float = 10.0,
47
  ):
48
  super().__init__()
49
 
@@ -53,6 +55,8 @@ class SwimGAN(LightningModule):
53
  self.n_styles = n_styles
54
  self.learning_rate = learning_rate
55
  self.weight_decay = weight_decay
 
 
56
  self.lambda_rec = lambda_rec
57
  self.lambda_cycle = lambda_cycle
58
  self.lambda_cls = lambda_cls
@@ -82,7 +86,10 @@ class SwimGAN(LightningModule):
82
 
83
  # training only
84
  self.i_discriminator = vision_aided_loss.Discriminator(
85
- cv_type="clip", loss_type="multilevel_sigmoid", device="cpu"
 
 
 
86
  )
87
 
88
  self.f_discriminator = FeatureDiscriminator(
@@ -95,6 +102,9 @@ class SwimGAN(LightningModule):
95
  self.style_pool = GroupTensorPool(n_styles, 32)
96
  self.content_pool = GroupTensorPool(n_styles, 32)
97
 
 
 
 
98
  self.cls_loss = nn.CrossEntropyLoss()
99
 
100
  def on_fit_start(self):
@@ -108,7 +118,7 @@ class SwimGAN(LightningModule):
108
 
109
  g_opt, i_d_opt, f_d_opt = self.optimizers()
110
 
111
- # train the autoencoder
112
  z_s = self.style_encoder(x)
113
  z_c = self.content_encoder(x)
114
  x_rec = self.decoder(z_c, z_s)
@@ -116,11 +126,28 @@ class SwimGAN(LightningModule):
116
  style_logits = self.style_classifier(z_s)
117
  cls_loss = self.cls_loss(style_logits, gt_style)
118
 
119
- rec_loss = F.l1_loss(x, x_rec)
 
 
 
 
 
 
 
 
 
 
120
 
121
- # # sample a random content and style feature
 
 
 
122
  z_c_hat, _ = self.content_pool.query(z_c, gt_style)
123
- z_s_hat, _ = self.style_pool.query(z_s, gt_style)
 
 
 
 
124
 
125
  x1 = self.decoder(z_c, z_s_hat)
126
  x2 = self.decoder(z_c_hat, z_s)
@@ -128,29 +155,30 @@ class SwimGAN(LightningModule):
128
  z_c_rec = self.content_encoder(x1)
129
  z_s_rec = self.style_encoder(x2)
130
 
 
 
 
131
  x_cycle = self.decoder(z_c_rec, z_s_rec)
132
 
133
- c_const_loss = F.l1_loss(z_c, z_c_rec)
134
- s_const_loss = F.l1_loss(z_s, z_s_rec)
135
 
136
- cycle_loss = F.l1_loss(x, x_cycle)
 
137
 
138
- # adversarial loss
139
  i_g_loss = (
140
- self.i_discriminator(x1, for_G=True).mean()
141
- + self.i_discriminator(x2, for_G=True).mean()
142
- ) / 2
143
-
144
- c_g_loss = self.f_discriminator(z_c, z_s, for_G=True).mean()
145
 
146
  g_loss = (
147
  self.lambda_rec * rec_loss
 
 
148
  + self.lambda_cycle * cycle_loss
149
  + self.lambda_c_const * c_const_loss
150
  + self.lambda_s_const * s_const_loss
151
  + self.lambda_i_g * i_g_loss
152
- + self.lambda_f_g * c_g_loss
153
- + self.lambda_cls * cls_loss
154
  )
155
 
156
  g_opt.zero_grad()
@@ -158,14 +186,19 @@ class SwimGAN(LightningModule):
158
  g_opt.step()
159
 
160
  # train the image discriminator
 
 
 
 
161
  i_d_loss = (
162
- self.i_discriminator(x, for_real=True).mean()
163
  + (
164
- self.i_discriminator(x1.detach(), for_real=False).mean()
165
- + self.i_discriminator(x2.detach(), for_real=False).mean()
 
166
  )
167
- / 2
168
- ) / 2
169
 
170
  i_d_opt.zero_grad()
171
  self.manual_backward(i_d_loss)
@@ -173,10 +206,10 @@ class SwimGAN(LightningModule):
173
 
174
  # train the feature discriminator
175
  f_d_loss = (
176
- self.f_discriminator(z_c.detach(), z_s.detach(), for_real=False).mean()
177
  + (
178
- self.f_discriminator(z_c.detach(), z_s_hat, for_real=True).mean()
179
- + self.f_discriminator(z_c_hat, z_s.detach(), for_real=True).mean()
180
  )
181
  / 2
182
  ) / 2
@@ -191,8 +224,8 @@ class SwimGAN(LightningModule):
191
  "train/cycle_loss": cycle_loss,
192
  "train/cls_loss": cls_loss,
193
  "train/i_g_loss": i_g_loss,
194
- "train/i_d_loss": f_d_loss,
195
- "train/f_g_loss": c_g_loss,
196
  "train/f_d_loss": f_d_loss,
197
  "train/c_const_loss": c_const_loss,
198
  "train/s_const_loss": s_const_loss,
@@ -257,31 +290,38 @@ class SwimGAN(LightningModule):
257
  return x
258
 
259
  def configure_optimizers(self):
260
- g_opt = torch.optim.AdamW(
261
  [
262
  {"params": self.content_encoder.parameters()},
263
- {"params": self.style_encoder.fc.parameters()},
264
- {
265
- "params": self.style_encoder.resnet.parameters(),
266
- "lr": self.learning_rate / 10,
267
- },
268
  {"params": self.style_classifier.parameters()},
269
  {"params": self.decoder.parameters()},
270
  ],
271
  lr=self.learning_rate,
272
  weight_decay=self.weight_decay,
 
273
  )
274
 
 
 
 
 
 
 
 
 
275
  i_d_opt = torch.optim.AdamW(
276
  list(self.i_discriminator.parameters()),
277
  lr=self.learning_rate,
278
  weight_decay=self.weight_decay,
 
279
  )
280
 
281
  f_d_opt = torch.optim.AdamW(
282
  list(self.f_discriminator.parameters()),
283
- lr=self.learning_rate * 2,
284
  weight_decay=self.weight_decay,
 
285
  )
286
 
287
- return [g_opt, i_d_opt, f_d_opt]
 
9
 
10
  from lightning import LightningModule
11
 
 
12
  from diffusers.utils import make_image_grid
13
 
14
  from swim.utils.tensor_pool import GroupTensorPool
 
19
  from .discriminator import FeatureDiscriminator
20
 
21
  import vision_aided_loss
22
+ from lpips import LPIPS
23
 
24
 
25
  class SwimGAN(LightningModule):
 
29
  channels: int = 128,
30
  z_c_channels: int = 512,
31
  updown_channel_mults: List[int] = [1, 2, 4],
32
+ n_enc_resnet_blocks: int = 6,
33
  n_dec_resnet_blocks: int = 6,
34
  n_f_d_resnet_blocks: int = 2,
35
  n_styles: int = 5,
36
  d_style_emb: int = 128,
37
  input_size: int = 512,
38
  learning_rate: float = 1e-5,
39
+ weight_decay: float = 1e-2,
40
+ beta_1: float = 0.9,
41
+ beta_2: float = 0.999,
42
  lambda_cls: float = 10.0,
43
+ lambda_rec: float = 10.0,
44
+ lambda_cycle: float = 10.0,
45
+ lambda_f_g: float = 1.0,
46
  lambda_i_g: float = 1.0,
47
  lambda_c_const: float = 1.0,
48
+ lambda_s_const: float = 1.0,
49
  ):
50
  super().__init__()
51
 
 
55
  self.n_styles = n_styles
56
  self.learning_rate = learning_rate
57
  self.weight_decay = weight_decay
58
+ self.beta_1 = beta_1
59
+ self.beta_2 = beta_2
60
  self.lambda_rec = lambda_rec
61
  self.lambda_cycle = lambda_cycle
62
  self.lambda_cls = lambda_cls
 
86
 
87
  # training only
88
  self.i_discriminator = vision_aided_loss.Discriminator(
89
+ cv_type="clip",
90
+ num_classes=n_styles,
91
+ loss_type="multilevel_sigmoid_s",
92
+ device="cpu",
93
  )
94
 
95
  self.f_discriminator = FeatureDiscriminator(
 
102
  self.style_pool = GroupTensorPool(n_styles, 32)
103
  self.content_pool = GroupTensorPool(n_styles, 32)
104
 
105
+ # self.lpips = LPIPS(net="vgg")
106
+ # self.lpips.requires_grad_(False)
107
+
108
  self.cls_loss = nn.CrossEntropyLoss()
109
 
110
  def on_fit_start(self):
 
118
 
119
  g_opt, i_d_opt, f_d_opt = self.optimizers()
120
 
121
+ # train g1
122
  z_s = self.style_encoder(x)
123
  z_c = self.content_encoder(x)
124
  x_rec = self.decoder(z_c, z_s)
 
126
  style_logits = self.style_classifier(z_s)
127
  cls_loss = self.cls_loss(style_logits, gt_style)
128
 
129
+ rec_loss = F.l1_loss(x, x_rec) # + self.lpips(x, x_rec).mean()
130
+
131
+ f_g_loss = self.f_discriminator(z_c, z_s, for_G=True).mean()
132
+
133
+ # g1_loss = (
134
+ # self.lambda_rec * rec_loss
135
+ # + self.lambda_f_g * f_g_loss
136
+ # + self.lambda_cls * cls_loss
137
+ # )
138
+
139
+ # g1_opt.zero_grad()
140
 
141
+ # self.manual_backward(g1_loss)
142
+ # g1_opt.step()
143
+
144
+ # sample a random content and style feature
145
  z_c_hat, _ = self.content_pool.query(z_c, gt_style)
146
+ z_s_hat, gt_style_hat = self.style_pool.query(z_s, gt_style)
147
+
148
+ # train g2
149
+ z_c = z_c.detach()
150
+ z_s = z_s.detach()
151
 
152
  x1 = self.decoder(z_c, z_s_hat)
153
  x2 = self.decoder(z_c_hat, z_s)
 
155
  z_c_rec = self.content_encoder(x1)
156
  z_s_rec = self.style_encoder(x2)
157
 
158
+ z_c_hat_rec = self.content_encoder(x2)
159
+ z_s_hat_rec = self.style_encoder(x1)
160
+
161
  x_cycle = self.decoder(z_c_rec, z_s_rec)
162
 
163
+ cycle_loss = F.l1_loss(x, x_cycle) # + self.lpips(x, x_cycle).mean()
 
164
 
165
+ c_const_loss = F.l1_loss(z_c, z_c_rec) + F.l1_loss(z_c_hat, z_c_hat_rec)
166
+ s_const_loss = F.l1_loss(z_s, z_s_rec) + F.l1_loss(z_s_hat, z_s_hat_rec)
167
 
 
168
  i_g_loss = (
169
+ self.i_discriminator(x1, gt_style_hat, for_G=True).mean()
170
+ + self.i_discriminator(x2, gt_style, for_G=True).mean()
171
+ + self.i_discriminator(x_cycle, gt_style, for_G=True).mean()
172
+ ) / 3.0
 
173
 
174
  g_loss = (
175
  self.lambda_rec * rec_loss
176
+ + self.lambda_f_g * f_g_loss
177
+ + self.lambda_cls * cls_loss
178
  + self.lambda_cycle * cycle_loss
179
  + self.lambda_c_const * c_const_loss
180
  + self.lambda_s_const * s_const_loss
181
  + self.lambda_i_g * i_g_loss
 
 
182
  )
183
 
184
  g_opt.zero_grad()
 
186
  g_opt.step()
187
 
188
  # train the image discriminator
189
+ x1 = x1.detach()
190
+ x2 = x2.detach()
191
+ x_cycle = x_cycle.detach()
192
+
193
  i_d_loss = (
194
+ self.i_discriminator(x, gt_style, for_real=True).mean()
195
  + (
196
+ self.i_discriminator(x1, gt_style_hat, for_real=False).mean()
197
+ + self.i_discriminator(x2, gt_style, for_real=False).mean()
198
+ + self.i_discriminator(x_cycle, gt_style, for_real=False).mean()
199
  )
200
+ / 3.0
201
+ ) / 2.0
202
 
203
  i_d_opt.zero_grad()
204
  self.manual_backward(i_d_loss)
 
206
 
207
  # train the feature discriminator
208
  f_d_loss = (
209
+ self.f_discriminator(z_c, z_s, for_real=False).mean()
210
  + (
211
+ self.f_discriminator(z_c, z_s_hat, for_real=True).mean()
212
+ + self.f_discriminator(z_c_hat, z_s, for_real=True).mean()
213
  )
214
  / 2
215
  ) / 2
 
224
  "train/cycle_loss": cycle_loss,
225
  "train/cls_loss": cls_loss,
226
  "train/i_g_loss": i_g_loss,
227
+ "train/i_d_loss": i_d_loss,
228
+ "train/f_g_loss": f_g_loss,
229
  "train/f_d_loss": f_d_loss,
230
  "train/c_const_loss": c_const_loss,
231
  "train/s_const_loss": s_const_loss,
 
290
  return x
291
 
292
  def configure_optimizers(self):
293
+ g1_opt = torch.optim.AdamW(
294
  [
295
  {"params": self.content_encoder.parameters()},
296
+ {"params": self.style_encoder.parameters()},
 
 
 
 
297
  {"params": self.style_classifier.parameters()},
298
  {"params": self.decoder.parameters()},
299
  ],
300
  lr=self.learning_rate,
301
  weight_decay=self.weight_decay,
302
+ betas=(self.beta_1, self.beta_2),
303
  )
304
 
305
+ # g2_opt = torch.optim.AdamW(
306
+ # [
307
+ # {"params": self.decoder.parameters()},
308
+ # ],
309
+ # lr=self.learning_rate,
310
+ # weight_decay=self.weight_decay,
311
+ # )
312
+
313
  i_d_opt = torch.optim.AdamW(
314
  list(self.i_discriminator.parameters()),
315
  lr=self.learning_rate,
316
  weight_decay=self.weight_decay,
317
+ betas=(self.beta_1, self.beta_2),
318
  )
319
 
320
  f_d_opt = torch.optim.AdamW(
321
  list(self.f_discriminator.parameters()),
322
+ lr=self.learning_rate * 10,
323
  weight_decay=self.weight_decay,
324
+ betas=(self.beta_1, self.beta_2),
325
  )
326
 
327
+ return [g1_opt, i_d_opt, f_d_opt]
swim/train.py CHANGED
@@ -88,7 +88,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
88
 
89
  if cfg.get("train"):
90
  if cfg.compile:
91
- model.compile()
92
 
93
  log.info("Starting training!")
94
  trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
 
88
 
89
  if cfg.get("train"):
90
  if cfg.compile:
91
+ model = torch.compile(model)
92
 
93
  log.info("Starting training!")
94
  trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
swim/utils/tensor_pool.py CHANGED
@@ -28,6 +28,10 @@ class TensorPool(object):
28
  return torch.stack(return_tensors)
29
 
30
 
 
 
 
 
31
  class GroupTensorPool(object):
32
 
33
  def __init__(self, n_groups: int = 5, pool_size: int = 32):
@@ -40,25 +44,34 @@ class GroupTensorPool(object):
40
  return_groups = []
41
 
42
  tensors = tensors.detach().clone()
 
43
  for tensor, curr_group in zip(tensors, labels):
44
- # choose a random group except the current one
45
- # if the pool is empty, return the current tensor
 
 
46
  group = choice([i for i in range(self.n_groups) if i != curr_group.item()])
47
  pool = self.pools[group]
48
 
49
  if len(pool) == 0:
 
50
  return_tensors.append(tensor)
51
  return_groups.append(curr_group)
52
- pool.append((tensor, curr_group))
53
- elif len(pool) < self.pool_size:
54
- idx = randint(0, len(pool) - 1)
55
- return_tensors.append(pool[idx][0])
56
- return_groups.append(pool[idx][1])
57
- pool.append((tensor, curr_group))
58
  else:
59
- idx = randint(0, self.pool_size - 1)
60
- return_tensors.append(pool[idx][0])
61
- return_groups.append(pool[idx][1])
62
- pool[idx] = (tensor, curr_group)
63
 
64
  return torch.stack(return_tensors), torch.stack(return_groups)
 
 
 
 
 
 
 
 
 
 
 
 
28
  return torch.stack(return_tensors)
29
 
30
 
31
+ import torch
32
+ from random import choice, randint
33
+
34
+
35
  class GroupTensorPool(object):
36
 
37
  def __init__(self, n_groups: int = 5, pool_size: int = 32):
 
44
  return_groups = []
45
 
46
  tensors = tensors.detach().clone()
47
+
48
  for tensor, curr_group in zip(tensors, labels):
49
+ self.save_new_tensor(tensor, curr_group.item())
50
+
51
+ for tensor, curr_group in zip(tensors, labels):
52
+ # Choose a random group except the current one
53
  group = choice([i for i in range(self.n_groups) if i != curr_group.item()])
54
  pool = self.pools[group]
55
 
56
  if len(pool) == 0:
57
+ # If the selected group pool is empty, return the current tensor
58
  return_tensors.append(tensor)
59
  return_groups.append(curr_group)
 
 
 
 
 
 
60
  else:
61
+ # Otherwise, select a random tensor from the pool
62
+ random_tensor = choice(pool)
63
+ return_tensors.append(random_tensor)
64
+ return_groups.append(torch.tensor(group))
65
 
66
  return torch.stack(return_tensors), torch.stack(return_groups)
67
+
68
+ def save_new_tensor(self, tensor: torch.Tensor, group: int):
69
+ pool = self.pools[group]
70
+
71
+ if len(pool) < self.pool_size:
72
+ # If the pool is not full, append the tensor
73
+ pool.append(tensor)
74
+ else:
75
+ # Replace a random item in the pool with the new tensor
76
+ replace_idx = randint(0, len(pool) - 1)
77
+ pool[replace_idx] = tensor
train_swim.sh CHANGED
@@ -3,7 +3,7 @@ python swim/train.py \
3
  data.root_dir=/cm/shared/ninhnq3/datasets/swim_data \
4
  logger=wandb \
5
  logger.wandb.save_dir=/cm/shared/ninhnq3/workdirs/swim \
6
- +trainer.val_check_interval=0.02 \
7
  +trainer.limit_val_batches=0.01 \
8
  callbacks.model_checkpoint.dirpath=/cm/shared/ninhnq3/checkpoints/swim_final
9
 
 
3
  data.root_dir=/cm/shared/ninhnq3/datasets/swim_data \
4
  logger=wandb \
5
  logger.wandb.save_dir=/cm/shared/ninhnq3/workdirs/swim \
6
+ +trainer.val_check_interval=0.05 \
7
  +trainer.limit_val_batches=0.01 \
8
  callbacks.model_checkpoint.dirpath=/cm/shared/ninhnq3/checkpoints/swim_final
9