Skip to content

Commit 4234fbe

Browse files
mferratocmunley1
andauthored
Simplify verifiers_agent to use upstream NeMoRLChatCompletionsClient (#1076)
## Summary - Replaces the custom `VLLMOpenAIClient` and `parse_response_messages` runtime patch with verifiers' `NeMoRLChatCompletionsClient`, which handles token ID relocation natively - Removes the `verifiers==0.1.9.post3` version pin - Reduces the agent from ~340 to ~250 lines ## Dependency > **Draft:** This PR depends on [PrimeIntellect-ai/verifiers#1141](PrimeIntellect-ai/verifiers#1141) > which adds `NeMoRLChatCompletionsClient` to the verifiers library. This PR > should not be merged until that PR is merged and released. ## What changed - **Deleted:** `VLLMOpenAIClient` (70-line custom HTTP client) — replaced by `NeMoRLChatCompletionsClient` - **Deleted:** `_patched_parse_response_messages` runtime patch (22 lines) — token IDs now flow through the standard `parse_tokens()` path - **Deleted:** `max_concurrent_generation` / `max_concurrent_scoring` config — new verifiers API handles concurrency internally - **Changed:** `run_group()` call updated to new API signature (`sampling_args` instead of `gen_sampling_args` + semaphores) - **Changed:** `_convert_trajectory_to_output()` handles single-turn environments and Pydantic message objects - **Added:** `state_columns=["trajectory"]` to `run_group()` to ensure token IDs are included in output ## Test plan - [x] Integration: validated on A10 GPU with Qwen3-4B + acereason-math environment - [x] Token ID validation: prompt_token_ids, generation_token_ids, generation_log_probs all populated - [ ] Pending: multi-turn tool-calling environment validation - [ ] Pending: NeMo RL training loop validation --------- Signed-off-by: Matthew Ferrato <mferrato@nvidia.com> Signed-off-by: Mauricio Ferrato <mferrato@nvidia.com> Co-authored-by: Christian Munley <cmunley@nvidia.com>
1 parent 2d74930 commit 4234fbe

2 files changed

Lines changed: 55 additions & 136 deletions

File tree

responses_api_agents/verifiers_agent/app.py

Lines changed: 53 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
from __future__ import annotations
1516

1617
import logging
1718
import traceback
1819
from typing import Any
1920

2021
import verifiers as vf
21-
import verifiers.envs.multiturn_env as _multiturn_env_module
2222
from fastapi import Body, Request, Response
23-
from openai.types.chat.chat_completion import ChatCompletion
23+
from openai import AsyncOpenAI
2424
from pydantic import ConfigDict, Field
25-
from verifiers.utils.async_utils import maybe_semaphore
26-
from verifiers.utils.response_utils import parse_response_messages as _original_parse_response_messages
25+
from verifiers.clients import NeMoRLChatCompletionsClient
2726

2827
from nemo_gym.base_resources_server import BaseRunRequest, BaseVerifyResponse
2928
from nemo_gym.base_responses_api_agent import BaseResponsesAPIAgentConfig, SimpleResponsesAPIAgent
@@ -37,36 +36,11 @@
3736
NeMoGymResponseOutputMessageForTraining,
3837
NeMoGymResponseOutputText,
3938
)
40-
from nemo_gym.server_utils import get_global_aiohttp_client
4139

4240

4341
logger = logging.getLogger(__name__)
4442

4543

46-
# patch verifiers to include prompt and generation token ids and logprobs for
47-
# re-tokenization correction in replace_prefix_tokens (https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40)
48-
async def _patched_parse_response_messages(response, message_type):
49-
messages = await _original_parse_response_messages(response, message_type)
50-
if message_type == "chat" and isinstance(messages, list):
51-
for msg in messages:
52-
if isinstance(msg, dict) and msg.get("role") == "assistant":
53-
if hasattr(response, "prompt_token_ids"):
54-
msg["prompt_token_ids"] = response.prompt_token_ids
55-
if response.choices and hasattr(response.choices[0], "token_ids"):
56-
msg["generation_token_ids"] = response.choices[0].token_ids
57-
if (
58-
response.choices
59-
and response.choices[0].logprobs
60-
and hasattr(response.choices[0].logprobs, "content")
61-
and response.choices[0].logprobs.content
62-
):
63-
msg["generation_log_probs"] = [t.logprob for t in response.choices[0].logprobs.content]
64-
return messages
65-
66-
67-
_multiturn_env_module.parse_response_messages = _patched_parse_response_messages
68-
69-
7044
class VerifiersNeMoGymResponse(NeMoGymResponse):
7145
env_id: str
7246
group_id: str
@@ -84,90 +58,13 @@ class VerifiersAgentVerifyResponse(BaseVerifyResponse):
8458
reward: float
8559

8660

