Skip to content

Commit 9603653

Browse files
authored
fix(gdpval): plumb judge_responses_create_params_overrides into create() call (#1174)
## Summary `GDPValResourcesServerConfig.judge_responses_create_params_overrides` is documented as a way to override the kwargs passed to the rubric judge's `client.chat.completions.create(...)` call, but today only `model` and `api_key` are extracted from it (`resources_servers/gdpval/app.py:148-151`). Anything else — most importantly `max_tokens` — is silently dropped, which is unfortunate because the default `max_tokens=8192` in `scoring.py:132,265` truncates the rubric judge's JSON for ~50% of GDPVal-style rollouts. This PR plumbs the rest of the dict through: - `app.py` switches `model` / `api_key` extraction from `dict.get(...)` to `dict.pop(...)`, then forwards what's left as a new `create_overrides` kwarg. - `score_with_rubric` and `score_with_rubric_visual` accept a `create_overrides: dict | None = None`, and merge it into the kwargs of `client.chat.completions.create` (user-supplied keys win over the in-function defaults). - `_verify_comparison` is left alone — it goes through `comparison.run_trials`, which already exposes `max_output_tokens` separately and doesn't share the rubric path. After this PR, the override config field actually does what its name promises: ```yaml gdpval_resources_server: resources_servers: gdpval: judge_responses_create_params_overrides: max_tokens: 16384 temperature: 0.0 ``` or via Hydra at the CLI: ``` ++gdpval_resources_server.resources_servers.gdpval.judge_responses_create_params_overrides.max_tokens=16384 ``` ## Test plan - [x] `pytest resources_servers/gdpval/tests/test_app.py -v` — 8 passed (1 new + 7 existing). The new `test_verify_rubric_passes_create_overrides_through` builds a server with `judge_responses_create_params_overrides={"model": "custom-judge", "api_key": "sk-custom", "max_tokens": 16384, "temperature": 0.0}`, patches `score_with_rubric` to capture kwargs, and asserts `model` / `api_key` were popped and the rest reached the scoring function as `create_overrides`. - [x] `ruff check` / `ruff format --check` clean. - [x] Manual end-to-end: a 140-task rubric run on Ultra V3 SFT showed an `invalid_judge_response` rate of ~44% under load (vs. 10% on a 10-rollout smoke), all attributable to truncated rubric JSON at the 8192 cap. Bumping the override to 16384 takes the truncated-JSON fallback path out of play; rerun in progress. --------- Signed-off-by: Serge Panev <spanev@nvidia.com>
1 parent 2af2d53 commit 9603653

3 files changed

Lines changed: 71 additions & 14 deletions

File tree

resources_servers/gdpval/app.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,11 @@ async def _verify_rubric(self, body: GDPValVerifyRequest) -> GDPValVerifyRespons
147147

148148
overrides = dict(self.config.judge_responses_create_params_overrides or {})
149149
judge_base_url = get_server_url(self.config.judge_model_server.name) + "/v1"
150-
judge_model_name = overrides.get("model", "judge")
151-
judge_api_key = overrides.get("api_key", "dummy")
150+
judge_model_name = overrides.pop("model", "judge")
151+
judge_api_key = overrides.pop("api_key", "dummy")
152+
# Anything left in `overrides` (max_tokens, temperature, top_p, …) is
153+
# merged into the judge's chat.completions.create kwargs.
154+
judge_create_overrides = overrides or None
152155

153156
deliverable_text = _safe_output_text(body.response)
154157
deliverable_content_blocks: Optional[List[Dict[str, Any]]] = None
@@ -185,6 +188,7 @@ async def _verify_rubric(self, body: GDPValVerifyRequest) -> GDPValVerifyRespons
185188
model_base_url=judge_base_url,
186189
model_name=judge_model_name,
187190
api_key=judge_api_key,
191+
create_overrides=judge_create_overrides,
188192
)
189193
else:
190194
from resources_servers.gdpval.scoring import score_with_rubric
@@ -198,6 +202,7 @@ async def _verify_rubric(self, body: GDPValVerifyRequest) -> GDPValVerifyRespons
198202
model_base_url=judge_base_url,
199203
model_name=judge_model_name,
200204
api_key=judge_api_key,
205+
create_overrides=judge_create_overrides,
201206
)
202207

