qninhdt commited on
Commit
b759b90
·
1 Parent(s): d2c8fab
cc.py CHANGED
@@ -4,18 +4,24 @@ from torchinfo import summary
4
  from swim.models.content_encoder import ContentEncoder
5
  from swim.models.decoder import Decoder
6
  from swim.models.style_encoder import StyleEncoder
7
- from swim.models.discriminator import Discriminator
8
- from swim.models.swim_gan import SwimGAN
 
 
9
 
10
- model = SwimGAN().cuda()
11
  image = torch.randn(1, 3, 512, 512).to("cuda")
12
- sample = torch.randn(1, 4, 64, 64).to("cuda")
13
  style_emb = torch.randn(1, 256).to("cuda")
 
 
 
 
 
 
 
 
14
 
15
- # summary(content_encoder, input_data=(sample,))
16
- # summary(decoder, input_data=(sample, style_emb))
17
- # summary(style_encoder, input_data=sample)
18
- summary(
19
- model,
20
- input_data=(image),
21
- )
 
4
  from swim.models.content_encoder import ContentEncoder
5
  from swim.models.decoder import Decoder
6
  from swim.models.style_encoder import StyleEncoder
7
+ from swim.models.discriminator import FeatureDiscriminator
8
+ import vision_aided_loss
9
+
10
+ # from swim.models.swim_gan import SwimGAN
11
 
 
12
  image = torch.randn(1, 3, 512, 512).to("cuda")
 
13
  style_emb = torch.randn(1, 256).to("cuda")
14
+ content = torch.randn(1, 512, 64, 64).to("cuda")
15
+ content_encoder = ContentEncoder().cuda()
16
+ decoder = Decoder().cuda()
17
+ style_encoder = StyleEncoder().cuda()
18
+ discriminator = FeatureDiscriminator().cuda()
19
+ # i_discriminator = vision_aided_loss.Discriminator(
20
+ # cv_type="clip", loss_type="multilevel_sigmoid_s", device="cuda"
21
+ # ).to("cuda")
22
 
23
+ summary(content_encoder, input_data=(image,))
24
+ # summary(decoder, input_data=(content, style_emb))
25
+ # summary(style_encoder, input_data=image)
26
+ # summary(discriminator, input_data=(content, style_emb))
27
+ # summary(i_discriminator, input_data=(image,))
 
 
configs/experiment/channel=64.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: swim
5
+ - override /model: swim_gan
6
+ - override /callbacks: default
7
+ - override /trainer: gpu
8
+
9
+ seed: 42
10
+
11
+ trainer:
12
+ max_epochs: 100
13
+
14
+ model:
15
+ channels: 64
16
+ z_c_channels: 256
17
+ n_enc_resnet_blocks: 4
18
+ n_dec_resnet_blocks: 6
19
+ n_f_d_resnet_blocks: 2
20
+
21
+ data:
22
+ batch_size: 2
configs/experiment/example.yaml DELETED
@@ -1,41 +0,0 @@
1
- # @package _global_
2
-
3
- # to execute this experiment run:
4
- # python train.py experiment=example
5
-
6
- defaults:
7
- - override /data: mnist
8
- - override /model: mnist
9
- - override /callbacks: default
10
- - override /trainer: default
11
-
12
- # all parameters below will be merged with parameters from default configurations set above
13
- # this allows you to overwrite only specified parameters
14
-
15
- tags: ["mnist", "simple_dense_net"]
16
-
17
- seed: 12345
18
-
19
- trainer:
20
- min_epochs: 10
21
- max_epochs: 10
22
- gradient_clip_val: 0.5
23
-
24
- model:
25
- optimizer:
26
- lr: 0.002
27
- net:
28
- lin1_size: 128
29
- lin2_size: 256
30
- lin3_size: 64
31
- compile: false
32
-
33
- data:
34
- batch_size: 64
35
-
36
- logger:
37
- wandb:
38
- tags: ${tags}
39
- group: "mnist"
40
- aim:
41
- experiment: "mnist"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/experiment/potato.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: swim
5
+ - override /model: swim_gan
6
+ - override /callbacks: default
7
+ - override /trainer: gpu
8
+
9
+ seed: 42
10
+
11
+ trainer:
12
+ max_epochs: 100
13
+
14
+ model:
15
+ channels: 32
16
+ z_c_channels: 128
17
+ n_enc_resnet_blocks: 1
18
+ n_dec_resnet_blocks: 1
19
+
20
+ data:
21
+ batch_size: 2
configs/model/swim_gan.yaml CHANGED
@@ -1,3 +1 @@
1
  _target_: swim.models.swim_gan.SwimGAN
2
-
3
- learning_rate: 1e-4
 
1
  _target_: swim.models.swim_gan.SwimGAN
 
 
swim/models/blocks.py CHANGED
@@ -77,11 +77,9 @@ class DownSample(nn.Module):
77
  class ResnetBlock(nn.Module):
78
  def __init__(self, in_channels: int, out_channels: int, cond_channels: int = 0):
79
  super().__init__()
80
- # First normalization and convolution layer
81
- self.norm1 = normalization(in_channels)
82
  self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
 
83
 
84
- # cond layer
85
  self.cond_channels = cond_channels
86
  if cond_channels > 0:
87
  self.cond_proj = nn.Sequential(
@@ -90,40 +88,32 @@ class ResnetBlock(nn.Module):
90
  nn.Linear(cond_channels, out_channels * 2),
91
  )
92
 
93
- # Second normalization and convolution layer
94
- self.norm2 = normalization(out_channels)
95
  self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
96
- # `in_channels` to `out_channels` mapping layer for residual connection
 
97
  if in_channels != out_channels:
