chansung commited on
Commit
e06cb5f
·
1 Parent(s): e0c0641

add custom handler

Browse files
Files changed (2) hide show
  1. __pycache__/handler.cpython-38.pyc +0 -0
  2. handler.py +6 -0
__pycache__/handler.cpython-38.pyc CHANGED
Binary files a/__pycache__/handler.cpython-38.pyc and b/__pycache__/handler.cpython-38.pyc differ
 
handler.py CHANGED
@@ -2,6 +2,7 @@ from typing import Dict, List, Any
2
  import base64
3
 
4
  import tensorflow as tf
 
5
  from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder
6
  from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer
7
  from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
@@ -12,6 +13,11 @@ class EndpointHandler():
12
 
13
  self.tokenizer = SimpleTokenizer()
14
  self.text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
 
 
 
 
 
15
  self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)
16
 
17
  def _get_unconditional_context(self):
 
2
  import base64
3
 
4
  import tensorflow as tf
5
+ from tensorflow import keras
6
  from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder
7
  from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer
8
  from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
 
13
 
14
  self.tokenizer = SimpleTokenizer()
15
  self.text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH)
16
+ text_encoder_weights_fpath = keras.utils.get_file(
17
+ origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
18
+ file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
19
+ )
20
+ self.text_encoder.load_weights(text_encoder_weights_fpath)
21
  self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32)
22
 
23
  def _get_unconditional_context(self):