Skip to content

Commit c810514

Browse files
committed
feat: use global config
Signed-off-by: Sugam Devare <sdevare@nvidia.com>
1 parent 93977d3 commit c810514

2 files changed

Lines changed: 18 additions & 16 deletions

File tree

nemo_gym/cli.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626

2727
import uvicorn
2828
from devtools import pprint
29-
from omegaconf import DictConfig, OmegaConf
29+
from omegaconf import DictConfig, OmegaConf, open_dict
3030
from pydantic import BaseModel
3131
from tqdm.auto import tqdm
3232

3333
from nemo_gym import PARENT_DIR
3434
from nemo_gym.global_config import (
35+
HEAD_SERVER_DEPS_KEY_NAME,
3536
NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME,
3637
NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME,
3738
NEMO_GYM_RESERVED_TOP_LEVEL_KEYS,
@@ -46,30 +47,27 @@
4647
)
4748

4849

49-
def _capture_head_server_dependencies() -> Optional[Path]: # pragma: no cover
50+
def _capture_head_server_dependencies(global_config_dict: DictConfig) -> None: # pragma: no cover
5051
try:
5152
result = subprocess.run(
5253
["uv", "pip", "freeze", "--exclude-editable"],
5354
capture_output=True,
5455
text=True,
5556
check=True,
5657
)
57-
frozen_deps = result.stdout
58-
constraints_file = Path("/tmp/head_server_constraints.txt")
59-
constraints_file.parent.mkdir(parents=True, exist_ok=True)
60-
with open(constraints_file, "w") as f:
61-
f.write(frozen_deps)
62-
63-
return constraints_file
58+
head_server_deps = result.stdout
6459
except Exception as e:
6560
print(f"Warning: Could not capture head server dependencies: {e}")
66-
return None
61+
head_server_deps = None
62+
63+
with open_dict(global_config_dict):
64+
global_config_dict[HEAD_SERVER_DEPS_KEY_NAME] = head_server_deps
6765

6866

69-
def _setup_env_command(dir_path: Path, head_server_deps_file: Optional[str] = None) -> str: # pragma: no cover
67+
def _setup_env_command(dir_path: Path, head_server_deps: Optional[str] = None) -> str: # pragma: no cover
7068
install_cmd = "uv pip install -r requirements.txt"
71-
if head_server_deps_file:
72-
install_cmd += f" --constraint {head_server_deps_file.absolute()}"
69+
if head_server_deps:
70+
install_cmd += f" --constraint <(cat << 'EOF'\n{head_server_deps}\nEOF\n)"
7371

7472
return f"""cd {dir_path} \\
7573
&& uv venv --allow-existing \\
@@ -127,8 +125,8 @@ class RunHelper: # pragma: no cover
127125
def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig) -> None:
128126
global_config_dict = get_global_config_dict(global_config_dict_parser_config=global_config_dict_parser_config)
129127

130-
# Capture head server dependencies to use as constraints for other servers
131-
head_server_deps_file = _capture_head_server_dependencies()
128+
# Capture head server dependencies and store in global config dict
129+
_capture_head_server_dependencies(global_config_dict)
132130

133131
# Assume Nemo Gym Run is for a single agent.
134132
escaped_config_dict_yaml_str = shlex.quote(OmegaConf.to_yaml(global_config_dict))
@@ -165,7 +163,9 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig)
165163

166164
dir_path = PARENT_DIR / Path(first_key, second_key)
167165

168-
command = f"""{_setup_env_command(dir_path, head_server_deps_file)} \\
166+
head_server_deps = global_config_dict.get(HEAD_SERVER_DEPS_KEY_NAME)
167+
168+
command = f"""{_setup_env_command(dir_path, head_server_deps)} \\
169169
&& {NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME}={escaped_config_dict_yaml_str} \\
170170
{NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME}={shlex.quote(top_level_path)} \\
171171
python {str(entrypoint_fpath)}"""

nemo_gym/global_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@
3535
ENTRYPOINT_KEY_NAME = "entrypoint"
3636
DEFAULT_HOST_KEY_NAME = "default_host"
3737
HEAD_SERVER_KEY_NAME = "head_server"
38+
HEAD_SERVER_DEPS_KEY_NAME = "head_server_deps"
3839
NEMO_GYM_RESERVED_TOP_LEVEL_KEYS = [
3940
CONFIG_PATHS_KEY_NAME,
4041
ENTRYPOINT_KEY_NAME,
4142
DEFAULT_HOST_KEY_NAME,
4243
HEAD_SERVER_KEY_NAME,
44+
HEAD_SERVER_DEPS_KEY_NAME,
4345
]
4446

4547
POLICY_BASE_URL_KEY_NAME = "policy_base_url"

0 commit comments

Comments
 (0)