[Speculative Decoding] Refine ngram kernel signature and adapt ngram proposer#7774
[Speculative Decoding] Refine ngram kernel signature and adapt ngram proposer#7774NKNaN wants to merge 1 commit intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-11 16:27:40
📋 Review 摘要
PR 概述:重构 ngram kernel 接口,移除冗余的 input_ids/input_ids_len 参数,统一由 token_ids_all 承载 prompt 与生成 token,并修复 ngram 偏移量 bug(cur_step_idx + 1 - ngram_size → cur_step_idx - ngram_size)。
变更范围:custom_ops/gpu_ops/speculate_decoding/、fastdeploy/spec_decode/ngram.py、fastdeploy/worker/gpu_model_runner.py、fastdeploy/config.py、tests/spec_decode/
影响面 Tag:[Speculative Decoding] [OP] [FDConfig]
📝 PR 规范检查
标题含 [Speculative Decoding](官方 Tag)✓;PR body 结构包含 Motivation / Modifications / Usage or Command / Accuracy Tests / Checklist 所有必填段落 ✓,无需修改建议。
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🟡 建议 | fastdeploy/worker/gpu_model_runner.py:2100 |
CUDAGraph capture 新增 NGRAM,需确认其他硬件 Runner 是否同步 |
| ❓ 疑问 | tests/spec_decode/test_ngram_gpu_kernel.py:285 |
_make_mixed_test_data 的 step_idx 保留旧语义(gen_len-1),与 _make_ngram_test_data 新语义(gen_len)不一致,需确认是否有意为之 |
总体评价
本次 PR 将 ngram kernel 接口从需要单独传入 input_ids/input_ids_len 简化为完全依赖 token_ids_all,同时修复了 ngram 偏移量 bug,逻辑清晰正确、端到端验证充分(有详细的 VERIFY-WRITE/READ 及 E2E 匹配日志)。主要关注点是多硬件 Runner 同步及测试数据语义一致性,无阻塞性问题。
| elif self.speculative_decoding and self.spec_method in [ | ||
| SpecMethod.MTP, | ||
| SpecMethod.SUFFIX, | ||
| SpecMethod.NGRAM, |
There was a problem hiding this comment.
🟡 建议 其他硬件 Runner 的同步检查
gpu_model_runner.py 的 capture_model() 新增了 SpecMethod.NGRAM,按照 A6 多硬件同步原则,如果 dcu_model_runner.py、iluvatar_model_runner.py 等文件的 capture_model() 中也存在形如 self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX] 的分支,则同样需要加入 SpecMethod.NGRAM。
建议确认其他硬件 Runner 是否支持投机解码,若支持则补充同步。
| pre_ids[b, :gen_len] = input_ids[b, src : src + gen_len] | ||
| # step_idx = last valid position (0-based index) | ||
| # step_idx = last valid position (0-based index), matches hybrid kernel semantics | ||
| step_idx[b] = gen_len - 1 |
There was a problem hiding this comment.
❓ 疑问 _make_mixed_test_data 中 step_idx 语义与 _make_ngram_test_data 不一致
本次 PR 将 _make_ngram_test_data 的 step_idx 从 gen_len - 1(0-based 最后有效位置)改为 gen_len(生成 token 数量),同时内核偏移公式也从 cur_step_idx + 1 - ngram_size 改为 cur_step_idx - ngram_size。
但此函数 _make_mixed_test_data 仍保持 step_idx[b] = gen_len - 1,注释说明是 matches hybrid kernel semantics。请确认:
_make_mixed_test_data是否对应一个未在本次修改的独立 kernel 路径(即旧语义 kernel)?- 若对应同一个
ngram_matchkernel,则step_idx值应同步改为gen_len,否则测试数据与内核语义不符,可能导致 CPU 参考实现与 GPU kernel 计算结果不一致。
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览
2 任务状态汇总2.1 Required任务 : 7/10 通过
2.2 可选任务 — 21/26 通过
3 失败详情(仅 required)Approval — 流程审批(置信度: 高)Approval
根因详情: 关键日志: 修复建议:
修复建议摘要: 请 @freeliuzc 或 @Deleter-D 审批本PR 关联变更: PR 修改了 |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #7774 +/- ##
==========================================
Coverage ? 71.66%
==========================================
Files ? 396
Lines ? 55706
Branches ? 8712
==========================================
Hits ? 39921
Misses ? 13040
Partials ? 2745
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
投机解码 ngram 方法端到端结果验证
Modifications
测试脚本(AI studio A800单卡环境能够跑通):
修改ngram kernel接口:
由于 input_ids 和 pre_ids 目前全部并入 token_ids_all,将原本接口中的 input_ids 删除,prompt tokens 和 predict tokens 完全由 token_ids_all 负责记录。
确认修改后的 ngram match kernel 端到端执行正确:
token_ids_all 为 5 时是 dummy batch,seq_lens_decoder=0。除此之外可以看到 token_ids_all prompt 部分的读和写的内容一致。
kernel 中的 ngram 地址计算 bug 修复后日志打印结果显示能够匹配到,且存在经过 verify 后在一步 decode 中接受了多个token的情况,如:
[NGRAM-DEBUG] call=22 slt=[6, 6] step_idx=[[25], [24], [13], [13], [13], [13], [13], [13]] prompt_lens=[[168], [54], [2048], [2048], [1024], [1024], [1024], [1024]] draft_token_num=[5, 5, 5, 5, 5, 5, 5, 5] seq_dec=[192, 77, 0, 0, 0, 0, 0, 0]
[NGRAM-DEBUG] call=23 slt=[6, 6] step_idx=[[31], [25], [13], [13], [13], [13], [13], [13]] prompt_lens=[[168], [54], [2048], [2048], [1024], [1024], [1024], [1024]] draft_token_num=[5, 5, 5, 5, 5, 5, 5, 5] seq_dec=[198, 78, 0, 0, 0, 0, 0, 0]
相邻两次proposer.run()时,step_idx[0] 从 25 增加到 31,seq_len_decoder[0] 从 192 增加到 198
CUDAGraph 适配
Overlap Schedule 适配
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.