|
import os |
|
import random |
|
|
|
from typing import Dict, Optional, Sequence, Iterator, List, Iterable, Union |
|
from PIL import PngImagePlugin, Image, ImageFile, ImageOps |
|
|
|
import numpy as np |
|
|
|
import torch |
|
from torch.utils.data import ( |
|
Dataset, |
|
ConcatDataset, |
|
Sampler, |
|
WeightedRandomSampler |
|
) |
|
import torchvision.transforms as T |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from robohusky.train.tcsloader import TCSLoader |
|
|
|
from decord import VideoReader, cpu |
|
from robohusky.video_transformers import ( |
|
GroupNormalize, |
|
GroupScale, |
|
GroupCenterCrop, |
|
Stack, |
|
ToTorchFormatTensor, |
|
get_index, |
|
) |
|
|
|
from robohusky.conversation import get_conv_template |
|
|
|
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] |
|
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] |
|
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] |
|
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] |
|
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] |
|
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] |
|
|
|
IGNORE_INDEX = -100 |
|
|
|
Image.MAX_IMAGE_PIXELS = None |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
MaximumDecompressedSize = 1024 |
|
MegaByte = 2 ** 20 |
|
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte |
|
|
|
DEFAULT_IMG_START_TOKEN = "<img>" |
|
DEFAULT_IMG_END_TOKEN = "</img>" |
|
|
|
DEFAULT_VIDEO_START_TOKEN = "<vid>" |
|
DEFAULT_VIDEO_END_TOKEN = "</vid>" |
|
|
|
DEFAULT_EMBED_TOKEN = "<quad>" |
|
|
|
conf_path = "/your path to/petrelf.conf" |
|
|
|
def is_image(image_file): |
|
if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): |
|
return True |
|
else: |
|
return False |
|
|
|
def is_video(image_file): |
|
if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")): |
|
return True |
|
else: |
|
return False |
|
|
|
def is_numpy(image_file): |
|
if image_file.endswith(".npy"): |
|
return True |
|
else: |
|
return False |
|
|
|
def get_media_type(image_file): |
|
if is_image(image_file): |
|
return "image" |
|
elif is_video(image_file): |
|
return "video" |
|
elif is_numpy(image_file): |
|
return "numpy" |
|
else: |
|
return "text" |
|
|
|
def build_transform(input_size, norm_type="openai", media_type="image"): |
|
if norm_type == "openai": |
|
mean = OPENAI_CLIP_MEAN |
|
std = OPENAI_CLIP_STD |
|
elif norm_type == "imagenet": |
|
mean = IMAGENET_DEFAULT_MEAN |
|
std = IMAGENET_DEFAULT_STD |
|
else: |
|
mean = IMAGENET_DEFAULT_MEAN |
|
std = IMAGENET_DEFAULT_STD |
|
|
|
if media_type == "image": |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=mean, std=std) |
|
]) |
|
elif media_type == "video": |
|
transform = T.Compose([ |
|
GroupScale(int(input_size), interpolation=InterpolationMode.BICUBIC), |
|
GroupCenterCrop(input_size), |
|
Stack(), |
|
ToTorchFormatTensor(), |
|
GroupNormalize(mean=mean, std=std) |
|
]) |
|
else: |
|
transform = None |
|
return transform |
|
|
|
def check_format(data): |
|
if not ('id' in data and 'image' in data and 'conversations' in data and len(data['conversations']) % 2 == 0): |
|
print(f"Lake field: {data}") |
|
return False |
|
for i, message in enumerate(data['conversations']): |
|
if i == 0: |
|
if not (message['value'].startswith("<image>\n") or message['value'].endswith("\n<image>")): |
|
print(f"No <image>: {data}") |
|
return False |
|
if i % 2 == 0: |
|
if not (message['from'] == 'human'): |
|
print(f"Not from human: {data}") |
|
return False |
|
else: |
|
if not (message['from'] == 'gpt'): |
|
print(f"Not from gpt: {data}") |
|
return False |
|
if message['value'] is None or (len(message['value']) == 0): |
|
print(f"No Message: {data}") |
|
return False |
|
return True |
|
|
|
def format_inputs(sources, conv_tempt="husky", num_query_tokens=256): |
|
|
|
conv = get_conv_template(conv_tempt).copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
conversations = [] |
|
|
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
|
|
if role == conv.roles[0]: |
|
value = sentence["value"] |
|
if "<image>" in value: |
|
if value.endswith("\n<image>"): |
|
value = "<image>\n" + value.replace("\n<image>", "") |
|
|
|
image_query = DEFAULT_IMG_START_TOKEN + num_query_tokens * DEFAULT_EMBED_TOKEN + DEFAULT_IMG_END_TOKEN |
|
sentence["value"] = value.replace("<image>", image_query) |
|
|
|
elif "<video>" in value: |
|
if value.endswith("\n<video>"): |
|
value = "<video>\n" + value.replace("\n<video>", "") |
|
|
|
video_query = DEFAULT_VIDEO_START_TOKEN + num_query_tokens * DEFAULT_EMBED_TOKEN + DEFAULT_VIDEO_END_TOKEN |
|
sentence["value"] = value.replace("<video>", video_query) |
|
|
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
return conversations, conv |
|
|
|
def process_func(examples, tokenizer, max_seq_length=-1, conv_tempt="husky", num_query_tokens=256): |
|
conversations, conv = format_inputs(examples['conversations'], conv_tempt, num_query_tokens) |
|
if max_seq_length < 0: |
|
model_inputs = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
) |
|
else: |
|
model_inputs = tokenizer( |
|
conversations, |
|
max_length=max_seq_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
model_inputs.pop("token_type_ids", None) |
|
|
|
|
|
targets = model_inputs["input_ids"].clone() |
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ": " |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
turns = conversation.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, turn in enumerate(turns): |
|
if turn == "": |
|
break |
|
turn_len = len(tokenizer(turn).input_ids) |
|
|
|
parts = turn.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
if i != 0 and not tokenizer.legacy: |
|
|
|
instruction_len -= 1 |
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX |
|
cur_len += turn_len |
|
|
|
if i != 0 and not tokenizer.legacy: |
|
|
|
cur_len -= 1 |
|
|
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
|
|
model_inputs["labels"] = targets |
|
return model_inputs |
|
|
|
class BaseDataset(Dataset): |
|
def __init__( |
|
self, |
|
dataset, |
|
processor, |
|
image_path="", |
|
input_size=224, |
|
num_segments=8, |
|
norm_type="openai", |
|
media_type="image" |
|
): |
|
super(BaseDataset, self).__init__() |
|
self.dataset = dataset |
|
self.image_path = image_path |
|
self.input_size = input_size |
|
self.num_segments = num_segments |
|
|
|
self.media_type = media_type |
|
self.transform = build_transform(input_size, norm_type, media_type) |
|
self.husky_processor = processor |
|
self.tcs_loader = TCSLoader(os.path.abspath(conf_path), media_type=media_type) |
|
|
|
self.cached_data_dict = {} |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
if i in self.cached_data_dict: |
|
return self.cached_data_dict[i] |
|
|
|
data = self.dataset[i] |
|
image_file = data["image"] if "image" in data else data["video"] |
|
|
|
if self.media_type == "llm" or image_file == "": |
|
|
|
|
|
pixel_values = None |
|
else: |
|
if self.image_path != "": |
|
image_file = os.path.join(self.image_path, image_file) |
|
if "s3://" not in image_file and not os.path.exists(image_file): |
|
i = random.randint(0, len(self.dataset)) |
|
return self.__getitem__(i % len(self.dataset)) |
|
|
|
try: |
|
if self.media_type == "image": |
|
|
|
if "s3://" in image_file: |
|
image = self.tcs_loader(image_file) |
|
else: |
|
image = Image.open(image_file).convert('RGB') |
|
|
|
|
|
height, width = image.size |
|
if height / width >= 1.8: |
|
delta = height - width |
|
padding = (0, delta // 2, 0, delta - delta // 2) |
|
image = ImageOps.expand(image, padding) |
|
elif height / width <= 0.56: |
|
delta = width - height |
|
padding = (delta // 2, 0, delta - delta // 2, 0) |
|
image = ImageOps.expand(image, padding) |
|
pixel_values = self.transform(image) |
|
elif self.media_type == "video": |
|
if "s3://" in image_file: |
|
vr = self.tcs_loader(image_file) |
|
else: |
|
vr = VideoReader(image_file, ctx=cpu(0)) |
|
|
|
num_frames = len(vr) |
|
frame_indices = get_index(num_frames, self.num_segments) |
|
images_group = list() |
|
for frame_index in frame_indices: |
|
img = Image.fromarray(vr[frame_index].asnumpy()) |
|
images_group.append(img) |
|
pixel_values = self.transform(images_group) |
|
TC, H, W = pixel_values.shape |
|
pixel_values = pixel_values.reshape(TC // 3, 3, H, W).transpose(0, 1) |
|
else: |
|
|
|
if "s3://" in image_file: |
|
pixel_values = self.tcs_loader(image_file) |
|
else: |
|
pixel_values = np.load(image_file) |
|
pixel_values = torch.tensor(pixel_values).transpose(0, 1) |
|
except (AttributeError, OSError): |
|
with open("error.txt", 'a') as f: |
|
f.write(image_file + '\n') |
|
i = random.randint(0, len(self.dataset)) |
|
return self.__getitem__(i % len(self.dataset)) |
|
|
|
for k, v in data.items(): |
|
data[k] = [v] |
|
ret = self.husky_processor(data) |
|
for k, v in ret.items(): |
|
ret[k] = v[0] |
|
|
|
if pixel_values is not None: |
|
ret["pixel_values"] = pixel_values |
|
|
|
self.cached_data_dict[i] = ret |
|
return ret |
|
|
|
class WeightedConcatDataset(ConcatDataset): |
|
def __init__( |
|
self, |
|
datasets: List[Dataset], |
|
weights: Sequence[float] = None, |
|
replacement: bool = True, |
|
batch_size: int = -1, |
|
generator=None |
|
) -> None: |
|
super().__init__(datasets) |
|
if weights is None: |
|
weights = [1.0] * len(self.datasets) |
|
weights_tensor = torch.as_tensor(weights, dtype=torch.double) |
|
if len(weights_tensor.shape) != 1: |
|
raise ValueError("weights should be a 1d sequence but given " |
|
"weights have shape {}".format(tuple(weights_tensor.shape))) |
|
self.weights = weights_tensor |
|
self.batch_size = batch_size |
|
|
|
self.replacement = replacement |
|
self.generator = generator |
|
|
|
if self.batch_size <= 0: |
|
self.num_samples = sum([len(d) for d in datasets]) |
|
self.sampler = WeightedRandomSampler( |
|
weights=self.weights, |
|
num_samples=self.num_samples, |
|
replacement=self.replacement |
|
) |
|
else: |
|
self.task_batches = [len(d) // batch_size for d in datasets] |
|
self.num_samples = sum(self.task_batches) * batch_size |
|
self.sampler = WeightedBatchSampler( |
|
weights=self.weights, |
|
num_samples=self.num_samples, |
|
batch_size=self.batch_size, |
|
replacement=self.replacement |
|
) |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
return iter(self.sampler) |
|
|
|
def __len__(self) -> int: |
|
return self.num_samples |
|
|
|
class WeightedBatchSampler(Sampler[int]): |
|
weights: torch.Tensor |
|
num_samples: int |
|
batch_size: int |
|
replacement: bool |
|
|
|
def __init__( |
|
self, |
|
weights: Sequence[float], |
|
num_samples: int, |
|
batch_size: int, |
|
replacement: bool = True, |
|
generator=None |
|
) -> None: |
|
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ |
|
batch_size <= 0: |
|
raise ValueError("batch_size should be a positive integer value, " |
|
"but got batch_size={}".format(batch_size)) |
|
if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ |
|
num_samples <= 0: |
|
raise ValueError("num_samples should be a positive integer " |
|
"value, but got num_samples={}".format(num_samples)) |
|
if not isinstance(replacement, bool): |
|
raise ValueError("replacement should be a boolean value, but got " |
|
"replacement={}".format(replacement)) |
|
|
|
weights_tensor = torch.as_tensor(weights, dtype=torch.double) |
|
if len(weights_tensor.shape) != 1: |
|
raise ValueError("weights should be a 1d sequence but given " |
|
"weights have shape {}".format(tuple(weights_tensor.shape))) |
|
|
|
self.weights = weights_tensor |
|
self.num_samples = num_samples |
|
self.batch_size = batch_size |
|
self.num_batches = num_samples // batch_size |
|
self.replacement = replacement |
|
self.generator = generator |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
rand_tensor = torch.multinomial(self.weights, self.num_batches, self.replacement, generator=self.generator) |
|
rand_tensor = rand_tensor.repeat_interleave(self.batch_size) |
|
|
|
yield from iter(rand_tensor.tolist()) |
|
|
|
def __len__(self) -> int: |
|
return self.num_samples |
|
|