Skip to content

Commit fa4bc2b

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Add map_heterogeneous_to_full to MultiTaskGP.construct_inputs (meta-pytorch#3286)
Summary: Pull Request resolved: meta-pytorch#3286 Adds a `map_heterogeneous_to_full` parameter to `MultiTaskGP.construct_inputs` that zero-pads heterogeneous `MultiTaskDataset` features into the union feature space, enabling `MultiTaskGP` + `LearnedFeatureImputation` as a drop-in replacement for `ImputedMultiTaskGP`. Reviewed By: sdaulton Differential Revision: D101903488 fbshipit-source-id: 04fb51c60d2fef44f9b29e1520a586ee860ede29
1 parent 9fb1c7a commit fa4bc2b

2 files changed

Lines changed: 256 additions & 0 deletions

File tree

botorch/models/multitask.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def construct_inputs(
486486
task_covar_prior: Prior | _DefaultType | None = DEFAULT,
487487
prior_config: dict | None = None,
488488
rank: int | None = None,
489+
map_heterogeneous_to_full: bool = False,
489490
) -> dict[str, Any]:
490491
r"""Construct ``Model`` keyword arguments from a dataset and other args.
491492
@@ -503,6 +504,13 @@ def construct_inputs(
503504
Must contain ``use_LKJ_prior`` indicator and should contain float
504505
value ``eta``.
505506
rank: The rank of the cross-task covariance matrix.
507+
map_heterogeneous_to_full: If True and ``training_data`` is a
508+
``MultiTaskDataset`` with heterogeneous features, zero-pad each
509+
task's features into the union feature space and concatenate into
510+
a single ``train_X`` tensor. The zero-padded entries are intended
511+
to be overwritten by a ``LearnedFeatureImputation`` input
512+
transform. If False (default), heterogeneous features will raise
513+
``UnsupportedError`` via ``training_data.X``.
506514
"""
507515
if (
508516
task_covar_prior is not DEFAULT
@@ -525,6 +533,50 @@ def construct_inputs(
525533
raise ValueError(f"eta must be a real number, your eta was {eta}.")
526534
task_covar_prior = LKJCovariancePrior(num_tasks, eta, sd_prior)
527535

536+
# Handle heterogeneous MultiTaskDataset by zero-padding into the union
537+
# feature space. This branch bypasses super().construct_inputs() since
538+
# training_data.X would raise UnsupportedError.
539+
if (
540+
map_heterogeneous_to_full
541+
and isinstance(training_data, MultiTaskDataset)
542+
and training_data.has_heterogeneous_features
543+
):
544+
all_datasets, feature_indices, full_feature_dim = (
545+
training_data.get_heterogeneous_feature_mapping()
546+
)
547+
548+
# Zero-pad each task's X into the full feature space + task column.
549+
all_Xs = []
550+
for task_idx, (ds, fi) in enumerate(
551+
zip(all_datasets, feature_indices, strict=True)
552+
):
553+
X_task = ds.X[..., :-1] # strip task feature column
554+
X_full = torch.zeros(
555+
*X_task.shape[:-1],
556+
full_feature_dim + 1,
557+
dtype=X_task.dtype,
558+
device=X_task.device,
559+
)
560+
X_full[..., fi] = X_task
561+
X_full[..., -1] = task_idx
562+
all_Xs.append(X_full)
563+
564+
all_Yvars = [ds.Yvar for ds in all_datasets]
565+
base_inputs: dict[str, Any] = {
566+
"train_X": torch.cat(all_Xs, dim=0),
567+
"train_Y": torch.cat([ds.Y for ds in all_datasets], dim=0),
568+
}
569+
if all_Yvars[0] is not None:
570+
base_inputs["train_Yvar"] = torch.cat(all_Yvars, dim=0)
571+
base_inputs["task_feature"] = -1
572+
base_inputs["all_tasks"] = list(range(len(all_datasets)))
573+
base_inputs["output_tasks"] = output_tasks
574+
if task_covar_prior is not DEFAULT:
575+
base_inputs["task_covar_prior"] = task_covar_prior
576+
if rank is not None:
577+
base_inputs["rank"] = rank
578+
return base_inputs
579+
528580
# Call Model.construct_inputs to parse training data
529581
base_inputs = super().construct_inputs(training_data=training_data)
530582
if (

test/models/test_multitask.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from botorch.models.utils.priors import BetaPrior
2626
from botorch.posteriors import GPyTorchPosterior
2727
from botorch.posteriors.transformed import TransformedPosterior
28+
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
2829
from botorch.utils.test_helpers import gen_multi_task_dataset
2930
from botorch.utils.testing import BotorchTestCase
3031
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
@@ -803,6 +804,209 @@ def test_multitask_gp_unobserved_tasks(self) -> None:
803804
samples = posterior.rsample(sample_shape=torch.Size([2]))
804805
self.assertEqual(samples.shape, torch.Size([2, 3, 1]))
805806

807+
def test_construct_inputs_heterogeneous(self) -> None:
808+
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
809+
810+
# Build a heterogeneous MultiTaskDataset: target has features [a, b],
811+
# source has features [a, c].
812+
n_target, n_source = 10, 8
813+
X_target = torch.rand(n_target, 2, **tkwargs)
814+
X_source = torch.rand(n_source, 2, **tkwargs)
815+
Y_target = torch.rand(n_target, 1, **tkwargs)
816+
Y_source = torch.rand(n_source, 1, **tkwargs)
817+
818+
ds_target = SupervisedDataset(
819+
X=torch.cat([X_target, torch.zeros(n_target, 1, **tkwargs)], dim=-1),
820+
Y=Y_target,
821+
feature_names=["a", "b", "task"],
822+
outcome_names=["target"],
823+
)
824+
ds_source = SupervisedDataset(
825+
X=torch.cat([X_source, torch.ones(n_source, 1, **tkwargs)], dim=-1),
826+
Y=Y_source,
827+
feature_names=["a", "c", "task"],
828+
outcome_names=["source"],
829+
)
830+
mt_dataset = MultiTaskDataset(
831+
datasets=[ds_target, ds_source],
832+
target_outcome_name="target",
833+
task_feature_index=-1,
834+
)
835+
self.assertTrue(mt_dataset.has_heterogeneous_features)
836+
837+
# map_heterogeneous_to_full=False (default) → raises UnsupportedError
838+
with self.assertRaises(UnsupportedError):
839+
MultiTaskGP.construct_inputs(mt_dataset, task_feature=-1)
840+
841+
# map_heterogeneous_to_full=True → zero-padded train_X
842+
data_dict = MultiTaskGP.construct_inputs(
843+
mt_dataset, task_feature=-1, map_heterogeneous_to_full=True
844+
)
845+
train_X = data_dict["train_X"]
846+
# Full features: [a, b, c] + task → 4 columns
847+
self.assertEqual(train_X.shape, (n_target + n_source, 4))
848+
# Task column is last
849+
self.assertTrue((train_X[:n_target, -1] == 0).all())
850+
self.assertTrue((train_X[n_target:, -1] == 1).all())
851+
# Target rows: a, b filled; c (index 2) is zero
852+
self.assertAllClose(train_X[:n_target, 0], X_target[:, 0])
853+
self.assertAllClose(train_X[:n_target, 1], X_target[:, 1])
854+
self.assertTrue((train_X[:n_target, 2] == 0).all())
855+
# Source rows: a filled at index 0; b (index 1) is zero; c at index 2
856+
self.assertAllClose(train_X[n_target:, 0], X_source[:, 0])
857+
self.assertTrue((train_X[n_target:, 1] == 0).all())
858+
self.assertAllClose(train_X[n_target:, 2], X_source[:, 1])
859+
860+
self.assertEqual(data_dict["task_feature"], -1)
861+
self.assertEqual(data_dict["all_tasks"], [0, 1])
862+
self.assertTrue(
863+
torch.equal(data_dict["train_Y"], torch.cat([Y_target, Y_source]))
864+
)
865+
866+
# Cover rank and task_covar_prior conditional branches.
867+
data_dict = MultiTaskGP.construct_inputs(
868+
mt_dataset, task_feature=-1, map_heterogeneous_to_full=True, rank=2
869+
)
870+
self.assertEqual(data_dict["rank"], 2)
871+
prior = LKJCovariancePrior(2, 1.0, SmoothedBoxPrior(0.1, 2.0))
872+
data_dict = MultiTaskGP.construct_inputs(
873+
mt_dataset,
874+
task_feature=-1,
875+
map_heterogeneous_to_full=True,
876+
task_covar_prior=prior,
877+
)
878+
self.assertIs(data_dict["task_covar_prior"], prior)
879+
880+
# task_feature_index != -1 → NotImplementedError
881+
mt_bad = MultiTaskDataset(
882+
datasets=[ds_target, ds_source],
883+
target_outcome_name="target",
884+
task_feature_index=0,
885+
)
886+
mt_bad.has_heterogeneous_features = True
887+
with self.assertRaises(NotImplementedError):
888+
MultiTaskGP.construct_inputs(
889+
mt_bad, task_feature=-1, map_heterogeneous_to_full=True
890+
)
891+
892+
def test_construct_inputs_heterogeneous_with_yvar(self) -> None:
893+
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
894+
895+
n_target, n_source = 6, 4
896+
ds_target = SupervisedDataset(
897+
X=torch.cat(
898+
[
899+
torch.rand(n_target, 2, **tkwargs),
900+
torch.zeros(n_target, 1, **tkwargs),
901+
],
902+
dim=-1,
903+
),
904+
Y=torch.rand(n_target, 1, **tkwargs),
905+
Yvar=torch.full((n_target, 1), 0.1, **tkwargs),
906+
feature_names=["a", "b", "task"],
907+
outcome_names=["target"],
908+
)
909+
ds_source = SupervisedDataset(
910+
X=torch.cat(
911+
[
912+
torch.rand(n_source, 1, **tkwargs),
913+
torch.ones(n_source, 1, **tkwargs),
914+
],
915+
dim=-1,
916+
),
917+
Y=torch.rand(n_source, 1, **tkwargs),
918+
Yvar=torch.full((n_source, 1), 0.2, **tkwargs),
919+
feature_names=["a", "task"],
920+
outcome_names=["source"],
921+
)
922+
mt_dataset = MultiTaskDataset(
923+
datasets=[ds_target, ds_source],
924+
target_outcome_name="target",
925+
task_feature_index=-1,
926+
)
927+
data_dict = MultiTaskGP.construct_inputs(
928+
mt_dataset, task_feature=-1, map_heterogeneous_to_full=True
929+
)
930+
# Full features: [a, b] + task → 3 columns
931+
self.assertEqual(data_dict["train_X"].shape, (n_target + n_source, 3))
932+
self.assertIn("train_Yvar", data_dict)
933+
self.assertEqual(data_dict["train_Yvar"].shape, (n_target + n_source, 1))
934+
935+
def test_e2e_multitask_gp_with_learned_feature_imputation(self) -> None:
936+
from botorch.models.transforms.input import (
937+
ChainedInputTransform,
938+
LearnedFeatureImputation,
939+
)
940+
941+
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
942+
943+
# Heterogeneous dataset: target has [a, b], source has [a, c].
944+
n_target, n_source = 15, 10
945+
X_target = torch.rand(n_target, 2, **tkwargs)
946+
X_source = torch.rand(n_source, 2, **tkwargs)
947+
Y_target = torch.sin(X_target.sum(dim=-1, keepdim=True))
948+
Y_source = torch.cos(X_source.sum(dim=-1, keepdim=True))
949+
950+
ds_target = SupervisedDataset(
951+
X=torch.cat([X_target, torch.zeros(n_target, 1, **tkwargs)], dim=-1),
952+
Y=Y_target,
953+
feature_names=["a", "b", "task"],
954+
outcome_names=["target"],
955+
)
956+
ds_source = SupervisedDataset(
957+
X=torch.cat([X_source, torch.ones(n_source, 1, **tkwargs)], dim=-1),
958+
Y=Y_source,
959+
feature_names=["a", "c", "task"],
960+
outcome_names=["source"],
961+
)
962+
mt_dataset = MultiTaskDataset(
963+
datasets=[ds_target, ds_source],
964+
target_outcome_name="target",
965+
task_feature_index=-1,
966+
)
967+
968+
model_inputs = MultiTaskGP.construct_inputs(
969+
mt_dataset, task_feature=-1, map_heterogeneous_to_full=True
970+
)
971+
972+
# Full features = [a, b, c], d=3, task at -1.
973+
d = 3
974+
feature_indices = {0: [0, 1], 1: [0, 2]}
975+
input_transform = ChainedInputTransform(
976+
normalize=Normalize(d=d + 1, indices=list(range(d))),
977+
impute=LearnedFeatureImputation(
978+
feature_indices=feature_indices,
979+
d=d,
980+
task_feature_index=-1,
981+
**tkwargs,
982+
),
983+
)
984+
model = MultiTaskGP(**model_inputs, input_transform=input_transform)
985+
986+
# Verify the model can produce posteriors.
987+
test_X = torch.cat(
988+
[torch.rand(5, 2, **tkwargs), torch.zeros(5, 1, **tkwargs)], dim=-1
989+
)
990+
# Posterior call on target-task-dimensioned input won't work directly
991+
# because MultiTaskGP expects full-dimensional input; test with full X.
992+
test_X_full = torch.zeros(5, d + 1, **tkwargs)
993+
test_X_full[:, :2] = test_X[:, :2]
994+
test_X_full[:, -1] = 0
995+
posterior = model.posterior(test_X_full)
996+
self.assertEqual(posterior.mean.shape, torch.Size([5, 1]))
997+
self.assertFalse(torch.isnan(posterior.mean).any())
998+
999+
# Fit the model briefly and verify imputation values change.
1000+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
1001+
initial_imp = model.input_transform.impute.imputation_values.clone().detach()
1002+
with warnings.catch_warnings():
1003+
warnings.filterwarnings("ignore", category=OptimizationWarning)
1004+
fit_gpytorch_mll(
1005+
mll, optimizer_kwargs={"options": {"maxiter": 5}}, max_attempts=1
1006+
)
1007+
final_imp = model.input_transform.impute.imputation_values.detach()
1008+
self.assertFalse(torch.allclose(initial_imp, final_imp, atol=1e-6))
1009+
8061010

8071011
class TestKroneckerMultiTaskGP(BotorchTestCase):
8081012
def test_KroneckerMultiTaskGP_default(self) -> None:

0 commit comments

Comments
 (0)