ten-vad / include /ten_vad.py
Ziyi Lin
Add macOS and Windows support with usage of prebuilt-lib
8db09f4
#
# Copyright © 2025 Agora
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0, with certain conditions.
# Refer to the "LICENSE" file in the root directory for more information.
#
from ctypes import c_int, c_int32, c_float, c_size_t, CDLL, c_void_p, POINTER
import numpy as np
import os
import platform
class TenVad:
def __init__(self, hop_size: int = 256, threshold: float = 0.5):
self.hop_size = hop_size
self.threshold = threshold
if platform.system() == "Linux" and platform.machine() == "x86_64":
git_path = os.path.join(
os.path.dirname(os.path.relpath(__file__)),
"../lib/Linux/x64/libten_vad.so"
)
if os.path.exists(git_path):
self.vad_library = CDLL(git_path)
else:
pip_path = os.path.join(
os.path.dirname(os.path.relpath(__file__)),
"./ten_vad_library/libten_vad.so"
)
self.vad_library = CDLL(pip_path)
elif platform.system() == "Darwin":
git_path = os.path.join(
os.path.dirname(os.path.relpath(__file__)),
"../lib/macOS/ten_vad.framework/Versions/A/ten_vad"
)
if os.path.exists(git_path):
self.vad_library = CDLL(git_path)
else:
pip_path = os.path.join(
os.path.dirname(os.path.relpath(__file__)),
"./ten_vad_library/libten_vad"
)
self.vad_library = CDLL(pip_path)
elif platform.system().upper() == 'WINDOWS':
if platform.machine().upper() in ['X64', 'X86_64', 'AMD64']:
git_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"../lib/Windows/x64/ten_vad.dll"
)
if os.path.exists(git_path):
self.vad_library = CDLL(git_path)
else:
pip_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"./ten_vad_library/ten_vad.dll"
)
self.vad_library = CDLL(pip_path)
else:
git_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"../lib/Windows/x86/ten_vad.dll"
)
if os.path.exists(git_path):
self.vad_library = CDLL(git_path)
else:
pip_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"./ten_vad_library/ten_vad.dll"
)
self.vad_library = CDLL(pip_path)
else:
raise NotImplementedError(f"Unsupported platform: {platform.system()} {platform.machine()}")
self.vad_handler = c_void_p(0)
self.out_probability = c_float()
self.out_flags = c_int32()
self.vad_library.ten_vad_create.argtypes = [
POINTER(c_void_p),
c_size_t,
c_float,
]
self.vad_library.ten_vad_create.restype = c_int
self.vad_library.ten_vad_destroy.argtypes = [POINTER(c_void_p)]
self.vad_library.ten_vad_destroy.restype = c_int
self.vad_library.ten_vad_process.argtypes = [
c_void_p,
c_void_p,
c_size_t,
POINTER(c_float),
POINTER(c_int32),
]
self.vad_library.ten_vad_process.restype = c_int
self.create_and_init_handler()
def create_and_init_handler(self):
assert (
self.vad_library.ten_vad_create(
POINTER(c_void_p)(self.vad_handler),
c_size_t(self.hop_size),
c_float(self.threshold),
)
== 0
), "[TEN VAD]: create handler failure!"
def __del__(self):
assert (
self.vad_library.ten_vad_destroy(
POINTER(c_void_p)(self.vad_handler)
)
== 0
), "[TEN VAD]: destroy handler failure!"
def get_input_data(self, audio_data: np.ndarray):
audio_data = np.squeeze(audio_data)
assert (
len(audio_data.shape) == 1
and audio_data.shape[0] == self.hop_size
), "[TEN VAD]: audio data shape should be [%d]" % (
self.hop_size
)
assert (
type(audio_data[0]) == np.int16
), "[TEN VAD]: audio data type error, must be int16"
data_pointer = audio_data.__array_interface__["data"][0]
return c_void_p(data_pointer)
def process(self, audio_data: np.ndarray):
input_pointer = self.get_input_data(audio_data)
self.vad_library.ten_vad_process(
self.vad_handler,
input_pointer,
c_size_t(self.hop_size),
POINTER(c_float)(self.out_probability),
POINTER(c_int32)(self.out_flags),
)
return self.out_probability.value, self.out_flags.value