@@ -121,7 +121,11 @@ def filter_for_server_instance_configs(self, global_config_dict: DictConfig) ->
121121 return server_instance_configs
122122
123123 def validate_and_populate_defaults (
124- self , server_instance_configs : List [ServerInstanceConfig ], default_host : str
124+ self ,
125+ server_instance_configs : List [ServerInstanceConfig ],
126+ default_host : str ,
127+ head_server_host : Optional [str ] = None ,
128+ head_server_port : Optional [int ] = None ,
125129 ) -> None :
126130 server_refs = [c .get_server_ref () for c in server_instance_configs ]
127131 for server_instance_config in server_instance_configs :
@@ -142,7 +146,10 @@ def validate_and_populate_defaults(
142146 if not run_server_config_dict .get ("host" ):
143147 run_server_config_dict ["host" ] = default_host
144148 if not run_server_config_dict .get ("port" ):
145- run_server_config_dict ["port" ] = find_open_port ()
149+ run_server_config_dict ["port" ] = find_open_port (
150+ head_server_port = head_server_port ,
151+ head_server_host = head_server_host ,
152+ )
146153
147154 def parse (self , parse_config : Optional [GlobalConfigDictParserConfig ] = None ) -> DictConfig :
148155 if parse_config is None :
@@ -186,7 +193,15 @@ def parse(self, parse_config: Optional[GlobalConfigDictParserConfig] = None) ->
186193 # Do one pass through all the configs validate and populate various configs for our servers.
187194 default_host = global_config_dict .get (DEFAULT_HOST_KEY_NAME ) or "127.0.0.1"
188195
189- self .validate_and_populate_defaults (server_instance_configs , default_host )
196+ head_server_config = global_config_dict .get (HEAD_SERVER_KEY_NAME )
197+ if head_server_config :
198+ head_server_host = head_server_config .get ("host" ) or default_host
199+ head_server_port = head_server_config .get ("port" ) or DEFAULT_HEAD_SERVER_PORT
200+ else :
201+ head_server_host = default_host
202+ head_server_port = DEFAULT_HEAD_SERVER_PORT
203+
204+ self .validate_and_populate_defaults (server_instance_configs , default_host , head_server_host , head_server_port )
190205
191206 # Populate head server defaults
192207 if not global_config_dict .get (HEAD_SERVER_KEY_NAME ):
@@ -261,7 +276,21 @@ def get_first_server_config_dict(global_config_dict: DictConfig, top_level_path:
261276 return server_config_dict
262277
263278
264- def find_open_port () -> int : # pragma: no cover
265- with socket () as s :
266- s .bind (("" , 0 )) # Bind to a free port provided by the host.
267- return s .getsockname ()[1 ] # Return the port number assigned.
279+ def find_open_port (
280+ head_server_host : Optional [str ] = None ,
281+ head_server_port : Optional [int ] = None ,
282+ max_retries : int = 50 ,
283+ ) -> int : # pragma: no cover
284+ # Find an open port that doesn't conflict with the head server.
285+ for _ in range (max_retries ):
286+ with socket () as s :
287+ s .bind (("" , 0 )) # Bind to a free port provided by the host.
288+ port = s .getsockname ()[1 ]
289+
290+ if head_server_port is None or port != head_server_port :
291+ return port # Return the port number assigned.
292+
293+ raise RuntimeError (
294+ f"Unable to find an open port that doesn't conflict with head server "
295+ f"({ head_server_host } :{ head_server_port } ) after { max_retries } attempts"
296+ )
0 commit comments