streaming generate

#10
by weege007 - opened
from transformers.generation.streamers import BaseStreamer

class TokenStreamer(BaseStreamer):
    def __init__(self, skip_prompt: bool = False, timeout=None):
        self.skip_prompt = skip_prompt

        # variables used in the streaming process
        self.token_queue = Queue()
        self.stop_signal = None
        self.next_tokens_are_prompt = True
        self.timeout = timeout

    def put(self, value):
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        for token in value.tolist():
            self.token_queue.put(token)

    def end(self):
        self.token_queue.put(self.stop_signal)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.token_queue.get(timeout=self.timeout)
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value

#TTS start!
with torch.no_grad():
 
    formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"

    # Tokenize the text
    chat = [
        {"role": "user", "content": "Convert the text to speech:" + formatted_text},
        {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
    ]

    input_ids = tokenizer.apply_chat_template(
        chat, 
        tokenize=True, 
        return_tensors='pt', 
        continue_final_message=True
    )
    input_ids = input_ids.to('cuda')
    speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
    streamer = TokenStreamer(skip_prompt=True)
    generation_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_length=2048,  # We trained our model with a max length of 2048
        eos_token_id= speech_end_id ,
        do_sample=True,    
        top_p=1,           #  Adjusts the diversity of generated content
        temperature=0.8,   #  Controls randomness in output
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    i = 0
    batch_size = 60
    generated_ids=[]
    j=0
    for token_id in streamer:
        print(token_id, end=',', flush=True)
        generated_ids.append(token_id)
        if i>0 and i % batch_size == 0:
            #print(generated_ids)
            speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
            # Convert  token <|s_23456|> to int 23456 
            speech_tokens = extract_speech_ids(speech_tokens)
            speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
            # Decode the speech tokens to speech waveform
            gen_wav = Codec_model.decode_code(speech_tokens)
            sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
            generated_ids=[]
            j+=1
        i += 1
        #yield token_id
    if len(generated_ids)>0:
        speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
        # Convert  token <|s_23456|> to int 23456 
        speech_tokens = extract_speech_ids(speech_tokens)
        speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
        # Decode the speech tokens to speech waveform
        gen_wav = Codec_model.decode_code(speech_tokens)
        sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)

colab 笔记:https://github.com/weedge/doraemon-nb/blob/main/LLaSA.ipynb

Hello! I just implemented something very similar. It sort of works, but I notice that there are artifacts in the resulting waveform if the number of speech tokens is too small in the input to Codec_model.decode_code().

Is there a recommendation for the minimum chunk size?

Hey -- do you mind expanding more about how to put together this for a streaming type ... I've got it to work and I've been going through all the other examples -- and have a pretty good understanding now -- but jumping to streaming and all I would love some guidance.

Please share :)
Thanks Also I love this project it rocks! I used a 15 sec clip of my voice -- and this has done the best I've seen at a mimic. I'm just learning so anything will help. thanks!

from queue import Queue
from threading import Thread
import torch
from transformers.generation.streamers import BaseStreamer
from transformers import AutoTokenizer, AutoModelForCausalLM
import soundfile as sf

import os

from xcodec2.modeling_xcodec2 import XCodec2Model
 
llasa_3b ='HKUST-Audio/Llasa-3B'

tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
model = AutoModelForCausalLM.from_pretrained(llasa_3b)
model.eval() 
model.to('cuda')

model_path = "HKUST-Audio/xcodec2"  
 
Codec_model = XCodec2Model.from_pretrained(model_path)
Codec_model.eval().cuda()   

class TokenStreamer(BaseStreamer):
    def __init__(self, skip_prompt: bool = False, timeout=None):
        self.skip_prompt = skip_prompt

        # variables used in the streaming process
        self.token_queue = Queue()
        self.stop_signal = None
        self.next_tokens_are_prompt = True
        self.timeout = timeout

    def put(self, value):
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        for token in value.tolist():
            self.token_queue.put(token)

    def end(self):
        self.token_queue.put(self.stop_signal)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.token_queue.get(timeout=self.timeout)
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value

# only 16khz speech support!
prompt_wav, sr = sf.read("speakers/speaker_1.wav")   # you can find wav in Files
#prompt_wav, sr = sf.read("Anna.wav") # English prompt
prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0)  

prompt_text ="Hello, I am speaker number 1."
#promt_text = "A chance to leave him alone, but... No. She just wanted to see him again. Anna, you don't know how it feels to lose a sister. Anna, I'm sorry, but your father asked me not to tell you anything."
target_text = 'Hi, how are you doing today?'
#target_text = "Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me."
input_text = prompt_text   + target_text

def ids_to_speech_tokens(speech_ids):
 
    speech_tokens_str = []
    for speech_id in speech_ids:
        speech_tokens_str.append(f"<|s_{speech_id}|>")
    return speech_tokens_str

def extract_speech_ids(speech_tokens_str):
 
    speech_ids = []
    for token_str in speech_tokens_str:
        if token_str.startswith('<|s_') and token_str.endswith('|>'):
            num_str = token_str[4:-2]

            num = int(num_str)
            speech_ids.append(num)
        else:
            print(f"Unexpected token: {token_str}")
    return speech_ids

#TTS start!
with torch.no_grad():
    # Encode the prompt wav
 
    formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"

    # Tokenize the text
    chat = [
        {"role": "user", "content": "Convert the text to speech:" + formatted_text},
        {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
    ]

    input_ids = tokenizer.apply_chat_template(
        chat, 
        tokenize=True, 
        return_tensors='pt', 
        continue_final_message=True
    )
    input_ids = input_ids.to('cuda')
    speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
    streamer = TokenStreamer(skip_prompt=True)
    generation_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_length=2048,  # We trained our model with a max length of 2048
        eos_token_id= speech_end_id ,
        do_sample=True,    
        top_p=1,           #  Adjusts the diversity of generated content
        temperature=0.8,   #  Controls randomness in output
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    i = 0
    batch_size = 60
    generated_ids=[]
    j=0
    for token_id in streamer:
        print(token_id, end=',', flush=True)
        generated_ids.append(token_id)
        if i>0 and i % batch_size == 0:
            #print(generated_ids)
            speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
            # Convert  token <|s_23456|> to int 23456 
            speech_tokens = extract_speech_ids(speech_tokens)
            speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
            # Decode the speech tokens to speech waveform
            gen_wav = Codec_model.decode_code(speech_tokens)
            sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
            generated_ids=[]
            j+=1
        i += 1
        #yield token_id
    if len(generated_ids)>0:
        speech_tokens = tokenizer.batch_decode(torch.tensor(generated_ids).cuda(), skip_special_tokens=True)
        # Convert  token <|s_23456|> to int 23456 
        speech_tokens = extract_speech_ids(speech_tokens)
        speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
        # Decode the speech tokens to speech waveform
        gen_wav = Codec_model.decode_code(speech_tokens)
        sf.write(f"gen_{j}.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)

Sign up or log in to comment