KIFF's picture
Update handler.py
eea9e6d verified
raw
history blame
2.23 kB
from typing import Dict
from pyannote.audio import Pipeline
import torch
import base64
import numpy as np
import os
SAMPLE_RATE = 16000
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = Pipeline.from_pretrained(
"pyannote/[email protected]", # 3.0 and later is nor supported as of yet in dec 2023
use_auth_token=os.environ.get("HF_API_TOKEN")
)
self.pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
def __call__(self, data: Dict) -> Dict:
"""
Args:
data (Dict):
'inputs': Base64-encoded audio bytes
'parameters': Additional diarization parameters (currently unused)
Return:
Dict: Speaker diarization results
"""
inputs = data.get("inputs")
parameters = data.get("parameters", {}) # We are not using them now, since model don't take speaker count anymore
# Decode the base64 audio data
audio_data = base64.b64decode(inputs)
audio_nparray = np.frombuffer(audio_data, dtype=np.int16)
# Handle multi-channel audio (convert to mono)
if audio_nparray.ndim > 1:
audio_nparray = audio_nparray.mean(axis=0) # Average channels to create mono
# Convert to PyTorch tensor
audio_tensor = torch.from_numpy(audio_nparray).float().unsqueeze(0)
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
pyannote_input = {"waveform": audio_tensor, "sample_rate": SAMPLE_RATE}
# Run diarization pipeline (without num_speakers)
try:
diarization = self.pipeline(pyannote_input)
except Exception as e:
print(f"An unexpected error occurred: {e}")
return {"error": "Diarization failed unexpectedly"}
# Build a friendly JSON response
processed_diarization = [
{
"label": str(label),
"start": str(segment.start),
"stop": str(segment.end),
}
for segment, _, label in diarization.itertracks(yield_label=True)
]
return {"diarization": processed_diarization}