Skip to content

Commit a6587a0

Browse files
committed
feat: add head server deps
Signed-off-by: Sugam Devare <sdevare@nvidia.com>
1 parent 901164a commit a6587a0

2 files changed

Lines changed: 34 additions & 5 deletions

File tree

nemo_gym/cli.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import asyncio
1515
import json
1616
import shlex
17+
import subprocess
1718
import tomllib
1819
from glob import glob
1920
from os import environ, makedirs
@@ -27,13 +28,14 @@
2728
import rich
2829
import uvicorn
2930
from devtools import pprint
30-
from omegaconf import DictConfig, OmegaConf
31+
from omegaconf import DictConfig, OmegaConf, open_dict
3132
from pydantic import BaseModel, Field
3233
from tqdm.auto import tqdm
3334

3435
from nemo_gym import PARENT_DIR
3536
from nemo_gym.config_types import BaseNeMoGymCLIConfig
3637
from nemo_gym.global_config import (
38+
HEAD_SERVER_DEPS_KEY_NAME,
3739
NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME,
3840
NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME,
3941
NEMO_GYM_RESERVED_TOP_LEVEL_KEYS,
@@ -48,18 +50,38 @@
4850
)
4951

5052

51-
def _setup_env_command(dir_path: Path) -> str: # pragma: no cover
53+
def _capture_head_server_dependencies(global_config_dict: DictConfig) -> None: # pragma: no cover
54+
try:
55+
result = subprocess.run(
56+
["uv", "pip", "freeze", "--exclude-editable"],
57+
capture_output=True,
58+
text=True,
59+
check=True,
60+
)
61+
head_server_deps = result.stdout
62+
except Exception as e:
63+
print(f"Warning: Could not capture head server dependencies: {e}")
64+
head_server_deps = None
65+
66+
with open_dict(global_config_dict):
67+
global_config_dict[HEAD_SERVER_DEPS_KEY_NAME] = head_server_deps
68+
69+
70+
def _setup_env_command(dir_path: Path, head_server_deps: Optional[str] = None) -> str: # pragma: no cover
71+
install_cmd = "uv pip install -r requirements.txt"
72+
if head_server_deps:
73+
install_cmd += f" --constraint <(cat << 'EOF'\n{head_server_deps}\nEOF\n)"
74+
5275
return f"""cd {dir_path} \\
5376
&& uv venv --allow-existing \\
5477
&& source .venv/bin/activate \\
55-
&& uv pip install -r requirements.txt \\
78+
&& {install_cmd} \\
5679
"""
5780

5881

5982
def _run_command(command: str, working_directory: Path) -> Popen: # pragma: no cover
6083
custom_env = environ.copy()
6184
custom_env["PYTHONPATH"] = f"{working_directory.absolute()}:{custom_env.get('PYTHONPATH', '')}"
62-
print(f"Executing command:\n{command}\n")
6385
return Popen(command, executable="/bin/bash", shell=True, env=custom_env)
6486

6587

@@ -114,6 +136,9 @@ class RunHelper: # pragma: no cover
114136
def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig) -> None:
115137
global_config_dict = get_global_config_dict(global_config_dict_parser_config=global_config_dict_parser_config)
116138

139+
# Capture head server dependencies and store in global config dict
140+
_capture_head_server_dependencies(global_config_dict)
141+
117142
# Assume Nemo Gym Run is for a single agent.
118143
escaped_config_dict_yaml_str = shlex.quote(OmegaConf.to_yaml(global_config_dict))
119144

@@ -149,7 +174,9 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig)
149174

150175
dir_path = PARENT_DIR / Path(first_key, second_key)
151176

152-
command = f"""{_setup_env_command(dir_path)} \\
177+
head_server_deps = global_config_dict.get(HEAD_SERVER_DEPS_KEY_NAME)
178+
179+
command = f"""{_setup_env_command(dir_path, head_server_deps)} \\
153180
&& {NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME}={escaped_config_dict_yaml_str} \\
154181
{NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME}={shlex.quote(top_level_path)} \\
155182
python {str(entrypoint_fpath)}"""

nemo_gym/global_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@
3535
ENTRYPOINT_KEY_NAME = "entrypoint"
3636
DEFAULT_HOST_KEY_NAME = "default_host"
3737
HEAD_SERVER_KEY_NAME = "head_server"
38+
HEAD_SERVER_DEPS_KEY_NAME = "head_server_deps"
3839
NEMO_GYM_RESERVED_TOP_LEVEL_KEYS = [
3940
CONFIG_PATHS_KEY_NAME,
4041
ENTRYPOINT_KEY_NAME,
4142
DEFAULT_HOST_KEY_NAME,
4243
HEAD_SERVER_KEY_NAME,
44+
HEAD_SERVER_DEPS_KEY_NAME,
4345
]
4446

4547
POLICY_BASE_URL_KEY_NAME = "policy_base_url"

0 commit comments

Comments
 (0)