vidfom commited on
Commit
20c58a2
·
verified ·
1 Parent(s): 52a2105

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. __pycache__/folder_paths.cpython-311.pyc +0 -0
  3. __pycache__/latent_preview.cpython-311.pyc +0 -0
  4. __pycache__/node_helpers.cpython-311.pyc +0 -0
  5. __pycache__/nodes.cpython-311.pyc +3 -0
  6. folder_paths.py +270 -0
  7. latent_preview.py +94 -0
  8. models/clip/clip_l.safetensors +3 -0
  9. models/clip/t5xxl_fp8_e4m3fn.safetensors +3 -0
  10. models/unet/flux1-schnell.safetensors +3 -0
  11. models/vae/ae.sft +3 -0
  12. node_helpers.py +37 -0
  13. nodes.py +2073 -0
  14. totoro/__pycache__/checkpoint_pickle.cpython-311.pyc +0 -0
  15. totoro/__pycache__/cli_args.cpython-311.pyc +0 -0
  16. totoro/__pycache__/clip_model.cpython-311.pyc +0 -0
  17. totoro/__pycache__/clip_vision.cpython-311.pyc +0 -0
  18. totoro/__pycache__/conds.cpython-311.pyc +0 -0
  19. totoro/__pycache__/controlnet.cpython-311.pyc +0 -0
  20. totoro/__pycache__/diffusers_convert.cpython-311.pyc +0 -0
  21. totoro/__pycache__/diffusers_load.cpython-311.pyc +0 -0
  22. totoro/__pycache__/gligen.cpython-311.pyc +0 -0
  23. totoro/__pycache__/latent_formats.cpython-311.pyc +0 -0
  24. totoro/__pycache__/lora.cpython-311.pyc +0 -0
  25. totoro/__pycache__/model_base.cpython-311.pyc +0 -0
  26. totoro/__pycache__/model_detection.cpython-311.pyc +0 -0
  27. totoro/__pycache__/model_management.cpython-311.pyc +0 -0
  28. totoro/__pycache__/model_patcher.cpython-311.pyc +0 -0
  29. totoro/__pycache__/model_sampling.cpython-311.pyc +0 -0
  30. totoro/__pycache__/ops.cpython-311.pyc +0 -0
  31. totoro/__pycache__/options.cpython-311.pyc +0 -0
  32. totoro/__pycache__/sample.cpython-311.pyc +0 -0
  33. totoro/__pycache__/sampler_helpers.cpython-311.pyc +0 -0
  34. totoro/__pycache__/samplers.cpython-311.pyc +0 -0
  35. totoro/__pycache__/sd.cpython-311.pyc +0 -0
  36. totoro/__pycache__/sd1_clip.cpython-311.pyc +0 -0
  37. totoro/__pycache__/sdxl_clip.cpython-311.pyc +0 -0
  38. totoro/__pycache__/supported_models.cpython-311.pyc +0 -0
  39. totoro/__pycache__/supported_models_base.cpython-311.pyc +0 -0
  40. totoro/__pycache__/types.cpython-311.pyc +0 -0
  41. totoro/__pycache__/utils.cpython-311.pyc +0 -0
  42. totoro/checkpoint_pickle.py +13 -0
  43. totoro/cldm/__pycache__/cldm.cpython-311.pyc +0 -0
  44. totoro/cldm/__pycache__/control_types.cpython-311.pyc +0 -0
  45. totoro/cldm/__pycache__/mmdit.cpython-311.pyc +0 -0
  46. totoro/cldm/cldm.py +437 -0
  47. totoro/cldm/control_types.py +10 -0
  48. totoro/cldm/mmdit.py +77 -0
  49. totoro/cli_args.py +180 -0
  50. totoro/clip_config_bigg.json +23 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ __pycache__/nodes.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
37
+ models/vae/ae.sft filter=lfs diff=lfs merge=lfs -text
__pycache__/folder_paths.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
__pycache__/latent_preview.cpython-311.pyc ADDED
Binary file (6.52 kB). View file
 
__pycache__/node_helpers.cpython-311.pyc ADDED
Binary file (1.76 kB). View file
 
