hlky HF staff commited on
Commit
c2c693f
·
verified ·
1 Parent(s): ffd5f31

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -62
handler.py CHANGED
@@ -1,88 +1,72 @@
1
- from typing import Any, Dict, List, Tuple
2
 
3
- import base64
4
  import torch
5
 
6
  from diffusers import AutoencoderKL
7
- from safetensors.torch import _tobytes
8
 
9
- def prepare_tensor(tensor: torch.Tensor) -> Tuple[str, List[int], str]:
10
- tensor_data = base64.b64encode(_tobytes(tensor, "inputs")).decode("utf-8")
11
- shape = list(tensor.shape)
12
- dtype = str(tensor.dtype).split(".")[-1]
13
- return tensor_data, shape, dtype
14
-
15
-
16
- def unpack_tensor(tensor_data: bytes, shape: List[int], dtype: str) -> torch.Tensor:
17
- tensor = tensor_data
18
- DTYPE_MAP = {
19
- "float16": torch.float16,
20
- "float32": torch.float32,
21
- "bfloat16": torch.bfloat16,
22
- }
23
- torch_dtype = DTYPE_MAP.get(dtype)
24
- tensor = torch.frombuffer(bytearray(tensor), dtype=torch_dtype).reshape(shape)
25
- return tensor
26
 
27
  class EndpointHandler:
28
  def __init__(self, path=""):
29
  self.device = "cuda"
30
  self.dtype = torch.float16
31
- self.vae = AutoencoderKL.from_pretrained(path, torch_dtype=self.dtype).to(self.device, self.dtype).eval()
 
 
 
32
 
33
  @torch.no_grad()
34
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
35
  """
36
  Args:
37
  data (:obj:):
38
  includes the input data and the parameters for the inference.
39
  """
40
- tensor_data = data["inputs"]
41
- parameters = data.get("parameters", {})
42
- if "shape" not in parameters:
43
- raise ValueError("Expected `shape` in parameters.")
44
- if "dtype" not in parameters:
45
- raise ValueError("Expected `dtype` in parameters.")
46
-
47
- shape = parameters.get("shape")
48
- dtype = parameters.get("dtype")
49
-
50
- tensor = unpack_tensor(tensor_data, shape, dtype)
51
 
52
  tensor = tensor.to(self.device, self.dtype)
53
 
54
- # unscale/denormalize the latents
55
- # denormalize with the mean and std if available and not None
56
- has_latents_mean = (
57
- hasattr(self.vae.config, "latents_mean")
58
- and self.vae.config.latents_mean is not None
59
- )
60
- has_latents_std = (
61
- hasattr(self.vae.config, "latents_std")
62
- and self.vae.config.latents_std is not None
63
- )
64
- if has_latents_mean and has_latents_std:
65
- latents_mean = (
66
- torch.tensor(self.vae.config.latents_mean)
67
- .view(1, 4, 1, 1)
68
- .to(tensor.device, tensor.dtype)
69
- )
70
- latents_std = (
71
- torch.tensor(self.vae.config.latents_std)
72
- .view(1, 4, 1, 1)
73
- .to(tensor.device, tensor.dtype)
74
  )
75
- tensor = (
76
- tensor * latents_std / self.vae.config.scaling_factor + latents_mean
 
77
  )
78
- else:
79
- tensor = tensor / self.vae.config.scaling_factor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  with torch.no_grad():
82
- image: torch.Tensor = self.vae.decode(tensor, return_dict=False)[0]
83
 
84
- image = (image * 0.5 + 0.5).clamp(0, 1)
85
- image = image.permute(0, 2, 3, 1).contiguous().float()
86
- image = (image * 255).round().to(torch.uint8)
 
 
 
87
 
88
- return _tobytes(image, "image")
 
1
+ from typing import cast, Union
2
 
3
+ import PIL.Image
4
  import torch
5
 
6
  from diffusers import AutoencoderKL
7
+ from diffusers.image_processor import VaeImageProcessor
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  self.device = "cuda"
13
  self.dtype = torch.float16
14
+ self.vae = cast(AutoencoderKL, AutoencoderKL.from_pretrained(path, torch_dtype=self.dtype).to(self.device, self.dtype).eval())
15
+
16
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
17
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
 
19
  @torch.no_grad()
20
+ def __call__(self, data) -> Union[torch.Tensor, PIL.Image.Image]:
21
  """
22
  Args:
23
  data (:obj:):
24
  includes the input data and the parameters for the inference.
25
  """
26
+ tensor = cast(torch.Tensor, data["inputs"])
27
+ parameters = cast(dict, data.get("parameters", {}))
28
+ do_scaling = cast(bool, parameters.get("do_scaling", True))
29
+ output_type = cast(str, parameters.get("output_type", "pil"))
30
+ partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
31
+ if partial_postprocess and output_type != "pt":
32
+ output_type = "pt"
 
 
 
 
33
 
34
  tensor = tensor.to(self.device, self.dtype)
35
 
36
+ if do_scaling:
37
+ has_latents_mean = (
38
+ hasattr(self.vae.config, "latents_mean")
39
+ and self.vae.config.latents_mean is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
+ has_latents_std = (
42
+ hasattr(self.vae.config, "latents_std")
43
+ and self.vae.config.latents_std is not None
44
  )
45
+ if has_latents_mean and has_latents_std:
46
+ latents_mean = (
47
+ torch.tensor(self.vae.config.latents_mean)
48
+ .view(1, 4, 1, 1)
49
+ .to(tensor.device, tensor.dtype)
50
+ )
51
+ latents_std = (
52
+ torch.tensor(self.vae.config.latents_std)
53
+ .view(1, 4, 1, 1)
54
+ .to(tensor.device, tensor.dtype)
55
+ )
56
+ tensor = (
57
+ tensor * latents_std / self.vae.config.scaling_factor + latents_mean
58
+ )
59
+ else:
60
+ tensor = tensor / self.vae.config.scaling_factor
61
 
62
  with torch.no_grad():
63
+ image = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
64
 
65
+ if partial_postprocess:
66
+ image = (image * 0.5 + 0.5).clamp(0, 1)
67
+ image = image.permute(0, 2, 3, 1).contiguous().float()
68
+ image = (image * 255).round().to(torch.uint8)
69
+ elif output_type == "pil":
70
+ image = cast(PIL.Image.Image, self.image_processor.postprocess(image, output_type="pil")[0])
71
 
72
+ return image