|
import click |
|
import torch |
|
import os |
|
from PIL import Image |
|
from io import BytesIO |
|
from itertools import batched |
|
from tqdm import tqdm |
|
import torchvision.transforms as T |
|
|
|
from diffusers import StableDiffusionImg2ImgPipeline |
|
|
|
|
|
@click.command() |
|
@click.option("--input") |
|
@click.option("--output") |
|
@click.option("--prompt") |
|
@click.option("--strength", type=float, default=0.5) |
|
@click.option("--batch_size", type=int, default=1) |
|
def sdedit(input, output, prompt, strength, batch_size): |
|
|
|
os.makedirs(output, exist_ok=True) |
|
|
|
prompts = [prompt] * batch_size |
|
|
|
batches = list(batched(os.listdir(input), batch_size)) |
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 |
|
).to("cuda") |
|
|
|
transform = T.Compose([T.Resize(512), T.CenterCrop(512)]) |
|
|
|
for batch in tqdm(batches): |
|
images = [Image.open(os.path.join(input, name)) for name in batch] |
|
images = [image.resize((768, 512)) for image in images] |
|
|
|
output_images = pipe(prompt=prompts, image=images, strength=strength).images |
|
|
|
for name, output_image in zip(batch, output_images): |
|
output_image = output_image.resize((512, 512)) |
|
output_image.save(os.path.join(output, name)) |
|
|
|
|
|
if __name__ == "__main__": |
|
sdedit() |
|
|