__pycache__/nodes.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ede1805c76e641f174da26d129150e2a67be482e28bc0aa248fa606e69eb616
3
+ size 115175
folder_paths.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ from typing import Set, List, Dict, Tuple
5
+
6
+ supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
7
+
8
+ SupportedFileExtensionsType = Set[str]
9
+ ScanPathType = List[str]
10
+ folder_names_and_paths: Dict[str, Tuple[ScanPathType, SupportedFileExtensionsType]] = {}
11
+
12
+ base_path = os.path.dirname(os.path.realpath(__file__))
13
+ models_dir = os.path.join(base_path, "models")
14
+ folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
15
+ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
16
+
17
+ folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
18
+ folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
19
+ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
20
+ folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
21
+ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
22
+ folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
23
+ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
24
+ folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
25
+ folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
26
+
27
+ folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
28
+ folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
29
+
30
+ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
31
+
32
+ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set())
33
+
34
+ folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
35
+
36
+ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
37
+
38
+ folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
39
+
40
+ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
41
+ temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
42
+ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
43
+ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
44
+
45
+ filename_list_cache = {}
46
+
47
+ if not os.path.exists(input_directory):
48
+ try:
49
+ os.makedirs(input_directory)
50
+ except:
51
+ logging.error("Failed to create input directory")
52
+
53
+ def set_output_directory(output_dir):
54
+ global output_directory
55
+ output_directory = output_dir
56
+
57
+ def set_temp_directory(temp_dir):
58
+ global temp_directory
59
+ temp_directory = temp_dir
60
+
61
+ def set_input_directory(input_dir):
62
+ global input_directory
63
+ input_directory = input_dir
64
+
65
+ def get_output_directory():
66
+ global output_directory
67
+ return output_directory
68
+
69
+ def get_temp_directory():
70
+ global temp_directory
71
+ return temp_directory
72
+
73
+ def get_input_directory():
74
+ global input_directory
75
+ return input_directory
76
+
77
+
78
+ #NOTE: used in http server so don't put folders that should not be accessed remotely
79
+ def get_directory_by_type(type_name):
80
+ if type_name == "output":
81
+ return get_output_directory()
82
+ if type_name == "temp":
83
+ return get_temp_directory()
84
+ if type_name == "input":
85
+ return get_input_directory()
86
+ return None
87
+
88
+
89
+ # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
90
+ # otherwise use default_path as base_dir
91
+ def annotated_filepath(name):
92
+ if name.endswith("[output]"):
93
+ base_dir = get_output_directory()
94
+ name = name[:-9]
95
+ elif name.endswith("[input]"):
96
+ base_dir = get_input_directory()
97
+ name = name[:-8]
98
+ elif name.endswith("[temp]"):
99
+ base_dir = get_temp_directory()
100
+ name = name[:-7]
101
+ else:
102
+ return name, None
103
+
104
+ return name, base_dir
105
+
106
+
107
+ def get_annotated_filepath(name, default_dir=None):
108
+ name, base_dir = annotated_filepath(name)
109
+
110
+ if base_dir is None:
111
+ if default_dir is not None:
112
+ base_dir = default_dir
113
+ else:
114
+ base_dir = get_input_directory() # fallback path
115
+
116
+ return os.path.join(base_dir, name)
117
+
118
+
119
+ def exists_annotated_filepath(name):
120
+ name, base_dir = annotated_filepath(name)
121
+
122
+ if base_dir is None:
123
+ base_dir = get_input_directory() # fallback path
124
+
125
+ filepath = os.path.join(base_dir, name)
126
+ return os.path.exists(filepath)
127
+
128
+
129
+ def add_model_folder_path(folder_name, full_folder_path):
130
+ global folder_names_and_paths
131
+ if folder_name in folder_names_and_paths:
132
+ folder_names_and_paths[folder_name][0].append(full_folder_path)
133
+ else:
134
+ folder_names_and_paths[folder_name] = ([full_folder_path], set())
135
+
136
+ def get_folder_paths(folder_name):
137
+ return folder_names_and_paths[folder_name][0][:]
138
+
139
+ def recursive_search(directory, excluded_dir_names=None):
140
+ if not os.path.isdir(directory):
141
+ return [], {}
142
+
143
+ if excluded_dir_names is None:
144
+ excluded_dir_names = []
145
+
146
+ result = []
147
+ dirs = {}
148
+
149
+ # Attempt to add the initial directory to dirs with error handling
150
+ try:
151
+ dirs[directory] = os.path.getmtime(directory)
152
+ except FileNotFoundError:
153
+ logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
154
+
155
+ logging.debug("recursive file list on directory {}".format(directory))
156
+ for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
157
+ subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
158
+ for file_name in filenames:
159
+ relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
160
+ result.append(relative_path)
161
+
162
+ for d in subdirs:
163
+ path = os.path.join(dirpath, d)
164
+ try:
165
+ dirs[path] = os.path.getmtime(path)
166
+ except FileNotFoundError:
167
+ logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
168
+ continue
169
+ logging.debug("found {} files".format(len(result)))
170
+ return result, dirs
171
+
172
+ def filter_files_extensions(files, extensions):
173
+ return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
174
+
175
+
176
+
177
+ def get_full_path(folder_name, filename):
178
+ global folder_names_and_paths
179
+ if folder_name not in folder_names_and_paths:
180
+ return None
181
+ folders = folder_names_and_paths[folder_name]
182
+ filename = os.path.relpath(os.path.join("/", filename), "/")
183
+ for x in folders[0]:
184
+ full_path = os.path.join(x, filename)
185
+ if os.path.isfile(full_path):
186
+ return full_path
187
+ elif os.path.islink(full_path):
188
+ logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
189
+
190
+ return None
191
+
192
+ def get_filename_list_(folder_name):
193
+ global folder_names_and_paths
194
+ output_list = set()
195
+ folders = folder_names_and_paths[folder_name]
196
+ output_folders = {}
197
+ for x in folders[0]:
198
+ files, folders_all = recursive_search(x, excluded_dir_names=[".git"])
199
+ output_list.update(filter_files_extensions(files, folders[1]))
200
+ output_folders = {**output_folders, **folders_all}
201
+
202
+ return (sorted(list(output_list)), output_folders, time.perf_counter())
203
+
204
+ def cached_filename_list_(folder_name):
205
+ global filename_list_cache
206
+ global folder_names_and_paths
207
+ if folder_name not in filename_list_cache:
208
+ return None
209
+ out = filename_list_cache[folder_name]
210
+
211
+ for x in out[1]:
212
+ time_modified = out[1][x]
213
+ folder = x
214
+ if os.path.getmtime(folder) != time_modified:
215
+ return None
216
+
217
+ folders = folder_names_and_paths[folder_name]
218
+ for x in folders[0]:
219
+ if os.path.isdir(x):
220
+ if x not in out[1]:
221
+ return None
222
+
223
+ return out
224
+
225
+ def get_filename_list(folder_name):
226
+ out = cached_filename_list_(folder_name)
227
+ if out is None:
228
+ out = get_filename_list_(folder_name)
229
+ global filename_list_cache
230
+ filename_list_cache[folder_name] = out
231
+ return list(out[0])
232
+
233
+ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
234
+ def map_filename(filename):
235
+ prefix_len = len(os.path.basename(filename_prefix))
236
+ prefix = filename[:prefix_len + 1]
237
+ try:
238
+ digits = int(filename[prefix_len + 1:].split('_')[0])
239
+ except:
240
+ digits = 0
241
+ return (digits, prefix)
242
+
243
+ def compute_vars(input, image_width, image_height):
244
+ input = input.replace("%width%", str(image_width))
245
+ input = input.replace("%height%", str(image_height))
246
+ return input
247
+
248
+ filename_prefix = compute_vars(filename_prefix, image_width, image_height)
249
+
250
+ subfolder = os.path.dirname(os.path.normpath(filename_prefix))
251
+ filename = os.path.basename(os.path.normpath(filename_prefix))
252
+
253
+ full_output_folder = os.path.join(output_dir, subfolder)
254
+
255
+ if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
256
+ err = "**** ERROR: Saving image outside the output folder is not allowed." + \
257
+ "\n full_output_folder: " + os.path.abspath(full_output_folder) + \
258
+ "\n output_dir: " + output_dir + \
259
+ "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
260
+ logging.error(err)
261
+ raise Exception(err)
262
+
263
+ try:
264
+ counter = max(filter(lambda a: os.path.normcase(a[1][:-1]) == os.path.normcase(filename) and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
265
+ except ValueError:
266
+ counter = 1
267
+ except FileNotFoundError:
268
+ os.makedirs(full_output_folder, exist_ok=True)
269
+ counter = 1
270
+ return full_output_folder, filename, counter, subfolder, filename_prefix
latent_preview.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import struct
4
+ import numpy as np
5
+ from totoro.cli_args import args, LatentPreviewMethod
6
+ from totoro.taesd.taesd import TAESD
7
+ import totoro.model_management
8
+ import folder_paths
9
+ import totoro.utils
10
+ import logging
11
+
12
+ MAX_PREVIEW_RESOLUTION = 512
13
+
14
+ def preview_to_image(latent_image):
15
+ latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
16
+ .mul(0xFF) # to 0..255
17
+ ).to(device="cpu", dtype=torch.uint8, non_blocking=totoro.model_management.device_supports_non_blocking(latent_image.device))
18
+
19
+ return Image.fromarray(latents_ubyte.numpy())
20
+
21
+ class LatentPreviewer:
22
+ def decode_latent_to_preview(self, x0):
23
+ pass
24
+
25
+ def decode_latent_to_preview_image(self, preview_format, x0):
26
+ preview_image = self.decode_latent_to_preview(x0)
27
+ return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
28
+
29
+ class TAESDPreviewerImpl(LatentPreviewer):
30
+ def __init__(self, taesd):
31
+ self.taesd = taesd
32
+
33
+ def decode_latent_to_preview(self, x0):
34
+ x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
35
+ return preview_to_image(x_sample)
36
+
37
+
38
+ class Latent2RGBPreviewer(LatentPreviewer):
39
+ def __init__(self, latent_rgb_factors):
40
+ self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
41
+
42
+ def decode_latent_to_preview(self, x0):
43
+ self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
44
+ latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
45
+ return preview_to_image(latent_image)
46
+
47
+
48
+ def get_previewer(device, latent_format):
49
+ previewer = None
50
+ method = args.preview_method
51
+ if method != LatentPreviewMethod.NoPreviews:
52
+ # TODO previewer methods
53
+ taesd_decoder_path = None
54
+ if latent_format.taesd_decoder_name is not None:
55
+ taesd_decoder_path = next(
56
+ (fn for fn in folder_paths.get_filename_list("vae_approx")
57
+ if fn.startswith(latent_format.taesd_decoder_name)),
58
+ ""
59
+ )
60
+ taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
61
+
62
+ if method == LatentPreviewMethod.Auto:
63
+ method = LatentPreviewMethod.Latent2RGB
64
+
65
+ if method == LatentPreviewMethod.TAESD:
66
+ if taesd_decoder_path:
67
+ taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
68
+ previewer = TAESDPreviewerImpl(taesd)
69
+ else:
70
+ logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
71
+
72
+ if previewer is None:
73
+ if latent_format.latent_rgb_factors is not None:
74
+ previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
75
+ return previewer
76
+
77
+ def prepare_callback(model, steps, x0_output_dict=None):
78
+ preview_format = "JPEG"
79
+ if preview_format not in ["JPEG", "PNG"]:
80
+ preview_format = "JPEG"
81
+
82
+ previewer = get_previewer(model.load_device, model.model.latent_format)
83
+
84
+ pbar = totoro.utils.ProgressBar(steps)
85
+ def callback(step, x0, x, total_steps):
86
+ if x0_output_dict is not None:
87
+ x0_output_dict["x0"] = x0
88
+
89
+ preview_bytes = None
90
+ if previewer:
91
+ preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
92
+ pbar.update_absolute(step + 1, total_steps, preview_bytes)
93
+ return callback
94
+
models/clip/clip_l.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:660c6f5b1abae9dc498ac2d21e1347d2abdb0cf6c0c0c8576cd796491d9a6cdd
3
+ size 246144152
models/clip/t5xxl_fp8_e4m3fn.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d330da4816157540d6bb7838bf63a0f02f573fc48ca4d8de34bb0cbfd514f09
3
+ size 4893934904
models/unet/flux1-schnell.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9403429e0052277ac2a87ad800adece5481eecefd9ed334e1f348723621d2a0a
3
+ size 23782506688
models/vae/ae.sft ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afc8e28272cd15db3919bacdb6918ce9c1ed22e96cb12c4d5ed0fba823529e38
3
+ size 335304388
node_helpers.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+
3
+ from totoro.cli_args import args
4
+
5
+ from PIL import ImageFile, UnidentifiedImageError
6
+
7
+ def conditioning_set_values(conditioning, values={}):
8
+ c = []
9
+ for t in conditioning:
10
+ n = [t[0], t[1].copy()]
11
+ for k in values:
12
+ n[1][k] = values[k]
13
+ c.append(n)
14
+
15
+ return c
16
+
17
+ def pillow(fn, arg):
18
+ prev_value = None
19
+ try:
20
+ x = fn(arg)
21
+ except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes totoroUI issue #3416
22
+ prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
23
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
24
+ x = fn(arg)
25
+ finally:
26
+ if prev_value is not None:
27
+ ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
28
+ return x
29
+
30
+ def hasher():
31
+ hashfuncs = {
32
+ "md5": hashlib.md5,
33
+ "sha1": hashlib.sha1,
34
+ "sha256": hashlib.sha256,
35
+ "sha512": hashlib.sha512
36
+ }
37
+ return hashfuncs[args.default_hashing_function]
nodes.py ADDED
@@ -0,0 +1,2073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import os
4
+ import sys
5
+ import json
6
+ import hashlib
7
+ import traceback
8
+ import math
9
+ import time
10
+ import random
11
+ import logging
12
+
13
+ from PIL import Image, ImageOps, ImageSequence, ImageFile
14
+ from PIL.PngImagePlugin import PngInfo
15
+
16
+ import numpy as np
17
+ import safetensors.torch
18
+
19
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "totoro"))
20
+
21
+ import totoro.diffusers_load
22
+ import totoro.samplers
23
+ import totoro.sample
24
+ import totoro.sd
25
+ import totoro.utils
26
+ import totoro.controlnet
27
+
28
+ import totoro.clip_vision
29
+
30
+ import totoro.model_management
31
+ from totoro.cli_args import args
32
+
33
+ import importlib
34
+
35
+ import folder_paths
36
+ import latent_preview
37
+ import node_helpers
38
+
39
+ def before_node_execution():
40
+ totoro.model_management.throw_exception_if_processing_interrupted()
41
+
42
+ def interrupt_processing(value=True):
43
+ totoro.model_management.interrupt_current_processing(value)
44
+
45
+ MAX_RESOLUTION=16384
46
+
47
+ class CLIPTextEncode:
48
+ @classmethod
49
+ def INPUT_TYPES(s):
50
+ return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}}
51
+ RETURN_TYPES = ("CONDITIONING",)
52
+ FUNCTION = "encode"
53
+
54
+ CATEGORY = "conditioning"
55
+
56
+ def encode(self, clip, text):
57
+ tokens = clip.tokenize(text)
58
+ output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
59
+ cond = output.pop("cond")
60
+ return ([[cond, output]], )
61
+
62
+ class ConditioningCombine:
63
+ @classmethod
64
+ def INPUT_TYPES(s):
65
+ return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
66
+ RETURN_TYPES = ("CONDITIONING",)
67
+ FUNCTION = "combine"
68
+
69
+ CATEGORY = "conditioning"
70
+
71
+ def combine(self, conditioning_1, conditioning_2):
72
+ return (conditioning_1 + conditioning_2, )
73
+
74
+ class ConditioningAverage :
75
+ @classmethod
76
+ def INPUT_TYPES(s):
77
+ return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
78
+ "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
79
+ }}
80
+ RETURN_TYPES = ("CONDITIONING",)
81
+ FUNCTION = "addWeighted"
82
+
83
+ CATEGORY = "conditioning"
84
+
85
+ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
86
+ out = []
87
+
88
+ if len(conditioning_from) > 1:
89
+ logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
90
+
91
+ cond_from = conditioning_from[0][0]
92
+ pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
93
+
94
+ for i in range(len(conditioning_to)):
95
+ t1 = conditioning_to[i][0]
96
+ pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
97
+ t0 = cond_from[:,:t1.shape[1]]
98
+ if t0.shape[1] < t1.shape[1]:
99
+ t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
100
+
101
+ tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
102
+ t_to = conditioning_to[i][1].copy()
103
+ if pooled_output_from is not None and pooled_output_to is not None:
104
+ t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
105
+ elif pooled_output_from is not None:
106
+ t_to["pooled_output"] = pooled_output_from
107
+
108
+ n = [tw, t_to]
109
+ out.append(n)
110
+ return (out, )
111
+
112
+ class ConditioningConcat:
113
+ @classmethod
114
+ def INPUT_TYPES(s):
115
+ return {"required": {
116
+ "conditioning_to": ("CONDITIONING",),
117
+ "conditioning_from": ("CONDITIONING",),
118
+ }}
119
+ RETURN_TYPES = ("CONDITIONING",)
120
+ FUNCTION = "concat"
121
+
122
+ CATEGORY = "conditioning"
123
+
124
+ def concat(self, conditioning_to, conditioning_from):
125
+ out = []
126
+
127
+ if len(conditioning_from) > 1:
128
+ logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
129
+
130
+ cond_from = conditioning_from[0][0]
131
+
132
+ for i in range(len(conditioning_to)):
133
+ t1 = conditioning_to[i][0]
134
+ tw = torch.cat((t1, cond_from),1)
135
+ n = [tw, conditioning_to[i][1].copy()]
136
+ out.append(n)
137
+
138
+ return (out, )
139
+
140
+ class ConditioningSetArea:
141
+ @classmethod
142
+ def INPUT_TYPES(s):
143
+ return {"required": {"conditioning": ("CONDITIONING", ),
144
+ "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
145
+ "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
146
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
147
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
148
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
149
+ }}
150
+ RETURN_TYPES = ("CONDITIONING",)
151
+ FUNCTION = "append"
152
+
153
+ CATEGORY = "conditioning"
154
+
155
+ def append(self, conditioning, width, height, x, y, strength):
156
+ c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
157
+ "strength": strength,
158
+ "set_area_to_bounds": False})
159
+ return (c, )
160
+
161
+ class ConditioningSetAreaPercentage:
162
+ @classmethod
163
+ def INPUT_TYPES(s):
164
+ return {"required": {"conditioning": ("CONDITIONING", ),
165
+ "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
166
+ "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
167
+ "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
168
+ "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
169
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
170
+ }}
171
+ RETURN_TYPES = ("CONDITIONING",)
172
+ FUNCTION = "append"
173
+
174
+ CATEGORY = "conditioning"
175
+
176
+ def append(self, conditioning, width, height, x, y, strength):
177
+ c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
178
+ "strength": strength,
179
+ "set_area_to_bounds": False})
180
+ return (c, )
181
+
182
+ class ConditioningSetAreaStrength:
183
+ @classmethod
184
+ def INPUT_TYPES(s):
185
+ return {"required": {"conditioning": ("CONDITIONING", ),
186
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
187
+ }}
188
+ RETURN_TYPES = ("CONDITIONING",)
189
+ FUNCTION = "append"
190
+
191
+ CATEGORY = "conditioning"
192
+
193
+ def append(self, conditioning, strength):
194
+ c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
195
+ return (c, )
196
+
197
+
198
+ class ConditioningSetMask:
199
+ @classmethod
200
+ def INPUT_TYPES(s):
201
+ return {"required": {"conditioning": ("CONDITIONING", ),
202
+ "mask": ("MASK", ),
203
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
204
+ "set_cond_area": (["default", "mask bounds"],),
205
+ }}
206
+ RETURN_TYPES = ("CONDITIONING",)
207
+ FUNCTION = "append"
208
+
209
+ CATEGORY = "conditioning"
210
+
211
+ def append(self, conditioning, mask, set_cond_area, strength):
212
+ set_area_to_bounds = False
213
+ if set_cond_area != "default":
214
+ set_area_to_bounds = True
215
+ if len(mask.shape) < 3:
216
+ mask = mask.unsqueeze(0)
217
+
218
+ c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
219
+ "set_area_to_bounds": set_area_to_bounds,
220
+ "mask_strength": strength})
221
+ return (c, )
222
+
223
+ class ConditioningZeroOut:
224
+ @classmethod
225
+ def INPUT_TYPES(s):
226
+ return {"required": {"conditioning": ("CONDITIONING", )}}
227
+ RETURN_TYPES = ("CONDITIONING",)
228
+ FUNCTION = "zero_out"
229
+
230
+ CATEGORY = "advanced/conditioning"
231
+
232
+ def zero_out(self, conditioning):
233
+ c = []
234
+ for t in conditioning:
235
+ d = t[1].copy()
236
+ pooled_output = d.get("pooled_output", None)
237
+ if pooled_output is not None:
238
+ d["pooled_output"] = torch.zeros_like(pooled_output)
239
+ n = [torch.zeros_like(t[0]), d]
240
+ c.append(n)
241
+ return (c, )
242
+
243
+ class ConditioningSetTimestepRange:
244
+ @classmethod
245
+ def INPUT_TYPES(s):
246
+ return {"required": {"conditioning": ("CONDITIONING", ),
247
+ "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
248
+ "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
249
+ }}
250
+ RETURN_TYPES = ("CONDITIONING",)
251
+ FUNCTION = "set_range"
252
+
253
+ CATEGORY = "advanced/conditioning"
254
+
255
+ def set_range(self, conditioning, start, end):
256
+ c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
257
+ "end_percent": end})
258
+ return (c, )
259
+
260
+ class VAEDecode:
261
+ @classmethod
262
+ def INPUT_TYPES(s):
263
+ return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
264
+ RETURN_TYPES = ("IMAGE",)
265
+ FUNCTION = "decode"
266
+
267
+ CATEGORY = "latent"
268
+
269
+ def decode(self, vae, samples):
270
+ return (vae.decode(samples["samples"]), )
271
+
272
+ class VAEDecodeTiled:
273
+ @classmethod
274
+ def INPUT_TYPES(s):
275
+ return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
276
+ "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
277
+ }}
278
+ RETURN_TYPES = ("IMAGE",)
279
+ FUNCTION = "decode"
280
+
281
+ CATEGORY = "_for_testing"
282
+
283
+ def decode(self, vae, samples, tile_size):
284
+ return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
285
+
286
+ class VAEEncode:
287
+ @classmethod
288
+ def INPUT_TYPES(s):
289
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
290
+ RETURN_TYPES = ("LATENT",)
291
+ FUNCTION = "encode"
292
+
293
+ CATEGORY = "latent"
294
+
295
+ def encode(self, vae, pixels):
296
+ t = vae.encode(pixels[:,:,:,:3])
297
+ return ({"samples":t}, )
298
+
299
+ class VAEEncodeTiled:
300
+ @classmethod
301
+ def INPUT_TYPES(s):
302
+ return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
303
+ "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
304
+ }}
305
+ RETURN_TYPES = ("LATENT",)
306
+ FUNCTION = "encode"
307
+
308
+ CATEGORY = "_for_testing"
309
+
310
+ def encode(self, vae, pixels, tile_size):
311
+ t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
312
+ return ({"samples":t}, )
313
+
314
+ class VAEEncodeForInpaint:
315
+ @classmethod
316
+ def INPUT_TYPES(s):
317
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
318
+ RETURN_TYPES = ("LATENT",)
319
+ FUNCTION = "encode"
320
+
321
+ CATEGORY = "latent/inpaint"
322
+
323
+ def encode(self, vae, pixels, mask, grow_mask_by=6):
324
+ x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
325
+ y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
326
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
327
+
328
+ pixels = pixels.clone()
329
+ if pixels.shape[1] != x or pixels.shape[2] != y:
330
+ x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
331
+ y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
332
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
333
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
334
+
335
+ #grow mask by a few pixels to keep things seamless in latent space
336
+ if grow_mask_by == 0:
337
+ mask_erosion = mask
338
+ else:
339
+ kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
340
+ padding = math.ceil((grow_mask_by - 1) / 2)
341
+
342
+ mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
343
+
344
+ m = (1.0 - mask.round()).squeeze(1)
345
+ for i in range(3):
346
+ pixels[:,:,:,i] -= 0.5
347
+ pixels[:,:,:,i] *= m
348
+ pixels[:,:,:,i] += 0.5
349
+ t = vae.encode(pixels)
350
+
351
+ return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
352
+
353
+
354
+ class InpaintModelConditioning:
355
+ @classmethod
356
+ def INPUT_TYPES(s):
357
+ return {"required": {"positive": ("CONDITIONING", ),
358
+ "negative": ("CONDITIONING", ),
359
+ "vae": ("VAE", ),
360
+ "pixels": ("IMAGE", ),
361
+ "mask": ("MASK", ),
362
+ }}
363
+
364
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
365
+ RETURN_NAMES = ("positive", "negative", "latent")
366
+ FUNCTION = "encode"
367
+
368
+ CATEGORY = "conditioning/inpaint"
369
+
370
+ def encode(self, positive, negative, pixels, vae, mask):
371
+ x = (pixels.shape[1] // 8) * 8
372
+ y = (pixels.shape[2] // 8) * 8
373
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
374
+
375
+ orig_pixels = pixels
376
+ pixels = orig_pixels.clone()
377
+ if pixels.shape[1] != x or pixels.shape[2] != y:
378
+ x_offset = (pixels.shape[1] % 8) // 2
379
+ y_offset = (pixels.shape[2] % 8) // 2
380
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
381
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
382
+
383
+ m = (1.0 - mask.round()).squeeze(1)
384
+ for i in range(3):
385
+ pixels[:,:,:,i] -= 0.5
386
+ pixels[:,:,:,i] *= m
387
+ pixels[:,:,:,i] += 0.5
388
+ concat_latent = vae.encode(pixels)
389
+ orig_latent = vae.encode(orig_pixels)
390
+
391
+ out_latent = {}
392
+
393
+ out_latent["samples"] = orig_latent
394
+ out_latent["noise_mask"] = mask
395
+
396
+ out = []
397
+ for conditioning in [positive, negative]:
398
+ c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
399
+ "concat_mask": mask})
400
+ out.append(c)
401
+ return (out[0], out[1], out_latent)
402
+
403
+
404
+ class SaveLatent:
405
+ def __init__(self):
406
+ self.output_dir = folder_paths.get_output_directory()
407
+
408
+ @classmethod
409
+ def INPUT_TYPES(s):
410
+ return {"required": { "samples": ("LATENT", ),
411
+ "filename_prefix": ("STRING", {"default": "latents/totoroUI"})},
412
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
413
+ }
414
+ RETURN_TYPES = ()
415
+ FUNCTION = "save"
416
+
417
+ OUTPUT_NODE = True
418
+
419
+ CATEGORY = "_for_testing"
420
+
421
+ def save(self, samples, filename_prefix="totoroUI", prompt=None, extra_pnginfo=None):
422
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
423
+
424
+ # support save metadata for latent sharing
425
+ prompt_info = ""
426
+ if prompt is not None:
427
+ prompt_info = json.dumps(prompt)
428
+
429
+ metadata = None
430
+ if not args.disable_metadata:
431
+ metadata = {"prompt": prompt_info}
432
+ if extra_pnginfo is not None:
433
+ for x in extra_pnginfo:
434
+ metadata[x] = json.dumps(extra_pnginfo[x])
435
+
436
+ file = f"{filename}_{counter:05}_.latent"
437
+
438
+ results = list()
439
+ results.append({
440
+ "filename": file,
441
+ "subfolder": subfolder,
442
+ "type": "output"
443
+ })
444
+
445
+ file = os.path.join(full_output_folder, file)
446
+
447
+ output = {}
448
+ output["latent_tensor"] = samples["samples"]
449
+ output["latent_format_version_0"] = torch.tensor([])
450
+
451
+ totoro.utils.save_torch_file(output, file, metadata=metadata)
452
+ return { "ui": { "latents": results } }
453
+
454
+
455
+ class LoadLatent:
456
+ @classmethod
457
+ def INPUT_TYPES(s):
458
+ input_dir = folder_paths.get_input_directory()
459
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
460
+ return {"required": {"latent": [sorted(files), ]}, }
461
+
462
+ CATEGORY = "_for_testing"
463
+
464
+ RETURN_TYPES = ("LATENT", )
465
+ FUNCTION = "load"
466
+
467
+ def load(self, latent):
468
+ latent_path = folder_paths.get_annotated_filepath(latent)
469
+ latent = safetensors.torch.load_file(latent_path, device="cpu")
470
+ multiplier = 1.0
471
+ if "latent_format_version_0" not in latent:
472
+ multiplier = 1.0 / 0.18215
473
+ samples = {"samples": latent["latent_tensor"].float() * multiplier}
474
+ return (samples, )
475
+
476
+ @classmethod
477
+ def IS_CHANGED(s, latent):
478
+ image_path = folder_paths.get_annotated_filepath(latent)
479
+ m = hashlib.sha256()
480
+ with open(image_path, 'rb') as f:
481
+ m.update(f.read())
482
+ return m.digest().hex()
483
+
484
+ @classmethod
485
+ def VALIDATE_INPUTS(s, latent):
486
+ if not folder_paths.exists_annotated_filepath(latent):
487
+ return "Invalid latent file: {}".format(latent)
488
+ return True
489
+
490
+
491
+ class CheckpointLoader:
492
+ @classmethod
493
+ def INPUT_TYPES(s):
494
+ return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
495
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}}
496
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
497
+ FUNCTION = "load_checkpoint"
498
+
499
+ CATEGORY = "advanced/loaders"
500
+
501
+ def load_checkpoint(self, config_name, ckpt_name):
502
+ config_path = folder_paths.get_full_path("configs", config_name)
503
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
504
+ return totoro.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
505
+
506
+ class CheckpointLoaderSimple:
507
+ @classmethod
508
+ def INPUT_TYPES(s):
509
+ return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
510
+ }}
511
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
512
+ FUNCTION = "load_checkpoint"
513
+
514
+ CATEGORY = "loaders"
515
+
516
+ def load_checkpoint(self, ckpt_name):
517
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
518
+ out = totoro.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
519
+ return out[:3]
520
+
521
+ class DiffusersLoader:
522
+ @classmethod
523
+ def INPUT_TYPES(cls):
524
+ paths = []
525
+ for search_path in folder_paths.get_folder_paths("diffusers"):
526
+ if os.path.exists(search_path):
527
+ for root, subdir, files in os.walk(search_path, followlinks=True):
528
+ if "model_index.json" in files:
529
+ paths.append(os.path.relpath(root, start=search_path))
530
+
531
+ return {"required": {"model_path": (paths,), }}
532
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
533
+ FUNCTION = "load_checkpoint"
534
+
535
+ CATEGORY = "advanced/loaders/deprecated"
536
+
537
+ def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
538
+ for search_path in folder_paths.get_folder_paths("diffusers"):
539
+ if os.path.exists(search_path):
540
+ path = os.path.join(search_path, model_path)
541
+ if os.path.exists(path):
542
+ model_path = path
543
+ break
544
+
545
+ return totoro.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
546
+
547
+
548
+ class unCLIPCheckpointLoader:
549
+ @classmethod
550
+ def INPUT_TYPES(s):
551
+ return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
552
+ }}
553
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
554
+ FUNCTION = "load_checkpoint"
555
+
556
+ CATEGORY = "loaders"
557
+
558
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
559
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
560
+ out = totoro.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
561
+ return out
562
+
563
+ class CLIPSetLastLayer:
564
+ @classmethod
565
+ def INPUT_TYPES(s):
566
+ return {"required": { "clip": ("CLIP", ),
567
+ "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
568
+ }}
569
+ RETURN_TYPES = ("CLIP",)
570
+ FUNCTION = "set_last_layer"
571
+
572
+ CATEGORY = "conditioning"
573
+
574
+ def set_last_layer(self, clip, stop_at_clip_layer):
575
+ clip = clip.clone()
576
+ clip.clip_layer(stop_at_clip_layer)
577
+ return (clip,)
578
+
579
+ class LoraLoader:
580
+ def __init__(self):
581
+ self.loaded_lora = None
582
+
583
+ @classmethod
584
+ def INPUT_TYPES(s):
585
+ return {"required": { "model": ("MODEL",),
586
+ "clip": ("CLIP", ),
587
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
588
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
589
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
590
+ }}
591
+ RETURN_TYPES = ("MODEL", "CLIP")
592
+ FUNCTION = "load_lora"
593
+
594
+ CATEGORY = "loaders"
595
+
596
+ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
597
+ if strength_model == 0 and strength_clip == 0:
598
+ return (model, clip)
599
+
600
+ lora_path = folder_paths.get_full_path("loras", lora_name)
601
+ lora = None
602
+ if self.loaded_lora is not None:
603
+ if self.loaded_lora[0] == lora_path:
604
+ lora = self.loaded_lora[1]
605
+ else:
606
+ temp = self.loaded_lora
607
+ self.loaded_lora = None
608
+ del temp
609
+
610
+ if lora is None:
611
+ lora = totoro.utils.load_torch_file(lora_path, safe_load=True)
612
+ self.loaded_lora = (lora_path, lora)
613
+
614
+ model_lora, clip_lora = totoro.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
615
+ return (model_lora, clip_lora)
616
+
617
+ class LoraLoaderModelOnly(LoraLoader):
618
+ @classmethod
619
+ def INPUT_TYPES(s):
620
+ return {"required": { "model": ("MODEL",),
621
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
622
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
623
+ }}
624
+ RETURN_TYPES = ("MODEL",)
625
+ FUNCTION = "load_lora_model_only"
626
+
627
+ def load_lora_model_only(self, model, lora_name, strength_model):
628
+ return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
629
+
630
+ class VAELoader:
631
+ @staticmethod
632
+ def vae_list():
633
+ vaes = folder_paths.get_filename_list("vae")
634
+ approx_vaes = folder_paths.get_filename_list("vae_approx")
635
+ sdxl_taesd_enc = False
636
+ sdxl_taesd_dec = False
637
+ sd1_taesd_enc = False
638
+ sd1_taesd_dec = False
639
+ sd3_taesd_enc = False
640
+ sd3_taesd_dec = False
641
+
642
+ for v in approx_vaes:
643
+ if v.startswith("taesd_decoder."):
644
+ sd1_taesd_dec = True
645
+ elif v.startswith("taesd_encoder."):
646
+ sd1_taesd_enc = True
647
+ elif v.startswith("taesdxl_decoder."):
648
+ sdxl_taesd_dec = True
649
+ elif v.startswith("taesdxl_encoder."):
650
+ sdxl_taesd_enc = True
651
+ elif v.startswith("taesd3_decoder."):
652
+ sd3_taesd_dec = True
653
+ elif v.startswith("taesd3_encoder."):
654
+ sd3_taesd_enc = True
655
+ if sd1_taesd_dec and sd1_taesd_enc:
656
+ vaes.append("taesd")
657
+ if sdxl_taesd_dec and sdxl_taesd_enc:
658
+ vaes.append("taesdxl")
659
+ if sd3_taesd_dec and sd3_taesd_enc:
660
+ vaes.append("taesd3")
661
+ return vaes
662
+
663
+ @staticmethod
664
+ def load_taesd(name):
665
+ sd = {}
666
+ approx_vaes = folder_paths.get_filename_list("vae_approx")
667
+
668
+ encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
669
+ decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
670
+
671
+ enc = totoro.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
672
+ for k in enc:
673
+ sd["taesd_encoder.{}".format(k)] = enc[k]
674
+
675
+ dec = totoro.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
676
+ for k in dec:
677
+ sd["taesd_decoder.{}".format(k)] = dec[k]
678
+
679
+ if name == "taesd":
680
+ sd["vae_scale"] = torch.tensor(0.18215)
681
+ sd["vae_shift"] = torch.tensor(0.0)
682
+ elif name == "taesdxl":
683
+ sd["vae_scale"] = torch.tensor(0.13025)
684
+ sd["vae_shift"] = torch.tensor(0.0)
685
+ elif name == "taesd3":
686
+ sd["vae_scale"] = torch.tensor(1.5305)
687
+ sd["vae_shift"] = torch.tensor(0.0609)
688
+ return sd
689
+
690
+ @classmethod
691
+ def INPUT_TYPES(s):
692
+ return {"required": { "vae_name": (s.vae_list(), )}}
693
+ RETURN_TYPES = ("VAE",)
694
+ FUNCTION = "load_vae"
695
+
696
+ CATEGORY = "loaders"
697
+
698
+ #TODO: scale factor?
699
+ def load_vae(self, vae_name):
700
+ if vae_name in ["taesd", "taesdxl", "taesd3"]:
701
+ sd = self.load_taesd(vae_name)
702
+ else:
703
+ vae_path = folder_paths.get_full_path("vae", vae_name)
704
+ sd = totoro.utils.load_torch_file(vae_path)
705
+ vae = totoro.sd.VAE(sd=sd)
706
+ return (vae,)
707
+
708
+ class ControlNetLoader:
709
+ @classmethod
710
+ def INPUT_TYPES(s):
711
+ return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
712
+
713
+ RETURN_TYPES = ("CONTROL_NET",)
714
+ FUNCTION = "load_controlnet"
715
+
716
+ CATEGORY = "loaders"
717
+
718
+ def load_controlnet(self, control_net_name):
719
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
720
+ controlnet = totoro.controlnet.load_controlnet(controlnet_path)
721
+ return (controlnet,)
722
+
723
+ class DiffControlNetLoader:
724
+ @classmethod
725
+ def INPUT_TYPES(s):
726
+ return {"required": { "model": ("MODEL",),
727
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
728
+
729
+ RETURN_TYPES = ("CONTROL_NET",)
730
+ FUNCTION = "load_controlnet"
731
+
732
+ CATEGORY = "loaders"
733
+
734
+ def load_controlnet(self, model, control_net_name):
735
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
736
+ controlnet = totoro.controlnet.load_controlnet(controlnet_path, model)
737
+ return (controlnet,)
738
+
739
+
740
+ class ControlNetApply:
741
+ @classmethod
742
+ def INPUT_TYPES(s):
743
+ return {"required": {"conditioning": ("CONDITIONING", ),
744
+ "control_net": ("CONTROL_NET", ),
745
+ "image": ("IMAGE", ),
746
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
747
+ }}
748
+ RETURN_TYPES = ("CONDITIONING",)
749
+ FUNCTION = "apply_controlnet"
750
+
751
+ CATEGORY = "conditioning/controlnet"
752
+
753
+ def apply_controlnet(self, conditioning, control_net, image, strength):
754
+ if strength == 0:
755
+ return (conditioning, )
756
+
757
+ c = []
758
+ control_hint = image.movedim(-1,1)
759
+ for t in conditioning:
760
+ n = [t[0], t[1].copy()]
761
+ c_net = control_net.copy().set_cond_hint(control_hint, strength)
762
+ if 'control' in t[1]:
763
+ c_net.set_previous_controlnet(t[1]['control'])
764
+ n[1]['control'] = c_net
765
+ n[1]['control_apply_to_uncond'] = True
766
+ c.append(n)
767
+ return (c, )
768
+
769
+
770
+ class ControlNetApplyAdvanced:
771
+ @classmethod
772
+ def INPUT_TYPES(s):
773
+ return {"required": {"positive": ("CONDITIONING", ),
774
+ "negative": ("CONDITIONING", ),
775
+ "control_net": ("CONTROL_NET", ),
776
+ "image": ("IMAGE", ),
777
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
778
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
779
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
780
+ }}
781
+
782
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING")
783
+ RETURN_NAMES = ("positive", "negative")
784
+ FUNCTION = "apply_controlnet"
785
+
786
+ CATEGORY = "conditioning/controlnet"
787
+
788
+ def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
789
+ if strength == 0:
790
+ return (positive, negative)
791
+
792
+ control_hint = image.movedim(-1,1)
793
+ cnets = {}
794
+
795
+ out = []
796
+ for conditioning in [positive, negative]:
797
+ c = []
798
+ for t in conditioning:
799
+ d = t[1].copy()
800
+
801
+ prev_cnet = d.get('control', None)
802
+ if prev_cnet in cnets:
803
+ c_net = cnets[prev_cnet]
804
+ else:
805
+ c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
806
+ c_net.set_previous_controlnet(prev_cnet)
807
+ cnets[prev_cnet] = c_net
808
+
809
+ d['control'] = c_net
810
+ d['control_apply_to_uncond'] = False
811
+ n = [t[0], d]
812
+ c.append(n)
813
+ out.append(c)
814
+ return (out[0], out[1])
815
+
816
+
817
+ class UNETLoader:
818
+ @classmethod
819
+ def INPUT_TYPES(s):
820
+ return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
821
+ "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
822
+ }}
823
+ RETURN_TYPES = ("MODEL",)
824
+ FUNCTION = "load_unet"
825
+
826
+ CATEGORY = "advanced/loaders"
827
+
828
+ def load_unet(self, unet_name, weight_dtype):
829
+ weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype]
830
+ unet_path = folder_paths.get_full_path("unet", unet_name)
831
+ model = totoro.sd.load_unet(unet_path, dtype=weight_dtype)
832
+ return (model,)
833
+
834
+ class CLIPLoader:
835
+ @classmethod
836
+ def INPUT_TYPES(s):
837
+ return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
838
+ "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio"], ),
839
+ }}
840
+ RETURN_TYPES = ("CLIP",)
841
+ FUNCTION = "load_clip"
842
+
843
+ CATEGORY = "advanced/loaders"
844
+
845
+ def load_clip(self, clip_name, type="stable_diffusion"):
846
+ if type == "stable_cascade":
847
+ clip_type = totoro.sd.CLIPType.STABLE_CASCADE
848
+ elif type == "sd3":
849
+ clip_type = totoro.sd.CLIPType.SD3
850
+ elif type == "stable_audio":
851
+ clip_type = totoro.sd.CLIPType.STABLE_AUDIO
852
+ else:
853
+ clip_type = totoro.sd.CLIPType.STABLE_DIFFUSION
854
+
855
+ clip_path = folder_paths.get_full_path("clip", clip_name)
856
+ clip = totoro.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
857
+ return (clip,)
858
+
859
+ class DualCLIPLoader:
860
+ @classmethod
861
+ def INPUT_TYPES(s):
862
+ return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
863
+ "clip_name2": (folder_paths.get_filename_list("clip"), ),
864
+ "type": (["sdxl", "sd3", "flux"], ),
865
+ }}
866
+ RETURN_TYPES = ("CLIP",)
867
+ FUNCTION = "load_clip"
868
+
869
+ CATEGORY = "advanced/loaders"
870
+
871
+ def load_clip(self, clip_name1, clip_name2, type):
872
+ clip_path1 = folder_paths.get_full_path("clip", clip_name1)
873
+ clip_path2 = folder_paths.get_full_path("clip", clip_name2)
874
+ if type == "sdxl":
875
+ clip_type = totoro.sd.CLIPType.STABLE_DIFFUSION
876
+ elif type == "sd3":
877
+ clip_type = totoro.sd.CLIPType.SD3
878
+ elif type == "flux":
879
+ clip_type = totoro.sd.CLIPType.FLUX
880
+
881
+ clip = totoro.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
882
+ return (clip,)
883
+
884
+ class CLIPVisionLoader:
885
+ @classmethod
886
+ def INPUT_TYPES(s):
887
+ return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
888
+ }}
889
+ RETURN_TYPES = ("CLIP_VISION",)
890
+ FUNCTION = "load_clip"
891
+
892
+ CATEGORY = "loaders"
893
+
894
+ def load_clip(self, clip_name):
895
+ clip_path = folder_paths.get_full_path("clip_vision", clip_name)
896
+ clip_vision = totoro.clip_vision.load(clip_path)
897
+ return (clip_vision,)
898
+
899
+ class CLIPVisionEncode:
900
+ @classmethod
901
+ def INPUT_TYPES(s):
902
+ return {"required": { "clip_vision": ("CLIP_VISION",),
903
+ "image": ("IMAGE",)
904
+ }}
905
+ RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
906
+ FUNCTION = "encode"
907
+
908
+ CATEGORY = "conditioning"
909
+
910
+ def encode(self, clip_vision, image):
911
+ output = clip_vision.encode_image(image)
912
+ return (output,)
913
+
914
+ class StyleModelLoader:
915
+ @classmethod
916
+ def INPUT_TYPES(s):
917
+ return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
918
+
919
+ RETURN_TYPES = ("STYLE_MODEL",)
920
+ FUNCTION = "load_style_model"
921
+
922
+ CATEGORY = "loaders"
923
+
924
+ def load_style_model(self, style_model_name):
925
+ style_model_path = folder_paths.get_full_path("style_models", style_model_name)
926
+ style_model = totoro.sd.load_style_model(style_model_path)
927
+ return (style_model,)
928
+
929
+
930
+ class StyleModelApply:
931
+ @classmethod
932
+ def INPUT_TYPES(s):
933
+ return {"required": {"conditioning": ("CONDITIONING", ),
934
+ "style_model": ("STYLE_MODEL", ),
935
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
936
+ }}
937
+ RETURN_TYPES = ("CONDITIONING",)
938
+ FUNCTION = "apply_stylemodel"
939
+
940
+ CATEGORY = "conditioning/style_model"
941
+
942
+ def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
943
+ cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
944
+ c = []
945
+ for t in conditioning:
946
+ n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
947
+ c.append(n)
948
+ return (c, )
949
+
950
+ class unCLIPConditioning:
951
+ @classmethod
952
+ def INPUT_TYPES(s):
953
+ return {"required": {"conditioning": ("CONDITIONING", ),
954
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
955
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
956
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
957
+ }}
958
+ RETURN_TYPES = ("CONDITIONING",)
959
+ FUNCTION = "apply_adm"
960
+
961
+ CATEGORY = "conditioning"
962
+
963
+ def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
964
+ if strength == 0:
965
+ return (conditioning, )
966
+
967
+ c = []
968
+ for t in conditioning:
969
+ o = t[1].copy()
970
+ x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
971
+ if "unclip_conditioning" in o:
972
+ o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
973
+ else:
974
+ o["unclip_conditioning"] = [x]
975
+ n = [t[0], o]
976
+ c.append(n)
977
+ return (c, )
978
+
979
+ class GLIGENLoader:
980
+ @classmethod
981
+ def INPUT_TYPES(s):
982
+ return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
983
+
984
+ RETURN_TYPES = ("GLIGEN",)
985
+ FUNCTION = "load_gligen"
986
+
987
+ CATEGORY = "loaders"
988
+
989
+ def load_gligen(self, gligen_name):
990
+ gligen_path = folder_paths.get_full_path("gligen", gligen_name)
991
+ gligen = totoro.sd.load_gligen(gligen_path)
992
+ return (gligen,)
993
+
994
+ class GLIGENTextBoxApply:
995
+ @classmethod
996
+ def INPUT_TYPES(s):
997
+ return {"required": {"conditioning_to": ("CONDITIONING", ),
998
+ "clip": ("CLIP", ),
999
+ "gligen_textbox_model": ("GLIGEN", ),
1000
+ "text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
1001
+ "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1002
+ "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1003
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1004
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1005
+ }}
1006
+ RETURN_TYPES = ("CONDITIONING",)
1007
+ FUNCTION = "append"
1008
+
1009
+ CATEGORY = "conditioning/gligen"
1010
+
1011
+ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
1012
+ c = []
1013
+ cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
1014
+ for t in conditioning_to:
1015
+ n = [t[0], t[1].copy()]
1016
+ position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
1017
+ prev = []
1018
+ if "gligen" in n[1]:
1019
+ prev = n[1]['gligen'][2]
1020
+
1021
+ n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
1022
+ c.append(n)
1023
+ return (c, )
1024
+
1025
+ class EmptyLatentImage:
1026
+ def __init__(self):
1027
+ self.device = totoro.model_management.intermediate_device()
1028
+
1029
+ @classmethod
1030
+ def INPUT_TYPES(s):
1031
+ return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
1032
+ "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
1033
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
1034
+ RETURN_TYPES = ("LATENT",)
1035
+ FUNCTION = "generate"
1036
+
1037
+ CATEGORY = "latent"
1038
+
1039
+ def generate(self, width, height, batch_size=1):
1040
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
1041
+ return ({"samples":latent}, )
1042
+
1043
+
1044
+ class LatentFromBatch:
1045
+ @classmethod
1046
+ def INPUT_TYPES(s):
1047
+ return {"required": { "samples": ("LATENT",),
1048
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
1049
+ "length": ("INT", {"default": 1, "min": 1, "max": 64}),
1050
+ }}
1051
+ RETURN_TYPES = ("LATENT",)
1052
+ FUNCTION = "frombatch"
1053
+
1054
+ CATEGORY = "latent/batch"
1055
+
1056
+ def frombatch(self, samples, batch_index, length):
1057
+ s = samples.copy()
1058
+ s_in = samples["samples"]
1059
+ batch_index = min(s_in.shape[0] - 1, batch_index)
1060
+ length = min(s_in.shape[0] - batch_index, length)
1061
+ s["samples"] = s_in[batch_index:batch_index + length].clone()
1062
+ if "noise_mask" in samples:
1063
+ masks = samples["noise_mask"]
1064
+ if masks.shape[0] == 1:
1065
+ s["noise_mask"] = masks.clone()
1066
+ else:
1067
+ if masks.shape[0] < s_in.shape[0]:
1068
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1069
+ s["noise_mask"] = masks[batch_index:batch_index + length].clone()
1070
+ if "batch_index" not in s:
1071
+ s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
1072
+ else:
1073
+ s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
1074
+ return (s,)
1075
+
1076
+ class RepeatLatentBatch:
1077
+ @classmethod
1078
+ def INPUT_TYPES(s):
1079
+ return {"required": { "samples": ("LATENT",),
1080
+ "amount": ("INT", {"default": 1, "min": 1, "max": 64}),
1081
+ }}
1082
+ RETURN_TYPES = ("LATENT",)
1083
+ FUNCTION = "repeat"
1084
+
1085
+ CATEGORY = "latent/batch"
1086
+
1087
+ def repeat(self, samples, amount):
1088
+ s = samples.copy()
1089
+ s_in = samples["samples"]
1090
+
1091
+ s["samples"] = s_in.repeat((amount, 1,1,1))
1092
+ if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
1093
+ masks = samples["noise_mask"]
1094
+ if masks.shape[0] < s_in.shape[0]:
1095
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1096
+ s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
1097
+ if "batch_index" in s:
1098
+ offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
1099
+ s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
1100
+ return (s,)
1101
+
1102
+ class LatentUpscale:
1103
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1104
+ crop_methods = ["disabled", "center"]
1105
+
1106
+ @classmethod
1107
+ def INPUT_TYPES(s):
1108
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1109
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1110
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1111
+ "crop": (s.crop_methods,)}}
1112
+ RETURN_TYPES = ("LATENT",)
1113
+ FUNCTION = "upscale"
1114
+
1115
+ CATEGORY = "latent"
1116
+
1117
+ def upscale(self, samples, upscale_method, width, height, crop):
1118
+ if width == 0 and height == 0:
1119
+ s = samples
1120
+ else:
1121
+ s = samples.copy()
1122
+
1123
+ if width == 0:
1124
+ height = max(64, height)
1125
+ width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
1126
+ elif height == 0:
1127
+ width = max(64, width)
1128
+ height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
1129
+ else:
1130
+ width = max(64, width)
1131
+ height = max(64, height)
1132
+
1133
+ s["samples"] = totoro.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
1134
+ return (s,)
1135
+
1136
+ class LatentUpscaleBy:
1137
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1138
+
1139
+ @classmethod
1140
+ def INPUT_TYPES(s):
1141
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1142
+ "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1143
+ RETURN_TYPES = ("LATENT",)
1144
+ FUNCTION = "upscale"
1145
+
1146
+ CATEGORY = "latent"
1147
+
1148
+ def upscale(self, samples, upscale_method, scale_by):
1149
+ s = samples.copy()
1150
+ width = round(samples["samples"].shape[3] * scale_by)
1151
+ height = round(samples["samples"].shape[2] * scale_by)
1152
+ s["samples"] = totoro.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
1153
+ return (s,)
1154
+
1155
+ class LatentRotate:
1156
+ @classmethod
1157
+ def INPUT_TYPES(s):
1158
+ return {"required": { "samples": ("LATENT",),
1159
+ "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
1160
+ }}
1161
+ RETURN_TYPES = ("LATENT",)
1162
+ FUNCTION = "rotate"
1163
+
1164
+ CATEGORY = "latent/transform"
1165
+
1166
+ def rotate(self, samples, rotation):
1167
+ s = samples.copy()
1168
+ rotate_by = 0
1169
+ if rotation.startswith("90"):
1170
+ rotate_by = 1
1171
+ elif rotation.startswith("180"):
1172
+ rotate_by = 2
1173
+ elif rotation.startswith("270"):
1174
+ rotate_by = 3
1175
+
1176
+ s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
1177
+ return (s,)
1178
+
1179
+ class LatentFlip:
1180
+ @classmethod
1181
+ def INPUT_TYPES(s):
1182
+ return {"required": { "samples": ("LATENT",),
1183
+ "flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
1184
+ }}
1185
+ RETURN_TYPES = ("LATENT",)
1186
+ FUNCTION = "flip"
1187
+
1188
+ CATEGORY = "latent/transform"
1189
+
1190
+ def flip(self, samples, flip_method):
1191
+ s = samples.copy()
1192
+ if flip_method.startswith("x"):
1193
+ s["samples"] = torch.flip(samples["samples"], dims=[2])
1194
+ elif flip_method.startswith("y"):
1195
+ s["samples"] = torch.flip(samples["samples"], dims=[3])
1196
+
1197
+ return (s,)
1198
+
1199
+ class LatentComposite:
1200
+ @classmethod
1201
+ def INPUT_TYPES(s):
1202
+ return {"required": { "samples_to": ("LATENT",),
1203
+ "samples_from": ("LATENT",),
1204
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1205
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1206
+ "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1207
+ }}
1208
+ RETURN_TYPES = ("LATENT",)
1209
+ FUNCTION = "composite"
1210
+
1211
+ CATEGORY = "latent"
1212
+
1213
+ def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
1214
+ x = x // 8
1215
+ y = y // 8
1216
+ feather = feather // 8
1217
+ samples_out = samples_to.copy()
1218
+ s = samples_to["samples"].clone()
1219
+ samples_to = samples_to["samples"]
1220
+ samples_from = samples_from["samples"]
1221
+ if feather == 0:
1222
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1223
+ else:
1224
+ samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1225
+ mask = torch.ones_like(samples_from)
1226
+ for t in range(feather):
1227
+ if y != 0:
1228
+ mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
1229
+
1230
+ if y + samples_from.shape[2] < samples_to.shape[2]:
1231
+ mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
1232
+ if x != 0:
1233
+ mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
1234
+ if x + samples_from.shape[3] < samples_to.shape[3]:
1235
+ mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
1236
+ rev_mask = torch.ones_like(mask) - mask
1237
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
1238
+ samples_out["samples"] = s
1239
+ return (samples_out,)
1240
+
1241
+ class LatentBlend:
1242
+ @classmethod
1243
+ def INPUT_TYPES(s):
1244
+ return {"required": {
1245
+ "samples1": ("LATENT",),
1246
+ "samples2": ("LATENT",),
1247
+ "blend_factor": ("FLOAT", {
1248
+ "default": 0.5,
1249
+ "min": 0,
1250
+ "max": 1,
1251
+ "step": 0.01
1252
+ }),
1253
+ }}
1254
+
1255
+ RETURN_TYPES = ("LATENT",)
1256
+ FUNCTION = "blend"
1257
+
1258
+ CATEGORY = "_for_testing"
1259
+
1260
+ def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
1261
+
1262
+ samples_out = samples1.copy()
1263
+ samples1 = samples1["samples"]
1264
+ samples2 = samples2["samples"]
1265
+
1266
+ if samples1.shape != samples2.shape:
1267
+ samples2.permute(0, 3, 1, 2)
1268
+ samples2 = totoro.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
1269
+ samples2.permute(0, 2, 3, 1)
1270
+
1271
+ samples_blended = self.blend_mode(samples1, samples2, blend_mode)
1272
+ samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
1273
+ samples_out["samples"] = samples_blended
1274
+ return (samples_out,)
1275
+
1276
+ def blend_mode(self, img1, img2, mode):
1277
+ if mode == "normal":
1278
+ return img2
1279
+ else:
1280
+ raise ValueError(f"Unsupported blend mode: {mode}")
1281
+
1282
+ class LatentCrop:
1283
+ @classmethod
1284
+ def INPUT_TYPES(s):
1285
+ return {"required": { "samples": ("LATENT",),
1286
+ "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1287
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1288
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1289
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1290
+ }}
1291
+ RETURN_TYPES = ("LATENT",)
1292
+ FUNCTION = "crop"
1293
+
1294
+ CATEGORY = "latent/transform"
1295
+
1296
+ def crop(self, samples, width, height, x, y):
1297
+ s = samples.copy()
1298
+ samples = samples['samples']
1299
+ x = x // 8
1300
+ y = y // 8
1301
+
1302
+ #enfonce minimum size of 64
1303
+ if x > (samples.shape[3] - 8):
1304
+ x = samples.shape[3] - 8
1305
+ if y > (samples.shape[2] - 8):
1306
+ y = samples.shape[2] - 8
1307
+
1308
+ new_height = height // 8
1309
+ new_width = width // 8
1310
+ to_x = new_width + x
1311
+ to_y = new_height + y
1312
+ s['samples'] = samples[:,:,y:to_y, x:to_x]
1313
+ return (s,)
1314
+
1315
+ class SetLatentNoiseMask:
1316
+ @classmethod
1317
+ def INPUT_TYPES(s):
1318
+ return {"required": { "samples": ("LATENT",),
1319
+ "mask": ("MASK",),
1320
+ }}
1321
+ RETURN_TYPES = ("LATENT",)
1322
+ FUNCTION = "set_mask"
1323
+
1324
+ CATEGORY = "latent/inpaint"
1325
+
1326
+ def set_mask(self, samples, mask):
1327
+ s = samples.copy()
1328
+ s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
1329
+ return (s,)
1330
+
1331
+ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
1332
+ latent_image = latent["samples"]
1333
+ latent_image = totoro.sample.fix_empty_latent_channels(model, latent_image)
1334
+
1335
+ if disable_noise:
1336
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
1337
+ else:
1338
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
1339
+ noise = totoro.sample.prepare_noise(latent_image, seed, batch_inds)
1340
+
1341
+ noise_mask = None
1342
+ if "noise_mask" in latent:
1343
+ noise_mask = latent["noise_mask"]
1344
+
1345
+ callback = latent_preview.prepare_callback(model, steps)
1346
+ disable_pbar = not totoro.utils.PROGRESS_BAR_ENABLED
1347
+ samples = totoro.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
1348
+ denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
1349
+ force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
1350
+ out = latent.copy()
1351
+ out["samples"] = samples
1352
+ return (out, )
1353
+
1354
+ class KSampler:
1355
+ @classmethod
1356
+ def INPUT_TYPES(s):
1357
+ return {"required":
1358
+ {"model": ("MODEL",),
1359
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1360
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1361
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1362
+ "sampler_name": (totoro.samplers.KSampler.SAMPLERS, ),
1363
+ "scheduler": (totoro.samplers.KSampler.SCHEDULERS, ),
1364
+ "positive": ("CONDITIONING", ),
1365
+ "negative": ("CONDITIONING", ),
1366
+ "latent_image": ("LATENT", ),
1367
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
1368
+ }
1369
+ }
1370
+
1371
+ RETURN_TYPES = ("LATENT",)
1372
+ FUNCTION = "sample"
1373
+
1374
+ CATEGORY = "sampling"
1375
+
1376
+ def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
1377
+ return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
1378
+
1379
+ class KSamplerAdvanced:
1380
+ @classmethod
1381
+ def INPUT_TYPES(s):
1382
+ return {"required":
1383
+ {"model": ("MODEL",),
1384
+ "add_noise": (["enable", "disable"], ),
1385
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1386
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1387
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1388
+ "sampler_name": (totoro.samplers.KSampler.SAMPLERS, ),
1389
+ "scheduler": (totoro.samplers.KSampler.SCHEDULERS, ),
1390
+ "positive": ("CONDITIONING", ),
1391
+ "negative": ("CONDITIONING", ),
1392
+ "latent_image": ("LATENT", ),
1393
+ "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
1394
+ "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
1395
+ "return_with_leftover_noise": (["disable", "enable"], ),
1396
+ }
1397
+ }
1398
+
1399
+ RETURN_TYPES = ("LATENT",)
1400
+ FUNCTION = "sample"
1401
+
1402
+ CATEGORY = "sampling"
1403
+
1404
+ def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
1405
+ force_full_denoise = True
1406
+ if return_with_leftover_noise == "enable":
1407
+ force_full_denoise = False
1408
+ disable_noise = False
1409
+ if add_noise == "disable":
1410
+ disable_noise = True
1411
+ return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
1412
+
1413
+ class SaveImage:
1414
+ def __init__(self):
1415
+ self.output_dir = folder_paths.get_output_directory()
1416
+ self.type = "output"
1417
+ self.prefix_append = ""
1418
+ self.compress_level = 4
1419
+
1420
+ @classmethod
1421
+ def INPUT_TYPES(s):
1422
+ return {"required":
1423
+ {"images": ("IMAGE", ),
1424
+ "filename_prefix": ("STRING", {"default": "totoroUI"})},
1425
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1426
+ }
1427
+
1428
+ RETURN_TYPES = ()
1429
+ FUNCTION = "save_images"
1430
+
1431
+ OUTPUT_NODE = True
1432
+
1433
+ CATEGORY = "image"
1434
+
1435
+ def save_images(self, images, filename_prefix="totoroUI", prompt=None, extra_pnginfo=None):
1436
+ filename_prefix += self.prefix_append
1437
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
1438
+ results = list()
1439
+ for (batch_number, image) in enumerate(images):
1440
+ i = 255. * image.cpu().numpy()
1441
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
1442
+ metadata = None
1443
+ if not args.disable_metadata:
1444
+ metadata = PngInfo()
1445
+ if prompt is not None:
1446
+ metadata.add_text("prompt", json.dumps(prompt))
1447
+ if extra_pnginfo is not None:
1448
+ for x in extra_pnginfo:
1449
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
1450
+
1451
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
1452
+ file = f"{filename_with_batch_num}_{counter:05}_.png"
1453
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
1454
+ results.append({
1455
+ "filename": file,
1456
+ "subfolder": subfolder,
1457
+ "type": self.type
1458
+ })
1459
+ counter += 1
1460
+
1461
+ return { "ui": { "images": results } }
1462
+
1463
+ class PreviewImage(SaveImage):
1464
+ def __init__(self):
1465
+ self.output_dir = folder_paths.get_temp_directory()
1466
+ self.type = "temp"
1467
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
1468
+ self.compress_level = 1
1469
+
1470
+ @classmethod
1471
+ def INPUT_TYPES(s):
1472
+ return {"required":
1473
+ {"images": ("IMAGE", ), },
1474
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1475
+ }
1476
+
1477
+ class LoadImage:
1478
+ @classmethod
1479
+ def INPUT_TYPES(s):
1480
+ input_dir = folder_paths.get_input_directory()
1481
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1482
+ return {"required":
1483
+ {"image": (sorted(files), {"image_upload": True})},
1484
+ }
1485
+
1486
+ CATEGORY = "image"
1487
+
1488
+ RETURN_TYPES = ("IMAGE", "MASK")
1489
+ FUNCTION = "load_image"
1490
+ def load_image(self, image):
1491
+ image_path = folder_paths.get_annotated_filepath(image)
1492
+
1493
+ img = node_helpers.pillow(Image.open, image_path)
1494
+
1495
+ output_images = []
1496
+ output_masks = []
1497
+ w, h = None, None
1498
+
1499
+ excluded_formats = ['MPO']
1500
+
1501
+ for i in ImageSequence.Iterator(img):
1502
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
1503
+
1504
+ if i.mode == 'I':
1505
+ i = i.point(lambda i: i * (1 / 255))
1506
+ image = i.convert("RGB")
1507
+
1508
+ if len(output_images) == 0:
1509
+ w = image.size[0]
1510
+ h = image.size[1]
1511
+
1512
+ if image.size[0] != w or image.size[1] != h:
1513
+ continue
1514
+
1515
+ image = np.array(image).astype(np.float32) / 255.0
1516
+ image = torch.from_numpy(image)[None,]
1517
+ if 'A' in i.getbands():
1518
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
1519
+ mask = 1. - torch.from_numpy(mask)
1520
+ else:
1521
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1522
+ output_images.append(image)
1523
+ output_masks.append(mask.unsqueeze(0))
1524
+
1525
+ if len(output_images) > 1 and img.format not in excluded_formats:
1526
+ output_image = torch.cat(output_images, dim=0)
1527
+ output_mask = torch.cat(output_masks, dim=0)
1528
+ else:
1529
+ output_image = output_images[0]
1530
+ output_mask = output_masks[0]
1531
+
1532
+ return (output_image, output_mask)
1533
+
1534
+ @classmethod
1535
+ def IS_CHANGED(s, image):
1536
+ image_path = folder_paths.get_annotated_filepath(image)
1537
+ m = hashlib.sha256()
1538
+ with open(image_path, 'rb') as f:
1539
+ m.update(f.read())
1540
+ return m.digest().hex()
1541
+
1542
+ @classmethod
1543
+ def VALIDATE_INPUTS(s, image):
1544
+ if not folder_paths.exists_annotated_filepath(image):
1545
+ return "Invalid image file: {}".format(image)
1546
+
1547
+ return True
1548
+
1549
+ class LoadImageMask:
1550
+ _color_channels = ["alpha", "red", "green", "blue"]
1551
+ @classmethod
1552
+ def INPUT_TYPES(s):
1553
+ input_dir = folder_paths.get_input_directory()
1554
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1555
+ return {"required":
1556
+ {"image": (sorted(files), {"image_upload": True}),
1557
+ "channel": (s._color_channels, ), }
1558
+ }
1559
+
1560
+ CATEGORY = "mask"
1561
+
1562
+ RETURN_TYPES = ("MASK",)
1563
+ FUNCTION = "load_image"
1564
+ def load_image(self, image, channel):
1565
+ image_path = folder_paths.get_annotated_filepath(image)
1566
+ i = node_helpers.pillow(Image.open, image_path)
1567
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
1568
+ if i.getbands() != ("R", "G", "B", "A"):
1569
+ if i.mode == 'I':
1570
+ i = i.point(lambda i: i * (1 / 255))
1571
+ i = i.convert("RGBA")
1572
+ mask = None
1573
+ c = channel[0].upper()
1574
+ if c in i.getbands():
1575
+ mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
1576
+ mask = torch.from_numpy(mask)
1577
+ if c == 'A':
1578
+ mask = 1. - mask
1579
+ else:
1580
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1581
+ return (mask.unsqueeze(0),)
1582
+
1583
+ @classmethod
1584
+ def IS_CHANGED(s, image, channel):
1585
+ image_path = folder_paths.get_annotated_filepath(image)
1586
+ m = hashlib.sha256()
1587
+ with open(image_path, 'rb') as f:
1588
+ m.update(f.read())
1589
+ return m.digest().hex()
1590
+
1591
+ @classmethod
1592
+ def VALIDATE_INPUTS(s, image):
1593
+ if not folder_paths.exists_annotated_filepath(image):
1594
+ return "Invalid image file: {}".format(image)
1595
+
1596
+ return True
1597
+
1598
+ class ImageScale:
1599
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1600
+ crop_methods = ["disabled", "center"]
1601
+
1602
+ @classmethod
1603
+ def INPUT_TYPES(s):
1604
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1605
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1606
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1607
+ "crop": (s.crop_methods,)}}
1608
+ RETURN_TYPES = ("IMAGE",)
1609
+ FUNCTION = "upscale"
1610
+
1611
+ CATEGORY = "image/upscaling"
1612
+
1613
+ def upscale(self, image, upscale_method, width, height, crop):
1614
+ if width == 0 and height == 0:
1615
+ s = image
1616
+ else:
1617
+ samples = image.movedim(-1,1)
1618
+
1619
+ if width == 0:
1620
+ width = max(1, round(samples.shape[3] * height / samples.shape[2]))
1621
+ elif height == 0:
1622
+ height = max(1, round(samples.shape[2] * width / samples.shape[3]))
1623
+
1624
+ s = totoro.utils.common_upscale(samples, width, height, upscale_method, crop)
1625
+ s = s.movedim(1,-1)
1626
+ return (s,)
1627
+
1628
+ class ImageScaleBy:
1629
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1630
+
1631
+ @classmethod
1632
+ def INPUT_TYPES(s):
1633
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1634
+ "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1635
+ RETURN_TYPES = ("IMAGE",)
1636
+ FUNCTION = "upscale"
1637
+
1638
+ CATEGORY = "image/upscaling"
1639
+
1640
+ def upscale(self, image, upscale_method, scale_by):
1641
+ samples = image.movedim(-1,1)
1642
+ width = round(samples.shape[3] * scale_by)
1643
+ height = round(samples.shape[2] * scale_by)
1644
+ s = totoro.utils.common_upscale(samples, width, height, upscale_method, "disabled")
1645
+ s = s.movedim(1,-1)
1646
+ return (s,)
1647
+
1648
+ class ImageInvert:
1649
+
1650
+ @classmethod
1651
+ def INPUT_TYPES(s):
1652
+ return {"required": { "image": ("IMAGE",)}}
1653
+
1654
+ RETURN_TYPES = ("IMAGE",)
1655
+ FUNCTION = "invert"
1656
+
1657
+ CATEGORY = "image"
1658
+
1659
+ def invert(self, image):
1660
+ s = 1.0 - image
1661
+ return (s,)
1662
+
1663
+ class ImageBatch:
1664
+
1665
+ @classmethod
1666
+ def INPUT_TYPES(s):
1667
+ return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
1668
+
1669
+ RETURN_TYPES = ("IMAGE",)
1670
+ FUNCTION = "batch"
1671
+
1672
+ CATEGORY = "image"
1673
+
1674
+ def batch(self, image1, image2):
1675
+ if image1.shape[1:] != image2.shape[1:]:
1676
+ image2 = totoro.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
1677
+ s = torch.cat((image1, image2), dim=0)
1678
+ return (s,)
1679
+
1680
+ class EmptyImage:
1681
+ def __init__(self, device="cpu"):
1682
+ self.device = device
1683
+
1684
+ @classmethod
1685
+ def INPUT_TYPES(s):
1686
+ return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1687
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1688
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
1689
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
1690
+ }}
1691
+ RETURN_TYPES = ("IMAGE",)
1692
+ FUNCTION = "generate"
1693
+
1694
+ CATEGORY = "image"
1695
+
1696
+ def generate(self, width, height, batch_size=1, color=0):
1697
+ r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
1698
+ g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
1699
+ b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
1700
+ return (torch.cat((r, g, b), dim=-1), )
1701
+
1702
+ class ImagePadForOutpaint:
1703
+
1704
+ @classmethod
1705
+ def INPUT_TYPES(s):
1706
+ return {
1707
+ "required": {
1708
+ "image": ("IMAGE",),
1709
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1710
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1711
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1712
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1713
+ "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1714
+ }
1715
+ }
1716
+
1717
+ RETURN_TYPES = ("IMAGE", "MASK")
1718
+ FUNCTION = "expand_image"
1719
+
1720
+ CATEGORY = "image"
1721
+
1722
+ def expand_image(self, image, left, top, right, bottom, feathering):
1723
+ d1, d2, d3, d4 = image.size()
1724
+
1725
+ new_image = torch.ones(
1726
+ (d1, d2 + top + bottom, d3 + left + right, d4),
1727
+ dtype=torch.float32,
1728
+ ) * 0.5
1729
+
1730
+ new_image[:, top:top + d2, left:left + d3, :] = image
1731
+
1732
+ mask = torch.ones(
1733
+ (d2 + top + bottom, d3 + left + right),
1734
+ dtype=torch.float32,
1735
+ )
1736
+
1737
+ t = torch.zeros(
1738
+ (d2, d3),
1739
+ dtype=torch.float32
1740
+ )
1741
+
1742
+ if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
1743
+
1744
+ for i in range(d2):
1745
+ for j in range(d3):
1746
+ dt = i if top != 0 else d2
1747
+ db = d2 - i if bottom != 0 else d2
1748
+
1749
+ dl = j if left != 0 else d3
1750
+ dr = d3 - j if right != 0 else d3
1751
+
1752
+ d = min(dt, db, dl, dr)
1753
+
1754
+ if d >= feathering:
1755
+ continue
1756
+
1757
+ v = (feathering - d) / feathering
1758
+
1759
+ t[i, j] = v * v
1760
+
1761
+ mask[top:top + d2, left:left + d3] = t
1762
+
1763
+ return (new_image, mask)
1764
+
1765
+
1766
+ NODE_CLASS_MAPPINGS = {
1767
+ "KSampler": KSampler,
1768
+ "CheckpointLoaderSimple": CheckpointLoaderSimple,
1769
+ "CLIPTextEncode": CLIPTextEncode,
1770
+ "CLIPSetLastLayer": CLIPSetLastLayer,
1771
+ "VAEDecode": VAEDecode,
1772
+ "VAEEncode": VAEEncode,
1773
+ "VAEEncodeForInpaint": VAEEncodeForInpaint,
1774
+ "VAELoader": VAELoader,
1775
+ "EmptyLatentImage": EmptyLatentImage,
1776
+ "LatentUpscale": LatentUpscale,
1777
+ "LatentUpscaleBy": LatentUpscaleBy,
1778
+ "LatentFromBatch": LatentFromBatch,
1779
+ "RepeatLatentBatch": RepeatLatentBatch,
1780
+ "SaveImage": SaveImage,
1781
+ "PreviewImage": PreviewImage,
1782
+ "LoadImage": LoadImage,
1783
+ "LoadImageMask": LoadImageMask,
1784
+ "ImageScale": ImageScale,
1785
+ "ImageScaleBy": ImageScaleBy,
1786
+ "ImageInvert": ImageInvert,
1787
+ "ImageBatch": ImageBatch,
1788
+ "ImagePadForOutpaint": ImagePadForOutpaint,
1789
+ "EmptyImage": EmptyImage,
1790
+ "ConditioningAverage": ConditioningAverage ,
1791
+ "ConditioningCombine": ConditioningCombine,
1792
+ "ConditioningConcat": ConditioningConcat,
1793
+ "ConditioningSetArea": ConditioningSetArea,
1794
+ "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
1795
+ "ConditioningSetAreaStrength": ConditioningSetAreaStrength,
1796
+ "ConditioningSetMask": ConditioningSetMask,
1797
+ "KSamplerAdvanced": KSamplerAdvanced,
1798
+ "SetLatentNoiseMask": SetLatentNoiseMask,
1799
+ "LatentComposite": LatentComposite,
1800
+ "LatentBlend": LatentBlend,
1801
+ "LatentRotate": LatentRotate,
1802
+ "LatentFlip": LatentFlip,
1803
+ "LatentCrop": LatentCrop,
1804
+ "LoraLoader": LoraLoader,
1805
+ "CLIPLoader": CLIPLoader,
1806
+ "UNETLoader": UNETLoader,
1807
+ "DualCLIPLoader": DualCLIPLoader,
1808
+ "CLIPVisionEncode": CLIPVisionEncode,
1809
+ "StyleModelApply": StyleModelApply,
1810
+ "unCLIPConditioning": unCLIPConditioning,
1811
+ "ControlNetApply": ControlNetApply,
1812
+ "ControlNetApplyAdvanced": ControlNetApplyAdvanced,
1813
+ "ControlNetLoader": ControlNetLoader,
1814
+ "DiffControlNetLoader": DiffControlNetLoader,
1815
+ "StyleModelLoader": StyleModelLoader,
1816
+ "CLIPVisionLoader": CLIPVisionLoader,
1817
+ "VAEDecodeTiled": VAEDecodeTiled,
1818
+ "VAEEncodeTiled": VAEEncodeTiled,
1819
+ "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
1820
+ "GLIGENLoader": GLIGENLoader,
1821
+ "GLIGENTextBoxApply": GLIGENTextBoxApply,
1822
+ "InpaintModelConditioning": InpaintModelConditioning,
1823
+
1824
+ "CheckpointLoader": CheckpointLoader,
1825
+ "DiffusersLoader": DiffusersLoader,
1826
+
1827
+ "LoadLatent": LoadLatent,
1828
+ "SaveLatent": SaveLatent,
1829
+
1830
+ "ConditioningZeroOut": ConditioningZeroOut,
1831
+ "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
1832
+ "LoraLoaderModelOnly": LoraLoaderModelOnly,
1833
+ }
1834
+
1835
+ NODE_DISPLAY_NAME_MAPPINGS = {
1836
+ # Sampling
1837
+ "KSampler": "KSampler",
1838
+ "KSamplerAdvanced": "KSampler (Advanced)",
1839
+ # Loaders
1840
+ "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
1841
+ "CheckpointLoaderSimple": "Load Checkpoint",
1842
+ "VAELoader": "Load VAE",
1843
+ "LoraLoader": "Load LoRA",
1844
+ "CLIPLoader": "Load CLIP",
1845
+ "ControlNetLoader": "Load ControlNet Model",
1846
+ "DiffControlNetLoader": "Load ControlNet Model (diff)",
1847
+ "StyleModelLoader": "Load Style Model",
1848
+ "CLIPVisionLoader": "Load CLIP Vision",
1849
+ "UpscaleModelLoader": "Load Upscale Model",
1850
+ "UNETLoader": "Load Diffusion Model",
1851
+ # Conditioning
1852
+ "CLIPVisionEncode": "CLIP Vision Encode",
1853
+ "StyleModelApply": "Apply Style Model",
1854
+ "CLIPTextEncode": "CLIP Text Encode (Prompt)",
1855
+ "CLIPSetLastLayer": "CLIP Set Last Layer",
1856
+ "ConditioningCombine": "Conditioning (Combine)",
1857
+ "ConditioningAverage ": "Conditioning (Average)",
1858
+ "ConditioningConcat": "Conditioning (Concat)",
1859
+ "ConditioningSetArea": "Conditioning (Set Area)",
1860
+ "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
1861
+ "ConditioningSetMask": "Conditioning (Set Mask)",
1862
+ "ControlNetApply": "Apply ControlNet",
1863
+ "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
1864
+ # Latent
1865
+ "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
1866
+ "SetLatentNoiseMask": "Set Latent Noise Mask",
1867
+ "VAEDecode": "VAE Decode",
1868
+ "VAEEncode": "VAE Encode",
1869
+ "LatentRotate": "Rotate Latent",
1870
+ "LatentFlip": "Flip Latent",
1871
+ "LatentCrop": "Crop Latent",
1872
+ "EmptyLatentImage": "Empty Latent Image",
1873
+ "LatentUpscale": "Upscale Latent",
1874
+ "LatentUpscaleBy": "Upscale Latent By",
1875
+ "LatentComposite": "Latent Composite",
1876
+ "LatentBlend": "Latent Blend",
1877
+ "LatentFromBatch" : "Latent From Batch",
1878
+ "RepeatLatentBatch": "Repeat Latent Batch",
1879
+ # Image
1880
+ "SaveImage": "Save Image",
1881
+ "PreviewImage": "Preview Image",
1882
+ "LoadImage": "Load Image",
1883
+ "LoadImageMask": "Load Image (as Mask)",
1884
+ "ImageScale": "Upscale Image",
1885
+ "ImageScaleBy": "Upscale Image By",
1886
+ "ImageUpscaleWithModel": "Upscale Image (using Model)",
1887
+ "ImageInvert": "Invert Image",
1888
+ "ImagePadForOutpaint": "Pad Image for Outpainting",
1889
+ "ImageBatch": "Batch Images",
1890
+ # _for_testing
1891
+ "VAEDecodeTiled": "VAE Decode (Tiled)",
1892
+ "VAEEncodeTiled": "VAE Encode (Tiled)",
1893
+ }
1894
+
1895
+ EXTENSION_WEB_DIRS = {}
1896
+
1897
+
1898
+ def get_module_name(module_path: str) -> str:
1899
+ """
1900
+ Returns the module name based on the given module path.
1901
+ Examples:
1902
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node.py") -> "my_custom_node"
1903
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node") -> "my_custom_node"
1904
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node/") -> "my_custom_node"
1905
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node/__init__.py") -> "my_custom_node"
1906
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node/__init__") -> "my_custom_node"
1907
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node/__init__/") -> "my_custom_node"
1908
+ get_module_name("C:/Users/username/totoroUI/custom_nodes/my_custom_node.disabled") -> "custom_nodes
1909
+ Args:
1910
+ module_path (str): The path of the module.
1911
+ Returns:
1912
+ str: The module name.
1913
+ """
1914
+ base_path = os.path.basename(module_path)
1915
+ if os.path.isfile(module_path):
1916
+ base_path = os.path.splitext(base_path)[0]
1917
+ return base_path
1918
+
1919
+
1920
+ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
1921
+ module_name = os.path.basename(module_path)
1922
+ if os.path.isfile(module_path):
1923
+ sp = os.path.splitext(module_path)
1924
+ module_name = sp[0]
1925
+ try:
1926
+ logging.debug("Trying to load custom node {}".format(module_path))
1927
+ if os.path.isfile(module_path):
1928
+ module_spec = importlib.util.spec_from_file_location(module_name, module_path)
1929
+ module_dir = os.path.split(module_path)[0]
1930
+ else:
1931
+ module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
1932
+ module_dir = module_path
1933
+
1934
+ module = importlib.util.module_from_spec(module_spec)
1935
+ sys.modules[module_name] = module
1936
+ module_spec.loader.exec_module(module)
1937
+
1938
+ if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
1939
+ web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
1940
+ if os.path.isdir(web_dir):
1941
+ EXTENSION_WEB_DIRS[module_name] = web_dir
1942
+
1943
+ if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
1944
+ for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
1945
+ if name not in ignore:
1946
+ NODE_CLASS_MAPPINGS[name] = node_cls
1947
+ node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
1948
+ if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
1949
+ NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
1950
+ return True
1951
+ else:
1952
+ logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
1953
+ return False
1954
+ except Exception as e:
1955
+ logging.warning(traceback.format_exc())
1956
+ logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
1957
+ return False
1958
+
1959
+ def init_external_custom_nodes():
1960
+ """
1961
+ Initializes the external custom nodes.
1962
+
1963
+ This function loads custom nodes from the specified folder paths and imports them into the application.
1964
+ It measures the import times for each custom node and logs the results.
1965
+
1966
+ Returns:
1967
+ None
1968
+ """
1969
+ base_node_names = set(NODE_CLASS_MAPPINGS.keys())
1970
+ node_paths = folder_paths.get_folder_paths("custom_nodes")
1971
+ node_import_times = []
1972
+ for custom_node_path in node_paths:
1973
+ possible_modules = os.listdir(os.path.realpath(custom_node_path))
1974
+ if "__pycache__" in possible_modules:
1975
+ possible_modules.remove("__pycache__")
1976
+
1977
+ for possible_module in possible_modules:
1978
+ module_path = os.path.join(custom_node_path, possible_module)
1979
+ if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
1980
+ if module_path.endswith(".disabled"): continue
1981
+ time_before = time.perf_counter()
1982
+ success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
1983
+ node_import_times.append((time.perf_counter() - time_before, module_path, success))
1984
+
1985
+ if len(node_import_times) > 0:
1986
+ logging.info("\nImport times for custom nodes:")
1987
+ for n in sorted(node_import_times):
1988
+ if n[2]:
1989
+ import_message = ""
1990
+ else:
1991
+ import_message = " (IMPORT FAILED)"
1992
+ logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
1993
+ logging.info("")
1994
+
1995
+ def init_builtin_extra_nodes():
1996
+ """
1997
+ Initializes the built-in extra nodes in totoroUI.
1998
+
1999
+ This function loads the extra node files located in the "totoro_extras" directory and imports them into totoroUI.
2000
+ If any of the extra node files fail to import, a warning message is logged.
2001
+
2002
+ Returns:
2003
+ None
2004
+ """
2005
+ extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "totoro_extras")
2006
+ extras_files = [
2007
+ "nodes_latent.py",
2008
+ "nodes_hypernetwork.py",
2009
+ "nodes_upscale_model.py",
2010
+ "nodes_post_processing.py",
2011
+ "nodes_mask.py",
2012
+ "nodes_compositing.py",
2013
+ "nodes_rebatch.py",
2014
+ "nodes_model_merging.py",
2015
+ "nodes_tomesd.py",
2016
+ "nodes_clip_sdxl.py",
2017
+ "nodes_canny.py",
2018
+ "nodes_freelunch.py",
2019
+ "nodes_custom_sampler.py",
2020
+ "nodes_hypertile.py",
2021
+ "nodes_model_advanced.py",
2022
+ "nodes_model_downscale.py",
2023
+ "nodes_images.py",
2024
+ "nodes_video_model.py",
2025
+ "nodes_sag.py",
2026
+ "nodes_perpneg.py",
2027
+ "nodes_stable3d.py",
2028
+ "nodes_sdupscale.py",
2029
+ "nodes_photomaker.py",
2030
+ "nodes_cond.py",
2031
+ "nodes_morphology.py",
2032
+ "nodes_stable_cascade.py",
2033
+ "nodes_differential_diffusion.py",
2034
+ "nodes_ip2p.py",
2035
+ "nodes_model_merging_model_specific.py",
2036
+ "nodes_pag.py",
2037
+ "nodes_align_your_steps.py",
2038
+ "nodes_attention_multiply.py",
2039
+ "nodes_advanced_samplers.py",
2040
+ "nodes_webcam.py",
2041
+ "nodes_audio.py",
2042
+ "nodes_sd3.py",
2043
+ "nodes_gits.py",
2044
+ "nodes_controlnet.py",
2045
+ "nodes_hunyuan.py",
2046
+ ]
2047
+
2048
+ import_failed = []
2049
+ for node_file in extras_files:
2050
+ if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="totoro_extras"):
2051
+ import_failed.append(node_file)
2052
+
2053
+ return import_failed
2054
+
2055
+
2056
+ def init_extra_nodes(init_custom_nodes=True):
2057
+ import_failed = init_builtin_extra_nodes()
2058
+
2059
+ if init_custom_nodes:
2060
+ init_external_custom_nodes()
2061
+ else:
2062
+ logging.info("Skipping loading of custom nodes")
2063
+
2064
+ if len(import_failed) > 0:
2065
+ logging.warning("WARNING: some totoro_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n")
2066
+ for node in import_failed:
2067
+ logging.warning("IMPORT FAILED: {}".format(node))
2068
+ logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated totoroUI.")
2069
+ if args.windows_standalone_build:
2070
+ logging.warning("Please run the update script: update/update_totoroui.bat")
2071
+ else:
2072
+ logging.warning("Please do a: pip install -r requirements.txt")
2073
+ logging.warning("")
totoro/__pycache__/checkpoint_pickle.cpython-311.pyc ADDED
Binary file (1.08 kB). View file
 
