|
import os |
|
import uuid |
|
import time |
|
import psutil |
|
import uvicorn |
|
import torch |
|
import cv2 |
|
import shutil |
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from models.qwen import Qwen2VL |
|
from models.gemma import Gemma |
|
from models.minicpm import MiniCPM |
|
from models.lfm import LFM2 |
|
from video_processor import extract_frames, FrameSamplingMethod |
|
import argparse |
|
import json |
|
import logging |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_path", type=str, default="Qwen/Qwen2.5-VL-3B-Instruct-AWQ") |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
LOG_DIR = f"logs/{args.model_path.split('/')[-1]}" |
|
OUTPUT_DIR = f"outputs/{args.model_path.split('/')[-1]}" |
|
TEMP_VIDEO_DIR = "temp_videos" |
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
os.makedirs(TEMP_VIDEO_DIR, exist_ok=True) |
|
start_time = time.strftime('%Y%m%d_%H%M%S') |
|
log_filename = f"{LOG_DIR}/{start_time}.log" |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filename=log_filename, filemode='a') |
|
|
|
|
|
app = FastAPI(title=f"{args.model_path} Video Inference Service") |
|
total_output = {} |
|
|
|
logging.info(f"Loading model: {args.model_path}") |
|
model_load_start = time.time() |
|
if "qwen" in args.model_path.lower(): |
|
model = Qwen2VL(args.model_path) |
|
elif "gemma" in args.model_path.lower(): |
|
model = Gemma(args.model_path) |
|
elif "minicpm" in args.model_path.lower(): |
|
model = MiniCPM(args.model_path) |
|
elif "lfm" in args.model_path.lower(): |
|
model = LFM2(args.model_path) |
|
model_load_end = time.time() |
|
GPU_MEMORY_USAGE = f"{torch.cuda.memory_allocated(0)/1024**2:.2f} MB" if torch.cuda.is_available() else "N/A" |
|
logging.info(f"Model loaded in {model_load_end - model_load_start:.2f} seconds") |
|
logging.info(f"GPU Memory Usage after model load: {GPU_MEMORY_USAGE}") |
|
|
|
@app.post("/video-inference/") |
|
async def video_inference( |
|
prompt: str = Form(...), |
|
video_file: str = Form(...), |
|
sampling_method: FrameSamplingMethod = Form(FrameSamplingMethod.CONTENT_AWARE), |
|
sampling_rate: int = Form(5), |
|
): |
|
""" |
|
接收视频和文本提示,进行推理并返回结果。 |
|
""" |
|
request_start_time = time.time() |
|
request_id = str(uuid.uuid4()) |
|
logging.info(f"[{request_id}] Received new video inference request. Prompt: '{prompt}', Video: '{video_file}'") |
|
|
|
if not video_file.endswith(".mp4"): |
|
logging.error(f"[{request_id}] Uploaded file '{video_file}' is not a video.") |
|
raise HTTPException(status_code=400, detail="Uploaded file is not a video.") |
|
|
|
file_extension = os.path.splitext(video_file)[1] |
|
temp_video_path = os.path.join(TEMP_VIDEO_DIR, f"{request_id}{file_extension}") |
|
temp_frame_dir = os.path.join(TEMP_VIDEO_DIR, request_id) |
|
os.makedirs(temp_frame_dir, exist_ok=True) |
|
|
|
try: |
|
|
|
logging.info(f"[{request_id}] Video saved to temporary file: {temp_video_path}") |
|
logging.info(f"[{request_id}] Extracting frames using method: {sampling_method.value}, rate/threshold: {sampling_rate}") |
|
|
|
frames = extract_frames(video_file, sampling_method, sampling_rate) |
|
if not frames: |
|
logging.error(f"[{request_id}] Could not extract any frames from the video: {temp_video_path}") |
|
raise HTTPException(status_code=400, detail="Could not extract any frames from the video.") |
|
|
|
logging.info(f"[{request_id}] Extracted {len(frames)} frames successfully. Saving to temporary files...") |
|
|
|
|
|
frame_paths = [] |
|
for i, frame in enumerate(frames): |
|
frame_path = os.path.join(temp_frame_dir, f"frame_{i:04d}.jpg") |
|
cv2.imwrite(frame_path, frame) |
|
abs_frame_path = os.path.abspath(frame_path) |
|
frame_paths.append(abs_frame_path) |
|
|
|
logging.info(f"[{request_id}] {len(frame_paths)} frames saved to {temp_frame_dir}") |
|
|
|
output = model.generate(frame_paths, prompt) |
|
|
|
logging.info(f"Tokens per second: {output['tokens_per_second']}, Peak GPU memory MB: {output['peak_gpu_memory_mb']}") |
|
|
|
inference_end_time = time.time() |
|
cpu_usage = psutil.cpu_percent(interval=None) |
|
cpu_core_utilization = psutil.cpu_percent(interval=None, percpu=True) |
|
logging.info(f"[{request_id}] Inference time: {inference_end_time - request_start_time:.2f} seconds, CPU usage: {cpu_usage}%, CPU core utilization: {cpu_core_utilization}") |
|
output["inference_time"] = inference_end_time - request_start_time |
|
output["cpu_usage"] = cpu_usage |
|
output["cpu_core_utilization"] = cpu_core_utilization |
|
output["num_generated_tokens"] = output["num_generated_tokens"] |
|
|
|
return JSONResponse(content=output) |
|
|
|
except Exception as e: |
|
logging.error(f"[{request_id}] An error occurred during processing: {str(e)}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}") |
|
finally: |
|
if os.path.exists(temp_video_path): |
|
os.remove(temp_video_path) |
|
logging.info(f"[{request_id}] Cleaned up temporary file: {temp_video_path}") |
|
if os.path.exists(temp_frame_dir): |
|
shutil.rmtree(temp_frame_dir) |
|
logging.info(f"[{request_id}] Cleaned up temporary frame directory: {temp_frame_dir}") |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8010) |
|
|