Overwrite with converted Qwen2.5-3B model files (mirror delete: 11 files)
Browse files
utils.py
DELETED
@@ -1,157 +0,0 @@
|
|
1 |
-
"""Utility functions"""
|
2 |
-
import importlib
|
3 |
-
import random
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import numpy as np
|
7 |
-
from PIL import Image
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
class UnNormalize(object):
|
12 |
-
"""Unformalize image as: image = (image * std) + mean
|
13 |
-
"""
|
14 |
-
def __init__(self, mean, std):
|
15 |
-
self.mean = torch.tensor(mean)
|
16 |
-
self.std = torch.tensor(std)
|
17 |
-
|
18 |
-
def __call__(self, tensor):
|
19 |
-
"""
|
20 |
-
Args:
|
21 |
-
tensor: A tensor of shape [C, H, W] or [N, C, H, W]
|
22 |
-
|
23 |
-
Returns:
|
24 |
-
tensor: A tensor of shape [C, H, W] or [N, C, H, W]
|
25 |
-
"""
|
26 |
-
|
27 |
-
std = self.std.to(tensor.device)
|
28 |
-
mean = self.mean.to(tensor.device)
|
29 |
-
if tensor.ndim == 3:
|
30 |
-
std, mean = std.view(-1, 1, 1), mean.view(-1, 1, 1)
|
31 |
-
elif tensor.ndim == 4:
|
32 |
-
std, mean = std.view(1, -1, 1, 1), mean.view(1, -1, 1, 1)
|
33 |
-
tensor = (tensor * std) + mean
|
34 |
-
return tensor
|
35 |
-
|
36 |
-
|
37 |
-
class VQVAEUnNormalize(UnNormalize):
|
38 |
-
"""Unformalize image as:
|
39 |
-
First: image = (image * std) + mean
|
40 |
-
Second: image = (image * 2) - 1
|
41 |
-
"""
|
42 |
-
def __call__(self, tensor):
|
43 |
-
"""
|
44 |
-
Args:
|
45 |
-
tensor (Tensor): Tensor image of size (C, H, W) or (N, C, H, W)
|
46 |
-
to be unnormalized.
|
47 |
-
Returns:
|
48 |
-
Tensor: UnNormalized image.
|
49 |
-
"""
|
50 |
-
tensor = super().__call__(tensor)
|
51 |
-
tensor = 2 * tensor - 1
|
52 |
-
return tensor
|
53 |
-
|
54 |
-
def normalize(image,rescale=True):
|
55 |
-
|
56 |
-
if rescale:
|
57 |
-
image = image.float() / 255.0 # Convert to float and rescale to [0, 1]
|
58 |
-
normalize_image = 2*image-1 # normalize to [-1, 1]
|
59 |
-
|
60 |
-
return normalize_image
|
61 |
-
|
62 |
-
# train_transforms = transforms.Compose(
|
63 |
-
# [
|
64 |
-
# transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
65 |
-
# transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
66 |
-
# transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
67 |
-
# transforms.ToTensor(),
|
68 |
-
# transforms.Normalize([0.5], [0.5]),
|
69 |
-
# ]
|
70 |
-
# )
|
71 |
-
|
72 |
-
|
73 |
-
def mean_list(l):
|
74 |
-
l = [int(_l) for _l in l]
|
75 |
-
return float(sum(l)) / len(l)
|
76 |
-
|
77 |
-
|
78 |
-
def segment_mean(x, index):
|
79 |
-
"""Function as tf.segment_mean.
|
80 |
-
"""
|
81 |
-
x = x.view(-1, x.shape[-1])
|
82 |
-
index = index.view(-1)
|
83 |
-
|
84 |
-
max_index = index.max() + 1
|
85 |
-
sum_x = torch.zeros((max_index, x.shape[-1]),
|
86 |
-
dtype=x.dtype,
|
87 |
-
device=x.device)
|
88 |
-
num_index = torch.zeros((max_index,),
|
89 |
-
dtype=x.dtype,
|
90 |
-
device=x.device)
|
91 |
-
|
92 |
-
num_index = num_index.scatter_add_(
|
93 |
-
0, index, torch.ones_like(index, dtype=x.dtype))
|
94 |
-
num_index = torch.where(torch.eq(num_index, 0),
|
95 |
-
torch.ones_like(num_index, dtype=x.dtype),
|
96 |
-
num_index)
|
97 |
-
|
98 |
-
index_2d = index.view(-1, 1).expand(-1, x.shape[-1])
|
99 |
-
sum_x = sum_x.scatter_add_(0, index_2d, x)
|
100 |
-
mean_x = sum_x.div_(num_index.view(-1, 1))
|
101 |
-
|
102 |
-
return mean_x
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
def initiate_time_steps(step, total_timestep, batch_size, config):
|
109 |
-
"""A helper function to initiate time steps for the diffusion model.
|
110 |
-
|
111 |
-
Args:
|
112 |
-
step: An integer of the constant step
|
113 |
-
total_timestep: An integer of the total timesteps of the diffusion model
|
114 |
-
batch_size: An integer of the batch size
|
115 |
-
config: A config object
|
116 |
-
|
117 |
-
Returns:
|
118 |
-
timesteps: A tensor of shape [batch_size,] of the time steps
|
119 |
-
"""
|
120 |
-
if config.tta.rand_timestep_equal_int:
|
121 |
-
# the same timestep for each image in the batch
|
122 |
-
interval_val = total_timestep // batch_size
|
123 |
-
start_point = random.randint(0, interval_val - 1)
|
124 |
-
timesteps = torch.tensor(
|
125 |
-
list(range(start_point, total_timestep, interval_val))
|
126 |
-
).long()
|
127 |
-
return timesteps
|
128 |
-
elif config.tta.random_timestep_per_iteration:
|
129 |
-
# random timestep for each image in the batch
|
130 |
-
return torch.randint(0, total_timestep, (batch_size,)).long() #default
|
131 |
-
else:
|
132 |
-
# why we need to do this?
|
133 |
-
return torch.tensor([step] * batch_size).long()
|
134 |
-
|
135 |
-
|
136 |
-
def instantiate_from_config(config):
|
137 |
-
"""A helper function to instantiate a class from a config object.
|
138 |
-
See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py
|
139 |
-
"""
|
140 |
-
if not "target" in config:
|
141 |
-
if config == '__is_first_stage__':
|
142 |
-
return None
|
143 |
-
elif config == "__is_unconditional__":
|
144 |
-
return None
|
145 |
-
raise KeyError("Expected key `target` to instantiate.")
|
146 |
-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
147 |
-
|
148 |
-
|
149 |
-
def get_obj_from_str(string, reload=False):
|
150 |
-
"""A helper function to instantiate a class from a config object.
|
151 |
-
See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py
|
152 |
-
"""
|
153 |
-
module, cls = string.rsplit(".", 1)
|
154 |
-
if reload:
|
155 |
-
module_imp = importlib.import_module(module)
|
156 |
-
importlib.reload(module_imp)
|
157 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|