{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "407171be-ab46-442b-a0bd-83ca75173eba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AsymmetricAutoencoderKL(\n", " (encoder): Encoder(\n", " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (down_blocks): ModuleList(\n", " (0): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (1): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (2): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (3): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " )\n", " (mid_block): UNetMidBlock2D(\n", " (attentions): ModuleList(\n", " (0): Attention(\n", " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (to_q): Linear(in_features=512, out_features=512, bias=True)\n", " (to_k): Linear(in_features=512, out_features=512, bias=True)\n", " (to_v): Linear(in_features=512, out_features=512, bias=True)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv_act): SiLU()\n", " (conv_out): Conv2d(512, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " (decoder): MaskConditionDecoder(\n", " (conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (up_blocks): ModuleList(\n", " (0-1): 2 x UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-3): 4 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (2): UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-3): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (3): UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-3): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " )\n", " (mid_block): UNetMidBlock2D(\n", " (attentions): ModuleList(\n", " (0): Attention(\n", " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (to_q): Linear(in_features=512, out_features=512, bias=True)\n", " (to_k): Linear(in_features=512, out_features=512, bias=True)\n", " (to_v): Linear(in_features=512, out_features=512, bias=True)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " (condition_encoder): MaskConditionEncoder(\n", " (layers): Sequential(\n", " (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (2): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (3): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (4): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv_act): SiLU()\n", " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " (quant_conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))\n", " (post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))\n", ")\n" ] } ], "source": [ "from diffusers.models import AsymmetricAutoencoderKL\n", "import torch\n", "\n", "config = {\n", " \"_class_name\": \"AsymmetricAutoencoderKL\",\n", " \"act_fn\": \"silu\",\n", " \"down_block_out_channels\": [128, 256, 512, 512],\n", " \"down_block_types\": [\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " ],\n", " \"in_channels\": 3,\n", " \"latent_channels\": 4,\n", " \"norm_num_groups\": 32,\n", " \"out_channels\": 3,\n", " \"sample_size\": 1024,\n", " \"scaling_factor\": 0.13025,\n", " \"shift_factor\": 0,\n", " \"up_block_out_channels\": [128, 256, 512, 512],\n", " \"up_block_types\": [\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " ],\n", "}\n", "\n", "# Создаем модель\n", "vae = AsymmetricAutoencoderKL(\n", " act_fn=config[\"act_fn\"],\n", " down_block_out_channels=config[\"down_block_out_channels\"],\n", " down_block_types=config[\"down_block_types\"],\n", " in_channels=config[\"in_channels\"],\n", " latent_channels=config[\"latent_channels\"],\n", " norm_num_groups=config[\"norm_num_groups\"],\n", " out_channels=config[\"out_channels\"],\n", " sample_size=config[\"sample_size\"],\n", " scaling_factor=config[\"scaling_factor\"],\n", " up_block_out_channels=config[\"up_block_out_channels\"],\n", " up_block_types=config[\"up_block_types\"],\n", " layers_per_down_block = 2,\n", " layers_per_up_block = 3\n", ")\n", "\n", "\n", "vae.save_pretrained(\"asymmetric_vae_empty\")\n", "print(vae)" ] }, { "cell_type": "code", "execution_count": 3, "id": "a2950158-5203-42b9-8791-e231ddbf1063", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n", "Перенос весов: 100%|██████████| 248/248 [00:00<00:00, 30427.29it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Статистика переноса: {'перенесено': 248, 'несовпадение_размеров': 0, 'пропущено': 0}\n", "Неперенесенные ключи в новой модели:\n", "decoder.condition_encoder.layers.0.bias\n", "decoder.condition_encoder.layers.0.weight\n", "decoder.condition_encoder.layers.1.bias\n", "decoder.condition_encoder.layers.1.weight\n", "decoder.condition_encoder.layers.2.bias\n", "decoder.condition_encoder.layers.2.weight\n", "decoder.condition_encoder.layers.3.bias\n", "decoder.condition_encoder.layers.3.weight\n", "decoder.condition_encoder.layers.4.bias\n", "decoder.condition_encoder.layers.4.weight\n", "decoder.up_blocks.0.resnets.3.conv1.bias\n", "decoder.up_blocks.0.resnets.3.conv1.weight\n", "decoder.up_blocks.0.resnets.3.conv2.bias\n", "decoder.up_blocks.0.resnets.3.conv2.weight\n", "decoder.up_blocks.0.resnets.3.norm1.bias\n", "decoder.up_blocks.0.resnets.3.norm1.weight\n", "decoder.up_blocks.0.resnets.3.norm2.bias\n", "decoder.up_blocks.0.resnets.3.norm2.weight\n", "decoder.up_blocks.1.resnets.3.conv1.bias\n", "decoder.up_blocks.1.resnets.3.conv1.weight\n", "decoder.up_blocks.1.resnets.3.conv2.bias\n", "decoder.up_blocks.1.resnets.3.conv2.weight\n", "decoder.up_blocks.1.resnets.3.norm1.bias\n", "decoder.up_blocks.1.resnets.3.norm1.weight\n", "decoder.up_blocks.1.resnets.3.norm2.bias\n", "decoder.up_blocks.1.resnets.3.norm2.weight\n", "decoder.up_blocks.2.resnets.3.conv1.bias\n", "decoder.up_blocks.2.resnets.3.conv1.weight\n", "decoder.up_blocks.2.resnets.3.conv2.bias\n", "decoder.up_blocks.2.resnets.3.conv2.weight\n", "decoder.up_blocks.2.resnets.3.norm1.bias\n", "decoder.up_blocks.2.resnets.3.norm1.weight\n", "decoder.up_blocks.2.resnets.3.norm2.bias\n", "decoder.up_blocks.2.resnets.3.norm2.weight\n", "decoder.up_blocks.3.resnets.3.conv1.bias\n", "decoder.up_blocks.3.resnets.3.conv1.weight\n", "decoder.up_blocks.3.resnets.3.conv2.bias\n", "decoder.up_blocks.3.resnets.3.conv2.weight\n", "decoder.up_blocks.3.resnets.3.norm1.bias\n", "decoder.up_blocks.3.resnets.3.norm1.weight\n", "decoder.up_blocks.3.resnets.3.norm2.bias\n", "decoder.up_blocks.3.resnets.3.norm2.weight\n" ] } ], "source": [ "import torch\n", "from diffusers import AsymmetricAutoencoderKL,AutoencoderKL\n", "from tqdm import tqdm\n", "import torch.nn.init as init\n", "\n", "def log(message):\n", " print(message)\n", "\n", "def main():\n", " checkpoint_path_old = \"vae\"\n", " checkpoint_path_new = \"asymmetric_vae_empty\"\n", " device = \"cuda\"\n", " dtype = torch.float32\n", "\n", " # Загрузка моделей\n", " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n", " new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n", "\n", " old_state_dict = old_unet.state_dict()\n", " new_state_dict = new_unet.state_dict()\n", "\n", " transferred_state_dict = {}\n", " transfer_stats = {\n", " \"перенесено\": 0,\n", " \"несовпадение_размеров\": 0,\n", " \"пропущено\": 0\n", " }\n", "\n", " transferred_keys = set()\n", "\n", " # Обрабатываем каждый ключ старой модели\n", " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n", " new_key = old_key\n", "\n", " if new_key in new_state_dict:\n", " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n", " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n", " transferred_keys.add(new_key)\n", " transfer_stats[\"перенесено\"] += 1\n", " else:\n", " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n", " transfer_stats[\"несовпадение_размеров\"] += 1\n", " else:\n", " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n", " transfer_stats[\"пропущено\"] += 1\n", "\n", " # Обновляем состояние новой модели перенесенными весами\n", " new_state_dict.update(transferred_state_dict)\n", " \n", " # Инициализируем веса для нового mid блока\n", " #new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n", " \n", " new_unet.load_state_dict(new_state_dict)\n", " new_unet.save_pretrained(\"asymmetric_vae\")\n", "\n", " # Получаем список неперенесенных ключей\n", " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n", "\n", " print(\"Статистика переноса:\", transfer_stats)\n", " print(\"Неперенесенные ключи в новой модели:\")\n", " for key in non_transferred_keys:\n", " print(key)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "code", "execution_count": 1, "id": "b316ee6c-d295-4396-9177-78e39a53055b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ok\n" ] } ], "source": [ "import torch\n", "\n", "from torchvision import transforms, utils\n", "\n", "import diffusers\n", "from diffusers import AsymmetricAutoencoderKL\n", "\n", "from diffusers.utils import load_image\n", "\n", "def crop_image_to_nearest_divisible_by_8(img):\n", " # Check if the image height and width are divisible by 8\n", " if img.shape[1] % 8 == 0 and img.shape[2] % 8 == 0:\n", " return img\n", " else:\n", " # Calculate the closest lower resolution divisible by 8\n", " new_height = img.shape[1] - (img.shape[1] % 8)\n", " new_width = img.shape[2] - (img.shape[2] % 8)\n", " \n", " # Use CenterCrop to crop the image\n", " transform = transforms.CenterCrop((new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR)\n", " img = transform(img).to(torch.float32).clamp(-1, 1)\n", " \n", " return img\n", " \n", "to_tensor = transforms.ToTensor()\n", "\n", "device = \"cuda\"\n", "dtype=torch.float16\n", "vae = AsymmetricAutoencoderKL.from_pretrained(\"asymmetric_vae\",torch_dtype=dtype).to(device).eval()\n", "\n", "image = load_image(\"123456789.jpg\")\n", "\n", "image = crop_image_to_nearest_divisible_by_8(to_tensor(image)).unsqueeze(0).to(device,dtype=dtype)\n", "\n", "upscaled_image = vae(image).sample\n", "#vae.config.scaled_factor\n", "# Save the reconstructed image\n", "utils.save_image(upscaled_image, \"test.png\")\n", "print('ok')" ] }, { "cell_type": "code", "execution_count": null, "id": "5a01b8e9-73c9-4da7-a097-e334019bd8e9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }