File size: 5,642 Bytes
f8ba0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import uuid
import base64
import shutil
from typing import List
import time
import cv2
import psutil
import ollama
import uvicorn
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse

try:
    import pynvml
    pynvml.nvmlInit()
    GPU_METRICS_AVAILABLE = True
except (ImportError, pynvml.NVMLError):
    GPU_METRICS_AVAILABLE = False

from video_processor import extract_frames, FrameSamplingMethod, encode_frames_to_base64

import logging
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="openbmb/minicpm-v4:latest")
args = parser.parse_args()

os.makedirs(f'logs/{args.model_name}', exist_ok=True)

# 初始化FastAPI应用
app = FastAPI(title = "Video Inference Service")

# 定义一个临时目录来存储上传的视频
TEMP_VIDEO_DIR = "temp_videos"
os.makedirs(TEMP_VIDEO_DIR, exist_ok=True)

# 使用当前时间戳生成唯一的日志文件名
log_filename = f"logs/{args.model_name}/{time.strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filename=log_filename, filemode='a')

@app.post("/video-inference/")
async def video_inference(
    prompt: str = Form(...),
    video_path: str = Form(...),
    sampling_method: str = Form(...),
    sampling_rate: int = Form(5),
    ):
    """  
    接收视频和文本提示,进行推理并返回结果。
    - prompt: 用户的问题。
    - video_file: 上传的视频文件。
    - sampling_method: 帧采样方法 ('uniform' 或 'content_aware')。
    - sampling_rate: 采样率或阈值。
    """
    try:
        request_start_time = time.time()
        request_id = str(uuid.uuid4())
        logging.info(f"[{request_id}] Received new video inference request. Prompt: '{prompt}', Video: '{video_path}'")

        # 验证上传的文件类型
        if not os.path.exists(video_path):
            raise FileNotFoundError(f"Video file not found: {video_path}")

        if not video_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
            logging.warning(f"[{request_id}] File '{video_path}' may not be a video file.")

        # 转换采样方法字符串为枚举
        sampling_method_map = {
            "CONTENT_AWARE": FrameSamplingMethod.CONTENT_AWARE,
            "UNIFORM": FrameSamplingMethod.UNIFORM,
        }
        sampling_method = sampling_method_map.get(sampling_method, FrameSamplingMethod.CONTENT_AWARE)

        # 创建临时目录
        temp_frame_dir = os.path.join(TEMP_VIDEO_DIR, request_id)
        os.makedirs(temp_frame_dir, exist_ok=True)
    except Exception as e:
        logging.error(f"[{request_id}] An error occurred during processing: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}")

    try:
        logging.info(f"[{request_id}] Extracting frames using method: {sampling_method.value}, rate/threshold: {sampling_rate}")
        
        frames = extract_frames(video_path, sampling_method, sampling_rate)
        if not frames:
            raise ValueError(f"Could not extract any frames from the video: {video_path}")
        
        logging.info(f"[{request_id}] Extracted {len(frames)} frames successfully. Saving to temporary files...")
        # 2. 将帧编码为Base64
        base64_frames = encode_frames_to_base64(frames)
        logging.info(f"[{request_id}] Encoded {len(base64_frames)} frames to Base64.")

        # 3. 构造面向视频的提示
        final_prompt = prompt
        
        # 4. 调用Ollama API
        try:
            logging.info(f"[{request_id}] Sending request to Ollama model '{args.model_name}'...")
            
            # 初始化CPU使用率测量,以便我们测量Ollama调用期间的平均使用率
            psutil.cpu_percent(interval=None)
            psutil.cpu_percent(interval=None, percpu=True)
            
            ollama_start_time = time.time()
            response = ollama.chat(
                model=args.model_name,  # 使用我们创建的自定义模型!
                messages=[
                    {
                        'role': 'user',
                        'content': final_prompt,
                        'images': base64_frames,
                    }
                ]
            )
            ollama_end_time = time.time()
            
            # 在Ollama调用后立即获取CPU使用率,以获得准确的平均值
            cpu_usage = psutil.cpu_percent(interval=None)
            cpu_core_utilization = psutil.cpu_percent(interval=None, percpu=True)

            logging.info(f"[{request_id}] Received response from Ollama successfully.")
            return response
        
        except Exception as ollama_error:
            # 更具体地处理Ollama的错误
            logging.error(f"[{request_id}] Ollama inference failed: {str(ollama_error)}", exc_info=True)
            raise HTTPException(status_code=503, detail=f"Ollama inference failed: {str(ollama_error)}")

    except Exception as e:
        logging.error(f"[{request_id}] An error occurred during processing: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}")
    finally:
        # 清理临时文件
        if os.path.exists(temp_frame_dir):
            shutil.rmtree(temp_frame_dir)
            logging.info(f"[{request_id}] Cleaned up temporary file: {temp_frame_dir}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8008)