recoilme commited on
Commit
21e5172
·
1 Parent(s): 68d00e7
asymmetric_vae_new/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b0689cd2f3a6f81c14a95e1f2a7c4cee6b97b51f34700c5983ee2f28df17ef6
3
  size 421473052
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69c5a55938fb7e33849a58865e243ee02b3ad9cf6ff5a6f6b97ad025e38d64e0
3
  size 421473052
eval_asym.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import logging
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.data as data
6
+ import lpips
7
+ from tqdm import tqdm
8
+ from torchvision.transforms import (
9
+ Compose,
10
+ Resize,
11
+ ToTensor,
12
+ CenterCrop,
13
+ )
14
+ from diffusers import AutoencoderKL,AsymmetricAutoencoderKL
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ warnings.filterwarnings(
20
+ "ignore",
21
+ ".*Found keys that are not in the model state dict but in the checkpoint.*",
22
+ )
23
+
24
+ DEVICE = "cuda"
25
+ DTYPE = torch.float16
26
+ SHORT_AXIS_SIZE = 256
27
+ batch_size = 1
28
+
29
+ NAMES = [
30
+ # "asymmetric_vae",
31
+ # "asymmetric_vae_new",
32
+ # "madebyollin/sdxl-vae-fp16-fix",
33
+ # "KBlueLeaf/EQ-SDXL-VAE ",
34
+ "AiArtLab/simplevae ",
35
+ ]
36
+ BASE_MODELS = [
37
+ # "./asymmetric_vae",
38
+ # "./asymmetric_vae_new",
39
+ # "madebyollin/sdxl-vae-fp16-fix",
40
+ # "KBlueLeaf/EQ-SDXL-VAE",
41
+ "AiArtLab/simplevae",
42
+ ]
43
+ SUB_FOLDERS = [
44
+ "sdxs_vae",
45
+ # None,
46
+ # None,
47
+ # "sdxl_vae"
48
+ ]
49
+
50
+ def process(x):
51
+ return x * 2 - 1
52
+
53
+ def deprocess(x):
54
+ return x * 0.5 + 0.5
55
+
56
+ import torch.utils.data as data
57
+ from datasets import load_dataset
58
+
59
+ class ImageNetDataset(data.IterableDataset):
60
+ def __init__(self, split, transform=None, max_len=10, streaming=True):
61
+ self.split = split
62
+ self.transform = transform
63
+ self.dataset = load_dataset("evanarlian/imagenet_1k_resized_256", split=split, streaming=streaming)
64
+ self.max_len = max_len
65
+ self.iterator = iter(self.dataset)
66
+
67
+ def __iter__(self):
68
+ for i, entry in enumerate(self.iterator):
69
+ if self.max_len and i >= self.max_len:
70
+ break
71
+ img = entry["image"]
72
+ target = entry["label"]
73
+ if self.transform is not None:
74
+ img = self.transform(img)
75
+ yield img, target
76
+
77
+ if __name__ == "__main__":
78
+ lpips_loss = torch.compile(
79
+ lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
80
+ )
81
+
82
+ @torch.compile
83
+ def metrics(inp, recon):
84
+ mse = F.mse_loss(inp, recon)
85
+ psnr = 10 * torch.log10(1 / mse)
86
+ return (
87
+ mse.cpu(),
88
+ psnr.cpu(),
89
+ lpips_loss(inp, recon, normalize=True).mean().cpu(),
90
+ )
91
+
92
+ transform = Compose(
93
+ [
94
+ Resize(SHORT_AXIS_SIZE),
95
+ CenterCrop(SHORT_AXIS_SIZE),
96
+ ToTensor(),
97
+ ]
98
+ )
99
+ valid_dataset = ImageNetDataset("val", transform=transform, max_len=50000, streaming=True)
100
+ valid_loader = data.DataLoader(
101
+ valid_dataset,
102
+ batch_size=batch_size,
103
+ shuffle=False,
104
+ num_workers=2,
105
+ pin_memory=True,
106
+ pin_memory_device=DEVICE,
107
+ )
108
+
109
+ # Проверяем, что данные грузятся
110
+ for batch in valid_loader:
111
+ print("Batch shape:", batch[0].shape)
112
+ break
113
+
114
+ logger.info("Loading models...")
115
+ vaes = []
116
+ for base_model, sub_folder in zip(
117
+ BASE_MODELS, SUB_FOLDERS
118
+ ):
119
+ vae = AsymmetricAutoencoderKL.from_pretrained(base_model, subfolder=sub_folder)
120
+ vae = vae.to(DTYPE).eval().requires_grad_(False).to(DEVICE)
121
+ vae.encoder = torch.compile(vae.encoder)
122
+ vae.decoder = torch.compile(vae.decoder)
123
+ vaes.append(torch.compile(vae))
124
+
125
+ logger.info("Running Validation")
126
+ total = 0
127
+ all_latents = [[] for _ in range(len(vaes))]
128
+ all_mse = [[] for _ in range(len(vaes))]
129
+ all_psnr = [[] for _ in range(len(vaes))]
130
+ all_lpips = [[] for _ in range(len(vaes))]
131
+
132
+ for idx, batch in enumerate(tqdm(valid_loader)):
133
+ image = batch[0].to(DEVICE)
134
+ test_inp = process(image).to(DTYPE)
135
+ batch_size = test_inp.size(0)
136
+
137
+ for i, vae in enumerate(vaes):
138
+ latent = vae.encode(test_inp).latent_dist.mode()
139
+ recon = deprocess(vae.decode(latent).sample.float())
140
+ all_latents[i].append(latent.cpu().float())
141
+ mse, psnr, lpips_ = metrics(image, recon)
142
+ all_mse[i].append(mse.cpu() * batch_size)
143
+ all_psnr[i].append(psnr.cpu() * batch_size)
144
+ all_lpips[i].append(lpips_.cpu() * batch_size)
145
+
146
+ total += batch_size
147
+
148
+ for i in range(len(vaes)):
149
+ all_latents[i] = torch.cat(all_latents[i], dim=0)
150
+ all_mse[i] = torch.stack(all_mse[i]).sum() / total
151
+ all_psnr[i] = torch.stack(all_psnr[i]).sum() / total
152
+ all_lpips[i] = torch.stack(all_lpips[i]).sum() / total
153
+
154
+ logger.info(
155
+ f" - {NAMES[i]}: MSE: {all_mse[i]:.3e}, PSNR: {all_psnr[i]:.4f}, "
156
+ f"LPIPS: {all_lpips[i]:.4f}"
157
+ )
158
+
159
+ logger.info("End")
samples/sample_0_0.jpg DELETED

