|
import cv2 |
|
import numpy as np |
|
import base64 |
|
from typing import List |
|
from enum import Enum |
|
from skimage.metrics import structural_similarity as ssim |
|
|
|
class FrameSamplingMethod(str, Enum): |
|
UNIFORM = "uniform" |
|
CONTENT_AWARE = "content_aware" |
|
|
|
def extract_frames( |
|
video_path: str, |
|
method: FrameSamplingMethod, |
|
sampling_rate: int |
|
) -> List[np.ndarray]: |
|
""" |
|
从视频中提取帧。 |
|
对于UNIFORM方法,sampling_rate表示要提取的总帧数。 |
|
对于CONTENT_AWARE方法,sampling_rate现在也表示要提取的总帧数,但会选择变化最大的帧。 |
|
""" |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
raise IOError(f"Cannot open video file: {video_path}") |
|
|
|
frames = [] |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
if method == FrameSamplingMethod.UNIFORM: |
|
if sampling_rate <= 0: |
|
cap.release() |
|
return [] |
|
|
|
|
|
if sampling_rate >= total_frames: |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
frames.append(frame) |
|
cap.release() |
|
return frames |
|
|
|
|
|
step = total_frames / sampling_rate |
|
for i in range(sampling_rate): |
|
frame_index = int(i * step) |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) |
|
ret, frame = cap.read() |
|
if ret: |
|
frames.append(frame) |
|
|
|
elif method == FrameSamplingMethod.CONTENT_AWARE: |
|
if sampling_rate <= 0: |
|
cap.release() |
|
return [] |
|
|
|
|
|
if total_frames <= sampling_rate: |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
frames.append(frame) |
|
cap.release() |
|
return frames |
|
|
|
|
|
ssim_scores = [] |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) |
|
ret, prev_frame = cap.read() |
|
if not ret: |
|
cap.release() |
|
return [] |
|
|
|
prev_frame_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) |
|
|
|
for i in range(1, total_frames): |
|
ret, current_frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
current_frame_gray = cv2.cvtColor(current_frame, cv2.COLOR_BGR2GRAY) |
|
score, _ = ssim(prev_frame_gray, current_frame_gray, full=True) |
|
ssim_scores.append((score, i)) |
|
prev_frame_gray = current_frame_gray |
|
|
|
|
|
|
|
ssim_scores.sort(key=lambda x: x[0]) |
|
|
|
|
|
selected_indices = {score[1] for score in ssim_scores[:sampling_rate - 1]} |
|
|
|
selected_indices.add(0) |
|
|
|
|
|
sorted_indices = sorted(list(selected_indices)) |
|
for idx in sorted_indices: |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
ret, frame = cap.read() |
|
if ret: |
|
frames.append(frame) |
|
|
|
cap.release() |
|
return frames |
|
|
|
def encode_frames_to_base64(frames: List[np.ndarray]) -> List[str]: |
|
""" |
|
将OpenCV帧列表编码为base64字符串列表。 |
|
""" |
|
base64_frames = [] |
|
for frame in frames: |
|
|
|
_, buffer = cv2.imencode('.jpg', frame) |
|
|
|
base64_str = base64.b64encode(buffer).decode('utf-8') |
|
base64_frames.append(base64_str) |
|
return base64_frames |