98
- self.nin_shortcut = nn.Conv2d(
99
- in_channels, out_channels, 1, stride=1, padding=0
100
- )
101
  else:
102
- self.nin_shortcut = nn.Identity()
103
 
104
  def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
105
  h = x
106
 
107
- # First normalization and convolution layer
108
  h = self.norm1(h)
109
  h = F.silu(h)
110
- h = self.conv1(h)
111
 
112
- # cond layer
113
  if cond is not None:
114
  cond = self.cond_proj(cond)[..., None, None]
115
  cond_scale, cond_shift = torch.chunk(cond, 2, dim=1)
116
- h = self.norm2(h)
117
  h = h * (1 + cond_scale) + cond_shift
118
- else:
119
- h = self.norm2(h)
120
 
121
- # Second normalization and convolution layer
122
- h = F.silu(h)
123
  h = self.conv2(h)
 
 
 
124
 
125
- # Map and add residual
126
- return self.nin_shortcut(x) + h
127
 
128
 
129
  def normalization(channels):
 
77
  class ResnetBlock(nn.Module):
78
  def __init__(self, in_channels: int, out_channels: int, cond_channels: int = 0):
79
  super().__init__()
 
 
80
  self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
81
+ self.norm1 = normalization(in_channels)
82
 
 
83
  self.cond_channels = cond_channels
84
  if cond_channels > 0:
85
  self.cond_proj = nn.Sequential(
 
88
  nn.Linear(cond_channels, out_channels * 2),
89
  )
90
 
 
 
91
  self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
92
+ self.norm2 = normalization(out_channels)
93
+
94
  if in_channels != out_channels:
95
+ self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
 
 
96
  else:
97
+ self.shortcut = nn.Identity()
98
 
99
  def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
100
  h = x
101
 
102
+ h = self.conv1(h)
103
  h = self.norm1(h)
104
  h = F.silu(h)
 
105
 
 
106
  if cond is not None:
107
  cond = self.cond_proj(cond)[..., None, None]
108
  cond_scale, cond_shift = torch.chunk(cond, 2, dim=1)
 
109
  h = h * (1 + cond_scale) + cond_shift
 
 
110
 
 
 
111
  h = self.conv2(h)
112
+ h = self.norm2(h)
113
+ h = h + self.shortcut(x)
114
+ h = F.silu(h)
115
 
116
+ return h
 
117
 
118
 
119
  def normalization(channels):
swim/models/content_encoder.py CHANGED
@@ -4,53 +4,56 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
- from .blocks import ResnetBlock, DownSample, AttentionBlock, normalization
8
 
9
 
10
  class ContentEncoder(nn.Module):
11
  def __init__(
12
  self,
13
- in_channels: int = 4,
14
- z_c_channels: int = 128,
15
  channels: int = 128,
16
- channel_multipliers: List[int] = [1, 2, 4],
17
- n_resnet_blocks: int = 2,
 
18
  ):
19
  super().__init__()
20
- n_resolutions = len(channel_multipliers)
21
 
22
- self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
 
23
 
24
- channels_list = [m * channels for m in [1] + channel_multipliers]
25
 
26
- self.down = nn.ModuleList()
27
- for i in range(n_resolutions):
28
- resnet_blocks = nn.ModuleList()
29
- for _ in range(n_resnet_blocks):
30
- resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
31
- channels = channels_list[i + 1]
32
- down = nn.Module()
33
- down.block = resnet_blocks
 
34
 
35
- if i != n_resolutions - 1:
36
- down.downsample = DownSample(channels)
37
- else:
38
- down.downsample = nn.Identity()
39
- self.down.append(down)
40
 
41
- self.norm_out = normalization(channels)
42
  self.conv_out = nn.Conv2d(channels, z_c_channels, 3, stride=1, padding=1)
 
 
 
 
 
 
43
 
44
- def forward(self, img: torch.Tensor):
45
- x = self.conv_in(img)
46
 
47
- for down in self.down:
48
- for block in down.block:
49
- x = block(x)
50
- x = down.downsample(x)
51
 
52
- x = self.norm_out(x)
53
- x = F.silu(x)
54
- x = self.conv_out(x)
55
 
56
- return x
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
+ from .blocks import ResnetBlock, normalization
8
 
9
 
10
  class ContentEncoder(nn.Module):
11
  def __init__(
12
  self,
 
 
13
  channels: int = 128,
14
+ z_c_channels: int = 512,
15
+ downsample_channel_mults: List[int] = [1, 2, 4],
16
+ n_resnet_blocks: int = 4,
17
  ):
18
  super().__init__()
 
19
 
20
+ self.conv_in = nn.Conv2d(3, channels, 7, stride=1, padding=3)
21
+ self.norm_in = normalization(channels)
22
 
23
+ channel_list = [channels * mult for mult in downsample_channel_mults]
24
 
25
+ self.downsamples = nn.ModuleList()
26
+ for out_channels in channel_list:
27
+ self.downsamples.append(
28
+ nn.Sequential(
29
+ nn.Conv2d(channels, out_channels, 4, stride=2, padding=1),
30
+ normalization(out_channels),
31
+ nn.SiLU(),
32
+ )
33
+ )
34
 
35
+ channels = out_channels
36
+
37
+ self.resnet_blocks = nn.ModuleList()
38
+ for _ in range(n_resnet_blocks):
39
+ self.resnet_blocks.append(ResnetBlock(channels, channels))
40
 
 
41
  self.conv_out = nn.Conv2d(channels, z_c_channels, 3, stride=1, padding=1)
42
+ self.norm_out = normalization(z_c_channels)
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ h = self.conv_in(x)
46
+ h = self.norm_in(h)
47
+ h = F.silu(h)
48
 