Git LFS Details

  • SHA256: fa157903dd5a4118d9c38e32c25c5a02a3eeaddb59d3a1c9d8fe7e9eb57e3f14
  • Pointer size: 130 Bytes
  • Size of remote file: 98 kB
samples/sample_0_1.jpg DELETED

Git LFS Details

  • SHA256: 7cba73cbeeb41f97f6247043e00a5346cf10f6bf67f4ffa4ac8a736c6841a2be
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
samples/sample_0_2.jpg DELETED

Git LFS Details

  • SHA256: 2cdfd5107c48e41eb4d9475b9360f2c5a98b25509649e37df9eac75065ffbd96
  • Pointer size: 130 Bytes
  • Size of remote file: 93.4 kB
samples/sample_673_0.jpg DELETED

Git LFS Details

  • SHA256: ecb6610fe8119c402581c2181181aea871f7a6f3a211b48c1927cea878d9babb
  • Pointer size: 130 Bytes
  • Size of remote file: 95.5 kB
samples/sample_673_1.jpg DELETED

Git LFS Details

  • SHA256: e370fb4119a38245baad69f7e243506d69e40437878253e91d683ebba1f443af
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
samples/sample_673_2.jpg DELETED

Git LFS Details

  • SHA256: ff7edcb0dbc7a36cd3a5a344e4a47b6e13ea1153455c115b738025beb2d45fbc
  • Pointer size: 130 Bytes
  • Size of remote file: 90.3 kB