@@ -77,6 +77,9 @@ def tokenize(self, text):
7777 def convert_tokens_to_ids (self , tokens ):
7878 return [self ._value (token ) for token in tokens ]
7979
80+ def encode (self , text , add_special_tokens = True , ** kwargs ):
81+ return self .convert_tokens_to_ids (self .tokenize (text ))
82+
8083 def decode (self , token_ids , ** kwargs ):
8184 return " " .join (str (t ) for t in token_ids )
8285
@@ -387,6 +390,89 @@ def test_process_request_dict_messages_template(self):
387390 self .assertTrue (processed ["enable_thinking" ])
388391 self .assertEqual (processed ["prompt_tokens" ], "system prompt hello" )
389392
393+ def test_process_request_dict_messages_template_batch_encoding (self ):
394+ """encode() 返回 BatchEncoding-like 对象时,messages2ids 应正确提取 input_ids"""
395+
396+ class BatchEncodingLike :
397+ """模拟 HuggingFace BatchEncoding (UserDict 子类,hasattr input_ids = True)"""
398+
399+ def __init__ (self , ids ):
400+ self .input_ids = ids
401+
402+ def __getitem__ (self , key ):
403+ return getattr (self , key )
404+
405+ class BatchEncodingTokenizer (DummyTokenizer ):
406+ def encode (self , text , add_special_tokens = True , ** kwargs ):
407+ return BatchEncodingLike ([len (text )])
408+
409+ module = self .text_processor_module
410+ processor = module .DataProcessor ("stub-model" )
411+ processor .tokenizer = BatchEncodingTokenizer ()
412+
413+ request = {
414+ "request_id" : "chat" ,
415+ "messages" : [{"role" : "user" , "content" : "hello" }],
416+ "chat_template_kwargs" : {"system" : "system prompt" },
417+ }
418+ processed = processor .process_request_dict (request , max_model_len = 100 )
419+ token_ids = processed ["prompt_token_ids" ]
420+ self .assertIsInstance (token_ids , list )
421+ self .assertTrue (all (isinstance (x , int ) for x in token_ids ))
422+
423+ def test_process_request_dict_messages_template_tensor (self ):
424+ """encode() 返回带 tolist() 的 tensor-like 对象时,messages2ids 应正确转换为 list"""
425+
426+ class TensorLike :
427+ """模拟 numpy/paddle/torch tensor,有 tolist() 方法"""
428+
429+ def __init__ (self , ids ):
430+ self ._ids = ids
431+
432+ def tolist (self ):
433+ return self ._ids
434+
435+ class TensorTokenizer (DummyTokenizer ):
436+ def encode (self , text , add_special_tokens = True , ** kwargs ):
437+ return TensorLike ([len (text )])
438+
439+ module = self .text_processor_module
440+ processor = module .DataProcessor ("stub-model" )
441+ processor .tokenizer = TensorTokenizer ()
442+
443+ request = {
444+ "request_id" : "chat" ,
445+ "messages" : [{"role" : "user" , "content" : "hello" }],
446+ "chat_template_kwargs" : {"system" : "system prompt" },
447+ }
448+ processed = processor .process_request_dict (request , max_model_len = 100 )
449+ token_ids = processed ["prompt_token_ids" ]
450+ self .assertIsInstance (token_ids , list )
451+ self .assertTrue (all (isinstance (x , int ) for x in token_ids ))
452+
453+ def test_process_request_dict_messages_template_plain_dict (self ):
454+ """encode() 返回 plain dict 时,messages2ids 应正确提取 input_ids 而非返回 key 列表"""
455+
456+ class PlainDictTokenizer (DummyTokenizer ):
457+ def encode (self , text , add_special_tokens = True , ** kwargs ):
458+ return {"input_ids" : [len (text )], "attention_mask" : [1 ]}
459+
460+ module = self .text_processor_module
461+ processor = module .DataProcessor ("stub-model" )
462+ processor .tokenizer = PlainDictTokenizer ()
463+
464+ request = {
465+ "request_id" : "chat" ,
466+ "messages" : [{"role" : "user" , "content" : "hello" }],
467+ "chat_template_kwargs" : {"system" : "system prompt" },
468+ }
469+ processed = processor .process_request_dict (request , max_model_len = 100 )
470+ token_ids = processed ["prompt_token_ids" ]
471+ self .assertIsInstance (token_ids , list )
472+ self .assertTrue (all (isinstance (x , int ) for x in token_ids ))
473+ # 确保不是 key 列表 ['input_ids', 'attention_mask']
474+ self .assertNotIn ("input_ids" , token_ids )
475+
390476 def test_process_request_dict_handles_sequences (self ):
391477 request = {
392478 "prompt" : [1 , 2 , 3 , 4 , 5 , 6 ],
0 commit comments