49
+ for downsample in self.downsamples:
50
+ h = downsample(h)
51
 
52
+ for resnet_block in self.resnet_blocks:
53
+ h = resnet_block(h)
 
 
54
 
55
+ h = self.conv_out(h)
56
+ h = self.norm_out(h)
57
+ h = F.silu(h)
58
 
59
+ return h
swim/models/decoder.py CHANGED
@@ -4,68 +4,58 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
- from .blocks import (
8
- ResnetBlock,
9
- AttentionBlock,
10
- StyledSequential,
11
- UpSample,
12
- normalization,
13
- )
14
 
15
 
16
  class Decoder(nn.Module):
 
17
  def __init__(
18
  self,
19
- z_c_channels: int = 128,
20
- out_channels: int = 4,
21
  channels: int = 128,
22
- channel_multipliers: List[int] = [1, 2, 4],
23
- n_resnet_blocks: int = 2,
 
24
  d_style_emb: int = 256,
25
  ):
26
  super().__init__()
27
- num_resolutions = len(channel_multipliers)
28
 
29
- channels_list = [m * channels for m in channel_multipliers]
 
 
 
30
 
31
- channels = channels_list[-1]
32
 
33
- self.conv_in = nn.Conv2d(z_c_channels, channels, 3, stride=1, padding=1)
 
 
34
 
35
- self.up = nn.ModuleList()
36
- for i in reversed(range(num_resolutions)):
37
- resnet_blocks = nn.ModuleList()
38
- for _ in range(n_resnet_blocks + 1):
39
- resnet_blocks.append(
40
- StyledSequential(
41
- ResnetBlock(channels, channels_list[i], d_style_emb)
42
- )
43
  )
44
- channels = channels_list[i]
45
 
46
- up = nn.Module()
47
- up.block = resnet_blocks
48
- if i != 0:
49
- up.upsample = UpSample(channels)
50
- else:
51
- up.upsample = nn.Identity()
52
- self.up.insert(0, up)
53
 
54
- self.norm_out = normalization(channels)
55
- self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
56
 
57
- def forward(self, z_c: torch.Tensor, z_s: torch.Tensor):
 
 
 
58
 
59
- h = self.conv_in(z_c)
 
60
 
61
- for up in reversed(self.up):
62
- for block in up.block:
63
- h = block(h, z_s)
64
- h = up.upsample(h)
65
 
66
- h = self.norm_out(h)
67
- h = F.silu(h)
68
- img = self.conv_out(h)
69
 
70
- #
71
- return img
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
+ from .blocks import ResnetBlock, normalization
 
 
 
 
 
 
8
 
9
 
10
  class Decoder(nn.Module):
11
+
12
  def __init__(
13
  self,
 
 
14
  channels: int = 128,
15
+ z_c_channels: int = 512,
16
+ upsample_channel_mults: List[int] = [1, 2, 4],
17
+ n_resnet_blocks: int = 6,
18
  d_style_emb: int = 256,
19
  ):
20
  super().__init__()
 
21
 
22
+ channel_list = [channels * mult for mult in upsample_channel_mults]
23
+
24
+ self.conv_in = nn.Conv2d(z_c_channels, channel_list[-1], 3, stride=1, padding=1)
25
+ self.norm_in = normalization(channel_list[-1])
26
 
27
+ channels = channel_list[-1]
28
 
29
+ self.resnet_blocks = nn.ModuleList()
30
+ for _ in range(n_resnet_blocks):
31
+ self.resnet_blocks.append(ResnetBlock(channels, channels, d_style_emb))
32
 
33
+ self.upsamples = nn.ModuleList()
34
+ for out_channels in channel_list[::-1]:
35
+ self.upsamples.append(
36
+ nn.Sequential(
37
+ nn.ConvTranspose2d(channels, out_channels, 4, stride=2, padding=1),
38
+ normalization(out_channels),
39
+ nn.SiLU(),
 
40
  )
41
+ )
42
 
43
+ channels = out_channels
 
 
 
 
 
 
44
 
45
+ self.conv_out = nn.Conv2d(channels, 3, 7, stride=1, padding=3)
 
46
 
47
+ def forward(self, x: torch.Tensor, style_emb: torch.Tensor):
48
+ h = self.conv_in(x)
49
+ h = self.norm_in(h)
50
+ h = F.silu(h)
51
 
52
+ for resnet_block in self.resnet_blocks:
53
+ h = resnet_block(h, style_emb)
54
 
55
+ for upsample in self.upsamples:
56
+ h = upsample(h)
 
 
57
 
58
+ h = self.conv_out(h)
59
+ h = torch.tanh(h)
 
60
 
61
+ return h
 
swim/models/discriminator.py CHANGED
@@ -10,10 +10,8 @@ class SNResnetBlock(nn.Module):
10
  in_channels: int,
11
  out_channels: int,
12
  cond_channels: int = 0,
13
- downsample: bool = False,
14
  ):
15
  super().__init__()
16
- self.downsample = downsample
17
  self.d_cond = cond_channels
18
  self.conv1 = spectral_norm(
19
  nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
@@ -32,85 +30,84 @@ class SNResnetBlock(nn.Module):
32
  )
33
 
34
  if in_channels != out_channels:
35
- self.nin_shortcut = spectral_norm(
36
  nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
37
  )
38
  else:
39
- self.nin_shortcut = nn.Identity()
40
 
41
  def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
42
  h = x
43
 
44
- h = F.leaky_relu(h, 0.2)
45
  h = self.conv1(h)
 
46
 
47
  if self.d_cond > 0:
48
  cond = self.cond_proj(cond)[..., None, None]
