Skip to content

Commit 8ced576

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Continuous relaxation with fallback for inequality-constrained ordinal dims (#3261)
Summary: Pull Request resolved: #3261 `_setup_continuous_relaxation` in `optimize_mixed.py` blanket-excludes all constrained discrete dimensions from continuous relaxation, forcing them into discrete local search even when they have high cardinality. This is overly conservative for inequality constraints and causes severe performance degradation. **Problem:** When ordinal parameters (e.g., integers 0-50) participate in linear inequality constraints (e.g., `x1 + x2 + x3 <= 100`), they are kept as discrete dims regardless of cardinality. In mixed search spaces, this inflates the discrete combination count (e.g., 51^4 x 20 = 135M), forces `optimize_acqf_mixed_alternating`, and with default optimizer budgets (`raw_samples=1024`, `maxiter_init=100`, `maxiter_alternating=64`) across many sequential candidates, produces ~900K+ acquisition function evaluations -- taking hours instead of minutes. **Fix:** Try continuous relaxation first for inequality-constrained dims, with automatic fallback to keeping them discrete if infeasible candidates result. Specifically, `optimize_acqf_mixed_alternating` now: 1. **Fast path**: Calls `_setup_continuous_relaxation` with `inequality_constraints=None`, allowing inequality-constrained dims to be relaxed and optimized continuously. 2. **Feasibility check**: After optimization, checks if the candidates satisfy all constraints via `evaluate_feasibility`. 3. **Fallback**: If any candidates are infeasible (e.g., due to rounding violations with non-contiguous discrete choices or tight constraints), re-runs with inequality-constrained dims kept discrete. `_setup_continuous_relaxation` itself retains the D94963154 behavior of excluding all constrained dims passed to it — the caller controls which constraints are relevant by choosing what to pass. The optimization body is extracted into `_run_alternating_optimization` to enable the fallback without code duplication. Differential Revision: D99304800
1 parent b16b28f commit 8ced576

2 files changed

Lines changed: 341 additions & 104 deletions

File tree

botorch/optim/optimize_mixed.py

Lines changed: 202 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,14 @@ def _setup_continuous_relaxation(
130130
``discrete_dims`` and ``post_processing_func`` is updated to round
131131
them to the nearest integer.
132132
133-
Dimensions that participate in constraints are NOT relaxed, as rounding
134-
after projection could violate those constraints.
133+
Dimensions that participate in the specified constraints are NOT
134+
relaxed, as rounding after continuous optimization could violate
135+
those constraints. The caller controls which constraints are relevant
136+
by passing or omitting ``inequality_constraints`` and
137+
``equality_constraints``.
135138
"""
136139

137-
# Identify dimensions involved in constraints
140+
# Identify dimensions involved in the specified constraints.
138141
constrained_dims: set[int] = set()
139142
for constraints in [inequality_constraints, equality_constraints]:
140143
if constraints is not None:
@@ -905,6 +908,140 @@ def continuous_step(
905908
return best_X.view_as(current_x), best_acq_values
906909

907910

911+
def _run_alternating_optimization(
912+
opt_inputs: OptimizeAcqfInputs,
913+
discrete_dims: dict[int, list[float]],
914+
cat_dims: dict[int, list[float]],
915+
return_acq_values: bool,
916+
) -> tuple[Tensor, Tensor | None]:
917+
r"""Run the alternating discrete/continuous optimization loop.
918+
919+
This is the core optimization routine used by
920+
``optimize_acqf_mixed_alternating``. It handles fixed feature filtering,
921+
starting point generation, the alternating optimization loop, and
922+
post-processing.
923+
924+
Args:
925+
opt_inputs: Common set of arguments for acquisition optimization.
926+
discrete_dims: A dictionary mapping indices of discrete dimensions
927+
to a list of allowed values, after continuous relaxation.
928+
cat_dims: A dictionary mapping indices of categorical dimensions
929+
to a list of allowed values.
930+
return_acq_values: Whether to return acquisition values.
931+
932+
Returns:
933+
A tuple of (candidates, acq_values_or_none).
934+
"""
935+
acq_function = opt_inputs.acq_function
936+
bounds = opt_inputs.bounds
937+
q = opt_inputs.q
938+
options = opt_inputs.options or {}
939+
fixed_features = opt_inputs.fixed_features or {}
940+
post_processing_func = opt_inputs.post_processing_func
941+
942+
base_X_pending = acq_function.X_pending if q > 1 else None
943+
dim = bounds.shape[-1]
944+
tkwargs: dict[str, Any] = {"device": bounds.device, "dtype": bounds.dtype}
945+
# Remove fixed features from dims, so they don't get optimized.
946+
discrete_dims = {
947+
dim: values
948+
for dim, values in discrete_dims.items()
949+
if dim not in fixed_features
950+
}
951+
cat_dims = {
952+
dim: values for dim, values in cat_dims.items() if dim not in fixed_features
953+
}
954+
non_cont_dims = [*discrete_dims.keys(), *cat_dims.keys()]
955+
if len(non_cont_dims) == 0:
956+
# If the problem is fully continuous, fall back to standard optimization.
957+
return _optimize_acqf(
958+
opt_inputs=dataclasses.replace(
959+
opt_inputs,
960+
return_best_only=True,
961+
return_acq_values=return_acq_values,
962+
)
963+
)
964+
if not (
965+
isinstance(non_cont_dims, list)
966+
and len(set(non_cont_dims)) == len(non_cont_dims)
967+
and min(non_cont_dims) >= 0
968+
and max(non_cont_dims) <= dim - 1
969+
):
970+
raise ValueError(
971+
"`discrete_dims` and `cat_dims` must be dictionaries with unique, disjoint "
972+
"integers as keys between 0 and num_dims - 1."
973+
)
974+
discrete_dims_t = torch.tensor(
975+
list(discrete_dims.keys()), dtype=torch.long, device=tkwargs["device"]
976+
)
977+
cat_dims_t = torch.tensor(
978+
list(cat_dims.keys()), dtype=torch.long, device=tkwargs["device"]
979+
)
980+
non_cont_dims = torch.tensor(
981+
non_cont_dims, dtype=torch.long, device=tkwargs["device"]
982+
)
983+
cont_dims = complement_indices_like(indices=non_cont_dims, d=dim)
984+
# Fixed features are all in cont_dims. Remove them, so they don't get optimized.
985+
ff_idcs = torch.tensor(
986+
list(fixed_features.keys()), dtype=torch.long, device=tkwargs["device"]
987+
)
988+
cont_dims = cont_dims[(cont_dims.unsqueeze(-1) != ff_idcs).all(dim=-1)]
989+
candidates = torch.empty(0, dim, **tkwargs)
990+
for _q in range(q):
991+
# Generate starting points.
992+
best_X, best_acq_val = generate_starting_points(
993+
opt_inputs=opt_inputs,
994+
discrete_dims=discrete_dims,
995+
cat_dims=cat_dims,
996+
cont_dims=cont_dims,
997+
)
998+
999+
done = torch.zeros(len(best_X), dtype=torch.bool, device=tkwargs["device"])
1000+
for _step in range(options.get("maxiter_alternating", MAX_ITER_ALTER)):
1001+
starting_acq_val = best_acq_val.clone()
1002+
best_X[~done], best_acq_val[~done] = discrete_step(
1003+
opt_inputs=opt_inputs,
1004+
discrete_dims=discrete_dims,
1005+
cat_dims=cat_dims,
1006+
current_x=best_X[~done],
1007+
)
1008+
1009+
best_X[~done], best_acq_val[~done] = continuous_step(
1010+
opt_inputs=opt_inputs,
1011+
discrete_dims=discrete_dims_t,
1012+
cat_dims=cat_dims_t,
1013+
current_x=best_X[~done],
1014+
)
1015+
1016+
improvement = best_acq_val - starting_acq_val
1017+
done_now = improvement < options.get("tol", CONVERGENCE_TOL)
1018+
done = done | done_now
1019+
if done.float().mean() >= STOP_AFTER_SHARE_CONVERGED:
1020+
break
1021+
1022+
new_candidate = best_X[torch.argmax(best_acq_val)].unsqueeze(0)
1023+
candidates = torch.cat([candidates, new_candidate], dim=-2)
1024+
# Update pending points to include the new candidate.
1025+
if q > 1:
1026+
acq_function.set_X_pending(
1027+
torch.cat([base_X_pending, candidates], dim=-2)
1028+
if base_X_pending is not None
1029+
else candidates
1030+
)
1031+
if q > 1:
1032+
acq_function.set_X_pending(base_X_pending)
1033+
1034+
if post_processing_func is not None:
1035+
candidates = post_processing_func(candidates)
1036+
1037+
if not return_acq_values:
1038+
return candidates, None
1039+
1040+
with torch.no_grad():
1041+
acq_value = acq_function(candidates) # compute joint acquisition value
1042+
return candidates, acq_value
1043+
1044+
9081045
def optimize_acqf_mixed_alternating(
9091046
acq_function: AcquisitionFunction,
9101047
bounds: Tensor,
@@ -1049,18 +1186,38 @@ def optimize_acqf_mixed_alternating(
10491186
"of freedom."
10501187
)
10511188

1052-
# Update discrete dims and post processing functions to account for any
1053-
# dimensions that should be using continuous relaxation.
1189+
# Save pre-relaxation state for potential fallback.
1190+
_pre_relaxation_discrete_dims = discrete_dims
1191+
_original_ppf = post_processing_func
1192+
1193+
# Identify inequality-constrained discrete dims.
1194+
_ineq_dim_indices: set[int] = set()
1195+
if inequality_constraints is not None:
1196+
for indices, _, _ in inequality_constraints:
1197+
_ineq_dim_indices.update(indices.tolist())
1198+
1199+
max_discrete_values = assert_is_instance(
1200+
options.get("max_discrete_values", MAX_DISCRETE_VALUES), int
1201+
)
1202+
1203+
# First attempt: relax inequality-constrained dims for performance.
1204+
# By passing inequality_constraints=None, these dims are not excluded
1205+
# from continuous relaxation, allowing them to be optimized continuously.
1206+
# If this produces infeasible candidates (e.g., due to rounding violations
1207+
# with non-contiguous choices), we fall back to keeping them discrete.
10541208
discrete_dims, post_processing_func = _setup_continuous_relaxation(
10551209
discrete_dims=discrete_dims,
1056-
max_discrete_values=assert_is_instance(
1057-
options.get("max_discrete_values", MAX_DISCRETE_VALUES), int
1058-
),
1210+
max_discrete_values=max_discrete_values,
10591211
post_processing_func=post_processing_func,
1060-
inequality_constraints=inequality_constraints,
1212+
inequality_constraints=None,
10611213
equality_constraints=equality_constraints,
10621214
)
10631215

1216+
# Track whether any inequality-constrained dims were actually relaxed.
1217+
_ineq_dims_relaxed = (
1218+
_ineq_dim_indices & set(_pre_relaxation_discrete_dims.keys())
1219+
) - set(discrete_dims.keys())
1220+
10641221
opt_inputs = OptimizeAcqfInputs(
10651222
acq_function=acq_function,
10661223
bounds=bounds,
@@ -1088,106 +1245,49 @@ def optimize_acqf_mixed_alternating(
10881245
opt_inputs=dataclasses.replace(opt_inputs, return_best_only=True)
10891246
)
10901247

1091-
base_X_pending = acq_function.X_pending if q > 1 else None
1092-
dim = bounds.shape[-1]
1093-
tkwargs: dict[str, Any] = {"device": bounds.device, "dtype": bounds.dtype}
1094-
# Remove fixed features from dims, so they don't get optimized.
1095-
discrete_dims = {
1096-
dim: values
1097-
for dim, values in discrete_dims.items()
1098-
if dim not in fixed_features
1099-
}
1100-
cat_dims = {
1101-
dim: values for dim, values in cat_dims.items() if dim not in fixed_features
1102-
}
1103-
non_cont_dims = [*discrete_dims.keys(), *cat_dims.keys()]
1104-
if len(non_cont_dims) == 0:
1105-
# If the problem is fully continuous, fall back to standard optimization.
1106-
return _optimize_acqf(
1107-
opt_inputs=dataclasses.replace(
1108-
opt_inputs,
1109-
return_best_only=True,
1110-
return_acq_values=return_acq_values,
1111-
)
1112-
)
1113-
if not (
1114-
isinstance(non_cont_dims, list)
1115-
and len(set(non_cont_dims)) == len(non_cont_dims)
1116-
and min(non_cont_dims) >= 0
1117-
and max(non_cont_dims) <= dim - 1
1118-
):
1119-
raise ValueError(
1120-
"`discrete_dims` and `cat_dims` must be dictionaries with unique, disjoint "
1121-
"integers as keys between 0 and num_dims - 1."
1122-
)
1123-
discrete_dims_t = torch.tensor(
1124-
list(discrete_dims.keys()), dtype=torch.long, device=tkwargs["device"]
1125-
)
1126-
cat_dims_t = torch.tensor(
1127-
list(cat_dims.keys()), dtype=torch.long, device=tkwargs["device"]
1128-
)
1129-
non_cont_dims = torch.tensor(
1130-
non_cont_dims, dtype=torch.long, device=tkwargs["device"]
1131-
)
1132-
cont_dims = complement_indices_like(indices=non_cont_dims, d=dim)
1133-
# Fixed features are all in cont_dims. Remove them, so they don't get optimized.
1134-
ff_idcs = torch.tensor(
1135-
list(fixed_features.keys()), dtype=torch.long, device=tkwargs["device"]
1248+
candidates, acq_value = _run_alternating_optimization(
1249+
opt_inputs=opt_inputs,
1250+
discrete_dims=discrete_dims,
1251+
cat_dims=cat_dims,
1252+
return_acq_values=return_acq_values,
11361253
)
1137-
cont_dims = cont_dims[(cont_dims.unsqueeze(-1) != ff_idcs).all(dim=-1)]
1138-
candidates = torch.empty(0, dim, **tkwargs)
1139-
for _q in range(q):
1140-
# Generate starting points.
1141-
best_X, best_acq_val = generate_starting_points(
1142-
opt_inputs=opt_inputs,
1143-
discrete_dims=discrete_dims,
1144-
cat_dims=cat_dims,
1145-
cont_dims=cont_dims,
1146-
)
11471254

1148-
done = torch.zeros(len(best_X), dtype=torch.bool, device=tkwargs["device"])
1149-
for _step in range(options.get("maxiter_alternating", MAX_ITER_ALTER)):
1150-
starting_acq_val = best_acq_val.clone()
1151-
best_X[~done], best_acq_val[~done] = discrete_step(
1255+
# Fallback: if continuous relaxation of inequality-constrained dims
1256+
# produced infeasible candidates (e.g., due to rounding violations with
1257+
# non-contiguous discrete choices or tight constraints), re-run with
1258+
# those dims kept discrete.
1259+
if _ineq_dims_relaxed:
1260+
is_feasible = evaluate_feasibility(
1261+
X=candidates.unsqueeze(-2),
1262+
inequality_constraints=inequality_constraints,
1263+
equality_constraints=equality_constraints,
1264+
nonlinear_inequality_constraints=None,
1265+
)
1266+
if not is_feasible.all():
1267+
warnings.warn(
1268+
"Continuous relaxation of inequality-constrained discrete "
1269+
"dims produced infeasible candidates. Retrying without "
1270+
"relaxation for constrained dims.",
1271+
OptimizationWarning,
1272+
stacklevel=2,
1273+
)
1274+
discrete_dims, post_processing_func = _setup_continuous_relaxation(
1275+
discrete_dims=_pre_relaxation_discrete_dims,
1276+
max_discrete_values=max_discrete_values,
1277+
post_processing_func=_original_ppf,
1278+
inequality_constraints=inequality_constraints,
1279+
equality_constraints=equality_constraints,
1280+
)
1281+
opt_inputs = dataclasses.replace(
1282+
opt_inputs, post_processing_func=post_processing_func
1283+
)
1284+
candidates, acq_value = _run_alternating_optimization(
11521285
opt_inputs=opt_inputs,
11531286
discrete_dims=discrete_dims,
11541287
cat_dims=cat_dims,
1155-
current_x=best_X[~done],
1156-
)
1157-
1158-
best_X[~done], best_acq_val[~done] = continuous_step(
1159-
opt_inputs=opt_inputs,
1160-
discrete_dims=discrete_dims_t,
1161-
cat_dims=cat_dims_t,
1162-
current_x=best_X[~done],
1163-
)
1164-
1165-
improvement = best_acq_val - starting_acq_val
1166-
done_now = improvement < options.get("tol", CONVERGENCE_TOL)
1167-
done = done | done_now
1168-
if done.float().mean() >= STOP_AFTER_SHARE_CONVERGED:
1169-
break
1170-
1171-
new_candidate = best_X[torch.argmax(best_acq_val)].unsqueeze(0)
1172-
candidates = torch.cat([candidates, new_candidate], dim=-2)
1173-
# Update pending points to include the new candidate.
1174-
if q > 1:
1175-
acq_function.set_X_pending(
1176-
torch.cat([base_X_pending, candidates], dim=-2)
1177-
if base_X_pending is not None
1178-
else candidates
1288+
return_acq_values=return_acq_values,
11791289
)
1180-
if q > 1:
1181-
acq_function.set_X_pending(base_X_pending)
1182-
1183-
if post_processing_func is not None:
1184-
candidates = post_processing_func(candidates)
11851290

1186-
if not return_acq_values:
1187-
return candidates, None
1188-
1189-
with torch.no_grad():
1190-
acq_value = acq_function(candidates) # compute joint acquisition value
11911291
return candidates, acq_value
11921292

11931293

0 commit comments

Comments
 (0)