File size: 8,302 Bytes
1e05caf 97e7f89 1e05caf 8cbbe9b 1e05caf 8cbbe9b 1e05caf 8cbbe9b 1e05caf 8cbbe9b 1e05caf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
import os, glob, sys, torchaudio
import numpy as np
import scipy.io.wavfile as Wavfile
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
os.system('git clone https://github.com/snakers4/silero-vad.git') # Clone the silero-vad repo, using Silero V5
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./silero-vad/src")))
from silero_vad.utils_vad import VADIterator, init_jit_model
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../include")))
from ten_vad import TenVad
def convert_label_to_framewise(label_file, hop_size):
frame_duration = hop_size / 16000
with open(label_file, "r") as f:
lines = f.readlines()
content = lines[0].strip().split(",")[1:]
start = np.array(
content[::3], dtype=float
) # Start point of each audio segment
end = np.array(
content[1:][::3], dtype=float
) # End point of each audio segment
lab_manual = np.array(
content[2:][::3], dtype=int
) # label, 0/1 stands for non-speech or speech, respectively
assert (
len(start) == len(end)
and len(start) == len(lab_manual)
and len(end) == len(lab_manual)
)
num = np.array(
np.round(((end - start) / frame_duration)), dtype=np.int32
) # get number of frames of each audio segment
label_framewise = np.array([])
for segment_idx in range(len(num)):
cur_lab = int(lab_manual[segment_idx])
num_segment = num[segment_idx]
if cur_lab == 1:
vad_result_this_segment = np.ones(num_segment)
elif cur_lab == 0:
vad_result_this_segment = np.zeros(num_segment)
label_framewise = np.append(label_framewise, vad_result_this_segment)
frame_num = min(
label_framewise.__len__(), int((end[-1] - start[0]) / frame_duration)
)
label_framewise = label_framewise[:frame_num]
return label_framewise
def read_file(file_path):
with open(file_path, "r") as f:
lines = f.readlines()
lines_arr = np.array([])
for line in lines:
lines_arr = np.append(lines_arr, float(line.strip()))
return lines_arr
def get_precision_recall(VAD_result, label, threshold):
vad_result_hard = np.where(VAD_result >= threshold, 1, 0)
# Compute confusion matrix
TN, FP, FN, TP = confusion_matrix(label, vad_result_hard).ravel()
# Compute precision, recall, false positive rate and false negative rate
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
FPR = FP / (FP + TN) if (FP + TN) > 0 else 0
FNR = FN / (TP + FN) if (TP + FN) > 0 else 0
return precision, recall, FPR, FNR
def silero_vad_inference_single_file(wav_path):
current_directory = os.path.dirname(os.path.abspath(__file__))
model = init_jit_model(f'{current_directory}/silero-vad/src/silero_vad/data/silero_vad.jit')
vad_iterator = VADIterator(model)
window_size_samples = 512
speech_probs = np.array([])
wav, sr = torchaudio.load(wav_path)
wav = wav.squeeze(0)
for i in range(0, len(wav), window_size_samples):
chunk = wav[i: i+ window_size_samples]
if len(chunk) < window_size_samples:
break
speech_prob = model(chunk, sr).item()
speech_probs = np.append(speech_probs, speech_prob)
vad_iterator.reset_states() # reset model states after each audio
return speech_probs, window_size_samples
def ten_vad_process_wav(ten_vad_instance, wav_path, hop_size=256):
_, data = Wavfile.read(wav_path)
num_frames = data.shape[0] // hop_size
voice_prob_arr = np.array([])
for i in range(num_frames):
input_data = data[i * hop_size: (i + 1) * hop_size]
voice_prob, _ = ten_vad_instance.process(input_data)
voice_prob_arr = np.append(voice_prob_arr, voice_prob)
return voice_prob_arr
if __name__ == "__main__":
# Get the directory of the script
script_dir = os.path.dirname(os.path.abspath(__file__))
# testset dir
test_dir = f"{script_dir}/../testset"
# Initialization
hop_size = 256
threshold = 0.5
label_all, vad_result_ten_vad_all = np.array([]), np.array([])
label_hop_512_all, vad_result_silero_vad_all = np.array([]), np.array([])
wav_list = glob.glob(f"{test_dir}/*.wav")
# The WebRTC VAD is from the latest version of WebRTC and is not plotted here
print("Start processing")
for wav_path in wav_list:
# Running TEN VAD
ten_vad_instance = TenVad(hop_size, threshold)
label_file = wav_path.replace(".wav", ".scv")
label = convert_label_to_framewise(
label_file, hop_size=hop_size
) # Convert the VAD label to frame-wise one
vad_result_ten_vad = ten_vad_process_wav(
ten_vad_instance, wav_path, hop_size=hop_size
)
frame_num = min(label.__len__(), vad_result_ten_vad.__len__())
vad_result_ten_vad_all = np.append(
vad_result_ten_vad_all, vad_result_ten_vad[1:frame_num]
)
label_all = np.append(label_all, label[:frame_num - 1])
del ten_vad_instance # To prevent getting different results of each run
# Running Silero VAD
label_hop_512 = convert_label_to_framewise(
label_file, hop_size=512
) # Convert the VAD label to frame-wise one for Silero VAD
vad_result_silero_vad, _ = silero_vad_inference_single_file(wav_path)
frame_num_silero_vad = min(label_hop_512.__len__(), vad_result_silero_vad.__len__())
vad_result_silero_vad_all = np.append(vad_result_silero_vad_all, vad_result_silero_vad[:frame_num_silero_vad])
label_hop_512_all = np.append(label_hop_512_all, label_hop_512[:frame_num_silero_vad])
# Compute Precision and Recall
threshold_arr = np.arange(0, 1.01, 0.01)
pr_data_arr = np.zeros((threshold_arr.__len__(), 3))
pr_data_silero_vad_arr = np.zeros((threshold_arr.__len__(), 3))
for ind, threshold in enumerate(threshold_arr):
precision, recall, FPR, FNR = get_precision_recall(vad_result_ten_vad_all, label_all, threshold)
pr_data_arr[ind] = precision, recall, threshold
precision_silero_vad, recall_silero_vad, FPR_silero_vad, FNR_silero_vad = get_precision_recall(vad_result_silero_vad_all, label_hop_512_all, threshold)
pr_data_silero_vad_arr[ind] = precision_silero_vad, recall_silero_vad, threshold
# Plot PR Curve
print("Plotting PR Curve")
pr_data_arr_to_plot = pr_data_arr[:-1]
plt.plot(
pr_data_arr_to_plot[:, 1],
pr_data_arr_to_plot[:, 0],
color="red",
label="TEN VAD",
) # Precision on y-axis, Recall on x-axis
pr_data_silero_vad_arr_to_plot = pr_data_silero_vad_arr[:-1]
plt.plot(
pr_data_silero_vad_arr_to_plot[:, 1], # Recall (x-axis)
pr_data_silero_vad_arr_to_plot[:, 0], # Precision (y-axis)
color="blue",
label="Silero VAD",
)
plt.xlabel("Recall", fontsize=14, fontweight="bold", color="black")
plt.ylabel("Precision", fontsize=14, fontweight="bold", color="black")
legend = plt.legend()
legend.get_texts()[0].set_fontweight("bold")
legend.get_texts()[1].set_fontweight("bold")
plt.grid(True)
plt.xlim(0.65, 1)
plt.ylim(0.7, 1)
plt.title(
"Precision-Recall Curve of TEN VAD on TEN-VAD-TestSet",
fontsize=12,
color="black",
fontweight="bold",
)
save_path = f"{script_dir}/PR_Curves.png"
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"PR Curves png file saved, save path: {save_path}")
# Save the PR data to txt file
pr_data_save_path = f"{script_dir}/PR_data_TEN_VAD.txt"
with open(pr_data_save_path, "w") as f:
for ind in range(pr_data_arr.shape[0]):
precision, recall, threshold = (
pr_data_arr[ind, 0],
pr_data_arr[ind, 1],
pr_data_arr[ind, 2],
)
f.write(f"{threshold:.2f} {precision:.4f} {recall:.4f}\n")
print("Processing done!")
|