# Copyright (c) OpenMMLab. All rights reserved. import argparse import os.path as osp import shutil from mmengine.config import Config, DictAction from mmengine.fileio import PetrelBackend, get_file_backend from xtuner.configs import cfgs_name_path from xtuner.model.utils import guess_load_checkpoint from xtuner.registry import BUILDER def parse_args(): parser = argparse.ArgumentParser( description='Convert the pth model to HuggingFace model') parser.add_argument('config', help='config file name or path.') parser.add_argument('pth_model', help='pth model file') parser.add_argument( 'save_dir', help='the directory to save HuggingFace model') parser.add_argument( '--fp32', action='store_true', help='Save LLM in fp32. If not set, fp16 will be used by default.') parser.add_argument( '--max-shard-size', type=str, default='2GB', help='Only applicable for LLM. The maximum size for ' 'each sharded checkpoint.') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') args = parser.parse_args() return args def main(): args = parse_args() # parse config if not osp.isfile(args.config): try: args.config = cfgs_name_path[args.config] except KeyError: raise FileNotFoundError(f'Cannot find {args.config}') # load config cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) model_name = cfg.model.type if isinstance(cfg.model.type, str) else cfg.model.type.__name__ if 'LLaVAModel' in model_name: cfg.model.pretrained_pth = None model = BUILDER.build(cfg.model) backend = get_file_backend(args.pth_model) if isinstance(backend, PetrelBackend): from xtuner.utils.fileio import patch_fileio with patch_fileio(): state_dict = guess_load_checkpoint(args.pth_model) else: state_dict = guess_load_checkpoint(args.pth_model) model.load_state_dict(state_dict, strict=False) print(f'Load PTH model from {args.pth_model}') if 'LLaVAModel' in model_name: if cfg.model.get('llm') and (not cfg.model.get('freeze_llm', False) or cfg.model.get('llm_lora')): if 'PeftModel' in model.llm.__class__.__name__: llm_path = osp.join(args.save_dir, 'llm_adapter') print(f'Saving LLM adapter to {llm_path}') else: llm_path = args.save_dir print(f'Saving LLM tokenizer to {llm_path}') tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.save_pretrained(llm_path) print(f'Saving LLM to {llm_path}') if not args.fp32: # StarCycle: The llm has been quantinized # print('Convert LLM to float16') # model.llm.half() model.llm.save_pretrained( llm_path, max_shard_size=args.max_shard_size) if cfg.model.get('visual_encoder') and ( not cfg.model.get('freeze_visual_encoder', False) or cfg.model.get('visual_encoder_lora')): if 'PeftModel' in model.visual_encoder.__class__.__name__: visual_encoder_path = osp.join(args.save_dir, 'visual_encoder_adapter') print( f'Saving visual_encoder adapter to {visual_encoder_path}') else: visual_encoder_path = osp.join(args.save_dir, 'visual_encoder') print('Saving visual_encoder image_processor to' f'{visual_encoder_path}') image_processor = BUILDER.build(cfg.image_processor) image_processor.save_pretrained(visual_encoder_path) print(f'Saving visual_encoder to {visual_encoder_path}') model.visual_encoder.save_pretrained( visual_encoder_path, max_shard_size=args.max_shard_size) if hasattr(model, 'projector'): projector_path = osp.join(args.save_dir, 'projector') print(f'Saving projector to {projector_path}') model.projector.save_pretrained( projector_path, max_shard_size=args.max_shard_size) else: llm_path = args.save_dir if 'PeftModel' in model.llm.__class__.__name__: print(f'Saving adapter to {llm_path}') else: print(f'Saving LLM tokenizer to {llm_path}') tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.save_pretrained(llm_path) print(f'Saving LLM to {llm_path}') if not args.fp32: print('Convert LLM to float16') model.llm.half() model.llm.save_pretrained( llm_path, max_shard_size=args.max_shard_size, safe_serialization=False) shutil.copyfile(args.config, osp.join(args.save_dir, 'xtuner_config.py')) print('All done!') if __name__ == '__main__': main()