T4_code / ollama_minicpm /video_processor.py
Wangtwohappy's picture
Upload folder using huggingface_hub
f8ba0eb verified
import cv2
import numpy as np
import base64
from typing import List
from enum import Enum
from skimage.metrics import structural_similarity as ssim
class FrameSamplingMethod(str, Enum):
UNIFORM = "uniform"
CONTENT_AWARE = "content_aware"
def extract_frames(
video_path: str,
method: FrameSamplingMethod,
sampling_rate: int
) -> List[np.ndarray]:
"""
从视频中提取帧。
对于UNIFORM方法,sampling_rate表示要提取的总帧数。
对于CONTENT_AWARE方法,sampling_rate现在也表示要提取的总帧数,但会选择变化最大的帧。
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video file: {video_path}")
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
target_size = (420, 280) # (width, height)
if method == FrameSamplingMethod.UNIFORM:
if sampling_rate <= 0:
cap.release()
return []
# 如果请求的帧数大于总帧数,则返回所有帧
if sampling_rate >= total_frames:
while True:
ret, frame = cap.read()
if not ret:
break
resized_frame = cv2.resize(frame, target_size)
frames.append(resized_frame)
cap.release()
return frames
# 计算采样间隔
step = total_frames / sampling_rate
for i in range(sampling_rate):
frame_index = int(i * step)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = cap.read()
if ret:
resized_frame = cv2.resize(frame, target_size)
frames.append(resized_frame)
elif method == FrameSamplingMethod.CONTENT_AWARE:
if sampling_rate <= 0:
cap.release()
return []
# 如果视频总帧数少于或等于请求的帧数,则返回所有帧
if total_frames <= sampling_rate:
while True:
ret, frame = cap.read()
if not ret:
break
resized_frame = cv2.resize(frame, target_size)
frames.append(resized_frame)
cap.release()
return frames
# --- Pass 1: 计算所有相邻帧的SSIM分数 ---
ssim_scores = []
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, prev_frame = cap.read()
if not ret:
cap.release()
return []
prev_frame_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
for i in range(1, total_frames):
ret, current_frame = cap.read()
if not ret:
break
current_frame_gray = cv2.cvtColor(current_frame, cv2.COLOR_BGR2GRAY)
score, _ = ssim(prev_frame_gray, current_frame_gray, full=True)
ssim_scores.append((score, i)) # 存储(ssim_score, frame_index)
prev_frame_gray = current_frame_gray
# --- 选择变化最大的 n-1 帧 ---
# 按SSIM分数升序排序 (分数越低,差异越大)
ssim_scores.sort(key=lambda x: x[0])
# 选择分数最低的 n-1 帧的索引
selected_indices = {score[1] for score in ssim_scores[:sampling_rate - 1]}
# 始终包括第一帧 (index 0)
selected_indices.add(0)
# --- Pass 2: 根据索引提取帧 ---
sorted_indices = sorted(list(selected_indices))
for idx in sorted_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
resized_frame = cv2.resize(frame, target_size)
frames.append(resized_frame)
cap.release()
return frames
def encode_frames_to_base64(frames: List[np.ndarray]) -> List[str]:
"""
将OpenCV帧列表编码为base64字符串列表。
"""
base64_frames = []
for frame in frames:
# 将帧编码为JPEG格式
_, buffer = cv2.imencode('.jpg', frame)
# 将缓冲区字节转换为base64字符串
base64_str = base64.b64encode(buffer).decode('utf-8')
base64_frames.append(base64_str)
return base64_frames