@@ -130,11 +130,14 @@ def _setup_continuous_relaxation(
130130 ``discrete_dims`` and ``post_processing_func`` is updated to round
131131 them to the nearest integer.
132132
133- Dimensions that participate in constraints are NOT relaxed, as rounding
134- after projection could violate those constraints.
133+ Dimensions that participate in the specified constraints are NOT
134+ relaxed, as rounding after continuous optimization could violate
135+ those constraints. The caller controls which constraints are relevant
136+ by passing or omitting ``inequality_constraints`` and
137+ ``equality_constraints``.
135138 """
136139
137- # Identify dimensions involved in constraints
140+ # Identify dimensions involved in the specified constraints.
138141 constrained_dims : set [int ] = set ()
139142 for constraints in [inequality_constraints , equality_constraints ]:
140143 if constraints is not None :
@@ -905,6 +908,140 @@ def continuous_step(
905908 return best_X .view_as (current_x ), best_acq_values
906909
907910
911+ def _run_alternating_optimization (
912+ opt_inputs : OptimizeAcqfInputs ,
913+ discrete_dims : dict [int , list [float ]],
914+ cat_dims : dict [int , list [float ]],
915+ return_acq_values : bool ,
916+ ) -> tuple [Tensor , Tensor | None ]:
917+ r"""Run the alternating discrete/continuous optimization loop.
918+
919+ This is the core optimization routine used by
920+ ``optimize_acqf_mixed_alternating``. It handles fixed feature filtering,
921+ starting point generation, the alternating optimization loop, and
922+ post-processing.
923+
924+ Args:
925+ opt_inputs: Common set of arguments for acquisition optimization.
926+ discrete_dims: A dictionary mapping indices of discrete dimensions
927+ to a list of allowed values, after continuous relaxation.
928+ cat_dims: A dictionary mapping indices of categorical dimensions
929+ to a list of allowed values.
930+ return_acq_values: Whether to return acquisition values.
931+
932+ Returns:
933+ A tuple of (candidates, acq_values_or_none).
934+ """
935+ acq_function = opt_inputs .acq_function
936+ bounds = opt_inputs .bounds
937+ q = opt_inputs .q
938+ options = opt_inputs .options or {}
939+ fixed_features = opt_inputs .fixed_features or {}
940+ post_processing_func = opt_inputs .post_processing_func
941+
942+ base_X_pending = acq_function .X_pending if q > 1 else None
943+ dim = bounds .shape [- 1 ]
944+ tkwargs : dict [str , Any ] = {"device" : bounds .device , "dtype" : bounds .dtype }
945+ # Remove fixed features from dims, so they don't get optimized.
946+ discrete_dims = {
947+ dim : values
948+ for dim , values in discrete_dims .items ()
949+ if dim not in fixed_features
950+ }
951+ cat_dims = {
952+ dim : values for dim , values in cat_dims .items () if dim not in fixed_features
953+ }
954+ non_cont_dims = [* discrete_dims .keys (), * cat_dims .keys ()]
955+ if len (non_cont_dims ) == 0 :
956+ # If the problem is fully continuous, fall back to standard optimization.
957+ return _optimize_acqf (
958+ opt_inputs = dataclasses .replace (
959+ opt_inputs ,
960+ return_best_only = True ,
961+ return_acq_values = return_acq_values ,
962+ )
963+ )
964+ if not (
965+ isinstance (non_cont_dims , list )
966+ and len (set (non_cont_dims )) == len (non_cont_dims )
967+ and min (non_cont_dims ) >= 0
968+ and max (non_cont_dims ) <= dim - 1
969+ ):
970+ raise ValueError (
971+ "`discrete_dims` and `cat_dims` must be dictionaries with unique, disjoint "
972+ "integers as keys between 0 and num_dims - 1."
973+ )
974+ discrete_dims_t = torch .tensor (
975+ list (discrete_dims .keys ()), dtype = torch .long , device = tkwargs ["device" ]
976+ )
977+ cat_dims_t = torch .tensor (
978+ list (cat_dims .keys ()), dtype = torch .long , device = tkwargs ["device" ]
979+ )
980+ non_cont_dims = torch .tensor (
981+ non_cont_dims , dtype = torch .long , device = tkwargs ["device" ]
982+ )
983+ cont_dims = complement_indices_like (indices = non_cont_dims , d = dim )
984+ # Fixed features are all in cont_dims. Remove them, so they don't get optimized.
985+ ff_idcs = torch .tensor (
986+ list (fixed_features .keys ()), dtype = torch .long , device = tkwargs ["device" ]
987+ )
988+ cont_dims = cont_dims [(cont_dims .unsqueeze (- 1 ) != ff_idcs ).all (dim = - 1 )]
989+ candidates = torch .empty (0 , dim , ** tkwargs )
990+ for _q in range (q ):
991+ # Generate starting points.
992+ best_X , best_acq_val = generate_starting_points (
993+ opt_inputs = opt_inputs ,
994+ discrete_dims = discrete_dims ,
995+ cat_dims = cat_dims ,
996+ cont_dims = cont_dims ,
997+ )
998+
999+ done = torch .zeros (len (best_X ), dtype = torch .bool , device = tkwargs ["device" ])
1000+ for _step in range (options .get ("maxiter_alternating" , MAX_ITER_ALTER )):
1001+ starting_acq_val = best_acq_val .clone ()
1002+ best_X [~ done ], best_acq_val [~ done ] = discrete_step (
1003+ opt_inputs = opt_inputs ,
1004+ discrete_dims = discrete_dims ,
1005+ cat_dims = cat_dims ,
1006+ current_x = best_X [~ done ],
1007+ )
1008+
1009+ best_X [~ done ], best_acq_val [~ done ] = continuous_step (
1010+ opt_inputs = opt_inputs ,
1011+ discrete_dims = discrete_dims_t ,
1012+ cat_dims = cat_dims_t ,
1013+ current_x = best_X [~ done ],
1014+ )
1015+
1016+ improvement = best_acq_val - starting_acq_val
1017+ done_now = improvement < options .get ("tol" , CONVERGENCE_TOL )
1018+ done = done | done_now
1019+ if done .float ().mean () >= STOP_AFTER_SHARE_CONVERGED :
1020+ break
1021+
1022+ new_candidate = best_X [torch .argmax (best_acq_val )].unsqueeze (0 )
1023+ candidates = torch .cat ([candidates , new_candidate ], dim = - 2 )
1024+ # Update pending points to include the new candidate.
1025+ if q > 1 :
1026+ acq_function .set_X_pending (
1027+ torch .cat ([base_X_pending , candidates ], dim = - 2 )
1028+ if base_X_pending is not None
1029+ else candidates
1030+ )
1031+ if q > 1 :
1032+ acq_function .set_X_pending (base_X_pending )
1033+
1034+ if post_processing_func is not None :
1035+ candidates = post_processing_func (candidates )
1036+
1037+ if not return_acq_values :
1038+ return candidates , None
1039+
1040+ with torch .no_grad ():
1041+ acq_value = acq_function (candidates ) # compute joint acquisition value
1042+ return candidates , acq_value
1043+
1044+
9081045def optimize_acqf_mixed_alternating (
9091046 acq_function : AcquisitionFunction ,
9101047 bounds : Tensor ,
@@ -1049,18 +1186,38 @@ def optimize_acqf_mixed_alternating(
10491186 "of freedom."
10501187 )
10511188
1052- # Update discrete dims and post processing functions to account for any
1053- # dimensions that should be using continuous relaxation.
1189+ # Save pre-relaxation state for potential fallback.
1190+ _pre_relaxation_discrete_dims = discrete_dims
1191+ _original_ppf = post_processing_func
1192+
1193+ # Identify inequality-constrained discrete dims.
1194+ _ineq_dim_indices : set [int ] = set ()
1195+ if inequality_constraints is not None :
1196+ for indices , _ , _ in inequality_constraints :
1197+ _ineq_dim_indices .update (indices .tolist ())
1198+
1199+ max_discrete_values = assert_is_instance (
1200+ options .get ("max_discrete_values" , MAX_DISCRETE_VALUES ), int
1201+ )
1202+
1203+ # First attempt: relax inequality-constrained dims for performance.
1204+ # By passing inequality_constraints=None, these dims are not excluded
1205+ # from continuous relaxation, allowing them to be optimized continuously.
1206+ # If this produces infeasible candidates (e.g., due to rounding violations
1207+ # with non-contiguous choices), we fall back to keeping them discrete.
10541208 discrete_dims , post_processing_func = _setup_continuous_relaxation (
10551209 discrete_dims = discrete_dims ,
1056- max_discrete_values = assert_is_instance (
1057- options .get ("max_discrete_values" , MAX_DISCRETE_VALUES ), int
1058- ),
1210+ max_discrete_values = max_discrete_values ,
10591211 post_processing_func = post_processing_func ,
1060- inequality_constraints = inequality_constraints ,
1212+ inequality_constraints = None ,
10611213 equality_constraints = equality_constraints ,
10621214 )
10631215
1216+ # Track whether any inequality-constrained dims were actually relaxed.
1217+ _ineq_dims_relaxed = (
1218+ _ineq_dim_indices & set (_pre_relaxation_discrete_dims .keys ())
1219+ ) - set (discrete_dims .keys ())
1220+
10641221 opt_inputs = OptimizeAcqfInputs (
10651222 acq_function = acq_function ,
10661223 bounds = bounds ,
@@ -1088,106 +1245,49 @@ def optimize_acqf_mixed_alternating(
10881245 opt_inputs = dataclasses .replace (opt_inputs , return_best_only = True )
10891246 )
10901247
1091- base_X_pending = acq_function .X_pending if q > 1 else None
1092- dim = bounds .shape [- 1 ]
1093- tkwargs : dict [str , Any ] = {"device" : bounds .device , "dtype" : bounds .dtype }
1094- # Remove fixed features from dims, so they don't get optimized.
1095- discrete_dims = {
1096- dim : values
1097- for dim , values in discrete_dims .items ()
1098- if dim not in fixed_features
1099- }
1100- cat_dims = {
1101- dim : values for dim , values in cat_dims .items () if dim not in fixed_features
1102- }
1103- non_cont_dims = [* discrete_dims .keys (), * cat_dims .keys ()]
1104- if len (non_cont_dims ) == 0 :
1105- # If the problem is fully continuous, fall back to standard optimization.
1106- return _optimize_acqf (
1107- opt_inputs = dataclasses .replace (
1108- opt_inputs ,
1109- return_best_only = True ,
1110- return_acq_values = return_acq_values ,
1111- )
1112- )
1113- if not (
1114- isinstance (non_cont_dims , list )
1115- and len (set (non_cont_dims )) == len (non_cont_dims )
1116- and min (non_cont_dims ) >= 0
1117- and max (non_cont_dims ) <= dim - 1
1118- ):
1119- raise ValueError (
1120- "`discrete_dims` and `cat_dims` must be dictionaries with unique, disjoint "
1121- "integers as keys between 0 and num_dims - 1."
1122- )
1123- discrete_dims_t = torch .tensor (
1124- list (discrete_dims .keys ()), dtype = torch .long , device = tkwargs ["device" ]
1125- )
1126- cat_dims_t = torch .tensor (
1127- list (cat_dims .keys ()), dtype = torch .long , device = tkwargs ["device" ]
1128- )
1129- non_cont_dims = torch .tensor (
1130- non_cont_dims , dtype = torch .long , device = tkwargs ["device" ]
1131- )
1132- cont_dims = complement_indices_like (indices = non_cont_dims , d = dim )
1133- # Fixed features are all in cont_dims. Remove them, so they don't get optimized.
1134- ff_idcs = torch .tensor (
1135- list (fixed_features .keys ()), dtype = torch .long , device = tkwargs ["device" ]
1248+ candidates , acq_value = _run_alternating_optimization (
1249+ opt_inputs = opt_inputs ,
1250+ discrete_dims = discrete_dims ,
1251+ cat_dims = cat_dims ,
1252+ return_acq_values = return_acq_values ,
11361253 )
1137- cont_dims = cont_dims [(cont_dims .unsqueeze (- 1 ) != ff_idcs ).all (dim = - 1 )]
1138- candidates = torch .empty (0 , dim , ** tkwargs )
1139- for _q in range (q ):
1140- # Generate starting points.
1141- best_X , best_acq_val = generate_starting_points (
1142- opt_inputs = opt_inputs ,
1143- discrete_dims = discrete_dims ,
1144- cat_dims = cat_dims ,
1145- cont_dims = cont_dims ,
1146- )
11471254
1148- done = torch .zeros (len (best_X ), dtype = torch .bool , device = tkwargs ["device" ])
1149- for _step in range (options .get ("maxiter_alternating" , MAX_ITER_ALTER )):
1150- starting_acq_val = best_acq_val .clone ()
1151- best_X [~ done ], best_acq_val [~ done ] = discrete_step (
1255+ # Fallback: if continuous relaxation of inequality-constrained dims
1256+ # produced infeasible candidates (e.g., due to rounding violations with
1257+ # non-contiguous discrete choices or tight constraints), re-run with
1258+ # those dims kept discrete.
1259+ if _ineq_dims_relaxed :
1260+ is_feasible = evaluate_feasibility (
1261+ X = candidates .unsqueeze (- 2 ),
1262+ inequality_constraints = inequality_constraints ,
1263+ equality_constraints = equality_constraints ,
1264+ nonlinear_inequality_constraints = None ,
1265+ )
1266+ if not is_feasible .all ():
1267+ warnings .warn (
1268+ "Continuous relaxation of inequality-constrained discrete "
1269+ "dims produced infeasible candidates. Retrying without "
1270+ "relaxation for constrained dims." ,
1271+ OptimizationWarning ,
1272+ stacklevel = 2 ,
1273+ )
1274+ discrete_dims , post_processing_func = _setup_continuous_relaxation (
1275+ discrete_dims = _pre_relaxation_discrete_dims ,
1276+ max_discrete_values = max_discrete_values ,
1277+ post_processing_func = _original_ppf ,
1278+ inequality_constraints = inequality_constraints ,
1279+ equality_constraints = equality_constraints ,
1280+ )
1281+ opt_inputs = dataclasses .replace (
1282+ opt_inputs , post_processing_func = post_processing_func
1283+ )
1284+ candidates , acq_value = _run_alternating_optimization (
11521285 opt_inputs = opt_inputs ,
11531286 discrete_dims = discrete_dims ,
11541287 cat_dims = cat_dims ,
1155- current_x = best_X [~ done ],
1156- )
1157-
1158- best_X [~ done ], best_acq_val [~ done ] = continuous_step (
1159- opt_inputs = opt_inputs ,
1160- discrete_dims = discrete_dims_t ,
1161- cat_dims = cat_dims_t ,
1162- current_x = best_X [~ done ],
1163- )
1164-
1165- improvement = best_acq_val - starting_acq_val
1166- done_now = improvement < options .get ("tol" , CONVERGENCE_TOL )
1167- done = done | done_now
1168- if done .float ().mean () >= STOP_AFTER_SHARE_CONVERGED :
1169- break
1170-
1171- new_candidate = best_X [torch .argmax (best_acq_val )].unsqueeze (0 )
1172- candidates = torch .cat ([candidates , new_candidate ], dim = - 2 )
1173- # Update pending points to include the new candidate.
1174- if q > 1 :
1175- acq_function .set_X_pending (
1176- torch .cat ([base_X_pending , candidates ], dim = - 2 )
1177- if base_X_pending is not None
1178- else candidates
1288+ return_acq_values = return_acq_values ,
11791289 )
1180- if q > 1 :
1181- acq_function .set_X_pending (base_X_pending )
1182-
1183- if post_processing_func is not None :
1184- candidates = post_processing_func (candidates )
11851290
1186- if not return_acq_values :
1187- return candidates , None
1188-
1189- with torch .no_grad ():
1190- acq_value = acq_function (candidates ) # compute joint acquisition value
11911291 return candidates , acq_value
11921292
11931293
0 commit comments