lovemefan commited on
Commit
ec97ce5
·
1 Parent(s): b366fad

Upload 14 files

Browse files

upload project from https://github.com/lovemefan/CT-Transformer-punctuation

LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2014-2017 Alexey Popravka
4
+ Copyright (c) 2021 Sean Stewart
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ include cttpunctuator/src/onnx/configuration.json
2
+ include cttpunctuator/src/onnx/punc.onnx
3
+ include cttpunctuator/src/onnx/punc.yaml
README.md CHANGED
@@ -1,3 +1,122 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <br/>
4
+ <h2 align="center">Ctt punctuator</h2>
5
+ <br/>
6
+
7
+
8
+ ![python3.7](https://img.shields.io/badge/python-3.7-green.svg)
9
+ ![python3.8](https://img.shields.io/badge/python-3.8-green.svg)
10
+ ![python3.9](https://img.shields.io/badge/python-3.9-green.svg)
11
+ ![python3.10](https://img.shields.io/badge/python-3.10-green.svg)
12
+
13
+
14
+
15
+ A enterprise-grade Chinese-English code switch punctuator [funasr](https://github.com/alibaba-damo-academy/FunASR/).
16
+
17
+
18
+
19
+ <br/>
20
+ <h2 align="center">Key Features</h2>
21
+ <br/>
22
+
23
+ - **General**
24
+
25
+ ctt punctuator was trained on chinese-english code switch corpora.
26
+ - [x] offline punctuator
27
+ - [x] online punctuator
28
+ - [x] punctuator for chinese-english code switch
29
+
30
+ the onnx model file is 279M, you can download it from [here](https://github.com/lovemefan/CT-Transformer-punctuation/raw/main/cttpunctuator/src/onnx/punc.onnx)
31
+
32
+ - **Highly Portable**
33
+
34
+ ctt-punctuator reaps benefits from the rich ecosystems built around **ONNX** running everywhere where these runtimes are available.
35
+
36
+
37
+
38
+ ## Installation
39
+
40
+ ```bash
41
+ sudo apt install git-lfs
42
+ # if the code raise : failed:Protobuf parsing failed.
43
+ # you should install git-lfs and run git lfs install
44
+ git lfs install
45
+ # use lfs download onnx file
46
+ git clone https://github.com/lovemefan/CT-Transformer-punctuation.git
47
+ cd CT-Transformer-punctuation
48
+ pip install -e .
49
+ ```
50
+
51
+ ## Usage
52
+
53
+ ```python
54
+ from cttPunctuator import CttPunctuator
55
+ import logging
56
+ logging.basicConfig(
57
+ level=logging.INFO,
58
+ format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
59
+ )
60
+ # offline mode
61
+ punc = CttPunctuator()
62
+ text = "据报道纽约时报使用ChatGPT创建了一个情人节消息生成器用户只需输入几个提示就可以得到一封自动生成的情书"
63
+ logging.info(punc.punctuate(text)[0])
64
+
65
+ # online mode
66
+ punc = CttPunctuator(online=True)
67
+ text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
68
+
69
+ vads = text_in.split("|")
70
+ rec_result_all = ""
71
+ param_dict = {"cache": []}
72
+ for vad in vads:
73
+ result = punc.punctuate(vad, param_dict=param_dict)
74
+ rec_result_all += result[0]
75
+ logging.info(f"Part: {rec_result_all}")
76
+
77
+ logging.info(f"Final: {rec_result_all}")
78
+ ```
79
+ ## Result
80
+ ```bash
81
+ [2023-04-19 01:12:39,308 INFO] [ctt-punctuator.py:50 ctt-punctuator.__init__] Initializing punctuator model with offline mode.
82
+ [2023-04-19 01:12:55,854 INFO] [ctt-punctuator.py:52 ctt-punctuator.__init__] Offline model initialized.
83
+ [2023-04-19 01:12:55,854 INFO] [ctt-punctuator.py:55 ctt-punctuator.__init__] Model initialized.
84
+ [2023-04-19 01:12:55,868 INFO] [ctt-punctuator.py:67 ctt-punctuator.<module>] 据报道,纽约时报使用ChatGPT创建了一个情人节消息生成器,用户只需输入几个提示,就可以得到一封自动生成的情书。
85
+ [2023-04-19 01:12:55,868 INFO] [ctt-punctuator.py:40 ctt-punctuator.__init__] Initializing punctuator model with online mode.
86
+ [2023-04-19 01:13:12,499 INFO] [ctt-punctuator.py:43 ctt-punctuator.__init__] Online model initialized.
87
+ [2023-04-19 01:13:12,499 INFO] [ctt-punctuator.py:55 ctt-punctuator.__init__] Model initialized.
88
+ [2023-04-19 01:13:12,502 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸
89
+ [2023-04-19 01:13:12,508 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员
90
+ [2023-04-19 01:13:12,521 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险
91
+ [2023-04-19 01:13:12,547 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切
92
+ [2023-04-19 01:13:12,553 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制
93
+ [2023-04-19 01:13:12,559 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是
94
+ [2023-04-19 01:13:12,560 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们
95
+ [2023-04-19 01:13:12,567 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的
96
+ [2023-04-19 01:13:12,572 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学
97
+ [2023-04-19 01:13:12,578 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学规划和论证,兼顾上下游的利益
98
+ [2023-04-19 01:13:12,578 INFO] [ctt-punctuator.py:79 ctt-punctuator.<module>] Final: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学规划和论证,兼顾上下游的利益
99
+ ```
100
+
101
+ ## Citation
102
+ ```
103
+ @inproceedings{chen2020controllable,
104
+ title={Controllable Time-Delay Transformer for Real-Time Punctuation Prediction and Disfluency Detection},
105
+ author={Chen, Qian and Chen, Mengzhe and Li, Bo and Wang, Wen},
106
+ booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
107
+ pages={8069--8073},
108
+ year={2020},
109
+ organization={IEEE}
110
+ }
111
+ ```
112
+ ```
113
+ @misc{FunASR,
114
+ author = {Speech Lab, Alibaba Group, China},
115
+ title = {FunASR: A Fundamental End-to-End Speech Recognition Toolkit},
116
+ year = {2023},
117
+ publisher = {GitHub},
118
+ journal = {GitHub repository},
119
+ howpublished = {\url{https://github.com/alibaba-damo-academy/FunASR/}},
120
+ }
121
+
122
+ ```
cttPunctuator.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :ctt-punctuator.py
3
+ # @Time :2023/4/13 15:03
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
6
+
7
+
8
+ __author__ = "lovemefan"
9
+ __copyright__ = "Copyright (C) 2023 lovemefan"
10
+ __license__ = "MIT"
11
+ __version__ = "v0.0.1"
12
+
13
+ import logging
14
+ import threading
15
+
16
+ from cttpunctuator.src.punctuator import (CT_Transformer,
17
+ CT_Transformer_VadRealtime)
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
22
+ )
23
+
24
+ lock = threading.RLock()
25
+
26
+
27
+ class CttPunctuator:
28
+ _offline_model = None
29
+ _online_model = None
30
+
31
+ def __init__(self, online: bool = False):
32
+ """
33
+ punctuator with singleton pattern
34
+ :param online:
35
+ """
36
+ self.online = online
37
+
38
+ if online:
39
+ if CttPunctuator._online_model is None:
40
+ with lock:
41
+ if CttPunctuator._online_model is None:
42
+ logging.info("Initializing punctuator model with online mode.")
43
+ CttPunctuator._online_model = CT_Transformer_VadRealtime()
44
+ self.param_dict = {"cache": []}
45
+ logging.info("Online model initialized.")
46
+ self.model = CttPunctuator._online_model
47
+
48
+ else:
49
+ if CttPunctuator._offline_model is None:
50
+ with lock:
51
+ if CttPunctuator._offline_model is None:
52
+ logging.info("Initializing punctuator model with offline mode.")
53
+ CttPunctuator._offline_model = CT_Transformer()
54
+ logging.info("Offline model initialized.")
55
+ self.model = CttPunctuator._offline_model
56
+
57
+ logging.info("Model initialized.")
58
+
59
+ def punctuate(self, text: str, param_dict=None):
60
+ if self.online:
61
+ param_dict = param_dict or self.param_dict
62
+ return self.model(text, self.param_dict)
63
+ else:
64
+ return self.model(text)
cttpunctuator/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :__init__.py.py
3
+ # @Time :2023/4/13 14:58
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
cttpunctuator/src/onnx/configuration.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "framework": "onnx",
3
+ "task" : "punctuation",
4
+ "model" : {
5
+ "type" : "generic-punc",
6
+ "punc_model_name" : "punc.pb",
7
+ "punc_model_config" : {
8
+ "type": "pytorch",
9
+ "code_base": "funasr",
10
+ "mode": "punc",
11
+ "lang": "zh-cn",
12
+ "batch_size": 1,
13
+ "punc_config": "punc.yaml",
14
+ "model": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
15
+ }
16
+ },
17
+ "pipeline": {
18
+ "type":"punc-inference"
19
+ }
20
+ }
cttpunctuator/src/onnx/punc.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06ae02f3fce2d6bfbcdd988672467808c0113e77b6eed7dc52835ff627e12330
3
+ size 292007354
cttpunctuator/src/onnx/punc.yaml ADDED
The diff for this file is too large to render. See raw diff
 
cttpunctuator/src/punctuator.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path
3
+ from pathlib import Path
4
+ from typing import Tuple, Union
5
+
6
+ import numpy as np
7
+
8
+ from cttpunctuator.src.utils.OrtInferSession import (ONNXRuntimeError,
9
+ OrtInferSession)
10
+ from cttpunctuator.src.utils.text_post_process import (TokenIDConverter,
11
+ code_mix_split_words,
12
+ read_yaml,
13
+ split_to_mini_sentence)
14
+
15
+
16
+ class CT_Transformer:
17
+ """
18
+ Author: Speech Lab, Alibaba Group, China
19
+ CT-Transformer: Controllable time-delay transformer
20
+ for real-time punctuation prediction and disfluency detection
21
+ https://arxiv.org/pdf/2003.01309.pdf
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model_dir: Union[str, Path] = None,
27
+ batch_size: int = 1,
28
+ device_id: Union[str, int] = "-1",
29
+ quantize: bool = False,
30
+ intra_op_num_threads: int = 4,
31
+ ):
32
+ model_dir = model_dir or os.path.join(os.path.dirname(__file__), "onnx")
33
+ if model_dir is None or not Path(model_dir).exists():
34
+ raise FileNotFoundError(f"{model_dir} does not exist.")
35
+
36
+ model_file = os.path.join(model_dir, "punc.onnx")
37
+ if quantize:
38
+ model_file = os.path.join(model_dir, "model_quant.onnx")
39
+ config_file = os.path.join(model_dir, "punc.yaml")
40
+ config = read_yaml(config_file)
41
+
42
+ self.converter = TokenIDConverter(config["token_list"])
43
+ self.ort_infer = OrtInferSession(
44
+ model_file, device_id, intra_op_num_threads=intra_op_num_threads
45
+ )
46
+ self.batch_size = 1
47
+ self.punc_list = config["punc_list"]
48
+ self.period = 0
49
+ for i in range(len(self.punc_list)):
50
+ if self.punc_list[i] == ",":
51
+ self.punc_list[i] = ","
52
+ elif self.punc_list[i] == "?":
53
+ self.punc_list[i] = "?"
54
+ elif self.punc_list[i] == "。":
55
+ self.period = i
56
+
57
+ def __call__(self, text: Union[list, str], split_size=20):
58
+ split_text = code_mix_split_words(text)
59
+ split_text_id = self.converter.tokens2ids(split_text)
60
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
61
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
62
+ assert len(mini_sentences) == len(mini_sentences_id)
63
+ cache_sent = []
64
+ cache_sent_id = []
65
+ new_mini_sentence = ""
66
+ new_mini_sentence_punc = []
67
+ cache_pop_trigger_limit = 200
68
+ for mini_sentence_i in range(len(mini_sentences)):
69
+ mini_sentence = mini_sentences[mini_sentence_i]
70
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
71
+ mini_sentence = cache_sent + mini_sentence
72
+ mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype="int64")
73
+ data = {
74
+ "text": mini_sentence_id[None, :],
75
+ "text_lengths": np.array([len(mini_sentence_id)], dtype="int32"),
76
+ }
77
+ try:
78
+ outputs = self.infer(data["text"], data["text_lengths"])
79
+ y = outputs[0]
80
+ punctuations = np.argmax(y, axis=-1)[0]
81
+ assert punctuations.size == len(mini_sentence)
82
+ except ONNXRuntimeError:
83
+ logging.warning("error")
84
+
85
+ # Search for the last Period/QuestionMark as cache
86
+ if mini_sentence_i < len(mini_sentences) - 1:
87
+ sentenceEnd = -1
88
+ last_comma_index = -1
89
+ for i in range(len(punctuations) - 2, 1, -1):
90
+ if (
91
+ self.punc_list[punctuations[i]] == "。"
92
+ or self.punc_list[punctuations[i]] == "?"
93
+ ):
94
+ sentenceEnd = i
95
+ break
96
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
97
+ last_comma_index = i
98
+
99
+ if (
100
+ sentenceEnd < 0
101
+ and len(mini_sentence) > cache_pop_trigger_limit
102
+ and last_comma_index >= 0
103
+ ):
104
+ # The sentence it too long, cut off at a comma.
105
+ sentenceEnd = last_comma_index
106
+ punctuations[sentenceEnd] = self.period
107
+ cache_sent = mini_sentence[sentenceEnd + 1 :]
108
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1 :].tolist()
109
+ mini_sentence = mini_sentence[0 : sentenceEnd + 1]
110
+ punctuations = punctuations[0 : sentenceEnd + 1]
111
+
112
+ new_mini_sentence_punc += [int(x) for x in punctuations]
113
+ words_with_punc = []
114
+ for i in range(len(mini_sentence)):
115
+ if i > 0:
116
+ if (
117
+ len(mini_sentence[i][0].encode()) == 1
118
+ and len(mini_sentence[i - 1][0].encode()) == 1
119
+ ):
120
+ mini_sentence[i] = " " + mini_sentence[i]
121
+ words_with_punc.append(mini_sentence[i])
122
+ if self.punc_list[punctuations[i]] != "_":
123
+ words_with_punc.append(self.punc_list[punctuations[i]])
124
+ new_mini_sentence += "".join(words_with_punc)
125
+ # Add Period for the end of the sentence
126
+ new_mini_sentence_out = new_mini_sentence
127
+ new_mini_sentence_punc_out = new_mini_sentence_punc
128
+ if mini_sentence_i == len(mini_sentences) - 1:
129
+ if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
130
+ new_mini_sentence_out = new_mini_sentence[:-1] + "。"
131
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
132
+ self.period
133
+ ]
134
+ elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
135
+ new_mini_sentence_out = new_mini_sentence + "。"
136
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
137
+ self.period
138
+ ]
139
+ return new_mini_sentence_out, new_mini_sentence_punc_out
140
+
141
+ def infer(
142
+ self, feats: np.ndarray, feats_len: np.ndarray
143
+ ) -> Tuple[np.ndarray, np.ndarray]:
144
+ outputs = self.ort_infer([feats, feats_len])
145
+ return outputs
146
+
147
+
148
+ class CT_Transformer_VadRealtime(CT_Transformer):
149
+ """
150
+ Author: Speech Lab, Alibaba Group, China
151
+ CT-Transformer: Controllable time-delay transformer for
152
+ real-time punctuation prediction and disfluency detection
153
+ https://arxiv.org/pdf/2003.01309.pdf
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ model_dir: Union[str, Path] = None,
159
+ batch_size: int = 1,
160
+ device_id: Union[str, int] = "-1",
161
+ quantize: bool = False,
162
+ intra_op_num_threads: int = 4,
163
+ ):
164
+ super(CT_Transformer_VadRealtime, self).__init__(
165
+ model_dir, batch_size, device_id, quantize, intra_op_num_threads
166
+ )
167
+
168
+ def __call__(self, text: str, param_dict: map, split_size=20):
169
+ cache_key = "cache"
170
+ assert cache_key in param_dict
171
+ cache = param_dict[cache_key]
172
+ if cache is not None and len(cache) > 0:
173
+ precache = "".join(cache)
174
+ else:
175
+ precache = ""
176
+ cache = []
177
+ full_text = precache + text
178
+ split_text = code_mix_split_words(full_text)
179
+ split_text_id = self.converter.tokens2ids(split_text)
180
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
181
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
182
+ new_mini_sentence_punc = []
183
+ assert len(mini_sentences) == len(mini_sentences_id)
184
+
185
+ cache_sent = []
186
+ cache_sent_id = np.array([], dtype="int32")
187
+ sentence_punc_list = []
188
+ sentence_words_list = []
189
+ cache_pop_trigger_limit = 200
190
+ skip_num = 0
191
+ for mini_sentence_i in range(len(mini_sentences)):
192
+ mini_sentence = mini_sentences[mini_sentence_i]
193
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
194
+ mini_sentence = cache_sent + mini_sentence
195
+ mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
196
+ text_length = len(mini_sentence_id)
197
+ data = {
198
+ "input": mini_sentence_id[None, :],
199
+ "text_lengths": np.array([text_length], dtype="int32"),
200
+ "vad_mask": self.vad_mask(text_length, len(cache))[
201
+ None, None, :, :
202
+ ].astype(np.float32),
203
+ "sub_masks": np.tril(
204
+ np.ones((text_length, text_length), dtype=np.float32)
205
+ )[None, None, :, :].astype(np.float32),
206
+ }
207
+ try:
208
+ outputs = self.infer(
209
+ data["input"],
210
+ data["text_lengths"],
211
+ data["vad_mask"],
212
+ data["sub_masks"],
213
+ )
214
+ y = outputs[0]
215
+ punctuations = np.argmax(y, axis=-1)[0]
216
+ assert punctuations.size == len(mini_sentence)
217
+ except ONNXRuntimeError:
218
+ logging.warning("error")
219
+
220
+ # Search for the last Period/QuestionMark as cache
221
+ if mini_sentence_i < len(mini_sentences) - 1:
222
+ sentenceEnd = -1
223
+ last_comma_index = -1
224
+ for i in range(len(punctuations) - 2, 1, -1):
225
+ if (
226
+ self.punc_list[punctuations[i]] == "。"
227
+ or self.punc_list[punctuations[i]] == "?"
228
+ ):
229
+ sentenceEnd = i
230
+ break
231
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
232
+ last_comma_index = i
233
+
234
+ if (
235
+ sentenceEnd < 0
236
+ and len(mini_sentence) > cache_pop_trigger_limit
237
+ and last_comma_index >= 0
238
+ ):
239
+ # The sentence it too long, cut off at a comma.
240
+ sentenceEnd = last_comma_index
241
+ punctuations[sentenceEnd] = self.period
242
+ cache_sent = mini_sentence[sentenceEnd + 1 :]
243
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
244
+ mini_sentence = mini_sentence[0 : sentenceEnd + 1]
245
+ punctuations = punctuations[0 : sentenceEnd + 1]
246
+
247
+ punctuations_np = [int(x) for x in punctuations]
248
+ new_mini_sentence_punc += punctuations_np
249
+ sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
250
+ sentence_words_list += mini_sentence
251
+
252
+ assert len(sentence_punc_list) == len(sentence_words_list)
253
+ words_with_punc = []
254
+ sentence_punc_list_out = []
255
+ for i in range(0, len(sentence_words_list)):
256
+ if i > 0:
257
+ if (
258
+ len(sentence_words_list[i][0].encode()) == 1
259
+ and len(sentence_words_list[i - 1][-1].encode()) == 1
260
+ ):
261
+ sentence_words_list[i] = " " + sentence_words_list[i]
262
+ if skip_num < len(cache):
263
+ skip_num += 1
264
+ else:
265
+ words_with_punc.append(sentence_words_list[i])
266
+ if skip_num >= len(cache):
267
+ sentence_punc_list_out.append(sentence_punc_list[i])
268
+ if sentence_punc_list[i] != "_":
269
+ words_with_punc.append(sentence_punc_list[i])
270
+ sentence_out = "".join(words_with_punc)
271
+
272
+ sentenceEnd = -1
273
+ for i in range(len(sentence_punc_list) - 2, 1, -1):
274
+ if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
275
+ sentenceEnd = i
276
+ break
277
+ cache_out = sentence_words_list[sentenceEnd + 1 :]
278
+ if sentence_out[-1] in self.punc_list:
279
+ sentence_out = sentence_out[:-1]
280
+ sentence_punc_list_out[-1] = "_"
281
+ param_dict[cache_key] = cache_out
282
+ return sentence_out, sentence_punc_list_out, cache_out
283
+
284
+ def vad_mask(self, size, vad_pos, dtype=np.bool_):
285
+ """Create mask for decoder self-attention.
286
+
287
+ :param int size: size of mask
288
+ :param int vad_pos: index of vad index
289
+ :param torch.dtype dtype: result dtype
290
+ :rtype: torch.Tensor (B, Lmax, Lmax)
291
+ """
292
+ ret = np.ones((size, size), dtype=dtype)
293
+ if vad_pos <= 0 or vad_pos >= size:
294
+ return ret
295
+ sub_corner = np.zeros((vad_pos - 1, size - vad_pos), dtype=dtype)
296
+ ret[0 : vad_pos - 1, vad_pos:] = sub_corner
297
+ return ret
298
+
299
+ def infer(
300
+ self,
301
+ feats: np.ndarray,
302
+ feats_len: np.ndarray,
303
+ vad_mask: np.ndarray,
304
+ sub_masks: np.ndarray,
305
+ ) -> Tuple[np.ndarray, np.ndarray]:
306
+ outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
307
+ return outputs
cttpunctuator/src/utils/OrtInferSession.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :OrtInferSession.py
3
+ # @Time :2023/4/13 15:13
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import List, Union
9
+
10
+ import numpy as np
11
+ from onnxruntime import (GraphOptimizationLevel, InferenceSession,
12
+ SessionOptions, get_available_providers, get_device)
13
+
14
+
15
+ class ONNXRuntimeError(Exception):
16
+ pass
17
+
18
+
19
+ class OrtInferSession:
20
+ def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
21
+ device_id = str(device_id)
22
+ sess_opt = SessionOptions()
23
+ sess_opt.intra_op_num_threads = intra_op_num_threads
24
+ sess_opt.log_severity_level = 4
25
+ sess_opt.enable_cpu_mem_arena = False
26
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
27
+
28
+ cuda_ep = "CUDAExecutionProvider"
29
+ cuda_provider_options = {
30
+ "device_id": device_id,
31
+ "arena_extend_strategy": "kNextPowerOfTwo",
32
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
33
+ "do_copy_in_default_stream": "true",
34
+ }
35
+ cpu_ep = "CPUExecutionProvider"
36
+ cpu_provider_options = {
37
+ "arena_extend_strategy": "kSameAsRequested",
38
+ }
39
+
40
+ EP_list = []
41
+ if (
42
+ device_id != "-1"
43
+ and get_device() == "GPU"
44
+ and cuda_ep in get_available_providers()
45
+ ):
46
+ EP_list = [(cuda_ep, cuda_provider_options)]
47
+ EP_list.append((cpu_ep, cpu_provider_options))
48
+
49
+ self._verify_model(model_file)
50
+ self.session = InferenceSession(
51
+ model_file, sess_options=sess_opt, providers=EP_list
52
+ )
53
+
54
+ if device_id != "-1" and cuda_ep not in self.session.get_providers():
55
+ logging.warnings.warn(
56
+ f"{cuda_ep} is not avaiable for current env, "
57
+ f"the inference part is automatically shifted to be executed under {cpu_ep}.\n"
58
+ "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
59
+ "you can check their relations from the offical web site: "
60
+ "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
61
+ RuntimeWarning,
62
+ )
63
+
64
+ def __call__(
65
+ self, input_content: List[Union[np.ndarray, np.ndarray]]
66
+ ) -> np.ndarray:
67
+ input_dict = dict(zip(self.get_input_names(), input_content))
68
+ try:
69
+ return self.session.run(self.get_output_names(), input_dict)
70
+ except Exception as e:
71
+ raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
72
+
73
+ def get_input_names(
74
+ self,
75
+ ):
76
+ return [v.name for v in self.session.get_inputs()]
77
+
78
+ def get_output_names(
79
+ self,
80
+ ):
81
+ return [v.name for v in self.session.get_outputs()]
82
+
83
+ def get_character_list(self, key: str = "character"):
84
+ return self.meta_dict[key].splitlines()
85
+
86
+ def have_key(self, key: str = "character") -> bool:
87
+ self.meta_dict = self.session.get_modelmeta().custom_metadata_map
88
+ if key in self.meta_dict.keys():
89
+ return True
90
+ return False
91
+
92
+ @staticmethod
93
+ def _verify_model(model_path):
94
+ model_path = Path(model_path)
95
+ if not model_path.exists():
96
+ raise FileNotFoundError(f"{model_path} does not exists.")
97
+ if not model_path.is_file():
98
+ raise FileExistsError(f"{model_path} is not a file.")
cttpunctuator/src/utils/text_post_process.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :text_post_process.py
3
+ # @Time :2023/4/13 15:09
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
6
+ from pathlib import Path
7
+ from typing import Dict, Iterable, List, Union
8
+
9
+ import numpy as np
10
+ import yaml
11
+ from typeguard import check_argument_types
12
+
13
+
14
+ class TokenIDConverterError(Exception):
15
+ pass
16
+
17
+
18
+ class TokenIDConverter:
19
+ def __init__(
20
+ self,
21
+ token_list: Union[List, str],
22
+ ):
23
+ check_argument_types()
24
+
25
+ self.token_list = token_list
26
+ self.unk_symbol = token_list[-1]
27
+ self.token2id = {v: i for i, v in enumerate(self.token_list)}
28
+ self.unk_id = self.token2id[self.unk_symbol]
29
+
30
+ def get_num_vocabulary_size(self) -> int:
31
+ return len(self.token_list)
32
+
33
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
34
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
35
+ raise TokenIDConverterError(
36
+ f"Must be 1 dim ndarray, but got {integers.ndim}"
37
+ )
38
+ return [self.token_list[i] for i in integers]
39
+
40
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
41
+
42
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
43
+
44
+
45
+ def split_to_mini_sentence(words: list, word_limit: int = 20):
46
+ assert word_limit > 1
47
+ if len(words) <= word_limit:
48
+ return [words]
49
+ sentences = []
50
+ length = len(words)
51
+ sentence_len = length // word_limit
52
+ for i in range(sentence_len):
53
+ sentences.append(words[i * word_limit : (i + 1) * word_limit])
54
+ if length % word_limit > 0:
55
+ sentences.append(words[sentence_len * word_limit :])
56
+ return sentences
57
+
58
+
59
+ def code_mix_split_words(text: str):
60
+ words = []
61
+ segs = text.split()
62
+ for seg in segs:
63
+ # There is no space in seg.
64
+ current_word = ""
65
+ for c in seg:
66
+ if len(c.encode()) == 1:
67
+ # This is an ASCII char.
68
+ current_word += c
69
+ else:
70
+ # This is a Chinese char.
71
+ if len(current_word) > 0:
72
+ words.append(current_word)
73
+ current_word = ""
74
+ words.append(c)
75
+ if len(current_word) > 0:
76
+ words.append(current_word)
77
+ return words
78
+
79
+
80
+ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
81
+ if not Path(yaml_path).exists():
82
+ raise FileExistsError(f"The {yaml_path} does not exist.")
83
+
84
+ with open(str(yaml_path), "rb") as f:
85
+ data = yaml.load(f, Loader=yaml.Loader)
86
+ return data
setup.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :setup.py
3
+ # @Time :2023/4/4 11:22
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
6
+ import os
7
+ from pathlib import Path
8
+
9
+ from setuptools import find_namespace_packages, setup
10
+
11
+ dirname = Path(os.path.dirname(__file__))
12
+ version_file = dirname / "version.txt"
13
+ with open(version_file, "r") as f:
14
+ version = f.read().strip()
15
+
16
+ requirements = {
17
+ "install": [
18
+ "setuptools<=65.0",
19
+ "PyYAML",
20
+ "typeguard==2.13.3",
21
+ "onnxruntime==1.14.1",
22
+ ],
23
+ "setup": [
24
+ "numpy==1.24.2",
25
+ ],
26
+ "all": [],
27
+ }
28
+ requirements["all"].extend(requirements["install"])
29
+
30
+ install_requires = requirements["install"]
31
+ setup_requires = requirements["setup"]
32
+
33
+
34
+ setup(
35
+ name="cttpunctuator",
36
+ version=version,
37
+ url="https://github.com/lovemefan/CT-Transformer-punctuation",
38
+ author="Lovemefan, Yunnan Key Laboratory of Artificial Intelligence, "
39
+ "Kunming University of Science and Technology, Kunming, Yunnan ",
40
+ author_email="[email protected]",
41
+ description="ctt-punctuator: A enterprise-grade punctuator after chinese asr based "
42
+ "on ct-transformer from funasr opensource",
43
+ long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
44
+ long_description_content_type="text/markdown",
45
+ license="The MIT License",
46
+ packages=find_namespace_packages(),
47
+ include_package_data=True,
48
+ install_requires=install_requires,
49
+ python_requires=">=3.7.0",
50
+ classifiers=[
51
+ "Programming Language :: Python",
52
+ "Programming Language :: Python :: 3.7",
53
+ "Programming Language :: Python :: 3.8",
54
+ "Programming Language :: Python :: 3.9",
55
+ "Programming Language :: Python :: 3.10",
56
+ "Development Status :: 5 - Production/Stable",
57
+ "Intended Audience :: Science/Research",
58
+ "Operating System :: POSIX :: Linux",
59
+ "License :: OSI Approved :: Apache Software License",
60
+ "Topic :: Multimedia :: Sound/Audio :: Speech",
61
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
62
+ "Topic :: Software Development :: Libraries :: Python Modules",
63
+ ],
64
+ )
test/test.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :test.py.py
3
+ # @Time :2023/4/19 13:39
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
6
+
7
+ import logging
8
+
9
+ from cttPunctuator import CttPunctuator
10
+
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
14
+ )
15
+ # offline mode
16
+ punc = CttPunctuator()
17
+ text = "据报道纽约时报使用ChatGPT创建了一个情人节消息生成器用户只需输入几个提示就可以得到一封自动生成的情书"
18
+ logging.info(punc.punctuate(text)[0])
19
+
20
+ # online mode
21
+ punc = CttPunctuator(online=True)
22
+ text_in = (
23
+ "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|"
24
+ "在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|"
25
+ "向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|"
26
+ "愿意进一步完善双方联合工作机制|凡是|中方能做的我们|"
27
+ "都会去做而且会做得更好我请印度朋友们放心中国在上游的|"
28
+ "任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
29
+ )
30
+
31
+ vads = text_in.split("|")
32
+ rec_result_all = ""
33
+ for vad in vads:
34
+ result = punc.punctuate(vad)
35
+ rec_result_all += result[0]
36
+ logging.info(f"Part: {rec_result_all}")
37
+
38
+ logging.info(f"Final: {rec_result_all}")
version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.0.1