Skip to content

Commit 9e48bf7

Browse files
committed
Add retry logic to find_open_port
Signed-off-by: Frankie Siino <fsiino@nvidia.com>
1 parent 901164a commit 9e48bf7

1 file changed

Lines changed: 36 additions & 7 deletions

File tree

nemo_gym/global_config.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,11 @@ def filter_for_server_instance_configs(self, global_config_dict: DictConfig) ->
121121
return server_instance_configs
122122

123123
def validate_and_populate_defaults(
124-
self, server_instance_configs: List[ServerInstanceConfig], default_host: str
124+
self,
125+
server_instance_configs: List[ServerInstanceConfig],
126+
default_host: str,
127+
head_server_host: Optional[str] = None,
128+
head_server_port: Optional[int] = None,
125129
) -> None:
126130
server_refs = [c.get_server_ref() for c in server_instance_configs]
127131
for server_instance_config in server_instance_configs:
@@ -142,7 +146,10 @@ def validate_and_populate_defaults(
142146
if not run_server_config_dict.get("host"):
143147
run_server_config_dict["host"] = default_host
144148
if not run_server_config_dict.get("port"):
145-
run_server_config_dict["port"] = find_open_port()
149+
run_server_config_dict["port"] = find_open_port(
150+
head_server_port=head_server_port,
151+
head_server_host=head_server_host,
152+
)
146153

147154
def parse(self, parse_config: Optional[GlobalConfigDictParserConfig] = None) -> DictConfig:
148155
if parse_config is None:
@@ -186,7 +193,15 @@ def parse(self, parse_config: Optional[GlobalConfigDictParserConfig] = None) ->
186193
# Do one pass through all the configs validate and populate various configs for our servers.
187194
default_host = global_config_dict.get(DEFAULT_HOST_KEY_NAME) or "127.0.0.1"
188195

189-
self.validate_and_populate_defaults(server_instance_configs, default_host)
196+
head_server_config = global_config_dict.get(HEAD_SERVER_KEY_NAME)
197+
if head_server_config:
198+
head_server_host = head_server_config.get("host") or default_host
199+
head_server_port = head_server_config.get("port") or DEFAULT_HEAD_SERVER_PORT
200+
else:
201+
head_server_host = default_host
202+
head_server_port = DEFAULT_HEAD_SERVER_PORT
203+
204+
self.validate_and_populate_defaults(server_instance_configs, default_host, head_server_host, head_server_port)
190205

191206
# Populate head server defaults
192207
if not global_config_dict.get(HEAD_SERVER_KEY_NAME):
@@ -261,7 +276,21 @@ def get_first_server_config_dict(global_config_dict: DictConfig, top_level_path:
261276
return server_config_dict
262277

263278

264-
def find_open_port() -> int: # pragma: no cover
265-
with socket() as s:
266-
s.bind(("", 0)) # Bind to a free port provided by the host.
267-
return s.getsockname()[1] # Return the port number assigned.
279+
def find_open_port(
280+
head_server_host: Optional[str] = None,
281+
head_server_port: Optional[int] = None,
282+
max_retries: int = 50,
283+
) -> int: # pragma: no cover
284+
# Find an open port that doesn't conflict with the head server.
285+
for _ in range(max_retries):
286+
with socket() as s:
287+
s.bind(("", 0)) # Bind to a free port provided by the host.
288+
port = s.getsockname()[1]
289+
290+
if head_server_port is None or port != head_server_port:
291+
return port # Return the port number assigned.
292+
293+
raise RuntimeError(
294+
f"Unable to find an open port that doesn't conflict with head server "
295+
f"({head_server_host}:{head_server_port}) after {max_retries} attempts"
296+
)

0 commit comments

Comments
 (0)