|
""" |
|
srun -p INTERN2 --job-name='husky_multi_test' --gres=gpu:1 --cpus-per-task=8 --quotatype="auto" python -u demo/inference_new.py |
|
""" |
|
|
|
import abc |
|
from typing import Optional |
|
|
|
import os |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
import torch |
|
import torchvision.transforms as T |
|
from peft import PeftModel |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from transformers import ( |
|
LlamaTokenizer, |
|
GenerationConfig, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
) |
|
|
|
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration |
|
|
|
from robohusky.conversation import ( |
|
conv_templates, |
|
get_conv_template, |
|
) |
|
|
|
from robohusky.video_transformers import ( |
|
GroupNormalize, |
|
GroupScale, |
|
GroupCenterCrop, |
|
Stack, |
|
ToTorchFormatTensor, |
|
get_index, |
|
) |
|
|
|
from robohusky.compression import compress_module |
|
from decord import VideoReader, cpu |
|
|
|
|
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_UNK_TOKEN = "<unk>" |
|
DEFAULT_IMG_START_TOKEN = "<img>" |
|
DEFAULT_IMG_END_TOKEN = "</img>" |
|
|
|
DEFAULT_VIDEO_START_TOKEN = "<vid>" |
|
DEFAULT_VIDEO_END_TOKEN = "</vid>" |
|
|
|
def get_gpu_memory(max_gpus=None): |
|
gpu_memory = [] |
|
num_gpus = ( |
|
torch.cuda.device_count() |
|
if max_gpus is None |
|
else min(max_gpus, torch.cuda.device_count()) |
|
) |
|
|
|
for gpu_id in range(num_gpus): |
|
with torch.cuda.device(gpu_id): |
|
device = torch.cuda.current_device() |
|
gpu_properties = torch.cuda.get_device_properties(device) |
|
total_memory = gpu_properties.total_memory / (1024 ** 3) |
|
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3) |
|
available_memory = total_memory - allocated_memory |
|
gpu_memory.append(available_memory) |
|
return gpu_memory |
|
|
|
def load_model( |
|
model_path, device, num_gpus, max_gpu_memory=None, load_8bit=False, lora_weights=None |
|
): |
|
if device == "cpu": |
|
kwargs = {} |
|
elif device == "cuda": |
|
kwargs = {"torch_dtype": torch.float16} |
|
if num_gpus == "auto": |
|
kwargs["device_map"] = "auto" |
|
else: |
|
num_gpus = int(num_gpus) |
|
if num_gpus != 1: |
|
kwargs["device_map"] = "auto" |
|
if max_gpu_memory is None: |
|
kwargs[ |
|
"device_map" |
|
] = "sequential" |
|
available_gpu_memory = get_gpu_memory(num_gpus) |
|
kwargs["max_memory"] = { |
|
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" |
|
for i in range(num_gpus) |
|
} |
|
else: |
|
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} |
|
else: |
|
raise ValueError(f"Invalid device: {device}") |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained( |
|
model_path, use_fast=False) |
|
|
|
if lora_weights is None: |
|
model = HuskyForConditionalGeneration.from_pretrained( |
|
model_path, low_cpu_mem_usage=True, **kwargs |
|
) |
|
else: |
|
kwargs["device_map"] = "auto" |
|
model = HuskyForConditionalGeneration.from_pretrained( |
|
model_path, low_cpu_mem_usage=True, **kwargs |
|
) |
|
model.language_model = PeftModel.from_pretrained( |
|
model.language_model, |
|
lora_weights, |
|
**kwargs |
|
) |
|
|
|
if load_8bit: |
|
compress_module(model, device) |
|
|
|
if (device == "cuda" and num_gpus == 1) or device == "mps": |
|
model.to(device) |
|
|
|
model = model.eval() |
|
return model, tokenizer |
|
|
|
def load_image(image_file, input_size=224): |
|
if image_file.startswith('http') or image_file.startswith('https'): |
|
response = requests.get(image_file) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
else: |
|
image = Image.open(image_file).convert('RGB') |
|
|
|
crop_pct = 224 / 256 |
|
size = int(input_size / crop_pct) |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Resize(size, interpolation=InterpolationMode.BICUBIC), |
|
T.CenterCrop(input_size), |
|
T.ToTensor(), |
|
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
|
]) |
|
image = transform(image) |
|
return image |
|
|
|
def load_video(video_path, num_segments=8): |
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
num_frames = len(vr) |
|
frame_indices = get_index(num_frames, num_segments) |
|
|
|
|
|
crop_size = 224 |
|
scale_size = 224 |
|
input_mean = [0.48145466, 0.4578275, 0.40821073] |
|
input_std = [0.26862954, 0.26130258, 0.27577711] |
|
|
|
transform = T.Compose([ |
|
GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), |
|
GroupCenterCrop(crop_size), |
|
Stack(), |
|
ToTorchFormatTensor(), |
|
GroupNormalize(input_mean, input_std) |
|
]) |
|
|
|
images_group = list() |
|
for frame_index in frame_indices: |
|
img = Image.fromarray(vr[frame_index].asnumpy()) |
|
images_group.append(img) |
|
video = transform(images_group) |
|
return video |
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
|
|
def __init__(self, stops, encounters=1): |
|
super().__init__() |
|
self.stops = stops |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): |
|
for stop in self.stops: |
|
if input_ids.shape[-1] < len(stop): |
|
continue |
|
|
|
stop = stop.to(input_ids.device) |
|
if torch.all((stop == input_ids[0][-len(stop):])).item(): |
|
return True |
|
return False |
|
|
|
|
|
|
|
@torch.inference_mode() |
|
def generate_stream( |
|
model, tokenizer, image_processor, params, device |
|
): |
|
prompt = params["prompt"] |
|
images = params.get("images", None) |
|
videos = params.get("videos", None) |
|
temperature = float(params.get("temperature", 0.7)) |
|
max_new_tokens = int(params.get("max_new_tokens", 1024)) |
|
|
|
num_queries = model.config.num_query_tokens |
|
|
|
stop_words = ["Human: ", "Assistant: ", "###", "\n\n"] |
|
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] |
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
generation_config = GenerationConfig( |
|
bos_token_id=1, |
|
do_sample=True, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
stopping_criteria=stopping_criteria |
|
) |
|
|
|
pixel_values = None |
|
if images is not None: |
|
pixel_values = load_image(images).to(device) |
|
image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN |
|
prompt = prompt.replace("<image>", image_query) |
|
|
|
elif videos is not None: |
|
pixel_values = load_video(videos).to(device) |
|
video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN |
|
prompt = prompt.replace("<video>", video_query) |
|
|
|
model_inputs = tokenizer([prompt], return_tensors="pt") |
|
model_inputs.pop("token_type_ids", None) |
|
|
|
if pixel_values is not None: |
|
model_inputs["pixel_values"] = pixel_values |
|
|
|
generation_output = model.generate( |
|
**model_inputs, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
else: |
|
generation_output = model.language_model.generate( |
|
**model_inputs, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
|
|
preds = generation_output.sequences |
|
outputs = tokenizer.batch_decode(preds, skip_special_tokens=True) |
|
return outputs |
|
|
|
class Chat: |
|
def __init__( |
|
self, |
|
model_path, |
|
device, |
|
num_gpus=1, |
|
load_8bit=False, |
|
temperature=0.3, |
|
max_new_tokens=512, |
|
lora_path=None, |
|
): |
|
model, tokenizer = load_model( |
|
model_path, device, num_gpus, load_8bit=load_8bit, lora_weights=lora_path |
|
) |
|
|
|
self.model = model |
|
|
|
|
|
self.tokenizer = tokenizer |
|
num_queries = model.config.num_query_tokens |
|
|
|
self.device = device |
|
self.dtype = model.dtype |
|
|
|
stop_words = ["Human: ", "Assistant: ", "###", "\n\n"] |
|
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] |
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
self.conv = get_conv_template("husky") |
|
|
|
self.image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN |
|
self.video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN |
|
|
|
self.generation_config = GenerationConfig( |
|
bos_token_id=1, |
|
do_sample=True, |
|
top_k=20, |
|
top_p=0.9, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens |
|
|
|
) |
|
self.stopping_criteria = stopping_criteria |
|
def ask(self, text, conv, modal_type="image"): |
|
assert modal_type in ["text", "image", "video"] |
|
conversations = [] |
|
|
|
if len(conv.messages) > 0 or modal_type == "text": |
|
conv.append_message(conv.roles[0], text) |
|
elif modal_type == "image": |
|
conv.append_message(conv.roles[0], self.image_query + "\n" + text) |
|
else: |
|
conv.append_message(conv.roles[0], self.video_query + "\n" + text) |
|
|
|
conv.append_message(conv.roles[1], None) |
|
conversations.append(conv.get_prompt()) |
|
return conversations |
|
|
|
@torch.no_grad() |
|
def get_image_embedding(self, image_file): |
|
pixel_values = load_image(image_file) |
|
pixel_values = pixel_values.unsqueeze(0).to(self.device, dtype=self.dtype) |
|
language_model_inputs = self.model.extract_feature(pixel_values) |
|
return language_model_inputs |
|
|
|
@torch.no_grad() |
|
def get_video_embedding(self, video_file): |
|
pixel_values = load_video(video_file) |
|
TC, H, W = pixel_values.shape |
|
pixel_values = pixel_values.reshape(TC // 3, 3, H, W).transpose(0, 1) |
|
pixel_values = pixel_values.unsqueeze(0).to(self.device, dtype=self.dtype) |
|
assert len(pixel_values.shape) == 5 |
|
language_model_inputs = self.model.extract_feature(pixel_values) |
|
return language_model_inputs |
|
|
|
@torch.no_grad() |
|
def answer(self, conversations, language_model_inputs, modal_type="image"): |
|
model_inputs = self.tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
) |
|
model_inputs.pop("token_type_ids", None) |
|
|
|
input_ids = model_inputs["input_ids"].to(self.device) |
|
attention_mask = model_inputs["attention_mask"].to(self.device) |
|
|
|
if modal_type == "text": |
|
generation_output = self.model.language_model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
generation_config=self.generation_config, |
|
stopping_criteria=self.stopping_criteria, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
else: |
|
pixel_values = model_inputs.pop("pixel_values", None) |
|
if pixel_values is not None: |
|
pixel_values = pixel_values.to(self.device) |
|
|
|
generation_output = self.model.generate( |
|
pixel_values=pixel_values, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
language_model_inputs=language_model_inputs, |
|
generation_config=self.generation_config, |
|
stopping_criteria=self.stopping_criteria, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
|
|
preds = generation_output.sequences |
|
outputs = self.tokenizer.batch_decode(preds, skip_special_tokens=True)[0] |
|
|
|
if modal_type == "text": |
|
skip_echo_len = len(conversations[0]) - conversations[0].count("</s>") * 3 |
|
outputs = outputs[skip_echo_len:].strip() |
|
|
|
return outputs |
|
|
|
if __name__ == '__main__': |
|
|
|
model_path = "./" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
chat = Chat(model_path, device=device, num_gpus=1, max_new_tokens=1024, load_8bit=False) |
|
|
|
vision_feature = None |
|
image_state = False |
|
video_state = False |
|
|
|
while True: |
|
query = input("\n") |
|
if query.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): |
|
if os.path.exists(query): |
|
print("received.") |
|
vision_feature = chat.get_image_embedding(query) |
|
chat.conv = get_conv_template("husky").copy() |
|
image_state = True |
|
continue |
|
if query.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")): |
|
if os.path.exists(query): |
|
print("received.") |
|
vision_feature = chat.get_video_embedding(query) |
|
chat.conv = get_conv_template("husky").copy() |
|
video_state = True |
|
continue |
|
|
|
if query == "stop": |
|
break |
|
if query == "clear" or query == "" or query == "\n": |
|
chat.conv = get_conv_template("husky").copy() |
|
image_state = False |
|
video_state = False |
|
os.system("clear") |
|
print("欢迎使用 husky-13b-zh 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") |
|
continue |
|
|
|
if image_state: |
|
modal_type = "image" |
|
elif video_state: |
|
modal_type = "video" |
|
else: |
|
modal_type = "text" |
|
|
|
|
|
|
|
|
|
|
|
|
|
conversations = chat.ask(text=query, conv=chat.conv, modal_type=modal_type) |
|
outputs = chat.answer(conversations, vision_feature, modal_type=modal_type) |
|
|
|
chat.conv.messages[-1][1] = outputs.strip() |
|
|
|
print(f"Husky: \n{outputs}") |
|
|