cc
Browse files
scripts/compute_latent.py
CHANGED
@@ -48,7 +48,7 @@ def compute_latent(model_name: str, swim_dir: str, batch_size: int):
|
|
48 |
with torch.no_grad():
|
49 |
images = torch.stack(images).cuda()
|
50 |
latents = model.encode(images).latent_dist.mode()
|
51 |
-
latents = latents.
|
52 |
|
53 |
for name, latent in zip(image_names, latents):
|
54 |
torch.save(
|
|
|
48 |
with torch.no_grad():
|
49 |
images = torch.stack(images).cuda()
|
50 |
latents = model.encode(images).latent_dist.mode()
|
51 |
+
latents = latents.detach().cpu().numpy()
|
52 |
|
53 |
for name, latent in zip(image_names, latents):
|
54 |
torch.save(
|