Upload 14 files
Browse filesupload project from https://github.com/lovemefan/CT-Transformer-punctuation
- LICENSE +22 -0
- MANIFEST.in +3 -0
- README.md +122 -3
- cttPunctuator.py +64 -0
- cttpunctuator/__init__.py +5 -0
- cttpunctuator/src/onnx/configuration.json +20 -0
- cttpunctuator/src/onnx/punc.onnx +3 -0
- cttpunctuator/src/onnx/punc.yaml +0 -0
- cttpunctuator/src/punctuator.py +307 -0
- cttpunctuator/src/utils/OrtInferSession.py +98 -0
- cttpunctuator/src/utils/text_post_process.py +86 -0
- setup.py +64 -0
- test/test.py +38 -0
- version.txt +1 -0
LICENSE
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
|
3 |
+
Copyright (c) 2014-2017 Alexey Popravka
|
4 |
+
Copyright (c) 2021 Sean Stewart
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
MANIFEST.in
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include cttpunctuator/src/onnx/configuration.json
|
2 |
+
include cttpunctuator/src/onnx/punc.onnx
|
3 |
+
include cttpunctuator/src/onnx/punc.yaml
|
README.md
CHANGED
@@ -1,3 +1,122 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
<br/>
|
4 |
+
<h2 align="center">Ctt punctuator</h2>
|
5 |
+
<br/>
|
6 |
+
|
7 |
+
|
8 |
+

|
9 |
+

|
10 |
+

|
11 |
+

