Skip to content

Commit 9fb1c7a

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Extract get_heterogeneous_feature_mapping into MultiTaskDataset (meta-pytorch#3285)
Summary: Pull Request resolved: meta-pytorch#3285 Extracts the shared feature-ordering and index-mapping logic from `HeterogeneousMTGP.construct_inputs` into `MultiTaskDataset.get_heterogeneous_feature_mapping()`, enabling reuse by `MultiTaskGP.construct_inputs` in the follow-up diff. Reviewed By: sdaulton Differential Revision: D101903471 fbshipit-source-id: 5a2c7fd85617d1f0e2eb10c7dd8f1fe0d2765d5a
1 parent 8a0cbef commit 9fb1c7a

3 files changed

Lines changed: 70 additions & 17 deletions

File tree

botorch/models/heterogeneous_mtgp.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -295,33 +295,21 @@ def construct_inputs(
295295
"Heterogeneous MTGP currently only supports output_tasks=[0]. "
296296
"The target task will be given the task value of 0."
297297
)
298-
child_datasets = training_data.datasets.copy()
299-
target_dataset = child_datasets.pop(training_data.target_outcome_name)
300-
all_datasets = [target_dataset] + list(child_datasets.values())
301-
# Use target's feature order as canonical (NO alphabetical sort).
302-
# Source-only features are appended at the end.
303-
all_features: list[str] = list(target_dataset.feature_names[:-1])
304-
for ds in all_datasets[1:]:
305-
for fn in ds.feature_names[:-1]:
306-
if fn not in all_features:
307-
all_features.append(fn)
308-
# Get indices mapping the features from a given dataset to all features.
309-
feature_indices = [
310-
[all_features.index(fn) for fn in ds.feature_names[:-1]]
311-
for ds in all_datasets
312-
]
298+
all_datasets, feature_indices, full_feature_dim = (
299+
training_data.get_heterogeneous_feature_mapping()
300+
)
313301
Xs = [ds.X[..., :-1] for ds in all_datasets]
314302
Ys = [ds.Y for ds in all_datasets]
315303
Yvars = (
316-
None if target_dataset.Yvar is None else [ds.Yvar for ds in all_datasets]
304+
None if all_datasets[0].Yvar is None else [ds.Yvar for ds in all_datasets]
317305
)
318306
all_tasks = list(range(len(all_datasets)))
319307
return {
320308
"train_Xs": Xs,
321309
"train_Ys": Ys,
322310
"train_Yvars": Yvars,
323311
"feature_indices": feature_indices,
324-
"full_feature_dim": len(all_features),
312+
"full_feature_dim": full_feature_dim,
325313
"rank": rank,
326314
"use_saas_prior": use_saas_prior,
327315
"use_combinatorial_kernel": use_combinatorial_kernel,

botorch/utils/datasets.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,46 @@ def __eq__(self, other: Any) -> bool:
542542
and self.task_feature_index == other.task_feature_index
543543
)
544544

545+
def get_heterogeneous_feature_mapping(
546+
self,
547+
) -> tuple[list["SupervisedDataset"], list[list[int]], int]:
548+
"""Compute canonical feature ordering for heterogeneous datasets.
549+
550+
Target features come first (preserving order), then source-only
551+
features are appended. The task column (at ``task_feature_index``)
552+
is excluded from the mapping.
553+
554+
Returns:
555+
A 3-tuple of:
556+
- Ordered datasets (target first, then sources).
557+
- Feature indices mapping each dataset's non-task features
558+
to the canonical ordering.
559+
- Full feature dimensionality (number of unique non-task features).
560+
561+
Raises:
562+
NotImplementedError: If ``task_feature_index`` is not ``-1``.
563+
"""
564+
if self.task_feature_index != -1:
565+
raise NotImplementedError(
566+
"Heterogeneous feature mapping requires `task_feature_index` to be -1."
567+
)
568+
child_datasets = self.datasets.copy()
569+
target_dataset = child_datasets.pop(self.target_outcome_name)
570+
all_datasets = [target_dataset] + list(child_datasets.values())
571+
572+
# Target's feature order is canonical; source-only features appended.
573+
all_features: list[str] = list(target_dataset.feature_names[:-1])
574+
for ds in all_datasets[1:]:
575+
for fn in ds.feature_names[:-1]:
576+
if fn not in all_features:
577+
all_features.append(fn)
578+
579+
feature_indices = [
580+
[all_features.index(fn) for fn in ds.feature_names[:-1]]
581+
for ds in all_datasets
582+
]
583+
return all_datasets, feature_indices, len(all_features)
584+
545585
def clone(
546586
self, deepcopy: bool = False, mask: Tensor | None = None
547587
) -> MultiTaskDataset:

test/utils/test_datasets.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,31 @@ def test_multi_task(self):
491491
MultiTaskDataset(datasets=[dataset_1, dataset_5], target_outcome_name="z"),
492492
)
493493

494+
def test_get_heterogeneous_feature_mapping(self):
495+
ds_target = make_dataset(
496+
d=3, feature_names=["a", "b", "task"], outcome_names=["y"]
497+
)
498+
ds_source = make_dataset(
499+
d=3, feature_names=["a", "c", "task"], outcome_names=["z"]
500+
)
501+
mt_err = MultiTaskDataset(
502+
datasets=[ds_target, ds_source],
503+
target_outcome_name="y",
504+
task_feature_index=0,
505+
)
506+
with self.assertRaises(NotImplementedError):
507+
mt_err.get_heterogeneous_feature_mapping()
508+
509+
mt = MultiTaskDataset(
510+
datasets=[ds_target, ds_source],
511+
target_outcome_name="y",
512+
task_feature_index=-1,
513+
)
514+
all_datasets, feature_indices, full_dim = mt.get_heterogeneous_feature_mapping()
515+
self.assertEqual(len(all_datasets), 2)
516+
self.assertEqual(full_dim, 3)
517+
self.assertEqual(feature_indices, [[0, 1], [0, 2]])
518+
494519
def test_clone_multitask(self) -> None:
495520
for has_yvar in [False, True]:
496521
dataset_1 = make_dataset(outcome_names=["y"], has_yvar=has_yvar)

0 commit comments

Comments
 (0)