File size: 5,582 Bytes
f8ba0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import uuid
import time
import psutil
import uvicorn
import torch
import cv2
import shutil
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from models.qwen import Qwen2VL
from models.gemma import Gemma
from models.minicpm import MiniCPM
from models.lfm import LFM2
from video_processor import extract_frames, FrameSamplingMethod
import argparse
import json
import logging



parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="Qwen/Qwen2.5-VL-3B-Instruct-AWQ")
args = parser.parse_args()



# --- 日志和临时文件目录配置 ---
LOG_DIR = f"logs/{args.model_path.split('/')[-1]}"
OUTPUT_DIR = f"outputs/{args.model_path.split('/')[-1]}"
TEMP_VIDEO_DIR = "temp_videos"
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(TEMP_VIDEO_DIR, exist_ok=True)
start_time = time.strftime('%Y%m%d_%H%M%S')
log_filename = f"{LOG_DIR}/{start_time}.log"
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filename=log_filename, filemode='a')

# --- FastAPI 应用初始化 ---
app = FastAPI(title=f"{args.model_path} Video Inference Service")
total_output = {}
# --- 加载模型和处理器 ---
logging.info(f"Loading model: {args.model_path}")
model_load_start = time.time()
if "qwen" in args.model_path.lower():
    model = Qwen2VL(args.model_path)
elif "gemma" in args.model_path.lower():
    model = Gemma(args.model_path)
elif "minicpm" in args.model_path.lower():
    model = MiniCPM(args.model_path)
elif "lfm" in args.model_path.lower():
    model = LFM2(args.model_path)
model_load_end = time.time()
GPU_MEMORY_USAGE = f"{torch.cuda.memory_allocated(0)/1024**2:.2f} MB" if torch.cuda.is_available() else "N/A"
logging.info(f"Model loaded in {model_load_end - model_load_start:.2f} seconds")
logging.info(f"GPU Memory Usage after model load: {GPU_MEMORY_USAGE}")

@app.post("/video-inference/")
async def video_inference(
   prompt: str = Form(...),
   video_file: str = Form(...),
   sampling_method: FrameSamplingMethod = Form(FrameSamplingMethod.CONTENT_AWARE),
   sampling_rate: int = Form(5),
):
    """
    接收视频和文本提示,进行推理并返回结果。
    """
    request_start_time = time.time()
    request_id = str(uuid.uuid4())
    logging.info(f"[{request_id}] Received new video inference request. Prompt: '{prompt}', Video: '{video_file}'")

    if not video_file.endswith(".mp4"):
        logging.error(f"[{request_id}] Uploaded file '{video_file}' is not a video.")
        raise HTTPException(status_code=400, detail="Uploaded file is not a video.")

    file_extension = os.path.splitext(video_file)[1]
    temp_video_path = os.path.join(TEMP_VIDEO_DIR, f"{request_id}{file_extension}")
    temp_frame_dir = os.path.join(TEMP_VIDEO_DIR, request_id)
    os.makedirs(temp_frame_dir, exist_ok=True)

    try:
        
        logging.info(f"[{request_id}] Video saved to temporary file: {temp_video_path}")
        logging.info(f"[{request_id}] Extracting frames using method: {sampling_method.value}, rate/threshold: {sampling_rate}")
        
        frames = extract_frames(video_file, sampling_method, sampling_rate)
        if not frames:
            logging.error(f"[{request_id}] Could not extract any frames from the video: {temp_video_path}")
            raise HTTPException(status_code=400, detail="Could not extract any frames from the video.")
        
        logging.info(f"[{request_id}] Extracted {len(frames)} frames successfully. Saving to temporary files...")

        # 将帧保存到临时文件并获取其路径
        frame_paths = []
        for i, frame in enumerate(frames):
            frame_path = os.path.join(temp_frame_dir, f"frame_{i:04d}.jpg")
            cv2.imwrite(frame_path, frame)
            abs_frame_path = os.path.abspath(frame_path)
            frame_paths.append(abs_frame_path)

        logging.info(f"[{request_id}] {len(frame_paths)} frames saved to {temp_frame_dir}")

        output = model.generate(frame_paths, prompt)

        logging.info(f"Tokens per second: {output['tokens_per_second']}, Peak GPU memory MB: {output['peak_gpu_memory_mb']}")

        inference_end_time = time.time()
        cpu_usage = psutil.cpu_percent(interval=None)
        cpu_core_utilization = psutil.cpu_percent(interval=None, percpu=True)
        logging.info(f"[{request_id}] Inference time: {inference_end_time - request_start_time:.2f} seconds, CPU usage: {cpu_usage}%, CPU core utilization: {cpu_core_utilization}")
        output["inference_time"] = inference_end_time - request_start_time
        output["cpu_usage"] = cpu_usage
        output["cpu_core_utilization"] = cpu_core_utilization
        output["num_generated_tokens"] = output["num_generated_tokens"]
        
        return JSONResponse(content=output)

    except Exception as e:
        logging.error(f"[{request_id}] An error occurred during processing: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}")
    finally:
        if os.path.exists(temp_video_path):
            os.remove(temp_video_path)
            logging.info(f"[{request_id}] Cleaned up temporary file: {temp_video_path}")
        if os.path.exists(temp_frame_dir):
            shutil.rmtree(temp_frame_dir)
            logging.info(f"[{request_id}] Cleaned up temporary frame directory: {temp_frame_dir}")


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8010)