totoro/__pycache__/cli_args.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
totoro/__pycache__/clip_model.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
totoro/__pycache__/clip_vision.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
totoro/__pycache__/conds.cpython-311.pyc ADDED
Binary file (5.49 kB). View file
 
totoro/__pycache__/controlnet.cpython-311.pyc ADDED
Binary file (38.2 kB). View file
 
totoro/__pycache__/diffusers_convert.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
totoro/__pycache__/diffusers_load.cpython-311.pyc ADDED
Binary file (2.36 kB). View file
 
totoro/__pycache__/gligen.cpython-311.pyc ADDED
Binary file (22 kB). View file
 
totoro/__pycache__/latent_formats.cpython-311.pyc ADDED
Binary file (8.56 kB). View file
 
totoro/__pycache__/lora.cpython-311.pyc ADDED
Binary file (15.7 kB). View file
 
totoro/__pycache__/model_base.cpython-311.pyc ADDED
Binary file (53.9 kB). View file
 
totoro/__pycache__/model_detection.cpython-311.pyc ADDED
Binary file (30.2 kB). View file
 
totoro/__pycache__/model_management.cpython-311.pyc ADDED
Binary file (40.8 kB). View file
 
totoro/__pycache__/model_patcher.cpython-311.pyc ADDED
Binary file (34 kB). View file
 
