cc
Browse files- cc.py +17 -11
- configs/experiment/channel=64.yaml +22 -0
- configs/experiment/example.yaml +0 -41
- configs/experiment/potato.yaml +21 -0
- configs/model/swim_gan.yaml +0 -2
- swim/models/blocks.py +10 -20
- swim/models/content_encoder.py +35 -32
- swim/models/decoder.py +34 -44
- swim/models/discriminator.py +40 -43
- swim/models/style_encoder.py +14 -52
- swim/models/swim_gan.py +86 -117
- swim/train.py +0 -3
cc.py
CHANGED
@@ -4,18 +4,24 @@ from torchinfo import summary
|
|
4 |
from swim.models.content_encoder import ContentEncoder
|
5 |
from swim.models.decoder import Decoder
|
6 |
from swim.models.style_encoder import StyleEncoder
|
7 |
-
from swim.models.discriminator import
|
8 |
-
|
|
|
|
|
9 |
|
10 |
-
model = SwimGAN().cuda()
|
11 |
image = torch.randn(1, 3, 512, 512).to("cuda")
|
12 |
-
sample = torch.randn(1, 4, 64, 64).to("cuda")
|
13 |
style_emb = torch.randn(1, 256).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
# summary(decoder, input_data=(
|
17 |
-
# summary(style_encoder, input_data=
|
18 |
-
summary(
|
19 |
-
|
20 |
-
input_data=(image),
|
21 |
-
)
|
|
|
4 |
from swim.models.content_encoder import ContentEncoder
|
5 |
from swim.models.decoder import Decoder
|
6 |
from swim.models.style_encoder import StyleEncoder
|
7 |
+
from swim.models.discriminator import FeatureDiscriminator
|
8 |
+
import vision_aided_loss
|
9 |
+
|
10 |
+
# from swim.models.swim_gan import SwimGAN
|
11 |
|
|
|
12 |
image = torch.randn(1, 3, 512, 512).to("cuda")
|
|
|
13 |
style_emb = torch.randn(1, 256).to("cuda")
|
14 |
+
content = torch.randn(1, 512, 64, 64).to("cuda")
|
15 |
+
content_encoder = ContentEncoder().cuda()
|
16 |
+
decoder = Decoder().cuda()
|
17 |
+
style_encoder = StyleEncoder().cuda()
|
18 |
+
discriminator = FeatureDiscriminator().cuda()
|
19 |
+
# i_discriminator = vision_aided_loss.Discriminator(
|
20 |
+
# cv_type="clip", loss_type="multilevel_sigmoid_s", device="cuda"
|
21 |
+
# ).to("cuda")
|
22 |
|
23 |
+
summary(content_encoder, input_data=(image,))
|
24 |
+
# summary(decoder, input_data=(content, style_emb))
|
25 |
+
# summary(style_encoder, input_data=image)
|
26 |
+
# summary(discriminator, input_data=(content, style_emb))
|
27 |
+
# summary(i_discriminator, input_data=(image,))
|
|
|
|
configs/experiment/channel=64.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: swim
|
5 |
+
- override /model: swim_gan
|
6 |
+
- override /callbacks: default
|
7 |
+
- override /trainer: gpu
|
8 |
+
|
9 |
+
seed: 42
|
10 |
+
|
11 |
+
trainer:
|
12 |
+
max_epochs: 100
|
13 |
+
|
14 |
+
model:
|
15 |
+
channels: 64
|
16 |
+
z_c_channels: 256
|
17 |
+
n_enc_resnet_blocks: 4
|
18 |
+
n_dec_resnet_blocks: 6
|
19 |
+
n_f_d_resnet_blocks: 2
|
20 |
+
|
21 |
+
data:
|
22 |
+
batch_size: 2
|
configs/experiment/example.yaml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# to execute this experiment run:
|
4 |
-
# python train.py experiment=example
|
5 |
-
|
6 |
-
defaults:
|
7 |
-
- override /data: mnist
|
8 |
-
- override /model: mnist
|
9 |
-
- override /callbacks: default
|
10 |
-
- override /trainer: default
|
11 |
-
|
12 |
-
# all parameters below will be merged with parameters from default configurations set above
|
13 |
-
# this allows you to overwrite only specified parameters
|
14 |
-
|
15 |
-
tags: ["mnist", "simple_dense_net"]
|
16 |
-
|
17 |
-
seed: 12345
|
18 |
-
|
19 |
-
trainer:
|
20 |
-
min_epochs: 10
|
21 |
-
max_epochs: 10
|
22 |
-
gradient_clip_val: 0.5
|
23 |
-
|
24 |
-
model:
|
25 |
-
optimizer:
|
26 |
-
lr: 0.002
|
27 |
-
net:
|
28 |
-
lin1_size: 128
|
29 |
-
lin2_size: 256
|
30 |
-
lin3_size: 64
|
31 |
-
compile: false
|
32 |
-
|
33 |
-
data:
|
34 |
-
batch_size: 64
|
35 |
-
|
36 |
-
logger:
|
37 |
-
wandb:
|
38 |
-
tags: ${tags}
|
39 |
-
group: "mnist"
|
40 |
-
aim:
|
41 |
-
experiment: "mnist"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/experiment/potato.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: swim
|
5 |
+
- override /model: swim_gan
|
6 |
+
- override /callbacks: default
|
7 |
+
- override /trainer: gpu
|
8 |
+
|
9 |
+
seed: 42
|
10 |
+
|
11 |
+
trainer:
|
12 |
+
max_epochs: 100
|
13 |
+
|
14 |
+
model:
|
15 |
+
channels: 32
|
16 |
+
z_c_channels: 128
|
17 |
+
n_enc_resnet_blocks: 1
|
18 |
+
n_dec_resnet_blocks: 1
|
19 |
+
|
20 |
+
data:
|
21 |
+
batch_size: 2
|
configs/model/swim_gan.yaml
CHANGED
@@ -1,3 +1 @@
|
|
1 |
_target_: swim.models.swim_gan.SwimGAN
|
2 |
-
|
3 |
-
learning_rate: 1e-4
|
|
|
1 |
_target_: swim.models.swim_gan.SwimGAN
|
|
|
|
swim/models/blocks.py
CHANGED
@@ -77,11 +77,9 @@ class DownSample(nn.Module):
|
|
77 |
class ResnetBlock(nn.Module):
|
78 |
def __init__(self, in_channels: int, out_channels: int, cond_channels: int = 0):
|
79 |
super().__init__()
|
80 |
-
# First normalization and convolution layer
|
81 |
-
self.norm1 = normalization(in_channels)
|
82 |
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
|
|
|
83 |
|
84 |
-
# cond layer
|
85 |
self.cond_channels = cond_channels
|
86 |
if cond_channels > 0:
|
87 |
self.cond_proj = nn.Sequential(
|
@@ -90,40 +88,32 @@ class ResnetBlock(nn.Module):
|
|
90 |
nn.Linear(cond_channels, out_channels * 2),
|
91 |
)
|
92 |
|
93 |
-
# Second normalization and convolution layer
|
94 |
-
self.norm2 = normalization(out_channels)
|
95 |
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
|
96 |
-
|
|
|
97 |
if in_channels != out_channels:
|
98 |
-
self.
|
99 |
-
in_channels, out_channels, 1, stride=1, padding=0
|
100 |
-
)
|
101 |
else:
|
102 |
-
self.
|
103 |
|
104 |
def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
|
105 |
h = x
|
106 |
|
107 |
-
|
108 |
h = self.norm1(h)
|
109 |
h = F.silu(h)
|
110 |
-
h = self.conv1(h)
|
111 |
|
112 |
-
# cond layer
|
113 |
if cond is not None:
|
114 |
cond = self.cond_proj(cond)[..., None, None]
|
115 |
cond_scale, cond_shift = torch.chunk(cond, 2, dim=1)
|
116 |
-
h = self.norm2(h)
|
117 |
h = h * (1 + cond_scale) + cond_shift
|
118 |
-
else:
|
119 |
-
h = self.norm2(h)
|
120 |
|
121 |
-
# Second normalization and convolution layer
|
122 |
-
h = F.silu(h)
|
123 |
h = self.conv2(h)
|
|
|
|
|
|
|
124 |
|
125 |
-
|
126 |
-
return self.nin_shortcut(x) + h
|
127 |
|
128 |
|
129 |
def normalization(channels):
|
|
|
77 |
class ResnetBlock(nn.Module):
|
78 |
def __init__(self, in_channels: int, out_channels: int, cond_channels: int = 0):
|
79 |
super().__init__()
|
|
|
|
|
80 |
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
|
81 |
+
self.norm1 = normalization(in_channels)
|
82 |
|
|
|
83 |
self.cond_channels = cond_channels
|
84 |
if cond_channels > 0:
|
85 |
self.cond_proj = nn.Sequential(
|
|
|
88 |
nn.Linear(cond_channels, out_channels * 2),
|
89 |
)
|
90 |
|
|
|
|
|
91 |
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
|
92 |
+
self.norm2 = normalization(out_channels)
|
93 |
+
|
94 |
if in_channels != out_channels:
|
95 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
|
|
|
|
|
96 |
else:
|
97 |
+
self.shortcut = nn.Identity()
|
98 |
|
99 |
def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
|
100 |
h = x
|
101 |
|
102 |
+
h = self.conv1(h)
|
103 |
h = self.norm1(h)
|
104 |
h = F.silu(h)
|
|
|
105 |
|
|
|
106 |
if cond is not None:
|
107 |
cond = self.cond_proj(cond)[..., None, None]
|
108 |
cond_scale, cond_shift = torch.chunk(cond, 2, dim=1)
|
|
|
109 |
h = h * (1 + cond_scale) + cond_shift
|
|
|
|
|
110 |
|
|
|
|
|
111 |
h = self.conv2(h)
|
112 |
+
h = self.norm2(h)
|
113 |
+
h = h + self.shortcut(x)
|
114 |
+
h = F.silu(h)
|
115 |
|
116 |
+
return h
|
|
|
117 |
|
118 |
|
119 |
def normalization(channels):
|
swim/models/content_encoder.py
CHANGED
@@ -4,53 +4,56 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
-
from .blocks import ResnetBlock,
|
8 |
|
9 |
|
10 |
class ContentEncoder(nn.Module):
|
11 |
def __init__(
|
12 |
self,
|
13 |
-
in_channels: int = 4,
|
14 |
-
z_c_channels: int = 128,
|
15 |
channels: int = 128,
|
16 |
-
|
17 |
-
|
|
|
18 |
):
|
19 |
super().__init__()
|
20 |
-
n_resolutions = len(channel_multipliers)
|
21 |
|
22 |
-
self.conv_in = nn.Conv2d(
|
|
|
23 |
|
24 |
-
|
25 |
|
26 |
-
self.
|
27 |
-
for
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
self.
|
40 |
|
41 |
-
self.norm_out = normalization(channels)
|
42 |
self.conv_out = nn.Conv2d(channels, z_c_channels, 3, stride=1, padding=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
for
|
48 |
-
|
49 |
-
x = block(x)
|
50 |
-
x = down.downsample(x)
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
-
return
|
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
+
from .blocks import ResnetBlock, normalization
|
8 |
|
9 |
|
10 |
class ContentEncoder(nn.Module):
|
11 |
def __init__(
|
12 |
self,
|
|
|
|
|
13 |
channels: int = 128,
|
14 |
+
z_c_channels: int = 512,
|
15 |
+
downsample_channel_mults: List[int] = [1, 2, 4],
|
16 |
+
n_resnet_blocks: int = 4,
|
17 |
):
|
18 |
super().__init__()
|
|
|
19 |
|
20 |
+
self.conv_in = nn.Conv2d(3, channels, 7, stride=1, padding=3)
|
21 |
+
self.norm_in = normalization(channels)
|
22 |
|
23 |
+
channel_list = [channels * mult for mult in downsample_channel_mults]
|
24 |
|
25 |
+
self.downsamples = nn.ModuleList()
|
26 |
+
for out_channels in channel_list:
|
27 |
+
self.downsamples.append(
|
28 |
+
nn.Sequential(
|
29 |
+
nn.Conv2d(channels, out_channels, 4, stride=2, padding=1),
|
30 |
+
normalization(out_channels),
|
31 |
+
nn.SiLU(),
|
32 |
+
)
|
33 |
+
)
|
34 |
|
35 |
+
channels = out_channels
|
36 |
+
|
37 |
+
self.resnet_blocks = nn.ModuleList()
|
38 |
+
for _ in range(n_resnet_blocks):
|
39 |
+
self.resnet_blocks.append(ResnetBlock(channels, channels))
|
40 |
|
|
|
41 |
self.conv_out = nn.Conv2d(channels, z_c_channels, 3, stride=1, padding=1)
|
42 |
+
self.norm_out = normalization(z_c_channels)
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
h = self.conv_in(x)
|
46 |
+
h = self.norm_in(h)
|
47 |
+
h = F.silu(h)
|
48 |
|
49 |
+
for downsample in self.downsamples:
|
50 |
+
h = downsample(h)
|
51 |
|
52 |
+
for resnet_block in self.resnet_blocks:
|
53 |
+
h = resnet_block(h)
|
|
|
|
|
54 |
|
55 |
+
h = self.conv_out(h)
|
56 |
+
h = self.norm_out(h)
|
57 |
+
h = F.silu(h)
|
58 |
|
59 |
+
return h
|
swim/models/decoder.py
CHANGED
@@ -4,68 +4,58 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
-
from .blocks import
|
8 |
-
ResnetBlock,
|
9 |
-
AttentionBlock,
|
10 |
-
StyledSequential,
|
11 |
-
UpSample,
|
12 |
-
normalization,
|
13 |
-
)
|
14 |
|
15 |
|
16 |
class Decoder(nn.Module):
|
|
|
17 |
def __init__(
|
18 |
self,
|
19 |
-
z_c_channels: int = 128,
|
20 |
-
out_channels: int = 4,
|
21 |
channels: int = 128,
|
22 |
-
|
23 |
-
|
|
|
24 |
d_style_emb: int = 256,
|
25 |
):
|
26 |
super().__init__()
|
27 |
-
num_resolutions = len(channel_multipliers)
|
28 |
|
29 |
-
|
|
|
|
|
|
|
30 |
|
31 |
-
channels =
|
32 |
|
33 |
-
self.
|
|
|
|
|
34 |
|
35 |
-
self.
|
36 |
-
for
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
)
|
43 |
)
|
44 |
-
|
45 |
|
46 |
-
|
47 |
-
up.block = resnet_blocks
|
48 |
-
if i != 0:
|
49 |
-
up.upsample = UpSample(channels)
|
50 |
-
else:
|
51 |
-
up.upsample = nn.Identity()
|
52 |
-
self.up.insert(0, up)
|
53 |
|
54 |
-
self.
|
55 |
-
self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
|
56 |
|
57 |
-
def forward(self,
|
|
|
|
|
|
|
58 |
|
59 |
-
|
|
|
60 |
|
61 |
-
for
|
62 |
-
|
63 |
-
h = block(h, z_s)
|
64 |
-
h = up.upsample(h)
|
65 |
|
66 |
-
h = self.
|
67 |
-
h =
|
68 |
-
img = self.conv_out(h)
|
69 |
|
70 |
-
|
71 |
-
return img
|
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
+
from .blocks import ResnetBlock, normalization
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class Decoder(nn.Module):
|
11 |
+
|
12 |
def __init__(
|
13 |
self,
|
|
|
|
|
14 |
channels: int = 128,
|
15 |
+
z_c_channels: int = 512,
|
16 |
+
upsample_channel_mults: List[int] = [1, 2, 4],
|
17 |
+
n_resnet_blocks: int = 6,
|
18 |
d_style_emb: int = 256,
|
19 |
):
|
20 |
super().__init__()
|
|
|
21 |
|
22 |
+
channel_list = [channels * mult for mult in upsample_channel_mults]
|
23 |
+
|
24 |
+
self.conv_in = nn.Conv2d(z_c_channels, channel_list[-1], 3, stride=1, padding=1)
|
25 |
+
self.norm_in = normalization(channel_list[-1])
|
26 |
|
27 |
+
channels = channel_list[-1]
|
28 |
|
29 |
+
self.resnet_blocks = nn.ModuleList()
|
30 |
+
for _ in range(n_resnet_blocks):
|
31 |
+
self.resnet_blocks.append(ResnetBlock(channels, channels, d_style_emb))
|
32 |
|
33 |
+
self.upsamples = nn.ModuleList()
|
34 |
+
for out_channels in channel_list[::-1]:
|
35 |
+
self.upsamples.append(
|
36 |
+
nn.Sequential(
|
37 |
+
nn.ConvTranspose2d(channels, out_channels, 4, stride=2, padding=1),
|
38 |
+
normalization(out_channels),
|
39 |
+
nn.SiLU(),
|
|
|
40 |
)
|
41 |
+
)
|
42 |
|
43 |
+
channels = out_channels
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
self.conv_out = nn.Conv2d(channels, 3, 7, stride=1, padding=3)
|
|
|
46 |
|
47 |
+
def forward(self, x: torch.Tensor, style_emb: torch.Tensor):
|
48 |
+
h = self.conv_in(x)
|
49 |
+
h = self.norm_in(h)
|
50 |
+
h = F.silu(h)
|
51 |
|
52 |
+
for resnet_block in self.resnet_blocks:
|
53 |
+
h = resnet_block(h, style_emb)
|
54 |
|
55 |
+
for upsample in self.upsamples:
|
56 |
+
h = upsample(h)
|
|
|
|
|
57 |
|
58 |
+
h = self.conv_out(h)
|
59 |
+
h = torch.tanh(h)
|
|
|
60 |
|
61 |
+
return h
|
|
swim/models/discriminator.py
CHANGED
@@ -10,10 +10,8 @@ class SNResnetBlock(nn.Module):
|
|
10 |
in_channels: int,
|
11 |
out_channels: int,
|
12 |
cond_channels: int = 0,
|
13 |
-
downsample: bool = False,
|
14 |
):
|
15 |
super().__init__()
|
16 |
-
self.downsample = downsample
|
17 |
self.d_cond = cond_channels
|
18 |
self.conv1 = spectral_norm(
|
19 |
nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
|
@@ -32,85 +30,84 @@ class SNResnetBlock(nn.Module):
|
|
32 |
)
|
33 |
|
34 |
if in_channels != out_channels:
|
35 |
-
self.
|
36 |
nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
|
37 |
)
|
38 |
else:
|
39 |
-
self.
|
40 |
|
41 |
def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
|
42 |
h = x
|
43 |
|
44 |
-
h = F.leaky_relu(h, 0.2)
|
45 |
h = self.conv1(h)
|
|
|
46 |
|
47 |
if self.d_cond > 0:
|
48 |
cond = self.cond_proj(cond)[..., None, None]
|
49 |
cond_scale, cond_shift = torch.chunk(cond, 2, dim=1)
|
50 |
h = h * (1 + cond_scale) + cond_shift
|
51 |
-
else:
|
52 |
-
assert cond is None
|
53 |
|
54 |
-
h = F.leaky_relu(h, 0.2)
|
55 |
h = self.conv2(h)
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
if self.downsample:
|
60 |
-
h = F.avg_pool2d(h, 2)
|
61 |
-
h_skip = F.avg_pool2d(h_skip, 2)
|
62 |
-
|
63 |
-
return h + h_skip
|
64 |
|
65 |
|
66 |
-
class
|
67 |
|
68 |
def __init__(
|
69 |
self,
|
70 |
-
|
71 |
-
channels,
|
72 |
-
|
73 |
-
|
74 |
):
|
75 |
super().__init__()
|
76 |
|
77 |
-
self.
|
|
|
|
|
78 |
|
79 |
-
self.
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
self.blocks.append(
|
85 |
-
SNResnetBlock(
|
86 |
-
channels, channels_list[i], cond_channels=d_cond, downsample=True
|
87 |
-
)
|
88 |
-
)
|
89 |
-
channels = channels_list[i]
|
90 |
|
91 |
-
self.
|
92 |
-
|
|
|
|
|
93 |
)
|
94 |
|
95 |
def forward(
|
96 |
-
self, x: torch.Tensor,
|
97 |
):
|
98 |
-
h = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
|
101 |
-
|
|
|
102 |
|
103 |
-
|
|
|
104 |
|
105 |
if for_real:
|
106 |
target = torch.ones_like(h)
|
107 |
else:
|
108 |
target = torch.zeros_like(h)
|
109 |
|
110 |
-
loss = F.binary_cross_entropy_with_logits(h, target)
|
111 |
-
# if for_real:
|
112 |
-
# loss = torch.relu(1 - h).mean()
|
113 |
-
# else:
|
114 |
-
# loss = torch.relu(1 + h).mean()
|
115 |
|
116 |
return loss
|
|
|
10 |
in_channels: int,
|
11 |
out_channels: int,
|
12 |
cond_channels: int = 0,
|
|
|
13 |
):
|
14 |
super().__init__()
|
|
|
15 |
self.d_cond = cond_channels
|
16 |
self.conv1 = spectral_norm(
|
17 |
nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
|
|
|
30 |
)
|
31 |
|
32 |
if in_channels != out_channels:
|
33 |
+
self.shortcut = spectral_norm(
|
34 |
nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
|
35 |
)
|
36 |
else:
|
37 |
+
self.shortcut = nn.Identity()
|
38 |
|
39 |
def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
|
40 |
h = x
|
41 |
|
|
|
42 |
h = self.conv1(h)
|
43 |
+
h = F.leaky_relu(h, 0.2)
|
44 |
|
45 |
if self.d_cond > 0:
|
46 |
cond = self.cond_proj(cond)[..., None, None]
|
47 |
cond_scale, cond_shift = torch.chunk(cond, 2, dim=1)
|
48 |
h = h * (1 + cond_scale) + cond_shift
|
|
|
|
|
49 |
|
|
|
50 |
h = self.conv2(h)
|
51 |
+
h = h + self.shortcut(x)
|
52 |
+
h = F.leaky_relu(h, 0.2)
|
53 |
|
54 |
+
return h
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
+
class FeatureDiscriminator(nn.Module):
|
58 |
|
59 |
def __init__(
|
60 |
self,
|
61 |
+
z_c_channels: int = 512,
|
62 |
+
channels: int = 512,
|
63 |
+
d_style_emb: int = 256,
|
64 |
+
n_resnet_blocks: int = 2,
|
65 |
):
|
66 |
super().__init__()
|
67 |
|
68 |
+
self.conv_in = spectral_norm(
|
69 |
+
nn.Conv2d(z_c_channels, channels, 3, stride=1, padding=1)
|
70 |
+
)
|
71 |
|
72 |
+
self.resnet_blocks = nn.ModuleList()
|
73 |
+
for _ in range(n_resnet_blocks):
|
74 |
+
self.resnet_blocks.append(SNResnetBlock(channels, channels, d_style_emb))
|
75 |
|
76 |
+
self.conv_out = spectral_norm(
|
77 |
+
nn.Conv2d(channels, channels, 3, stride=1, padding=1)
|
78 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
self.mlp = nn.Sequential(
|
81 |
+
spectral_norm(nn.Linear(channels, 256)),
|
82 |
+
nn.LeakyReLU(0.2),
|
83 |
+
spectral_norm(nn.Linear(256, 1)),
|
84 |
)
|
85 |
|
86 |
def forward(
|
87 |
+
self, x: torch.Tensor, style_emb: torch.Tensor, for_G=False, for_real=False
|
88 |
):
|
89 |
+
h = self.conv_in(x)
|
90 |
+
h = F.leaky_relu(h, 0.2)
|
91 |
+
|
92 |
+
for resnet_block in self.resnet_blocks:
|
93 |
+
h = resnet_block(h, style_emb)
|
94 |
+
h = F.avg_pool2d(h, 2)
|
95 |
+
|
96 |
+
h = self.conv_out(h)
|
97 |
+
h = F.leaky_relu(h, 0.2)
|
98 |
|
99 |
+
h = F.adaptive_avg_pool2d(h, 1)
|
100 |
+
h = torch.flatten(h, 1)
|
101 |
+
h = self.mlp(h)
|
102 |
|
103 |
+
if for_G:
|
104 |
+
for_real = True
|
105 |
|
106 |
if for_real:
|
107 |
target = torch.ones_like(h)
|
108 |
else:
|
109 |
target = torch.zeros_like(h)
|
110 |
|
111 |
+
loss = F.binary_cross_entropy_with_logits(h, target, reduction="none")
|
|
|
|
|
|
|
|
|
112 |
|
113 |
return loss
|
swim/models/style_encoder.py
CHANGED
@@ -1,62 +1,24 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
4 |
|
5 |
-
from lightning import LightningModule
|
6 |
|
7 |
-
|
8 |
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
def __init__(
|
12 |
-
self,
|
13 |
-
in_channels: int = 4,
|
14 |
-
n_styles: int = 4,
|
15 |
-
d_style_emb: int = 256,
|
16 |
-
d_hidden: int = 512,
|
17 |
-
channels: int = 64,
|
18 |
-
n_layers: int = 4,
|
19 |
-
):
|
20 |
-
super().__init__()
|
21 |
|
22 |
-
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
nn.Conv2d(in_channels, channels, kernel_size=3, padding=1),
|
27 |
-
nn.GroupNorm(32, channels),
|
28 |
-
nn.SiLU(),
|
29 |
-
)
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
for i in range(n_layers):
|
34 |
-
self.blocks.append(
|
35 |
-
nn.Sequential(
|
36 |
-
nn.Conv2d(
|
37 |
-
channels, channels * 2, kernel_size=3, stride=2, padding=1
|
38 |
-
), # Downsample
|
39 |
-
nn.GroupNorm(32, channels * 2),
|
40 |
-
nn.SiLU(),
|
41 |
-
)
|
42 |
-
)
|
43 |
-
channels *= 2
|
44 |
-
|
45 |
-
# Output MLP
|
46 |
-
self.out = nn.Sequential(
|
47 |
-
nn.AdaptiveAvgPool2d(1), # Global average pooling
|
48 |
-
nn.Flatten(), # Flatten spatial dimensions
|
49 |
-
nn.Linear(channels, d_hidden), # First dense layer
|
50 |
-
nn.SiLU(),
|
51 |
-
nn.Linear(d_hidden, d_style_emb), # Final style embedding
|
52 |
-
)
|
53 |
-
|
54 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
55 |
-
h = self.conv_in(x) # Initial convolution
|
56 |
-
|
57 |
-
for block in self.blocks:
|
58 |
-
h = block(h) # Pass through each block
|
59 |
-
|
60 |
-
h = self.out(h) # Pool and process for style embedding
|
61 |
-
|
62 |
-
return h
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
+
from torchvision.models.resnet import resnet18, ResNet18_Weights
|
5 |
|
|
|
6 |
|
7 |
+
class StyleEncoder(nn.Module):
|
8 |
|
9 |
+
def __init__(self, d_style_emb=256):
|
10 |
+
super(StyleEncoder, self).__init__()
|
11 |
+
self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
|
12 |
+
self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
|
13 |
|
14 |
+
self.fc = nn.Linear(512, d_style_emb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
def forward(self, x):
|
17 |
+
# resize input to 224x224
|
18 |
+
# x = F.interpolate(x, size=(224, 224), mode="bilinear")
|
19 |
|
20 |
+
x = self.resnet(x)
|
21 |
+
x = torch.flatten(x, 1)
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
x = self.fc(x)
|
24 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
swim/models/swim_gan.py
CHANGED
@@ -12,12 +12,14 @@ from lightning import LightningModule
|
|
12 |
from diffusers import AutoencoderKL
|
13 |
from diffusers.utils import make_image_grid
|
14 |
|
15 |
-
from swim.utils.tensor_pool import GroupTensorPool
|
16 |
|
17 |
from .style_encoder import StyleEncoder
|
18 |
from .content_encoder import ContentEncoder
|
19 |
from .decoder import Decoder
|
20 |
-
from .discriminator import
|
|
|
|
|
21 |
|
22 |
|
23 |
class SwimGAN(LightningModule):
|
@@ -26,16 +28,18 @@ class SwimGAN(LightningModule):
|
|
26 |
self,
|
27 |
channels: int = 128,
|
28 |
z_c_channels: int = 512,
|
29 |
-
|
30 |
-
|
|
|
|
|
31 |
n_styles: int = 5,
|
32 |
d_style_emb: int = 128,
|
33 |
input_size: int = 512,
|
34 |
learning_rate: float = 1e-5,
|
35 |
-
weight_decay: float =
|
36 |
lambda_cls: float = 10.0,
|
37 |
-
lambda_rec: float =
|
38 |
-
lambda_cycle: float =
|
39 |
lambda_c_g: float = 1.0,
|
40 |
lambda_x_g: float = 1.0,
|
41 |
lambda_c_const: float = 1.0,
|
@@ -58,43 +62,34 @@ class SwimGAN(LightningModule):
|
|
58 |
self.lambda_x_g = lambda_x_g
|
59 |
|
60 |
self.content_encoder = ContentEncoder(
|
61 |
-
in_channels=4,
|
62 |
-
z_c_channels=z_c_channels,
|
63 |
channels=channels,
|
64 |
-
|
65 |
-
|
|
|
66 |
)
|
67 |
|
68 |
-
self.style_encoder = StyleEncoder(
|
69 |
|
70 |
self.decoder = Decoder(
|
71 |
-
z_c_channels=z_c_channels,
|
72 |
-
out_channels=4,
|
73 |
channels=channels,
|
74 |
-
|
75 |
-
|
|
|
76 |
d_style_emb=d_style_emb,
|
77 |
)
|
78 |
|
79 |
-
self.vae: AutoencoderKL = AutoencoderKL.from_pretrained(
|
80 |
-
"stabilityai/sd-turbo", subfolder="vae"
|
81 |
-
)
|
82 |
-
self.vae.requires_grad_(False)
|
83 |
-
self.vae.eval()
|
84 |
-
|
85 |
self.style_classifier = nn.Linear(d_style_emb, n_styles)
|
86 |
|
87 |
# training only
|
88 |
-
self.
|
89 |
-
|
90 |
-
channels=128,
|
91 |
-
channel_multipliers=[1, 2, 4, 8],
|
92 |
)
|
93 |
-
|
94 |
-
|
|
|
95 |
channels=z_c_channels,
|
96 |
-
|
97 |
-
|
98 |
)
|
99 |
|
100 |
self.style_pool = GroupTensorPool(n_styles, 256)
|
@@ -102,117 +97,100 @@ class SwimGAN(LightningModule):
|
|
102 |
|
103 |
self.cls_loss = nn.CrossEntropyLoss()
|
104 |
|
105 |
-
def
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
return self.vae.decode(z / self.vae.config.scaling_factor).sample.clamp(-1, 1)
|
110 |
-
|
111 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
112 |
-
with torch.no_grad():
|
113 |
-
z = self.vae_encode(x)
|
114 |
-
z_c = self.content_encoder(z)
|
115 |
-
z_s = self.style_encoder(z)
|
116 |
-
z_rec = self.decoder(z_c, z_s)
|
117 |
-
|
118 |
-
return z, z_c, z_s, z_rec
|
119 |
|
120 |
def training_step(self, batch, batch_idx):
|
121 |
x = batch["images"]
|
122 |
gt_style = batch["styles"]
|
123 |
|
124 |
-
g_opt,
|
125 |
|
126 |
-
|
|
|
|
|
|
|
127 |
|
128 |
-
# train the cls
|
129 |
-
z_s = self.style_encoder(z)
|
130 |
style_logits = self.style_classifier(z_s)
|
131 |
cls_loss = self.cls_loss(style_logits, gt_style)
|
132 |
|
133 |
-
|
134 |
-
self.manual_backward(cls_loss)
|
135 |
-
cls_opt.step()
|
136 |
-
|
137 |
-
# train the autoencoder
|
138 |
-
z_s = z_s.detach()
|
139 |
-
z_c = self.content_encoder(z)
|
140 |
-
z_rec = self.decoder(z_c, z_s)
|
141 |
-
|
142 |
-
rec_loss = F.l1_loss(z, z_rec)
|
143 |
|
144 |
# # sample a random content and style feature
|
145 |
z_c_hat, _ = self.content_pool.query(z_c, gt_style)
|
146 |
z_s_hat, _ = self.style_pool.query(z_s, gt_style)
|
147 |
|
148 |
-
|
149 |
-
|
150 |
|
151 |
-
z_c_rec = self.content_encoder(
|
152 |
-
z_s_hat_rec = self.style_encoder(
|
153 |
|
154 |
-
z_c_hat_rec = self.content_encoder(
|
155 |
-
z_s_rec = self.style_encoder(
|
156 |
|
157 |
c_const_loss = F.l1_loss(z_c, z_c_rec) + F.l1_loss(z_c_hat, z_c_hat_rec)
|
158 |
s_const_loss = F.l1_loss(z_s, z_s_rec) + F.l1_loss(z_s_hat, z_s_hat_rec)
|
159 |
|
160 |
# adversarial loss
|
161 |
-
|
162 |
-
self.
|
163 |
-
+ self.
|
164 |
) / 2
|
165 |
|
166 |
-
c_g_loss = self.
|
167 |
|
168 |
g_loss = (
|
169 |
self.lambda_rec * rec_loss
|
170 |
+ self.lambda_c_const * c_const_loss
|
171 |
+ self.lambda_s_const * s_const_loss
|
172 |
-
+ self.lambda_x_g *
|
173 |
+ self.lambda_c_g * c_g_loss
|
|
|
174 |
)
|
175 |
|
176 |
g_opt.zero_grad()
|
177 |
self.manual_backward(g_loss)
|
178 |
g_opt.step()
|
179 |
|
180 |
-
# train the
|
181 |
-
|
182 |
-
self.
|
183 |
+ (
|
184 |
-
self.
|
185 |
-
+ self.
|
186 |
)
|
187 |
/ 2
|
188 |
) / 2
|
189 |
|
190 |
-
|
191 |
-
self.manual_backward(
|
192 |
-
|
193 |
|
194 |
-
# train the
|
195 |
-
|
196 |
-
self.
|
197 |
+ (
|
198 |
-
self.
|
199 |
-
+ self.
|
200 |
)
|
201 |
/ 2
|
202 |
) / 2
|
203 |
|
204 |
-
|
205 |
-
self.manual_backward(
|
206 |
-
|
207 |
|
208 |
self.log_dict(
|
209 |
{
|
210 |
"train/rec_loss": rec_loss,
|
211 |
"train/cls_loss": cls_loss,
|
212 |
-
"train/
|
213 |
-
"train/
|
214 |
"train/c_g_loss": c_g_loss,
|
215 |
-
"train/
|
216 |
"train/c_const_loss": c_const_loss,
|
217 |
"train/s_const_loss": s_const_loss,
|
218 |
},
|
@@ -223,25 +201,20 @@ class SwimGAN(LightningModule):
|
|
223 |
|
224 |
def validation_step(self, batch, batch_idx):
|
225 |
x = batch["images"]
|
226 |
-
gt_style_logits = batch["styles"] # B x n_styles
|
227 |
|
228 |
x = x[torch.randperm(x.shape[0])]
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
z_rec = self.decoder(z_c, z_s)
|
233 |
|
234 |
x1, x2 = x.chunk(2, dim=0)
|
235 |
|
236 |
-
x_rec = self.vae_decode(z_rec)
|
237 |
x1_rec, x2_rec = x_rec.chunk(2, dim=0)
|
238 |
|
239 |
z1_c, z2_c = z_c.chunk(2, dim=0)
|
240 |
z1_s, z2_s = z_s.chunk(2, dim=0)
|
241 |
-
|
242 |
-
|
243 |
-
x1_swap = self.vae_decode(z1_swap)
|
244 |
-
x2_swap = self.vae_decode(z2_swap)
|
245 |
|
246 |
if self.trainer.is_global_zero:
|
247 |
x1_img = self.postprocess_images(x1)
|
@@ -266,7 +239,7 @@ class SwimGAN(LightningModule):
|
|
266 |
wandb.Image(image, caption="orig | rec | swap") for image in images
|
267 |
]
|
268 |
|
269 |
-
wandb.log({"val/samples": images})
|
270 |
|
271 |
self.log("val/lpips", -self.global_step, sync_dist=True)
|
272 |
|
@@ -282,28 +255,24 @@ class SwimGAN(LightningModule):
|
|
282 |
|
283 |
def configure_optimizers(self):
|
284 |
g_opt = torch.optim.AdamW(
|
285 |
-
list(self.content_encoder.parameters())
|
|
|
|
|
|
|
286 |
lr=self.learning_rate,
|
287 |
-
weight_decay=
|
288 |
-
)
|
289 |
-
|
290 |
-
cls_opt = torch.optim.AdamW(
|
291 |
-
list(self.style_encoder.parameters())
|
292 |
-
+ list(self.style_classifier.parameters()),
|
293 |
-
lr=self.learning_rate * 10,
|
294 |
-
weight_decay=1e-4,
|
295 |
)
|
296 |
|
297 |
-
|
298 |
-
list(self.
|
299 |
-
lr=self.learning_rate
|
300 |
-
weight_decay=
|
301 |
)
|
302 |
|
303 |
-
|
304 |
-
list(self.
|
305 |
-
lr=self.learning_rate
|
306 |
-
weight_decay=
|
307 |
)
|
308 |
|
309 |
-
return [g_opt,
|
|
|
12 |
from diffusers import AutoencoderKL
|
13 |
from diffusers.utils import make_image_grid
|
14 |
|
15 |
+
from swim.utils.tensor_pool import GroupTensorPool
|
16 |
|
17 |
from .style_encoder import StyleEncoder
|
18 |
from .content_encoder import ContentEncoder
|
19 |
from .decoder import Decoder
|
20 |
+
from .discriminator import FeatureDiscriminator
|
21 |
+
|
22 |
+
import vision_aided_loss
|
23 |
|
24 |
|
25 |
class SwimGAN(LightningModule):
|
|
|
28 |
self,
|
29 |
channels: int = 128,
|
30 |
z_c_channels: int = 512,
|
31 |
+
updown_channel_mults: List[int] = [1, 2, 4],
|
32 |
+
n_enc_resnet_blocks: int = 4,
|
33 |
+
n_dec_resnet_blocks: int = 6,
|
34 |
+
n_f_d_resnet_blocks: int = 2,
|
35 |
n_styles: int = 5,
|
36 |
d_style_emb: int = 128,
|
37 |
input_size: int = 512,
|
38 |
learning_rate: float = 1e-5,
|
39 |
+
weight_decay: float = 1e-4,
|
40 |
lambda_cls: float = 10.0,
|
41 |
+
lambda_rec: float = 10.0,
|
42 |
+
lambda_cycle: float = 10.0,
|
43 |
lambda_c_g: float = 1.0,
|
44 |
lambda_x_g: float = 1.0,
|
45 |
lambda_c_const: float = 1.0,
|
|
|
62 |
self.lambda_x_g = lambda_x_g
|
63 |
|
64 |
self.content_encoder = ContentEncoder(
|
|
|
|
|
65 |
channels=channels,
|
66 |
+
z_c_channels=z_c_channels,
|
67 |
+
downsample_channel_mults=[1, 2, 4],
|
68 |
+
n_resnet_blocks=n_enc_resnet_blocks,
|
69 |
)
|
70 |
|
71 |
+
self.style_encoder = StyleEncoder(d_style_emb=d_style_emb)
|
72 |
|
73 |
self.decoder = Decoder(
|
|
|
|
|
74 |
channels=channels,
|
75 |
+
z_c_channels=z_c_channels,
|
76 |
+
upsample_channel_mults=updown_channel_mults,
|
77 |
+
n_resnet_blocks=n_dec_resnet_blocks,
|
78 |
d_style_emb=d_style_emb,
|
79 |
)
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
self.style_classifier = nn.Linear(d_style_emb, n_styles)
|
82 |
|
83 |
# training only
|
84 |
+
self.i_discriminator = vision_aided_loss.Discriminator(
|
85 |
+
cv_type="clip", loss_type="multilevel_sigmoid", device="cpu"
|
|
|
|
|
86 |
)
|
87 |
+
|
88 |
+
self.f_discriminator = FeatureDiscriminator(
|
89 |
+
z_c_channels=z_c_channels,
|
90 |
channels=z_c_channels,
|
91 |
+
d_style_emb=d_style_emb,
|
92 |
+
n_resnet_blocks=n_f_d_resnet_blocks,
|
93 |
)
|
94 |
|
95 |
self.style_pool = GroupTensorPool(n_styles, 256)
|
|
|
97 |
|
98 |
self.cls_loss = nn.CrossEntropyLoss()
|
99 |
|
100 |
+
def on_fit_start(self):
|
101 |
+
for model in self.i_discriminator.cv_ensemble.models:
|
102 |
+
model.to(self.device)
|
103 |
+
model.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
def training_step(self, batch, batch_idx):
|
106 |
x = batch["images"]
|
107 |
gt_style = batch["styles"]
|
108 |
|
109 |
+
g_opt, i_d_opt, f_d_opt = self.optimizers()
|
110 |
|
111 |
+
# train the autoencoder
|
112 |
+
z_s = self.style_encoder(x)
|
113 |
+
z_c = self.content_encoder(x)
|
114 |
+
x_rec = self.decoder(z_c, z_s)
|
115 |
|
|
|
|
|
116 |
style_logits = self.style_classifier(z_s)
|
117 |
cls_loss = self.cls_loss(style_logits, gt_style)
|
118 |
|
119 |
+
rec_loss = F.l1_loss(x, x_rec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
# # sample a random content and style feature
|
122 |
z_c_hat, _ = self.content_pool.query(z_c, gt_style)
|
123 |
z_s_hat, _ = self.style_pool.query(z_s, gt_style)
|
124 |
|
125 |
+
x1 = self.decoder(z_c, z_s_hat)
|
126 |
+
x2 = self.decoder(z_c_hat, z_s)
|
127 |
|
128 |
+
z_c_rec = self.content_encoder(x1)
|
129 |
+
z_s_hat_rec = self.style_encoder(x1)
|
130 |
|
131 |
+
z_c_hat_rec = self.content_encoder(x2)
|
132 |
+
z_s_rec = self.style_encoder(x2)
|
133 |
|
134 |
c_const_loss = F.l1_loss(z_c, z_c_rec) + F.l1_loss(z_c_hat, z_c_hat_rec)
|
135 |
s_const_loss = F.l1_loss(z_s, z_s_rec) + F.l1_loss(z_s_hat, z_s_hat_rec)
|
136 |
|
137 |
# adversarial loss
|
138 |
+
i_g_loss = (
|
139 |
+
self.i_discriminator(x1, for_G=True).mean()
|
140 |
+
+ self.i_discriminator(x2, for_G=True).mean()
|
141 |
) / 2
|
142 |
|
143 |
+
c_g_loss = self.f_discriminator(z_c, z_s, for_real=False).mean()
|
144 |
|
145 |
g_loss = (
|
146 |
self.lambda_rec * rec_loss
|
147 |
+ self.lambda_c_const * c_const_loss
|
148 |
+ self.lambda_s_const * s_const_loss
|
149 |
+
+ self.lambda_x_g * i_g_loss
|
150 |
+ self.lambda_c_g * c_g_loss
|
151 |
+
+ self.lambda_cls * cls_loss
|
152 |
)
|
153 |
|
154 |
g_opt.zero_grad()
|
155 |
self.manual_backward(g_loss)
|
156 |
g_opt.step()
|
157 |
|
158 |
+
# train the image discriminator
|
159 |
+
i_d_loss = (
|
160 |
+
self.i_discriminator(x, for_real=True).mean()
|
161 |
+ (
|
162 |
+
self.i_discriminator(x1.detach(), for_real=False).mean()
|
163 |
+
+ self.i_discriminator(x2.detach(), for_real=False).mean()
|
164 |
)
|
165 |
/ 2
|
166 |
) / 2
|
167 |
|
168 |
+
i_d_opt.zero_grad()
|
169 |
+
self.manual_backward(i_d_loss)
|
170 |
+
i_d_opt.step()
|
171 |
|
172 |
+
# train the feature discriminator
|
173 |
+
f_d_loss = (
|
174 |
+
self.f_discriminator(z_c.detach(), z_s.detach(), for_real=True).mean()
|
175 |
+ (
|
176 |
+
self.f_discriminator(z_c.detach(), z_s_hat, for_real=False).mean()
|
177 |
+
+ self.f_discriminator(z_c_hat, z_s.detach(), for_real=False).mean()
|
178 |
)
|
179 |
/ 2
|
180 |
) / 2
|
181 |
|
182 |
+
f_d_opt.zero_grad()
|
183 |
+
self.manual_backward(f_d_loss)
|
184 |
+
f_d_opt.step()
|
185 |
|
186 |
self.log_dict(
|
187 |
{
|
188 |
"train/rec_loss": rec_loss,
|
189 |
"train/cls_loss": cls_loss,
|
190 |
+
"train/i_g_loss": i_g_loss,
|
191 |
+
"train/f_d_loss": f_d_loss,
|
192 |
"train/c_g_loss": c_g_loss,
|
193 |
+
"train/f_d_loss": f_d_loss,
|
194 |
"train/c_const_loss": c_const_loss,
|
195 |
"train/s_const_loss": s_const_loss,
|
196 |
},
|
|
|
201 |
|
202 |
def validation_step(self, batch, batch_idx):
|
203 |
x = batch["images"]
|
|
|
204 |
|
205 |
x = x[torch.randperm(x.shape[0])]
|
206 |
+
z_c = self.content_encoder(x)
|
207 |
+
z_s = self.style_encoder(x)
|
208 |
+
x_rec = self.decoder(z_c, z_s)
|
|
|
209 |
|
210 |
x1, x2 = x.chunk(2, dim=0)
|
211 |
|
|
|
212 |
x1_rec, x2_rec = x_rec.chunk(2, dim=0)
|
213 |
|
214 |
z1_c, z2_c = z_c.chunk(2, dim=0)
|
215 |
z1_s, z2_s = z_s.chunk(2, dim=0)
|
216 |
+
x1_swap = self.decoder(z1_c, z2_s)
|
217 |
+
x2_swap = self.decoder(z2_c, z1_s)
|
|
|
|
|
218 |
|
219 |
if self.trainer.is_global_zero:
|
220 |
x1_img = self.postprocess_images(x1)
|
|
|
239 |
wandb.Image(image, caption="orig | rec | swap") for image in images
|
240 |
]
|
241 |
|
242 |
+
# wandb.log({"val/samples": images})
|
243 |
|
244 |
self.log("val/lpips", -self.global_step, sync_dist=True)
|
245 |
|
|
|
255 |
|
256 |
def configure_optimizers(self):
|
257 |
g_opt = torch.optim.AdamW(
|
258 |
+
list(self.content_encoder.parameters())
|
259 |
+
+ list(self.style_encoder.parameters())
|
260 |
+
+ list(self.style_classifier.parameters())
|
261 |
+
+ list(self.decoder.parameters()),
|
262 |
lr=self.learning_rate,
|
263 |
+
weight_decay=self.weight_decay,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
)
|
265 |
|
266 |
+
i_d_opt = torch.optim.AdamW(
|
267 |
+
list(self.i_discriminator.parameters()),
|
268 |
+
lr=self.learning_rate,
|
269 |
+
weight_decay=self.weight_decay,
|
270 |
)
|
271 |
|
272 |
+
f_d_opt = torch.optim.AdamW(
|
273 |
+
list(self.f_discriminator.parameters()),
|
274 |
+
lr=self.learning_rate,
|
275 |
+
weight_decay=self.weight_decay,
|
276 |
)
|
277 |
|
278 |
+
return [g_opt, i_d_opt, f_d_opt]
|
swim/train.py
CHANGED
@@ -91,9 +91,6 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
91 |
model.compile()
|
92 |
|
93 |
log.info("Starting training!")
|
94 |
-
from swim.models.swim_gan import SwimGAN
|
95 |
-
|
96 |
-
# model = SwimGAN.load_from_checkpoint(cfg.get("ckpt_path"), strict=False)
|
97 |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
98 |
|
99 |
train_metrics = trainer.callback_metrics
|
|
|
91 |
model.compile()
|
92 |
|
93 |
log.info("Starting training!")
|
|
|
|
|
|
|
94 |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
95 |
|
96 |
train_metrics = trainer.callback_metrics
|