|
import torch |
|
import numpy as np |
|
from .processors import Processor_id |
|
|
|
|
|
class ControlNetConfigUnit: |
|
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False): |
|
self.processor_id = processor_id |
|
self.model_path = model_path |
|
self.scale = scale |
|
self.skip_processor = skip_processor |
|
|
|
|
|
class ControlNetUnit: |
|
def __init__(self, processor, model, scale=1.0): |
|
self.processor = processor |
|
self.model = model |
|
self.scale = scale |
|
|
|
|
|
class MultiControlNetManager: |
|
def __init__(self, controlnet_units=[]): |
|
self.processors = [unit.processor for unit in controlnet_units] |
|
self.models = [unit.model for unit in controlnet_units] |
|
self.scales = [unit.scale for unit in controlnet_units] |
|
|
|
def cpu(self): |
|
for model in self.models: |
|
model.cpu() |
|
|
|
def to(self, device): |
|
for model in self.models: |
|
model.to(device) |
|
for processor in self.processors: |
|
processor.to(device) |
|
|
|
def process_image(self, image, processor_id=None): |
|
if processor_id is None: |
|
processed_image = [processor(image) for processor in self.processors] |
|
else: |
|
processed_image = [self.processors[processor_id](image)] |
|
processed_image = torch.concat([ |
|
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) |
|
for image_ in processed_image |
|
], dim=0) |
|
return processed_image |
|
|
|
def __call__( |
|
self, |
|
sample, timestep, encoder_hidden_states, conditionings, |
|
tiled=False, tile_size=64, tile_stride=32, **kwargs |
|
): |
|
res_stack = None |
|
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): |
|
res_stack_ = model( |
|
sample, timestep, encoder_hidden_states, conditioning, **kwargs, |
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, |
|
processor_id=processor.processor_id |
|
) |
|
res_stack_ = [res * scale for res in res_stack_] |
|
if res_stack is None: |
|
res_stack = res_stack_ |
|
else: |
|
res_stack = [i + j for i, j in zip(res_stack, res_stack_)] |
|
return res_stack |
|
|
|
|
|
class FluxMultiControlNetManager(MultiControlNetManager): |
|
def __init__(self, controlnet_units=[]): |
|
super().__init__(controlnet_units=controlnet_units) |
|
|
|
def process_image(self, image, processor_id=None): |
|
if processor_id is None: |
|
processed_image = [processor(image) for processor in self.processors] |
|
else: |
|
processed_image = [self.processors[processor_id](image)] |
|
return processed_image |
|
|
|
def __call__(self, conditionings, **kwargs): |
|
res_stack, single_res_stack = None, None |
|
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): |
|
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs) |
|
res_stack_ = [res * scale for res in res_stack_] |
|
single_res_stack_ = [res * scale for res in single_res_stack_] |
|
if res_stack is None: |
|
res_stack = res_stack_ |
|
single_res_stack = single_res_stack_ |
|
else: |
|
res_stack = [i + j for i, j in zip(res_stack, res_stack_)] |
|
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)] |
|
return res_stack, single_res_stack |
|
|