brdhaker3 commited on
Commit
75a78b0
·
verified ·
1 Parent(s): 9c29f8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -108
app.py CHANGED
@@ -1,109 +1,105 @@
1
- import torch
2
- import gradio as gr
3
- import speechbrain as sb
4
- import torchaudio
5
- from hyperpyyaml import load_hyperpyyaml
6
- from pyctcdecode import build_ctcdecoder
7
- import os
8
-
9
- # Load hyperparameters and initialize the ASR model
10
- hparams_file = "train.yaml"
11
- with open(hparams_file, "r") as fin:
12
- hparams = load_hyperpyyaml(fin)
13
-
14
- # Initialize the label encoder
15
- label_encoder = sb.dataio.encoder.CTCTextEncoder()
16
- lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
17
- special_labels = {
18
- "blank_label": hparams["blank_index"],
19
- "unk_label": hparams["unk_index"]
20
- }
21
- label_encoder.load_or_create(
22
- path=lab_enc_file,
23
- from_didatasets=[[]],
24
- output_key="char_list",
25
- special_labels=special_labels,
26
- sequence_input=True,
27
- )
28
-
29
- # Prepare labels for the CTC decoder
30
- ind2lab = label_encoder.ind2lab
31
- labels = [ind2lab[x] for x in range(len(ind2lab))]
32
- labels = [""] + labels[1:-1] + ["1"]
33
-
34
- # Initialize the CTC decoder
35
- decoder = build_ctcdecoder(
36
- labels,
37
- kenlm_model_path=hparams["ngram_lm_path"],
38
- alpha=0.5,
39
- beta=1.0,
40
- )
41
-
42
-
43
- # Define the ASR class with the `treat_wav` method
44
- class ASR(sb.core.Brain):
45
- def treat_wav(self, sig):
46
- """Process a waveform and return the transcribed text."""
47
- feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
48
- feats = self.modules.enc(feats)
49
- logits = self.modules.ctc_lin(feats)
50
- p_ctc = self.hparams.log_softmax(logits)
51
- predicted_words = []
52
- for logs in p_ctc:
53
- text = decoder.decode(logs.detach().cpu().numpy())
54
- predicted_words.append(text.split(" "))
55
- return " ".join(predicted_words[0])
56
-
57
-
58
- # Initialize the ASR model
59
- asr_brain = ASR(
60
- modules=hparams["modules"],
61
- hparams=hparams,
62
- run_opts={"device": "cpu"},
63
- checkpointer=hparams["checkpointer"],
64
- )
65
- asr_brain.tokenizer = label_encoder
66
- asr_brain.checkpointer.recover_if_possible()
67
- asr_brain.modules.eval()
68
-
69
-
70
- # Function to process audio files
71
- def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
72
- if file_mic is not None:
73
- wav = file_mic
74
- elif file_upload is not None:
75
- wav = file_upload
76
- else:
77
- return "ERROR: You have to either use the microphone or upload an audio file"
78
-
79
- # Read and preprocess the audio file
80
- info = torchaudio.info(wav)
81
- sr = info.sample_rate
82
- sig = sb.dataio.dataio.read_audio(wav)
83
- if len(sig.shape) > 1:
84
- sig = torch.mean(sig, dim=1)
85
- sig = torch.unsqueeze(sig, 0)
86
- tensor_wav = sig.to(device)
87
- resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)
88
-
89
- # Transcribe the audio
90
- sentence = asr.treat_wav(resampled)
91
- return sentence
92
-
93
-
94
- # Gradio interface
95
- title = "Tunisian Speech Recognition"
96
- description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %.
97
- \n
98
- \n Interesting isn\'t it !'''
99
-
100
- gr.Interface(
101
- fn=treat_wav_file,
102
- inputs=[
103
- gr.Audio(sources="microphone", type='filepath', label="Record"),
104
- gr.Audio(sources="upload", type='filepath', label="Upload File")
105
- ],
106
- outputs="text",
107
- title=title,
108
- description=description
109
  ).launch()
 
