Skip to content

Make JAX/numpyro imports lazy in fully Bayesian models (#3292)#3292

Open
sdaulton wants to merge 2 commits intometa-pytorch:mainfrom
sdaulton:export-D102367287
Open

Make JAX/numpyro imports lazy in fully Bayesian models (#3292)#3292
sdaulton wants to merge 2 commits intometa-pytorch:mainfrom
sdaulton:export-D102367287

Conversation

@sdaulton
Copy link
Copy Markdown
Contributor

@sdaulton sdaulton commented Apr 24, 2026

Summary:

botorch.models.fully_bayesian and botorch.models.fully_bayesian_multitask eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an ImportError even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a try/except block, deferring the failure to when AbstractFullyBayesianSingleTaskGP or SaasFullyBayesianMultiTaskGP is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since from __future__ import annotations is already present, type annotations referencing jax.Array remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287

@meta-cla meta-cla Bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Apr 24, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Apr 24, 2026

@sdaulton has exported this pull request. If you are a Meta employee, you can view the originating Diff in D102367287.

sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 24, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 24, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.98%. Comparing base (fa4bc2b) to head (07f286c).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #3292   +/-   ##
=======================================
  Coverage   99.98%   99.98%           
=======================================
  Files         225      225           
  Lines       22249    22250    +1     
=======================================
+ Hits        22245    22246    +1     
  Misses          4        4           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 24, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
@sdaulton sdaulton force-pushed the export-D102367287 branch 2 times, most recently from 2979002 to 3f376a6 Compare April 24, 2026 19:03
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 24, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
@sdaulton sdaulton force-pushed the export-D102367287 branch 3 times, most recently from 93466c0 to 6701523 Compare April 24, 2026 20:33
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 24, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 24, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 26, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
@meta-codesync meta-codesync Bot changed the title Make JAX/numpyro imports lazy in fully Bayesian models Make JAX/numpyro imports lazy in fully Bayesian models (#3292) Apr 26, 2026
@sdaulton sdaulton force-pushed the export-D102367287 branch from 6701523 to 76262ca Compare April 26, 2026 20:29
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 26, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
…ytorch#3291)

Summary:

X-link: facebook/Ax#5190

Reviewed By: saitcakmak

Differential Revision: D102367257
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
@sdaulton sdaulton force-pushed the export-D102367287 branch from 76262ca to 07f286c Compare April 26, 2026 20:31
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 26, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 26, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 26, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
sdaulton added a commit to sdaulton/botorch that referenced this pull request Apr 26, 2026
…3292)

Summary:

`botorch.models.fully_bayesian` and `botorch.models.fully_bayesian_multitask` eagerly import JAX and numpyro at module level. When JAX requires NumPy >= 2.0 but the environment has an older NumPy (e.g., pyper Bento kernel with NumPy 1.24), this causes an `ImportError` even for callers that never use fully Bayesian models.

This diff wraps the JAX/numpyro imports in a `try/except` block, deferring the failure to when `AbstractFullyBayesianSingleTaskGP` or `SaasFullyBayesianMultiTaskGP` is actually instantiated. The error message is preserved and clearly states the NumPy >= 2.0 requirement. Since `from __future__ import annotations` is already present, type annotations referencing `jax.Array` remain valid as strings.

Reviewed By: saitcakmak

Differential Revision: D102367287
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant