Skip to content

Commit 8a0cbef

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Add LearnedFeatureImputation input transform (meta-pytorch#3281)
Summary: Pull Request resolved: meta-pytorch#3281 X-link: https://github.com/facebookexternal/botorch_fb/pull/32 Adds `LearnedFeatureImputation`, a composable `InputTransform` that learns imputation values for missing features in heterogeneous multi-task settings. This is the first step toward replacing the `HeterogeneousMTGP` model hierarchy (see review plan) with a standard `MultiTaskGP` + composable input transforms. **Design**: - Expects normalized `[0, 1]` inputs — chain `Normalize` before this transform. - Default bounds `[0, 1]^d`, imputation values initialized at center (0.5). - Clamp-based constraint keeps imputation values within bounds. - Emits `UserInputWarning` if non-default bounds are passed. - Handles expansion from target-space to full-space during inference (when target task has fewer features). - `FillMissingParameters` has prio on the Ax side, so this will only fill in the gaps *after* is FMP has been applied. Idea is to hopefully use this with heterogeneous search spaces, s.t. we can freely use `MultiTaskGP` and other non-heterogeneous-specific models. `botorch_fb` while in eval phase. Reviewed By: saitcakmak Differential Revision: D97625736 fbshipit-source-id: 9533bd635112890e5308c895292a538780e277eb
1 parent b3c6862 commit 8a0cbef

3 files changed

Lines changed: 573 additions & 1 deletion

File tree

botorch/models/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from botorch.models.transforms.factory import get_rounding_input_transform
88
from botorch.models.transforms.input import (
99
ChainedInputTransform,
10+
LearnedFeatureImputation,
1011
Normalize,
1112
Round,
1213
Warp,
@@ -25,6 +26,7 @@
2526
"Bilog",
2627
"ChainedInputTransform",
2728
"ChainedOutcomeTransform",
29+
"LearnedFeatureImputation",
2830
"Log",
2931
"Normalize",
3032
"Power",

botorch/models/transforms/input.py

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from botorch.models.utils import fantasize
3636
from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE
3737
from gpytorch import Module as GPyTorchModule
38-
from gpytorch.constraints import GreaterThan
38+
from gpytorch.constraints import GreaterThan, Interval
3939
from gpytorch.priors import Prior
4040
from torch import LongTensor, nn, Tensor
4141
from torch.nn import Module, ModuleDict
@@ -1915,3 +1915,222 @@ def equals(self, other: InputTransform) -> bool:
19151915
and (self.transform_on_fantasize == other.transform_on_fantasize)
19161916
and self.categorical_features == other.categorical_features
19171917
)
1918+
1919+
1920+
class LearnedFeatureImputation(InputTransform, GPyTorchModule):
1921+
r"""An input transform that learns imputation values for missing features
1922+
in heterogeneous multi-task settings.
1923+
1924+
In multi-task problems where different tasks observe different subsets of
1925+
features, this transform fills in the unobserved feature columns with
1926+
learned parameter values. This enables using a standard ``MultiTaskGP``
1927+
with composable input transforms instead of specialized model classes.
1928+
1929+
The input tensor ``X`` is expected to have shape ``batch_shape x n x (d+1)``,
1930+
where the column at ``task_feature_index`` contains the task identifier.
1931+
For each task, feature columns not listed in ``feature_indices[task_value]``
1932+
are replaced with the corresponding learned imputation values. Task values
1933+
need not be contiguous or 0-indexed.
1934+
"""
1935+
1936+
def __init__(
1937+
self,
1938+
feature_indices: dict[int, list[int]],
1939+
d: int,
1940+
task_feature_index: int = -1,
1941+
target_task: int | None = None,
1942+
bounds: Tensor | None = None,
1943+
transform_on_train: bool = True,
1944+
transform_on_eval: bool = True,
1945+
transform_on_fantasize: bool = True,
1946+
dtype: torch.dtype = torch.float64,
1947+
device: torch.device | None = None,
1948+
) -> None:
1949+
r"""Initialize LearnedFeatureImputation.
1950+
1951+
Args:
1952+
feature_indices: A mapping from integer task values (as they appear
1953+
in the task column of ``X``) to lists of observed X-column
1954+
indices for that task. Indices refer directly to columns of the
1955+
input tensor ``X`` and must not include the task column. When
1956+
``task_feature_index=-1`` (the common case), the ``d`` feature
1957+
columns are ``0, 1, ..., d-1``. Task values need not be
1958+
contiguous or 0-indexed.
1959+
d: The total number of feature columns (excluding the task column).
1960+
task_feature_index: The column index in ``X`` that contains the
1961+
task identifier. Must be ``-1`` (last column). Defaults to ``-1``.
1962+
target_task: The task identifier to use when ``X`` has ``d``
1963+
columns (no task column). Required for d-dim inputs since the
1964+
task cannot be inferred from shape alone — two tasks may share
1965+
the same number of active dimensions. Must be a key in
1966+
``feature_indices``. If ``None``, only ``(d+1)``-dim inputs
1967+
are supported.
1968+
bounds: A ``2 x d`` tensor of ``[lower, upper]`` bounds for each
1969+
feature. If provided, imputation values are constrained to lie
1970+
within these bounds via a GPyTorch ``Interval`` constraint.
1971+
Defaults to ``None`` (unconstrained). This transform is designed
1972+
to operate on normalized inputs; if bounds differ from
1973+
``[0, 1]^d``, a warning is emitted suggesting to chain
1974+
``Normalize`` before this transform.
1975+
transform_on_train: If ``True``, apply the transform in train mode.
1976+
transform_on_eval: If ``True``, apply the transform in eval mode.
1977+
transform_on_fantasize: If ``True``, apply the transform inside
1978+
``fantasize`` calls.
1979+
dtype: The dtype for the imputation parameters.
1980+
device: The device for the imputation parameters.
1981+
"""
1982+
super().__init__()
1983+
self.transform_on_train = transform_on_train
1984+
self.transform_on_eval = transform_on_eval
1985+
self.transform_on_fantasize = transform_on_fantasize
1986+
1987+
if target_task is not None and target_task not in feature_indices:
1988+
raise ValueError(
1989+
f"target_task={target_task} is not a key in feature_indices. "
1990+
f"Available tasks: {sorted(feature_indices.keys())}."
1991+
)
1992+
self.target_task = target_task
1993+
1994+
if task_feature_index != -1:
1995+
raise ValueError(
1996+
"LearnedFeatureImputation requires task_feature_index=-1 "
1997+
"(task column last). Different tasks may have different "
1998+
"feature counts, so a fixed non-last position is ambiguous."
1999+
)
2000+
2001+
task_values_sorted = sorted(feature_indices.keys())
2002+
self.num_tasks = len(task_values_sorted)
2003+
self.d = d
2004+
2005+
# Sorted task identifiers as they appear in the task column of X.
2006+
self.register_buffer(
2007+
"_task_values",
2008+
torch.tensor(task_values_sorted, dtype=torch.long, device=device),
2009+
)
2010+
2011+
if bounds is not None:
2012+
if bounds.shape != (2, d):
2013+
raise ValueError(f"bounds must have shape (2, {d}), got {bounds.shape}")
2014+
bounds = bounds.to(dtype=dtype, device=device)
2015+
if not ((bounds[0] == 0).all() and (bounds[1] == 1).all()):
2016+
warn(
2017+
"Non-default bounds passed to LearnedFeatureImputation. "
2018+
"This transform expects normalized [0, 1] inputs -- chain "
2019+
"Normalize before this transform so that the default "
2020+
"bounds=[0, 1]^d are appropriate.",
2021+
UserInputWarning,
2022+
stacklevel=2,
2023+
)
2024+
self.register_buffer("bounds", bounds)
2025+
else:
2026+
self.register_buffer("bounds", None)
2027+
2028+
# Validate that no feature index overlaps with the task column.
2029+
for task_value, feat_cols in feature_indices.items():
2030+
if d in feat_cols:
2031+
raise ValueError(
2032+
f"feature_indices[{task_value}] contains the task column "
2033+
f"index {d}. Feature indices must not include the "
2034+
f"task column."
2035+
)
2036+
2037+
missing_mask = torch.ones(
2038+
self.num_tasks, d + 1, dtype=torch.bool, device=device
2039+
)
2040+
missing_mask[:, -1] = False
2041+
for task_pos, task_value in enumerate(task_values_sorted):
2042+
missing_mask[task_pos, feature_indices[task_value]] = False
2043+
self.register_buffer("missing_mask", missing_mask)
2044+
2045+
# Learnable imputation values, shape (num_tasks, d+1). The task column
2046+
# slot is unused but kept for index alignment with X columns.
2047+
self.register_parameter(
2048+
"raw_imputation_values",
2049+
nn.Parameter(
2050+
torch.zeros(self.num_tasks, d + 1, dtype=dtype, device=device)
2051+
),
2052+
)
2053+
if bounds is not None:
2054+
# Pad bounds with dummy [0, 1] for the task column so the Interval
2055+
# constraint has shape (d+1,) matching raw_imputation_values.
2056+
padded_lower = torch.zeros(d + 1, dtype=dtype, device=device)
2057+
padded_upper = torch.ones(d + 1, dtype=dtype, device=device)
2058+
padded_lower[:d] = bounds[0]
2059+
padded_upper[:d] = bounds[1]
2060+
self.register_constraint(
2061+
"raw_imputation_values",
2062+
Interval(
2063+
lower_bound=padded_lower,
2064+
upper_bound=padded_upper,
2065+
),
2066+
)
2067+
2068+
@property
2069+
def imputation_values(self) -> Tensor:
2070+
r"""The imputation values, mapped through the Interval constraint when
2071+
bounds are present, or the raw values otherwise."""
2072+
if self.bounds is not None:
2073+
return self.raw_imputation_values_constraint.transform(
2074+
self.raw_imputation_values
2075+
)
2076+
return self.raw_imputation_values
2077+
2078+
def transform(self, X: Tensor) -> Tensor:
2079+
r"""Impute missing features with learned values.
2080+
2081+
Args:
2082+
X: A ``batch_shape x n x (d+1)``-dim tensor of inputs where the
2083+
last column contains integer task identifiers, or a
2084+
``batch_shape x n x d``-dim tensor when ``target_task`` was
2085+
configured at init (the task column is appended automatically).
2086+
2087+
Returns:
2088+
A ``batch_shape x n x (d+1)``-dim tensor with missing features
2089+
replaced by learned imputation values.
2090+
"""
2091+
x_dim = X.shape[-1]
2092+
if x_dim == self.d:
2093+
if self.target_task is None:
2094+
raise ValueError(
2095+
f"Received d-dim input (X.shape[-1]={self.d}) but no "
2096+
"target_task was configured. When X lacks a task column, "
2097+
"target_task must be specified at init so the transform "
2098+
"knows which task's imputation pattern to apply."
2099+
)
2100+
task_col = torch.full(
2101+
X.shape[:-1] + (1,),
2102+
self.target_task,
2103+
dtype=X.dtype,
2104+
device=X.device,
2105+
)
2106+
X = torch.cat([X, task_col], dim=-1)
2107+
elif x_dim != self.d + 1:
2108+
raise ValueError(
2109+
f"Expected X.shape[-1] to be {self.d} (no task column) or "
2110+
f"{self.d + 1} (with task column), got {x_dim}."
2111+
)
2112+
2113+
X_new = X.clone()
2114+
2115+
task_ids = X_new[..., -1].long()
2116+
imputation_vals = self.imputation_values
2117+
2118+
# For each task, replace unobserved feature columns with learned values.
2119+
# torch.where with task_mask ensures rows belonging to other tasks are
2120+
# left untouched, even if the same column is observed for those tasks.
2121+
for task_pos in range(self.num_tasks):
2122+
task_value = self._task_values[task_pos]
2123+
task_mask = task_ids == task_value
2124+
if not task_mask.any():
2125+
continue
2126+
missing_cols = (
2127+
self.missing_mask[task_pos].nonzero(as_tuple=False).squeeze(-1)
2128+
)
2129+
if missing_cols.numel() == 0:
2130+
continue
2131+
X_new[..., missing_cols] = torch.where(
2132+
task_mask.unsqueeze(-1),
2133+
imputation_vals[task_pos, missing_cols],
2134+
X_new[..., missing_cols],
2135+
)
2136+
return X_new

0 commit comments

Comments
 (0)