Skip to content

Commit e61e49b

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Make JAX/numpyro imports lazy in fully Bayesian models (meta-pytorch#3292)
Summary: `botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models. This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings. Reviewed By: saitcakmak Differential Revision: D102367287
1 parent b9fd180 commit e61e49b

4 files changed

Lines changed: 57 additions & 32 deletions

File tree

botorch/models/fully_bayesian.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,17 @@
3838
from math import log, sqrt
3939
from typing import Any, TypeVar
4040

41-
import jax.numpy as jnp
4241
import numpy as np
4342

44-
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
45-
raise ImportError(
46-
"BoTorch's fully Bayesian models require NumPy >= 2.0 "
47-
"(python-scientific-stack version 3). "
48-
"Please update your PACKAGE file to set "
49-
'"python-scientific-stack": "3".'
50-
)
51-
52-
import jax
53-
import numpyro
54-
import numpyro.distributions as numpyro_dist
43+
try:
44+
import jax
45+
import jax.numpy as jnp
46+
import numpyro
47+
import numpyro.distributions as numpyro_dist
48+
49+
_HAS_JAX = True
50+
except ImportError: # pragma: no cover
51+
_HAS_JAX = False
5552
import torch
5653
from botorch.acquisition.objective import PosteriorTransform
5754
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
@@ -90,6 +87,16 @@
9087
_sqrt5 = math.sqrt(5)
9188

9289

90+
def _check_jax_available() -> None:
91+
if not _HAS_JAX:
92+
raise ImportError(
93+
"BoTorch's fully Bayesian models require JAX, numpyro, and "
94+
"NumPy >= 2.0 (python-scientific-stack version 3). "
95+
"Please update your PACKAGE file to set "
96+
'"python-scientific-stack": "3".'
97+
)
98+
99+
93100
def matern52_kernel(X: jax.Array, lengthscale: jax.Array) -> jax.Array:
94101
"""Matern-5/2 kernel."""
95102
dist = compute_dists(X=X, lengthscale=lengthscale)
@@ -872,6 +879,7 @@ def __init__(
872879
indices_to_warp: An optional list of indices to warp. The default
873880
is to warp all inputs.
874881
"""
882+
_check_jax_available()
875883
if not (
876884
train_X.ndim == train_Y.ndim == 2
877885
and len(train_X) == len(train_Y)

botorch/models/fully_bayesian_multitask.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@
1111
from collections.abc import Mapping
1212
from typing import Any, NoReturn, TypeVar
1313

14-
import jax.numpy as jnp
15-
import numpyro
16-
import numpyro.distributions as numpyro_dist
1714
import torch
1815
from botorch.acquisition.objective import PosteriorTransform
1916
from botorch.models.fully_bayesian import (
17+
_check_jax_available,
18+
_HAS_JAX,
2019
matern52_kernel,
2120
MCMC_DIM,
2221
MIN_INFERRED_NOISE_LEVEL,
2322
reshape_and_detach,
2423
SaasPyroModel,
2524
)
25+
26+
if _HAS_JAX:
27+
import jax.numpy as jnp
28+
import numpyro
29+
import numpyro.distributions as numpyro_dist
2630
from botorch.models.gpytorch import (
2731
BatchedMultiOutputGPyTorchModel,
2832
MultiTaskGPyTorchModel,
@@ -353,6 +357,7 @@ def __init__(
353357
input are expected tasks values. If false, unexpected task values
354358
will be mapped to the first output_task if supplied.
355359
"""
360+
_check_jax_available()
356361
if not (
357362
train_X.ndim == train_Y.ndim == 2
358363
and len(train_X) == len(train_Y)

test/models/test_fully_bayesian.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,20 +1284,16 @@ class TestFullyBayesianLinearWarpingSingleTaskGP(TestFullyBayesianLinearSingleTa
12841284

12851285

12861286
class TestNumpyVersionCheck(BotorchTestCase):
1287-
def test_old_numpy_raises_import_error(self) -> None:
1288-
"""Test that importing fully_bayesian with NumPy < 2.0 raises ImportError."""
1289-
import importlib
1290-
import sys
1291-
1292-
from numpy.lib import NumpyVersion
1293-
1294-
with patch.object(NumpyVersion, "__lt__", return_value=True):
1295-
# Remove cached module so the version check re-runs on import.
1296-
mod_name = "botorch.models.fully_bayesian"
1297-
saved = sys.modules.pop(mod_name)
1298-
try:
1299-
with self.assertRaises(ImportError):
1300-
importlib.import_module(mod_name)
1301-
finally:
1302-
# Restore the module so other tests are not affected.
1303-
sys.modules[mod_name] = saved
1287+
def test_missing_jax_raises_on_instantiation(self) -> None:
1288+
"""Test that missing JAX raises ImportError at model instantiation."""
1289+
from botorch.models import fully_bayesian
1290+
from botorch.models.fully_bayesian import _check_jax_available
1291+
1292+
with patch.object(fully_bayesian, "_HAS_JAX", False):
1293+
with self.assertRaises(ImportError):
1294+
_check_jax_available()
1295+
with self.assertRaises(ImportError):
1296+
SaasFullyBayesianSingleTaskGP(
1297+
train_X=torch.rand(10, 2),
1298+
train_Y=torch.rand(10, 1),
1299+
)

test/models/test_fully_bayesian_multitask.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import itertools
9+
from unittest.mock import patch
910

1011
import jax.numpy as jnp
1112
import numpyro.handlers
@@ -182,6 +183,21 @@ def _get_mcmc_samples(self, num_samples: int, dim: int, task_rank: int, **tkwarg
182183
}
183184
return mcmc_samples
184185

186+
def test_missing_jax_raises_on_instantiation(self) -> None:
187+
"""Test that missing JAX raises ImportError at model instantiation."""
188+
from botorch.models import fully_bayesian
189+
190+
tkwargs = {"device": self.device, "dtype": torch.double}
191+
train_X, train_Y, train_Yvar = self._get_base_data(**tkwargs)
192+
with patch.object(fully_bayesian, "_HAS_JAX", False):
193+
with self.assertRaises(ImportError):
194+
SaasFullyBayesianMultiTaskGP(
195+
train_X=train_X,
196+
train_Y=train_Y,
197+
train_Yvar=train_Yvar,
198+
task_feature=4,
199+
)
200+
185201
def test_raises(self):
186202
tkwargs = {"device": self.device, "dtype": torch.double}
187203
with self.assertRaisesRegex(

0 commit comments

Comments
 (0)