ThinkSound / init_cross_attn.py
Liyang Chen
full pipeline
e1b47be
raw
history blame
3.73 kB
import torch
import torch.nn as nn
from prefigure.prefigure import get_all_args, push_wandb_config
import json
import os
import re
import torchaudio
from lightning.pytorch import seed_everything
import random
from datetime import datetime
import numpy as np
import sys
# 获取当前脚本所在目录(ckpts/)
current_dir = os.path.dirname(os.path.abspath(__file__))
# 项目根目录 = ckpts 的上级目录
project_root = os.path.abspath(os.path.join(current_dir, '..'))
# 添加项目根目录到 sys.path
if project_root not in sys.path:
sys.path.insert(0, project_root)
from ThinkSound.data.datamodule import DataModule
from ThinkSound.models import create_model_from_config
from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
from ThinkSound.inference.sampling import sample, sample_discrete_euler
from pathlib import Path
from tqdm import tqdm
def main():
args = get_all_args()
if args.save_dir == '':
args.save_dir = args.results_dir
seed = args.seed
if os.environ.get("SLURM_PROCID") is not None:
seed += int(os.environ.get("SLURM_PROCID"))
seed_everything(seed, workers=True)
# Load config
if args.model_config == '':
args.model_config = "ThinkSound/configs/model_configs/thinksound.json"
with open(args.model_config) as f:
model_config = json.load(f)
duration = float(args.duration_sec)
sample_rate = model_config["sample_rate"]
latent_length = round(44100 / 64 / 32 * duration)
model_config["sample_size"] = duration * sample_rate
model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration)
model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration)
model_config["model"]["diffusion"]["config"]["latent_seq_len"] = latent_length
model = create_model_from_config(model_config)
# model.load_state_dict(torch.load(args.ckpt_dir))
# Step 2: 加载旧的 checkpoint(不包含cross-attn)
old_ckpt = torch.load(args.ckpt_dir, map_location='cpu')
# Step 3: 仅提取匹配的权重(名字和尺寸都要匹配)
model_state = model.state_dict()
matched_ckpt = {k: v for k, v in old_ckpt.items() if k in model_state and v.shape == model_state[k].shape}
print(f"[INFO] Loaded {len(matched_ckpt)} keys from old checkpoint")
# Step 4: 加载已有权重
model_state.update(matched_ckpt)
model.load_state_dict(model_state)
# Step 5: 初始化 cross-attn 模块(只初始化新增部分)
def init_cross_attn_weights(module):
from einops.layers.torch import Rearrange
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.RMSNorm) or module.__class__.__name__ == "RMSNorm":
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
import pdb; pdb.set_trace()
pass
# 只遍历 cross-attn 模块进行初始化
for name, module in model.named_modules():
if 'cross_attn' in name:
module.apply(init_cross_attn_weights)
print(f"[INIT] Initialized {name}")
# Step 6: 保存新权重
torch.save(model.state_dict(), 'ckpts/row_thinksound_light_cross_attn.ckpt')
print("[DONE] New checkpoint saved with old weights + initialized cross-attn.")
if __name__ == '__main__':
main()