Taejin's picture
Adding exampl.py
719134e
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the NVIDIA Open Model License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Load one of the NeMo speaker diarization models:
[Streaming Sortformer Diarizer v2](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1),
[Streaming Sortformer Diarizer v2.1](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1)
"""
```python
from nemo.collections.asr.models import SortformerEncLabelModel, ASRModel
import torch
# A speaker diarization model is needed for tracking the speech activity of each speaker.
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1").eval().to(torch.device("cuda"))
asr_model = ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1.nemo").eval().to(torch.device("cuda"))
# Use the pre-defined dataclass template `MultitalkerTranscriptionConfig` from `multitalker_transcript_config.py`.
# Configure the diarization model using streaming parameters:
from multitalker_transcript_config import MultitalkerTranscriptionConfig
from omegaconf import OmegaConf
cfg = OmegaConf.structured(MultitalkerTranscriptionConfig())
cfg.audio_file = "/path/to/your/audio.wav"
cfg.output_path = "/path/to/output_transcription.json"
diar_model = MultitalkerTranscriptionConfig.init_diar_model(cfg, diar_model)
# Load your audio file into a streaming audio buffer to simulate a real-time audio session.
from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer
samples = [{'audio_filepath': cfg.audio_file}]
streaming_buffer = CacheAwareStreamingAudioBuffer(
model=asr_model,
online_normalization=cfg.online_normalization,
pad_and_drop_preencoded=cfg.pad_and_drop_preencoded,
)
streaming_buffer.append_audio_file(audio_filepath=cfg.audio_file, stream_id=-1)
streaming_buffer_iter = iter(streaming_buffer)
# Use the helper class `SpeakerTaggedASR`, which handles all ASR and diarization cache data for streaming.
from nemo.collections.asr.parts.utils.multispk_transcribe_utils import SpeakerTaggedASR
multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model)
for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
drop_extra_pre_encoded = (
0
if step_num == 0 and not cfg.pad_and_drop_preencoded
else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded
)
with torch.inference_mode():
with torch.amp.autocast(diar_model.device.type, enabled=True):
with torch.no_grad():
multispk_asr_streamer.perform_parallel_streaming_stt_spk(
step_num=step_num,
chunk_audio=chunk_audio,
chunk_lengths=chunk_lengths,
is_buffer_empty=streaming_buffer.is_buffer_empty(),
drop_extra_pre_encoded=drop_extra_pre_encoded,
)
# Generate the speaker-tagged transcript and print it.
multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples)
print(multispk_asr_streamer.instance_manager.seglst_dict_list)