""" | |
Configuration class for MoLA-LM | |
""" | |
from transformers import PretrainedConfig | |
from typing import Dict, List | |
EXPERTS_LIST = [ | |
"0", | |
"1", | |
"2", | |
"3", | |
"4", | |
"5", | |
"6", | |
"7", | |
"8", | |
] | |
class MoLAConfig(PretrainedConfig): | |
"""Configuration class for MoLA-LM model.""" | |
model_type = "mola_lm" | |
def __init__( | |
self, | |
base_model_name_or_path: str = "Qwen/Qwen2.5-3B-Instruct", | |
task_labels: List[str] = None, | |
router_config: Dict = None, | |
lora_configs: Dict[str, Dict] = None, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.base_model_name_or_path = base_model_name_or_path | |
self.task_labels = task_labels or EXPERTS_LIST | |
self.router_config = router_config or {} | |
self.lora_configs = lora_configs or {} | |
self.num_loras = len(self.task_labels) | |