87-
class VLLMOpenAIClient:
88-
def __init__(self, base_url: str) -> None:
89-
self._base_url = base_url.rstrip("/")
90-
self.chat = self._Chat(self)
91-
92-
class _Chat:
93-
def __init__(self, client: "VLLMOpenAIClient") -> None:
94-
self.completions = client
95-
96-
async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion:
97-
request_body: dict[str, Any] = {
98-
"model": kwargs.get("model", ""),
99-
"messages": kwargs.get("messages", []),
100-
}
101-
for key in (
102-
"temperature",
103-
"max_tokens",
104-
"max_completion_tokens",
105-
"top_p",
106-
"stop",
107-
"n",
108-
"tools",
109-
"tool_choice",
110-
):
111-
if key in kwargs and kwargs[key] is not None:
112-
request_body[key] = kwargs[key]
113-
114-
url = f"{self._base_url}/chat/completions"
115-
try:
116-
session = get_global_aiohttp_client()
117-
async with session.post(url, json=request_body) as resp:
118-
if resp.status != 200:
119-
error_text = await resp.text()
120-
logger.error(f"Request to {url} failed with status {resp.status}: {error_text}")
121-
resp.raise_for_status()
122-
response_dict = await resp.json()
123-
except Exception as e:
124-
logger.error(f"Exception calling {url}: {type(e).__name__}: {e}")
125-
raise
126-
127-
choice_dict = response_dict["choices"][0]
128-
message_dict = choice_dict.get("message", {})
129-
130-
prompt_token_ids = message_dict.pop("prompt_token_ids", [])
131-
generation_token_ids = message_dict.pop("generation_token_ids", [])
132-
generation_log_probs = message_dict.pop("generation_log_probs", [])
133-
134-
if not generation_token_ids:
135-
logger.warning(
136-
f"No generation_token_ids in response! Full message keys were: {list(choice_dict.get('message', {}).keys())}"
137-
)
138-
139-
if prompt_token_ids and isinstance(prompt_token_ids[0], str):
140-
prompt_token_ids = [int(tid) for tid in prompt_token_ids]
141-
142-
if generation_token_ids and isinstance(generation_token_ids[0], str):
143-
generation_token_ids = [int(tid) for tid in generation_token_ids]
144-
145-
if generation_token_ids and generation_log_probs:
146-
choice_dict["logprobs"] = {
147-
"content": [
148-
{"token": f"token_id:{tid}", "logprob": lp, "top_logprobs": []}
149-
for tid, lp in zip(generation_token_ids, generation_log_probs)
150-
]
151-
}
152-
153-
response = ChatCompletion.model_validate(response_dict)
154-
setattr(response, "prompt_token_ids", prompt_token_ids)
155-
setattr(response.choices[0], "token_ids", generation_token_ids)
156-
return response
157-
158-
15961
class VerifiersAgentConfig(BaseResponsesAPIAgentConfig):
16062
model_server: ModelServerRef
16163
model_name: str = Field(default="", description="Model name")
16264

16365
vf_env_id: str = Field(default="", description="Verifiers environment ID")
16466
vf_env_args: dict = Field(default_factory=dict, description="Verifiers environment arguments")
16567

166-
max_concurrent_generation: int = Field(
167-
default=-1, description="Max concurrent generation requests (-1 = unlimited)"
168-
)
169-
max_concurrent_scoring: int = Field(default=-1, description="Max concurrent scoring requests (-1 = unlimited)")
170-
17168
max_tokens: int = Field(default=8192, description="Max tokens for generation")
17269

17370
# nemo rl generation_config overrides these
@@ -193,17 +90,17 @@ class VerifiersAgent(SimpleResponsesAPIAgent):
19390
model_config = ConfigDict(arbitrary_types_allowed=True)
19491
config: VerifiersAgentConfig
19592

196-
envs_cache: dict[str, Any] = Field(default_factory=dict) # vf.Environment
197-
openai_client_cache: dict[str, VLLMOpenAIClient] = Field(default_factory=dict)
93+
envs_cache: dict[str, Any] = Field(default_factory=dict)
94+
client_cache: dict[str, NeMoRLChatCompletionsClient] = Field(default_factory=dict)
19895

19996
def _get_env(self, vf_env_id: str) -> vf.Environment:
20097
if vf_env_id not in self.envs_cache:
20198
self.envs_cache[vf_env_id] = vf.load_environment(vf_env_id, **self.config.vf_env_args)
20299
return self.envs_cache[vf_env_id]
203100

