File size: 3,062 Bytes
dc39a0d e0c0641 dc39a0d 6f23e96 dc39a0d e06cb5f dc39a0d e06cb5f dc39a0d e0c0641 dc39a0d d063207 dc39a0d e0c0641 dc39a0d e0c0641 6f23e96 a57f986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from typing import Dict, List, Any
import base64
import logging
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder
from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
class EndpointHandler():
def __init__(self, path=""):
self.MAX_PROMPT_LENGTH = 77
self.tokenizer = SimpleTokenizer()
self.text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
text_encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
)
self.text_encoder.load_weights(text_encoder_weights_fpath)
self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)
def _get_unconditional_context(self):
unconditional_tokens = tf.convert_to_tensor(
[_UNCONDITIONAL_TOKENS], dtype=tf.int32
)
unconditional_context = self.text_encoder.predict_on_batch(
[unconditional_tokens, self.pos_ids]
)
return unconditional_context
def encode_text(self, prompt):
# Tokenize prompt (i.e. starting context)
inputs = self.tokenizer.encode(prompt)
if len(inputs) > self.MAX_PROMPT_LENGTH:
raise ValueError(
f"Prompt is too long (should be <= {self.MAX_PROMPT_LENGTH} tokens)"
)
phrase = inputs + [49407] * (self.MAX_PROMPT_LENGTH - len(inputs))
phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
context = self.text_encoder.predict_on_batch([phrase, self.pos_ids])
return context
def get_contexts(self, encoded_text, batch_size):
encoded_text = tf.squeeze(encoded_text)
if encoded_text.shape.rank == 2:
encoded_text = tf.repeat(
tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
)
context = encoded_text
unconditional_context = tf.repeat(
self._get_unconditional_context(), batch_size, axis=0
)
return context, unconditional_context
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# get inputs
prompt = data.pop("inputs", data)
batch_size = data.pop("batch_size", 1)
encoded_text = self.encode_text(prompt)
context, unconditional_context = self.get_contexts(encoded_text, batch_size)
context_b64 = base64.b64encode(context.numpy().tobytes())
context_b64str = context_b64.decode()
unconditional_context_b64 = base64.b64encode(unconditional_context.numpy().tobytes())
unconditional_context_b64str = unconditional_context_b64.decode()
return {"context_b64str": context_b64str, "unconditional_context_b64str": unconditional_context_b64str}
|