swim_new / cc.py
qninhdt's picture
cc
b759b90
raw
history blame
997 Bytes
import torch
from torchinfo import summary
from swim.models.content_encoder import ContentEncoder
from swim.models.decoder import Decoder
from swim.models.style_encoder import StyleEncoder
from swim.models.discriminator import FeatureDiscriminator
import vision_aided_loss
# from swim.models.swim_gan import SwimGAN
image = torch.randn(1, 3, 512, 512).to("cuda")
style_emb = torch.randn(1, 256).to("cuda")
content = torch.randn(1, 512, 64, 64).to("cuda")
content_encoder = ContentEncoder().cuda()
decoder = Decoder().cuda()
style_encoder = StyleEncoder().cuda()
discriminator = FeatureDiscriminator().cuda()
# i_discriminator = vision_aided_loss.Discriminator(
# cv_type="clip", loss_type="multilevel_sigmoid_s", device="cuda"
# ).to("cuda")
summary(content_encoder, input_data=(image,))
# summary(decoder, input_data=(content, style_emb))
# summary(style_encoder, input_data=image)
# summary(discriminator, input_data=(content, style_emb))
# summary(i_discriminator, input_data=(image,))