import torch import torch.nn as nn from prefigure.prefigure import get_all_args, push_wandb_config import json import os import re import torchaudio from lightning.pytorch import seed_everything import random from datetime import datetime import numpy as np import sys # 获取当前脚本所在目录(ckpts/) current_dir = os.path.dirname(os.path.abspath(__file__)) # 项目根目录 = ckpts 的上级目录 project_root = os.path.abspath(os.path.join(current_dir, '..')) # 添加项目根目录到 sys.path if project_root not in sys.path: sys.path.insert(0, project_root) from ThinkSound.data.datamodule import DataModule from ThinkSound.models import create_model_from_config from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model from ThinkSound.inference.sampling import sample, sample_discrete_euler from pathlib import Path from tqdm import tqdm def main(): args = get_all_args() if args.save_dir == '': args.save_dir = args.results_dir seed = args.seed if os.environ.get("SLURM_PROCID") is not None: seed += int(os.environ.get("SLURM_PROCID")) seed_everything(seed, workers=True) # Load config if args.model_config == '': args.model_config = "ThinkSound/configs/model_configs/thinksound.json" with open(args.model_config) as f: model_config = json.load(f) duration = float(args.duration_sec) sample_rate = model_config["sample_rate"] latent_length = round(44100 / 64 / 32 * duration) model_config["sample_size"] = duration * sample_rate model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration) model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration) model_config["model"]["diffusion"]["config"]["latent_seq_len"] = latent_length model = create_model_from_config(model_config) # model.load_state_dict(torch.load(args.ckpt_dir)) # Step 2: 加载旧的 checkpoint(不包含cross-attn) old_ckpt = torch.load(args.ckpt_dir, map_location='cpu') # Step 3: 仅提取匹配的权重(名字和尺寸都要匹配) model_state = model.state_dict() matched_ckpt = {k: v for k, v in old_ckpt.items() if k in model_state and v.shape == model_state[k].shape} print(f"[INFO] Loaded {len(matched_ckpt)} keys from old checkpoint") # Step 4: 加载已有权重 model_state.update(matched_ckpt) model.load_state_dict(model_state) # Step 5: 初始化 cross-attn 模块(只初始化新增部分) def init_cross_attn_weights(module): from einops.layers.torch import Rearrange if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) elif isinstance(module, nn.RMSNorm) or module.__class__.__name__ == "RMSNorm": if hasattr(module, 'weight'): nn.init.ones_(module.weight) if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) import pdb; pdb.set_trace() pass # 只遍历 cross-attn 模块进行初始化 for name, module in model.named_modules(): if 'cross_attn' in name: module.apply(init_cross_attn_weights) print(f"[INIT] Initialized {name}") # Step 6: 保存新权重 torch.save(model.state_dict(), 'ckpts/row_thinksound_light_cross_attn.ckpt') print("[DONE] New checkpoint saved with old weights + initialized cross-attn.") if __name__ == '__main__': main()