File size: 3,062 Bytes
dc39a0d
e0c0641
dc39a0d
6f23e96
 
dc39a0d
e06cb5f
dc39a0d
 
 
 
 
 
 
 
 
 
e06cb5f
 
 
 
 
dc39a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c0641
dc39a0d
 
 
 
d063207
dc39a0d
 
 
 
e0c0641
 
dc39a0d
e0c0641
 
 
 
 
6f23e96
a57f986
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Dict, List, Any
import base64

import logging

import tensorflow as tf
from tensorflow import keras
from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder
from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS

class EndpointHandler():
    def __init__(self, path=""):
        self.MAX_PROMPT_LENGTH = 77

        self.tokenizer = SimpleTokenizer()
        self.text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
        text_encoder_weights_fpath = keras.utils.get_file(
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
            file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
        )
        self.text_encoder.load_weights(text_encoder_weights_fpath)        
        self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)    

    def _get_unconditional_context(self):
        unconditional_tokens = tf.convert_to_tensor(
            [_UNCONDITIONAL_TOKENS], dtype=tf.int32
        )
        unconditional_context = self.text_encoder.predict_on_batch(
            [unconditional_tokens, self.pos_ids]
        )

        return unconditional_context

    def encode_text(self, prompt):
      # Tokenize prompt (i.e. starting context)
      inputs = self.tokenizer.encode(prompt)
      if len(inputs) > self.MAX_PROMPT_LENGTH:
          raise ValueError(
              f"Prompt is too long (should be <= {self.MAX_PROMPT_LENGTH} tokens)"
          )
      phrase = inputs + [49407] * (self.MAX_PROMPT_LENGTH - len(inputs))
      phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)

      context = self.text_encoder.predict_on_batch([phrase, self.pos_ids])

      return context  

    def get_contexts(self, encoded_text, batch_size):
        encoded_text = tf.squeeze(encoded_text)
        if encoded_text.shape.rank == 2:
            encoded_text = tf.repeat(
                tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
            )

        context = encoded_text

        unconditional_context = tf.repeat(
            self._get_unconditional_context(), batch_size, axis=0
        )  

        return context, unconditional_context

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        # get inputs 
        prompt = data.pop("inputs", data)
        batch_size = data.pop("batch_size", 1)

        encoded_text = self.encode_text(prompt)
        context, unconditional_context = self.get_contexts(encoded_text, batch_size)

        context_b64 = base64.b64encode(context.numpy().tobytes())
        context_b64str = context_b64.decode()

        unconditional_context_b64 = base64.b64encode(unconditional_context.numpy().tobytes())
        unconditional_context_b64str = unconditional_context_b64.decode()        
        
        return {"context_b64str": context_b64str, "unconditional_context_b64str": unconditional_context_b64str}