import torch from torchinfo import summary from swim.models.content_encoder import ContentEncoder from swim.models.decoder import Decoder from swim.models.style_encoder import StyleEncoder from swim.models.discriminator import FeatureDiscriminator import vision_aided_loss # from swim.models.swim_gan import SwimGAN image = torch.randn(1, 3, 512, 512).to("cuda") style_emb = torch.randn(1, 256).to("cuda") content = torch.randn(1, 512, 64, 64).to("cuda") content_encoder = ContentEncoder().cuda() decoder = Decoder().cuda() style_encoder = StyleEncoder().cuda() discriminator = FeatureDiscriminator().cuda() # i_discriminator = vision_aided_loss.Discriminator( # cv_type="clip", loss_type="multilevel_sigmoid_s", device="cuda" # ).to("cuda") summary(content_encoder, input_data=(image,)) # summary(decoder, input_data=(content, style_emb)) # summary(style_encoder, input_data=image) # summary(discriminator, input_data=(content, style_emb)) # summary(i_discriminator, input_data=(image,))