import click import torch import os from PIL import Image from io import BytesIO from itertools import batched from tqdm import tqdm import torchvision.transforms as T from diffusers import StableDiffusionImg2ImgPipeline @click.command() @click.option("--input") @click.option("--output") @click.option("--prompt") @click.option("--strength", type=float, default=0.5) @click.option("--batch_size", type=int, default=1) def sdedit(input, output, prompt, strength, batch_size): os.makedirs(output, exist_ok=True) prompts = [prompt] * batch_size batches = list(batched(os.listdir(input), batch_size)) pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ).to("cuda") transform = T.Compose([T.Resize(512), T.CenterCrop(512)]) for batch in tqdm(batches): images = [Image.open(os.path.join(input, name)) for name in batch] images = [image.resize((768, 512)) for image in images] output_images = pipe(prompt=prompts, image=images, strength=strength).images for name, output_image in zip(batch, output_images): output_image = output_image.resize((512, 512)) output_image.save(os.path.join(output, name)) if __name__ == "__main__": sdedit()