import os import random from typing import Dict, Optional, Sequence, Iterator, List, Iterable, Union from PIL import PngImagePlugin, Image, ImageFile, ImageOps import numpy as np import torch from torch.utils.data import ( Dataset, ConcatDataset, Sampler, WeightedRandomSampler ) import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from robohusky.train.tcsloader import TCSLoader from decord import VideoReader, cpu from robohusky.video_transformers import ( GroupNormalize, GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor, get_index, ) from robohusky.conversation import get_conv_template IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] IGNORE_INDEX = -100 Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True MaximumDecompressedSize = 1024 MegaByte = 2 ** 20 PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte DEFAULT_IMG_START_TOKEN = "" DEFAULT_IMG_END_TOKEN = "" DEFAULT_VIDEO_START_TOKEN = "" DEFAULT_VIDEO_END_TOKEN = "" DEFAULT_EMBED_TOKEN = "" conf_path = "/your path to/petrelf.conf" def is_image(image_file): if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): return True else: return False def is_video(image_file): if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")): return True else: return False def is_numpy(image_file): if image_file.endswith(".npy"): return True else: return False def get_media_type(image_file): if is_image(image_file): return "image" elif is_video(image_file): return "video" elif is_numpy(image_file): return "numpy" else: return "text" def build_transform(input_size, norm_type="openai", media_type="image"): if norm_type == "openai": mean = OPENAI_CLIP_MEAN std = OPENAI_CLIP_STD elif norm_type == "imagenet": mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD else: mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD if media_type == "image": transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std) ]) elif media_type == "video": transform = T.Compose([ GroupScale(int(input_size), interpolation=InterpolationMode.BICUBIC), GroupCenterCrop(input_size), Stack(), ToTorchFormatTensor(), GroupNormalize(mean=mean, std=std) ]) else: transform = None return transform def check_format(data): if not ('id' in data and 'image' in data and 'conversations' in data and len(data['conversations']) % 2 == 0): print(f"Lake field: {data}") return False for i, message in enumerate(data['conversations']): if i == 0: if not (message['value'].startswith("\n") or message['value'].endswith("\n")): print(f"No : {data}") return False if i % 2 == 0: if not (message['from'] == 'human'): print(f"Not from human: {data}") return False else: if not (message['from'] == 'gpt'): print(f"Not from gpt: {data}") return False if message['value'] is None or (len(message['value']) == 0): print(f"No Message: {data}") return False return True def format_inputs(sources, conv_tempt="husky", num_query_tokens=256): # Apply prompt templates conv = get_conv_template(conv_tempt).copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" # vision is only supported for the human input if role == conv.roles[0]: value = sentence["value"] if "" in value: if value.endswith("\n"): value = "\n" + value.replace("\n", "") image_query = DEFAULT_IMG_START_TOKEN + num_query_tokens * DEFAULT_EMBED_TOKEN + DEFAULT_IMG_END_TOKEN sentence["value"] = value.replace("", image_query) elif "