# Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 import math import os import os.path as osp import warnings from dataclasses import asdict from typing import Any, Dict, List, Optional, Sequence, Tuple import torch from huggingface_hub import file_exists, repo_exists from huggingface_hub.utils import HFValidationError import transformers from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ) # from .conversation import * from .conversation import default_conversation, SeparatorStyle SENTINEL_TOKEN = "" MEDIA_TOKENS = { "image": "", "video": "", } # from llava.model.utils import packing # from llava.utils.logging import logger # from llava.utils.tokenizer import infer_stop_tokens DUMMY_CONVERSATION = [ {"from": "human", "value": "question"}, {"from": "gpt", "value": "answer"}, ] * 10 def tokenizer_image_token(prompt, tokenizer, return_tensors=None): return tokenizer(prompt, return_tensors=return_tensors).input_ids[0] def has_tokenizer(repo_id_or_path: str) -> bool: # Check if the tokenizer is in a local directory if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): return True # Check if the tokenizer is in a Hugging Face Hub repo try: return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json") except HFValidationError: return False def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: if not hasattr(tokenizer, "sentinel_token"): tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) tokenizer.sentinel_token = SENTINEL_TOKEN tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) def tokenize_conversation_legacy( messages: Sequence[Dict[str, str]], tokenizer: transformers.PreTrainedTokenizer, add_generation_prompt: bool = False, overrides: Optional[Dict[str, str]] = None, no_system_prompt: bool = False, ) -> torch.Tensor: conv = default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} if no_system_prompt: conv.system = "" # Skip the first message if it is not from human if messages[0]["from"] != "human": messages = messages[1:] # Add a generation prompt if needed if add_generation_prompt: messages.append({"from": "gpt", "value": None}) conv.messages = [] for turn, message in enumerate(messages): role = roles[message["from"]] assert role == conv.roles[turn % 2] if overrides is not None and message["from"] in overrides: conv.append_message(role, overrides[message["from"]]) else: conv.append_message(role, message["value"]) return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt") def tokenize_conversation( messages: Sequence[Dict[str, str]], tokenizer: transformers.PreTrainedTokenizer, add_generation_prompt: bool = False, overrides: Optional[Dict[str, str]] = None, no_system_prompt: bool = False, ) -> torch.Tensor: # Normalize the conversation before tokenization for message in messages: message["value"] = message["value"].strip() if default_conversation.sep_style != SeparatorStyle.AUTO: return tokenize_conversation_legacy( messages, tokenizer, add_generation_prompt=add_generation_prompt, overrides=overrides, no_system_prompt=no_system_prompt, ) conversation = [] for m in messages: message = {} if m["from"] == "human": message["role"] = "user" elif m["from"] == "gpt": message["role"] = "assistant" else: raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.") message["content"] = m["value"] if overrides is not None and m["from"] in overrides: message["content"] = overrides[m["from"]] conversation.append(message) if no_system_prompt: conversation = [{"role": "system", "content": ""}] + conversation text = tokenizer.apply_chat_template( conversation, add_generation_prompt=add_generation_prompt, tokenize=False, ) return tokenizer_image_token(text, tokenizer, return_tensors="pt") def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: _maybe_add_sentinel_token(tokenizer) template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) stop_tokens = {tokenizer.eos_token} for k in range(template.size(0) - 1): if template[k] == tokenizer.sentinel_token_id: stop_token = tokenizer.decode(template[k + 1]) stop_tokens.add(stop_token) return list(stop_tokens) def context_length_extension(config): orig_ctx_len = getattr(config, "max_position_embeddings", None) model_max_length = getattr(config, "model_max_length", None) if orig_ctx_len and model_max_length > orig_ctx_len: print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) config.rope_scaling = {"type": "linear", "factor": scaling_factor} return config def build_llm_and_tokenizer( model_name_or_path: str, config: PretrainedConfig, attn_implementation=None, model_max_length=None, *args, **kwargs, ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: # print(model_name_or_path) llm_cfg = AutoConfig.from_pretrained(model_name_or_path) llm_cfg._attn_implementation = attn_implementation llm_cfg.model_max_length = model_max_length if model_max_length is not None: context_length_extension(llm_cfg) # Quantization related quantization_restore_from_checkpoint = False if quantization_restore_from_checkpoint: fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None) llm = AutoModelForCausalLM.from_pretrained( fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs ) else: llm = AutoModelForCausalLM.from_pretrained( model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs ) # NOTE(ligeng): not sure whether it affects the training # packing.patch(llm) # Locate the tokenizer. llm_path = model_name_or_path if not has_tokenizer(llm_path): llm_path = osp.join(llm_path, "llm") if not has_tokenizer(llm_path): raise ValueError(f"Cannot find tokenizer in {llm_path}.") tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False) if model_max_length is not None: tokenizer.model_max_length = model_max_length # Load chat template if specified. if getattr(config, "chat_template", None) is not None: print(f"Using chat template: {config.chat_template}") fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja") with open(fpath) as fd: chat_template = fd.read() tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "") # NOTE(ligeng): disable temporarially, let see will any bugs introduce # Set stop tokens for the tokenizer tokenizer.stop_tokens = infer_stop_tokens(tokenizer) tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens) # Add media tokens to the tokenizer tokenizer.media_tokens = MEDIA_TOKENS tokenizer.media_token_ids = {} for name, token in MEDIA_TOKENS.items(): tokenizer.add_tokens([token], special_tokens=True) tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token) # TODO(ligeng): is this necessary for llava? config.hidden_size = llm.config.hidden_size return llm, tokenizer