lambertxiao's picture
Overwrite with converted Qwen2.5-3B model files
492f6af verified
"""Utility functions"""
import importlib
import random
import re
import torch
import numpy as np
from PIL import Image
def normalize(image,rescale=True):
if rescale:
image = image.float() / 255.0 # Convert to float and rescale to [0, 1]
normalize_image = 2*image-1 # normalize to [-1, 1]
return normalize_image
def process_caption(caption):
"""Process a caption to ensure proper formatting and remove duplicates.
Args:
caption: A string containing the caption text
Returns:
processed_caption: A string with processed caption
"""
if not caption.endswith('.'):
last_period_index = caption.rfind('.')
if last_period_index != -1:
caption = caption[:last_period_index + 1]
sentences = re.split(r'(?<=[.!?])\s+', caption)
unique_sentences = []
for sentence in sentences:
if sentence and sentence not in unique_sentences:
unique_sentences.append(sentence)
processed_caption = ' '.join(unique_sentences)
return processed_caption
def initiate_time_steps(step, total_timestep, batch_size, config):
"""A helper function to initiate time steps for the diffusion model.
Args:
step: An integer of the constant step
total_timestep: An integer of the total timesteps of the diffusion model
batch_size: An integer of the batch size
config: A config object
Returns:
timesteps: A tensor of shape [batch_size,] of the time steps
"""
if config.rand_timestep_equal_int:
# the same timestep for each image in the batch
interval_val = total_timestep // batch_size
start_point = random.randint(0, interval_val - 1)
timesteps = torch.tensor(
list(range(start_point, total_timestep, interval_val))
).long()
return timesteps
elif config.random_timestep_per_iteration:
# random timestep for each image in the batch
return torch.randint(0, total_timestep, (batch_size,)).long() #default
else:
# why we need to do this?
return torch.tensor([step] * batch_size).long()