ten-vad / examples /plot_pr_curves.py
Ziyi Lin
Update testset name
97e7f89
raw
history blame
8.3 kB
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
import os, glob, sys, torchaudio
import numpy as np
import scipy.io.wavfile as Wavfile
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
os.system('git clone https://github.com/snakers4/silero-vad.git') # Clone the silero-vad repo, using Silero V5
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./silero-vad/src")))
from silero_vad.utils_vad import VADIterator, init_jit_model
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../include")))
from ten_vad import TenVad
def convert_label_to_framewise(label_file, hop_size):
frame_duration = hop_size / 16000
with open(label_file, "r") as f:
lines = f.readlines()
content = lines[0].strip().split(",")[1:]
start = np.array(
content[::3], dtype=float
) # Start point of each audio segment
end = np.array(
content[1:][::3], dtype=float
) # End point of each audio segment
lab_manual = np.array(
content[2:][::3], dtype=int
) # label, 0/1 stands for non-speech or speech, respectively
assert (
len(start) == len(end)
and len(start) == len(lab_manual)
and len(end) == len(lab_manual)
)
num = np.array(
np.round(((end - start) / frame_duration)), dtype=np.int32
) # get number of frames of each audio segment
label_framewise = np.array([])
for segment_idx in range(len(num)):
cur_lab = int(lab_manual[segment_idx])
num_segment = num[segment_idx]
if cur_lab == 1:
vad_result_this_segment = np.ones(num_segment)
elif cur_lab == 0:
vad_result_this_segment = np.zeros(num_segment)
label_framewise = np.append(label_framewise, vad_result_this_segment)
frame_num = min(
label_framewise.__len__(), int((end[-1] - start[0]) / frame_duration)
)
label_framewise = label_framewise[:frame_num]
return label_framewise
def read_file(file_path):
with open(file_path, "r") as f:
lines = f.readlines()
lines_arr = np.array([])
for line in lines:
lines_arr = np.append(lines_arr, float(line.strip()))
return lines_arr
def get_precision_recall(VAD_result, label, threshold):
vad_result_hard = np.where(VAD_result >= threshold, 1, 0)
# Compute confusion matrix
TN, FP, FN, TP = confusion_matrix(label, vad_result_hard).ravel()
# Compute precision, recall, false positive rate and false negative rate
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
FPR = FP / (FP + TN) if (FP + TN) > 0 else 0
FNR = FN / (TP + FN) if (TP + FN) > 0 else 0
return precision, recall, FPR, FNR
def silero_vad_inference_single_file(wav_path):
current_directory = os.path.dirname(os.path.abspath(__file__))
model = init_jit_model(f'{current_directory}/silero-vad/src/silero_vad/data/silero_vad.jit')
vad_iterator = VADIterator(model)
window_size_samples = 512
speech_probs = np.array([])
wav, sr = torchaudio.load(wav_path)
wav = wav.squeeze(0)
for i in range(0, len(wav), window_size_samples):
chunk = wav[i: i+ window_size_samples]
if len(chunk) < window_size_samples:
break
speech_prob = model(chunk, sr).item()
speech_probs = np.append(speech_probs, speech_prob)
vad_iterator.reset_states() # reset model states after each audio
return speech_probs, window_size_samples
def ten_vad_process_wav(ten_vad_instance, wav_path, hop_size=256):
_, data = Wavfile.read(wav_path)
num_frames = data.shape[0] // hop_size
voice_prob_arr = np.array([])
for i in range(num_frames):
input_data = data[i * hop_size: (i + 1) * hop_size]
voice_prob, _ = ten_vad_instance.process(input_data)
voice_prob_arr = np.append(voice_prob_arr, voice_prob)
return voice_prob_arr
if __name__ == "__main__":
# Get the directory of the script
script_dir = os.path.dirname(os.path.abspath(__file__))
# testset dir
test_dir = f"{script_dir}/../testset"
# Initialization
hop_size = 256
threshold = 0.5
label_all, vad_result_ten_vad_all = np.array([]), np.array([])
label_hop_512_all, vad_result_silero_vad_all = np.array([]), np.array([])
wav_list = glob.glob(f"{test_dir}/*.wav")
# The WebRTC VAD is from the latest version of WebRTC and is not plotted here
print("Start processing")
for wav_path in wav_list:
# Running TEN VAD
ten_vad_instance = TenVad(hop_size, threshold)
label_file = wav_path.replace(".wav", ".scv")
label = convert_label_to_framewise(
label_file, hop_size=hop_size
) # Convert the VAD label to frame-wise one
vad_result_ten_vad = ten_vad_process_wav(
ten_vad_instance, wav_path, hop_size=hop_size
)
frame_num = min(label.__len__(), vad_result_ten_vad.__len__())
vad_result_ten_vad_all = np.append(
vad_result_ten_vad_all, vad_result_ten_vad[1:frame_num]
)
label_all = np.append(label_all, label[:frame_num - 1])
del ten_vad_instance # To prevent getting different results of each run
# Running Silero VAD
label_hop_512 = convert_label_to_framewise(
label_file, hop_size=512
) # Convert the VAD label to frame-wise one for Silero VAD
vad_result_silero_vad, _ = silero_vad_inference_single_file(wav_path)
frame_num_silero_vad = min(label_hop_512.__len__(), vad_result_silero_vad.__len__())
vad_result_silero_vad_all = np.append(vad_result_silero_vad_all, vad_result_silero_vad[:frame_num_silero_vad])
label_hop_512_all = np.append(label_hop_512_all, label_hop_512[:frame_num_silero_vad])
# Compute Precision and Recall
threshold_arr = np.arange(0, 1.01, 0.01)
pr_data_arr = np.zeros((threshold_arr.__len__(), 3))
pr_data_silero_vad_arr = np.zeros((threshold_arr.__len__(), 3))
for ind, threshold in enumerate(threshold_arr):
precision, recall, FPR, FNR = get_precision_recall(vad_result_ten_vad_all, label_all, threshold)
pr_data_arr[ind] = precision, recall, threshold
precision_silero_vad, recall_silero_vad, FPR_silero_vad, FNR_silero_vad = get_precision_recall(vad_result_silero_vad_all, label_hop_512_all, threshold)
pr_data_silero_vad_arr[ind] = precision_silero_vad, recall_silero_vad, threshold
# Plot PR Curve
print("Plotting PR Curve")
pr_data_arr_to_plot = pr_data_arr[:-1]
plt.plot(
pr_data_arr_to_plot[:, 1],
pr_data_arr_to_plot[:, 0],
color="red",
label="TEN VAD",
) # Precision on y-axis, Recall on x-axis
pr_data_silero_vad_arr_to_plot = pr_data_silero_vad_arr[:-1]
plt.plot(
pr_data_silero_vad_arr_to_plot[:, 1], # Recall (x-axis)
pr_data_silero_vad_arr_to_plot[:, 0], # Precision (y-axis)
color="blue",
label="Silero VAD",
)
plt.xlabel("Recall", fontsize=14, fontweight="bold", color="black")
plt.ylabel("Precision", fontsize=14, fontweight="bold", color="black")
legend = plt.legend()
legend.get_texts()[0].set_fontweight("bold")
legend.get_texts()[1].set_fontweight("bold")
plt.grid(True)
plt.xlim(0.65, 1)
plt.ylim(0.7, 1)
plt.title(
"Precision-Recall Curve of TEN VAD on TEN-VAD-TestSet",
fontsize=12,
color="black",
fontweight="bold",
)
save_path = f"{script_dir}/PR_Curves.png"
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"PR Curves png file saved, save path: {save_path}")
# Save the PR data to txt file
pr_data_save_path = f"{script_dir}/PR_data_TEN_VAD.txt"
with open(pr_data_save_path, "w") as f:
for ind in range(pr_data_arr.shape[0]):
precision, recall, threshold = (
pr_data_arr[ind, 0],
pr_data_arr[ind, 1],
pr_data_arr[ind, 2],
)
f.write(f"{threshold:.2f} {precision:.4f} {recall:.4f}\n")
print("Processing done!")