Skip to content

[XPU] Support Guided Decoding for xpu, Also fix import errors on XPU when torch is installed.#7531

Open
Jiajun-Ji wants to merge 1 commit intoPaddlePaddle:developfrom
Jiajun-Ji:guided-xpu
Open

[XPU] Support Guided Decoding for xpu, Also fix import errors on XPU when torch is installed.#7531
Jiajun-Ji wants to merge 1 commit intoPaddlePaddle:developfrom
Jiajun-Ji:guided-xpu

Conversation

@Jiajun-Ji
Copy link
Copy Markdown
Contributor

@Jiajun-Ji Jiajun-Ji commented Apr 21, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

xpu下支持Structured Outputs功能
需额外在XPU环境下安装
torch 2.6.0+cpu
xgrammar 0.1.19
测试结果如下
image

Modifications

Usage or Command

"""
XPU Guided Decoding 端到端测试脚本
用法: python test_guided_decoding.py [--port PORT]
"""

import argparse
import json
import re
import sys
import traceback

import openai
from pydantic import BaseModel
from enum import Enum


def make_client(port):
    return openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="null")


def test_json_object(client):
    """测试 JSON Object 约束"""
    resp = client.chat.completions.create(
        model="null",
        messages=[{"role": "user", "content": "用JSON格式介绍北京的名称和人口"}],
        response_format={"type": "json_object"},
    )
    content = resp.choices[0].message.content
    print(f"  Response: {content[:200]}")
    result = json.loads(content)
    assert isinstance(result, dict), f"Expected dict, got {type(result)}"


def test_json_schema(client):
    """测试 JSON Schema 约束"""

    class BookType(str, Enum):
        romance = "Romance"
        historical = "Historical"
        adventure = "Adventure"

    class BookDescription(BaseModel):
        author: str
        title: str
        genre: BookType

    resp = client.chat.completions.create(
        model="null",
        messages=[{"role": "user", "content": "生成一个JSON,描述一本中国的著作,要包含作者、标题和书籍类型。"}],
        response_format={
            "type": "json_schema",
            "json_schema": {
                "name": "book-description",
                "schema": BookDescription.model_json_schema(),
            },
        },
    )
    content = resp.choices[0].message.content
    print(f"  Response: {content[:200]}")
    result = json.loads(content)
    assert "author" in result, "Missing 'author' field"
    assert "title" in result, "Missing 'title' field"
    assert "genre" in result, "Missing 'genre' field"
    assert result["genre"] in ["Romance", "Historical", "Adventure"], f"Invalid genre: {result['genre']}"


def test_choice(client):
    """测试 Choice 约束"""
    choices = ["北京", "上海", "广州"]
    resp = client.chat.completions.create(
        model="null",
        messages=[{"role": "user", "content": "中国的首都是?"}],
        extra_body={"guided_choice": choices},
    )
    content = resp.choices[0].message.content
    print(f"  Response: {content}")
    assert content in choices, f"Expected one of {choices}, got '{content}'"


def test_regex(client):
    """测试 Regex 约束"""
    pattern = r"^https:\/\/www\.[a-zA-Z]+\.com\/?$\n"
    resp = client.chat.completions.create(
        model="null",
        messages=[{"role": "user", "content": "生成一个标准格式的网络地址,包括协议、域名。"}],
        extra_body={"guided_regex": pattern},
    )
    content = resp.choices[0].message.content
    print(f"  Response: {content.strip()}")
    assert re.match(pattern, content), f"Does not match regex: '{content}'"


def test_grammar(client):
    """测试 EBNF Grammar 约束"""
    grammar = r"""
    root ::= html_statement

    html_statement ::= "<h1" style_attribute? ">" text "</h1>"

    style_attribute ::= " style=" dq style_value dq

    style_value ::= (font_style ("; " font_weight)?) | (font_weight ("; " font_style)?)

    font_style ::= "font-family: '" font_name "'"

    font_weight ::= "font-weight: " weight_value

    font_name ::= "Arial" | "Times New Roman" | "Courier New"

    weight_value ::= "normal" | "bold"

    text ::= [A-Za-z0-9 ]+

    dq ::= ["]
    """
    resp = client.chat.completions.create(
        model="null",
        messages=[{"role": "user", "content": "生成一段html代码,对以下标题加粗、Times New Roman字体。标题:ERNIE Bot"}],
        extra_body={"guided_grammar": grammar},
    )
    content = resp.choices[0].message.content
    print(f"  Response: {content}")
    assert content.startswith("<h1"), f"Expected <h1> tag, got: '{content}'"
    assert content.rstrip().endswith("</h1>"), f"Expected </h1> ending, got: '{content}'"


