import cv2 import numpy as np import base64 from typing import List from enum import Enum from skimage.metrics import structural_similarity as ssim class FrameSamplingMethod(str, Enum): UNIFORM = "uniform" CONTENT_AWARE = "content_aware" def extract_frames( video_path: str, method: FrameSamplingMethod, sampling_rate: int ) -> List[np.ndarray]: """ 从视频中提取帧。 对于UNIFORM方法,sampling_rate表示要提取的总帧数。 对于CONTENT_AWARE方法,sampling_rate现在也表示要提取的总帧数,但会选择变化最大的帧。 """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise IOError(f"Cannot open video file: {video_path}") frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if method == FrameSamplingMethod.UNIFORM: if sampling_rate <= 0: cap.release() return [] # 如果请求的帧数大于总帧数,则返回所有帧 if sampling_rate >= total_frames: while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() return frames # 计算采样间隔 step = total_frames / sampling_rate for i in range(sampling_rate): frame_index = int(i * step) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ret, frame = cap.read() if ret: frames.append(frame) elif method == FrameSamplingMethod.CONTENT_AWARE: if sampling_rate <= 0: cap.release() return [] # 如果视频总帧数少于或等于请求的帧数,则返回所有帧 if total_frames <= sampling_rate: while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() return frames # --- Pass 1: 计算所有相邻帧的SSIM分数 --- ssim_scores = [] cap.set(cv2.CAP_PROP_POS_FRAMES, 0) ret, prev_frame = cap.read() if not ret: cap.release() return [] prev_frame_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) for i in range(1, total_frames): ret, current_frame = cap.read() if not ret: break current_frame_gray = cv2.cvtColor(current_frame, cv2.COLOR_BGR2GRAY) score, _ = ssim(prev_frame_gray, current_frame_gray, full=True) ssim_scores.append((score, i)) # 存储(ssim_score, frame_index) prev_frame_gray = current_frame_gray # --- 选择变化最大的 n-1 帧 --- # 按SSIM分数升序排序 (分数越低,差异越大) ssim_scores.sort(key=lambda x: x[0]) # 选择分数最低的 n-1 帧的索引 selected_indices = {score[1] for score in ssim_scores[:sampling_rate - 1]} # 始终包括第一帧 (index 0) selected_indices.add(0) # --- Pass 2: 根据索引提取帧 --- sorted_indices = sorted(list(selected_indices)) for idx in sorted_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frames.append(frame) cap.release() return frames def encode_frames_to_base64(frames: List[np.ndarray]) -> List[str]: """ 将OpenCV帧列表编码为base64字符串列表。 """ base64_frames = [] for frame in frames: # 将帧编码为JPEG格式 _, buffer = cv2.imencode('.jpg', frame) # 将缓冲区字节转换为base64字符串 base64_str = base64.b64encode(buffer).decode('utf-8') base64_frames.append(base64_str) return base64_frames