49
  cond_scale, cond_shift = torch.chunk(cond, 2, dim=1)
50
  h = h * (1 + cond_scale) + cond_shift
51
- else:
52
- assert cond is None
53
 
54
- h = F.leaky_relu(h, 0.2)
55
  h = self.conv2(h)
 
 
56
 
57
- h_skip = self.nin_shortcut(x)
58
-
59
- if self.downsample:
60
- h = F.avg_pool2d(h, 2)
61
- h_skip = F.avg_pool2d(h_skip, 2)
62
-
63
- return h + h_skip
64
 
65
 
66
- class Discriminator(nn.Module):
67
 
68
  def __init__(
69
  self,
70
- in_channels,
71
- channels,
72
- channel_multipliers,
73
- d_cond: int = 0,
74
  ):
75
  super().__init__()
76
 
77
- self.input_block = spectral_norm(nn.Conv2d(in_channels, channels, 3, padding=1))
 
 
78
 
79
- self.blocks = nn.ModuleList()
 
 
80
 
81
- n_resolutions = len(channel_multipliers)
82
- channels_list = [m * channels for m in channel_multipliers]
83
- for i in range(n_resolutions):
84
- self.blocks.append(
85
- SNResnetBlock(
86
- channels, channels_list[i], cond_channels=d_cond, downsample=True
87
- )
88
- )
89
- channels = channels_list[i]
90
 
91
- self.out = nn.Sequential(
92
- nn.LeakyReLU(0.2), spectral_norm(nn.Conv2d(channels, 1, 3, 1, 1))
 
 
93
  )
94
 
95
  def forward(
96
- self, x: torch.Tensor, cond: torch.Tensor = None, for_real: bool = False
97
  ):
98
- h = self.input_block(x)
 
 
 
 
 
 
 
 
99
 
100
- for block in self.blocks:
101
- h = block(h, cond)
 
102
 
103
- h = self.out(h)
 
104
 
105
  if for_real:
106
  target = torch.ones_like(h)
107
  else:
108
  target = torch.zeros_like(h)
109
 
110
- loss = F.binary_cross_entropy_with_logits(h, target)
111
- # if for_real:
112
- # loss = torch.relu(1 - h).mean()
113
- # else:
114
- # loss = torch.relu(1 + h).mean()
115
 
116
  return loss
 
10
  in_channels: int,
11
  out_channels: int,
12
  cond_channels: int = 0,
 
13
  ):
14
  super().__init__()
 
15
  self.d_cond = cond_channels
16
  self.conv1 = spectral_norm(
17
  nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
 
30
  )
31
 
32
  if in_channels != out_channels:
33
+ self.shortcut = spectral_norm(
34
  nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
35
  )
36
  else:
37
+ self.shortcut = nn.Identity()
38
 
39
  def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
40
  h = x
41
 
 
42
  h = self.conv1(h)
43
+ h = F.leaky_relu(h, 0.2)
44
 
45
  if self.d_cond > 0:
46
  cond = self.cond_proj(cond)[..., None, None]
47
  cond_scale, cond_shift = torch.chunk(cond, 2, dim=1)
48
  h = h * (1 + cond_scale) + cond_shift
 
 
49
 
 
50
  h = self.conv2(h)
51
+ h = h + self.shortcut(x)
52
+ h = F.leaky_relu(h, 0.2)
53
 
54
+ return h
 
 
 
 
 
 
55
 
56
 
57
+ class FeatureDiscriminator(nn.Module):
58
 
59
  def __init__(
60
  self,
61
+ z_c_channels: int = 512,
62
+ channels: int = 512,
63
+ d_style_emb: int = 256,
64
+ n_resnet_blocks: int = 2,
65
  ):
66
  super().__init__()
67
 
68
+ self.conv_in = spectral_norm(
69
+ nn.Conv2d(z_c_channels, channels, 3, stride=1, padding=1)
70
+ )
71
 
72
+ self.resnet_blocks = nn.ModuleList()
73
+ for _ in range(n_resnet_blocks):
74
+ self.resnet_blocks.append(SNResnetBlock(channels, channels, d_style_emb))
75
 
76
+ self.conv_out = spectral_norm(
77
+ nn.Conv2d(channels, channels, 3, stride=1, padding=1)
78
+ )
 
 
 
 
 
 
79
 
80
+ self.mlp = nn.Sequential(
81
+ spectral_norm(nn.Linear(channels, 256)),
82
+ nn.LeakyReLU(0.2),
83
+ spectral_norm(nn.Linear(256, 1)),
84
  )
85
 
86
  def forward(
87
+ self, x: torch.Tensor, style_emb: torch.Tensor, for_G=False, for_real=False
88
  ):
89
+ h = self.conv_in(x)
90
+ h = F.leaky_relu(h, 0.2)
91
+
92
+ for resnet_block in self.resnet_blocks:
93
+ h = resnet_block(h, style_emb)
94
+ h = F.avg_pool2d(h, 2)
95
+
96
+ h = self.conv_out(h)
97
+ h = F.leaky_relu(h, 0.2)
98
 
99
+ h = F.adaptive_avg_pool2d(h, 1)
100
+ h = torch.flatten(h, 1)
101
+ h = self.mlp(h)
102
 
103
+ if for_G:
104
+ for_real = True
105
 
106
  if for_real:
107
  target = torch.ones_like(h)
108
  else:
109
  target = torch.zeros_like(h)
110
 
111
+ loss = F.binary_cross_entropy_with_logits(h, target, reduction="none")
 
 
 
 
112
 
113
  return loss
swim/models/style_encoder.py CHANGED
@@ -1,62 +1,24 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
- from lightning import LightningModule
6
 
