Skip to content

Commit 0dd33ec

Browse files
committed
Test retries and error handling, clean some lines
Signed-off-by: Frankie Siino <fsiino@nvidia.com>
1 parent 9e48bf7 commit 0dd33ec

2 files changed

Lines changed: 42 additions & 7 deletions

File tree

nemo_gym/global_config.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,9 @@ def parse(self, parse_config: Optional[GlobalConfigDictParserConfig] = None) ->
193193
# Do one pass through all the configs validate and populate various configs for our servers.
194194
default_host = global_config_dict.get(DEFAULT_HOST_KEY_NAME) or "127.0.0.1"
195195

196-
head_server_config = global_config_dict.get(HEAD_SERVER_KEY_NAME)
197-
if head_server_config:
198-
head_server_host = head_server_config.get("host") or default_host
199-
head_server_port = head_server_config.get("port") or DEFAULT_HEAD_SERVER_PORT
200-
else:
201-
head_server_host = default_host
202-
head_server_port = DEFAULT_HEAD_SERVER_PORT
196+
head_server_config = global_config_dict.get(HEAD_SERVER_KEY_NAME, {})
197+
head_server_host = head_server_config.get("host", default_host)
198+
head_server_port = head_server_config.get("port", DEFAULT_HEAD_SERVER_PORT)
203199

204200
self.validate_and_populate_defaults(server_instance_configs, default_host, head_server_host, head_server_port)
205201

tests/unit_tests/test_global_config.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import nemo_gym.global_config
1919
import nemo_gym.server_utils
2020
from nemo_gym.global_config import (
21+
DEFAULT_HEAD_SERVER_PORT,
2122
NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME,
23+
find_open_port,
2224
get_first_server_config_dict,
2325
get_global_config_dict,
2426
)
@@ -371,3 +373,40 @@ def test_get_first_server_config_dict(self) -> None:
371373
}
372374
)
373375
assert {"my_key": "my_value"} == get_first_server_config_dict(global_config_dict, "a")
376+
377+
def test_find_open_port_avoids_head_server_port(self, monkeypatch: MonkeyPatch) -> None:
378+
"""Test that find_open_port retries when the head server port is returned."""
379+
socket_mock = MagicMock()
380+
socket_instance = MagicMock()
381+
socket_mock.return_value.__enter__ = MagicMock(return_value=socket_instance)
382+
socket_mock.return_value.__exit__ = MagicMock(return_value=False)
383+
384+
socket_instance.getsockname.side_effect = [
385+
("", DEFAULT_HEAD_SERVER_PORT), # first attempt: 11000 (conflict)
386+
("", 12345), # second attempt (safe)
387+
]
388+
389+
monkeypatch.setattr(nemo_gym.global_config, "socket", socket_mock)
390+
391+
port = find_open_port(head_server_host="127.0.0.1", head_server_port=DEFAULT_HEAD_SERVER_PORT)
392+
393+
assert port == 12345
394+
assert socket_instance.getsockname.call_count == 2 # first: conflict, second: success
395+
396+
def test_find_open_port_raises_after_max_retries(self, monkeypatch: MonkeyPatch) -> None:
397+
"""Test that find_open_port raises RuntimeError after exhausting retries."""
398+
socket_mock = MagicMock()
399+
socket_instance = MagicMock()
400+
socket_mock.return_value.__enter__ = MagicMock(return_value=socket_instance)
401+
socket_mock.return_value.__exit__ = MagicMock(return_value=False)
402+
403+
socket_instance.getsockname.return_value = ("", DEFAULT_HEAD_SERVER_PORT) # force conflict
404+
405+
monkeypatch.setattr(nemo_gym.global_config, "socket", socket_mock)
406+
407+
with raises(RuntimeError) as exc_info:
408+
find_open_port(head_server_host="127.0.0.1", head_server_port=DEFAULT_HEAD_SERVER_PORT, max_retries=5)
409+
410+
assert "Unable to find an open port" in str(exc_info.value)
411+
assert "after 5 attempts" in str(exc_info.value)
412+
assert socket_instance.getsockname.call_count == 5

0 commit comments

Comments
 (0)