yuanshuai commited on
Commit
03a5a24
·
verified ·
1 Parent(s): af40883

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +105 -65
  2. draft/qwen2.py +641 -0
README.md CHANGED
@@ -12,33 +12,30 @@ base_model:
12
  <div align="center">
13
 
14
  # Baichuan-M2-32B
15
-
16
  [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
17
  [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-yellow)](https://huggingface.co/baichuan-inc/Baichuan-M2-32B)
18
 
19
  </div>
20
 
21
- ## 🌟 模型简介
22
-
23
- Baichuan-M2-32B 是百川智能推出的医疗增强推理模型,这是百川开源发布的第二个医疗增强模型,专为真实世界的医疗推理任务设计。该模型基于 Qwen2.5-32B 基座,通过创新的大型验证器系统(Large Verifier System)从真实世界的医疗问题出发,进行医疗领域后训练对齐,在保持模型通用能力的同时,实现了医疗效果的突破性提升。
24
 
25
- **模型特点:**
26
 
27
- Baichuan-M2 采用了三个核心技术创新:首先通过**大型验证器系统**,结合医疗场景特点设计了全面的医疗验证体系,包含患者模拟器和多维度验证机制;其次通过**医疗领域适应性增强**的中期训练(Mid-Training),在保持通用能力的同时实现轻量高效的医疗领域适应;最后采用**多阶段强化学习**策略,将复杂的 RL 任务分解为层次化的训练阶段,逐步提升模型的医学常识、推理和患者交互能力。
28
 
29
- **核心亮点:**
30
- - 🏆 **全球最强医疗开源模型**:在 HealthBench 评测集上超越所有开源模型及众多前沿闭源模型,是最接近 GPT-5 医疗能力的开源大模型
31
- - 🧠 **医生思维对齐**:基于真实病例数据和患者模拟器训练,具备临床诊断思维和鲁棒的医患交互能力
32
- - ⚡ **高效部署与推理**:支持 4bit 量化在 RTX4090 单卡部署,MTP 版本单用户场景下 token 吞吐提升 58.5%
33
 
 
 
 
 
34
 
 
35
 
36
- ## 📊 性能表现
37
 
38
- ### HealthBench指标
39
-
40
- | 模型名称 | HealthBench | HealthBench-Hard | HealthBench-Consensus |
41
- |----------|-------------|------------------|-----------------------|
42
  | Baichuan-M2 | 60.1 | 34.7 | 91.5 |
43
  | gpt-oss-120b | 57.6 | 30 | 90 |
44
  | Qwen3-235B-A22B-Thinking-2507 | 55.2 | 25.9 | 90.6 |
@@ -47,80 +44,123 @@ Baichuan-M2 采用了三个核心技术创新:首先通过**大型验证器系
47
  | Kimi-K2 | 43 | 10.7 | 90.9 |
48
  | gpt-oss-20b | 42.5 | 10.8 | 82.6 |
49
 
50
- ### 通用指标
51
 
52
- | 评测集 | Baichuan-M2-32B | Qwen3-32B |
53
- |--------|-----------------|-----------|
54
  | AIME24 | 83.4 | 81.4 |
55
  | AIME25 | 72.9 | 72.9 |
56
  | Arena-Hard-v2.0 | 45.8 | 44.5 |
57
  | CFBench | 77.6 | 75.7 |
58
  | WritingBench | 8.56 | 7.90 |
59
 
60
- *备注:AIME max_tokens 设为 64k,其他评测集设为 32k,温度统一为 0.6。*
61
-
62
 
63
- ## 🛠️ 技术特色
64
 
65
- ### 大型验证器系统
66
- - **患者模拟器**:基于真实病例构建的虚拟患者系统
67
- - **多维度验证**:医学准确性、回答完整性、追问感知等 8 个维度
68
- - **动态评分**:实时生成评分标准,适应复杂临床环境
 
 
 
 
69
 
70
- ### 医疗领域适应
71
- - **Mid-Training**:医疗知识注入的同时保持通用能力
72
- - **强化学习**:多阶段 RL 策略优化
73
- - **通专兼顾**:2:2:1 配比的医疗、通用、数学数据
74
 
75
- ## 🔧 快速开始
76
-
77
- ### 安装使用
78
-
79
- ```bash
80
- # 安装依赖
81
- pip install transformers torch vllm sglang
82
-
83
- # Transformers 使用
84
  from transformers import AutoTokenizer, AutoModelForCausalLM
85
  model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-M2-32B", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # vLLM 使用(推荐)
88
- from vllm import LLM
89
- llm = LLM(model="baichuan-inc/Baichuan-M2-32B", trust_remote_code=True)
90
-
91
- # SGLang 使用
92
- python -m sglang.launch_server --model-path baichuan-inc/Baichuan-M2-32B
93
  ```
94
 
95
- ## ⚠️ 使用须知
 
 
 
 
 
 
 
 
96
 
97
- 1. **医疗免责声明**:本模型仅供研究和参考,不能替代专业医疗诊断和治疗建议
98
- 2. **适用场景**:医学教育、健康咨询、临床辅助决策等
99
- 3. **安全使用**:建议在专业医疗人员指导下使用
100
 
101
- ## 📄 许可证
102
-
103
- 本项目采用 [Apache License 2.0](LICENSE) 开源协议,欢迎研究和商业使用。
104
-
105
- ## 🤝 致谢
106
-
107
- - 基础模型:Qwen2.5-32B
108
- - 训练框架:VERL
109
- - 推理引擎:vLLM、SGLang
110
- - 量化方法:AutoRound、GPTQ、QuaRot、QQQ
 
 
 
 
 
111
 
112
- 感谢开源社区的贡献,我们将持续回馈社区,推动医疗 AI 技术发展。
 
 
 
113
 
114
- ## 📞 联系我们
 
115
 
116
- - 更多资源:[百川智能官网](https://www.baichuan-ai.com)
 
 
 
 
 
117
 
118
- - 技术交流:[GitHub](https://github.com/baichuan-inc)
 
 
119
 
120
  ---
121
-
122
  <div align="center">
123
 
124
- **让AI助力医疗,让健康触手可及**
 
 
125
 
126
- </div>
 
12
  <div align="center">
13
 
14
  # Baichuan-M2-32B
 
15
  [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
16
  [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-yellow)](https://huggingface.co/baichuan-inc/Baichuan-M2-32B)
17
 
18
  </div>
19
 
20
+ ## 🌟 Model Overview
 
 
21
 
22
+ Baichuan-M2-32B is BaiChuan AI's medical-enhanced reasoning model, the second medical model released by BaiChuan. Designed for real-world medical reasoning tasks, this model builds upon Qwen2.5-32B with an innovative Large Verifier System. Through domain-specific fine-tuning on real-world medical questions, it achieves breakthrough medical performance while maintaining strong general capabilities.
23
 
24
+ **Model Features:**
25
 
26
+ Baichuan-M2 incorporates three core technical innovations: First, through the Large Verifier System, it combines medical scenario characteristics to design a comprehensive medical verification framework, including patient simulators and multi-dimensional verification mechanisms; second, through medical domain adaptation enhancement via Mid-Training, it achieves lightweight and efficient medical domain adaptation while preserving general capabilities; finally, it employs a multi-stage reinforcement learning strategy, decomposing complex RL tasks into hierarchical training stages to progressively enhance the model's medical knowledge, reasoning, and patient interaction capabilities.
 
 
 
27
 
28
+ **Core Highlights:**
29
+ - 🏆 **World's Leading Open-Source Medical Model**: Outperforms all open-source models and many proprietary models on HealthBench, achieving medical capabilities closest to GPT-5
30
+ - 🧠 **Doctor-Thinking Alignment**: Trained on real clinical cases and patient simulators, with clinical diagnostic thinking and robust patient interaction capabilities
31
+ - ⚡ **Efficient Deployment**: Supports 4-bit quantization for single-RTX4090 deployment, with 58.5% higher token throughput in MTP version for single-user scenarios
32
 
33
+ ## 📊 Performance Metrics
34
 
35
+ ### HealthBench Scores
36
 
37
+ | Model Name | HealthBench | HealthBench-Hard | HealthBench-Consensus |
38
+ |------------|-------------|------------------|-----------------------|
 
 
39
  | Baichuan-M2 | 60.1 | 34.7 | 91.5 |
40
  | gpt-oss-120b | 57.6 | 30 | 90 |
41
  | Qwen3-235B-A22B-Thinking-2507 | 55.2 | 25.9 | 90.6 |
 
44
  | Kimi-K2 | 43 | 10.7 | 90.9 |
45
  | gpt-oss-20b | 42.5 | 10.8 | 82.6 |
46
 
47
+ ### General Performance
48
 
49
+ | Benchmark | Baichuan-M2-32B | Qwen3-32B |
50
+ |-----------|-----------------|-----------|
51
  | AIME24 | 83.4 | 81.4 |
52
  | AIME25 | 72.9 | 72.9 |
53
  | Arena-Hard-v2.0 | 45.8 | 44.5 |
54
  | CFBench | 77.6 | 75.7 |
55
  | WritingBench | 8.56 | 7.90 |
56
 
57
+ *Note: AIME uses max_tokens=64k, others use 32k; temperature=0.6 for all tests.*
 
58
 
59
+ ## 🔧 Technical Features
60
 
61
+ ### Large Verifier System
62
+ - **Patient Simulator**: Virtual patient system based on real clinical cases
63
+ - **Multi-Dimensional Verification**: 8 dimensions including medical accuracy, response completeness, and follow-up awareness
64
+ - **Dynamic Scoring**: Real-time generation of adaptive evaluation criteria for complex clinical scenarios
65
+ ### Medical Domain Adaptation
66
+ - **Mid-Training**: Medical knowledge injection while preserving general capabilities
67
+ - **Reinforcement Learning**: Multi-stage RL strategy optimization
68
+ - **General-Specialized Balance**: Carefully balanced medical, general, and mathematical composite training data
69
 
70
+ ## ⚙️ Quick Start
 
 
 
71
 
72
+ ```python
73
+ # 1. load model
 
 
 
 
 
 
 
74
  from transformers import AutoTokenizer, AutoModelForCausalLM
75
  model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-M2-32B", trust_remote_code=True)
76
+ tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-M2-32B")
77
+ # 2. Input prompt text
78
+ prompt = "Got a big swelling after a bug bite. Need help reducing it."
79
+ # 3. Encode the input text for the model
80
+ messages = [
81
+ {"role": "user", "content": prompt}
82
+ ]
83
+ text = tokenizer.apply_chat_template(
84
+ messages,
85
+ tokenize=False,
86
+ add_generation_prompt=True,
87
+ thinking_mode='on' # on/off/auto
88
+ )
89
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
90
+ # 4. Generate text
91
+ generated_ids = model.generate(
92
+ **model_inputs,
93
+ max_new_tokens=4096
94
+ )
95
+ output_ids = [
96
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
97
+ ][0].tolist()
98
+ # 5. parsing thinking content
99
+ try:
100
+ # rindex finding 151668 (</think>)
101
+ index = len(output_ids) - output_ids[::-1].index(151668)
102
+ except ValueError:
103
+ index = 0
104
+
105
+ thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
106
+ content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
107
+
108
+ print("thinking content:", thinking_content)
109
+ print("content:", content)
110
 
 
 
 
 
 
 
111
  ```
112
 
113
+ For deployment, you can use `sglang>=0.4.6.post1` or `vllm>=0.9.0` or to create an OpenAI-compatible API endpoint:
114
+ - SGLang:
115
+ ```shell
116
+ python -m sglang.launch_server --model-path baichuan-inc/Baichuan-M2-32B --reasoning-parser qwen3
117
+ ```
118
+ - vLLM:
119
+ ```shell
120
+ vllm serve baichuan-inc/Baichuan-M2-32B --reasoning-parser qwen3
121
+ ```
122
 
123
+ ## MTP inference with SGLang
 
 
124
 
125
+ 1. Replace the qwen2.py file in the sglang installation directory with draft/qwen2.py.
126
+ 2. Launch sglang:
127
+ ```
128
+ python3 -m sglang.launch_server \
129
+ --model Baichuan-M2-32B \
130
+ --speculative-algorithm EAGLE3 \
131
+ --speculative-draft-model-path Baichuan-M2-32B/draft \
132
+ --speculative-num-steps 6 \
133
+ --speculative-eagle-topk 10 \
134
+ --speculative-num-draft-tokens 32 \
135
+ --mem-fraction 0.9 \
136
+ --cuda-graph-max-bs 2 \
137
+ --reasoning-parser qwen3 \
138
+ --dtype bfloat16
139
+ ```
140
 
141
+ ## ⚠️ Usage Notices
142
+ 1. **Medical Disclaimer**: For research and reference only; cannot replace professional medical diagnosis or treatment
143
+ 2. **Intended Use Cases**: Medical education, health consultation, clinical decision support
144
+ 3. **Safe Use**: Recommended under guidance of medical professionals
145
 
146
+ ## 📄 License
147
+ Licensed under the [Apache License 2.0](LICENSE). Research and commercial use permitted.
148
 
149
+ ## 🤝 Acknowledgements
150
+ - Base Model: Qwen2.5-32B
151
+ - Training Framework: verl
152
+ - Inference Engines: vLLM, SGLang
153
+ - Quantization: AutoRound, GPTQ
154
+ Thank you to the open-source community. We commit to continuous contribution and advancement of healthcare AI.
155
 
156
+ ## 📞 Contact Us
157
+ - Resources: [BaiChuan AI Website](https://www.baichuan-ai.com)
158
+ - Technical Support: [GitHub](https://github.com/baichuan-inc)
159
 
160
  ---
 
161
  <div align="center">
162
 
163
+ **Empowering Healthcare with AI, Making Health Accessible to All**
164
+
165
+ </div>
166
 
 
draft/qwen2.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from llama2.py
16
+ # Modify details for the adaptation of Qwen2 model.
17
+ """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
+ import logging
19
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union, List
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from sglang.srt.distributed import (
25
+ get_pp_group,
26
+ get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ )
29
+ from sglang.srt.layers.activation import SiluAndMul
30
+ from sglang.srt.layers.layernorm import RMSNorm
31
+ from sglang.srt.layers.linear import (
32
+ MergedColumnParallelLinear,
33
+ QKVParallelLinear,
34
+ RowParallelLinear,
35
+ )
36
+ from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.pooler import Pooler, PoolingType
38
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
+ from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.rotary_embedding import get_rope
41
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ ParallelLMHead,
44
+ VocabParallelEmbedding,
45
+ )
46
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
+ from sglang.srt.model_loader.weight_utils import (
49
+ default_weight_loader,
50
+ kv_cache_scales_loader,
51
+ )
52
+ from sglang.srt.utils import add_prefix, make_layers
53
+
54
+ Qwen2Config = None
55
+
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ class Qwen2MLP(nn.Module):
61
+ def __init__(
62
+ self,
63
+ hidden_size: int,
64
+ intermediate_size: int,
65
+ hidden_act: str,
66
+ quant_config: Optional[QuantizationConfig] = None,
67
+ prefix: str = "",
68
+ ) -> None:
69
+ super().__init__()
70
+ self.gate_up_proj = MergedColumnParallelLinear(
71
+ hidden_size,
72
+ [intermediate_size] * 2,
73
+ bias=False,
74
+ quant_config=quant_config,
75
+ prefix=add_prefix("gate_up_proj", prefix),
76
+ )
77
+ self.down_proj = RowParallelLinear(
78
+ intermediate_size,
79
+ hidden_size,
80
+ bias=False,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("down_proj", prefix),
83
+ )
84
+ if hidden_act != "silu":
85
+ raise ValueError(
86
+ f"Unsupported activation: {hidden_act}. "
87
+ "Only silu is supported for now."
88
+ )
89
+ self.act_fn = SiluAndMul()
90
+
91
+ def forward(self, x):
92
+ gate_up, _ = self.gate_up_proj(x)
93
+ x = self.act_fn(gate_up)
94
+ x, _ = self.down_proj(x)
95
+ return x
96
+
97
+
98
+ class Qwen2Attention(nn.Module):
99
+ def __init__(
100
+ self,
101
+ hidden_size: int,
102
+ num_heads: int,
103
+ num_kv_heads: int,
104
+ head_dim: Optional[int] = None,
105
+ layer_id: int = 0,
106
+ rope_theta: float = 1000000,
107
+ rope_scaling: Optional[Dict[str, Any]] = None,
108
+ max_position_embeddings: int = 32768,
109
+ quant_config: Optional[QuantizationConfig] = None,
110
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
111
+ prefix: str = "",
112
+ ) -> None:
113
+ super().__init__()
114
+ self.hidden_size = hidden_size
115
+ tp_size = get_tensor_model_parallel_world_size()
116
+ self.total_num_heads = num_heads
117
+ assert self.total_num_heads % tp_size == 0
118
+ self.num_heads = self.total_num_heads // tp_size
119
+ self.total_num_kv_heads = num_kv_heads
120
+ if self.total_num_kv_heads >= tp_size:
121
+ # Number of KV heads is greater than TP size, so we partition
122
+ # the KV heads across multiple tensor parallel GPUs.
123
+ assert self.total_num_kv_heads % tp_size == 0
124
+ else:
125
+ # Number of KV heads is less than TP size, so we replicate
126
+ # the KV heads across multiple tensor parallel GPUs.
127
+ assert tp_size % self.total_num_kv_heads == 0
128
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
129
+ if head_dim is not None:
130
+ self.head_dim = head_dim
131
+ else:
132
+ self.head_dim = hidden_size // self.total_num_heads
133
+ self.q_size = self.num_heads * self.head_dim
134
+ self.kv_size = self.num_kv_heads * self.head_dim
135
+ self.scaling = self.head_dim**-0.5
136
+ self.rope_theta = rope_theta
137
+ self.max_position_embeddings = max_position_embeddings
138
+
139
+ self.qkv_proj = QKVParallelLinear(
140
+ hidden_size,
141
+ self.head_dim,
142
+ self.total_num_heads,
143
+ self.total_num_kv_heads,
144
+ bias=True,
145
+ quant_config=quant_config,
146
+ prefix=add_prefix("qkv_proj", prefix),
147
+ )
148
+ self.o_proj = RowParallelLinear(
149
+ self.total_num_heads * self.head_dim,
150
+ hidden_size,
151
+ bias=False,
152
+ quant_config=quant_config,
153
+ prefix=add_prefix("o_proj", prefix),
154
+ )
155
+
156
+ self.rotary_emb = get_rope(
157
+ self.head_dim,
158
+ rotary_dim=self.head_dim,
159
+ max_position=max_position_embeddings,
160
+ base=rope_theta,
161
+ rope_scaling=rope_scaling,
162
+ dual_chunk_attention_config=dual_chunk_attention_config,
163
+ )
164
+ self.attn = RadixAttention(
165
+ self.num_heads,
166
+ self.head_dim,
167
+ self.scaling,
168
+ num_kv_heads=self.num_kv_heads,
169
+ layer_id=layer_id,
170
+ quant_config=quant_config,
171
+ prefix=add_prefix("attn", prefix),
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ forward_batch: ForwardBatch,
179
+ ) -> torch.Tensor:
180
+ qkv, _ = self.qkv_proj(hidden_states)
181
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
182
+ q, k = self.rotary_emb(positions, q, k)
183
+ attn_output = self.attn(q, k, v, forward_batch)
184
+ output, _ = self.o_proj(attn_output)
185
+ return output
186
+
187
+
188
+ class Qwen2DecoderLayer(nn.Module):
189
+ def __init__(
190
+ self,
191
+ config: Qwen2Config,
192
+ layer_id: int = 0,
193
+ quant_config: Optional[QuantizationConfig] = None,
194
+ prefix: str = "",
195
+ alt_stream: Optional[torch.cuda.Stream] = None,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.hidden_size = config.hidden_size
199
+ rope_theta = getattr(config, "rope_theta", 1000000)
200
+ rope_scaling = getattr(config, "rope_scaling", None)
201
+ max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
202
+ head_dim = getattr(config, "head_dim", None)
203
+ dual_chunk_attention_config = getattr(
204
+ config, "dual_chunk_attention_config", None
205
+ )
206
+ self.self_attn = Qwen2Attention(
207
+ hidden_size=self.hidden_size,
208
+ num_heads=config.num_attention_heads,
209
+ num_kv_heads=config.num_key_value_heads,
210
+ head_dim=head_dim,
211
+ layer_id=layer_id,
212
+ rope_theta=rope_theta,
213
+ rope_scaling=rope_scaling,
214
+ max_position_embeddings=max_position_embeddings,
215
+ quant_config=quant_config,
216
+ dual_chunk_attention_config=dual_chunk_attention_config,
217
+ prefix=add_prefix("self_attn", prefix),
218
+ )
219
+ self.mlp = Qwen2MLP(
220
+ hidden_size=self.hidden_size,
221
+ intermediate_size=config.intermediate_size,
222
+ hidden_act=config.hidden_act,
223
+ quant_config=quant_config,
224
+ prefix=add_prefix("mlp", prefix),
225
+ )
226
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
227
+ self.post_attention_layernorm = RMSNorm(
228
+ config.hidden_size, eps=config.rms_norm_eps
229
+ )
230
+
231
+ def forward(
232
+ self,
233
+ positions: torch.Tensor,
234
+ hidden_states: torch.Tensor,
235
+ forward_batch: ForwardBatch,
236
+ residual: Optional[torch.Tensor],
237
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
238
+ # Self Attention
239
+ if residual is None:
240
+ residual = hidden_states
241
+ hidden_states = self.input_layernorm(hidden_states)
242
+ else:
243
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
244
+ hidden_states = self.self_attn(
245
+ positions=positions,
246
+ hidden_states=hidden_states,
247
+ forward_batch=forward_batch,
248
+ )
249
+
250
+ # Fully Connected
251
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
252
+ hidden_states = self.mlp(hidden_states)
253
+ return hidden_states, residual
254
+
255
+
256
+ class Qwen2Model(nn.Module):
257
+ def __init__(
258
+ self,
259
+ config: Qwen2Config,
260
+ quant_config: Optional[QuantizationConfig] = None,
261
+ prefix: str = "",
262
+ decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
263
+ alt_stream: Optional[torch.cuda.Stream] = None,
264
+ ) -> None:
265
+ super().__init__()
266
+ self.config = config
267
+ self.padding_idx = config.pad_token_id
268
+ self.vocab_size = config.vocab_size
269
+ self.pp_group = get_pp_group()
270
+
271
+ if self.pp_group.is_first_rank:
272
+ self.embed_tokens = VocabParallelEmbedding(
273
+ config.vocab_size,
274
+ config.hidden_size,
275
+ quant_config=quant_config,
276
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
277
+ prefix=add_prefix("embed_tokens", prefix),
278
+ )
279
+ else:
280
+ self.embed_tokens = PPMissingLayer()
281
+
282
+ # Use the provided decoder layer type or default to Qwen2DecoderLayer
283
+ decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
284
+ self.layers, self.start_layer, self.end_layer = make_layers(
285
+ config.num_hidden_layers,
286
+ lambda idx, prefix: decoder_layer_type(
287
+ layer_id=idx,
288
+ config=config,
289
+ quant_config=quant_config,
290
+ prefix=prefix,
291
+ alt_stream=alt_stream,
292
+ ),
293
+ pp_rank=self.pp_group.rank_in_group,
294
+ pp_size=self.pp_group.world_size,
295
+ prefix=add_prefix("layers", prefix),
296
+ )
297
+ if self.pp_group.is_last_rank:
298
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
299
+ else:
300
+ self.norm = PPMissingLayer(return_tuple=True)
301
+
302
+ # For EAGLE3 support
303
+ self.layers_to_capture = []
304
+
305
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
306
+ if hasattr(self.config, "scale_emb"):
307
+ return self.get_input_embeddings()(input_ids) * self.config.scale_emb
308
+ else:
309
+ return self.get_input_embeddings()(input_ids)
310
+
311
+ def get_input_embeddings(self) -> nn.Embedding:
312
+ return self.embed_tokens
313
+
314
+ def forward(
315
+ self,
316
+ input_ids: torch.Tensor,
317
+ positions: torch.Tensor,
318
+ forward_batch: ForwardBatch,
319
+ input_embeds: torch.Tensor = None,
320
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
321
+ ) -> Union[torch.Tensor, PPProxyTensors]:
322
+ if self.pp_group.is_first_rank:
323
+ if input_embeds is None:
324
+ hidden_states = self.embed_tokens(input_ids)
325
+ else:
326
+ hidden_states = input_embeds
327
+ residual = None
328
+ else:
329
+ assert pp_proxy_tensors is not None
330
+ hidden_states = pp_proxy_tensors["hidden_states"]
331
+ residual = pp_proxy_tensors["residual"]
332
+
333
+ aux_hidden_states = []
334
+ for i in range(self.start_layer, self.end_layer):
335
+ if i in self.layers_to_capture:
336
+ aux_hidden_states.append(
337
+ hidden_states + residual if residual is not None else hidden_states
338
+ )
339
+ layer = self.layers[i]
340
+ hidden_states, residual = layer(
341
+ positions,
342
+ hidden_states,
343
+ forward_batch,
344
+ residual,
345
+ )
346
+ if not self.pp_group.is_last_rank:
347
+ return PPProxyTensors(
348
+ {
349
+ "hidden_states": hidden_states,
350
+ "residual": residual,
351
+ }
352
+ )
353
+ else:
354
+ if hidden_states.shape[0] != 0:
355
+ if residual is None:
356
+ hidden_states = self.norm(hidden_states)
357
+ else:
358
+ hidden_states, _ = self.norm(hidden_states, residual)
359
+
360
+ if len(aux_hidden_states) == 0:
361
+ return hidden_states
362
+
363
+ return hidden_states, aux_hidden_states
364
+
365
+ # If this function is called, it should always initialize KV cache scale
366
+ # factors (or else raise an exception). Thus, handled exceptions should
367
+ # make sure to leave KV cache scale factors in a known good (dummy) state
368
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
369
+ tp_size = get_tensor_model_parallel_world_size()
370
+ tp_rank = get_tensor_model_parallel_rank()
371
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
372
+ quantization_param_path,
373
+ tp_rank,
374
+ tp_size,
375
+ self.config.num_hidden_layers,
376
+ self.config.__class__.model_type,
377
+ ):
378
+ if not isinstance(self.layers[layer_idx], nn.Identity):
379
+ layer_self_attn = self.layers[layer_idx].self_attn
380
+ if hasattr(layer_self_attn.attn, "k_scale"):
381
+ layer_self_attn.attn.k_scale = scaling_factor
382
+ layer_self_attn.attn.v_scale = scaling_factor
383
+ else:
384
+ raise RuntimeError(
385
+ "Self attention has no KV cache scaling " "factor attribute!"
386
+ )
387
+
388
+
389
+ class Qwen2ForCausalLM(nn.Module):
390
+ # BitandBytes specific attributes
391
+ default_bitsandbytes_target_modules = [
392
+ ".gate_proj.",
393
+ ".down_proj.",
394
+ ".up_proj.",
395
+ ".q_proj.",
396
+ ".k_proj.",
397
+ ".v_proj.",
398
+ ".o_proj.",
399
+ ]
400
+ bitsandbytes_stacked_params_mapping = {
401
+ # shard_name, weight_name, index
402
+ "q_proj": ("qkv_proj", 0),
403
+ "k_proj": ("qkv_proj", 1),
404
+ "v_proj": ("qkv_proj", 2),
405
+ "gate_proj": ("gate_up_proj", 0),
406
+ "up_proj": ("gate_up_proj", 1),
407
+ }
408
+
409
+ def __init__(
410
+ self,
411
+ config: Qwen2Config,
412
+ quant_config: Optional[QuantizationConfig] = None,
413
+ prefix: str = "",
414
+ ) -> None:
415
+ super().__init__()
416
+ self.pp_group = get_pp_group()
417
+ self.config = config
418
+ self.quant_config = quant_config
419
+ self.model = Qwen2Model(
420
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
421
+ )
422
+ self.capture_aux_hidden_states = False
423
+
424
+ # handle the lm head on different pp ranks
425
+ if self.pp_group.is_last_rank:
426
+ if self.pp_group.world_size == 1 and config.tie_word_embeddings:
427
+ self.lm_head = self.model.embed_tokens
428
+ else:
429
+ self.lm_head = ParallelLMHead(
430
+ config.vocab_size,
431
+ config.hidden_size,
432
+ quant_config=quant_config,
433
+ prefix=add_prefix("lm_head", prefix),
434
+ )
435
+ else:
436
+ # ranks other than the last rank will have a placeholder layer
437
+ self.lm_head = PPMissingLayer()
438
+
439
+ # perform weight tying for PP
440
+ if self.pp_group.world_size > 1 and config.tie_word_embeddings:
441
+ if self.pp_group.is_first_rank:
442
+ self.pp_group.send(
443
+ self.model.embed_tokens.weight, dst=self.pp_group.last_rank
444
+ )
445
+ else:
446
+ emb_token_weight = self.pp_group.recv(
447
+ size=(config.vocab_size, config.hidden_size),
448
+ dtype=next(self.model.parameters()).dtype,
449
+ src=self.pp_group.first_rank,
450
+ )
451
+ self.lm_head.weight.copy_(emb_token_weight)
452
+
453
+ self.logits_processor = LogitsProcessor(config)
454
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
455
+
456
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
457
+ return self.model.get_input_embedding(input_ids)
458
+
459
+ def get_input_embeddings(self) -> nn.Embedding:
460
+ return self.model.embed_tokens
461
+
462
+ @torch.no_grad()
463
+ def forward(
464
+ self,
465
+ input_ids: torch.Tensor,
466
+ positions: torch.Tensor,
467
+ forward_batch: ForwardBatch,
468
+ input_embeds: torch.Tensor = None,
469
+ get_embedding: bool = False,
470
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
471
+ ) -> torch.Tensor:
472
+ hidden_states = self.model(
473
+ input_ids,
474
+ positions,
475
+ forward_batch,
476
+ input_embeds,
477
+ pp_proxy_tensors=pp_proxy_tensors,
478
+ )
479
+ aux_hidden_states = None
480
+ if self.capture_aux_hidden_states:
481
+ hidden_states, aux_hidden_states = hidden_states
482
+
483
+ if self.pp_group.is_last_rank:
484
+ if not get_embedding:
485
+ return self.logits_processor(
486
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
487
+ )
488
+ else:
489
+ return self.pooler(hidden_states, forward_batch)
490
+ else:
491
+ return hidden_states
492
+
493
+ @torch.no_grad()
494
+ def forward_split_prefill(
495
+ self,
496
+ input_ids: torch.Tensor,
497
+ positions: torch.Tensor,
498
+ forward_batch: ForwardBatch,
499
+ split_interval: Tuple[int, int], # [start, end) 0-based
500
+ input_embeds: torch.Tensor = None,
501
+ ):
502
+ start, end = split_interval
503
+ # embed
504
+ if start == 0:
505
+ if input_embeds is None:
506
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
507
+ else:
508
+ forward_batch.hidden_states = input_embeds
509
+ # decoder layer
510
+ for i in range(start, end):
511
+ layer = self.model.layers[i]
512
+ forward_batch.hidden_states, forward_batch.residual = layer(
513
+ positions,
514
+ forward_batch.hidden_states,
515
+ forward_batch,
516
+ forward_batch.residual,
517
+ )
518
+
519
+ if end == self.model.config.num_hidden_layers:
520
+ # norm
521
+ hidden_states, _ = self.model.norm(
522
+ forward_batch.hidden_states, forward_batch.residual
523
+ )
524
+ forward_batch.hidden_states = hidden_states
525
+ # logits process
526
+ result = self.logits_processor(
527
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
528
+ )
529
+ else:
530
+ result = None
531
+
532
+ return result
533
+
534
+ @property
535
+ def start_layer(self):
536
+ return self.model.start_layer
537
+
538
+ @property
539
+ def end_layer(self):
540
+ return self.model.end_layer
541
+
542
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
543
+ stacked_params_mapping = [
544
+ # (param_name, shard_name, shard_id)
545
+ ("qkv_proj", "q_proj", "q"),
546
+ ("qkv_proj", "k_proj", "k"),
547
+ ("qkv_proj", "v_proj", "v"),
548
+ ("gate_up_proj", "gate_proj", 0),
549
+ ("gate_up_proj", "up_proj", 1),
550
+ ]
551
+
552
+ params_dict = dict(self.named_parameters())
553
+ for name, loaded_weight in weights:
554
+ layer_id = get_layer_id(name)
555
+ if (
556
+ layer_id is not None
557
+ and hasattr(self.model, "start_layer")
558
+ and (
559
+ layer_id < self.model.start_layer
560
+ or layer_id >= self.model.end_layer
561
+ )
562
+ ):
563
+ continue
564
+
565
+ if "rotary_emb.inv_freq" in name or "projector" in name:
566
+ continue
567
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
568
+ # Models trained using ColossalAI may include these tensors in
569
+ # the checkpoint. Skip them.
570
+ continue
571
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
572
+ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
573
+ # Handle pp weight tying here
574
+ # find the embed_tokens.weight in the weights
575
+ embed_token_weights = next(
576
+ filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
577
+ )[1]
578
+ loaded_weight = embed_token_weights
579
+ else:
580
+ continue
581
+ if name.startswith("model.vision_tower") and name not in params_dict:
582
+ continue
583
+
584
+ for param_name, weight_name, shard_id in stacked_params_mapping:
585
+ if weight_name not in name:
586
+ continue
587
+ name = name.replace(weight_name, param_name)
588
+ # Skip loading extra bias for GPTQ models.
589
+ if name.endswith(".bias") and name not in params_dict:
590
+ continue
591
+ if name not in params_dict:
592
+ continue
593
+ param = params_dict[name]
594
+ weight_loader = param.weight_loader
595
+ weight_loader(param, loaded_weight, shard_id)
596
+ break
597
+ else:
598
+ # Skip loading extra bias for GPTQ models.
599
+ if name.endswith(".bias") and name not in params_dict:
600
+ continue
601
+
602
+ if name in params_dict.keys():
603
+ param = params_dict[name]
604
+ weight_loader = getattr(
605
+ param, "weight_loader", default_weight_loader
606
+ )
607
+ weight_loader(param, loaded_weight)
608
+ else:
609
+ logger.warning(f"Parameter {name} not found in params_dict")
610
+
611
+ def get_embed_and_head(self):
612
+ return self.model.embed_tokens.weight, self.lm_head.weight
613
+
614
+ def set_embed_and_head(self, embed, head):
615
+ del self.model.embed_tokens.weight
616
+ del self.lm_head.weight
617
+ self.model.embed_tokens.weight = embed
618
+ self.lm_head.weight = head
619
+ torch.cuda.empty_cache()
620
+ torch.cuda.synchronize()
621
+
622
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
623
+ self.model.load_kv_cache_scales(quantization_param_path)
624
+
625
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
626
+ if not self.pp_group.is_last_rank:
627
+ return
628
+
629
+ self.capture_aux_hidden_states = True
630
+ if layer_ids is None:
631
+ num_layers = self.config.num_hidden_layers
632
+ self.model.layers_to_capture = [
633
+ 2,
634
+ num_layers // 2,
635
+ num_layers - 3,
636
+ ] # Specific layers for EAGLE3 support
637
+ else:
638
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
639
+
640
+
641
+ EntryClass = Qwen2ForCausalLM