Skip to content

Commit 978b813

Browse files
K11OntheBoatZhangX-21“liuruian”
authored
[Optimization] Accelerate the speed of tokenizer. (#7544)
* Change default workers and max-concurrency when launch api-server * Change convert_tokens_to_ids to encode to get token ids --------- Co-authored-by: zhangxiao35 <zhangxiao35@baidu.com> Co-authored-by: “liuruian” <liuruian@baidu.com>
1 parent 4c8f7df commit 978b813

2 files changed

Lines changed: 99 additions & 2 deletions

File tree

fastdeploy/input/base_processor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,19 @@ def messages2ids(self, request, **kwargs):
163163
)
164164
request["prompt_tokens"] = spliced_message
165165
req_id = request.get("request_id", None) if isinstance(request, dict) else None
166-
tokens = self.tokenizer.tokenize(spliced_message)
167-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
166+
if self.tokenizer_type == "ernie4_5":
167+
# NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode()
168+
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(spliced_message))
169+
else:
170+
token_ids = self.tokenizer.encode(spliced_message, add_special_tokens=False)
171+
if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids):
172+
token_ids = token_ids["input_ids"]
173+
if hasattr(token_ids, "ndim") and token_ids.ndim > 1:
174+
token_ids = token_ids[0]
175+
if hasattr(token_ids, "tolist"):
176+
token_ids = token_ids.tolist()
177+
if not isinstance(token_ids, list):
178+
token_ids = list(token_ids)
168179
log_request(
169180
level=1,
170181
message="req_id:{req_id}, token_ids: {token_ids}",

tests/input/test_text_processor.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def tokenize(self, text):
7777
def convert_tokens_to_ids(self, tokens):
7878
return [self._value(token) for token in tokens]
7979

80+
def encode(self, text, add_special_tokens=True, **kwargs):
81+
return self.convert_tokens_to_ids(self.tokenize(text))
82+
8083
def decode(self, token_ids, **kwargs):
8184
return " ".join(str(t) for t in token_ids)
8285

@@ -387,6 +390,89 @@ def test_process_request_dict_messages_template(self):
387390
self.assertTrue(processed["enable_thinking"])
388391
self.assertEqual(processed["prompt_tokens"], "system prompt hello")
389392

393+
def test_process_request_dict_messages_template_batch_encoding(self):
394+
"""encode() 返回 BatchEncoding-like 对象时,messages2ids 应正确提取 input_ids"""
395+
396+
class BatchEncodingLike:
397+
"""模拟 HuggingFace BatchEncoding (UserDict 子类,hasattr input_ids = True)"""
398+
399+
def __init__(self, ids):
400+
self.input_ids = ids
401+
402+
def __getitem__(self, key):
403+
return getattr(self, key)
404+
405+
class BatchEncodingTokenizer(DummyTokenizer):
406+
def encode(self, text, add_special_tokens=True, **kwargs):
407+
return BatchEncodingLike([len(text)])
408+
409+
module = self.text_processor_module
410+
processor = module.DataProcessor("stub-model")
411+
processor.tokenizer = BatchEncodingTokenizer()
412+
413+
request = {
414+
"request_id": "chat",
415+
"messages": [{"role": "user", "content": "hello"}],
416+
"chat_template_kwargs": {"system": "system prompt"},
417+
}
418+
processed = processor.process_request_dict(request, max_model_len=100)
419+
token_ids = processed["prompt_token_ids"]
420+
self.assertIsInstance(token_ids, list)
421+
self.assertTrue(all(isinstance(x, int) for x in token_ids))
422+
423+
def test_process_request_dict_messages_template_tensor(self):
424+
"""encode() 返回带 tolist() 的 tensor-like 对象时,messages2ids 应正确转换为 list"""
425+
426+
class TensorLike:
427+
"""模拟 numpy/paddle/torch tensor,有 tolist() 方法"""
428+
429+
def __init__(self, ids):
430+
self._ids = ids
431+
432+
def tolist(self):
433+
return self._ids
434+
435+
class TensorTokenizer(DummyTokenizer):
436+
def encode(self, text, add_special_tokens=True, **kwargs):
437+
return TensorLike([len(text)])
438+
439+
module = self.text_processor_module
440+
processor = module.DataProcessor("stub-model")
441+
processor.tokenizer = TensorTokenizer()
442+
443+
request = {
444+
"request_id": "chat",
445+
"messages": [{"role": "user", "content": "hello"}],
446+
"chat_template_kwargs": {"system": "system prompt"},
447+
}
448+
processed = processor.process_request_dict(request, max_model_len=100)
449+
token_ids = processed["prompt_token_ids"]
450+
self.assertIsInstance(token_ids, list)
451+
self.assertTrue(all(isinstance(x, int) for x in token_ids))
452+
453+
def test_process_request_dict_messages_template_plain_dict(self):
454+
"""encode() 返回 plain dict 时,messages2ids 应正确提取 input_ids 而非返回 key 列表"""
455+
456+
class PlainDictTokenizer(DummyTokenizer):
457+
def encode(self, text, add_special_tokens=True, **kwargs):
458+
return {"input_ids": [len(text)], "attention_mask": [1]}
459+
460+
module = self.text_processor_module
461+
processor = module.DataProcessor("stub-model")
462+
processor.tokenizer = PlainDictTokenizer()
463+
464+
request = {
465+
"request_id": "chat",
466+
"messages": [{"role": "user", "content": "hello"}],
467+
"chat_template_kwargs": {"system": "system prompt"},
468+
}
469+
processed = processor.process_request_dict(request, max_model_len=100)
470+
token_ids = processed["prompt_token_ids"]
471+
self.assertIsInstance(token_ids, list)
472+
self.assertTrue(all(isinstance(x, int) for x in token_ids))
473+
# 确保不是 key 列表 ['input_ids', 'attention_mask']
474+
self.assertNotIn("input_ids", token_ids)
475+
390476
def test_process_request_dict_handles_sequences(self):
391477
request = {
392478
"prompt": [1, 2, 3, 4, 5, 6],

0 commit comments

Comments
 (0)