File size: 3,895 Bytes
f8ba0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import cv2
import numpy as np
import base64
from typing import List
from enum import Enum
from skimage.metrics import structural_similarity as ssim

class FrameSamplingMethod(str, Enum):
   UNIFORM = "uniform"
   CONTENT_AWARE = "content_aware"

def extract_frames(
   video_path: str,
   method: FrameSamplingMethod,
   sampling_rate: int
) -> List[np.ndarray]:
   """
   从视频中提取帧。
   对于UNIFORM方法,sampling_rate表示要提取的总帧数。
   对于CONTENT_AWARE方法,sampling_rate现在也表示要提取的总帧数,但会选择变化最大的帧。
   """
   cap = cv2.VideoCapture(video_path)
   if not cap.isOpened():
       raise IOError(f"Cannot open video file: {video_path}")

   frames = []
   total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

   if method == FrameSamplingMethod.UNIFORM:
       if sampling_rate <= 0:
           cap.release()
           return []
       
       # 如果请求的帧数大于总帧数,则返回所有帧
       if sampling_rate >= total_frames:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frames.append(frame)
            cap.release()
            return frames

       # 计算采样间隔
       step = total_frames / sampling_rate
       for i in range(sampling_rate):
           frame_index = int(i * step)
           cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
           ret, frame = cap.read()
           if ret:
               frames.append(frame)
   
   elif method == FrameSamplingMethod.CONTENT_AWARE:
        if sampling_rate <= 0:
            cap.release()
            return []

        # 如果视频总帧数少于或等于请求的帧数,则返回所有帧
        if total_frames <= sampling_rate:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frames.append(frame)
            cap.release()
            return frames
        
        # --- Pass 1: 计算所有相邻帧的SSIM分数 ---
        ssim_scores = []
        cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
        ret, prev_frame = cap.read()
        if not ret:
            cap.release()
            return []
        
        prev_frame_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)

        for i in range(1, total_frames):
            ret, current_frame = cap.read()
            if not ret:
                break
            
            current_frame_gray = cv2.cvtColor(current_frame, cv2.COLOR_BGR2GRAY)
            score, _ = ssim(prev_frame_gray, current_frame_gray, full=True)
            ssim_scores.append((score, i)) # 存储(ssim_score, frame_index)
            prev_frame_gray = current_frame_gray
        
        # --- 选择变化最大的 n-1 帧 ---
        # 按SSIM分数升序排序 (分数越低,差异越大)
        ssim_scores.sort(key=lambda x: x[0])
        
        # 选择分数最低的 n-1 帧的索引
        selected_indices = {score[1] for score in ssim_scores[:sampling_rate - 1]}
        # 始终包括第一帧 (index 0)
        selected_indices.add(0)

        # --- Pass 2: 根据索引提取帧 ---
        sorted_indices = sorted(list(selected_indices))
        for idx in sorted_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frames.append(frame)

   cap.release()
   return frames

def encode_frames_to_base64(frames: List[np.ndarray]) -> List[str]:
   """
   将OpenCV帧列表编码为base64字符串列表。
   """
   base64_frames = []
   for frame in frames:
       # 将帧编码为JPEG格式
       _, buffer = cv2.imencode('.jpg', frame)
       # 将缓冲区字节转换为base64字符串
       base64_str = base64.b64encode(buffer).decode('utf-8')
       base64_frames.append(base64_str)
   return base64_frames