NeMo / scripts /neural_machine_translation /collect_tokenizer_dataset_stats.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import multiprocessing as mp
import os
import numpy as np
from matplotlib import pyplot as plt
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
# =============================================================================#
# Auxiliary methods
# =============================================================================#
worker_data = {
"tokenizer": None,
}
def init_tokenizer(library, tokenizer_model):
tokenizer = get_nmt_tokenizer(library=library, tokenizer_model=tokenizer_model)
worker_data["tokenizer"] = tokenizer
def read_batch(fh, batch_size):
"""
Reads a batch (or smaller) chunk of lines.
"""
lines = []
for i in range(batch_size):
l = fh.readline()
if not l:
break
else:
lines.append(l.strip())
return lines
def tokenize_line(line, tokenizer):
"""
Returns a tokenized line
"""
line = line.rstrip("\n")
tokens = tokenizer.text_to_ids(line)
return tokens
def line_len(line, tokenizer=None):
"""
Returns a tokenized length of a text line
"""
if tokenizer is None:
tokenizer = worker_data["tokenizer"]
tokens = tokenize_line(line, tokenizer)
return len(tokens)
# =============================================================================#
# Main script
# =============================================================================#
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Collects statistics over tokenized dataset')
parser.add_argument('input_files', metavar='N', type=str, nargs='+', help='Input files to parse')
parser.add_argument(
'--tokenizer_library', type=str, required=True, help='Path to pre-trained nemo-supported tokenizer model'
)
parser.add_argument(
'--tokenizer_model', type=str, required=True, help='Path to pre-trained nemo-supported tokenizer model'
)
parser.add_argument(
'--num_workers', type=int, default=mp.cpu_count(), help='Number of workers (default to number of CPUs)'
)
parser.add_argument('--max_lines', type=int, default=-1, help='Max number of lines to parse')
parser.add_argument('--batch_size', type=int, default=10000000, help='Batch size to parse in parallel')
parser.add_argument('--out_dir', type=str, default="", help='Path to store data and plots')
args = parser.parse_args()
tokenizer = get_nmt_tokenizer(library=args.tokenizer_library, tokenizer_model=args.tokenizer_model,)
all_len = []
for fn in args.input_files:
print(f"Parsing fn = {fn}")
# read file
fh = open(fn)
# read all batches
while True:
lines = read_batch(fh, args.batch_size)
# move to next file when no lines are read
if not lines:
break
# tokenize lines
with mp.Pool(
args.num_workers, initializer=init_tokenizer, initargs=(args.tokenizer_library, args.tokenizer_model)
) as p:
all_len.extend(p.map(line_len, lines))
print(f"{fn}: Parsed {len(all_len)} lines")
# early stop, if required
if (args.max_lines > 0) and (len(all_len) >= args.max_lines):
lines = lines[: args.max_lines]
break
# early stop, if required
if (args.max_lines > 0) and (len(all_len) >= args.max_lines):
lines = lines[: args.max_lines]
break
# compute stats
# save all results
if args.out_dir:
os.makedirs(args.out_dir, exist_ok=True)
stats = {
"samples": int(len(all_len)),
"mean": float(np.mean(all_len)),
"stdev": float(np.std(all_len)),
"min": float(np.min(all_len)),
"max": float(np.max(all_len)),
"median": float(np.median(all_len)),
}
print(f"stats = \n{stats}")
# save all results
if args.out_dir:
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir, exist_ok=True)
fh = open(os.path.join(args.out_dir, "lengths.txt"), "w")
fh.writelines(["{l}\n".format(l=l) for l in all_len])
json.dump(stats, open(os.path.join(args.out_dir, "stats.json"), "w"))
fig = plt.hist(all_len)
plt.savefig(os.path.join(args.out_dir, "lengths_hist.pdf"))