MuFun-Instruct / modeling_mufun.py
wanghanrui
Add initial model and configuration files
c3bf9f0
raw
history blame
25.7 kB
from typing import List, Optional, Tuple, Union
import re
import os
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from transformers import AutoConfig, AutoModelForCausalLM, Qwen3ForCausalLM, WhisperForConditionalGeneration, StoppingCriteria, AutoProcessor
from .audio_preprocess import AudioPreprocess, load_audios
from .text_preprocess import TextPreprocess
from .message import Message
from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)
ACT_TYPE = {
'relu': nn.ReLU,
'gelu': nn.GELU
}
class CNet(nn.Module):
def __init__(self, config):
super().__init__()
def extract_numbers(s):
match = re.findall(r'(\d+)[ix]', s)
if len(match) == 2:
return tuple(map(int, match))
return None
ix, hx = extract_numbers(config.connector_type)
act_type = 'gelu'
self.act=ACT_TYPE[act_type]()
vdim = config.vision_hidden_size*ix
ldim = config.hidden_size
self.linear1 = nn.Linear(vdim, hx*vdim)
self.linear2 = nn.Linear(hx*vdim, ldim)
def forward(self, x):
x = self.act(self.linear1(x))
return self.linear2(x)
class Connector(nn.Module):
def __init__(self, config=None):
super().__init__()
self._connector = None
def load_model(self, **kwargs):
pretrained_connector_path = kwargs.get('pretrained_connector_path', None)
if pretrained_connector_path is not None:
pretrained_connector_path = os.path.join(pretrained_connector_path, 'pytorch_model.bin')
connector_weights = torch.load(pretrained_connector_path, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self._connector.load_state_dict(get_w(connector_weights, '_connector'))
print(f'Loading connector from {pretrained_connector_path}...')
for p in self._connector.parameters():
p.requires_grad = False
def forward(self, x):
return self._connector(x)
class MLPConnector(Connector):
def __init__(self, config):
super().__init__()
self._connector = CNet(config)
def get_value_from_kwargs(kwargs, name):
if name in kwargs:
return kwargs.pop(name)
else:
return None
class AudioTower(nn.Module):
def __init__(self, cfg):
super().__init__()
self._vision_tower = None
self._image_processor = None
self.config = cfg
def load_model(self, vision_tower_name, **kwargs):
self._load_model(vision_tower_name, **kwargs)
self._vision_tower.requires_grad_(False)
def _load_model(self, vision_tower_name, **kwargs):
pretrained_vision_tower_path = get_value_from_kwargs(kwargs, 'pretrained_vision_tower_path')
if isinstance(self._vision_tower, PreTrainedModel): # hf model
if pretrained_vision_tower_path is not None:
vision_tower_name = pretrained_vision_tower_path
self._vision_tower = self._vision_tower.from_pretrained(vision_tower_name, **kwargs)
else: # nn.Module
if pretrained_vision_tower_path is not None:
vision_tower_weights = torch.load(os.path.join(pretrained_vision_tower_path, 'pytorch_model.bin'), map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self._vision_tower.load_state_dict(vision_tower_weights)
print("Loading vision tower from ", vision_tower_name)
def forward(self, x, **kwargs):
image_features = self._vision_tower(x, output_hidden_states=True)
image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch':
image_features = image_features[:, 1:]
elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch':
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
return image_features
@property
def vision_tower(self):
return self._vision_tower
@vision_tower.setter
def vision_tower(self, vision_tower):
self._vision_tower = vision_tower
class WpmAudioTower(AudioTower):
def __init__(self, cfg):
super().__init__(cfg)
self._vision_tower = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3").get_encoder()
self._image_processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
self.pool_stride = 5
self.avg_pooler = nn.AvgPool1d(self.pool_stride, stride=self.pool_stride)
self.features_layers = [0, 7, 15, 32]
def _load_model(self, vision_tower_name, **kwargs):
pretrained_vision_tower_path = kwargs.pop('pretrained_vision_tower_path', None)
if pretrained_vision_tower_path is None:
print("Loading vision tower1 from ", vision_tower_name)
else: # nn.Module
if pretrained_vision_tower_path is not None:
vision_tower_weights = torch.load(os.path.join(pretrained_vision_tower_path, 'pytorch_model.bin'), map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self._vision_tower.load_state_dict(vision_tower_weights)
print("Loading vision tower from ", pretrained_vision_tower_path)
def forward(self, x, **kwargs):
if len(x.shape)==4:
x=torch.squeeze(x, 1)
image_features = self._vision_tower(x, output_hidden_states=True).hidden_states
hidden_states = torch.cat([image_features[il] for il in self.features_layers], dim=-1)
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states = self.avg_pooler(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
return hidden_states
class TinyLlavaPreTrainedModel(PreTrainedModel):
config_class = TinyLlavaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
return self.language_model._supports_sdpa
class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel, GenerationMixin):
def __init__(self, config: TinyLlavaConfig):
super().__init__(config)
# apply_liger_kernel_to_qwen3()
self.language_model = Qwen3ForCausalLM(config.text_config)
self.vision_tower = WpmAudioTower(config.vision_config)
self.connector = MLPConnector(config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
logits_to_keep = None
) -> Union[Tuple, CausalLMOutputWithPast]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)
return self.language_model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if isinstance(images, list) and (images != []):
images = torch.cat(images, dim=0)
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.language_model.get_input_embeddings()(inputs)
return self.language_model.generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def encode_images(self, images):
kwargs = {}
kwargs['vision_feature_layer'] = self.config.vision_feature_layer
kwargs['vision_feature_select_strategy'] = self.config.vision_feature_select_strategy
images = images.to(device=self.device, dtype=self.dtype)
if images.shape[-1] != 3000:
splits = torch.split(images, 3000, dim=-1)
image_features = torch.cat([self.connector(self.vision_tower(x, **kwargs)) for x in splits], dim=-1)
else:
image_features = self.vision_tower(images, **kwargs)
image_features = self.connector(image_features)
return image_features
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = self.language_model.prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
images, image_sizes=None
):
vision_tower = self.vision_tower
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
image_features = self.encode_images(images)
# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, 'tune_mm_mlp_adapter', False):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
cur_image_size = image_sizes[batch_idx] if image_sizes is not None else None
if num_images == 0:
# cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.language_model.get_input_embeddings()(cur_input_ids)
# cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds_1)
new_labels.append(labels[batch_idx])
# cur_image_idx += 1
continue
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.language_model.get_input_embeddings()(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
img_size = cur_image_size[i]
cur_image_features = image_features[cur_image_idx:cur_image_idx + img_size]
cur_image_features = [img.squeeze(0) for img in cur_image_features]
cur_image_features = torch.cat(cur_image_features, dim=0)
cur_image_idx += img_size
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
# print(f"max_len: {max_len}")
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
new_input_embeds_padded.append(torch.cat((
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
cur_new_embed
), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(torch.cat((
cur_new_embed,
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def load_llm(self, **kwargs):
language_model_name = get_value_from_kwargs(kwargs, 'model_name_or_path')
pretrained_llm_path = get_value_from_kwargs(kwargs, 'pretrained_llm_path')
if pretrained_llm_path is not None:
language_model_name = pretrained_llm_path
if language_model_name is not None:
self.language_model = self.language_model.from_pretrained(
language_model_name, **kwargs
)
print('loading language model from ', language_model_name)
self.language_model.requires_grad_(False)
self.config.text_config.torch_dtype = kwargs.get('torch_dtype', None)
self.config.pad_token = getattr(self.tokenizer, 'pad_token', None)
self.config.pad_token_id = getattr(self.tokenizer, 'pad_token_id', None)
#self.config.tokenizer_padding_side = getattr(self.tokenizer, 'padding_side', None)
#self.config.tokenizer_model_max_length = getattr(self.tokenizer, 'model_max_length', None)
def load_vision_tower(self, **kwargs):
vision_tower_name = get_value_from_kwargs(kwargs, 'model_name_or_path')
self.vision_tower.load_model(vision_tower_name, **kwargs)
def load_connector(self, **kwargs):
self.connector.load_model(**kwargs)
def chat(
self,
tokenizer,
prompt,
audio_files,
segs = None,
max_new_tokens = 512,
temperature= 0.5,
top_k = 50,
top_p = 1.0,
):
text_processor =TextPreprocess(tokenizer, 'qwen2_instruct')
audio_processor = AudioPreprocess(self.vision_tower._image_processor, self.config)
msg = Message()
audio_tensor, audio_size = load_audios(audio_processor, audio_files, segs)
if (audio_tensor) and ('<audio>' not in prompt):
prompt = '<audio>\n' + prompt
msg.add_message(prompt)
result = text_processor(msg.messages, mode='eval')
input_ids = result['input_ids'].unsqueeze(0).to(self.device)
stop_str = text_processor.template.separator.apply()[1]
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = self.generate(
input_ids,
images=audio_tensor,
do_sample=True if temperature > 0 else False,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
use_cache=True,
pad_token_id = tokenizer.eos_token_id,
image_sizes=[audio_size] if audio_tensor is not None else None,
stopping_criteria=[stopping_criteria]
)
gen_text = tokenizer.decode(output_ids[0])
if gen_text.endswith(stop_str):
gen_text = gen_text[:-len(stop_str)]
return gen_text
AutoConfig.register("tinyllava", TinyLlavaConfig)
AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)