Liyang Chen
commited on
Commit
·
e1b47be
1
Parent(s):
88fbe87
full pipeline
Browse files
init_cross_attn.py
CHANGED
@@ -9,6 +9,17 @@ from lightning.pytorch import seed_everything
|
|
9 |
import random
|
10 |
from datetime import datetime
|
11 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
from ThinkSound.data.datamodule import DataModule
|
14 |
from ThinkSound.models import create_model_from_config
|
@@ -62,6 +73,7 @@ def main():
|
|
62 |
|
63 |
# Step 5: 初始化 cross-attn 模块(只初始化新增部分)
|
64 |
def init_cross_attn_weights(module):
|
|
|
65 |
if isinstance(module, nn.Linear):
|
66 |
nn.init.xavier_uniform_(module.weight)
|
67 |
if module.bias is not None:
|
@@ -69,7 +81,14 @@ def main():
|
|
69 |
elif isinstance(module, nn.LayerNorm):
|
70 |
nn.init.ones_(module.weight)
|
71 |
nn.init.zeros_(module.bias)
|
|
|
|
|
|
|
|
|
|
|
72 |
|
|
|
|
|
73 |
# 只遍历 cross-attn 模块进行初始化
|
74 |
for name, module in model.named_modules():
|
75 |
if 'cross_attn' in name:
|
@@ -77,7 +96,7 @@ def main():
|
|
77 |
print(f"[INIT] Initialized {name}")
|
78 |
|
79 |
# Step 6: 保存新权重
|
80 |
-
torch.save(model.state_dict(), 'ckpts/
|
81 |
print("[DONE] New checkpoint saved with old weights + initialized cross-attn.")
|
82 |
|
83 |
if __name__ == '__main__':
|
|
|
9 |
import random
|
10 |
from datetime import datetime
|
11 |
import numpy as np
|
12 |
+
import sys
|
13 |
+
|
14 |
+
# 获取当前脚本所在目录(ckpts/)
|
15 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
16 |
+
|
17 |
+
# 项目根目录 = ckpts 的上级目录
|
18 |
+
project_root = os.path.abspath(os.path.join(current_dir, '..'))
|
19 |
+
|
20 |
+
# 添加项目根目录到 sys.path
|
21 |
+
if project_root not in sys.path:
|
22 |
+
sys.path.insert(0, project_root)
|
23 |
|
24 |
from ThinkSound.data.datamodule import DataModule
|
25 |
from ThinkSound.models import create_model_from_config
|
|
|
73 |
|
74 |
# Step 5: 初始化 cross-attn 模块(只初始化新增部分)
|
75 |
def init_cross_attn_weights(module):
|
76 |
+
from einops.layers.torch import Rearrange
|
77 |
if isinstance(module, nn.Linear):
|
78 |
nn.init.xavier_uniform_(module.weight)
|
79 |
if module.bias is not None:
|
|
|
81 |
elif isinstance(module, nn.LayerNorm):
|
82 |
nn.init.ones_(module.weight)
|
83 |
nn.init.zeros_(module.bias)
|
84 |
+
elif isinstance(module, nn.RMSNorm) or module.__class__.__name__ == "RMSNorm":
|
85 |
+
if hasattr(module, 'weight'):
|
86 |
+
nn.init.ones_(module.weight)
|
87 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
88 |
+
nn.init.zeros_(module.bias)
|
89 |
|
90 |
+
import pdb; pdb.set_trace()
|
91 |
+
pass
|
92 |
# 只遍历 cross-attn 模块进行初始化
|
93 |
for name, module in model.named_modules():
|
94 |
if 'cross_attn' in name:
|
|
|
96 |
print(f"[INIT] Initialized {name}")
|
97 |
|
98 |
# Step 6: 保存新权重
|
99 |
+
torch.save(model.state_dict(), 'ckpts/row_thinksound_light_cross_attn.ckpt')
|
100 |
print("[DONE] New checkpoint saved with old weights + initialized cross-attn.")
|
101 |
|
102 |
if __name__ == '__main__':
|
row_thinksound_light_cross_attn.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db9f641e234f91448d9ca1ec9254339ad64414cbdc2c637311ff06b985d8fb65
|
3 |
+
size 6026895638
|
thinksound_light_cross_attn.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:1a80397e1a4f44c9e698ce2b6cbacf5cc775ef598907e41b2241b285b9e7eb78
|
3 |
-
size 5909451670
|
|
|
|
|
|
|
|