|
25 | 25 | from botorch.models.utils.priors import BetaPrior |
26 | 26 | from botorch.posteriors import GPyTorchPosterior |
27 | 27 | from botorch.posteriors.transformed import TransformedPosterior |
| 28 | +from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset |
28 | 29 | from botorch.utils.test_helpers import gen_multi_task_dataset |
29 | 30 | from botorch.utils.testing import BotorchTestCase |
30 | 31 | from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal |
@@ -803,6 +804,209 @@ def test_multitask_gp_unobserved_tasks(self) -> None: |
803 | 804 | samples = posterior.rsample(sample_shape=torch.Size([2])) |
804 | 805 | self.assertEqual(samples.shape, torch.Size([2, 3, 1])) |
805 | 806 |
|
| 807 | + def test_construct_inputs_heterogeneous(self) -> None: |
| 808 | + tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} |
| 809 | + |
| 810 | + # Build a heterogeneous MultiTaskDataset: target has features [a, b], |
| 811 | + # source has features [a, c]. |
| 812 | + n_target, n_source = 10, 8 |
| 813 | + X_target = torch.rand(n_target, 2, **tkwargs) |
| 814 | + X_source = torch.rand(n_source, 2, **tkwargs) |
| 815 | + Y_target = torch.rand(n_target, 1, **tkwargs) |
| 816 | + Y_source = torch.rand(n_source, 1, **tkwargs) |
| 817 | + |
| 818 | + ds_target = SupervisedDataset( |
| 819 | + X=torch.cat([X_target, torch.zeros(n_target, 1, **tkwargs)], dim=-1), |
| 820 | + Y=Y_target, |
| 821 | + feature_names=["a", "b", "task"], |
| 822 | + outcome_names=["target"], |
| 823 | + ) |
| 824 | + ds_source = SupervisedDataset( |
| 825 | + X=torch.cat([X_source, torch.ones(n_source, 1, **tkwargs)], dim=-1), |
| 826 | + Y=Y_source, |
| 827 | + feature_names=["a", "c", "task"], |
| 828 | + outcome_names=["source"], |
| 829 | + ) |
| 830 | + mt_dataset = MultiTaskDataset( |
| 831 | + datasets=[ds_target, ds_source], |
| 832 | + target_outcome_name="target", |
| 833 | + task_feature_index=-1, |
| 834 | + ) |
| 835 | + self.assertTrue(mt_dataset.has_heterogeneous_features) |
| 836 | + |
| 837 | + # map_heterogeneous_to_full=False (default) → raises UnsupportedError |
| 838 | + with self.assertRaises(UnsupportedError): |
| 839 | + MultiTaskGP.construct_inputs(mt_dataset, task_feature=-1) |
| 840 | + |
| 841 | + # map_heterogeneous_to_full=True → zero-padded train_X |
| 842 | + data_dict = MultiTaskGP.construct_inputs( |
| 843 | + mt_dataset, task_feature=-1, map_heterogeneous_to_full=True |
| 844 | + ) |
| 845 | + train_X = data_dict["train_X"] |
| 846 | + # Full features: [a, b, c] + task → 4 columns |
| 847 | + self.assertEqual(train_X.shape, (n_target + n_source, 4)) |
| 848 | + # Task column is last |
| 849 | + self.assertTrue((train_X[:n_target, -1] == 0).all()) |
| 850 | + self.assertTrue((train_X[n_target:, -1] == 1).all()) |
| 851 | + # Target rows: a, b filled; c (index 2) is zero |
| 852 | + self.assertAllClose(train_X[:n_target, 0], X_target[:, 0]) |
| 853 | + self.assertAllClose(train_X[:n_target, 1], X_target[:, 1]) |
| 854 | + self.assertTrue((train_X[:n_target, 2] == 0).all()) |
| 855 | + # Source rows: a filled at index 0; b (index 1) is zero; c at index 2 |
| 856 | + self.assertAllClose(train_X[n_target:, 0], X_source[:, 0]) |
| 857 | + self.assertTrue((train_X[n_target:, 1] == 0).all()) |
| 858 | + self.assertAllClose(train_X[n_target:, 2], X_source[:, 1]) |
| 859 | + |
| 860 | + self.assertEqual(data_dict["task_feature"], -1) |
| 861 | + self.assertEqual(data_dict["all_tasks"], [0, 1]) |
| 862 | + self.assertTrue( |
| 863 | + torch.equal(data_dict["train_Y"], torch.cat([Y_target, Y_source])) |
| 864 | + ) |
| 865 | + |
| 866 | + # Cover rank and task_covar_prior conditional branches. |
| 867 | + data_dict = MultiTaskGP.construct_inputs( |
| 868 | + mt_dataset, task_feature=-1, map_heterogeneous_to_full=True, rank=2 |
| 869 | + ) |
| 870 | + self.assertEqual(data_dict["rank"], 2) |
| 871 | + prior = LKJCovariancePrior(2, 1.0, SmoothedBoxPrior(0.1, 2.0)) |
| 872 | + data_dict = MultiTaskGP.construct_inputs( |
| 873 | + mt_dataset, |
| 874 | + task_feature=-1, |
| 875 | + map_heterogeneous_to_full=True, |
| 876 | + task_covar_prior=prior, |
| 877 | + ) |
| 878 | + self.assertIs(data_dict["task_covar_prior"], prior) |
| 879 | + |
| 880 | + # task_feature_index != -1 → NotImplementedError |
| 881 | + mt_bad = MultiTaskDataset( |
| 882 | + datasets=[ds_target, ds_source], |
| 883 | + target_outcome_name="target", |
| 884 | + task_feature_index=0, |
| 885 | + ) |
| 886 | + mt_bad.has_heterogeneous_features = True |
| 887 | + with self.assertRaises(NotImplementedError): |
| 888 | + MultiTaskGP.construct_inputs( |
| 889 | + mt_bad, task_feature=-1, map_heterogeneous_to_full=True |
| 890 | + ) |
| 891 | + |
| 892 | + def test_construct_inputs_heterogeneous_with_yvar(self) -> None: |
| 893 | + tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} |
| 894 | + |
| 895 | + n_target, n_source = 6, 4 |
| 896 | + ds_target = SupervisedDataset( |
| 897 | + X=torch.cat( |
| 898 | + [ |
| 899 | + torch.rand(n_target, 2, **tkwargs), |
| 900 | + torch.zeros(n_target, 1, **tkwargs), |
| 901 | + ], |
| 902 | + dim=-1, |
| 903 | + ), |
| 904 | + Y=torch.rand(n_target, 1, **tkwargs), |
| 905 | + Yvar=torch.full((n_target, 1), 0.1, **tkwargs), |
| 906 | + feature_names=["a", "b", "task"], |
| 907 | + outcome_names=["target"], |
| 908 | + ) |
| 909 | + ds_source = SupervisedDataset( |
| 910 | + X=torch.cat( |
| 911 | + [ |
| 912 | + torch.rand(n_source, 1, **tkwargs), |
| 913 | + torch.ones(n_source, 1, **tkwargs), |
| 914 | + ], |
| 915 | + dim=-1, |
| 916 | + ), |
| 917 | + Y=torch.rand(n_source, 1, **tkwargs), |
| 918 | + Yvar=torch.full((n_source, 1), 0.2, **tkwargs), |
| 919 | + feature_names=["a", "task"], |
| 920 | + outcome_names=["source"], |
| 921 | + ) |
| 922 | + mt_dataset = MultiTaskDataset( |
| 923 | + datasets=[ds_target, ds_source], |
| 924 | + target_outcome_name="target", |
| 925 | + task_feature_index=-1, |
| 926 | + ) |
| 927 | + data_dict = MultiTaskGP.construct_inputs( |
| 928 | + mt_dataset, task_feature=-1, map_heterogeneous_to_full=True |
| 929 | + ) |
| 930 | + # Full features: [a, b] + task → 3 columns |
| 931 | + self.assertEqual(data_dict["train_X"].shape, (n_target + n_source, 3)) |
| 932 | + self.assertIn("train_Yvar", data_dict) |
| 933 | + self.assertEqual(data_dict["train_Yvar"].shape, (n_target + n_source, 1)) |
| 934 | + |
| 935 | + def test_e2e_multitask_gp_with_learned_feature_imputation(self) -> None: |
| 936 | + from botorch.models.transforms.input import ( |
| 937 | + ChainedInputTransform, |
| 938 | + LearnedFeatureImputation, |
| 939 | + ) |
| 940 | + |
| 941 | + tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} |
| 942 | + |
| 943 | + # Heterogeneous dataset: target has [a, b], source has [a, c]. |
| 944 | + n_target, n_source = 15, 10 |
| 945 | + X_target = torch.rand(n_target, 2, **tkwargs) |
| 946 | + X_source = torch.rand(n_source, 2, **tkwargs) |
| 947 | + Y_target = torch.sin(X_target.sum(dim=-1, keepdim=True)) |
| 948 | + Y_source = torch.cos(X_source.sum(dim=-1, keepdim=True)) |
| 949 | + |
| 950 | + ds_target = SupervisedDataset( |
| 951 | + X=torch.cat([X_target, torch.zeros(n_target, 1, **tkwargs)], dim=-1), |
| 952 | + Y=Y_target, |
| 953 | + feature_names=["a", "b", "task"], |
| 954 | + outcome_names=["target"], |
| 955 | + ) |
| 956 | + ds_source = SupervisedDataset( |
| 957 | + X=torch.cat([X_source, torch.ones(n_source, 1, **tkwargs)], dim=-1), |
| 958 | + Y=Y_source, |
| 959 | + feature_names=["a", "c", "task"], |
| 960 | + outcome_names=["source"], |
| 961 | + ) |
| 962 | + mt_dataset = MultiTaskDataset( |
| 963 | + datasets=[ds_target, ds_source], |
| 964 | + target_outcome_name="target", |
| 965 | + task_feature_index=-1, |
| 966 | + ) |
| 967 | + |
| 968 | + model_inputs = MultiTaskGP.construct_inputs( |
| 969 | + mt_dataset, task_feature=-1, map_heterogeneous_to_full=True |
| 970 | + ) |
| 971 | + |
| 972 | + # Full features = [a, b, c], d=3, task at -1. |
| 973 | + d = 3 |
| 974 | + feature_indices = {0: [0, 1], 1: [0, 2]} |
| 975 | + input_transform = ChainedInputTransform( |
| 976 | + normalize=Normalize(d=d + 1, indices=list(range(d))), |
| 977 | + impute=LearnedFeatureImputation( |
| 978 | + feature_indices=feature_indices, |
| 979 | + d=d, |
| 980 | + task_feature_index=-1, |
| 981 | + **tkwargs, |
| 982 | + ), |
| 983 | + ) |
| 984 | + model = MultiTaskGP(**model_inputs, input_transform=input_transform) |
| 985 | + |
| 986 | + # Verify the model can produce posteriors. |
| 987 | + test_X = torch.cat( |
| 988 | + [torch.rand(5, 2, **tkwargs), torch.zeros(5, 1, **tkwargs)], dim=-1 |
| 989 | + ) |
| 990 | + # Posterior call on target-task-dimensioned input won't work directly |
| 991 | + # because MultiTaskGP expects full-dimensional input; test with full X. |
| 992 | + test_X_full = torch.zeros(5, d + 1, **tkwargs) |
| 993 | + test_X_full[:, :2] = test_X[:, :2] |
| 994 | + test_X_full[:, -1] = 0 |
| 995 | + posterior = model.posterior(test_X_full) |
| 996 | + self.assertEqual(posterior.mean.shape, torch.Size([5, 1])) |
| 997 | + self.assertFalse(torch.isnan(posterior.mean).any()) |
| 998 | + |
| 999 | + # Fit the model briefly and verify imputation values change. |
| 1000 | + mll = ExactMarginalLogLikelihood(model.likelihood, model) |
| 1001 | + initial_imp = model.input_transform.impute.imputation_values.clone().detach() |
| 1002 | + with warnings.catch_warnings(): |
| 1003 | + warnings.filterwarnings("ignore", category=OptimizationWarning) |
| 1004 | + fit_gpytorch_mll( |
| 1005 | + mll, optimizer_kwargs={"options": {"maxiter": 5}}, max_attempts=1 |
| 1006 | + ) |
| 1007 | + final_imp = model.input_transform.impute.imputation_values.detach() |
| 1008 | + self.assertFalse(torch.allclose(initial_imp, final_imp, atol=1e-6)) |
| 1009 | + |
806 | 1010 |
|
807 | 1011 | class TestKroneckerMultiTaskGP(BotorchTestCase): |
808 | 1012 | def test_KroneckerMultiTaskGP_default(self) -> None: |
|
0 commit comments