lambertxiao commited on
Commit
094ba04
·
verified ·
1 Parent(s): 4027c26

Overwrite with converted Qwen2.5-3B model files (mirror delete: 11 files)

Browse files
Files changed (1) hide show
  1. utils.py +0 -157
utils.py DELETED
@@ -1,157 +0,0 @@
1
- """Utility functions"""
2
- import importlib
3
- import random
4
-
5
- import torch
6
- import numpy as np
7
- from PIL import Image
8
-
9
-
10
-
11
- class UnNormalize(object):
12
- """Unformalize image as: image = (image * std) + mean
13
- """
14
- def __init__(self, mean, std):
15
- self.mean = torch.tensor(mean)
16
- self.std = torch.tensor(std)
17
-
18
- def __call__(self, tensor):
19
- """
20
- Args:
21
- tensor: A tensor of shape [C, H, W] or [N, C, H, W]
22
-
23
- Returns:
24
- tensor: A tensor of shape [C, H, W] or [N, C, H, W]
25
- """
26
-
27
- std = self.std.to(tensor.device)
28
- mean = self.mean.to(tensor.device)
29
- if tensor.ndim == 3:
30
- std, mean = std.view(-1, 1, 1), mean.view(-1, 1, 1)
31
- elif tensor.ndim == 4:
32
- std, mean = std.view(1, -1, 1, 1), mean.view(1, -1, 1, 1)
33
- tensor = (tensor * std) + mean
34
- return tensor
35
-
36
-
37
- class VQVAEUnNormalize(UnNormalize):
38
- """Unformalize image as:
39
- First: image = (image * std) + mean
40
- Second: image = (image * 2) - 1
41
- """
42
- def __call__(self, tensor):
43
- """
44
- Args:
45
- tensor (Tensor): Tensor image of size (C, H, W) or (N, C, H, W)
46
- to be unnormalized.
47
- Returns:
48
- Tensor: UnNormalized image.
49
- """
50
- tensor = super().__call__(tensor)
51
- tensor = 2 * tensor - 1
52
- return tensor
53
-
54
- def normalize(image,rescale=True):
55
-
56
- if rescale:
57
- image = image.float() / 255.0 # Convert to float and rescale to [0, 1]
58
- normalize_image = 2*image-1 # normalize to [-1, 1]
59
-
60
- return normalize_image
61
-
62
- # train_transforms = transforms.Compose(
63
- # [
64
- # transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
65
- # transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
66
- # transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
67
- # transforms.ToTensor(),
68
- # transforms.Normalize([0.5], [0.5]),
69
- # ]
70
- # )
71
-
72
-
73
- def mean_list(l):
74
- l = [int(_l) for _l in l]
75
- return float(sum(l)) / len(l)
76
-
77
-
78
- def segment_mean(x, index):
79
- """Function as tf.segment_mean.
80
- """
81
- x = x.view(-1, x.shape[-1])
82
- index = index.view(-1)
83
-
84
- max_index = index.max() + 1
85
- sum_x = torch.zeros((max_index, x.shape[-1]),
86
- dtype=x.dtype,
87
- device=x.device)
88
- num_index = torch.zeros((max_index,),
89
- dtype=x.dtype,
90
- device=x.device)
91
-
92
- num_index = num_index.scatter_add_(
93
- 0, index, torch.ones_like(index, dtype=x.dtype))
94
- num_index = torch.where(torch.eq(num_index, 0),
95
- torch.ones_like(num_index, dtype=x.dtype),
96
- num_index)
97
-
98
- index_2d = index.view(-1, 1).expand(-1, x.shape[-1])
99
- sum_x = sum_x.scatter_add_(0, index_2d, x)
100
- mean_x = sum_x.div_(num_index.view(-1, 1))
101
-
102
- return mean_x
103
-
104
-
105
-
106
-
107
-
108
- def initiate_time_steps(step, total_timestep, batch_size, config):
109
- """A helper function to initiate time steps for the diffusion model.
110
-
111
- Args:
112
- step: An integer of the constant step
113
- total_timestep: An integer of the total timesteps of the diffusion model
114
- batch_size: An integer of the batch size
115
- config: A config object
116
-
117
- Returns:
118
- timesteps: A tensor of shape [batch_size,] of the time steps
119
- """
120
- if config.tta.rand_timestep_equal_int:
121
- # the same timestep for each image in the batch
122
- interval_val = total_timestep // batch_size
123
- start_point = random.randint(0, interval_val - 1)
124
- timesteps = torch.tensor(
125
- list(range(start_point, total_timestep, interval_val))
126
- ).long()
127
- return timesteps
128
- elif config.tta.random_timestep_per_iteration:
129
- # random timestep for each image in the batch
130
- return torch.randint(0, total_timestep, (batch_size,)).long() #default
131
- else:
132
- # why we need to do this?
133
- return torch.tensor([step] * batch_size).long()
134
-
135
-
136
- def instantiate_from_config(config):
137
- """A helper function to instantiate a class from a config object.
138
- See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py
139
- """
140
- if not "target" in config:
141
- if config == '__is_first_stage__':
142
- return None
143
- elif config == "__is_unconditional__":
144
- return None
145
- raise KeyError("Expected key `target` to instantiate.")
146
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
147
-
148
-
149
- def get_obj_from_str(string, reload=False):
150
- """A helper function to instantiate a class from a config object.
151
- See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py
152
- """
153
- module, cls = string.rsplit(".", 1)
154
- if reload:
155
- module_imp = importlib.import_module(module)
156
- importlib.reload(module_imp)
157
- return getattr(importlib.import_module(module, package=None), cls)