|
18 | 18 | import nemo_gym.global_config |
19 | 19 | import nemo_gym.server_utils |
20 | 20 | from nemo_gym.global_config import ( |
| 21 | + DEFAULT_HEAD_SERVER_PORT, |
21 | 22 | NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME, |
| 23 | + find_open_port, |
22 | 24 | get_first_server_config_dict, |
23 | 25 | get_global_config_dict, |
24 | 26 | ) |
@@ -371,3 +373,40 @@ def test_get_first_server_config_dict(self) -> None: |
371 | 373 | } |
372 | 374 | ) |
373 | 375 | assert {"my_key": "my_value"} == get_first_server_config_dict(global_config_dict, "a") |
| 376 | + |
| 377 | + def test_find_open_port_avoids_head_server_port(self, monkeypatch: MonkeyPatch) -> None: |
| 378 | + """Test that find_open_port retries when the head server port is returned.""" |
| 379 | + socket_mock = MagicMock() |
| 380 | + socket_instance = MagicMock() |
| 381 | + socket_mock.return_value.__enter__ = MagicMock(return_value=socket_instance) |
| 382 | + socket_mock.return_value.__exit__ = MagicMock(return_value=False) |
| 383 | + |
| 384 | + socket_instance.getsockname.side_effect = [ |
| 385 | + ("", DEFAULT_HEAD_SERVER_PORT), # first attempt: 11000 (conflict) |
| 386 | + ("", 12345), # second attempt (safe) |
| 387 | + ] |
| 388 | + |
| 389 | + monkeypatch.setattr(nemo_gym.global_config, "socket", socket_mock) |
| 390 | + |
| 391 | + port = find_open_port(head_server_host="127.0.0.1", head_server_port=DEFAULT_HEAD_SERVER_PORT) |
| 392 | + |
| 393 | + assert port == 12345 |
| 394 | + assert socket_instance.getsockname.call_count == 2 # first: conflict, second: success |
| 395 | + |
| 396 | + def test_find_open_port_raises_after_max_retries(self, monkeypatch: MonkeyPatch) -> None: |
| 397 | + """Test that find_open_port raises RuntimeError after exhausting retries.""" |
| 398 | + socket_mock = MagicMock() |
| 399 | + socket_instance = MagicMock() |
| 400 | + socket_mock.return_value.__enter__ = MagicMock(return_value=socket_instance) |
| 401 | + socket_mock.return_value.__exit__ = MagicMock(return_value=False) |
| 402 | + |
| 403 | + socket_instance.getsockname.return_value = ("", DEFAULT_HEAD_SERVER_PORT) # force conflict |
| 404 | + |
| 405 | + monkeypatch.setattr(nemo_gym.global_config, "socket", socket_mock) |
| 406 | + |
| 407 | + with raises(RuntimeError) as exc_info: |
| 408 | + find_open_port(head_server_host="127.0.0.1", head_server_port=DEFAULT_HEAD_SERVER_PORT, max_retries=5) |
| 409 | + |
| 410 | + assert "Unable to find an open port" in str(exc_info.value) |
| 411 | + assert "after 5 attempts" in str(exc_info.value) |
| 412 | + assert socket_instance.getsockname.call_count == 5 |
0 commit comments