|
import torch |
|
import torch.nn as nn |
|
from prefigure.prefigure import get_all_args, push_wandb_config |
|
import json |
|
import os |
|
import re |
|
import torchaudio |
|
from lightning.pytorch import seed_everything |
|
import random |
|
from datetime import datetime |
|
import numpy as np |
|
|
|
from ThinkSound.data.datamodule import DataModule |
|
from ThinkSound.models import create_model_from_config |
|
from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model |
|
from ThinkSound.inference.sampling import sample, sample_discrete_euler |
|
from pathlib import Path |
|
from tqdm import tqdm |
|
|
|
def main(): |
|
args = get_all_args() |
|
|
|
if args.save_dir == '': |
|
args.save_dir = args.results_dir |
|
|
|
|
|
seed = args.seed |
|
if os.environ.get("SLURM_PROCID") is not None: |
|
seed += int(os.environ.get("SLURM_PROCID")) |
|
seed_everything(seed, workers=True) |
|
|
|
|
|
if args.model_config == '': |
|
args.model_config = "ThinkSound/configs/model_configs/thinksound.json" |
|
with open(args.model_config) as f: |
|
model_config = json.load(f) |
|
|
|
duration = float(args.duration_sec) |
|
sample_rate = model_config["sample_rate"] |
|
latent_length = round(44100 / 64 / 32 * duration) |
|
|
|
model_config["sample_size"] = duration * sample_rate |
|
model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration) |
|
model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration) |
|
model_config["model"]["diffusion"]["config"]["latent_seq_len"] = latent_length |
|
|
|
model = create_model_from_config(model_config) |
|
|
|
|
|
|
|
|
|
old_ckpt = torch.load(args.ckpt_dir, map_location='cpu') |
|
|
|
|
|
model_state = model.state_dict() |
|
matched_ckpt = {k: v for k, v in old_ckpt.items() if k in model_state and v.shape == model_state[k].shape} |
|
print(f"[INFO] Loaded {len(matched_ckpt)} keys from old checkpoint") |
|
|
|
|
|
model_state.update(matched_ckpt) |
|
model.load_state_dict(model_state) |
|
|
|
|
|
def init_cross_attn_weights(module): |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.LayerNorm): |
|
nn.init.ones_(module.weight) |
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
for name, module in model.named_modules(): |
|
if 'cross_attn' in name: |
|
module.apply(init_cross_attn_weights) |
|
print(f"[INIT] Initialized {name}") |
|
|
|
|
|
torch.save(model.state_dict(), 'ckpts/thinksound_light_cross_attn.ckpt') |
|
print("[DONE] New checkpoint saved with old weights + initialized cross-attn.") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|