# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the NVIDIA Open Model License (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/ # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Load one of the NeMo speaker diarization models: [Streaming Sortformer Diarizer v2](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1), [Streaming Sortformer Diarizer v2.1](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1) """ ```python from nemo.collections.asr.models import SortformerEncLabelModel, ASRModel import torch # A speaker diarization model is needed for tracking the speech activity of each speaker. diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1").eval().to(torch.device("cuda")) asr_model = ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1.nemo").eval().to(torch.device("cuda")) # Use the pre-defined dataclass template `MultitalkerTranscriptionConfig` from `multitalker_transcript_config.py`. # Configure the diarization model using streaming parameters: from multitalker_transcript_config import MultitalkerTranscriptionConfig from omegaconf import OmegaConf cfg = OmegaConf.structured(MultitalkerTranscriptionConfig()) cfg.audio_file = "/path/to/your/audio.wav" cfg.output_path = "/path/to/output_transcription.json" diar_model = MultitalkerTranscriptionConfig.init_diar_model(cfg, diar_model) # Load your audio file into a streaming audio buffer to simulate a real-time audio session. from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer samples = [{'audio_filepath': cfg.audio_file}] streaming_buffer = CacheAwareStreamingAudioBuffer( model=asr_model, online_normalization=cfg.online_normalization, pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, ) streaming_buffer.append_audio_file(audio_filepath=cfg.audio_file, stream_id=-1) streaming_buffer_iter = iter(streaming_buffer) # Use the helper class `SpeakerTaggedASR`, which handles all ASR and diarization cache data for streaming. from nemo.collections.asr.parts.utils.multispk_transcribe_utils import SpeakerTaggedASR multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model) for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): drop_extra_pre_encoded = ( 0 if step_num == 0 and not cfg.pad_and_drop_preencoded else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded ) with torch.inference_mode(): with torch.amp.autocast(diar_model.device.type, enabled=True): with torch.no_grad(): multispk_asr_streamer.perform_parallel_streaming_stt_spk( step_num=step_num, chunk_audio=chunk_audio, chunk_lengths=chunk_lengths, is_buffer_empty=streaming_buffer.is_buffer_empty(), drop_extra_pre_encoded=drop_extra_pre_encoded, ) # Generate the speaker-tagged transcript and print it. multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples) print(multispk_asr_streamer.instance_manager.seglst_dict_list)