|
from transformers import AutoTokenizer
|
|
|
|
|
|
class BaseStreamer:
|
|
"""
|
|
Base class from which `.generate()` streamers should inherit.
|
|
"""
|
|
|
|
def put(self, value):
|
|
"""Function that is called by `.generate()` to push new tokens"""
|
|
raise NotImplementedError()
|
|
|
|
def end(self):
|
|
"""Function that is called by `.generate()` to signal the end of generation"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
class ByteStreamer(BaseStreamer):
|
|
"""
|
|
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
|
|
|
<Tip warning={true}>
|
|
|
|
The API for the streamer classes is still under development and may change in the future.
|
|
|
|
</Tip>
|
|
|
|
Parameters:
|
|
tokenizer (`AutoTokenizer`):
|
|
The tokenized used to decode the tokens.
|
|
skip_prompt (`bool`, *optional*, defaults to `False`):
|
|
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
|
decode_kwargs (`dict`, *optional*):
|
|
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
|
|
|
>>> tok = AutoTokenizer.from_pretrained("gpt2")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
|
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
|
>>> streamer = TextStreamer(tok)
|
|
|
|
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
|
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
|
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
|
```
|
|
"""
|
|
|
|
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
|
self.tokenizer = tokenizer
|
|
self.skip_prompt = skip_prompt
|
|
self.decode_kwargs = decode_kwargs
|
|
|
|
|
|
self.token_cache = []
|
|
self.print_len = 0
|
|
self.next_tokens_are_prompt = True
|
|
|
|
def put(self, value):
|
|
"""
|
|
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
|
"""
|
|
if len(value.shape) > 1 and value.shape[0] > 1:
|
|
raise ValueError("TextStreamer only supports batch size 1")
|
|
elif len(value.shape) > 1:
|
|
value = value[0]
|
|
|
|
if self.skip_prompt and self.next_tokens_are_prompt:
|
|
self.next_tokens_are_prompt = False
|
|
return
|
|
|
|
|
|
self.token_cache.extend(value.tolist())
|
|
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
|
|
|
|
|
if text.endswith("\n"):
|
|
printable_text = text[self.print_len :]
|
|
self.token_cache = []
|
|
self.print_len = 0
|
|
else:
|
|
printable_text = text[self.print_len : self.print_len + 1]
|
|
self.print_len += len(printable_text)
|
|
|
|
self.on_finalized_text(printable_text)
|
|
|
|
def end(self):
|
|
"""Flushes any remaining cache and prints a newline to stdout."""
|
|
|
|
if len(self.token_cache) > 0:
|
|
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
|
printable_text = text[self.print_len :]
|
|
self.token_cache = []
|
|
self.print_len = 0
|
|
else:
|
|
printable_text = ""
|
|
|
|
self.next_tokens_are_prompt = True
|
|
self.on_finalized_text(printable_text, stream_end=True)
|
|
|
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
|
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
|
|
print(text, flush=True, end="" if not stream_end else None)
|
|
|