|
"""Utility functions""" |
|
import importlib |
|
import random |
|
import re |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
def normalize(image,rescale=True): |
|
|
|
if rescale: |
|
image = image.float() / 255.0 |
|
normalize_image = 2*image-1 |
|
|
|
return normalize_image |
|
|
|
|
|
|
|
def process_caption(caption): |
|
"""Process a caption to ensure proper formatting and remove duplicates. |
|
|
|
Args: |
|
caption: A string containing the caption text |
|
|
|
Returns: |
|
processed_caption: A string with processed caption |
|
""" |
|
if not caption.endswith('.'): |
|
last_period_index = caption.rfind('.') |
|
if last_period_index != -1: |
|
caption = caption[:last_period_index + 1] |
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', caption) |
|
|
|
unique_sentences = [] |
|
for sentence in sentences: |
|
if sentence and sentence not in unique_sentences: |
|
unique_sentences.append(sentence) |
|
|
|
processed_caption = ' '.join(unique_sentences) |
|
|
|
return processed_caption |
|
|
|
|
|
def initiate_time_steps(step, total_timestep, batch_size, config): |
|
"""A helper function to initiate time steps for the diffusion model. |
|
|
|
Args: |
|
step: An integer of the constant step |
|
total_timestep: An integer of the total timesteps of the diffusion model |
|
batch_size: An integer of the batch size |
|
config: A config object |
|
|
|
Returns: |
|
timesteps: A tensor of shape [batch_size,] of the time steps |
|
""" |
|
if config.rand_timestep_equal_int: |
|
|
|
interval_val = total_timestep // batch_size |
|
start_point = random.randint(0, interval_val - 1) |
|
timesteps = torch.tensor( |
|
list(range(start_point, total_timestep, interval_val)) |
|
).long() |
|
return timesteps |
|
elif config.random_timestep_per_iteration: |
|
|
|
return torch.randint(0, total_timestep, (batch_size,)).long() |
|
else: |
|
|
|
return torch.tensor([step] * batch_size).long() |