203208
return GDPValVerifyResponse(

resources_servers/gdpval/scoring.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,17 @@ async def score_with_rubric(
9292
model_base_url: str,
9393
model_name: str,
9494
api_key: str = "dummy",
95+
create_overrides: dict | None = None,
9596
) -> tuple[float, dict | None]:
9697
"""Score a deliverable against a rubric using an LLM judge.
9798
9899
Returns ``(score, judge_response)`` where *score* is a float in [0, 1]
99100
and *judge_response* is the parsed JSON dict from the judge (or ``None``
100101
on failure).
102+
103+
*create_overrides* is merged into the kwargs passed to
104+
``client.chat.completions.create``; user-supplied keys win over defaults.
105+
Use it to bump ``max_tokens`` (default 8192), tweak ``temperature``, etc.
101106
"""
102107
from openai import AsyncOpenAI
103108

@@ -119,18 +124,21 @@ async def score_with_rubric(
119124
response = None
120125
for attempt in range(max_retries + 1):
121126
try:
122-
response = await client.chat.completions.create(
123-
model=model_name,
124-
messages=[
127+
create_kwargs: dict = {
128+
"model": model_name,
129+
"messages": [
125130
{
126131
"role": "system",
127132
"content": "You are an expert evaluator. You must respond with valid JSON only.",
128133
},
129134
{"role": "user", "content": judge_prompt},
130135
],
131-
temperature=0.1,
132-
max_tokens=8192,
133-
)
136+
"temperature": 0.1,
137+
"max_tokens": 8192,
138+
}
139+
if create_overrides:
140+
create_kwargs.update(create_overrides)
141+
response = await client.chat.completions.create(**create_kwargs)
134142
break
135143
except Exception as retry_err:
136144
err_str = str(retry_err)
@@ -217,6 +225,7 @@ async def score_with_rubric_visual(
217225
model_base_url: str,
218226
model_name: str,
219227
api_key: str = "dummy",
228+
create_overrides: dict | None = None,
220229
) -> tuple[float, dict | None]:
221230
"""Score deliverables visually using a multimodal judge (e.g., Gemini 3 Pro).
222231
@@ -226,6 +235,10 @@ async def score_with_rubric_visual(
226235
*deliverable_content_blocks* is a list of OpenAI-compatible content blocks
227236
(text and image_url) produced by ``file_reader.convert_deliverables_to_content_blocks()``.
228237
238+
*create_overrides* is merged into the kwargs passed to
239+
``client.chat.completions.create``; user-supplied keys win over defaults.
240+
Use it to bump ``max_tokens`` (default 8192), tweak ``temperature``, etc.
241+
229242
Returns ``(score, judge_response)`` — same contract as ``score_with_rubric``.
230243
"""
231244
from openai import AsyncOpenAI
@@ -252,18 +265,21 @@ async def score_with_rubric_visual(
252265
response = None
253266
for attempt in range(max_retries + 1):
254267
try:
255-
response = await client.chat.completions.create(
256-
model=model_name,
257-
messages=[
268+
create_kwargs: dict = {
269+
"model": model_name,
270+
"messages": [
258271
{
259272
"role": "system",
260273
"content": "You are an expert evaluator. You must respond with valid JSON only.",
261274
},
262275
{"role": "user", "content": content},
263276
],
264-
temperature=0.1,
265-
max_tokens=8192,
266-
)
277+
"temperature": 0.1,
278+
"max_tokens": 8192,
279+
}
280+
if create_overrides:
281+
create_kwargs.update(create_overrides)
282+
response = await client.chat.completions.create(**create_kwargs)
267283
break
268284
except Exception as retry_err:
269285
err_str = str(retry_err)

resources_servers/gdpval/tests/test_app.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,42 @@ async def fake_score_with_rubric(**_kwargs):
122122
assert resp.invalid_judge_response is False
123123
assert resp.judge_response == canned_result
124124

125+
@pytest.mark.asyncio
126+
async def test_verify_rubric_passes_create_overrides_through(self) -> None:
127+
"""``judge_responses_create_params_overrides`` must reach the scoring fn.
128+
129+
``model`` and ``api_key`` are pulled out as their own kwargs; everything
130+
else (e.g. ``max_tokens``, ``temperature``) flows through as
131+
``create_overrides`` and gets merged into ``client.chat.completions.create``.
132+
"""
133+
server = _server(
134+
reward_mode="rubric",
135+
judge_responses_create_params_overrides={
136+
"model": "custom-judge",
137+
"api_key": "sk-custom", # pragma: allowlist secret
138+
"max_tokens": 16384,
139+
"temperature": 0.0,
140+
},
141+
)
142+
143+
captured: dict = {}
144+
145+
async def fake_score_with_rubric(**kwargs):
146+
captured.update(kwargs)
147+
return 0.5, {"overall_score": 0.5}
148+
149+
body = _verify_request(rubric_json=[{"criterion": "clarity", "score": 1}])
150+
151+
with (
152+
patch("resources_servers.gdpval.scoring.score_with_rubric", side_effect=fake_score_with_rubric),
153+
patch("resources_servers.gdpval.app.get_server_url", return_value="http://localhost:9999"),
154+
):
155+
await server.verify(body)
156+
157+
assert captured["model_name"] == "custom-judge"
158+
assert captured["api_key"] == "sk-custom" # pragma: allowlist secret
159+
assert captured["create_overrides"] == {"max_tokens": 16384, "temperature": 0.0}
160+
125161
@pytest.mark.asyncio
126162
async def test_verify_comparison_missing_reference(self, tmp_path) -> None:
127163
server = _server(

0 commit comments

Comments
 (0)