7
- from .blocks import ResnetBlock, DownSample
8
 
 
 
 
 
9
 
10
- class StyleEncoder(LightningModule):
11
- def __init__(
12
- self,
13
- in_channels: int = 4,
14
- n_styles: int = 4,
15
- d_style_emb: int = 256,
16
- d_hidden: int = 512,
17
- channels: int = 64,
18
- n_layers: int = 4,
19
- ):
20
- super().__init__()
21
 
22
- self.n_styles = n_styles
 
 
23
 
24
- # Initial convolution
25
- self.conv_in = nn.Sequential(
26
- nn.Conv2d(in_channels, channels, kernel_size=3, padding=1),
27
- nn.GroupNorm(32, channels),
28
- nn.SiLU(),
29
- )
30
 
31
- # Convolutional blocks with GroupNorm and single convolution
32
- self.blocks = nn.ModuleList()
33
- for i in range(n_layers):
34
- self.blocks.append(
35
- nn.Sequential(
36
- nn.Conv2d(
37
- channels, channels * 2, kernel_size=3, stride=2, padding=1
38
- ), # Downsample
39
- nn.GroupNorm(32, channels * 2),
40
- nn.SiLU(),
41
- )
42
- )
43
- channels *= 2
44
-
45
- # Output MLP
46
- self.out = nn.Sequential(
47
- nn.AdaptiveAvgPool2d(1), # Global average pooling
48
- nn.Flatten(), # Flatten spatial dimensions
49
- nn.Linear(channels, d_hidden), # First dense layer
50
- nn.SiLU(),
51
- nn.Linear(d_hidden, d_style_emb), # Final style embedding
52
- )
53
-
54
- def forward(self, x: torch.Tensor) -> torch.Tensor:
55
- h = self.conv_in(x) # Initial convolution
56
-
57
- for block in self.blocks:
58
- h = block(h) # Pass through each block
59
-
60
- h = self.out(h) # Pool and process for style embedding
61
-
62
- return h
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from torchvision.models.resnet import resnet18, ResNet18_Weights
5
 
 
6
 
7
+ class StyleEncoder(nn.Module):
8
 
9
+ def __init__(self, d_style_emb=256):
10
+ super(StyleEncoder, self).__init__()
11
+ self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
12
+ self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
13
 
14
+ self.fc = nn.Linear(512, d_style_emb)
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def forward(self, x):
17
+ # resize input to 224x224
18
+ # x = F.interpolate(x, size=(224, 224), mode="bilinear")
19
 
20
+ x = self.resnet(x)
21
+ x = torch.flatten(x, 1)
 
 
 
 
22
 
23
+ x = self.fc(x)
24
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
swim/models/swim_gan.py CHANGED
@@ -12,12 +12,14 @@ from lightning import LightningModule
12
  from diffusers import AutoencoderKL
13
  from diffusers.utils import make_image_grid
14
 
15
- from swim.utils.tensor_pool import GroupTensorPool, TensorPool
16
 
17
  from .style_encoder import StyleEncoder
18
  from .content_encoder import ContentEncoder
19
  from .decoder import Decoder
20
- from .discriminator import Discriminator
 
 
21
 
22
 
23
  class SwimGAN(LightningModule):
@@ -26,16 +28,18 @@ class SwimGAN(LightningModule):
26
  self,
27
  channels: int = 128,
28
  z_c_channels: int = 512,
29
- channel_multipliers: list = [1, 2, 2, 4],
30
- n_resnet_blocks: int = 2,
 
 
31
  n_styles: int = 5,
32
  d_style_emb: int = 128,
33
  input_size: int = 512,
34
  learning_rate: float = 1e-5,
35
- weight_decay: float = 0,
36
  lambda_cls: float = 10.0,
37
- lambda_rec: float = 1.0,
38
- lambda_cycle: float = 1.0,
39
  lambda_c_g: float = 1.0,
40
  lambda_x_g: float = 1.0,
41
  lambda_c_const: float = 1.0,
@@ -58,43 +62,34 @@ class SwimGAN(LightningModule):
58
  self.lambda_x_g = lambda_x_g
59
 
60
  self.content_encoder = ContentEncoder(
61
- in_channels=4,
62
- z_c_channels=z_c_channels,
63
  channels=channels,
64
- channel_multipliers=[1, 2, 2, 4],
65
- n_resnet_blocks=2,
 
66
  )
67
 
68
- self.style_encoder = StyleEncoder(4, n_styles, d_style_emb)
69
 
70
  self.decoder = Decoder(
71
- z_c_channels=z_c_channels,
72
- out_channels=4,
73
  channels=channels,
74
- channel_multipliers=[1, 2, 2, 4],
75
- n_resnet_blocks=2,
 
76
  d_style_emb=d_style_emb,
77
  )
78
 
79
- self.vae: AutoencoderKL = AutoencoderKL.from_pretrained(
80
- "stabilityai/sd-turbo", subfolder="vae"
81
- )
82
- self.vae.requires_grad_(False)
83
- self.vae.eval()
84
-
85
  self.style_classifier = nn.Linear(d_style_emb, n_styles)
86
 
87
  # training only
88
- self.x_discriminator = Discriminator(
89
- in_channels=4,
90
- channels=128,
91
- channel_multipliers=[1, 2, 4, 8],
92
  )
93
- self.c_discriminator = Discriminator(
94
- in_channels=z_c_channels,
 
95
  channels=z_c_channels,
96
- channel_multipliers=[1, 1, 1],
97
- d_cond=d_style_emb,
98
  )
99
 
100
  self.style_pool = GroupTensorPool(n_styles, 256)
@@ -102,117 +97,100 @@ class SwimGAN(LightningModule):
102
 
103
  self.cls_loss = nn.CrossEntropyLoss()
104
 
105
- def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
106
- return self.vae.encode(x).latent_dist.sample() * self.vae.config.scaling_factor
107
-
108
- def vae_decode(self, z: torch.Tensor) -> torch.Tensor:
109
- return self.vae.decode(z / self.vae.config.scaling_factor).sample.clamp(-1, 1)
110
-
111
- def forward(self, x: torch.Tensor) -> torch.Tensor:
112
- with torch.no_grad():
113
- z = self.vae_encode(x)
114
- z_c = self.content_encoder(z)
115
- z_s = self.style_encoder(z)
116
- z_rec = self.decoder(z_c, z_s)
117
-
118
- return z, z_c, z_s, z_rec
119
 
120
  def training_step(self, batch, batch_idx):
121
  x = batch["images"]
122
  gt_style = batch["styles"]
123
 
124
- g_opt, cls_opt, x_d_opt, c_d_opt = self.optimizers()
125
 
126
- z = self.vae_encode(x)
 
 
 
127
 
128
- # train the cls
129
- z_s = self.style_encoder(z)
130
  style_logits = self.style_classifier(z_s)
131
  cls_loss = self.cls_loss(style_logits, gt_style)
132
 
133
- cls_opt.zero_grad()
134
- self.manual_backward(cls_loss)
135
- cls_opt.step()
136
-
137
- # train the autoencoder
138
- z_s = z_s.detach()
139
- z_c = self.content_encoder(z)
140
- z_rec = self.decoder(z_c, z_s)
141
-
142
- rec_loss = F.l1_loss(z, z_rec)
143
 
144
  # # sample a random content and style feature
145
  z_c_hat, _ = self.content_pool.query(z_c, gt_style)
146
  z_s_hat, _ = self.style_pool.query(z_s, gt_style)
147
 
148
- z1 = self.decoder(z_c, z_s_hat)
149
- z2 = self.decoder(z_c_hat, z_s)
150
 
151
- z_c_rec = self.content_encoder(z1)
152
- z_s_hat_rec = self.style_encoder(z1)
153
 
154
- z_c_hat_rec = self.content_encoder(z2)
155
- z_s_rec = self.style_encoder(z2)
156
 
157
  c_const_loss = F.l1_loss(z_c, z_c_rec) + F.l1_loss(z_c_hat, z_c_hat_rec)
158
  s_const_loss = F.l1_loss(z_s, z_s_rec) + F.l1_loss(z_s_hat, z_s_hat_rec)
159
 
160
  # adversarial loss
161
- x_g_loss = (
162
- self.x_discriminator(z1, for_real=True)
163
- + self.x_discriminator(z2, for_real=True)
164
  ) / 2
165
 
166
- c_g_loss = self.c_discriminator(z_c, z_s, for_real=False)
167
 
168
  g_loss = (
169
  self.lambda_rec * rec_loss
170
  + self.lambda_c_const * c_const_loss
171
  + self.lambda_s_const * s_const_loss
172
- + self.lambda_x_g * x_g_loss
173
  + self.lambda_c_g * c_g_loss
 
174
  )
175
 
176
  g_opt.zero_grad()
177
  self.manual_backward(g_loss)
178
  g_opt.step()
179
 
180
- # train the x discriminator
181
- x_d_loss = (
182
- self.x_discriminator(z, for_real=True)
183
  + (
184
- self.x_discriminator(z1.detach(), for_real=False)
185
- + self.x_discriminator(z2.detach(), for_real=False)
186
  )
187
  / 2
188
  ) / 2
189
 
190
- x_d_opt.zero_grad()
191
- self.manual_backward(x_d_loss)
192
- x_d_opt.step()
193
 
194
- # train the content discriminator
195
- c_d_loss = (
196
- self.c_discriminator(z_c.detach(), z_s.detach(), for_real=True)
197
  + (
198
- self.c_discriminator(z_c.detach(), z_s_hat, for_real=False)
199
- + self.c_discriminator(z_c_hat, z_s.detach(), for_real=False)
200
  )
201
  / 2
202
  ) / 2
203
 
204
- c_d_opt.zero_grad()
205
- self.manual_backward(c_d_loss)
206
- c_d_opt.step()
207
 
208
  self.log_dict(
209
  {
210
  "train/rec_loss": rec_loss,
211
  "train/cls_loss": cls_loss,
212
- "train/x_g_loss": x_g_loss,
213
- "train/x_d_loss": x_d_loss,
214
  "train/c_g_loss": c_g_loss,
215
- "train/c_d_loss": c_d_loss,
216
  "train/c_const_loss": c_const_loss,
217
  "train/s_const_loss": s_const_loss,
218
  },
@@ -223,25 +201,20 @@ class SwimGAN(LightningModule):
223
 
224
  def validation_step(self, batch, batch_idx):
225
  x = batch["images"]
226
- gt_style_logits = batch["styles"] # B x n_styles
227
 
228
  x = x[torch.randperm(x.shape[0])]
229
- z = self.vae_encode(x)
230
- z_c = self.content_encoder(z)
231
- z_s = self.style_encoder(z)
232
- z_rec = self.decoder(z_c, z_s)
233
 
234
  x1, x2 = x.chunk(2, dim=0)
235
 
236
- x_rec = self.vae_decode(z_rec)
237
  x1_rec, x2_rec = x_rec.chunk(2, dim=0)
238
 
239
  z1_c, z2_c = z_c.chunk(2, dim=0)
240
  z1_s, z2_s = z_s.chunk(2, dim=0)
241
- z1_swap = self.decoder(z1_c, z2_s)
242
- z2_swap = self.decoder(z2_c, z1_s)
243
- x1_swap = self.vae_decode(z1_swap)
244
- x2_swap = self.vae_decode(z2_swap)
245
 
246
  if self.trainer.is_global_zero:
247
  x1_img = self.postprocess_images(x1)
@@ -266,7 +239,7 @@ class SwimGAN(LightningModule):
266
  wandb.Image(image, caption="orig | rec | swap") for image in images
267
  ]
268
 
269
- wandb.log({"val/samples": images})
270
 
271
  self.log("val/lpips", -self.global_step, sync_dist=True)
272
 
@@ -282,28 +255,24 @@ class SwimGAN(LightningModule):
282
 
283
  def configure_optimizers(self):
284
  g_opt = torch.optim.AdamW(
285
- list(self.content_encoder.parameters()) + list(self.decoder.parameters()),
 
 
 
286
  lr=self.learning_rate,
287
- weight_decay=1e-4,
288
- )
289
-
290
- cls_opt = torch.optim.AdamW(
291
- list(self.style_encoder.parameters())
292
- + list(self.style_classifier.parameters()),
293
- lr=self.learning_rate * 10,
294
- weight_decay=1e-4,
295
  )
296
 
297
- x_d_opt = torch.optim.AdamW(
298
- list(self.x_discriminator.parameters()),
299
- lr=self.learning_rate * 10,
300
- weight_decay=1e-4,
301
  )
302
 
303
- c_d_opt = torch.optim.AdamW(
304
- list(self.c_discriminator.parameters()),
305
- lr=self.learning_rate * 10,
306
- weight_decay=1e-4,
307
  )
308
 
309
- return [g_opt, cls_opt, x_d_opt, c_d_opt], []
 
12
  from diffusers import AutoencoderKL
13
  from diffusers.utils import make_image_grid
14
 
15
+ from swim.utils.tensor_pool import GroupTensorPool
16
 
17
  from .style_encoder import StyleEncoder
18
  from .content_encoder import ContentEncoder
19
  from .decoder import Decoder
20
+ from .discriminator import FeatureDiscriminator
21
+
22
+ import vision_aided_loss
23
 
24
 
25
  class SwimGAN(LightningModule):
 
28
  self,
29
  channels: int = 128,
30
  z_c_channels: int = 512,
31
+ updown_channel_mults: List[int] = [1, 2, 4],
32
+ n_enc_resnet_blocks: int = 4,
33
+ n_dec_resnet_blocks: int = 6,
34
+ n_f_d_resnet_blocks: int = 2,
35
  n_styles: int = 5,
36
  d_style_emb: int = 128,
37
  input_size: int = 512,
38
  learning_rate: float = 1e-5,
39
+ weight_decay: float = 1e-4,
40
  lambda_cls: float = 10.0,
41
+ lambda_rec: float = 10.0,
42
+ lambda_cycle: float = 10.0,
43
  lambda_c_g: float = 1.0,
44
  lambda_x_g: float = 1.0,
45
  lambda_c_const: float = 1.0,
 
62
  self.lambda_x_g = lambda_x_g
63
 
64
  self.content_encoder = ContentEncoder(
 
 
65
  channels=channels,
66
+ z_c_channels=z_c_channels,
67
+ downsample_channel_mults=[1, 2, 4],
68
+ n_resnet_blocks=n_enc_resnet_blocks,
69
  )
70
 
71
+ self.style_encoder = StyleEncoder(d_style_emb=d_style_emb)
72
 
73
  self.decoder = Decoder(
 
 
74
  channels=channels,
75
+ z_c_channels=z_c_channels,
76
+ upsample_channel_mults=updown_channel_mults,
77
+ n_resnet_blocks=n_dec_resnet_blocks,
78
  d_style_emb=d_style_emb,
79
  )
80
 
 
 
 
 
 
 
81
  self.style_classifier = nn.Linear(d_style_emb, n_styles)
82
 
83
  # training only
84
+ self.i_discriminator = vision_aided_loss.Discriminator(
85
+ cv_type="clip", loss_type="multilevel_sigmoid", device="cpu"
 
 
86
  )
87
+
88
+ self.f_discriminator = FeatureDiscriminator(
89
+ z_c_channels=z_c_channels,
90
  channels=z_c_channels,
91
+ d_style_emb=d_style_emb,
92
+ n_resnet_blocks=n_f_d_resnet_blocks,
93
  )
94
 
95
  self.style_pool = GroupTensorPool(n_styles, 256)
 
97
 
98
  self.cls_loss = nn.CrossEntropyLoss()
99
 
100
+ def on_fit_start(self):
101
+ for model in self.i_discriminator.cv_ensemble.models:
102
+ model.to(self.device)
103
+ model.requires_grad_(False)
 
 
 
 
 
 
 
 
 
 
104
 
105
  def training_step(self, batch, batch_idx):
106
  x = batch["images"]
107
  gt_style = batch["styles"]
108
 
109
+ g_opt, i_d_opt, f_d_opt = self.optimizers()
110
 
111
+ # train the autoencoder
112
+ z_s = self.style_encoder(x)
113
+ z_c = self.content_encoder(x)
114
+ x_rec = self.decoder(z_c, z_s)
115
 
 
 
116
  style_logits = self.style_classifier(z_s)
117
  cls_loss = self.cls_loss(style_logits, gt_style)
118
 
119
+ rec_loss = F.l1_loss(x, x_rec)
 
 
 
 
 
 
 
 
 
120
 
121
  # # sample a random content and style feature
122
  z_c_hat, _ = self.content_pool.query(z_c, gt_style)
123
  z_s_hat, _ = self.style_pool.query(z_s, gt_style)
124
 
125
+ x1 = self.decoder(z_c, z_s_hat)
126
+ x2 = self.decoder(z_c_hat, z_s)
127
 
128
+ z_c_rec = self.content_encoder(x1)
129
+ z_s_hat_rec = self.style_encoder(x1)
130
 
131
+ z_c_hat_rec = self.content_encoder(x2)
132
+ z_s_rec = self.style_encoder(x2)
133
 
134
  c_const_loss = F.l1_loss(z_c, z_c_rec) + F.l1_loss(z_c_hat, z_c_hat_rec)
135
  s_const_loss = F.l1_loss(z_s, z_s_rec) + F.l1_loss(z_s_hat, z_s_hat_rec)
136
 
137
  # adversarial loss
138
+ i_g_loss = (
139
+ self.i_discriminator(x1, for_G=True).mean()
140
+ + self.i_discriminator(x2, for_G=True).mean()
141
  ) / 2
142
 
143
+ c_g_loss = self.f_discriminator(z_c, z_s, for_real=False).mean()
144
 
145
  g_loss = (
146
  self.lambda_rec * rec_loss
147
  + self.lambda_c_const * c_const_loss
148
  + self.lambda_s_const * s_const_loss
149
+ + self.lambda_x_g * i_g_loss
150
  + self.lambda_c_g * c_g_loss
151
+ + self.lambda_cls * cls_loss
152
  )
153
 
154
  g_opt.zero_grad()
155
  self.manual_backward(g_loss)
156
  g_opt.step()
157
 
158
+ # train the image discriminator
159
+ i_d_loss = (
160
+ self.i_discriminator(x, for_real=True).mean()
161
  + (
162
+ self.i_discriminator(x1.detach(), for_real=False).mean()
163
+ + self.i_discriminator(x2.detach(), for_real=False).mean()
164
  )
165
  / 2
166
  ) / 2
167
 
168
+ i_d_opt.zero_grad()
169
+ self.manual_backward(i_d_loss)
170
+ i_d_opt.step()
171
 
172
+ # train the feature discriminator
173
+ f_d_loss = (
174
+ self.f_discriminator(z_c.detach(), z_s.detach(), for_real=True).mean()
175
  + (
176
+ self.f_discriminator(z_c.detach(), z_s_hat, for_real=False).mean()
177
+ + self.f_discriminator(z_c_hat, z_s.detach(), for_real=False).mean()
178
  )
179
  / 2
180
  ) / 2
181
 
182
+ f_d_opt.zero_grad()
183
+ self.manual_backward(f_d_loss)
184
+ f_d_opt.step()
185
 
186
  self.log_dict(
187
  {
188
  "train/rec_loss": rec_loss,
189
  "train/cls_loss": cls_loss,
190
+ "train/i_g_loss": i_g_loss,
191
+ "train/f_d_loss": f_d_loss,
192
  "train/c_g_loss": c_g_loss,
193
+ "train/f_d_loss": f_d_loss,
194
  "train/c_const_loss": c_const_loss,
195
  "train/s_const_loss": s_const_loss,
196
  },
 
201
 
202
  def validation_step(self, batch, batch_idx):
203
  x = batch["images"]
 
204
 
205
  x = x[torch.randperm(x.shape[0])]
206
+ z_c = self.content_encoder(x)
207
+ z_s = self.style_encoder(x)
208
+ x_rec = self.decoder(z_c, z_s)
 
209
 
210
  x1, x2 = x.chunk(2, dim=0)
211
 
 
212
  x1_rec, x2_rec = x_rec.chunk(2, dim=0)
213
 
214
  z1_c, z2_c = z_c.chunk(2, dim=0)
215
  z1_s, z2_s = z_s.chunk(2, dim=0)
216
+ x1_swap = self.decoder(z1_c, z2_s)
217
+ x2_swap = self.decoder(z2_c, z1_s)
 
 
218
 
219
  if self.trainer.is_global_zero:
220
  x1_img = self.postprocess_images(x1)
 
239
  wandb.Image(image, caption="orig | rec | swap") for image in images
240
  ]
241
 
242
+ # wandb.log({"val/samples": images})
243
 
244
  self.log("val/lpips", -self.global_step, sync_dist=True)
245
 
 
255
 
256
  def configure_optimizers(self):
257
  g_opt = torch.optim.AdamW(
258
+ list(self.content_encoder.parameters())
259
+ + list(self.style_encoder.parameters())
260
+ + list(self.style_classifier.parameters())
261
+ + list(self.decoder.parameters()),
262
  lr=self.learning_rate,
263
+ weight_decay=self.weight_decay,
 
 
 
 
 
 
 
264
  )
265
 
266
+ i_d_opt = torch.optim.AdamW(
267
+ list(self.i_discriminator.parameters()),
268
+ lr=self.learning_rate,
269
+ weight_decay=self.weight_decay,
270
  )
271
 
272
+ f_d_opt = torch.optim.AdamW(
273
+ list(self.f_discriminator.parameters()),
274
+ lr=self.learning_rate,
275
+ weight_decay=self.weight_decay,
276
  )
277
 
278
+ return [g_opt, i_d_opt, f_d_opt]
swim/train.py CHANGED
@@ -91,9 +91,6 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
91
  model.compile()
92
 
93
  log.info("Starting training!")
94
- from swim.models.swim_gan import SwimGAN
95
-
96
- # model = SwimGAN.load_from_checkpoint(cfg.get("ckpt_path"), strict=False)
97
  trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
98
 
99
  train_metrics = trainer.callback_metrics
 
91
  model.compile()
92
 
93
  log.info("Starting training!")
 
 
 
94
  trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
95
 
96
  train_metrics = trainer.callback_metrics