1
+ import torch
2
+ import gradio as gr
3
+ import speechbrain as sb
4
+ import torchaudio
5
+ from hyperpyyaml import load_hyperpyyaml
6
+ from pyctcdecode import build_ctcdecoder
7
+ import os
8
+
9
+ # Load hyperparameters and initialize the ASR model
10
+ hparams_file = "train.yaml"
11
+ with open(hparams_file, "r") as fin:
12
+ hparams = load_hyperpyyaml(fin)
13
+
14
+ # Initialize the label encoder
15
+ encoder = sb.dataio.encoder.CTCTextEncoder()
16
+
17
+ encoder.load_or_create(
18
+ path=hparams["encoder_file"],
19
+ from_didatasets=[[]],
20
+ output_key="char_list",
21
+ special_labels=special_labels = {"blank_label":0,"unk_label": 1},
22
+ sequence_input=True,
23
+ )
24
+
25
+ # Prepare labels for the CTC decoder
26
+ ind2lab = encoder.ind2lab
27
+ labels = [ind2lab[x] for x in range(len(ind2lab))]
28
+ labels = [""] + labels[1:-1] + ["1"]
29
+
30
+ # Initialize the CTC decoder
31
+ decoder = build_ctcdecoder(
32
+ labels,
33
+ kenlm_model_path=hparams["ngram_lm_path"],
34
+ alpha=0.5,
35
+ beta=1.0,
36
+ )
37
+
38
+
39
+ # Define the ASR class with the `treat_wav` method
40
+ class ASR(sb.core.Brain):
41
+ def treat_wav(self, sig):
42
+ """Process a waveform and return the transcribed text."""
43
+ feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
44
+ feats = self.modules.enc(feats)
45
+ logits = self.modules.ctc_lin(feats)
46
+ p_ctc = self.hparams.log_softmax(logits)
47
+ predicted_words = []
48
+ for logs in p_ctc:
49
+ text = decoder.decode(logs.detach().cpu().numpy())
50
+ predicted_words.append(text.split(" "))
51
+ return " ".join(predicted_words[0])
52
+
53
+
54
+ # Initialize the ASR model
55
+ asr_brain = ASR(
56
+ modules=hparams["modules"],
57
+ hparams=hparams,
58
+ run_opts={"device": "cpu"},
59
+ checkpointer=hparams["checkpointer"],
60
+ )
61
+ asr_brain.tokenizer = encoder
62
+ asr_brain.checkpointer.recover_if_possible()
63
+ asr_brain.modules.eval()
64
+
65
+
66
+ # Function to process audio files
67
+ def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
68
+ if file_mic is not None:
69
+ wav = file_mic
70
+ elif file_upload is not None:
71
+ wav = file_upload
72
+ else:
73
+ return "ERROR: You have to either use the microphone or upload an audio file"
74
+
75
+ # Read and preprocess the audio file
76
+ info = torchaudio.info(wav)
77
+ sr = info.sample_rate
78
+ sig = sb.dataio.dataio.read_audio(wav)
79
+ if len(sig.shape) > 1:
80
+ sig = torch.mean(sig, dim=1)
81
+ sig = torch.unsqueeze(sig, 0)
82
+ tensor_wav = sig.to(device)
83
+ resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)
84
+
85
+ # Transcribe the audio
86
+ sentence = asr.treat_wav(resampled)
87
+ return sentence
88
+
89
+
90
+ # Gradio interface
91
+ title = "Tunisian Speech Recognition"
92
+ description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %.
93
+ \n
94
+ \n Interesting isn\'t it !'''
95
+
96
+ gr.Interface(
97
+ fn=treat_wav_file,
98
+ inputs=[
99
+ gr.Audio(sources="microphone", type='filepath', label="Record"),
100
+ gr.Audio(sources="upload", type='filepath', label="Upload File")
101
+ ],
102
+ outputs="text",
103
+ title=title,
104
+ description=description
 
 
 
 
105
  ).launch()