File size: 2,814 Bytes
bfe4ff1
 
 
b402ddd
bfe4ff1
b402ddd
 
 
bfe4ff1
b402ddd
 
6691b0e
 
bfe4ff1
b402ddd
 
 
 
 
bfe4ff1
b402ddd
 
 
 
 
bfe4ff1
 
 
 
 
 
 
b402ddd
bfe4ff1
b402ddd
bfe4ff1
 
 
 
b402ddd
bfe4ff1
 
 
b402ddd
bfe4ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b402ddd
 
bfe4ff1
b402ddd
bfe4ff1
 
 
 
 
 
b402ddd
bfe4ff1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import cast, Union

import PIL.Image
import torch

from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor


class EndpointHandler:
    def __init__(self, path=""):
        self.device = "cuda"
        self.dtype = torch.float16
        self.vae = cast(AutoencoderKL, AutoencoderKL.from_pretrained(path, torch_dtype=self.dtype).to(self.device, self.dtype).eval())

        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

    @torch.no_grad()
    def __call__(self, data) -> Union[torch.Tensor, PIL.Image.Image]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        """
        tensor = cast(torch.Tensor, data["inputs"])
        parameters = cast(dict, data.get("parameters", {}))
        do_scaling = cast(bool, parameters.get("do_scaling", True))
        output_type = cast(str, parameters.get("output_type", "pil"))
        partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
        if partial_postprocess and output_type != "pt":
            output_type = "pt"

        tensor = tensor.to(self.device, self.dtype)

        if do_scaling:
            has_latents_mean = (
                hasattr(self.vae.config, "latents_mean")
                and self.vae.config.latents_mean is not None
            )
            has_latents_std = (
                hasattr(self.vae.config, "latents_std")
                and self.vae.config.latents_std is not None
            )
            if has_latents_mean and has_latents_std:
                latents_mean = (
                    torch.tensor(self.vae.config.latents_mean)
                    .view(1, 4, 1, 1)
                    .to(tensor.device, tensor.dtype)
                )
                latents_std = (
                    torch.tensor(self.vae.config.latents_std)
                    .view(1, 4, 1, 1)
                    .to(tensor.device, tensor.dtype)
                )
                tensor = (
                    tensor * latents_std / self.vae.config.scaling_factor + latents_mean
                )
            else:
                tensor = tensor / self.vae.config.scaling_factor

        with torch.no_grad():
            image = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])

        if partial_postprocess:
            image = (image * 0.5 + 0.5).clamp(0, 1)
            image = image.permute(0, 2, 3, 1).contiguous().float()
            image = (image * 255).round().to(torch.uint8)
        elif output_type == "pil":
            image = cast(PIL.Image.Image, self.image_processor.postprocess(image, output_type="pil")[0])

        return image