|
import os |
|
import random |
|
|
|
from typing import Dict, Optional, Sequence |
|
from PIL import PngImagePlugin, Image, ImageFile |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
import torchvision.transforms as T |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from robohusky.train.tcsloader import TCSLoader |
|
from robohusky.conversation import get_conv_template |
|
|
|
IGNORE_INDEX = -100 |
|
|
|
Image.MAX_IMAGE_PIXELS = None |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
MaximumDecompressedSize = 1024 |
|
MegaByte = 2 ** 20 |
|
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte |
|
|
|
DEFAULT_IMG_START_TOKEN = "<img>" |
|
DEFAULT_IMG_END_TOKEN = "</img>" |
|
|
|
DEFAULT_VIDEO_START_TOKEN = "<vid>" |
|
DEFAULT_VIDEO_END_TOKEN = "</vid>" |
|
|
|
def is_image(image_file): |
|
if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): |
|
return True |
|
else: |
|
return False |
|
|
|
def is_video(image_file): |
|
if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")): |
|
return True |
|
else: |
|
return False |
|
|
|
def build_transform(input_size): |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
|
]) |
|
return transform |
|
|
|
def format_inputs(sources): |
|
|
|
conv = get_conv_template("husky").copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
conversations = [] |
|
|
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
|
|
if role == conv.roles[0]: |
|
value = sentence["value"] |
|
if "<image>" in value: |
|
if value.endswith("\n<image>"): |
|
value = "<image>\n" + value.replace("\n<image>", "") |
|
image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN |
|
sentence["value"] = value.replace("<image>", image_query) |
|
|
|
elif "<video>" in value: |
|
if value.endswith("\n<video>"): |
|
value = "<video>\n" + value.replace("\n<video>", "") |
|
video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN |
|
sentence["value"] = value.replace("<video>", video_query) |
|
|
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
return conversations, conv |
|
|
|
def process_func(examples, tokenizer, max_seq_length): |
|
conversations, conv = format_inputs(examples['conversations']) |
|
model_inputs = tokenizer( |
|
conversations, |
|
max_length=max_seq_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
model_inputs.pop("token_type_ids", None) |
|
|
|
|
|
targets = model_inputs["input_ids"].clone() |
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ": " |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
turns = conversation.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, turn in enumerate(turns): |
|
if turn == "": |
|
break |
|
turn_len = len(tokenizer(turn).input_ids) |
|
|
|
parts = turn.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
if i != 0 and not tokenizer.legacy: |
|
|
|
instruction_len -= 1 |
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX |
|
cur_len += turn_len |
|
|
|
if i != 0 and not tokenizer.legacy: |
|
|
|
cur_len -= 1 |
|
|
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
|
|
model_inputs["labels"] = targets |
|
return model_inputs |
|
|
|
class BaseDataset(Dataset): |
|
def __init__(self, dataset, processor, image_path="", input_size=224): |
|
super(BaseDataset, self).__init__() |
|
self.dataset = dataset |
|
self.image_path = image_path |
|
|
|
self.transform = build_transform(input_size) |
|
self.husky_processor = processor |
|
|
|
self.cached_data_dict = {} |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
if i in self.cached_data_dict: |
|
return self.cached_data_dict[i] |
|
|
|
data = self.dataset[i] |
|
image_file = data.pop("image", None) |
|
|
|
if self.image_path != "": |
|
image_file = os.path.join(self.image_path, image_file) |
|
if not os.path.exists(image_file): |
|
return self.__getitem__((i + 1) % len(self.dataset)) |
|
image = Image.open(image_file) |
|
else: |
|
image = Image.open(image_file) |
|
|
|
for k, v in data.items(): |
|
data[k] = [v] |
|
ret = self.husky_processor(data) |
|
for k, v in ret.items(): |
|
ret[k] = v[0] |
|
|
|
pixel_values = self.transform(image) |
|
ret["pixel_values"] = pixel_values |
|
|
|
self.cached_data_dict[i] = ret |
|
return ret |
|
|
|
class CephDataset(Dataset): |
|
def __init__(self, dataset, processor, input_size=224): |
|
super(CephDataset, self).__init__() |
|
self.dataset = dataset |
|
|
|
self.transform = build_transform(input_size) |
|
self.husky_processor = processor |
|
|
|
conf_path = "./petrelf.conf" |
|
self.conf_path = os.path.abspath(conf_path) |
|
|
|
self.initialized = False |
|
self._init_memcached() |
|
|
|
def _init_memcached(self): |
|
if not self.initialized: |
|
assert self.conf_path is not None |
|
self.mt_loader = TCSLoader(self.conf_path) |
|
self.initialized = True |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
data = self.dataset[i] |
|
image_file = data.pop("image", None) |
|
|
|
try: |
|
image = self.mt_loader(image_file).convert('RGB') |
|
except (AttributeError, OSError): |
|
with open("error.txt", 'a') as f: |
|
f.write(image_file + '\n') |
|
i = random.randint(0, len(self.dataset)) |
|
return self.__getitem__(i % len(self.dataset)) |
|
|
|
for k, v in data.items(): |
|
data[k] = [v] |
|
|
|
ret = self.husky_processor(data) |
|
for k, v in ret.items(): |
|
ret[k] = v[0] |
|
pixel_values = self.transform(image) |
|
ret["pixel_values"] = pixel_values |
|
return ret |
|
|