totoro/__pycache__/model_sampling.cpython-311.pyc ADDED
Binary file (21.7 kB). View file
 
totoro/__pycache__/ops.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
totoro/__pycache__/options.cpython-311.pyc ADDED
Binary file (320 Bytes). View file
 
totoro/__pycache__/sample.cpython-311.pyc ADDED
Binary file (4.74 kB). View file
 
totoro/__pycache__/sampler_helpers.cpython-311.pyc ADDED
Binary file (4.64 kB). View file
 
totoro/__pycache__/samplers.cpython-311.pyc ADDED
Binary file (45.7 kB). View file
 
totoro/__pycache__/sd.cpython-311.pyc ADDED
Binary file (47.3 kB). View file
 
totoro/__pycache__/sd1_clip.cpython-311.pyc ADDED
Binary file (34.6 kB). View file
 
totoro/__pycache__/sdxl_clip.cpython-311.pyc ADDED
Binary file (9.91 kB). View file
 
totoro/__pycache__/supported_models.cpython-311.pyc ADDED
Binary file (30.8 kB). View file
 
totoro/__pycache__/supported_models_base.cpython-311.pyc ADDED
Binary file (5.92 kB). View file
 
totoro/__pycache__/types.cpython-311.pyc ADDED
Binary file (1.97 kB). View file
 
totoro/__pycache__/utils.cpython-311.pyc ADDED
Binary file (41.1 kB). View file
 
totoro/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
totoro/cldm/__pycache__/cldm.cpython-311.pyc ADDED
Binary file (23 kB). View file
 
totoro/cldm/__pycache__/control_types.cpython-311.pyc ADDED
Binary file (379 Bytes). View file
 
totoro/cldm/__pycache__/mmdit.cpython-311.pyc ADDED
Binary file (3.93 kB). View file
 
totoro/cldm/cldm.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ..ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ..ldm.modules.attention import SpatialTransformer
14
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ..ldm.util import exists
16
+ from .control_types import UNION_CONTROLNET_TYPES
17
+ from collections import OrderedDict
18
+ import totoro.ops
19
+ from totoro.ldm.modules.attention import optimized_attention
20
+
21
+ class OptimizedAttention(nn.Module):
22
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
23
+ super().__init__()
24
+ self.heads = nhead
25
+ self.c = c
26
+
27
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
28
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
29
+
30
+ def forward(self, x):
31
+ x = self.in_proj(x)
32
+ q, k, v = x.split(self.c, dim=2)
33
+ out = optimized_attention(q, k, v, self.heads)
34
+ return self.out_proj(out)
35
+
36
+ class QuickGELU(nn.Module):
37
+ def forward(self, x: torch.Tensor):
38
+ return x * torch.sigmoid(1.702 * x)
39
+
40
+ class ResBlockUnionControlnet(nn.Module):
41
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
42
+ super().__init__()
43
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
44
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
45
+ self.mlp = nn.Sequential(
46
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
47
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
48
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
49
+
50
+ def attention(self, x: torch.Tensor):
51
+ return self.attn(x)
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = x + self.attention(self.ln_1(x))
55
+ x = x + self.mlp(self.ln_2(x))
56
+ return x
57
+
58
+ class ControlledUnetModel(UNetModel):
59
+ #implemented in the ldm unet
60
+ pass
61
+
62
+ class ControlNet(nn.Module):
63
+ def __init__(
64
+ self,
65
+ image_size,
66
+ in_channels,
67
+ model_channels,
68
+ hint_channels,
69
+ num_res_blocks,
70
+ dropout=0,
71
+ channel_mult=(1, 2, 4, 8),
72
+ conv_resample=True,
73
+ dims=2,
74
+ num_classes=None,
75
+ use_checkpoint=False,
76
+ dtype=torch.float32,
77
+ num_heads=-1,
78
+ num_head_channels=-1,
79
+ num_heads_upsample=-1,
80
+ use_scale_shift_norm=False,
81
+ resblock_updown=False,
82
+ use_new_attention_order=False,
83
+ use_spatial_transformer=False, # custom transformer support
84
+ transformer_depth=1, # custom transformer support
85
+ context_dim=None, # custom transformer support
86
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
87
+ legacy=True,
88
+ disable_self_attentions=None,
89
+ num_attention_blocks=None,
90
+ disable_middle_self_attn=False,
91
+ use_linear_in_transformer=False,
92
+ adm_in_channels=None,
93
+ transformer_depth_middle=None,
94
+ transformer_depth_output=None,
95
+ attn_precision=None,
96
+ union_controlnet_num_control_type=None,
97
+ device=None,
98
+ operations=totoro.ops.disable_weight_init,
99
+ **kwargs,
100
+ ):
101
+ super().__init__()
102
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
103
+ if use_spatial_transformer:
104
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
105
+
106
+ if context_dim is not None:
107
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
108
+ # from omegaconf.listconfig import ListConfig
109
+ # if type(context_dim) == ListConfig:
110
+ # context_dim = list(context_dim)
111
+
112
+ if num_heads_upsample == -1:
113
+ num_heads_upsample = num_heads
114
+
115
+ if num_heads == -1:
116
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
117
+
118
+ if num_head_channels == -1:
119
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
120
+
121
+ self.dims = dims
122
+ self.image_size = image_size
123
+ self.in_channels = in_channels
124
+ self.model_channels = model_channels
125
+
126
+ if isinstance(num_res_blocks, int):
127
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
128
+ else:
129
+ if len(num_res_blocks) != len(channel_mult):
130
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
131
+ "as a list/tuple (per-level) with the same length as channel_mult")
132
+ self.num_res_blocks = num_res_blocks
133
+
134
+ if disable_self_attentions is not None:
135
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
136
+ assert len(disable_self_attentions) == len(channel_mult)
137
+ if num_attention_blocks is not None:
138
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
139
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
140
+
141
+ transformer_depth = transformer_depth[:]
142
+
143
+ self.dropout = dropout
144
+ self.channel_mult = channel_mult
145
+ self.conv_resample = conv_resample
146
+ self.num_classes = num_classes
147
+ self.use_checkpoint = use_checkpoint
148
+ self.dtype = dtype
149
+ self.num_heads = num_heads
150
+ self.num_head_channels = num_head_channels
151
+ self.num_heads_upsample = num_heads_upsample
152
+ self.predict_codebook_ids = n_embed is not None
153
+
154
+ time_embed_dim = model_channels * 4
155
+ self.time_embed = nn.Sequential(
156
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
157
+ nn.SiLU(),
158
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
159
+ )
160
+
161
+ if self.num_classes is not None:
162
+ if isinstance(self.num_classes, int):
163
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
164
+ elif self.num_classes == "continuous":
165
+ print("setting up linear c_adm embedding layer")
166
+ self.label_emb = nn.Linear(1, time_embed_dim)
167
+ elif self.num_classes == "sequential":
168
+ assert adm_in_channels is not None
169
+ self.label_emb = nn.Sequential(
170
+ nn.Sequential(
171
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
172
+ nn.SiLU(),
173
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
174
+ )
175
+ )
176
+ else:
177
+ raise ValueError()
178
+
179
+ self.input_blocks = nn.ModuleList(
180
+ [
181
+ TimestepEmbedSequential(
182
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
183
+ )
184
+ ]
185
+ )
186
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
187
+
188
+ self.input_hint_block = TimestepEmbedSequential(
189
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
190
+ nn.SiLU(),
191
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
192
+ nn.SiLU(),
193
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
194
+ nn.SiLU(),
195
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
196
+ nn.SiLU(),
197
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
198
+ nn.SiLU(),
199
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
200
+ nn.SiLU(),
201
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
202
+ nn.SiLU(),
203
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
204
+ )
205
+
206
+ self._feature_size = model_channels
207
+ input_block_chans = [model_channels]
208
+ ch = model_channels
209
+ ds = 1
210
+ for level, mult in enumerate(channel_mult):
211
+ for nr in range(self.num_res_blocks[level]):
212
+ layers = [
213
+ ResBlock(
214
+ ch,
215
+ time_embed_dim,
216
+ dropout,
217
+ out_channels=mult * model_channels,
218
+ dims=dims,
219
+ use_checkpoint=use_checkpoint,
220
+ use_scale_shift_norm=use_scale_shift_norm,
221
+ dtype=self.dtype,
222
+ device=device,
223
+ operations=operations,
224
+ )
225
+ ]
226
+ ch = mult * model_channels
227
+ num_transformers = transformer_depth.pop(0)
228
+ if num_transformers > 0:
229
+ if num_head_channels == -1:
230
+ dim_head = ch // num_heads
231
+ else:
232
+ num_heads = ch // num_head_channels
233
+ dim_head = num_head_channels
234
+ if legacy:
235
+ #num_heads = 1
236
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
237
+ if exists(disable_self_attentions):
238
+ disabled_sa = disable_self_attentions[level]
239
+ else:
240
+ disabled_sa = False
241
+
242
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
243
+ layers.append(
244
+ SpatialTransformer(
245
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
246
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
247
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
248
+ )
249
+ )
250
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
251
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
252
+ self._feature_size += ch
253
+ input_block_chans.append(ch)
254
+ if level != len(channel_mult) - 1:
255
+ out_ch = ch
256
+ self.input_blocks.append(
257
+ TimestepEmbedSequential(
258
+ ResBlock(
259
+ ch,
260
+ time_embed_dim,
261
+ dropout,
262
+ out_channels=out_ch,
263
+ dims=dims,
264
+ use_checkpoint=use_checkpoint,
265
+ use_scale_shift_norm=use_scale_shift_norm,
266
+ down=True,
267
+ dtype=self.dtype,
268
+ device=device,
269
+ operations=operations
270
+ )
271
+ if resblock_updown
272
+ else Downsample(
273
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
274
+ )
275
+ )
276
+ )
277
+ ch = out_ch
278
+ input_block_chans.append(ch)
279
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
280
+ ds *= 2
281
+ self._feature_size += ch
282
+
283
+ if num_head_channels == -1:
284
+ dim_head = ch // num_heads
285
+ else:
286
+ num_heads = ch // num_head_channels
287
+ dim_head = num_head_channels
288
+ if legacy:
289
+ #num_heads = 1
290
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
291
+ mid_block = [
292
+ ResBlock(
293
+ ch,
294
+ time_embed_dim,
295
+ dropout,
296
+ dims=dims,
297
+ use_checkpoint=use_checkpoint,
298
+ use_scale_shift_norm=use_scale_shift_norm,
299
+ dtype=self.dtype,
300
+ device=device,
301
+ operations=operations
302
+ )]
303
+ if transformer_depth_middle >= 0:
304
+ mid_block += [SpatialTransformer( # always uses a self-attn
305
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
306
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
307
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
308
+ ),
309
+ ResBlock(
310
+ ch,
311
+ time_embed_dim,
312
+ dropout,
313
+ dims=dims,
314
+ use_checkpoint=use_checkpoint,
315
+ use_scale_shift_norm=use_scale_shift_norm,
316
+ dtype=self.dtype,
317
+ device=device,
318
+ operations=operations
319
+ )]
320
+ self.middle_block = TimestepEmbedSequential(*mid_block)
321
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
322
+ self._feature_size += ch
323
+
324
+ if union_controlnet_num_control_type is not None:
325
+ self.num_control_type = union_controlnet_num_control_type
326
+ num_trans_channel = 320
327
+ num_trans_head = 8
328
+ num_trans_layer = 1
329
+ num_proj_channel = 320
330
+ # task_scale_factor = num_trans_channel ** 0.5
331
+ self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
332
+
333
+ self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
334
+ self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
335
+ #-----------------------------------------------------------------------------------------------------
336
+
337
+ control_add_embed_dim = 256
338
+ class ControlAddEmbedding(nn.Module):
339
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
340
+ super().__init__()
341
+ self.num_control_type = num_control_type
342
+ self.in_dim = in_dim
343
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
344
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
345
+ def forward(self, control_type, dtype, device):
346
+ c_type = torch.zeros((self.num_control_type,), device=device)
347
+ c_type[control_type] = 1.0
348
+ c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
349
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
350
+
351
+ self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
352
+ else:
353
+ self.task_embedding = None
354
+ self.control_add_embedding = None
355
+
356
+ def union_controlnet_merge(self, hint, control_type, emb, context):
357
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
358
+ inputs = []
359
+ condition_list = []
360
+
361
+ for idx in range(min(1, len(control_type))):
362
+ controlnet_cond = self.input_hint_block(hint[idx], emb, context)
363
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
364
+ if idx < len(control_type):
365
+ feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
366
+
367
+ inputs.append(feat_seq.unsqueeze(1))
368
+ condition_list.append(controlnet_cond)
369
+
370
+ x = torch.cat(inputs, dim=1)
371
+ x = self.transformer_layes(x)
372
+ controlnet_cond_fuser = None
373
+ for idx in range(len(control_type)):
374
+ alpha = self.spatial_ch_projs(x[:, idx])
375
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
376
+ o = condition_list[idx] + alpha
377
+ if controlnet_cond_fuser is None:
378
+ controlnet_cond_fuser = o
379
+ else:
380
+ controlnet_cond_fuser += o
381
+ return controlnet_cond_fuser
382
+
383
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
384
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
385
+
386
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
387
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
388
+ emb = self.time_embed(t_emb)
389
+
390
+ guided_hint = None
391
+ if self.control_add_embedding is not None: #Union Controlnet
392
+ control_type = kwargs.get("control_type", [])
393
+
394
+ if any([c >= self.num_control_type for c in control_type]):
395
+ max_type = max(control_type)
396
+ max_type_name = {
397
+ v: k for k, v in UNION_CONTROLNET_TYPES.items()
398
+ }[max_type]
399
+ raise ValueError(
400
+ f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
401
+ f"({self.num_control_type}) supported.\n" +
402
+ "Please consider using the ProMax ControlNet Union model.\n" +
403
+ "https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
404
+ )
405
+
406
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
407
+ if len(control_type) > 0:
408
+ if len(hint.shape) < 5:
409
+ hint = hint.unsqueeze(dim=0)
410
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
411
+
412
+ if guided_hint is None:
413
+ guided_hint = self.input_hint_block(hint, emb, context)
414
+
415
+ out_output = []
416
+ out_middle = []
417
+
418
+ hs = []
419
+ if self.num_classes is not None:
420
+ assert y.shape[0] == x.shape[0]
421
+ emb = emb + self.label_emb(y)
422
+
423
+ h = x
424
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
425
+ if guided_hint is not None:
426
+ h = module(h, emb, context)
427
+ h += guided_hint
428
+ guided_hint = None
429
+ else:
430
+ h = module(h, emb, context)
431
+ out_output.append(zero_conv(h, emb, context))
432
+
433
+ h = self.middle_block(h, emb, context)
434
+ out_middle.append(self.middle_block_out(h, emb, context))
435
+
436
+ return {"middle": out_middle, "output": out_output}
437
+
totoro/cldm/control_types.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ UNION_CONTROLNET_TYPES = {
2
+ "openpose": 0,
3
+ "depth": 1,
4
+ "hed/pidi/scribble/ted": 2,
5
+ "canny/lineart/anime_lineart/mlsd": 3,
6
+ "normal": 4,
7
+ "segment": 5,
8
+ "tile": 6,
9
+ "repaint": 7,
10
+ }
totoro/cldm/mmdit.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Optional
3
+ import totoro.ldm.modules.diffusionmodules.mmdit
4
+
5
+ class ControlNet(totoro.ldm.modules.diffusionmodules.mmdit.MMDiT):
6
+ def __init__(
7
+ self,
8
+ num_blocks = None,
9
+ dtype = None,
10
+ device = None,
11
+ operations = None,
12
+ **kwargs,
13
+ ):
14
+ super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
15
+ # controlnet_blocks
16
+ self.controlnet_blocks = torch.nn.ModuleList([])
17
+ for _ in range(len(self.joint_blocks)):
18
+ self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
19
+
20
+ self.pos_embed_input = totoro.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
21
+ None,
22
+ self.patch_size,
23
+ self.in_channels,
24
+ self.hidden_size,
25
+ bias=True,
26
+ strict_img_size=False,
27
+ dtype=dtype,
28
+ device=device,
29
+ operations=operations
30
+ )
31
+
32
+ def forward(
33
+ self,
34
+ x: torch.Tensor,
35
+ timesteps: torch.Tensor,
36
+ y: Optional[torch.Tensor] = None,
37
+ context: Optional[torch.Tensor] = None,
38
+ hint = None,
39
+ ) -> torch.Tensor:
40
+
41
+ #weird sd3 controlnet specific stuff
42
+ y = torch.zeros_like(y)
43
+
44
+ if self.context_processor is not None:
45
+ context = self.context_processor(context)
46
+
47
+ hw = x.shape[-2:]
48
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
49
+ x += self.pos_embed_input(hint)
50
+
51
+ c = self.t_embedder(timesteps, dtype=x.dtype)
52
+ if y is not None and self.y_embedder is not None:
53
+ y = self.y_embedder(y)
54
+ c = c + y
55
+
56
+ if context is not None:
57
+ context = self.context_embedder(context)
58
+
59
+ output = []
60
+
61
+ blocks = len(self.joint_blocks)
62
+ for i in range(blocks):
63
+ context, x = self.joint_blocks[i](
64
+ context,
65
+ x,
66
+ c=c,
67
+ use_checkpoint=self.use_checkpoint,
68
+ )
69
+
70
+ out = self.controlnet_blocks[i](x)
71
+ count = self.depth // blocks
72
+ if i == blocks - 1:
73
+ count -= 1
74
+ for j in range(count):
75
+ output.append(out)
76
+
77
+ return {"output": output}
totoro/cli_args.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import os
4
+ from typing import Optional
5
+ import totoro.options
6
+
7
+
8
+ class EnumAction(argparse.Action):
9
+ """
10
+ Argparse action for handling Enums
11
+ """
12
+ def __init__(self, **kwargs):
13
+ # Pop off the type value
14
+ enum_type = kwargs.pop("type", None)
15
+
16
+ # Ensure an Enum subclass is provided
17
+ if enum_type is None:
18
+ raise ValueError("type must be assigned an Enum when using EnumAction")
19
+ if not issubclass(enum_type, enum.Enum):
20
+ raise TypeError("type must be an Enum when using EnumAction")
21
+
22
+ # Generate choices from the Enum
23
+ choices = tuple(e.value for e in enum_type)
24
+ kwargs.setdefault("choices", choices)
25
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
26
+
27
+ super(EnumAction, self).__init__(**kwargs)
28
+
29
+ self._enum = enum_type
30
+
31
+ def __call__(self, parser, namespace, values, option_string=None):
32
+ # Convert value back into an Enum
33
+ value = self._enum(values)
34
+ setattr(namespace, self.dest, value)
35
+
36
+
37
+ parser = argparse.ArgumentParser()
38
+
39
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
40
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
41
+ parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
42
+ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
43
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
44
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
45
+
46
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
47
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the totoroUI output directory.")
48
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the totoroUI temp directory (default is in the totoroUI directory).")
49
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the totoroUI input directory.")
50
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch totoroUI in the default browser.")
51
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
53
+ cm_group = parser.add_mutually_exclusive_group()
54
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
55
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
56
+
57
+
58
+ fp_group = parser.add_mutually_exclusive_group()
59
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
60
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
61
+
62
+ fpunet_group = parser.add_mutually_exclusive_group()
63
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
64
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
65
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
66
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
67
+
68
+ fpvae_group = parser.add_mutually_exclusive_group()
69
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
70
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
71
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
72
+
73
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
74
+
75
+ fpte_group = parser.add_mutually_exclusive_group()
76
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
77
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
78
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
79
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
80
+
81
+ parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
82
+
83
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
84
+
85
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
86
+
87
+ class LatentPreviewMethod(enum.Enum):
88
+ NoPreviews = "none"
89
+ Auto = "auto"
90
+ Latent2RGB = "latent2rgb"
91
+ TAESD = "taesd"
92
+
93
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
94
+
95
+ attn_group = parser.add_mutually_exclusive_group()
96
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
97
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
98
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
99
+
100
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
101
+
102
+ upcast = parser.add_mutually_exclusive_group()
103
+ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
104
+ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
105
+
106
+
107
+ vram_group = parser.add_mutually_exclusive_group()
108
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
109
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
110
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
111
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
112
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
113
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
114
+
115
+ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
116
+
117
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force totoroUI to agressively offload to regular ram instead of keeping models in vram when it can.")
118
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
119
+
120
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
121
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
122
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
123
+
124
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
125
+ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
126
+
127
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
128
+
129
+ parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
130
+
131
+ # The default built-in provider hosted under web/
132
+ DEFAULT_VERSION_STRING = "totoroanonymous/totoroUI@latest"
133
+
134
+ parser.add_argument(
135
+ "--front-end-version",
136
+ type=str,
137
+ default=DEFAULT_VERSION_STRING,
138
+ help="""
139
+ Specifies the version of the frontend to be used. This command needs internet connectivity to query and
140
+ download available frontend implementations from GitHub releases.
141
+
142
+ The version string should be in the format of:
143
+ [repoOwner]/[repoName]@[version]
144
+ where version is one of: "latest" or a valid version number (e.g. "1.0.0")
145
+ """,
146
+ )
147
+
148
+ def is_valid_directory(path: Optional[str]) -> Optional[str]:
149
+ """Validate if the given path is a directory."""
150
+ if path is None:
151
+ return None
152
+
153
+ if not os.path.isdir(path):
154
+ raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
155
+ return path
156
+
157
+ parser.add_argument(
158
+ "--front-end-root",
159
+ type=is_valid_directory,
160
+ default=None,
161
+ help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
162
+ )
163
+
164
+ if totoro.options.args_parsing:
165
+ args = parser.parse_args()
166
+ else:
167
+ args = parser.parse_args([])
168
+
169
+ if args.windows_standalone_build:
170
+ args.auto_launch = True
171
+
172
+ if args.disable_auto_launch:
173
+ args.auto_launch = False
174
+
175
+ import logging
176
+ logging_level = logging.INFO
177
+ if args.verbose:
178
+ logging_level = logging.DEBUG
179
+
180
+ logging.basicConfig(format="%(message)s", level=logging_level)
totoro/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }