# -------------------------------------------------------- # Copyright (c) 2025 NVIDIA # Licensed under customized NSCLv1 [see LICENSE.md for details] # -------------------------------------------------------- import dataclasses from enum import IntEnum, auto from typing import Any, Dict, List, Tuple, Union class SeparatorStyle(IntEnum): """Separator styles.""" ADD_COLON_SINGLE = auto() ADD_COLON_TWO = auto() ADD_COLON_SPACE_SINGLE = auto() NO_COLON_SINGLE = auto() NO_COLON_TWO = auto() ADD_NEW_LINE_SINGLE = auto() LLAMA2 = auto() CHATGLM = auto() CHATML = auto() CHATINTERN = auto() DOLLY = auto() RWKV = auto() PHOENIX = auto() ROBIN = auto() FALCON_CHAT = auto() CHATGLM3 = auto() INTERNVL_ZH = auto() MPT = auto() LLAMA3 = auto() @dataclasses.dataclass class Conversation: """A class that manages prompt templates and keeps all conversation history.""" # The name of this template name: str # The template of the system prompt system_template: str = '{system_message}' # The system message system_message: str = '' # The names of two roles roles: Tuple[str] = ('USER', 'ASSISTANT') # All messages. Each item is (role, message). messages: List[List[str]] = () # The number of few shot examples offset: int = 0 # The separator style and configurations sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE sep: str = '\n' sep2: str = None # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None # Stops generation if meeting any token in this list stop_token_ids: List[int] = None def get_prompt(self) -> str: """Get the prompt for generation.""" system_prompt = self.system_template.format(system_message=self.system_message) if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ': ' + message + self.sep else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: seps = [self.sep, self.sep2] ret = system_prompt + seps[0] for i, (role, message) in enumerate(self.messages): if message: ret += role + ': ' + message + seps[i % 2] else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ': ' + message + self.sep else: ret += role + ': ' # must be end with a space return ret elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: ret = '' if system_prompt == '' else system_prompt + self.sep for role, message in self.messages: if message: ret += role + '\n' + message + self.sep else: ret += role + '\n' return ret elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: ret = system_prompt for role, message in self.messages: if message: ret += role + message + self.sep else: ret += role return ret elif self.sep_style == SeparatorStyle.NO_COLON_TWO: seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += role + message + seps[i % 2] else: ret += role return ret elif self.sep_style == SeparatorStyle.RWKV: ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += ( role + ': ' + message.replace('\r\n', '\n').replace('\n\n', '\n') ) ret += '\n\n' else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.LLAMA2: seps = [self.sep, self.sep2] if self.system_message: ret = system_prompt else: ret = '[INST] ' for i, (role, message) in enumerate(self.messages): tag = self.roles[i % 2] if message: if i == 0: ret += message + ' ' else: ret += tag + ' ' + message + seps[i % 2] else: ret += tag return ret elif self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 round_add_n = 1 if self.name == 'chatglm2' else 0 if system_prompt: ret = system_prompt + self.sep else: ret = '' for i, (role, message) in enumerate(self.messages): if i % 2 == 0: ret += f'[Round {i//2 + round_add_n}]{self.sep}' if message: ret += f'{role}:{message}{self.sep}' else: ret += f'{role}:' return ret elif self.sep_style == SeparatorStyle.CHATML: ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' for role, message in self.messages: if message: ret += role + '\n' + message + self.sep + '\n' else: ret += role + '\n' return ret elif self.sep_style == SeparatorStyle.CHATGLM3: ret = '' if self.system_message: ret += system_prompt for role, message in self.messages: if message: ret += role + '\n' + ' ' + message else: ret += role return ret elif self.sep_style == SeparatorStyle.CHATINTERN: # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): # if i % 2 == 0: # ret += "" if message: ret += role + ':' + message + seps[i % 2] + '\n' else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.DOLLY: seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += role + ':\n' + message + seps[i % 2] if i % 2 == 1: ret += '\n\n' else: ret += role + ':\n' return ret elif self.sep_style == SeparatorStyle.PHOENIX: ret = system_prompt for role, message in self.messages: if message: ret += role + ': ' + '' + message + '' else: ret += role + ': ' + '' return ret elif self.sep_style == SeparatorStyle.ROBIN: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ':\n' + message + self.sep else: ret += role + ':\n' return ret elif self.sep_style == SeparatorStyle.FALCON_CHAT: ret = '' if self.system_message: ret += system_prompt + self.sep for role, message in self.messages: if message: ret += role + ': ' + message + self.sep else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.INTERNVL_ZH: seps = [self.sep, self.sep2] ret = self.system_message + seps[0] for i, (role, message) in enumerate(self.messages): if message: ret += role + ': ' + message + seps[i % 2] else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.MPT: ret = system_prompt + self.sep for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role return ret elif self.sep_style == SeparatorStyle.LLAMA3: ret = system_prompt + self.sep for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role return ret else: raise ValueError(f'Invalid style: {self.sep_style}') def set_system_message(self, system_message: str): """Set the system message.""" self.system_message = system_message def append_message(self, role: str, message: str): """Append a new message.""" self.messages.append([role, message]) def update_last_message(self, message: str): """Update the last output. The last message is typically set to be None when constructing the prompt, so we need to update it in-place after getting the response from a model. """ self.messages[-1][1] = message def to_gradio_chatbot(self): """Convert the conversation to gradio chatbot format.""" ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def to_openai_api_messages(self): """Convert the conversation to OpenAI chat completion format.""" ret = [{'role': 'system', 'content': self.system_message}] for i, (_, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append({'role': 'user', 'content': msg}) else: if msg is not None: ret.append({'role': 'assistant', 'content': msg}) return ret def copy(self): return Conversation( name=self.name, system_template=self.system_template, system_message=self.system_message, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, stop_str=self.stop_str, stop_token_ids=self.stop_token_ids, ) def dict(self): return { 'template_name': self.name, 'system_message': self.system_message, 'roles': self.roles, 'messages': self.messages, 'offset': self.offset, } # A global registry for all conversation templates conv_templates: Dict[str, Conversation] = {} def register_conv_template(template: Conversation, override: bool = False): """Register a new conversation template.""" if not override: assert ( template.name not in conv_templates ), f'{template.name} has been registered.' conv_templates[template.name] = template def get_conv_template(name: str) -> Conversation: """Get a conversation template.""" return conv_templates[name].copy() register_conv_template( Conversation( name='bidirectional-llama-retriever', system_template='', system_message='', roles=('', ''), sep_style=SeparatorStyle.LLAMA3, sep='', stop_token_ids=[ 128259, 128001 ] ) )