import os
import random
from typing import Dict, Optional, Sequence, Iterator, List, Iterable, Union
from PIL import PngImagePlugin, Image, ImageFile, ImageOps
import numpy as np
import torch
from torch.utils.data import (
Dataset,
ConcatDataset,
Sampler,
WeightedRandomSampler
)
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from robohusky.train.tcsloader import TCSLoader
from decord import VideoReader, cpu
from robohusky.video_transformers import (
GroupNormalize,
GroupScale,
GroupCenterCrop,
Stack,
ToTorchFormatTensor,
get_index,
)
from robohusky.conversation import get_conv_template
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
IGNORE_INDEX = -100
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2 ** 20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
DEFAULT_IMG_START_TOKEN = "
"
DEFAULT_IMG_END_TOKEN = ""
DEFAULT_VIDEO_START_TOKEN = ""
DEFAULT_VIDEO_END_TOKEN = ""
DEFAULT_EMBED_TOKEN = ""
conf_path = "/your path to/petrelf.conf"
def is_image(image_file):
if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
return True
else:
return False
def is_video(image_file):
if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")):
return True
else:
return False
def is_numpy(image_file):
if image_file.endswith(".npy"):
return True
else:
return False
def get_media_type(image_file):
if is_image(image_file):
return "image"
elif is_video(image_file):
return "video"
elif is_numpy(image_file):
return "numpy"
else:
return "text"
def build_transform(input_size, norm_type="openai", media_type="image"):
if norm_type == "openai":
mean = OPENAI_CLIP_MEAN
std = OPENAI_CLIP_STD
elif norm_type == "imagenet":
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
if media_type == "image":
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
elif media_type == "video":
transform = T.Compose([
GroupScale(int(input_size), interpolation=InterpolationMode.BICUBIC),
GroupCenterCrop(input_size),
Stack(),
ToTorchFormatTensor(),
GroupNormalize(mean=mean, std=std)
])
else:
transform = None
return transform
def check_format(data):
if not ('id' in data and 'image' in data and 'conversations' in data and len(data['conversations']) % 2 == 0):
print(f"Lake field: {data}")
return False
for i, message in enumerate(data['conversations']):
if i == 0:
if not (message['value'].startswith("\n") or message['value'].endswith("\n")):
print(f"No : {data}")
return False
if i % 2 == 0:
if not (message['from'] == 'human'):
print(f"Not from human: {data}")
return False
else:
if not (message['from'] == 'gpt'):
print(f"Not from gpt: {data}")
return False
if message['value'] is None or (len(message['value']) == 0):
print(f"No Message: {data}")
return False
return True
def format_inputs(sources, conv_tempt="husky", num_query_tokens=256):
# Apply prompt templates
conv = get_conv_template(conv_tempt).copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
# vision is only supported for the human input
if role == conv.roles[0]:
value = sentence["value"]
if "" in value:
if value.endswith("\n"):
value = "\n" + value.replace("\n", "")
image_query = DEFAULT_IMG_START_TOKEN + num_query_tokens * DEFAULT_EMBED_TOKEN + DEFAULT_IMG_END_TOKEN
sentence["value"] = value.replace("", image_query)
elif "