File size: 1,302 Bytes
cc7ad25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import click
import torch
import os
from PIL import Image
from io import BytesIO
from itertools import batched
from tqdm import tqdm
import torchvision.transforms as T

from diffusers import StableDiffusionImg2ImgPipeline


@click.command()
@click.option("--input")
@click.option("--output")
@click.option("--prompt")
@click.option("--strength", type=float, default=0.5)
@click.option("--batch_size", type=int, default=1)
def sdedit(input, output, prompt, strength, batch_size):

    os.makedirs(output, exist_ok=True)

    prompts = [prompt] * batch_size

    batches = list(batched(os.listdir(input), batch_size))
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
    ).to("cuda")

    transform = T.Compose([T.Resize(512), T.CenterCrop(512)])

    for batch in tqdm(batches):
        images = [Image.open(os.path.join(input, name)) for name in batch]
        images = [image.resize((768, 512)) for image in images]

        output_images = pipe(prompt=prompts, image=images, strength=strength).images

        for name, output_image in zip(batch, output_images):
            output_image = output_image.resize((512, 512))
            output_image.save(os.path.join(output, name))


if __name__ == "__main__":
    sdedit()