204-
def _get_openai_client(self) -> VLLMOpenAIClient:
101+
def _get_client(self) -> NeMoRLChatCompletionsClient:
205102
cache_key = self.config.model_server.name
206-
if cache_key not in self.openai_client_cache:
103+
if cache_key not in self.client_cache:
207104
server_config_dict = get_first_server_config_dict(
208105
self.server_client.global_config_dict,
209106
self.config.model_server.name,
@@ -213,25 +110,50 @@ def _get_openai_client(self) -> VLLMOpenAIClient:
213110
if not model_server_url.endswith("/v1"):
214111
model_server_url = model_server_url.rstrip("/") + "/v1"
215112

216-
self.openai_client_cache[cache_key] = VLLMOpenAIClient(base_url=model_server_url)
113+
openai_client = AsyncOpenAI(
114+
base_url=model_server_url,
115+
api_key="EMPTY", # pragma: allowlist secret
116+
)
117+
self.client_cache[cache_key] = NeMoRLChatCompletionsClient(openai_client)
217118

218-
return self.openai_client_cache[cache_key]
119+
return self.client_cache[cache_key]
219120

220-
def _convert_trajectory_to_output(self, state: dict) -> list:
121+
def _convert_trajectory_to_output(self, rollout_output: dict) -> list:
221122
output = []
222-
trajectory = state.get("trajectory", [])
223-
224-
for step in trajectory:
225-
for msg in step.get("prompt", []):
226-
if isinstance(msg, dict):
123+
trajectory = rollout_output.get("trajectory", [])
124+
125+
# Build steps from trajectory if present, otherwise fall back to
126+
# top-level prompt/completion (single-turn environments).
127+
if trajectory:
128+
steps = trajectory
129+
else:
130+
steps = [
131+
{
132+
"prompt": rollout_output.get("prompt", []),
133+
"completion": rollout_output.get("completion", []),
134+
"tokens": None,
135+
}
136+
]
137+
138+
for step in steps:
139+
for msg in step.get("prompt", []) or []:
140+
# Handle both plain dicts (serialized RolloutOutput) and
141+
# Pydantic CustomBaseModel messages (which support .get()).
142+
if hasattr(msg, "get"):
227143
role = msg.get("role", "user")
228144
content = msg.get("content", "")
229145
output.append(NeMoGymEasyInputMessage(role=role, content=content).model_dump())
230146

231-
tokens = step.get("tokens")
232-
for msg in step.get("completion", []):
233-
if isinstance(msg, dict):
147+
step_tokens = step.get("tokens") if hasattr(step, "get") else None
148+
for msg in step.get("completion", []) or []:
149+
if hasattr(msg, "get"):
234150
content = msg.get("content", "")
151+
# For trajectory steps, tokens are on the step.
152+
# For single-turn fallback, tokens may be on the message
153+
# (ResponseMessage.tokens from verifiers).
154+
tokens = step_tokens
155+
if tokens is None:
156+
tokens = msg.get("tokens") if hasattr(msg, "get") else getattr(msg, "tokens", None)
235157
if tokens:
236158
output.append(
237159
NeMoGymResponseOutputMessageForTraining(
@@ -278,31 +200,28 @@ async def responses(
278200
example_id=body.example_id,
279201
)
280202

281-
client = self._get_openai_client()
282-
283-
gen_sem = await maybe_semaphore(self.config.max_concurrent_generation)
284-
score_sem = await maybe_semaphore(self.config.max_concurrent_scoring)
203+
client = self._get_client()
285204

286-
# prefer NeMo RL generation config set in responses_create_params https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/experience/rollouts.py#L1045-L1046
205+
# prefer NeMo RL generation config set in responses_create_params
206+
# https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/experience/rollouts.py#L1045-L1046
287207
sampling_args = {
288208
"max_tokens": self.config.max_tokens,
289209
"temperature": getattr(body.responses_create_params, "temperature", None) or self.config.temperature,
290210
"top_p": getattr(body.responses_create_params, "top_p", None) or self.config.top_p,
291211
}
292-
states = await vf_env.run_group(
212+
outputs = await vf_env.run_group(
293213
group_inputs=[rollout_input],
294214
client=client,
295215
model=self.config.model_name,
296-
gen_sampling_args=sampling_args,
297-
gen_sem=gen_sem,
298-
score_sem=score_sem,
216+
sampling_args=sampling_args,
217+
state_columns=["trajectory"],
299218
)
300219

301-
state = states[0]
302-
reward = state.get("reward", 0.0) or 0.0
303-
metrics = state.get("metrics", {}) or {}
220+
rollout_output = outputs[0]
221+
reward = rollout_output.get("reward", 0.0) or 0.0
222+
metrics = rollout_output.get("metrics", {}) or {}
304223

305-
output = self._convert_trajectory_to_output(state)
224+
output = self._convert_trajectory_to_output(rollout_output)
306225

307226
return VerifiersNeMoGymResponse(
308227
id=f"verifiers-{vf_env_id}-{task_idx}",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
-e nemo-gym[dev] @ ../../
2-
verifiers==0.1.9.post3
2+
verifiers
33
--extra-index-url https://hub.primeintellect.ai/primeintellect/simple/
4-
acereason-math
4+
acereason-math

0 commit comments

Comments
 (0)