|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
A enterprise-grade Chinese-English code switch punctuator [funasr](https://github.com/alibaba-damo-academy/FunASR/).
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
<br/>
|
20 |
+
<h2 align="center">Key Features</h2>
|
21 |
+
<br/>
|
22 |
+
|
23 |
+
- **General**
|
24 |
+
|
25 |
+
ctt punctuator was trained on chinese-english code switch corpora.
|
26 |
+
- [x] offline punctuator
|
27 |
+
- [x] online punctuator
|
28 |
+
- [x] punctuator for chinese-english code switch
|
29 |
+
|
30 |
+
the onnx model file is 279M, you can download it from [here](https://github.com/lovemefan/CT-Transformer-punctuation/raw/main/cttpunctuator/src/onnx/punc.onnx)
|
31 |
+
|
32 |
+
- **Highly Portable**
|
33 |
+
|
34 |
+
ctt-punctuator reaps benefits from the rich ecosystems built around **ONNX** running everywhere where these runtimes are available.
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
## Installation
|
39 |
+
|
40 |
+
```bash
|
41 |
+
sudo apt install git-lfs
|
42 |
+
# if the code raise : failed:Protobuf parsing failed.
|
43 |
+
# you should install git-lfs and run git lfs install
|
44 |
+
git lfs install
|
45 |
+
# use lfs download onnx file
|
46 |
+
git clone https://github.com/lovemefan/CT-Transformer-punctuation.git
|
47 |
+
cd CT-Transformer-punctuation
|
48 |
+
pip install -e .
|
49 |
+
```
|
50 |
+
|
51 |
+
## Usage
|
52 |
+
|
53 |
+
```python
|
54 |
+
from cttPunctuator import CttPunctuator
|
55 |
+
import logging
|
56 |
+
logging.basicConfig(
|
57 |
+
level=logging.INFO,
|
58 |
+
format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
|
59 |
+
)
|
60 |
+
# offline mode
|
61 |
+
punc = CttPunctuator()
|
62 |
+
text = "据报道纽约时报使用ChatGPT创建了一个情人节消息生成器用户只需输入几个提示就可以得到一封自动生成的情书"
|
63 |
+
logging.info(punc.punctuate(text)[0])
|
64 |
+
|
65 |
+
# online mode
|
66 |
+
punc = CttPunctuator(online=True)
|
67 |
+
text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
|
68 |
+
|
69 |
+
vads = text_in.split("|")
|
70 |
+
rec_result_all = ""
|
71 |
+
param_dict = {"cache": []}
|
72 |
+
for vad in vads:
|
73 |
+
result = punc.punctuate(vad, param_dict=param_dict)
|
74 |
+
rec_result_all += result[0]
|
75 |
+
logging.info(f"Part: {rec_result_all}")
|
76 |
+
|
77 |
+
logging.info(f"Final: {rec_result_all}")
|
78 |
+
```
|
79 |
+
## Result
|
80 |
+
```bash
|
81 |
+
[2023-04-19 01:12:39,308 INFO] [ctt-punctuator.py:50 ctt-punctuator.__init__] Initializing punctuator model with offline mode.
|
82 |
+
[2023-04-19 01:12:55,854 INFO] [ctt-punctuator.py:52 ctt-punctuator.__init__] Offline model initialized.
|
83 |
+
[2023-04-19 01:12:55,854 INFO] [ctt-punctuator.py:55 ctt-punctuator.__init__] Model initialized.
|
84 |
+
[2023-04-19 01:12:55,868 INFO] [ctt-punctuator.py:67 ctt-punctuator.<module>] 据报道,纽约时报使用ChatGPT创建了一个情人节消息生成器,用户只需输入几个提示,就可以得到一封自动生成的情书。
|
85 |
+
[2023-04-19 01:12:55,868 INFO] [ctt-punctuator.py:40 ctt-punctuator.__init__] Initializing punctuator model with online mode.
|
86 |
+
[2023-04-19 01:13:12,499 INFO] [ctt-punctuator.py:43 ctt-punctuator.__init__] Online model initialized.
|
87 |
+
[2023-04-19 01:13:12,499 INFO] [ctt-punctuator.py:55 ctt-punctuator.__init__] Model initialized.
|
88 |
+
[2023-04-19 01:13:12,502 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸
|
89 |
+
[2023-04-19 01:13:12,508 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员
|
90 |
+
[2023-04-19 01:13:12,521 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险
|
91 |
+
[2023-04-19 01:13:12,547 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切
|
92 |
+
[2023-04-19 01:13:12,553 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制
|
93 |
+
[2023-04-19 01:13:12,559 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是
|
94 |
+
[2023-04-19 01:13:12,560 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们
|
95 |
+
[2023-04-19 01:13:12,567 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的
|
96 |
+
[2023-04-19 01:13:12,572 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学
|
97 |
+
[2023-04-19 01:13:12,578 INFO] [ctt-punctuator.py:77 ctt-punctuator.<module>] Partial: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学规划和论证,兼顾上下游的利益
|
98 |
+
[2023-04-19 01:13:12,578 INFO] [ctt-punctuator.py:79 ctt-punctuator.<module>] Final: 跨境河流是养育沿岸人民的生命之源。长期以来,为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难,甚至冒着生命危险,向印方提供汛期水文资料处理紧急事件。中方重视印方在跨境河流>问题上的关切,愿意进一步完善双方联合工作机制。凡是中方能做的,我们都会去做,而且会做得更好。我请印度朋友们放心,中国在上游的任何开发利用,都会经过科学规划和论证,兼顾上下游的利益
|
99 |
+
```
|
100 |
+
|
101 |
+
## Citation
|
102 |
+
```
|
103 |
+
@inproceedings{chen2020controllable,
|
104 |
+
title={Controllable Time-Delay Transformer for Real-Time Punctuation Prediction and Disfluency Detection},
|
105 |
+
author={Chen, Qian and Chen, Mengzhe and Li, Bo and Wang, Wen},
|
106 |
+
booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
|
107 |
+
pages={8069--8073},
|
108 |
+
year={2020},
|
109 |
+
organization={IEEE}
|
110 |
+
}
|
111 |
+
```
|
112 |
+
```
|
113 |
+
@misc{FunASR,
|
114 |
+
author = {Speech Lab, Alibaba Group, China},
|
115 |
+
title = {FunASR: A Fundamental End-to-End Speech Recognition Toolkit},
|
116 |
+
year = {2023},
|
117 |
+
publisher = {GitHub},
|
118 |
+
journal = {GitHub repository},
|
119 |
+
howpublished = {\url{https://github.com/alibaba-damo-academy/FunASR/}},
|
120 |
+
}
|
121 |
+
|
122 |
+
```
|
cttPunctuator.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :ctt-punctuator.py
|
3 |
+
# @Time :2023/4/13 15:03
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
6 |
+
|
7 |
+
|
8 |
+
__author__ = "lovemefan"
|
9 |
+
__copyright__ = "Copyright (C) 2023 lovemefan"
|
10 |
+
__license__ = "MIT"
|
11 |
+
__version__ = "v0.0.1"
|
12 |
+
|
13 |
+
import logging
|
14 |
+
import threading
|
15 |
+
|
16 |
+
from cttpunctuator.src.punctuator import (CT_Transformer,
|
17 |
+
CT_Transformer_VadRealtime)
|
18 |
+
|
19 |
+
logging.basicConfig(
|
20 |
+
level=logging.INFO,
|
21 |
+
format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
|
22 |
+
)
|
23 |
+
|
24 |
+
lock = threading.RLock()
|
25 |
+
|
26 |
+
|
27 |
+
class CttPunctuator:
|
28 |
+
_offline_model = None
|
29 |
+
_online_model = None
|
30 |
+
|
31 |
+
def __init__(self, online: bool = False):
|
32 |
+
"""
|
33 |
+
punctuator with singleton pattern
|
34 |
+
:param online:
|
35 |
+
"""
|
36 |
+
self.online = online
|
37 |
+
|
38 |
+
if online:
|
39 |
+
if CttPunctuator._online_model is None:
|
40 |
+
with lock:
|
41 |
+
if CttPunctuator._online_model is None:
|
42 |
+
logging.info("Initializing punctuator model with online mode.")
|
43 |
+
CttPunctuator._online_model = CT_Transformer_VadRealtime()
|
44 |
+
self.param_dict = {"cache": []}
|
45 |
+
logging.info("Online model initialized.")
|
46 |
+
self.model = CttPunctuator._online_model
|
47 |
+
|
48 |
+
else:
|
49 |
+
if CttPunctuator._offline_model is None:
|
50 |
+
with lock:
|
51 |
+
if CttPunctuator._offline_model is None:
|
52 |
+
logging.info("Initializing punctuator model with offline mode.")
|
53 |
+
CttPunctuator._offline_model = CT_Transformer()
|
54 |
+
logging.info("Offline model initialized.")
|
55 |
+
self.model = CttPunctuator._offline_model
|
56 |
+
|
57 |
+
logging.info("Model initialized.")
|
58 |
+
|
59 |
+
def punctuate(self, text: str, param_dict=None):
|
60 |
+
if self.online:
|
61 |
+
param_dict = param_dict or self.param_dict
|
62 |
+
return self.model(text, self.param_dict)
|
63 |
+
else:
|
64 |
+
return self.model(text)
|
cttpunctuator/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :__init__.py.py
|
3 |
+
# @Time :2023/4/13 14:58
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
cttpunctuator/src/onnx/configuration.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"framework": "onnx",
|
3 |
+
"task" : "punctuation",
|
4 |
+
"model" : {
|
5 |
+
"type" : "generic-punc",
|
6 |
+
"punc_model_name" : "punc.pb",
|
7 |
+
"punc_model_config" : {
|
8 |
+
"type": "pytorch",
|
9 |
+
"code_base": "funasr",
|
10 |
+
"mode": "punc",
|
11 |
+
"lang": "zh-cn",
|
12 |
+
"batch_size": 1,
|
13 |
+
"punc_config": "punc.yaml",
|
14 |
+
"model": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
15 |
+
}
|
16 |
+
},
|
17 |
+
"pipeline": {
|
18 |
+
"type":"punc-inference"
|
19 |
+
}
|
20 |
+
}
|
cttpunctuator/src/onnx/punc.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:06ae02f3fce2d6bfbcdd988672467808c0113e77b6eed7dc52835ff627e12330
|
3 |
+
size 292007354
|
cttpunctuator/src/onnx/punc.yaml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cttpunctuator/src/punctuator.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os.path
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from cttpunctuator.src.utils.OrtInferSession import (ONNXRuntimeError,
|
9 |
+
OrtInferSession)
|
10 |
+
from cttpunctuator.src.utils.text_post_process import (TokenIDConverter,
|
11 |
+
code_mix_split_words,
|
12 |
+
read_yaml,
|
13 |
+
split_to_mini_sentence)
|
14 |
+
|
15 |
+
|
16 |
+
class CT_Transformer:
|
17 |
+
"""
|
18 |
+
Author: Speech Lab, Alibaba Group, China
|
19 |
+
CT-Transformer: Controllable time-delay transformer
|
20 |
+
for real-time punctuation prediction and disfluency detection
|
21 |
+
https://arxiv.org/pdf/2003.01309.pdf
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
model_dir: Union[str, Path] = None,
|
27 |
+
batch_size: int = 1,
|
28 |
+
device_id: Union[str, int] = "-1",
|
29 |
+
quantize: bool = False,
|
30 |
+
intra_op_num_threads: int = 4,
|
31 |
+
):
|
32 |
+
model_dir = model_dir or os.path.join(os.path.dirname(__file__), "onnx")
|
33 |
+
if model_dir is None or not Path(model_dir).exists():
|
34 |
+
raise FileNotFoundError(f"{model_dir} does not exist.")
|
35 |
+
|
36 |
+
model_file = os.path.join(model_dir, "punc.onnx")
|
37 |
+
if quantize:
|
38 |
+
model_file = os.path.join(model_dir, "model_quant.onnx")
|
39 |
+
config_file = os.path.join(model_dir, "punc.yaml")
|
40 |
+
config = read_yaml(config_file)
|
41 |
+
|
42 |
+
self.converter = TokenIDConverter(config["token_list"])
|
43 |
+
self.ort_infer = OrtInferSession(
|
44 |
+
model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
45 |
+
)
|
46 |
+
self.batch_size = 1
|
47 |
+
self.punc_list = config["punc_list"]
|
48 |
+
self.period = 0
|
49 |
+
for i in range(len(self.punc_list)):
|
50 |
+
if self.punc_list[i] == ",":
|
51 |
+
self.punc_list[i] = ","
|
52 |
+
elif self.punc_list[i] == "?":
|
53 |
+
self.punc_list[i] = "?"
|
54 |
+
elif self.punc_list[i] == "。":
|
55 |
+
self.period = i
|
56 |
+
|
57 |
+
def __call__(self, text: Union[list, str], split_size=20):
|
58 |
+
split_text = code_mix_split_words(text)
|
59 |
+
split_text_id = self.converter.tokens2ids(split_text)
|
60 |
+
mini_sentences = split_to_mini_sentence(split_text, split_size)
|
61 |
+
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
|
62 |
+
assert len(mini_sentences) == len(mini_sentences_id)
|
63 |
+
cache_sent = []
|
64 |
+
cache_sent_id = []
|
65 |
+
new_mini_sentence = ""
|
66 |
+
new_mini_sentence_punc = []
|
67 |
+
cache_pop_trigger_limit = 200
|
68 |
+
for mini_sentence_i in range(len(mini_sentences)):
|
69 |
+
mini_sentence = mini_sentences[mini_sentence_i]
|
70 |
+
mini_sentence_id = mini_sentences_id[mini_sentence_i]
|
71 |
+
mini_sentence = cache_sent + mini_sentence
|
72 |
+
mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype="int64")
|
73 |
+
data = {
|
74 |
+
"text": mini_sentence_id[None, :],
|
75 |
+
"text_lengths": np.array([len(mini_sentence_id)], dtype="int32"),
|
76 |
+
}
|
77 |
+
try:
|
78 |
+
outputs = self.infer(data["text"], data["text_lengths"])
|
79 |
+
y = outputs[0]
|
80 |
+
punctuations = np.argmax(y, axis=-1)[0]
|
81 |
+
assert punctuations.size == len(mini_sentence)
|
82 |
+
except ONNXRuntimeError:
|
83 |
+
logging.warning("error")
|
84 |
+
|
85 |
+
# Search for the last Period/QuestionMark as cache
|
86 |
+
if mini_sentence_i < len(mini_sentences) - 1:
|
87 |
+
sentenceEnd = -1
|
88 |
+
last_comma_index = -1
|
89 |
+
for i in range(len(punctuations) - 2, 1, -1):
|
90 |
+
if (
|
91 |
+
self.punc_list[punctuations[i]] == "。"
|
92 |
+
or self.punc_list[punctuations[i]] == "?"
|
93 |
+
):
|
94 |
+
sentenceEnd = i
|
95 |
+
break
|
96 |
+
if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
|
97 |
+
last_comma_index = i
|
98 |
+
|
99 |
+
if (
|
100 |
+
sentenceEnd < 0
|
101 |
+
and len(mini_sentence) > cache_pop_trigger_limit
|
102 |
+
and last_comma_index >= 0
|
103 |
+
):
|
104 |
+
# The sentence it too long, cut off at a comma.
|
105 |
+
sentenceEnd = last_comma_index
|
106 |
+
punctuations[sentenceEnd] = self.period
|
107 |
+
cache_sent = mini_sentence[sentenceEnd + 1 :]
|
108 |
+
cache_sent_id = mini_sentence_id[sentenceEnd + 1 :].tolist()
|
109 |
+
mini_sentence = mini_sentence[0 : sentenceEnd + 1]
|
110 |
+
punctuations = punctuations[0 : sentenceEnd + 1]
|
111 |
+
|
112 |
+
new_mini_sentence_punc += [int(x) for x in punctuations]
|
113 |
+
words_with_punc = []
|
114 |
+
for i in range(len(mini_sentence)):
|
115 |
+
if i > 0:
|
116 |
+
if (
|
117 |
+
len(mini_sentence[i][0].encode()) == 1
|
118 |
+
and len(mini_sentence[i - 1][0].encode()) == 1
|
119 |
+
):
|
120 |
+
mini_sentence[i] = " " + mini_sentence[i]
|
121 |
+
words_with_punc.append(mini_sentence[i])
|
122 |
+
if self.punc_list[punctuations[i]] != "_":
|
123 |
+
words_with_punc.append(self.punc_list[punctuations[i]])
|
124 |
+
new_mini_sentence += "".join(words_with_punc)
|
125 |
+
# Add Period for the end of the sentence
|
126 |
+
new_mini_sentence_out = new_mini_sentence
|
127 |
+
new_mini_sentence_punc_out = new_mini_sentence_punc
|
128 |
+
if mini_sentence_i == len(mini_sentences) - 1:
|
129 |
+
if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
|
130 |
+
new_mini_sentence_out = new_mini_sentence[:-1] + "。"
|
131 |
+
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
|
132 |
+
self.period
|
133 |
+
]
|
134 |
+
elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
|
135 |
+
new_mini_sentence_out = new_mini_sentence + "。"
|
136 |
+
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
|
137 |
+
self.period
|
138 |
+
]
|
139 |
+
return new_mini_sentence_out, new_mini_sentence_punc_out
|
140 |
+
|
141 |
+
def infer(
|
142 |
+
self, feats: np.ndarray, feats_len: np.ndarray
|
143 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
144 |
+
outputs = self.ort_infer([feats, feats_len])
|
145 |
+
return outputs
|
146 |
+
|
147 |
+
|
148 |
+
class CT_Transformer_VadRealtime(CT_Transformer):
|
149 |
+
"""
|
150 |
+
Author: Speech Lab, Alibaba Group, China
|
151 |
+
CT-Transformer: Controllable time-delay transformer for
|
152 |
+
real-time punctuation prediction and disfluency detection
|
153 |
+
https://arxiv.org/pdf/2003.01309.pdf
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
model_dir: Union[str, Path] = None,
|
159 |
+
batch_size: int = 1,
|
160 |
+
device_id: Union[str, int] = "-1",
|
161 |
+
quantize: bool = False,
|
162 |
+
intra_op_num_threads: int = 4,
|
163 |
+
):
|
164 |
+
super(CT_Transformer_VadRealtime, self).__init__(
|
165 |
+
model_dir, batch_size, device_id, quantize, intra_op_num_threads
|
166 |
+
)
|
167 |
+
|
168 |
+
def __call__(self, text: str, param_dict: map, split_size=20):
|
169 |
+
cache_key = "cache"
|
170 |
+
assert cache_key in param_dict
|
171 |
+
cache = param_dict[cache_key]
|
172 |
+
if cache is not None and len(cache) > 0:
|
173 |
+
precache = "".join(cache)
|
174 |
+
else:
|
175 |
+
precache = ""
|
176 |
+
cache = []
|
177 |
+
full_text = precache + text
|
178 |
+
split_text = code_mix_split_words(full_text)
|
179 |
+
split_text_id = self.converter.tokens2ids(split_text)
|
180 |
+
mini_sentences = split_to_mini_sentence(split_text, split_size)
|
181 |
+
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
|
182 |
+
new_mini_sentence_punc = []
|
183 |
+
assert len(mini_sentences) == len(mini_sentences_id)
|
184 |
+
|
185 |
+
cache_sent = []
|
186 |
+
cache_sent_id = np.array([], dtype="int32")
|
187 |
+
sentence_punc_list = []
|
188 |
+
sentence_words_list = []
|
189 |
+
cache_pop_trigger_limit = 200
|
190 |
+
skip_num = 0
|
191 |
+
for mini_sentence_i in range(len(mini_sentences)):
|
192 |
+
mini_sentence = mini_sentences[mini_sentence_i]
|
193 |
+
mini_sentence_id = mini_sentences_id[mini_sentence_i]
|
194 |
+
mini_sentence = cache_sent + mini_sentence
|
195 |
+
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
|
196 |
+
text_length = len(mini_sentence_id)
|
197 |
+
data = {
|
198 |
+
"input": mini_sentence_id[None, :],
|
199 |
+
"text_lengths": np.array([text_length], dtype="int32"),
|
200 |
+
"vad_mask": self.vad_mask(text_length, len(cache))[
|
201 |
+
None, None, :, :
|
202 |
+
].astype(np.float32),
|
203 |
+
"sub_masks": np.tril(
|
204 |
+
np.ones((text_length, text_length), dtype=np.float32)
|
205 |
+
)[None, None, :, :].astype(np.float32),
|
206 |
+
}
|
207 |
+
try:
|
208 |
+
outputs = self.infer(
|
209 |
+
data["input"],
|
210 |
+
data["text_lengths"],
|
211 |
+
data["vad_mask"],
|
212 |
+
data["sub_masks"],
|
213 |
+
)
|
214 |
+
y = outputs[0]
|
215 |
+
punctuations = np.argmax(y, axis=-1)[0]
|
216 |
+
assert punctuations.size == len(mini_sentence)
|
217 |
+
except ONNXRuntimeError:
|
218 |
+
logging.warning("error")
|
219 |
+
|
220 |
+
# Search for the last Period/QuestionMark as cache
|
221 |
+
if mini_sentence_i < len(mini_sentences) - 1:
|
222 |
+
sentenceEnd = -1
|
223 |
+
last_comma_index = -1
|
224 |
+
for i in range(len(punctuations) - 2, 1, -1):
|
225 |
+
if (
|
226 |
+
self.punc_list[punctuations[i]] == "。"
|
227 |
+
or self.punc_list[punctuations[i]] == "?"
|
228 |
+
):
|
229 |
+
sentenceEnd = i
|
230 |
+
break
|
231 |
+
if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
|
232 |
+
last_comma_index = i
|
233 |
+
|
234 |
+
if (
|
235 |
+
sentenceEnd < 0
|
236 |
+
and len(mini_sentence) > cache_pop_trigger_limit
|
237 |
+
and last_comma_index >= 0
|
238 |
+
):
|
239 |
+
# The sentence it too long, cut off at a comma.
|
240 |
+
sentenceEnd = last_comma_index
|
241 |
+
punctuations[sentenceEnd] = self.period
|
242 |
+
cache_sent = mini_sentence[sentenceEnd + 1 :]
|
243 |
+
cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
|
244 |
+
mini_sentence = mini_sentence[0 : sentenceEnd + 1]
|
245 |
+
punctuations = punctuations[0 : sentenceEnd + 1]
|
246 |
+
|
247 |
+
punctuations_np = [int(x) for x in punctuations]
|
248 |
+
new_mini_sentence_punc += punctuations_np
|
249 |
+
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
|
250 |
+
sentence_words_list += mini_sentence
|
251 |
+
|
252 |
+
assert len(sentence_punc_list) == len(sentence_words_list)
|
253 |
+
words_with_punc = []
|
254 |
+
sentence_punc_list_out = []
|
255 |
+
for i in range(0, len(sentence_words_list)):
|
256 |
+
if i > 0:
|
257 |
+
if (
|
258 |
+
len(sentence_words_list[i][0].encode()) == 1
|
259 |
+
and len(sentence_words_list[i - 1][-1].encode()) == 1
|
260 |
+
):
|
261 |
+
sentence_words_list[i] = " " + sentence_words_list[i]
|
262 |
+
if skip_num < len(cache):
|
263 |
+
skip_num += 1
|
264 |
+
else:
|
265 |
+
words_with_punc.append(sentence_words_list[i])
|
266 |
+
if skip_num >= len(cache):
|
267 |
+
sentence_punc_list_out.append(sentence_punc_list[i])
|
268 |
+
if sentence_punc_list[i] != "_":
|
269 |
+
words_with_punc.append(sentence_punc_list[i])
|
270 |
+
sentence_out = "".join(words_with_punc)
|
271 |
+
|
272 |
+
sentenceEnd = -1
|
273 |
+
for i in range(len(sentence_punc_list) - 2, 1, -1):
|
274 |
+
if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
|
275 |
+
sentenceEnd = i
|
276 |
+
break
|
277 |
+
cache_out = sentence_words_list[sentenceEnd + 1 :]
|
278 |
+
if sentence_out[-1] in self.punc_list:
|
279 |
+
sentence_out = sentence_out[:-1]
|
280 |
+
sentence_punc_list_out[-1] = "_"
|
281 |
+
param_dict[cache_key] = cache_out
|
282 |
+
return sentence_out, sentence_punc_list_out, cache_out
|
283 |
+
|
284 |
+
def vad_mask(self, size, vad_pos, dtype=np.bool_):
|
285 |
+
"""Create mask for decoder self-attention.
|
286 |
+
|
287 |
+
:param int size: size of mask
|
288 |
+
:param int vad_pos: index of vad index
|
289 |
+
:param torch.dtype dtype: result dtype
|
290 |
+
:rtype: torch.Tensor (B, Lmax, Lmax)
|
291 |
+
"""
|
292 |
+
ret = np.ones((size, size), dtype=dtype)
|
293 |
+
if vad_pos <= 0 or vad_pos >= size:
|
294 |
+
return ret
|
295 |
+
sub_corner = np.zeros((vad_pos - 1, size - vad_pos), dtype=dtype)
|
296 |
+
ret[0 : vad_pos - 1, vad_pos:] = sub_corner
|
297 |
+
return ret
|
298 |
+
|
299 |
+
def infer(
|
300 |
+
self,
|
301 |
+
feats: np.ndarray,
|
302 |
+
feats_len: np.ndarray,
|
303 |
+
vad_mask: np.ndarray,
|
304 |
+
sub_masks: np.ndarray,
|
305 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
306 |
+
outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
|
307 |
+
return outputs
|
cttpunctuator/src/utils/OrtInferSession.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :OrtInferSession.py
|
3 |
+
# @Time :2023/4/13 15:13
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from onnxruntime import (GraphOptimizationLevel, InferenceSession,
|
12 |
+
SessionOptions, get_available_providers, get_device)
|
13 |
+
|
14 |
+
|
15 |
+
class ONNXRuntimeError(Exception):
|
16 |
+
pass
|
17 |
+
|
18 |
+
|
19 |
+
class OrtInferSession:
|
20 |
+
def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
|
21 |
+
device_id = str(device_id)
|
22 |
+
sess_opt = SessionOptions()
|
23 |
+
sess_opt.intra_op_num_threads = intra_op_num_threads
|
24 |
+
sess_opt.log_severity_level = 4
|
25 |
+
sess_opt.enable_cpu_mem_arena = False
|
26 |
+
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
27 |
+
|
28 |
+
cuda_ep = "CUDAExecutionProvider"
|
29 |
+
cuda_provider_options = {
|
30 |
+
"device_id": device_id,
|
31 |
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
32 |
+
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
33 |
+
"do_copy_in_default_stream": "true",
|
34 |
+
}
|
35 |
+
cpu_ep = "CPUExecutionProvider"
|
36 |
+
cpu_provider_options = {
|
37 |
+
"arena_extend_strategy": "kSameAsRequested",
|
38 |
+
}
|
39 |
+
|
40 |
+
EP_list = []
|
41 |
+
if (
|
42 |
+
device_id != "-1"
|
43 |
+
and get_device() == "GPU"
|
44 |
+
and cuda_ep in get_available_providers()
|
45 |
+
):
|
46 |
+
EP_list = [(cuda_ep, cuda_provider_options)]
|
47 |
+
EP_list.append((cpu_ep, cpu_provider_options))
|
48 |
+
|
49 |
+
self._verify_model(model_file)
|
50 |
+
self.session = InferenceSession(
|
51 |
+
model_file, sess_options=sess_opt, providers=EP_list
|
52 |
+
)
|
53 |
+
|
54 |
+
if device_id != "-1" and cuda_ep not in self.session.get_providers():
|
55 |
+
logging.warnings.warn(
|
56 |
+
f"{cuda_ep} is not avaiable for current env, "
|
57 |
+
f"the inference part is automatically shifted to be executed under {cpu_ep}.\n"
|
58 |
+
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
|
59 |
+
"you can check their relations from the offical web site: "
|
60 |
+
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
|
61 |
+
RuntimeWarning,
|
62 |
+
)
|
63 |
+
|
64 |
+
def __call__(
|
65 |
+
self, input_content: List[Union[np.ndarray, np.ndarray]]
|
66 |
+
) -> np.ndarray:
|
67 |
+
input_dict = dict(zip(self.get_input_names(), input_content))
|
68 |
+
try:
|
69 |
+
return self.session.run(self.get_output_names(), input_dict)
|
70 |
+
except Exception as e:
|
71 |
+
raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
|
72 |
+
|
73 |
+
def get_input_names(
|
74 |
+
self,
|
75 |
+
):
|
76 |
+
return [v.name for v in self.session.get_inputs()]
|
77 |
+
|
78 |
+
def get_output_names(
|
79 |
+
self,
|
80 |
+
):
|
81 |
+
return [v.name for v in self.session.get_outputs()]
|
82 |
+
|
83 |
+
def get_character_list(self, key: str = "character"):
|
84 |
+
return self.meta_dict[key].splitlines()
|
85 |
+
|
86 |
+
def have_key(self, key: str = "character") -> bool:
|
87 |
+
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
88 |
+
if key in self.meta_dict.keys():
|
89 |
+
return True
|
90 |
+
return False
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def _verify_model(model_path):
|
94 |
+
model_path = Path(model_path)
|
95 |
+
if not model_path.exists():
|
96 |
+
raise FileNotFoundError(f"{model_path} does not exists.")
|
97 |
+
if not model_path.is_file():
|
98 |
+
raise FileExistsError(f"{model_path} is not a file.")
|
cttpunctuator/src/utils/text_post_process.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :text_post_process.py
|
3 |
+
# @Time :2023/4/13 15:09
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Dict, Iterable, List, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import yaml
|
11 |
+
from typeguard import check_argument_types
|
12 |
+
|
13 |
+
|
14 |
+
class TokenIDConverterError(Exception):
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
class TokenIDConverter:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
token_list: Union[List, str],
|
22 |
+
):
|
23 |
+
check_argument_types()
|
24 |
+
|
25 |
+
self.token_list = token_list
|
26 |
+
self.unk_symbol = token_list[-1]
|
27 |
+
self.token2id = {v: i for i, v in enumerate(self.token_list)}
|
28 |
+
self.unk_id = self.token2id[self.unk_symbol]
|
29 |
+
|
30 |
+
def get_num_vocabulary_size(self) -> int:
|
31 |
+
return len(self.token_list)
|
32 |
+
|
33 |
+
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
34 |
+
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
35 |
+
raise TokenIDConverterError(
|
36 |
+
f"Must be 1 dim ndarray, but got {integers.ndim}"
|
37 |
+
)
|
38 |
+
return [self.token_list[i] for i in integers]
|
39 |
+
|
40 |
+
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
41 |
+
|
42 |
+
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
43 |
+
|
44 |
+
|
45 |
+
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
46 |
+
assert word_limit > 1
|
47 |
+
if len(words) <= word_limit:
|
48 |
+
return [words]
|
49 |
+
sentences = []
|
50 |
+
length = len(words)
|
51 |
+
sentence_len = length // word_limit
|
52 |
+
for i in range(sentence_len):
|
53 |
+
sentences.append(words[i * word_limit : (i + 1) * word_limit])
|
54 |
+
if length % word_limit > 0:
|
55 |
+
sentences.append(words[sentence_len * word_limit :])
|
56 |
+
return sentences
|
57 |
+
|
58 |
+
|
59 |
+
def code_mix_split_words(text: str):
|
60 |
+
words = []
|
61 |
+
segs = text.split()
|
62 |
+
for seg in segs:
|
63 |
+
# There is no space in seg.
|
64 |
+
current_word = ""
|
65 |
+
for c in seg:
|
66 |
+
if len(c.encode()) == 1:
|
67 |
+
# This is an ASCII char.
|
68 |
+
current_word += c
|
69 |
+
else:
|
70 |
+
# This is a Chinese char.
|
71 |
+
if len(current_word) > 0:
|
72 |
+
words.append(current_word)
|
73 |
+
current_word = ""
|
74 |
+
words.append(c)
|
75 |
+
if len(current_word) > 0:
|
76 |
+
words.append(current_word)
|
77 |
+
return words
|
78 |
+
|
79 |
+
|
80 |
+
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
|
81 |
+
if not Path(yaml_path).exists():
|
82 |
+
raise FileExistsError(f"The {yaml_path} does not exist.")
|
83 |
+
|
84 |
+
with open(str(yaml_path), "rb") as f:
|
85 |
+
data = yaml.load(f, Loader=yaml.Loader)
|
86 |
+
return data
|
setup.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :setup.py
|
3 |
+
# @Time :2023/4/4 11:22
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
from setuptools import find_namespace_packages, setup
|
10 |
+
|
11 |
+
dirname = Path(os.path.dirname(__file__))
|
12 |
+
version_file = dirname / "version.txt"
|
13 |
+
with open(version_file, "r") as f:
|
14 |
+
version = f.read().strip()
|
15 |
+
|
16 |
+
requirements = {
|
17 |
+
"install": [
|
18 |
+
"setuptools<=65.0",
|
19 |
+
"PyYAML",
|
20 |
+
"typeguard==2.13.3",
|
21 |
+
"onnxruntime==1.14.1",
|
22 |
+
],
|
23 |
+
"setup": [
|
24 |
+
"numpy==1.24.2",
|
25 |
+
],
|
26 |
+
"all": [],
|
27 |
+
}
|
28 |
+
requirements["all"].extend(requirements["install"])
|
29 |
+
|
30 |
+
install_requires = requirements["install"]
|
31 |
+
setup_requires = requirements["setup"]
|
32 |
+
|
33 |
+
|
34 |
+
setup(
|
35 |
+
name="cttpunctuator",
|
36 |
+
version=version,
|
37 |
+
url="https://github.com/lovemefan/CT-Transformer-punctuation",
|
38 |
+
author="Lovemefan, Yunnan Key Laboratory of Artificial Intelligence, "
|
39 |
+
"Kunming University of Science and Technology, Kunming, Yunnan ",
|
40 |
+
author_email="[email protected]",
|
41 |
+
description="ctt-punctuator: A enterprise-grade punctuator after chinese asr based "
|
42 |
+
"on ct-transformer from funasr opensource",
|
43 |
+
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
|
44 |
+
long_description_content_type="text/markdown",
|
45 |
+
license="The MIT License",
|
46 |
+
packages=find_namespace_packages(),
|
47 |
+
include_package_data=True,
|
48 |
+
install_requires=install_requires,
|
49 |
+
python_requires=">=3.7.0",
|
50 |
+
classifiers=[
|
51 |
+
"Programming Language :: Python",
|
52 |
+
"Programming Language :: Python :: 3.7",
|
53 |
+
"Programming Language :: Python :: 3.8",
|
54 |
+
"Programming Language :: Python :: 3.9",
|
55 |
+
"Programming Language :: Python :: 3.10",
|
56 |
+
"Development Status :: 5 - Production/Stable",
|
57 |
+
"Intended Audience :: Science/Research",
|
58 |
+
"Operating System :: POSIX :: Linux",
|
59 |
+
"License :: OSI Approved :: Apache Software License",
|
60 |
+
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
61 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
62 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
63 |
+
],
|
64 |
+
)
|
test/test.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :test.py.py
|
3 |
+
# @Time :2023/4/19 13:39
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
6 |
+
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from cttPunctuator import CttPunctuator
|
10 |
+
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
|
14 |
+
)
|
15 |
+
# offline mode
|
16 |
+
punc = CttPunctuator()
|
17 |
+
text = "据报道纽约时报使用ChatGPT创建了一个情人节消息生成器用户只需输入几个提示就可以得到一封自动生成的情书"
|
18 |
+
logging.info(punc.punctuate(text)[0])
|
19 |
+
|
20 |
+
# online mode
|
21 |
+
punc = CttPunctuator(online=True)
|
22 |
+
text_in = (
|
23 |
+
"跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|"
|
24 |
+
"在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|"
|
25 |
+
"向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|"
|
26 |
+
"愿意进一步完善双方联合工作机制|凡是|中方能做的我们|"
|
27 |
+
"都会去做而且会做得更好我请印度朋友们放心中国在上游的|"
|
28 |
+
"任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
|
29 |
+
)
|
30 |
+
|
31 |
+
vads = text_in.split("|")
|
32 |
+
rec_result_all = ""
|
33 |
+
for vad in vads:
|
34 |
+
result = punc.punctuate(vad)
|
35 |
+
rec_result_all += result[0]
|
36 |
+
logging.info(f"Part: {rec_result_all}")
|
37 |
+
|
38 |
+
logging.info(f"Final: {rec_result_all}")
|
version.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
0.0.1
|