|
35 | 35 | from botorch.models.utils import fantasize |
36 | 36 | from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE |
37 | 37 | from gpytorch import Module as GPyTorchModule |
38 | | -from gpytorch.constraints import GreaterThan |
| 38 | +from gpytorch.constraints import GreaterThan, Interval |
39 | 39 | from gpytorch.priors import Prior |
40 | 40 | from torch import LongTensor, nn, Tensor |
41 | 41 | from torch.nn import Module, ModuleDict |
@@ -1915,3 +1915,222 @@ def equals(self, other: InputTransform) -> bool: |
1915 | 1915 | and (self.transform_on_fantasize == other.transform_on_fantasize) |
1916 | 1916 | and self.categorical_features == other.categorical_features |
1917 | 1917 | ) |
| 1918 | + |
| 1919 | + |
| 1920 | +class LearnedFeatureImputation(InputTransform, GPyTorchModule): |
| 1921 | + r"""An input transform that learns imputation values for missing features |
| 1922 | + in heterogeneous multi-task settings. |
| 1923 | +
|
| 1924 | + In multi-task problems where different tasks observe different subsets of |
| 1925 | + features, this transform fills in the unobserved feature columns with |
| 1926 | + learned parameter values. This enables using a standard ``MultiTaskGP`` |
| 1927 | + with composable input transforms instead of specialized model classes. |
| 1928 | +
|
| 1929 | + The input tensor ``X`` is expected to have shape ``batch_shape x n x (d+1)``, |
| 1930 | + where the column at ``task_feature_index`` contains the task identifier. |
| 1931 | + For each task, feature columns not listed in ``feature_indices[task_value]`` |
| 1932 | + are replaced with the corresponding learned imputation values. Task values |
| 1933 | + need not be contiguous or 0-indexed. |
| 1934 | + """ |
| 1935 | + |
| 1936 | + def __init__( |
| 1937 | + self, |
| 1938 | + feature_indices: dict[int, list[int]], |
| 1939 | + d: int, |
| 1940 | + task_feature_index: int = -1, |
| 1941 | + target_task: int | None = None, |
| 1942 | + bounds: Tensor | None = None, |
| 1943 | + transform_on_train: bool = True, |
| 1944 | + transform_on_eval: bool = True, |
| 1945 | + transform_on_fantasize: bool = True, |
| 1946 | + dtype: torch.dtype = torch.float64, |
| 1947 | + device: torch.device | None = None, |
| 1948 | + ) -> None: |
| 1949 | + r"""Initialize LearnedFeatureImputation. |
| 1950 | +
|
| 1951 | + Args: |
| 1952 | + feature_indices: A mapping from integer task values (as they appear |
| 1953 | + in the task column of ``X``) to lists of observed X-column |
| 1954 | + indices for that task. Indices refer directly to columns of the |
| 1955 | + input tensor ``X`` and must not include the task column. When |
| 1956 | + ``task_feature_index=-1`` (the common case), the ``d`` feature |
| 1957 | + columns are ``0, 1, ..., d-1``. Task values need not be |
| 1958 | + contiguous or 0-indexed. |
| 1959 | + d: The total number of feature columns (excluding the task column). |
| 1960 | + task_feature_index: The column index in ``X`` that contains the |
| 1961 | + task identifier. Must be ``-1`` (last column). Defaults to ``-1``. |
| 1962 | + target_task: The task identifier to use when ``X`` has ``d`` |
| 1963 | + columns (no task column). Required for d-dim inputs since the |
| 1964 | + task cannot be inferred from shape alone — two tasks may share |
| 1965 | + the same number of active dimensions. Must be a key in |
| 1966 | + ``feature_indices``. If ``None``, only ``(d+1)``-dim inputs |
| 1967 | + are supported. |
| 1968 | + bounds: A ``2 x d`` tensor of ``[lower, upper]`` bounds for each |
| 1969 | + feature. If provided, imputation values are constrained to lie |
| 1970 | + within these bounds via a GPyTorch ``Interval`` constraint. |
| 1971 | + Defaults to ``None`` (unconstrained). This transform is designed |
| 1972 | + to operate on normalized inputs; if bounds differ from |
| 1973 | + ``[0, 1]^d``, a warning is emitted suggesting to chain |
| 1974 | + ``Normalize`` before this transform. |
| 1975 | + transform_on_train: If ``True``, apply the transform in train mode. |
| 1976 | + transform_on_eval: If ``True``, apply the transform in eval mode. |
| 1977 | + transform_on_fantasize: If ``True``, apply the transform inside |
| 1978 | + ``fantasize`` calls. |
| 1979 | + dtype: The dtype for the imputation parameters. |
| 1980 | + device: The device for the imputation parameters. |
| 1981 | + """ |
| 1982 | + super().__init__() |
| 1983 | + self.transform_on_train = transform_on_train |
| 1984 | + self.transform_on_eval = transform_on_eval |
| 1985 | + self.transform_on_fantasize = transform_on_fantasize |
| 1986 | + |
| 1987 | + if target_task is not None and target_task not in feature_indices: |
| 1988 | + raise ValueError( |
| 1989 | + f"target_task={target_task} is not a key in feature_indices. " |
| 1990 | + f"Available tasks: {sorted(feature_indices.keys())}." |
| 1991 | + ) |
| 1992 | + self.target_task = target_task |
| 1993 | + |
| 1994 | + if task_feature_index != -1: |
| 1995 | + raise ValueError( |
| 1996 | + "LearnedFeatureImputation requires task_feature_index=-1 " |
| 1997 | + "(task column last). Different tasks may have different " |
| 1998 | + "feature counts, so a fixed non-last position is ambiguous." |
| 1999 | + ) |
| 2000 | + |
| 2001 | + task_values_sorted = sorted(feature_indices.keys()) |
| 2002 | + self.num_tasks = len(task_values_sorted) |
| 2003 | + self.d = d |
| 2004 | + |
| 2005 | + # Sorted task identifiers as they appear in the task column of X. |
| 2006 | + self.register_buffer( |
| 2007 | + "_task_values", |
| 2008 | + torch.tensor(task_values_sorted, dtype=torch.long, device=device), |
| 2009 | + ) |
| 2010 | + |
| 2011 | + if bounds is not None: |
| 2012 | + if bounds.shape != (2, d): |
| 2013 | + raise ValueError(f"bounds must have shape (2, {d}), got {bounds.shape}") |
| 2014 | + bounds = bounds.to(dtype=dtype, device=device) |
| 2015 | + if not ((bounds[0] == 0).all() and (bounds[1] == 1).all()): |
| 2016 | + warn( |
| 2017 | + "Non-default bounds passed to LearnedFeatureImputation. " |
| 2018 | + "This transform expects normalized [0, 1] inputs -- chain " |
| 2019 | + "Normalize before this transform so that the default " |
| 2020 | + "bounds=[0, 1]^d are appropriate.", |
| 2021 | + UserInputWarning, |
| 2022 | + stacklevel=2, |
| 2023 | + ) |
| 2024 | + self.register_buffer("bounds", bounds) |
| 2025 | + else: |
| 2026 | + self.register_buffer("bounds", None) |
| 2027 | + |
| 2028 | + # Validate that no feature index overlaps with the task column. |
| 2029 | + for task_value, feat_cols in feature_indices.items(): |
| 2030 | + if d in feat_cols: |
| 2031 | + raise ValueError( |
| 2032 | + f"feature_indices[{task_value}] contains the task column " |
| 2033 | + f"index {d}. Feature indices must not include the " |
| 2034 | + f"task column." |
| 2035 | + ) |
| 2036 | + |
| 2037 | + missing_mask = torch.ones( |
| 2038 | + self.num_tasks, d + 1, dtype=torch.bool, device=device |
| 2039 | + ) |
| 2040 | + missing_mask[:, -1] = False |
| 2041 | + for task_pos, task_value in enumerate(task_values_sorted): |
| 2042 | + missing_mask[task_pos, feature_indices[task_value]] = False |
| 2043 | + self.register_buffer("missing_mask", missing_mask) |
| 2044 | + |
| 2045 | + # Learnable imputation values, shape (num_tasks, d+1). The task column |
| 2046 | + # slot is unused but kept for index alignment with X columns. |
| 2047 | + self.register_parameter( |
| 2048 | + "raw_imputation_values", |
| 2049 | + nn.Parameter( |
| 2050 | + torch.zeros(self.num_tasks, d + 1, dtype=dtype, device=device) |
| 2051 | + ), |
| 2052 | + ) |
| 2053 | + if bounds is not None: |
| 2054 | + # Pad bounds with dummy [0, 1] for the task column so the Interval |
| 2055 | + # constraint has shape (d+1,) matching raw_imputation_values. |
| 2056 | + padded_lower = torch.zeros(d + 1, dtype=dtype, device=device) |
| 2057 | + padded_upper = torch.ones(d + 1, dtype=dtype, device=device) |
| 2058 | + padded_lower[:d] = bounds[0] |
| 2059 | + padded_upper[:d] = bounds[1] |
| 2060 | + self.register_constraint( |
| 2061 | + "raw_imputation_values", |
| 2062 | + Interval( |
| 2063 | + lower_bound=padded_lower, |
| 2064 | + upper_bound=padded_upper, |
| 2065 | + ), |
| 2066 | + ) |
| 2067 | + |
| 2068 | + @property |
| 2069 | + def imputation_values(self) -> Tensor: |
| 2070 | + r"""The imputation values, mapped through the Interval constraint when |
| 2071 | + bounds are present, or the raw values otherwise.""" |
| 2072 | + if self.bounds is not None: |
| 2073 | + return self.raw_imputation_values_constraint.transform( |
| 2074 | + self.raw_imputation_values |
| 2075 | + ) |
| 2076 | + return self.raw_imputation_values |
| 2077 | + |
| 2078 | + def transform(self, X: Tensor) -> Tensor: |
| 2079 | + r"""Impute missing features with learned values. |
| 2080 | +
|
| 2081 | + Args: |
| 2082 | + X: A ``batch_shape x n x (d+1)``-dim tensor of inputs where the |
| 2083 | + last column contains integer task identifiers, or a |
| 2084 | + ``batch_shape x n x d``-dim tensor when ``target_task`` was |
| 2085 | + configured at init (the task column is appended automatically). |
| 2086 | +
|
| 2087 | + Returns: |
| 2088 | + A ``batch_shape x n x (d+1)``-dim tensor with missing features |
| 2089 | + replaced by learned imputation values. |
| 2090 | + """ |
| 2091 | + x_dim = X.shape[-1] |
| 2092 | + if x_dim == self.d: |
| 2093 | + if self.target_task is None: |
| 2094 | + raise ValueError( |
| 2095 | + f"Received d-dim input (X.shape[-1]={self.d}) but no " |
| 2096 | + "target_task was configured. When X lacks a task column, " |
| 2097 | + "target_task must be specified at init so the transform " |
| 2098 | + "knows which task's imputation pattern to apply." |
| 2099 | + ) |
| 2100 | + task_col = torch.full( |
| 2101 | + X.shape[:-1] + (1,), |
| 2102 | + self.target_task, |
| 2103 | + dtype=X.dtype, |
| 2104 | + device=X.device, |
| 2105 | + ) |
| 2106 | + X = torch.cat([X, task_col], dim=-1) |
| 2107 | + elif x_dim != self.d + 1: |
| 2108 | + raise ValueError( |
| 2109 | + f"Expected X.shape[-1] to be {self.d} (no task column) or " |
| 2110 | + f"{self.d + 1} (with task column), got {x_dim}." |
| 2111 | + ) |
| 2112 | + |
| 2113 | + X_new = X.clone() |
| 2114 | + |
| 2115 | + task_ids = X_new[..., -1].long() |
| 2116 | + imputation_vals = self.imputation_values |
| 2117 | + |
| 2118 | + # For each task, replace unobserved feature columns with learned values. |
| 2119 | + # torch.where with task_mask ensures rows belonging to other tasks are |
| 2120 | + # left untouched, even if the same column is observed for those tasks. |
| 2121 | + for task_pos in range(self.num_tasks): |
| 2122 | + task_value = self._task_values[task_pos] |
| 2123 | + task_mask = task_ids == task_value |
| 2124 | + if not task_mask.any(): |
| 2125 | + continue |
| 2126 | + missing_cols = ( |
| 2127 | + self.missing_mask[task_pos].nonzero(as_tuple=False).squeeze(-1) |
| 2128 | + ) |
| 2129 | + if missing_cols.numel() == 0: |
| 2130 | + continue |
| 2131 | + X_new[..., missing_cols] = torch.where( |
| 2132 | + task_mask.unsqueeze(-1), |
| 2133 | + imputation_vals[task_pos, missing_cols], |
| 2134 | + X_new[..., missing_cols], |
| 2135 | + ) |
| 2136 | + return X_new |
0 commit comments