1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+
1415from __future__ import annotations
1516
1617import logging
1718import traceback
1819from typing import Any
1920
2021import verifiers as vf
21- import verifiers .envs .multiturn_env as _multiturn_env_module
2222from fastapi import Body , Request , Response
23- from openai . types . chat . chat_completion import ChatCompletion
23+ from openai import AsyncOpenAI
2424from pydantic import ConfigDict , Field
25- from verifiers .utils .async_utils import maybe_semaphore
26- from verifiers .utils .response_utils import parse_response_messages as _original_parse_response_messages
25+ from verifiers .clients import NeMoRLChatCompletionsClient
2726
2827from nemo_gym .base_resources_server import BaseRunRequest , BaseVerifyResponse
2928from nemo_gym .base_responses_api_agent import BaseResponsesAPIAgentConfig , SimpleResponsesAPIAgent
3736 NeMoGymResponseOutputMessageForTraining ,
3837 NeMoGymResponseOutputText ,
3938)
40- from nemo_gym .server_utils import get_global_aiohttp_client
4139
4240
4341logger = logging .getLogger (__name__ )
4442
4543
46- # patch verifiers to include prompt and generation token ids and logprobs for
47- # re-tokenization correction in replace_prefix_tokens (https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40)
48- async def _patched_parse_response_messages (response , message_type ):
49- messages = await _original_parse_response_messages (response , message_type )
50- if message_type == "chat" and isinstance (messages , list ):
51- for msg in messages :
52- if isinstance (msg , dict ) and msg .get ("role" ) == "assistant" :
53- if hasattr (response , "prompt_token_ids" ):
54- msg ["prompt_token_ids" ] = response .prompt_token_ids
55- if response .choices and hasattr (response .choices [0 ], "token_ids" ):
56- msg ["generation_token_ids" ] = response .choices [0 ].token_ids
57- if (
58- response .choices
59- and response .choices [0 ].logprobs
60- and hasattr (response .choices [0 ].logprobs , "content" )
61- and response .choices [0 ].logprobs .content
62- ):
63- msg ["generation_log_probs" ] = [t .logprob for t in response .choices [0 ].logprobs .content ]
64- return messages
65-
66-
67- _multiturn_env_module .parse_response_messages = _patched_parse_response_messages
68-
69-
7044class VerifiersNeMoGymResponse (NeMoGymResponse ):
7145 env_id : str
7246 group_id : str
@@ -84,90 +58,13 @@ class VerifiersAgentVerifyResponse(BaseVerifyResponse):
8458 reward : float
8559
8660
87- class VLLMOpenAIClient :
88- def __init__ (self , base_url : str ) -> None :
89- self ._base_url = base_url .rstrip ("/" )
90- self .chat = self ._Chat (self )
91-
92- class _Chat :
93- def __init__ (self , client : "VLLMOpenAIClient" ) -> None :
94- self .completions = client
95-
96- async def create (self , * args : Any , ** kwargs : Any ) -> ChatCompletion :
97- request_body : dict [str , Any ] = {
98- "model" : kwargs .get ("model" , "" ),
99- "messages" : kwargs .get ("messages" , []),
100- }
101- for key in (
102- "temperature" ,
103- "max_tokens" ,
104- "max_completion_tokens" ,
105- "top_p" ,
106- "stop" ,
107- "n" ,
108- "tools" ,
109- "tool_choice" ,
110- ):
111- if key in kwargs and kwargs [key ] is not None :
112- request_body [key ] = kwargs [key ]
113-
114- url = f"{ self ._base_url } /chat/completions"
115- try :
116- session = get_global_aiohttp_client ()
117- async with session .post (url , json = request_body ) as resp :
118- if resp .status != 200 :
119- error_text = await resp .text ()
120- logger .error (f"Request to { url } failed with status { resp .status } : { error_text } " )
121- resp .raise_for_status ()
122- response_dict = await resp .json ()
123- except Exception as e :
124- logger .error (f"Exception calling { url } : { type (e ).__name__ } : { e } " )
125- raise
126-
127- choice_dict = response_dict ["choices" ][0 ]
128- message_dict = choice_dict .get ("message" , {})
129-
130- prompt_token_ids = message_dict .pop ("prompt_token_ids" , [])
131- generation_token_ids = message_dict .pop ("generation_token_ids" , [])
132- generation_log_probs = message_dict .pop ("generation_log_probs" , [])
133-
134- if not generation_token_ids :
135- logger .warning (
136- f"No generation_token_ids in response! Full message keys were: { list (choice_dict .get ('message' , {}).keys ())} "
137- )
138-
139- if prompt_token_ids and isinstance (prompt_token_ids [0 ], str ):
140- prompt_token_ids = [int (tid ) for tid in prompt_token_ids ]
141-
142- if generation_token_ids and isinstance (generation_token_ids [0 ], str ):
143- generation_token_ids = [int (tid ) for tid in generation_token_ids ]
144-
145- if generation_token_ids and generation_log_probs :
146- choice_dict ["logprobs" ] = {
147- "content" : [
148- {"token" : f"token_id:{ tid } " , "logprob" : lp , "top_logprobs" : []}
149- for tid , lp in zip (generation_token_ids , generation_log_probs )
150- ]
151- }
152-
153- response = ChatCompletion .model_validate (response_dict )
154- setattr (response , "prompt_token_ids" , prompt_token_ids )
155- setattr (response .choices [0 ], "token_ids" , generation_token_ids )
156- return response
157-
158-
15961class VerifiersAgentConfig (BaseResponsesAPIAgentConfig ):
16062 model_server : ModelServerRef
16163 model_name : str = Field (default = "" , description = "Model name" )
16264
16365 vf_env_id : str = Field (default = "" , description = "Verifiers environment ID" )
16466 vf_env_args : dict = Field (default_factory = dict , description = "Verifiers environment arguments" )
16567
166- max_concurrent_generation : int = Field (
167- default = - 1 , description = "Max concurrent generation requests (-1 = unlimited)"
168- )
169- max_concurrent_scoring : int = Field (default = - 1 , description = "Max concurrent scoring requests (-1 = unlimited)" )
170-
17168 max_tokens : int = Field (default = 8192 , description = "Max tokens for generation" )
17269
17370 # nemo rl generation_config overrides these
@@ -193,17 +90,17 @@ class VerifiersAgent(SimpleResponsesAPIAgent):
19390 model_config = ConfigDict (arbitrary_types_allowed = True )
19491 config : VerifiersAgentConfig
19592
196- envs_cache : dict [str , Any ] = Field (default_factory = dict ) # vf.Environment
197- openai_client_cache : dict [str , VLLMOpenAIClient ] = Field (default_factory = dict )
93+ envs_cache : dict [str , Any ] = Field (default_factory = dict )
94+ client_cache : dict [str , NeMoRLChatCompletionsClient ] = Field (default_factory = dict )
19895
19996 def _get_env (self , vf_env_id : str ) -> vf .Environment :
20097 if vf_env_id not in self .envs_cache :
20198 self .envs_cache [vf_env_id ] = vf .load_environment (vf_env_id , ** self .config .vf_env_args )
20299 return self .envs_cache [vf_env_id ]
203100
204- def _get_openai_client (self ) -> VLLMOpenAIClient :
101+ def _get_client (self ) -> NeMoRLChatCompletionsClient :
205102 cache_key = self .config .model_server .name
206- if cache_key not in self .openai_client_cache :
103+ if cache_key not in self .client_cache :
207104 server_config_dict = get_first_server_config_dict (
208105 self .server_client .global_config_dict ,
209106 self .config .model_server .name ,
@@ -213,25 +110,50 @@ def _get_openai_client(self) -> VLLMOpenAIClient:
213110 if not model_server_url .endswith ("/v1" ):
214111 model_server_url = model_server_url .rstrip ("/" ) + "/v1"
215112
216- self .openai_client_cache [cache_key ] = VLLMOpenAIClient (base_url = model_server_url )
113+ openai_client = AsyncOpenAI (
114+ base_url = model_server_url ,
115+ api_key = "EMPTY" , # pragma: allowlist secret
116+ )
117+ self .client_cache [cache_key ] = NeMoRLChatCompletionsClient (openai_client )
217118
218- return self .openai_client_cache [cache_key ]
119+ return self .client_cache [cache_key ]
219120
220- def _convert_trajectory_to_output (self , state : dict ) -> list :
121+ def _convert_trajectory_to_output (self , rollout_output : dict ) -> list :
221122 output = []
222- trajectory = state .get ("trajectory" , [])
223-
224- for step in trajectory :
225- for msg in step .get ("prompt" , []):
226- if isinstance (msg , dict ):
123+ trajectory = rollout_output .get ("trajectory" , [])
124+
125+ # Build steps from trajectory if present, otherwise fall back to
126+ # top-level prompt/completion (single-turn environments).
127+ if trajectory :
128+ steps = trajectory
129+ else :
130+ steps = [
131+ {
132+ "prompt" : rollout_output .get ("prompt" , []),
133+ "completion" : rollout_output .get ("completion" , []),
134+ "tokens" : None ,
135+ }
136+ ]
137+
138+ for step in steps :
139+ for msg in step .get ("prompt" , []) or []:
140+ # Handle both plain dicts (serialized RolloutOutput) and
141+ # Pydantic CustomBaseModel messages (which support .get()).
142+ if hasattr (msg , "get" ):
227143 role = msg .get ("role" , "user" )
228144 content = msg .get ("content" , "" )
229145 output .append (NeMoGymEasyInputMessage (role = role , content = content ).model_dump ())
230146
231- tokens = step .get ("tokens" )
232- for msg in step .get ("completion" , []):
233- if isinstance (msg , dict ):
147+ step_tokens = step .get ("tokens" ) if hasattr ( step , "get" ) else None
148+ for msg in step .get ("completion" , []) or [] :
149+ if hasattr (msg , "get" ):
234150 content = msg .get ("content" , "" )
151+ # For trajectory steps, tokens are on the step.
152+ # For single-turn fallback, tokens may be on the message
153+ # (ResponseMessage.tokens from verifiers).
154+ tokens = step_tokens
155+ if tokens is None :
156+ tokens = msg .get ("tokens" ) if hasattr (msg , "get" ) else getattr (msg , "tokens" , None )
235157 if tokens :
236158 output .append (
237159 NeMoGymResponseOutputMessageForTraining (
@@ -278,31 +200,28 @@ async def responses(
278200 example_id = body .example_id ,
279201 )
280202
281- client = self ._get_openai_client ()
282-
283- gen_sem = await maybe_semaphore (self .config .max_concurrent_generation )
284- score_sem = await maybe_semaphore (self .config .max_concurrent_scoring )
203+ client = self ._get_client ()
285204
286- # prefer NeMo RL generation config set in responses_create_params https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/experience/rollouts.py#L1045-L1046
205+ # prefer NeMo RL generation config set in responses_create_params
206+ # https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/experience/rollouts.py#L1045-L1046
287207 sampling_args = {
288208 "max_tokens" : self .config .max_tokens ,
289209 "temperature" : getattr (body .responses_create_params , "temperature" , None ) or self .config .temperature ,
290210 "top_p" : getattr (body .responses_create_params , "top_p" , None ) or self .config .top_p ,
291211 }
292- states = await vf_env .run_group (
212+ outputs = await vf_env .run_group (
293213 group_inputs = [rollout_input ],
294214 client = client ,
295215 model = self .config .model_name ,
296- gen_sampling_args = sampling_args ,
297- gen_sem = gen_sem ,
298- score_sem = score_sem ,
216+ sampling_args = sampling_args ,
217+ state_columns = ["trajectory" ],
299218 )
300219
301- state = states [0 ]
302- reward = state .get ("reward" , 0.0 ) or 0.0
303- metrics = state .get ("metrics" , {}) or {}
220+ rollout_output = outputs [0 ]
221+ reward = rollout_output .get ("reward" , 0.0 ) or 0.0
222+ metrics = rollout_output .get ("metrics" , {}) or {}
304223
305- output = self ._convert_trajectory_to_output (state )
224+ output = self ._convert_trajectory_to_output (rollout_output )
306225
307226 return VerifiersNeMoGymResponse (
308227 id = f"verifiers-{ vf_env_id } -{ task_idx } " ,
0 commit comments