|
12 | 12 | from copy import deepcopy |
13 | 13 | from functools import partial |
14 | 14 | from itertools import filterfalse |
15 | | -from typing import Any |
| 15 | +from typing import Any, TYPE_CHECKING |
16 | 16 | from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage |
17 | 17 |
|
18 | | -import jax |
| 18 | +if TYPE_CHECKING: |
| 19 | + from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP |
| 20 | + from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP |
| 21 | + |
19 | 22 | from botorch.exceptions.errors import ModelFittingError, UnsupportedError |
20 | 23 | from botorch.exceptions.warnings import OptimizationWarning |
21 | 24 | from botorch.logging import logger |
22 | 25 | from botorch.models import SingleTaskGP |
23 | | -from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP |
24 | | -from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP |
25 | 26 | from botorch.models.map_saas import get_map_saas_model |
26 | 27 | from botorch.models.model_list_gp_regression import ModelListGP |
27 | 28 | from botorch.models.transforms.input import InputTransform |
|
44 | 45 | from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood |
45 | 46 | from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood |
46 | 47 | from linear_operator.utils.errors import NotPSDError |
47 | | -from numpyro.infer import MCMC, NUTS |
48 | 48 | from torch import device, Tensor |
49 | 49 | from torch.nn import Parameter |
50 | 50 | from torch.utils.data import DataLoader |
@@ -367,6 +367,11 @@ def fit_fully_bayesian_model_nuts( |
367 | 367 | >>> gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y) |
368 | 368 | >>> fit_fully_bayesian_model_nuts(gp) |
369 | 369 | """ |
| 370 | + # Local import to avoid pulling in JAX/numpyro at module level, |
| 371 | + # which would break environments without NumPy >= 2.0. |
| 372 | + import jax |
| 373 | + from numpyro.infer import MCMC, NUTS |
| 374 | + |
370 | 375 | model.train() |
371 | 376 |
|
372 | 377 | # Do inference with NUTS |
|
0 commit comments