Liyang Chen commited on
Commit
e1b47be
·
1 Parent(s): 88fbe87

full pipeline

Browse files
init_cross_attn.py CHANGED
@@ -9,6 +9,17 @@ from lightning.pytorch import seed_everything
9
  import random
10
  from datetime import datetime
11
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  from ThinkSound.data.datamodule import DataModule
14
  from ThinkSound.models import create_model_from_config
@@ -62,6 +73,7 @@ def main():
62
 
63
  # Step 5: 初始化 cross-attn 模块(只初始化新增部分)
64
  def init_cross_attn_weights(module):
 
65
  if isinstance(module, nn.Linear):
66
  nn.init.xavier_uniform_(module.weight)
67
  if module.bias is not None:
@@ -69,7 +81,14 @@ def main():
69
  elif isinstance(module, nn.LayerNorm):
70
  nn.init.ones_(module.weight)
71
  nn.init.zeros_(module.bias)
 
 
 
 
 
72
 
 
 
73
  # 只遍历 cross-attn 模块进行初始化
74
  for name, module in model.named_modules():
75
  if 'cross_attn' in name:
@@ -77,7 +96,7 @@ def main():
77
  print(f"[INIT] Initialized {name}")
78
 
79
  # Step 6: 保存新权重
80
- torch.save(model.state_dict(), 'ckpts/thinksound_light_cross_attn.ckpt')
81
  print("[DONE] New checkpoint saved with old weights + initialized cross-attn.")
82
 
83
  if __name__ == '__main__':
 
9
  import random
10
  from datetime import datetime
11
  import numpy as np
12
+ import sys
13
+
14
+ # 获取当前脚本所在目录(ckpts/)
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+
17
+ # 项目根目录 = ckpts 的上级目录
18
+ project_root = os.path.abspath(os.path.join(current_dir, '..'))
19
+
20
+ # 添加项目根目录到 sys.path
21
+ if project_root not in sys.path:
22
+ sys.path.insert(0, project_root)
23
 
24
  from ThinkSound.data.datamodule import DataModule
25
  from ThinkSound.models import create_model_from_config
 
73
 
74
  # Step 5: 初始化 cross-attn 模块(只初始化新增部分)
75
  def init_cross_attn_weights(module):
76
+ from einops.layers.torch import Rearrange
77
  if isinstance(module, nn.Linear):
78
  nn.init.xavier_uniform_(module.weight)
79
  if module.bias is not None:
 
81
  elif isinstance(module, nn.LayerNorm):
82
  nn.init.ones_(module.weight)
83
  nn.init.zeros_(module.bias)
84
+ elif isinstance(module, nn.RMSNorm) or module.__class__.__name__ == "RMSNorm":
85
+ if hasattr(module, 'weight'):
86
+ nn.init.ones_(module.weight)
87
+ if hasattr(module, 'bias') and module.bias is not None:
88
+ nn.init.zeros_(module.bias)
89
 
90
+ import pdb; pdb.set_trace()
91
+ pass
92
  # 只遍历 cross-attn 模块进行初始化
93
  for name, module in model.named_modules():
94
  if 'cross_attn' in name:
 
96
  print(f"[INIT] Initialized {name}")
97
 
98
  # Step 6: 保存新权重
99
+ torch.save(model.state_dict(), 'ckpts/row_thinksound_light_cross_attn.ckpt')
100
  print("[DONE] New checkpoint saved with old weights + initialized cross-attn.")
101
 
102
  if __name__ == '__main__':
row_thinksound_light_cross_attn.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1a80397e1a4f44c9e698ce2b6cbacf5cc775ef598907e41b2241b285b9e7eb78
3
- size 5909451670
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db9f641e234f91448d9ca1ec9254339ad64414cbdc2c637311ff06b985d8fb65
3
+ size 6026895638
thinksound_light_cross_attn.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1a80397e1a4f44c9e698ce2b6cbacf5cc775ef598907e41b2241b285b9e7eb78
3
- size 5909451670