def test_structural_tag(client):
    """测试 Structural Tag 约束"""
    resp = client.chat.completions.create(
        model="null",
        messages=[
            {
                "role": "system",
                "content": (
                    "你有以下函数可以调用:\n"
                    '{"name": "get_current_date", "description": "根据给定的时区获取当前日期和时间", '
                    '"parameters": {"type": "object", "properties": {"timezone": {"type": "string"}}, "required": ["timezone"]}}\n'
                    "如果你选择调用函数,请按以下格式:<function=函数名>{参数JSON}</function>"
                ),
            },
            {"role": "user", "content": "你今天去上海出差"},
        ],
        response_format={
            "type": "structural_tag",
            "structures": [
                {
                    "begin": "<function=get_current_date>",
                    "schema": {
                        "type": "object",
                        "properties": {"timezone": {"type": "string"}},
                        "required": ["timezone"],
                    },
                    "end": "</function>",
                }
            ],
            "triggers": ["<function="],
        },
    )
    content = resp.choices[0].message.content
    print(f"  Response: {content[:200]}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=9401)
    args = parser.parse_args()

    client = make_client(args.port)

    tests = [
        ("JSON Object", test_json_object),
        ("JSON Schema", test_json_schema),
        ("Choice", test_choice),
        ("Regex", test_regex),
        ("Grammar (EBNF)", test_grammar),
        ("Structural Tag", test_structural_tag),
    ]

    passed = 0
    failed = 0
    for name, fn in tests:
        print(f"[TEST] {name}")
        try:
            fn(client)
            print(f"  -> PASS\n")
            passed += 1
        except Exception as e:
            print(f"  -> FAIL: {e}")
            traceback.print_exc()
            print()
            failed += 1

    print("=" * 40)
    print(f"Results: {passed} passed, {failed} failed, {passed + failed} total")
    sys.exit(1 if failed > 0 else 0)


if __name__ == "__main__":
    main()

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 21, 2026 08:30
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 21, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在让 XPU 路径支持 Guided Decoding(Structured Outputs),并修复在 XPU 环境下安装了 torch 时 Triton 兼容层初始化导致的导入/运行问题。

Changes:

  • XPUModelRunner 中接入 guided decoding backend,并在采样前后增加 guided decoding 的 pre/post 处理流程
  • 调整 Triton 兼容驱动初始化条件,避免在非 CUDA 场景(如 XPU)错误创建/使用 driver
  • 放开配置层面对 XPU guided decoding 的限制,并更新相关告警逻辑

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
fastdeploy/worker/xpu_model_runner.py 为 XPU 推理执行流接入 guided decoding backend,初始化 logits processor,并在 sampling 前后维护 guided decoding 状态
fastdeploy/model_executor/ops/triton_ops/triton_utils.py 仅在 torch 存在且 Paddle 编译为 CUDA 时创建 Triton driver,避免 XPU+torch 场景下的兼容层问题
fastdeploy/config.py 移除 “XPU 不支持 guided decoding” 的限制逻辑,并调整 postprocess 告警;但 check() 中依赖校验逻辑需要进一步修正
Comments suppressed due to low confidence (1)

fastdeploy/config.py:2324

  • 这里在 guided_decoding_backend != "off" 时无条件 import xgrammar,这会导致当用户选择 guided_decoding_backend="guidance" 且未安装 xgrammar 时,配置校验直接失败(与 postprocess() 中对 guidance 仅校验 llguidance 的逻辑不一致)。建议按 backend 分支分别校验依赖:xgrammar backend 才 import xgrammar;guidance backend 则校验/导入 llguidance。
            if self.structured_outputs_config.guided_decoding_backend != "off":
                # TODO: speculative decoding support guided_decoding
                assert (
                    self.speculative_config.method is None
                ), "speculative decoding currently do not support guided_decoding"

                try:
                    import xgrammar  # noqa
                except Exception as e:
                    raise Exception(
                        f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
                    )

Copy link
Copy Markdown
Collaborator

@cmcamdy cmcamdy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 33.33333% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@b79b094). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...ploy/model_executor/ops/triton_ops/triton_utils.py 33.33% 0 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7531   +/-   ##
==========================================
  Coverage           ?   73.13%           
==========================================
  Files              ?      419           
  Lines              ?    57475           
  Branches           ?     9002           
==========================================
  Hits               ?    42037           
  Misses             ?    12610           
  Partials           ?     2828           
Flag Coverage Δ
GPU 73.13% <33.33%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants