|
import os |
|
import torch |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import random |
|
import json |
|
import gradio as gr |
|
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models |
|
from modelscope import dataset_snapshot_download |
|
|
|
|
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*") |
|
example_json = 'data/examples/eligen/entity_control/ui_examples.json' |
|
with open(example_json, 'r') as f: |
|
examples = json.load(f)['examples'] |
|
|
|
for idx in range(len(examples)): |
|
example_id = examples[idx]['example_id'] |
|
entity_prompts = examples[idx]['local_prompt_list'] |
|
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] |
|
|
|
def create_canvas_data(background, masks): |
|
if background.shape[-1] == 3: |
|
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)]) |
|
layers = [] |
|
for mask in masks: |
|
if mask is not None: |
|
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0] |
|
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8) |
|
layer[..., -1] = mask_single_channel |
|
layers.append(layer) |
|
else: |
|
layers.append(np.zeros_like(background)) |
|
|
|
composite = background.copy() |
|
for layer in layers: |
|
if layer.size > 0: |
|
composite = np.where(layer[..., -1:] > 0, layer, composite) |
|
return { |
|
"background": background, |
|
"layers": layers, |
|
"composite": composite, |
|
} |
|
|
|
def load_example(load_example_button): |
|
example_idx = int(load_example_button.split()[-1]) - 1 |
|
example = examples[example_idx] |
|
result = [ |
|
50, |
|
example["global_prompt"], |
|
example["negative_prompt"], |
|
example["seed"], |
|
*example["local_prompt_list"], |
|
] |
|
num_entities = len(example["local_prompt_list"]) |
|
result += [""] * (config["max_num_painter_layers"] - num_entities) |
|
masks = [] |
|
for mask in example["mask_lists"]: |
|
mask_single_channel = np.array(mask.convert("L")) |
|
masks.append(mask_single_channel) |
|
for _ in range(config["max_num_painter_layers"] - len(masks)): |
|
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8) |
|
masks.append(blank_mask) |
|
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255 |
|
canvas_data_list = [] |
|
for mask in masks: |
|
canvas_data = create_canvas_data(background, [mask]) |
|
canvas_data_list.append(canvas_data) |
|
result.extend(canvas_data_list) |
|
return result |
|
|
|
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): |
|
save_dir = os.path.join('workdirs/tmp_mask', random_dir) |
|
print(f'save to {save_dir}') |
|
os.makedirs(save_dir, exist_ok=True) |
|
for i, mask in enumerate(masks): |
|
save_path = os.path.join(save_dir, f'{i}.png') |
|
mask.save(save_path) |
|
sample = { |
|
"global_prompt": global_prompt, |
|
"mask_prompts": mask_prompts, |
|
"seed": seed, |
|
} |
|
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f: |
|
json.dump(sample, f, indent=4) |
|
|
|
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False): |
|
|
|
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) |
|
colors = [ |
|
(165, 238, 173, 80), |
|
(76, 102, 221, 80), |
|
(221, 160, 77, 80), |
|
(204, 93, 71, 80), |
|
(145, 187, 149, 80), |
|
(134, 141, 172, 80), |
|
(157, 137, 109, 80), |
|
(153, 104, 95, 80), |
|
(165, 238, 173, 80), |
|
(76, 102, 221, 80), |
|
(221, 160, 77, 80), |
|
(204, 93, 71, 80), |
|
(145, 187, 149, 80), |
|
(134, 141, 172, 80), |
|
(157, 137, 109, 80), |
|
(153, 104, 95, 80), |
|
] |
|
|
|
if use_random_colors: |
|
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] |
|
|
|
try: |
|
font = ImageFont.truetype("arial", font_size) |
|
except IOError: |
|
font = ImageFont.load_default(font_size) |
|
|
|
for mask, mask_prompt, color in zip(masks, mask_prompts, colors): |
|
if mask is None: |
|
continue |
|
|
|
mask_rgba = mask.convert('RGBA') |
|
mask_data = mask_rgba.getdata() |
|
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] |
|
mask_rgba.putdata(new_data) |
|
|
|
draw = ImageDraw.Draw(mask_rgba) |
|
mask_bbox = mask.getbbox() |
|
if mask_bbox is None: |
|
continue |
|
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) |
|
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) |
|
|
|
overlay = Image.alpha_composite(overlay, mask_rgba) |
|
|
|
result = Image.alpha_composite(image.convert('RGBA'), overlay) |
|
return result |
|
|
|
config = { |
|
"model_config": { |
|
"FLUX": { |
|
"model_folder": "models/FLUX", |
|
"pipeline_class": FluxImagePipeline, |
|
"default_parameters": { |
|
"cfg_scale": 3.0, |
|
"embedded_guidance": 3.5, |
|
"num_inference_steps": 30, |
|
} |
|
}, |
|
}, |
|
"max_num_painter_layers": 8, |
|
"max_num_model_cache": 1, |
|
} |
|
|
|
model_dict = {} |
|
|
|
def load_model(model_type='FLUX', model_path='FLUX.1-dev'): |
|
global model_dict |
|
model_key = f"{model_type}:{model_path}" |
|
if model_key in model_dict: |
|
return model_dict[model_key] |
|
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) |
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) |
|
model_manager.load_lora( |
|
download_customized_models( |
|
model_id="DiffSynth-Studio/Eligen", |
|
origin_file_path="model_bf16.safetensors", |
|
local_dir="models/lora/entity_control", |
|
), |
|
lora_alpha=1, |
|
) |
|
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) |
|
model_dict[model_key] = model_manager, pipe |
|
return model_manager, pipe |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown( |
|
"""## EliGen: Entity-Level Controllable Text-to-Image Model |
|
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river." |
|
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results. |
|
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images. |
|
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.** |
|
""" |
|
) |
|
|
|
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True) |
|
main_interface = gr.Column(visible=False) |
|
|
|
def initialize_model(): |
|
try: |
|
load_model() |
|
return { |
|
loading_status: gr.update(value="Model loaded successfully!", visible=False), |
|
main_interface: gr.update(visible=True), |
|
} |
|
except Exception as e: |
|
print(f'Failed to load model with error: {e}') |
|
return { |
|
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True), |
|
main_interface: gr.update(visible=True), |
|
} |
|
|
|
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface]) |
|
|
|
with main_interface: |
|
with gr.Row(): |
|
local_prompt_list = [] |
|
canvas_list = [] |
|
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}') |
|
with gr.Column(scale=382, min_width=100): |
|
model_type = gr.State('FLUX') |
|
model_path = gr.State('FLUX.1-dev') |
|
with gr.Accordion(label="Global prompt"): |
|
prompt = gr.Textbox(label="Global Prompt", lines=3) |
|
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3) |
|
with gr.Accordion(label="Inference Options", open=True): |
|
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True) |
|
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps") |
|
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale") |
|
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale") |
|
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") |
|
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") |
|
with gr.Accordion(label="Inpaint Input Image", open=False): |
|
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil") |
|
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False) |
|
|
|
with gr.Column(): |
|
reset_input_button = gr.Button(value="Reset Inpaint Input") |
|
send_input_to_painter = gr.Button(value="Set as painter's background") |
|
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click) |
|
def reset_input_image(input_image): |
|
return None |
|
|
|
with gr.Column(scale=618, min_width=100): |
|
with gr.Accordion(label="Entity Painter"): |
|
for painter_layer_id in range(config["max_num_painter_layers"]): |
|
with gr.Tab(label=f"Entity {painter_layer_id}"): |
|
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") |
|
canvas = gr.ImageEditor( |
|
canvas_size=(512, 512), |
|
sources=None, |
|
layers=False, |
|
interactive=True, |
|
image_mode="RGBA", |
|
brush=gr.Brush( |
|
default_size=50, |
|
default_color="#000000", |
|
colors=["#000000"], |
|
), |
|
label="Entity Mask Painter", |
|
key=f"canvas_{painter_layer_id}", |
|
width=width, |
|
height=height, |
|
) |
|
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden") |
|
def resize_canvas(height, width, canvas): |
|
h, w = canvas["background"].shape[:2] |
|
if h != height or width != w: |
|
return np.ones((height, width, 3), dtype=np.uint8) * 255 |
|
else: |
|
return canvas |
|
local_prompt_list.append(local_prompt) |
|
canvas_list.append(canvas) |
|
with gr.Accordion(label="Results"): |
|
run_button = gr.Button(value="Generate", variant="primary") |
|
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") |
|
with gr.Row(): |
|
with gr.Column(): |
|
output_to_painter_button = gr.Button(value="Set as painter's background") |
|
with gr.Column(): |
|
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting") |
|
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False) |
|
real_output = gr.State(None) |
|
mask_out = gr.State(None) |
|
|
|
@gr.on( |
|
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list, |
|
outputs=[output_image, real_output, mask_out], |
|
triggers=run_button.click |
|
) |
|
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): |
|
_, pipe = load_model(model_type, model_path) |
|
input_params = { |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"cfg_scale": cfg_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"height": height, |
|
"width": width, |
|
"progress_bar_cmd": progress.tqdm, |
|
} |
|
if isinstance(pipe, FluxImagePipeline): |
|
input_params["embedded_guidance"] = embedded_guidance |
|
if input_image is not None: |
|
input_params["input_image"] = input_image.resize((width, height)).convert("RGB") |
|
input_params["enable_eligen_inpaint"] = True |
|
|
|
local_prompt_list, canvas_list = ( |
|
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], |
|
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], |
|
) |
|
local_prompts, masks = [], [] |
|
for local_prompt, canvas in zip(local_prompt_list, canvas_list): |
|
if isinstance(local_prompt, str) and len(local_prompt) > 0: |
|
local_prompts.append(local_prompt) |
|
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) |
|
entity_masks = None if len(masks) == 0 else masks |
|
entity_prompts = None if len(local_prompts) == 0 else local_prompts |
|
input_params.update({ |
|
"eligen_entity_prompts": entity_prompts, |
|
"eligen_entity_masks": entity_masks, |
|
}) |
|
torch.manual_seed(seed) |
|
|
|
image = pipe(**input_params) |
|
masks = [mask.resize(image.size) for mask in masks] |
|
image_with_mask = visualize_masks(image, masks, local_prompts) |
|
|
|
real_output = gr.State(image) |
|
mask_out = gr.State(image_with_mask) |
|
|
|
if return_with_mask: |
|
return image_with_mask, real_output, mask_out |
|
return image, real_output, mask_out |
|
|
|
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click) |
|
def send_input_to_painter_background(input_image, *canvas_list): |
|
if input_image is None: |
|
return tuple(canvas_list) |
|
for canvas in canvas_list: |
|
h, w = canvas["background"].shape[:2] |
|
canvas["background"] = input_image.resize((w, h)) |
|
return tuple(canvas_list) |
|
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click) |
|
def send_output_to_painter_background(real_output, *canvas_list): |
|
if real_output is None: |
|
return tuple(canvas_list) |
|
for canvas in canvas_list: |
|
h, w = canvas["background"].shape[:2] |
|
canvas["background"] = real_output.value.resize((w, h)) |
|
return tuple(canvas_list) |
|
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden") |
|
def show_output(return_with_mask, real_output, mask_out): |
|
if return_with_mask: |
|
return mask_out.value |
|
else: |
|
return real_output.value |
|
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click) |
|
def send_output_to_pipe_input(real_output): |
|
return real_output.value |
|
|
|
with gr.Column(): |
|
gr.Markdown("## Examples") |
|
for i in range(0, len(examples), 2): |
|
with gr.Row(): |
|
if i < len(examples): |
|
example = examples[i] |
|
with gr.Column(): |
|
example_image = gr.Image( |
|
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png", |
|
label=example["description"], |
|
interactive=False, |
|
width=1024, |
|
height=512 |
|
) |
|
load_example_button = gr.Button(value=f"Load Example {example['example_id']}") |
|
load_example_button.click( |
|
load_example, |
|
inputs=[load_example_button], |
|
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list |
|
) |
|
|
|
if i + 1 < len(examples): |
|
example = examples[i + 1] |
|
with gr.Column(): |
|
example_image = gr.Image( |
|
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png", |
|
label=example["description"], |
|
interactive=False, |
|
width=1024, |
|
height=512 |
|
) |
|
load_example_button = gr.Button(value=f"Load Example {example['example_id']}") |
|
load_example_button.click( |
|
load_example, |
|
inputs=[load_example_button], |
|
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list |
|
) |
|
app.config["show_progress"] = "hidden" |
|
app.launch() |
|
|