import os import random from typing import Dict, Optional, Sequence from PIL import PngImagePlugin, Image, ImageFile import torch from torch.utils.data import Dataset import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from robohusky.train.tcsloader import TCSLoader from robohusky.conversation import get_conv_template IGNORE_INDEX = -100 Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True MaximumDecompressedSize = 1024 MegaByte = 2 ** 20 PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte DEFAULT_IMG_START_TOKEN = "" DEFAULT_IMG_END_TOKEN = "" DEFAULT_VIDEO_START_TOKEN = "" DEFAULT_VIDEO_END_TOKEN = "" def is_image(image_file): if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): return True else: return False def is_video(image_file): if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")): return True else: return False def build_transform(input_size): transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) return transform def format_inputs(sources): # Apply prompt templates conv = get_conv_template("husky").copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" # vision is only supported for the human input if role == conv.roles[0]: value = sentence["value"] if "" in value: if value.endswith("\n"): value = "\n" + value.replace("\n", "") image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN sentence["value"] = value.replace("", image_query) elif "