qninhdt commited on
Commit
82e5f44
·
1 Parent(s): db72572
Files changed (1) hide show
  1. scripts/compute_latent.py +4 -3
scripts/compute_latent.py CHANGED
@@ -4,6 +4,7 @@ import PIL
4
 
5
  from itertools import batched
6
 
 
7
  import torch
8
  import torchvision.transforms as T
9
  from diffusers import AutoencoderKL
@@ -51,11 +52,11 @@ def compute_latent(model_name: str, swim_dir: str, batch_size: int):
51
  latents = latents.detach().cpu().numpy()
52
 
53
  for name, latent in zip(image_names, latents):
54
- torch.save(
55
- latent,
56
  os.path.join(
57
- output, name.replace(".jpg", ".pt").replace(".png", ".pt")
58
  ),
 
59
  )
60
 
61
 
 
4
 
5
  from itertools import batched
6
 
7
+ import numpy as np
8
  import torch
9
  import torchvision.transforms as T
10
  from diffusers import AutoencoderKL
 
52
  latents = latents.detach().cpu().numpy()
53
 
54
  for name, latent in zip(image_names, latents):
55
+ np.save(
 
56
  os.path.join(
57
+ output, name.replace(".jpg", ".npy").replace(".png", ".npy")
58
  ),
59
+ latent,
60
  )
61
 
62