Ligeng-Zhu's picture
Upload files with `vila-upload`.
d8c0285 verified
import glob
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
import cv2
import numpy as np
import PIL
import PIL.Image
import requests
from transformers import PretrainedConfig
# from llava.constants import MEDIA_TOKENS
# from llava.media import Image, Video
# from llava.utils import make_list
# from llava.utils.logging import logger
MEDIA_TOKENS = {
"image": "<image>",
"video": "<vila/video>",
}
class Media:
pass
class File(Media):
def __init__(self, path: str) -> None:
self.path = path
class Image(File):
pass
class Video(File):
pass
def make_list(obj: Any) -> List:
return obj if isinstance(obj, list) else [obj]
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
if isinstance(image, Image):
if image.path.startswith("http://") or image.path.startswith("https://"):
image = PIL.Image.open(requests.get(image.path, stream=True).raw)
else:
image = PIL.Image.open(image.path)
return image
def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
# Load video frames from a directory
if os.path.isdir(video_path):
frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
return [PIL.Image.open(frame_paths[index]) for index in indices]
# Load video frames from a video file
vidcap = cv2.VideoCapture(video_path)
# Find the last frame as frame count might not be accurate
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
while frame_count > 0:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
if vidcap.grab():
break
frame_count -= 1
else:
raise ValueError(f"Video '{video_path}' has no frames.")
# Extract frames uniformly
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
frames = {}
for index in indices:
if index in frames:
continue
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
success, frame = vidcap.read()
if not success:
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames[index] = PIL.Image.fromarray(frame)
return [frames[index] for index in indices if index in frames]
def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
num_frames = config.num_video_frames
if getattr(config, "fps") != 0:
print("Extracting frames from video with specified FPS is not supported yet. Ignored.")
frames = _load_video(video.path, num_frames=num_frames)
return frames
def extract_media(
messages: List[Dict[str, Any]],
config: Optional[PretrainedConfig] = None,
draft: bool = False,
) -> Dict[str, List[Any]]:
media = defaultdict(list)
for message in messages:
text = ""
for part in make_list(message["value"]):
if isinstance(part, str):
for token in MEDIA_TOKENS.values():
if token in part:
print(f"Media token '{token}' found in text: '{part}'. Removed.")
part = part.replace(token, "").strip()
text += part
elif isinstance(part, (Image, PIL.Image.Image)):
if draft:
media["image"].append(part)
else:
media["image"].append(_extract_image(part))
text += MEDIA_TOKENS["image"]
elif isinstance(part, Video):
if draft:
media["video"].append(part)
else:
media["video"].append(_extract_video(part, config))
text += MEDIA_TOKENS["video"]
else:
raise ValueError(f"Unsupported prompt part type: {type(part)}")
message["value"] = text
return media