victormattosli's picture
Update handler.py
154c369 verified
import torch
import numpy as np
import librosa
import requests
import io
import os
import base64 # <-- Nova importação
import matplotlib.pyplot as plt # <-- Nova importação
import soundfile as sf
from scipy.signal import butter, lfilter
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
from pydub import AudioSegment
# --- FUNÇÕES DE PRÉ-PROCESSAMENTO ---
TARGET_SR = 2000
IMAGE_HEIGHT = 128
def butter_bandpass_filter(data, fs, lowcut=20.0, highcut=200.0, order=3):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
if high >= 1 or low <= 0:
return data
b, a = butter(order, [low, high], btype='band')
return lfilter(b, a, data)
def create_spectrogram_image(y_cleaned, sr):
mel_spec = librosa.feature.melspectrogram(y=y_cleaned, sr=sr, n_mels=IMAGE_HEIGHT)
S_DB = librosa.power_to_db(mel_spec, ref=np.max)
img_array = (S_DB - S_DB.min()) / (S_DB.max() - S_DB.min() + 1e-6) * 255.0
img_array = img_array.astype(np.uint8)
return Image.fromarray(img_array).convert("RGB")
def calculate_bpm(y_cleaned, sr):
onset_env = librosa.onset.onset_strength(y=y_cleaned, sr=sr, aggregate=np.mean)
bpm = librosa.beat.tempo(onset_envelope=onset_env, sr=sr)[0]
return bpm
# --- HANDLER DO ENDPOINT ---
class EndpointHandler:
def __init__(self, path=""):
device = "cuda" if torch.cuda.is_available() else "cpu"
model_directory = os.path.join(path, "modelo-vit-audio-final")
processor = AutoImageProcessor.from_pretrained(model_directory)
model = AutoModelForImageClassification.from_pretrained(model_directory).to(device)
self.pipe = pipeline(
"image-classification",
model=model,
image_processor=processor,
device=device
)
print("Pipeline ViT com pré-processamento de áudio (via URL) carregado com sucesso.")
def __call__(self, data: dict) -> list:
audio_url = data.pop("inputs", None)
if not audio_url or not isinstance(audio_url, str):
return [{"error": "Nenhum 'inputs' (URL de áudio como string) foi fornecido."}]
try:
print(f"Baixando e processando áudio de: {audio_url}")
response = requests.get(audio_url)
response.raise_for_status()
audio_data = io.BytesIO(response.content)
sound = AudioSegment.from_file(audio_data)
sound = sound.set_channels(1)
sr_original = sound.frame_rate
y = np.array(sound.get_array_of_samples()).astype(np.float32)
y_normalized = y / (2**15)
if sr_original != TARGET_SR:
y_resampled = librosa.resample(y=y_normalized, orig_sr=sr_original, target_sr=TARGET_SR)
else:
y_resampled = y_normalized
y_cleaned = butter_bandpass_filter(y_resampled, fs=TARGET_SR)
spectrogram_image = create_spectrogram_image(y_cleaned, TARGET_SR)
# Gera um buffer em memória com o áudio filtrado
buffer = io.BytesIO()
sf.write(buffer, y_cleaned, TARGET_SR, format='WAV')
buffer.seek(0)
# Codifica em Base64
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
bpm = calculate_bpm(y_cleaned, TARGET_SR)
print(f"BPM Estimado: {bpm:.0f}")
# --- NOVO BLOCO: GERAR E CODIFICAR O GRÁFICO PCG ---
print("Gerando e codificando gráfico PCG para a resposta...")
time_axis = np.arange(0, len(y_cleaned)) / TARGET_SR
start_time, end_time = 1.0, 5.0
start_index, end_index = int(start_time * TARGET_SR), int(end_time * TARGET_SR)
if end_index > len(y_cleaned):
end_index = len(y_cleaned)
start_index = max(0, end_index - int(4 * TARGET_SR))
fig, ax = plt.subplots(figsize=(15, 5))
ax.plot(time_axis, y_cleaned, linewidth=0.7)
ax.set_title("Fonocardiograma (PCG)")
ax.set_xlabel("Tempo (segundos)")
ax.set_ylabel("Amplitude")
ax.grid(True, linestyle='--')
ax.set_xlim(time_axis[start_index], time_axis[end_index - 1])
# Salva o gráfico em um buffer de memória
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close(fig) # Fecha a figura para liberar memória
buf.seek(0)
# Codifica a imagem em base64
pcg_image_base64 = base64.b64encode(buf.read()).decode('utf-8')
# ---------------------------------------------------------
print("Enviando espectrograma para o pipeline de predição...")
prediction = self.pipe(spectrogram_image)
print(f"Predição concluída: {prediction}")
# --- RESPOSTA FINAL ATUALIZADA ---
final_response = {
"classification_results": prediction,
"bpm_estimated": int(round(bpm)),
"pcg_image_base64": f'data:image/png;base64,{pcg_image_base64}', # <-- Adicionamos a imagem aqui
"audio_base64": f'data:audio/mp3;base64,{audio_base64}' #Audio codificado base64
}
return [final_response]
except Exception as e:
error_message = f"Erro ao processar a URL do áudio: {str(e)}"
import traceback
traceback.print_exc()
return [{"error": error_message}]