Skip to content

[BugFix] Fix save_output_specualate parameter bugs in suffix decoding#7566

Open
Deleter-D wants to merge 2 commits intoPaddlePaddle:developfrom
Deleter-D:dev_fix_save_output_bugs
Open

[BugFix] Fix save_output_specualate parameter bugs in suffix decoding#7566
Deleter-D wants to merge 2 commits intoPaddlePaddle:developfrom
Deleter-D:dev_fix_save_output_bugs

Conversation

@Deleter-D
Copy link
Copy Markdown
Collaborator

@Deleter-D Deleter-D commented Apr 22, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

proposer_share_inputs should be optional parameter since it's only needed for MTP prefill mode

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 22, 2026

Thanks for your contribution!

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-22 20:06:05

📋 Review 摘要

PR 概述:修复投机解码(Speculative Decoding)场景下 save_output 相关的两处 bug:draft_tokens 写入范围不足以及 proposer_share_inputs 在非 MTP 场景下的非法传参。
变更范围spec_decode/base.pymodel_executor/pre_and_post_process.pyworker/gpu_model_runner.py
影响面 Tag[Speculative Decoding] [BugFix]


📝 PR 规范检查

PR 的 MotivationModifications 章节均未填写,无法了解 bug 的触发条件和修复思路,建议补充以下内容(可直接复制):

描述建议(可直接复制):

## Motivation
在投机解码(Speculative Decoding)MTP 模式下,存在以下两处 bug:
1. `prepare_dummy_speculative_drafts` 中 `draft_tokens` 写入范围为 `[:max_fake_drafts]`,
   但 `seq_lens_this_time` 被设置为 `max_fake_drafts + 1`,两者不一致导致 CUDAGraph 捕获时
   最后一列 token 未被正确初始化,可能引发推理异常。
2. `save_output_specualate` 在非 MTP prefill 场景下仍强制要求传入 `proposer_share_inputs`,
   导致非 MTP 路径调用时出错。

## Modifications
- `spec_decode/base.py`:将 `draft_tokens` 写入列索引从 `:max_fake_drafts` 修正为 `:max_fake_drafts + 1`,与 `seq_lens_this_time` 保持一致。
- `model_executor/pre_and_post_process.py`:将 `proposer_share_inputs` 改为可选参数(默认 `None`),
  并在 `is_mtp_prefill=True` 分支入口处增加 `assert` 保护。
- `worker/gpu_model_runner.py`:调用侧按 `spec_method == MTP` 条件传入 `proposer.model_inputs` 或 `None`。

问题

级别 文件 概述
🟡 建议 spec_decode/base.py:121 写入范围修复正确,建议补充注释说明 +1 的语义

总体评价

三处修改逻辑连贯、方向正确,proposer_share_inputs 的参数设计由强制传参改为按需传入,提升了接口健壮性;draft_tokens 写入范围修复消除了 CUDAGraph 场景下的潜在不一致风险。建议补全 PR 描述以便后续追溯。

stop = share_inputs["stop_flags"][0].item()
if not stop:
share_inputs["draft_tokens"][:batch_size, :max_fake_drafts] = 5
share_inputs["draft_tokens"][:batch_size, : max_fake_drafts + 1] = 5
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 修复了 draft_tokens 写入范围与 seq_lens_this_time 不一致的 bug

修改前仅写入 [:max_fake_drafts] 列,但 seq_lens_this_time 设置为 max_fake_drafts + 1,导致最后一列 token 未初始化(仍为历史值),在 CUDAGraph 捕获时可能产生不一致行为。修复后写入范围与 seq_len 一致,逻辑正确。

建议补充注释说明 +1 的语义(1 个 target token 位 + max_fake_drafts 个 draft token 位):

# draft_tokens layout: [target_token_placeholder, draft_token_0, ..., draft_token_{N-1}]
share_inputs["draft_tokens"][:batch_size, : max_fake_drafts + 1] = 5

@Deleter-D Deleter-D changed the title [BugFix] Dev fix save output bugs [BugFix] Fix save_output_specualate parameter bugs in suffix decoding Apr 22, 2026
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@b19754c). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/pre_and_post_process.py 0.00% 1 Missing ⚠️
fastdeploy/spec_decode/base.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7566   +/-   ##
==========================================
  Coverage           ?   72.46%           
==========================================
  Files              ?      419           
  Lines              ?    57649           
  Branches           ?     9037           
==========================================
  Hits               ?    41773           
  Misses             ?    13058           
  Partials           ?     2818           
Flag Coverage Δ
GPU 72.46% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants