|
import os |
|
import base64 |
|
import tempfile |
|
from inference import Chat, get_conv_template |
|
import torch |
|
|
|
def save_base64_to_tempfile(base64_str, suffix): |
|
header_removed = base64_str |
|
|
|
if ',' in base64_str: |
|
header_removed = base64_str.split(',', 1)[1] |
|
|
|
data = base64.b64decode(header_removed) |
|
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) |
|
tmp.write(data) |
|
tmp.close() |
|
return tmp.name |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_path: str): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.chat = Chat( |
|
model_path=model_path, |
|
device=device, |
|
num_gpus=1, |
|
max_new_tokens=1024, |
|
load_8bit=False, |
|
) |
|
self.vision_feature = None |
|
self.modal_type = "text" |
|
self.chat.conv = get_conv_template("husky").copy() |
|
|
|
def __call__(self, data: dict) -> dict: |
|
|
|
if data.get("clear_history"): |
|
self.chat.conv = get_conv_template("husky").copy() |
|
self.vision_feature = None |
|
self.modal_type = "text" |
|
|
|
prompt = data.get("inputs", "") |
|
image_input = data.get("image", None) |
|
video_input = data.get("video", None) |
|
|
|
print("📨 收到 prompt:", repr(prompt)) |
|
|
|
|
|
if image_input: |
|
if os.path.exists(image_input): |
|
|
|
self.vision_feature = self.chat.get_image_embedding(image_input) |
|
else: |
|
|
|
tmp_path = save_base64_to_tempfile(image_input, suffix=".jpg") |
|
self.vision_feature = self.chat.get_image_embedding(tmp_path) |
|
os.unlink(tmp_path) |
|
self.modal_type = "image" |
|
self.chat.conv = get_conv_template("husky").copy() |
|
|
|
elif video_input: |
|
if os.path.exists(video_input): |
|
self.vision_feature = self.chat.get_video_embedding(video_input) |
|
else: |
|
tmp_path = save_base64_to_tempfile(video_input, suffix=".mp4") |
|
print("📼 保存临时视频路径:", tmp_path) |
|
self.vision_feature = self.chat.get_video_embedding(tmp_path) |
|
os.unlink(tmp_path) |
|
self.modal_type = "video" |
|
self.chat.conv = get_conv_template("husky").copy() |
|
|
|
|
|
if isinstance(self.vision_feature, torch.Tensor): |
|
print("📏 视觉特征张量 shape:", self.vision_feature.shape) |
|
else: |
|
print("❌ self.vision_feature 不是张量,类型:", type(self.vision_feature)) |
|
|
|
else: |
|
self.modal_type = "text" |
|
self.vision_feature = None |
|
|
|
try: |
|
|
|
print("🧠 当前 modal_type:", self.modal_type) |
|
print("🧠 是否有视觉特征:", self.vision_feature is not None) |
|
|
|
conversations = self.chat.ask(prompt, self.chat.conv, modal_type=self.modal_type) |
|
output = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type) |
|
|
|
|
|
print("📤 推理输出:", repr(output.strip())) |
|
|
|
self.chat.conv.messages[-1][1] = output.strip() |
|
return {"output": output.strip()} |
|
|
|
except Exception as e: |
|
|
|
import traceback |
|
print("❌ 推理出错:") |
|
traceback.print_exc() |
|
return {"error": str(e)} |
|
|