import torch import gradio as gr import speechbrain as sb import torchaudio from hyperpyyaml import load_hyperpyyaml from pyctcdecode import build_ctcdecoder import os # Load hyperparameters and initialize the ASR model hparams_file = "train.yaml" with open(hparams_file, "r") as fin: hparams = load_hyperpyyaml(fin) # Initialize the label encoder encoder = sb.dataio.encoder.CTCTextEncoder() encoder.load_or_create( path=hparams["encoder_file"], from_didatasets=[[]], output_key="char_list", special_labels=special_labels = {"blank_label":0,"unk_label": 1}, sequence_input=True, ) # Prepare labels for the CTC decoder ind2lab = encoder.ind2lab labels = [ind2lab[x] for x in range(len(ind2lab))] labels = [""] + labels[1:-1] + ["1"] # Initialize the CTC decoder decoder = build_ctcdecoder( labels, kenlm_model_path=hparams["ngram_lm_path"], alpha=0.5, beta=1.0, ) # Define the ASR class with the `treat_wav` method class ASR(sb.core.Brain): def treat_wav(self, sig): """Process a waveform and return the transcribed text.""" feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu")) feats = self.modules.enc(feats) logits = self.modules.ctc_lin(feats) p_ctc = self.hparams.log_softmax(logits) predicted_words = [] for logs in p_ctc: text = decoder.decode(logs.detach().cpu().numpy()) predicted_words.append(text.split(" ")) return " ".join(predicted_words[0]) # Initialize the ASR model asr_brain = ASR( modules=hparams["modules"], hparams=hparams, run_opts={"device": "cpu"}, checkpointer=hparams["checkpointer"], ) asr_brain.tokenizer = encoder asr_brain.checkpointer.recover_if_possible() asr_brain.modules.eval() # Function to process audio files def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"): if file_mic is not None: wav = file_mic elif file_upload is not None: wav = file_upload else: return "ERROR: You have to either use the microphone or upload an audio file" # Read and preprocess the audio file info = torchaudio.info(wav) sr = info.sample_rate sig = sb.dataio.dataio.read_audio(wav) if len(sig.shape) > 1: sig = torch.mean(sig, dim=1) sig = torch.unsqueeze(sig, 0) tensor_wav = sig.to(device) resampled = torchaudio.functional.resample(tensor_wav, sr, 16000) # Transcribe the audio sentence = asr.treat_wav(resampled) return sentence # Gradio interface title = "Tunisian Speech Recognition" description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %. \n \n Interesting isn\'t it !''' gr.Interface( fn=treat_wav_file, inputs=[ gr.Audio(sources="microphone", type='filepath', label="Record"), gr.Audio(sources="upload", type='filepath', label="Upload File") ], outputs="text", title=title, description=description ).launch()