File size: 5,234 Bytes
b743710
b50f2a2
 
 
 
b743710
 
 
 
8db09f4
b743710
 
 
 
 
8db09f4
 
b743710
8db09f4
b743710
8db09f4
 
 
 
 
 
b743710
8db09f4
 
 
 
 
 
b743710
8db09f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b743710
8db09f4
 
 
 
 
b743710
8db09f4
 
 
 
 
 
 
 
 
 
b743710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#
#  Copyright © 2025 Agora
#  This file is part of TEN Framework, an open source project.
#  Licensed under the Apache License, Version 2.0, with certain conditions.
#  Refer to the "LICENSE" file in the root directory for more information.
#
from ctypes import c_int, c_int32, c_float, c_size_t, CDLL, c_void_p, POINTER
import numpy as np
import os
import platform

class TenVad:
    def __init__(self, hop_size: int = 256, threshold: float = 0.5):
        self.hop_size = hop_size
        self.threshold = threshold
        if platform.system() == "Linux" and platform.machine() == "x86_64":
            git_path = os.path.join(
                os.path.dirname(os.path.relpath(__file__)),
                "../lib/Linux/x64/libten_vad.so"
            )
            if os.path.exists(git_path):
                self.vad_library = CDLL(git_path)
            else:
                pip_path = os.path.join(
                    os.path.dirname(os.path.relpath(__file__)),
                    "./ten_vad_library/libten_vad.so"
                )
                self.vad_library = CDLL(pip_path)
                
        elif platform.system() == "Darwin":
            git_path = os.path.join(
                os.path.dirname(os.path.relpath(__file__)),
                "../lib/macOS/ten_vad.framework/Versions/A/ten_vad"
            )
            if os.path.exists(git_path):
                self.vad_library = CDLL(git_path)
            else:
                pip_path = os.path.join(
                    os.path.dirname(os.path.relpath(__file__)),
                    "./ten_vad_library/libten_vad"
                )
                self.vad_library = CDLL(pip_path)
        elif platform.system().upper() == 'WINDOWS':
            if platform.machine().upper() in ['X64', 'X86_64', 'AMD64']:
                git_path = os.path.join(
                    os.path.dirname(os.path.realpath(__file__)),
                    "../lib/Windows/x64/ten_vad.dll"
                )
                if os.path.exists(git_path):
                    self.vad_library = CDLL(git_path)
                else:
                    pip_path = os.path.join(
                        os.path.dirname(os.path.realpath(__file__)),
                        "./ten_vad_library/ten_vad.dll"
                    )
                    self.vad_library = CDLL(pip_path)
            else:
                git_path = os.path.join(
                    os.path.dirname(os.path.realpath(__file__)),
                    "../lib/Windows/x86/ten_vad.dll"
                )
                if os.path.exists(git_path):
                    self.vad_library = CDLL(git_path)
                else:
                    pip_path = os.path.join(
                        os.path.dirname(os.path.realpath(__file__)),
                        "./ten_vad_library/ten_vad.dll"
                    )
                    self.vad_library = CDLL(pip_path)
        else:
            raise NotImplementedError(f"Unsupported platform: {platform.system()} {platform.machine()}")
        self.vad_handler = c_void_p(0)
        self.out_probability = c_float()
        self.out_flags = c_int32()

        self.vad_library.ten_vad_create.argtypes = [
            POINTER(c_void_p),
            c_size_t,
            c_float,
        ]
        self.vad_library.ten_vad_create.restype = c_int

        self.vad_library.ten_vad_destroy.argtypes = [POINTER(c_void_p)]
        self.vad_library.ten_vad_destroy.restype = c_int

        self.vad_library.ten_vad_process.argtypes = [
            c_void_p,
            c_void_p,
            c_size_t,
            POINTER(c_float),
            POINTER(c_int32),
        ]
        self.vad_library.ten_vad_process.restype = c_int
        self.create_and_init_handler()
    
    def create_and_init_handler(self):
        assert (
            self.vad_library.ten_vad_create(
                POINTER(c_void_p)(self.vad_handler),
                c_size_t(self.hop_size),
                c_float(self.threshold),
            ) 
            == 0
        ), "[TEN VAD]: create handler failure!"

    def __del__(self):
        assert (
            self.vad_library.ten_vad_destroy(
                POINTER(c_void_p)(self.vad_handler)
            )
            == 0
        ), "[TEN VAD]: destroy handler failure!"
    
    def get_input_data(self, audio_data: np.ndarray):
        audio_data = np.squeeze(audio_data)
        assert (
            len(audio_data.shape) == 1 
            and audio_data.shape[0] == self.hop_size
        ), "[TEN VAD]: audio data shape should be [%d]" % (
            self.hop_size
        )
        assert (
            type(audio_data[0]) == np.int16
        ), "[TEN VAD]: audio data type error, must be int16"
        data_pointer = audio_data.__array_interface__["data"][0]
        return c_void_p(data_pointer)
    
    def process(self, audio_data: np.ndarray):
        input_pointer = self.get_input_data(audio_data)
        self.vad_library.ten_vad_process(
            self.vad_handler,
            input_pointer,
            c_size_t(self.hop_size),
            POINTER(c_float)(self.out_probability),
            POINTER(c_int32)(self.out_flags),
        )
        return self.out_probability.value, self.out_flags.value