|
import glob |
|
import os |
|
from collections import defaultdict |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import cv2 |
|
import numpy as np |
|
import PIL |
|
import PIL.Image |
|
import requests |
|
from transformers import PretrainedConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
MEDIA_TOKENS = { |
|
"image": "<image>", |
|
"video": "<vila/video>", |
|
} |
|
|
|
class Media: |
|
pass |
|
|
|
class File(Media): |
|
def __init__(self, path: str) -> None: |
|
self.path = path |
|
|
|
class Image(File): |
|
pass |
|
|
|
|
|
class Video(File): |
|
pass |
|
|
|
def make_list(obj: Any) -> List: |
|
return obj if isinstance(obj, list) else [obj] |
|
|
|
|
|
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: |
|
if isinstance(image, Image): |
|
if image.path.startswith("http://") or image.path.startswith("https://"): |
|
image = PIL.Image.open(requests.get(image.path, stream=True).raw) |
|
else: |
|
image = PIL.Image.open(image.path) |
|
return image |
|
|
|
|
|
def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]: |
|
|
|
if os.path.isdir(video_path): |
|
frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) |
|
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) |
|
return [PIL.Image.open(frame_paths[index]) for index in indices] |
|
|
|
|
|
vidcap = cv2.VideoCapture(video_path) |
|
|
|
|
|
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
while frame_count > 0: |
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) |
|
if vidcap.grab(): |
|
break |
|
frame_count -= 1 |
|
else: |
|
raise ValueError(f"Video '{video_path}' has no frames.") |
|
|
|
|
|
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) |
|
frames = {} |
|
for index in indices: |
|
if index in frames: |
|
continue |
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) |
|
success, frame = vidcap.read() |
|
if not success: |
|
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.") |
|
continue |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames[index] = PIL.Image.fromarray(frame) |
|
return [frames[index] for index in indices if index in frames] |
|
|
|
|
|
def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: |
|
num_frames = config.num_video_frames |
|
if getattr(config, "fps") != 0: |
|
print("Extracting frames from video with specified FPS is not supported yet. Ignored.") |
|
|
|
frames = _load_video(video.path, num_frames=num_frames) |
|
return frames |
|
|
|
|
|
def extract_media( |
|
messages: List[Dict[str, Any]], |
|
config: Optional[PretrainedConfig] = None, |
|
draft: bool = False, |
|
) -> Dict[str, List[Any]]: |
|
media = defaultdict(list) |
|
for message in messages: |
|
text = "" |
|
for part in make_list(message["value"]): |
|
if isinstance(part, str): |
|
for token in MEDIA_TOKENS.values(): |
|
if token in part: |
|
print(f"Media token '{token}' found in text: '{part}'. Removed.") |
|
part = part.replace(token, "").strip() |
|
text += part |
|
elif isinstance(part, (Image, PIL.Image.Image)): |
|
if draft: |
|
media["image"].append(part) |
|
else: |
|
media["image"].append(_extract_image(part)) |
|
text += MEDIA_TOKENS["image"] |
|
elif isinstance(part, Video): |
|
if draft: |
|
media["video"].append(part) |
|
else: |
|
media["video"].append(_extract_video(part, config)) |
|
text += MEDIA_TOKENS["video"] |
|
else: |
|
raise ValueError(f"Unsupported prompt part type: {type(part)}") |
|
message["value"] = text |
|
return media |
|
|