MuFun-Instruct / text_preprocess.py
wanghanrui
Add initial model and configuration files
c3bf9f0
raw
history blame
8.75 kB
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
import copy
from typing import Any
# IGNORE_INDEX = -100
# IMAGE_TOKEN_INDEX = -200
# DEFAULT_IMAGE_TOKEN = "<audio>"
from .configuration import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from transformers import PreTrainedTokenizer
import torch
from abc import ABC, abstractmethod
# from dataclasses import dataclass
# from typing import Dict, Union, List
SLOT = Union[str, List[str], Dict[str, str]]
@dataclass
class Formatter(ABC):
slot: SLOT = ""
@abstractmethod
def apply(self, **kwargs) -> SLOT: ...
@dataclass
class EmptyFormatter(Formatter):
def apply(self, **kwargs) -> SLOT:
return self.slot
@dataclass
class StringFormatter(Formatter):
def apply(self, **kwargs) -> SLOT:
msg = ""
for name, value in kwargs.items():
if value is None:
msg = self.slot.split(':')[0] + ":"
return msg
if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value))
msg = self.slot.replace("{{" + name + "}}", value, 1)
return msg
@dataclass
class Template:
format_image_token: "Formatter"
format_user: "Formatter"
format_assistant: "Formatter"
system: "Formatter"
separator: "Formatter"
def encode(self, messages, tokenizer, mode='train'):
"""
1. get list form messages(conversations:[{from:human, value:message}, {from:gpt, value:message}])
===> human_list, value_list
2. prompt two list
3. tokenize prompt
4. make target
"""
question_list, answer_list = self.get_list_from_message(messages)
if mode == 'rl':
gt = answer_list[-1]
answer_list[-1] = '' # last answer is empty in RL mode
prompt = self.prompt(question_list, answer_list)
if mode == 'rl' and prompt.endswith(self.separator.apply()[1]):
prompt = prompt[:-len(self.separator.apply()[1])]
input_ids = self.tokenizer_image_token(prompt, tokenizer, return_tensors='pt')
if mode == 'train':
labels = self.make_labels(input_ids, prompt, tokenizer)
return dict(
input_ids=input_ids,
labels=labels
)
elif mode == 'rl':
return dict(
input_ids=input_ids,
prompt=prompt,
gt=gt
)
else:
return dict(input_ids=input_ids, prompt=prompt)
def get_list_from_message(self, messages):
return self._get_list_from_message(messages)
def _get_list_from_message(self, messages):
"""
messages ====> [{from:human, value:message}, {from:gpt, value:message}]
"""
question_list = []
answer_list = []
first_is_not_question = 0
for i, message in enumerate(messages):
if i == 0 and message['from'] != 'human':
first_is_not_question = 1
continue
if i % 2 == first_is_not_question:
question_list.append(message['value'])
else:
answer_list.append(message['value'])
assert len(question_list) == len(answer_list) , \
f"qa is not match : length_q:{len(question_list)} vs length_a:{len(answer_list)}"
return question_list, answer_list
def prompt(
self,
question_list, answer_list
):
if type(question_list) is str:
question_list = [question_list]
if type(answer_list) is str:
answer_list = [answer_list]
msg = self._prompt(question_list, answer_list)
return msg
def _prompt(
self,
question_list, answer_list,
):
msg = ""
for i, (question, answer) in enumerate(zip(question_list, answer_list)):
if i == 0:
msg += self.system.apply()
# if DEFAULT_IMAGE_TOKEN in question:
# question = question.replace(DEFAULT_IMAGE_TOKEN, '').strip()
# question = self.format_image_token.apply(content=question).strip()
msg += self.format_user.apply(content=question)
msg += self.format_assistant.apply(content=answer)
return msg
def make_labels(self, input_ids, prompt, tokenizer):
labels = copy.deepcopy(input_ids)
sep, eos_token = self.separator.apply()
total_len = int(labels.ne(tokenizer.pad_token_id).sum())
if tokenizer.pad_token_id == tokenizer.eos_token_id:
total_len += prompt.count(eos_token)
rounds = prompt.split(eos_token)
eos_token_length = len(tokenizer.encode(eos_token))
labels, cur_len = self._make_masks(labels, tokenizer, sep, eos_token_length, rounds)
if cur_len < tokenizer.model_max_length:
# import time
if (cur_len != total_len) and ( (cur_len+1) != total_len):
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
print("number of rounds: ", len(rounds) - 1)
print("rounds: ", rounds[:-1])
print("prompt: ", prompt)
print(labels)
print(input_ids)
# time.sleep(5)
# labels[:] = IGNORE_INDEX
return labels
def _make_masks(self, labels, tokenizer, sep, eos_token_length, rounds):
cur_len = 0
for rou in rounds:
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(self.tokenizer_image_token(rou, tokenizer)) + eos_token_length
instruction_len = len(self.tokenizer_image_token(parts[0], tokenizer)) - 1
labels[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
labels[cur_len:] = IGNORE_INDEX
return labels, cur_len
@classmethod
def tokenizer_image_token(cls, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
def _insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<audio>')]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
TEMPlATE_FACTORY: Dict[str, Template] = {}
def TemplateFactory(version):
template = TEMPlATE_FACTORY.get(version, None)
assert template, f"{version} is not implmentation"
return template
def register_template(name):
def register_template_cls(cls):
if name in TEMPlATE_FACTORY:
return TEMPlATE_FACTORY[name]
TEMPlATE_FACTORY[name] = cls
return cls
return register_template_cls
system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
@register_template('qwen2_instruct')
@dataclass
class Qwen2InstructTemplate(Template):
format_image_token: "Formatter" = field(default_factory=lambda: StringFormatter(slot="<audio>\n{{content}}"))
format_user: "Formatter" = field(default_factory=lambda: StringFormatter(slot="USER" + ": " + "{{content}}" + " "))
format_assistant: "Formatter" = field(default_factory=lambda: StringFormatter(slot="ASSISTANT" + ": " + "{{content}}" + "<|im_end|>"))
system: "Formatter" = field(default_factory=lambda: EmptyFormatter(slot=system+" "))
separator: "Formatter" = field(default_factory=lambda: EmptyFormatter(slot=[' ASSISTANT: ', '<|im_end|>']))
class TextPreprocess:
def __init__(self, tokenizer, version):
self.tokenizer = tokenizer
self.template = TemplateFactory(version)()
def __call__(self, messages, mode='eval'):
return self.template.encode(messages, self.tokenizer, mode)