# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import multiprocessing import shutil from collections import OrderedDict from pathlib import Path from pprint import pprint from typing import Dict import matplotlib.pyplot as plt import numpy as np import seaborn as sns import sox from scipy.stats import expon from tqdm import tqdm from nemo.collections.asr.parts.utils.vad_utils import ( get_nonspeech_segments, load_speech_overlap_segments_from_rttm, plot_sample_from_rttm, ) """ This script analyzes multi-speaker speech dataset and generates statistics. The input directory is required to contain the following files: - rttm files (*.rttm) - wav files (*.wav) Usage: python /scripts/speaker_tasks/multispeaker_data_analysis.py \ \ --session_dur 20 \ --silence_mean 0.2 \ --silence_var 100 \ --overlap_mean 0.15 \ --overlap_var 50 \ --num_workers 8 \ --num_samples 10 \ --output_dir """ def process_sample(sess_dict: Dict) -> Dict: """ Process each synthetic sample Args: sess_dict (dict): dictionary containing the following keys rttm_file (str): path to the rttm file session_dur (float): duration of the session (specified by argument) precise (bool): whether to measure the precise duration of the session using sox Returns: results (dict): dictionary containing the following keys session_dur (float): duration of the session silence_len_list (list): list of silence durations of each silence occurrence silence_dur (float): total silence duration in a session silence_ratio (float): ratio of silence duration to session duration overlap_len_list (list): list of overlap durations of each overlap occurrence overlap_dur (float): total overlap duration overlap_ratio (float): ratio of overlap duration to speech (non-silence) duration """ rttm_file = sess_dict["rttm_file"] session_dur = sess_dict["session_dur"] precise = sess_dict["precise"] if precise or session_dur is None: wav_file = rttm_file.parent / Path(rttm_file.stem + ".wav") session_dur = sox.file_info.duration(str(wav_file)) speech_seg, overlap_seg = load_speech_overlap_segments_from_rttm(rttm_file) speech_dur = sum([sess_dict[1] - sess_dict[0] for sess_dict in speech_seg]) silence_seg = get_nonspeech_segments(speech_seg, session_dur) silence_len_list = [sess_dict[1] - sess_dict[0] for sess_dict in silence_seg] silence_dur = max(0, session_dur - speech_dur) silence_ratio = silence_dur / session_dur overlap_len_list = [sess_dict[1] - sess_dict[0] for sess_dict in overlap_seg] overlap_dur = sum(overlap_len_list) if len(overlap_len_list) else 0 overlap_ratio = overlap_dur / speech_dur results = { "session_dur": session_dur, "silence_len_list": silence_len_list, "silence_dur": silence_dur, "silence_ratio": silence_ratio, "overlap_len_list": overlap_len_list, "overlap_dur": overlap_dur, "overlap_ratio": overlap_ratio, } return results def run_multispeaker_data_analysis( input_dir, session_dur=None, silence_mean=None, silence_var=None, overlap_mean=None, overlap_var=None, precise=False, save_path=None, num_workers=1, ) -> Dict: rttm_list = list(Path(input_dir).glob("*.rttm")) """ Analyze the multispeaker data and plot the distribution of silence and overlap durations. Args: input_dir (str): path to the directory containing the rttm files session_dur (float): duration of the session (specified by argument) silence_mean (float): mean of the silence duration distribution silence_var (float): variance of the silence duration distribution overlap_mean (float): mean of the overlap duration distribution overlap_var (float): variance of the overlap duration distribution precise (bool): whether to measure the precise duration of the session using sox save_path (str): path to save the plots Returns: stats (dict): dictionary containing the statistics of the analyzed data """ print(f"Found {len(rttm_list)} files to be processed") if len(rttm_list) == 0: raise ValueError(f"No rttm files found in {input_dir}") silence_duration = 0.0 total_duration = 0.0 overlap_duration = 0.0 silence_ratio_all = [] overlap_ratio_all = [] silence_length_all = [] overlap_length_all = [] queue = [] for rttm_file in tqdm(rttm_list): queue.append( {"rttm_file": rttm_file, "session_dur": session_dur, "precise": precise,} ) if num_workers <= 1: results = [process_sample(sess_dict) for sess_dict in tqdm(queue)] else: with multiprocessing.Pool(processes=num_workers) as p: results = list(tqdm(p.imap(process_sample, queue), total=len(queue), desc='Processing', leave=True,)) for item in results: total_duration += item["session_dur"] silence_duration += item["silence_dur"] overlap_duration += item["overlap_dur"] silence_length_all += item["silence_len_list"] overlap_length_all += item["overlap_len_list"] silence_ratio_all.append(item["silence_ratio"]) overlap_ratio_all.append(item["overlap_ratio"]) actual_silence_mean = silence_duration / total_duration actual_silence_var = np.var(silence_ratio_all) actual_overlap_mean = overlap_duration / (total_duration - silence_duration) actual_overlap_var = np.var(overlap_ratio_all) stats = OrderedDict() stats["total duration (hours)"] = f"{total_duration / 3600:.2f}" stats["number of sessions"] = len(rttm_list) stats["average session duration (seconds)"] = f"{total_duration / len(rttm_list):.2f}" stats["actual silence ratio mean/var"] = f"{actual_silence_mean:.4f}/{actual_silence_var:.4f}" stats["actual overlap ratio mean/var"] = f"{actual_overlap_mean:.4f}/{actual_overlap_var:.4f}" stats["expected silence ratio mean/var"] = f"{silence_mean}/{silence_var}" stats["expected overlap ratio mean/var"] = f"{overlap_mean}/{overlap_var}" stats["save_path"] = save_path print("-----------------------------------------------") print(" Results ") print("-----------------------------------------------") for k, v in stats.items(): print(k, ": ", v) print("-----------------------------------------------") fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 14)) fig.suptitle( f"Average session={total_duration/len(rttm_list):.2f} seconds, num sessions={len(rttm_list)}, total={total_duration/3600:.2f} hours" ) sns.histplot(silence_ratio_all, ax=ax1) ax1.set_xlabel("Silence ratio in a session") ax1.set_title( f"Target silence mean={silence_mean}, var={silence_var}. \nActual silence ratio={actual_silence_mean:.4f}, var={actual_silence_var:.4f}" ) _, scale = expon.fit(silence_length_all, floc=0) sns.histplot(silence_length_all, ax=ax2) ax2.set_xlabel("Per-silence length in seconds") ax2.set_title(f"Per-silence length histogram, \nfitted exponential distribution with mean={scale:.4f}") sns.histplot(overlap_ratio_all, ax=ax3) ax3.set_title( f"Target overlap mean={overlap_mean}, var={overlap_var}. \nActual ratio={actual_overlap_mean:.4f}, var={actual_overlap_var:.4f}" ) ax3.set_xlabel("Overlap ratio in a session") _, scale2 = expon.fit(overlap_length_all, floc=0) sns.histplot(overlap_length_all, ax=ax4) ax4.set_title(f"Per overlap length histogram, \nfitted exponential distribution with mean={scale2:.4f}") ax4.set_xlabel("Duration in seconds") if save_path: fig.savefig(save_path) print(f"Figure saved at: {save_path}") return stats def visualize_multispeaker_data(input_dir: str, output_dir: str, num_samples: int = 10) -> None: """ Visualize a set of randomly sampled data in the input directory Args: input_dir (str): Path to the input directory output_dir (str): Path to the output directory num_samples (int): Number of samples to visualize """ rttm_list = list(Path(input_dir).glob("*.rttm")) idx_list = np.random.permutation(len(rttm_list))[:num_samples] print(f"Visualizing {num_samples} random samples") for idx in idx_list: rttm_file = rttm_list[idx] audio_file = rttm_file.parent / Path(rttm_file.stem + ".wav") output_file = Path(output_dir) / Path(rttm_file.stem + ".png") plot_sample_from_rttm(audio_file=audio_file, rttm_file=rttm_file, save_path=str(output_file), show=False) print(f"Sample plots saved at: {output_dir}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("input_dir", default="", help="Input directory") parser.add_argument("-sd", "--session_dur", default=None, type=float, help="Duration per session in seconds") parser.add_argument("-sm", "--silence_mean", default=None, type=float, help="Expected silence ratio mean") parser.add_argument("-sv", "--silence_var", default=None, type=float, help="Expected silence ratio variance") parser.add_argument("-om", "--overlap_mean", default=None, type=float, help="Expected overlap ratio mean") parser.add_argument("-ov", "--overlap_var", default=None, type=float, help="Expected overlap ratio variance") parser.add_argument("-w", "--num_workers", default=1, type=int, help="Number of CPU workers to use") parser.add_argument("-s", "--num_samples", default=10, type=int, help="Number of random samples to plot") parser.add_argument("-o", "--output_dir", default="analysis/", type=str, help="Directory for saving output figure") parser.add_argument( "--precise", action="store_true", help="Set to get precise duration, with significant time cost" ) args = parser.parse_args() print("Running with params:") pprint(vars(args)) output_dir = Path(args.output_dir) if output_dir.exists(): print(f"Removing existing output directory: {args.output_dir}") shutil.rmtree(str(output_dir)) output_dir.mkdir(parents=True) run_multispeaker_data_analysis( input_dir=args.input_dir, session_dur=args.session_dur, silence_mean=args.silence_mean, silence_var=args.silence_var, overlap_mean=args.overlap_mean, overlap_var=args.overlap_var, precise=args.precise, save_path=str(Path(args.output_dir, "statistics.png")), num_workers=args.num_workers, ) visualize_multispeaker_data(input_dir=args.input_dir, output_dir=args.output_dir, num_samples=args.num_samples) print("The multispeaker data analysis has been completed.") print(f"Please check the output directory: \n{args.output_dir}")