add run_evaluation.py
Browse files- run_evaluation.py +5 -3
run_evaluation.py
CHANGED
@@ -16,13 +16,13 @@ python run_evaluation.y -m <wav2vec2 model_name> -d <Zindi dataset directory> -o
|
|
16 |
|
17 |
|
18 |
class KenLM:
|
19 |
-
def __init__(self, tokenizer, model_name, num_workers=8, beam_width=128):
|
20 |
self.num_workers = num_workers
|
21 |
self.beam_width = beam_width
|
22 |
vocab_dict = tokenizer.get_vocab()
|
23 |
self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
|
24 |
self.vocabulary = self.vocabulary[:-2]
|
25 |
-
self.decoder = build_ctcdecoder(self.vocabulary, model_name)
|
26 |
|
27 |
@staticmethod
|
28 |
def lm_postprocess(text):
|
@@ -52,6 +52,8 @@ def main():
|
|
52 |
help="Batch size")
|
53 |
parser.add_argument("-k", "--kenlm", type=str, required=False, default=False,
|
54 |
help="Path to KenLM model")
|
|
|
|
|
55 |
parser.add_argument("--num_workers", type=int, required=False, default=8,
|
56 |
help="KenLM's number of workers")
|
57 |
parser.add_argument("-w", "--beam_width", type=int, required=False, default=128,
|
@@ -67,7 +69,7 @@ def main():
|
|
67 |
model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
|
68 |
kenlm = None
|
69 |
if args.kenlm:
|
70 |
-
kenlm = KenLM(processor.tokenizer, args.kenlm)
|
71 |
|
72 |
# Preprocessing the datasets.
|
73 |
# We need to read the audio files as arrays
|
|
|
16 |
|
17 |
|
18 |
class KenLM:
|
19 |
+
def __init__(self, tokenizer, model_name, unigrams=None, num_workers=8, beam_width=128):
|
20 |
self.num_workers = num_workers
|
21 |
self.beam_width = beam_width
|
22 |
vocab_dict = tokenizer.get_vocab()
|
23 |
self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
|
24 |
self.vocabulary = self.vocabulary[:-2]
|
25 |
+
self.decoder = build_ctcdecoder(self.vocabulary, model_name, unigrams=unigrams)
|
26 |
|
27 |
@staticmethod
|
28 |
def lm_postprocess(text):
|
|
|
52 |
help="Batch size")
|
53 |
parser.add_argument("-k", "--kenlm", type=str, required=False, default=False,
|
54 |
help="Path to KenLM model")
|
55 |
+
parser.add_argument("-u", "--unigrams", type=str, required=False, default=False,
|
56 |
+
help="Path to unigrams file")
|
57 |
parser.add_argument("--num_workers", type=int, required=False, default=8,
|
58 |
help="KenLM's number of workers")
|
59 |
parser.add_argument("-w", "--beam_width", type=int, required=False, default=128,
|
|
|
69 |
model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
|
70 |
kenlm = None
|
71 |
if args.kenlm:
|
72 |
+
kenlm = KenLM(processor.tokenizer, args.kenlm, args.unigrams)
|
73 |
|
74 |
# Preprocessing the datasets.
|
75 |
# We need to read the audio files as arrays
|