|
|
|
import os |
|
import math |
|
import re |
|
import torch |
|
import numpy as np |
|
import random |
|
import gc |
|
from datetime import datetime |
|
from pathlib import Path |
|
|
|
import torchvision.transforms as transforms |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader, Dataset |
|
from torch.optim.lr_scheduler import LambdaLR |
|
from diffusers import AutoencoderKL, AsymmetricAutoencoderKL |
|
from accelerate import Accelerator |
|
from PIL import Image, UnidentifiedImageError |
|
from tqdm import tqdm |
|
import bitsandbytes as bnb |
|
import wandb |
|
import lpips |
|
|
|
|
|
ds_path = "/workspace/png" |
|
project = "asymmetric_vae" |
|
batch_size = 2 |
|
base_learning_rate = 1e-6 |
|
min_learning_rate = 8e-7 |
|
num_epochs = 8 |
|
sample_interval_share = 10 |
|
use_wandb = True |
|
save_model = True |
|
use_decay = True |
|
asymmetric = True |
|
optimizer_type = "adam8bit" |
|
dtype = torch.float32 |
|
|
|
model_resolution = 512 |
|
|
|
high_resolution = 1024 |
|
limit = 0 |
|
save_barrier = 1.03 |
|
warmup_percent = 0.01 |
|
percentile_clipping = 95 |
|
beta2 = 0.97 |
|
eps = 1e-6 |
|
clip_grad_norm = 1.0 |
|
mixed_precision = "no" |
|
gradient_accumulation_steps = 8 |
|
generated_folder = "samples" |
|
save_as = "asymmetric_vae_new" |
|
perceptual_loss_weight = 0.03 |
|
num_workers = 0 |
|
device = None |
|
|
|
|
|
lpips_ratio = 0.9 |
|
|
|
min_perceptual_weight = 0.1 |
|
max_perceptual_weight = 99 |
|
|
|
|
|
resize_long_side = 1280 |
|
|
|
Path(generated_folder).mkdir(parents=True, exist_ok=True) |
|
|
|
accelerator = Accelerator( |
|
mixed_precision=mixed_precision, |
|
gradient_accumulation_steps=gradient_accumulation_steps |
|
) |
|
device = accelerator.device |
|
|
|
|
|
seed = int(datetime.now().strftime("%Y%m%d")) |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
if use_wandb and accelerator.is_main_process: |
|
wandb.init(project=project, config={ |
|
"batch_size": batch_size, |
|
"base_learning_rate": base_learning_rate, |
|
"num_epochs": num_epochs, |
|
"optimizer_type": optimizer_type, |
|
"model_resolution": model_resolution, |
|
"high_resolution": high_resolution, |
|
"gradient_accumulation_steps": gradient_accumulation_steps, |
|
}) |
|
|
|
|
|
if model_resolution==high_resolution and not asymmetric: |
|
vae = AutoencoderKL.from_pretrained(project).to(dtype) |
|
else: |
|
vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype) |
|
|
|
|
|
for p in vae.parameters(): |
|
p.requires_grad = False |
|
|
|
decoder = getattr(vae, "decoder", None) |
|
if decoder is None: |
|
raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.") |
|
|
|
unfrozen_param_names = [] |
|
|
|
if not hasattr(decoder, "up_blocks"): |
|
raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.") |
|
|
|
|
|
n_up = len(decoder.up_blocks) |
|
start_idx = 0 |
|
for idx in range(start_idx, n_up): |
|
block = decoder.up_blocks[idx] |
|
for name, p in block.named_parameters(): |
|
p.requires_grad = True |
|
unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}") |
|
|
|
if hasattr(decoder, "mid_block"): |
|
for name, p in decoder.mid_block.named_parameters(): |
|
p.requires_grad = True |
|
unfrozen_param_names.append(f"decoder.mid_block.{name}") |
|
else: |
|
print("[WARN] decoder.mid_block не найден — mid_block не разморожен.") |
|
|
|
print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:") |
|
for nm in unfrozen_param_names[:200]: |
|
print(" ", nm) |
|
|
|
|
|
trainable_module = vae.decoder |
|
|
|
|
|
class PngFolderDataset(Dataset): |
|
def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0): |
|
|
|
self.root_dir = root_dir |
|
self.resolution = resolution |
|
self.paths = [] |
|
|
|
for root, _, files in os.walk(root_dir): |
|
for fname in files: |
|
if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)): |
|
self.paths.append(os.path.join(root, fname)) |
|
|
|
if limit: |
|
self.paths = self.paths[:limit] |
|
|
|
valid = [] |
|
for p in self.paths: |
|
try: |
|
with Image.open(p) as im: |
|
im.verify() |
|
valid.append(p) |
|
except (OSError, UnidentifiedImageError): |
|
|
|
continue |
|
self.paths = valid |
|
if len(self.paths) == 0: |
|
raise RuntimeError(f"No valid PNG images found under {root_dir}") |
|
|
|
random.shuffle(self.paths) |
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
def __getitem__(self, idx): |
|
p = self.paths[idx % len(self.paths)] |
|
|
|
with Image.open(p) as img: |
|
img = img.convert("RGB") |
|
|
|
if not resize_long_side or resize_long_side <= 0: |
|
return img |
|
w, h = img.size |
|
long = max(w, h) |
|
if long <= resize_long_side: |
|
return img |
|
scale = resize_long_side / float(long) |
|
new_w = int(round(w * scale)) |
|
new_h = int(round(h * scale)) |
|
return img.resize((new_w, new_h), Image.LANCZOS) |
|
|
|
|
|
|
|
def random_crop(img, sz): |
|
w, h = img.size |
|
if w < sz or h < sz: |
|
img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS) |
|
x = random.randint(0, max(1, img.width - sz)) |
|
y = random.randint(0, max(1, img.height - sz)) |
|
return img.crop((x, y, x + sz, y + sz)) |
|
|
|
tfm = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
|
|
|
|
dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit) |
|
if len(dataset) < batch_size: |
|
raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}") |
|
|
|
|
|
def collate_fn(batch): |
|
imgs = [] |
|
for img in batch: |
|
img = random_crop(img, high_resolution) |
|
imgs.append(tfm(img)) |
|
return torch.stack(imgs) |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
collate_fn=collate_fn, |
|
num_workers=num_workers, |
|
pin_memory=True, |
|
drop_last=True |
|
) |
|
|
|
|
|
def get_param_groups(module, weight_decay=0.001): |
|
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"] |
|
decay_params = [] |
|
no_decay_params = [] |
|
for n, p in module.named_parameters(): |
|
if not p.requires_grad: |
|
continue |
|
if any(nd in n for nd in no_decay): |
|
no_decay_params.append(p) |
|
else: |
|
decay_params.append(p) |
|
return [ |
|
{"params": decay_params, "weight_decay": weight_decay}, |
|
{"params": no_decay_params, "weight_decay": 0.0}, |
|
] |
|
|
|
def create_optimizer(name, param_groups): |
|
if name == "adam8bit": |
|
return bnb.optim.AdamW8bit( |
|
param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps |
|
) |
|
raise ValueError(name) |
|
|
|
param_groups = get_param_groups(trainable_module, weight_decay=0.001) |
|
optimizer = create_optimizer(optimizer_type, param_groups) |
|
|
|
|
|
batches_per_epoch = len(dataloader) |
|
steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) |
|
total_steps = steps_per_epoch * num_epochs |
|
|
|
def lr_lambda(step): |
|
if not use_decay: |
|
return 1.0 |
|
x = float(step) / float(max(1, total_steps)) |
|
warmup = float(warmup_percent) |
|
min_ratio = float(min_learning_rate) / float(base_learning_rate) |
|
if x < warmup: |
|
return min_ratio + (1.0 - min_ratio) * (x / warmup) |
|
decay_ratio = (x - warmup) / (1.0 - warmup) |
|
return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio)) |
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler) |
|
|
|
trainable_params = [p for p in vae.decoder.parameters() if p.requires_grad] |
|
|
|
|
|
@torch.no_grad() |
|
def get_fixed_samples(n=3): |
|
idx = random.sample(range(len(dataset)), min(n, len(dataset))) |
|
pil_imgs = [dataset[i] for i in idx] |
|
tensors = [] |
|
for img in pil_imgs: |
|
img = random_crop(img, high_resolution) |
|
tensors.append(tfm(img)) |
|
return torch.stack(tensors).to(accelerator.device, dtype) |
|
|
|
fixed_samples = get_fixed_samples() |
|
|
|
_lpips_net = None |
|
def _get_lpips(): |
|
global _lpips_net |
|
if _lpips_net is None: |
|
|
|
_lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval() |
|
return _lpips_net |
|
|
|
@torch.no_grad() |
|
def generate_and_save_samples(step=None): |
|
try: |
|
temp_vae = accelerator.unwrap_model(vae).eval() |
|
lpips_net = _get_lpips() |
|
with torch.no_grad(): |
|
|
|
orig_high = fixed_samples |
|
|
|
if model_resolution==high_resolution: |
|
orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False) |
|
else: |
|
orig_low =orig_high |
|
|
|
|
|
model_dtype = next(temp_vae.parameters()).dtype |
|
orig_low = orig_low.to(dtype=model_dtype) |
|
|
|
latent_dist = temp_vae.encode(orig_low).latent_dist |
|
latents = latent_dist.mean |
|
rec = temp_vae.decode(latents).sample |
|
|
|
|
|
|
|
if rec.shape[-2:] != orig_high.shape[-2:]: |
|
rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False) |
|
|
|
rec_img = ((rec.float() / 2.0 + 0.5).clamp(0, 1) * 255).cpu().numpy() |
|
for i in range(rec_img.shape[0]): |
|
arr = rec_img[i].transpose(1, 2, 0).astype(np.uint8) |
|
Image.fromarray(arr).save(f"{generated_folder}/sample_{step if step is not None else 'init'}_{i}.jpg", quality=95) |
|
|
|
|
|
lpips_scores = [] |
|
for i in range(rec.shape[0]): |
|
orig_full = orig_high[i:i+1] |
|
rec_full = rec[i:i+1] |
|
|
|
if rec_full.shape[-2:] != orig_full.shape[-2:]: |
|
rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False) |
|
rec_full = rec_full.to(torch.float32) |
|
orig_full = orig_full.to(torch.float32) |
|
lpips_val = lpips_net(orig_full, rec_full).item() |
|
lpips_scores.append(lpips_val) |
|
avg_lpips = float(np.mean(lpips_scores)) |
|
if use_wandb and accelerator.is_main_process: |
|
wandb.log({ |
|
"generated_images": [wandb.Image(Image.fromarray(rec_img[i].transpose(1,2,0).astype(np.uint8))) for i in range(rec_img.shape[0])], |
|
"lpips_mean": avg_lpips |
|
}, step=step) |
|
finally: |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
if accelerator.is_main_process and save_model: |
|
print("Генерация сэмплов до старта обучения...") |
|
generate_and_save_samples(0) |
|
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
|
|
progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process) |
|
global_step = 0 |
|
min_loss = float("inf") |
|
sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs)) |
|
|
|
for epoch in range(num_epochs): |
|
vae.train() |
|
batch_losses = [] |
|
batch_losses_mae = [] |
|
batch_losses_lpips = [] |
|
batch_losses_perc = [] |
|
batch_grads = [] |
|
for imgs in dataloader: |
|
with accelerator.accumulate(vae): |
|
|
|
imgs = imgs.to(accelerator.device) |
|
|
|
|
|
if model_resolution==high_resolution: |
|
imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False) |
|
else: |
|
imgs_low = imgs |
|
|
|
|
|
model_dtype = next(vae.parameters()).dtype |
|
if imgs_low.dtype != model_dtype: |
|
imgs_low_model = imgs_low.to(dtype=model_dtype) |
|
else: |
|
imgs_low_model = imgs_low |
|
|
|
|
|
latent_dist = vae.encode(imgs_low_model).latent_dist |
|
latents = latent_dist.mean |
|
rec = vae.decode(latents).sample |
|
|
|
|
|
if rec.shape[-2:] != imgs.shape[-2:]: |
|
rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False) |
|
|
|
|
|
rec_f32 = rec.to(torch.float32) |
|
imgs_f32 = imgs.to(torch.float32) |
|
|
|
|
|
mae_loss = F.l1_loss(rec_f32, imgs_f32) |
|
|
|
|
|
lpips_loss = _get_lpips()(rec_f32, imgs_f32).mean() |
|
|
|
|
|
if float(mae_loss.detach().cpu().item()) > 1e-12: |
|
desired_multiplier = lpips_ratio / max(1.0 - lpips_ratio, 1e-12) |
|
new_weight = (mae_loss.item() / float(lpips_loss.detach().cpu().item())) * desired_multiplier |
|
else: |
|
new_weight = perceptual_loss_weight |
|
|
|
perceptual_loss_weight = float(np.clip(new_weight, min_perceptual_weight, max_perceptual_weight)) |
|
batch_losses_perc.append(perceptual_loss_weight) |
|
if len(batch_losses_perc) >= sample_interval: |
|
avg_perc = float(np.mean(batch_losses_perc[-sample_interval:])) |
|
else: |
|
avg_perc = float(np.mean(batch_losses_perc[-sample_interval:])) |
|
|
|
total_loss = mae_loss + avg_perc * lpips_loss |
|
|
|
if torch.isnan(total_loss) or torch.isinf(total_loss): |
|
print("NaN/Inf loss – stopping") |
|
raise RuntimeError("NaN/Inf loss") |
|
|
|
accelerator.backward(total_loss) |
|
|
|
grad_norm = torch.tensor(0.0, device=accelerator.device) |
|
if accelerator.sync_gradients: |
|
grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
global_step += 1 |
|
progress.update(1) |
|
|
|
|
|
if accelerator.is_main_process: |
|
try: |
|
current_lr = optimizer.param_groups[0]["lr"] |
|
except Exception: |
|
current_lr = scheduler.get_last_lr()[0] |
|
|
|
batch_losses.append(total_loss.detach().item()) |
|
batch_losses_mae.append(mae_loss.detach().item()) |
|
batch_losses_lpips.append(lpips_loss.detach().item()) |
|
batch_grads.append(float(grad_norm if isinstance(grad_norm, (float, int)) else grad_norm.cpu().item())) |
|
|
|
if use_wandb and accelerator.sync_gradients: |
|
wandb.log({ |
|
"mae_loss": mae_loss.detach().item(), |
|
"lpips_loss": lpips_loss.detach().item(), |
|
"perceptual_loss_weight": avg_perc, |
|
"total_loss": total_loss.detach().item(), |
|
"learning_rate": current_lr, |
|
"epoch": epoch, |
|
"grad_norm": batch_grads[-1], |
|
}, step=global_step) |
|
|
|
|
|
if global_step > 0 and global_step % sample_interval == 0: |
|
|
|
if accelerator.is_main_process: |
|
generate_and_save_samples(global_step) |
|
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
n_micro = sample_interval * gradient_accumulation_steps |
|
|
|
if len(batch_losses) >= n_micro: |
|
avg_loss = float(np.mean(batch_losses[-n_micro:])) |
|
avg_loss_mae = float(np.mean(batch_losses_mae[-n_micro:])) |
|
avg_loss_lpips = float(np.mean(batch_losses_lpips[-n_micro:])) |
|
else: |
|
avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan") |
|
avg_loss_mae = float(np.mean(batch_losses_mae)) if batch_losses_mae else float("nan") |
|
avg_loss_lpips = float(np.mean(batch_losses_lpips)) if batch_losses_lpips else float("nan") |
|
|
|
avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0 |
|
|
|
if accelerator.is_main_process: |
|
print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}") |
|
if save_model and avg_loss < min_loss * save_barrier: |
|
min_loss = avg_loss |
|
accelerator.unwrap_model(vae).save_pretrained(save_as) |
|
if use_wandb: |
|
wandb.log({"interm_loss": avg_loss,"interm_loss_mae": avg_loss_mae,"interm_loss_lpips": avg_loss_lpips, "interm_grad": avg_grad}, step=global_step) |
|
|
|
if accelerator.is_main_process: |
|
epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan") |
|
print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}") |
|
if use_wandb: |
|
wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step) |
|
|
|
|
|
if accelerator.is_main_process: |
|
print("Training finished – saving final model") |
|
if save_model: |
|
accelerator.unwrap_model(vae).save_pretrained(save_as) |
|
|
|
accelerator.free_memory() |
|
if torch.distributed.is_initialized(): |
|
torch.distributed.destroy_process_group() |
|
print("Готово!") |