File size: 7,210 Bytes
f1994e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from http.server import HTTPServer, BaseHTTPRequestHandler
import json
import argparse
import uuid
# 全局字典:存储 uid 到 Tokenizer_Http 实例的映射
tokenizers = {}
class Tokenizer_Http():
def __init__(self):
model_id = "qwen2.5_tokenizer"
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
]
self.token_ids = []
self.token_ids_cache = []
def encode(self, prompt, last_reply=None):
if last_reply is not None:
self.messages.append({"role": "assistant", "content": last_reply})
text = self.tokenizer.apply_chat_template(
self.messages,
tokenize=False,
add_generation_prompt=True
)
# print("生成的文本:\n============\n", text, "============\n")
self.token_ids = self.tokenizer.encode(text)[:-3]
self.messages.append({"role": "user", "content": prompt})
text = self.tokenizer.apply_chat_template(
self.messages,
tokenize=False,
add_generation_prompt=True
)
print("生成的文本:\n============\n", text, "============\n")
token_ids = self.tokenizer.encode(text)
# 找出新增部分
diff = token_ids[len(self.token_ids):]
self.token_ids = token_ids
print(self.decode(diff))
return token_ids, diff
def decode(self, token_ids):
self.token_ids_cache += token_ids
text = self.tokenizer.decode(self.token_ids_cache)
if "\ufffd" in text:
print("text 中包含非法字符")
return ""
else:
self.token_ids_cache.clear()
return text
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def bos_token(self):
return self.tokenizer.bos_token
@property
def eos_token(self):
return self.tokenizer.eos_token
def reset(self, system_prompt="You are Qwen, created by Alibaba Cloud. You are a helpful assistant."):
self.messages = [
{"role": "system", "content": system_prompt},
]
text = self.tokenizer.apply_chat_template(
self.messages,
tokenize=False,
add_generation_prompt=True
)
token_ids = self.tokenizer.encode(text)[:-3]
self.token_ids = token_ids
print(self.decode(token_ids))
return token_ids
class Request(BaseHTTPRequestHandler):
timeout = 5
server_version = 'Apache'
def do_GET(self):
print("GET 请求路径:", self.path)
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
# 新增接口:获取 uid
if '/get_uid' in self.path:
new_uid = str(uuid.uuid4())
print("新 uid:", new_uid)
# 为该 uid 创建一个新的 Tokenizer_Http 实例
tokenizers[new_uid] = Tokenizer_Http()
msg = json.dumps({'uid': new_uid})
elif '/bos_id' in self.path:
# 获取 uid 参数(例如 ?uid=xxx)
uid = self.get_query_param("uid")
instance: Tokenizer_Http = tokenizers.get(uid)
if instance is None:
msg = json.dumps({'error': 'Invalid uid'})
else:
bos_id = instance.bos_id
msg = json.dumps({'bos_id': bos_id if bos_id is not None else -1})
elif '/eos_id' in self.path:
uid = self.get_query_param("uid")
instance: Tokenizer_Http = tokenizers.get(uid)
if instance is None:
msg = json.dumps({'error': 'Invalid uid'})
else:
eos_id = instance.eos_id
msg = json.dumps({'eos_id': eos_id if eos_id is not None else -1})
else:
msg = json.dumps({'error': 'Invalid GET endpoint'})
print("响应消息:", msg)
self.wfile.write(msg.encode())
def do_POST(self):
content_length = int(self.headers.get('content-length', 0))
data = self.rfile.read(content_length).decode()
print("POST 请求路径:", self.path)
print("接收到的数据:", data)
req = json.loads(data)
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
if '/encode' in self.path:
# 请求数据中必须包含 uid, text, 和可选的 last_reply
uid = req.get('uid')
prompt = req.get('text')
last_reply = req.get('last_reply')
instance: Tokenizer_Http = tokenizers.get(uid)
if instance is None:
msg = json.dumps({'error': 'Invalid uid'})
else:
token_ids, diff = instance.encode(prompt, last_reply)
msg = json.dumps({'token_ids': token_ids, 'diff': diff})
elif '/decode' in self.path:
uid = req.get('uid')
token_ids = req.get('token_ids')
instance: Tokenizer_Http = tokenizers.get(uid)
if instance is None:
msg = json.dumps({'error': 'Invalid uid'})
else:
text = instance.decode(token_ids)
msg = json.dumps({'text': text})
elif '/reset' in self.path:
uid = req.get("uid")
system_prompt = req.get("system_prompt")
instance: Tokenizer_Http = tokenizers.get(uid)
if instance is None:
msg = json.dumps({'error': 'Invalid uid'})
else:
if system_prompt is not None:
print("system_prompt:", system_prompt)
token_ids = instance.reset(system_prompt)
msg = json.dumps({'token_ids': token_ids})
else:
token_ids = instance.reset()
msg = json.dumps({'token_ids': token_ids})
else:
msg = json.dumps({'error': 'Invalid POST endpoint'})
print("响应消息:", msg)
self.wfile.write(msg.encode())
def get_query_param(self, key):
"""
辅助函数:从 GET 请求的 URL 中获取查询参数的值
例如:/bos_id?uid=xxx
"""
from urllib.parse import urlparse, parse_qs
query = urlparse(self.path).query
params = parse_qs(query)
values = params.get(key)
return values[0] if values else None
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default='0.0.0.0')
parser.add_argument('--port', type=int, default=12345)
args = parser.parse_args()
host = (args.host, args.port)
print('Server running at http://%s:%s' % host)
server = HTTPServer(host, Request)
server.serve_forever()
|