|
import torch |
|
from torchinfo import summary |
|
|
|
from swim.models.content_encoder import ContentEncoder |
|
from swim.models.decoder import Decoder |
|
from swim.models.style_encoder import StyleEncoder |
|
from swim.models.discriminator import FeatureDiscriminator |
|
import vision_aided_loss |
|
|
|
|
|
|
|
image = torch.randn(1, 3, 512, 512).to("cuda") |
|
style_emb = torch.randn(1, 256).to("cuda") |
|
content = torch.randn(1, 512, 64, 64).to("cuda") |
|
content_encoder = ContentEncoder().cuda() |
|
decoder = Decoder().cuda() |
|
style_encoder = StyleEncoder().cuda() |
|
discriminator = FeatureDiscriminator().cuda() |
|
|
|
|
|
|
|
|
|
summary(content_encoder, input_data=(image,)) |
|
|
|
|
|
|
|
|
|
|