File size: 7,210 Bytes
f1994e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from http.server import HTTPServer, BaseHTTPRequestHandler
import json
import argparse
import uuid

# 全局字典:存储 uid 到 Tokenizer_Http 实例的映射
tokenizers = {}

class Tokenizer_Http():
    def __init__(self):
        model_id = "qwen2.5_tokenizer"
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.messages = [
            {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
        ]
        self.token_ids = []
        
        self.token_ids_cache = []

    def encode(self, prompt, last_reply=None):
        if last_reply is not None:
            self.messages.append({"role": "assistant", "content": last_reply})
            text = self.tokenizer.apply_chat_template(
                self.messages,
                tokenize=False,
                add_generation_prompt=True
            )
            # print("生成的文本:\n============\n", text, "============\n")
            self.token_ids = self.tokenizer.encode(text)[:-3]
        self.messages.append({"role": "user", "content": prompt})
        
        text = self.tokenizer.apply_chat_template(
            self.messages,
            tokenize=False,
            add_generation_prompt=True
        )
        print("生成的文本:\n============\n", text, "============\n")
        token_ids = self.tokenizer.encode(text)
        # 找出新增部分
        diff = token_ids[len(self.token_ids):]
        self.token_ids = token_ids
        print(self.decode(diff))
        return token_ids, diff

    def decode(self, token_ids):
        self.token_ids_cache += token_ids
        text = self.tokenizer.decode(self.token_ids_cache)
        if "\ufffd" in text:
            print("text 中包含非法字符")
            return ""
        else:
            self.token_ids_cache.clear()
            return text
        

    @property
    def bos_id(self):
        return self.tokenizer.bos_token_id

    @property
    def eos_id(self):
        return self.tokenizer.eos_token_id
    
    @property
    def bos_token(self):
        return self.tokenizer.bos_token

    @property
    def eos_token(self):
        return self.tokenizer.eos_token
    
    def reset(self, system_prompt="You are Qwen, created by Alibaba Cloud. You are a helpful assistant."):
        self.messages = [
            {"role": "system", "content": system_prompt},
        ]
        text = self.tokenizer.apply_chat_template(
            self.messages,
            tokenize=False,
            add_generation_prompt=True
        )
        token_ids = self.tokenizer.encode(text)[:-3]
        self.token_ids = token_ids
        print(self.decode(token_ids))
        return token_ids
        

class Request(BaseHTTPRequestHandler):
    timeout = 5
    server_version = 'Apache'

    def do_GET(self):
        print("GET 请求路径:", self.path)
        self.send_response(200)
        self.send_header("Content-Type", "application/json")
        self.end_headers()

        # 新增接口:获取 uid
        if '/get_uid' in self.path:
            new_uid = str(uuid.uuid4())
            print("新 uid:", new_uid)
            # 为该 uid 创建一个新的 Tokenizer_Http 实例
            tokenizers[new_uid] = Tokenizer_Http()
            msg = json.dumps({'uid': new_uid})
        elif '/bos_id' in self.path:
            # 获取 uid 参数(例如 ?uid=xxx)
            uid = self.get_query_param("uid")
            instance: Tokenizer_Http = tokenizers.get(uid)
            if instance is None:
                msg = json.dumps({'error': 'Invalid uid'})
            else:
                bos_id = instance.bos_id
                msg = json.dumps({'bos_id': bos_id if bos_id is not None else -1})
        elif '/eos_id' in self.path:
            uid = self.get_query_param("uid")
            instance: Tokenizer_Http = tokenizers.get(uid)
            if instance is None:
                msg = json.dumps({'error': 'Invalid uid'})
            else:
                eos_id = instance.eos_id
                msg = json.dumps({'eos_id': eos_id if eos_id is not None else -1})
        else:
            msg = json.dumps({'error': 'Invalid GET endpoint'})

        print("响应消息:", msg)
        self.wfile.write(msg.encode())

    def do_POST(self):
        content_length = int(self.headers.get('content-length', 0))
        data = self.rfile.read(content_length).decode()
        print("POST 请求路径:", self.path)
        print("接收到的数据:", data)
        req = json.loads(data)

        self.send_response(200)
        self.send_header("Content-Type", "application/json")
        self.end_headers()

        if '/encode' in self.path:
            # 请求数据中必须包含 uid, text, 和可选的 last_reply
            uid = req.get('uid')
            prompt = req.get('text')
            last_reply = req.get('last_reply')
            instance: Tokenizer_Http = tokenizers.get(uid)
            if instance is None:
                msg = json.dumps({'error': 'Invalid uid'})
            else:
                token_ids, diff = instance.encode(prompt, last_reply)
                msg = json.dumps({'token_ids': token_ids, 'diff': diff})
        elif '/decode' in self.path:
            uid = req.get('uid')
            token_ids = req.get('token_ids')
            instance: Tokenizer_Http = tokenizers.get(uid)
            if instance is None:
                msg = json.dumps({'error': 'Invalid uid'})
            else:
                text = instance.decode(token_ids)
                msg = json.dumps({'text': text})
        elif '/reset' in self.path:
            uid = req.get("uid")
            system_prompt = req.get("system_prompt")
            instance: Tokenizer_Http = tokenizers.get(uid)
            if instance is None:
                msg = json.dumps({'error': 'Invalid uid'})
            else:
                if system_prompt is not None:
                    print("system_prompt:", system_prompt)
                    token_ids = instance.reset(system_prompt)
                    msg = json.dumps({'token_ids': token_ids})
                else:
                    token_ids = instance.reset()
                    msg = json.dumps({'token_ids': token_ids})
        else:
            msg = json.dumps({'error': 'Invalid POST endpoint'})

        print("响应消息:", msg)
        self.wfile.write(msg.encode())

    def get_query_param(self, key):
        """
        辅助函数:从 GET 请求的 URL 中获取查询参数的值
        例如:/bos_id?uid=xxx
        """
        from urllib.parse import urlparse, parse_qs
        query = urlparse(self.path).query
        params = parse_qs(query)
        values = params.get(key)
        return values[0] if values else None

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--host', type=str, default='0.0.0.0')
    parser.add_argument('--port', type=int, default=12345)
    args = parser.parse_args()

    host = (args.host, args.port)
    print('Server running at http://%s:%s' % host)
    server = HTTPServer(host, Request)
    server.serve_forever()