Results

#3
by Kwissbeats - opened

I can't get the same results as "nvidia/parakeet-tdt-0.6b-v2" with these onnx files.

first i tried to replicate the feature extraction in c++ but failed to get really close because of the blackbox "torch.stft" is.
Then I saw your nemo128 file and I thought here is my solution!

but I was trying to get it running in python with the code below first. but it is missing a lot punctuation and long words compared to the original non- onnx version.
is this expected? I did not yet compared the result of nemo128 onnx file with the original, and i will certainly try that next.

But maybe I'm missing something obvious?

here is the code:

import numpy as np
import librosa
import onnxruntime as ort
from pathlib import Path

Load and resample audio

audio_path = "test.wav" # Replace with your audio file
target_sr = 16000
audio, sr = librosa.load(audio_path, sr=None)
if sr != target_sr:
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)

Nemo128 feature extraction

session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.intra_op_num_threads = 1
session_options.inter_op_num_threads = 1
nemo_session = ort.InferenceSession("nemo128.onnx", session_options, providers=["CPUExecutionProvider"])
nemo_input_names = [inp.name for inp in nemo_session.get_inputs()]
audio_input = audio[np.newaxis, :] # [1, N]
audio_length = np.array([len(audio)], dtype=np.int64) # [1]
nemo_inputs = {
"waveforms": audio_input.astype(np.float32),
"waveforms_lens": audio_length
}
nemo_outputs = nemo_session.run(None, nemo_inputs)
features, features_lens = nemo_outputs[0], nemo_outputs[1] # [1, 128, T], [1]

Load symbol table

symbol_table = {}
with open("vocab.txt", "r", encoding="utf-8") as f:
for line in f:
token, idx = line.strip().split()
symbol_table[int(idx)] = token
blank_id = [k for k, v in symbol_table.items() if v == ""][0] # Blank ID

Load Parakeet models

encoder_session = ort.InferenceSession("encoder-model.onnx", session_options, providers=["CPUExecutionProvider"])
decoder_session = ort.InferenceSession("decoder_joint-model.onnx", session_options, providers=["CPUExecutionProvider"])
encoder_input_names = [inp.name for inp in encoder_session.get_inputs()]
decoder_input_names = [inp.name for inp in decoder_session.get_inputs()]

Run encoder

encoder_inputs = {
encoder_input_names[0]: features.astype(np.float32),
encoder_input_names[1]: features_lens
}
encoder_outputs = encoder_session.run(None, encoder_inputs)
encoder_out, encoded_lengths = encoder_outputs[0], encoder_outputs[1]
num_subsampled_frames = encoded_lengths[0]

Greedy decoding

hypothesis = []
decoder_input = np.array([[blank_id]], dtype=np.int32)
state_shape = [2, 1, 640]
current_states = [np.zeros(state_shape, dtype=np.float32), np.zeros(state_shape, dtype=np.float32)]

for t in range(num_subsampled_frames):
encoder_frame = encoder_out[:, :, t:t+1]
target_length = np.array([1], dtype=np.int32)
decoder_inputs = {
decoder_input_names[0]: encoder_frame,
decoder_input_names[1]: decoder_input,
decoder_input_names[2]: target_length,
decoder_input_names[3]: current_states[0],
decoder_input_names[4]: current_states[1]
}
decoder_outputs = decoder_session.run(None, decoder_inputs)
logits, _, state1, state2 = decoder_outputs
max_id = np.argmax(logits[0, 0, :])
if max_id != blank_id:
hypothesis.append(max_id)
decoder_input[0, 0] = max_id
current_states = [state1, state2]

Decode tokens

result_text = "".join(symbol_table.get(id, "") for id in hypothesis)
result_text = result_text.replace("\u2581", " ") # Replace Unicode U+2581 with space
if result_text.startswith(" "):
result_text = result_text[1:]
print("Transcription:", result_text)

Sorry for the late reply, was on vacation. If this is still relevant, I think the problem is that in TDT models the logit vector consists of two parts: token probabilities and time step probabilities. You can see the decoding code in my onnx-asr library.

Sign up or log in to comment