Skip to content

Adding array-api-compat fallback#159

Draft
amalia-k510 wants to merge 31 commits intoscverse:mainfrom
amalia-k510:array-api-implementation
Draft

Adding array-api-compat fallback#159
amalia-k510 wants to merge 31 commits intoscverse:mainfrom
amalia-k510:array-api-implementation

Conversation

@amalia-k510
Copy link
Copy Markdown

This PR adds array-api-compat as a fallback in the singledispatch functions across the stats and conv modules so that Array API-compatible arrays (JAX, PyTorch, and others) work out of the box without needing to register each backend individually. The approach is: in each fallback, try array_api_compat.array_namespace(x) first. If it recognizes the array, dispatch through the standard Array API; if not, fall through to the existing numpy path. This touches _generic_ops.py (sum/min/max), _is_constant.py, _power.py, and conv/_to_dense.py. array-api-compat is added as a dependency in pyproject.toml.

Tests are in tests/test_jax.py covering all the stats functions and to_dense with JAX arrays. One thing to flag: _mean_var.py didn't need changes since it calls mean() and power() internally which already go through the fixed dispatchers, at least in my understanding, but JAX requires jax_enable_x64 for the dtype=np.float64 calls to work.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 23, 2026

Codecov Report

❌ Patch coverage is 95.65217% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 97.32%. Comparing base (3fac19a) to head (3d4ee3a).

Files with missing lines Patch % Lines
src/fast_array_utils/stats/_generic_ops.py 92.30% 1 Missing ⚠️
src/fast_array_utils/stats/_power.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #159      +/-   ##
==========================================
- Coverage   99.29%   97.32%   -1.97%     
==========================================
  Files          21       21              
  Lines         566      598      +32     
==========================================
+ Hits          562      582      +20     
- Misses          4       16      +12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@codspeed-hq
Copy link
Copy Markdown

codspeed-hq Bot commented Mar 23, 2026

Merging this PR will degrade performance by 13.15%

⚠️ Different runtime environments detected

Some benchmarks with significant performance changes were compared across different runtime environments,
which may affect the accuracy of the results.

Open the report in CodSpeed to investigate

❌ 1 regressed benchmark
✅ 231 untouched benchmarks

⚠️ Please fix the performance issues or acknowledge them on CodSpeed.

Performance Changes

Benchmark BASE HEAD Efficiency
test_stats_benchmark[scipy.sparse.csc_array-2d-ax0-float32-is_constant] 2.9 ms 3.3 ms -13.15%

Comparing amalia-k510:array-api-implementation (3d4ee3a) with main (3fac19a)

Open in CodSpeed

Copy link
Copy Markdown
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good start. I wonder if adding a ArrayAPIObject protocol that checks for __array_namespace__ or so could be used instead of putting the array api stuff in the @singledispatch fallback body. What did we do in the AnnData PR?

Comment thread src/fast_array_utils/conv/_to_dense.py Outdated
if not isinstance(x, np.ndarray) and array_api_compat.is_array_api_obj(x):
if to_cpu_memory:
return np.asarray(x, order=order)
return x # already dense
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think that follows, but I also don’t know if we can do better.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I read, the array API standard only covers dense arrays right now, and sparse types are already caught by their own registered handlers before reaching this fallback. So, from my understanding, if anything that lands here and passes is_array_api_obj is going to be dense. Pydata/Sparse are under active development and once they add sparse support later, we can just register a new handler for it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does? I didn’t know it says anything about dense vs sparse. I assumed it just didn’t have any sparse-specific methods defined.

Copy link
Copy Markdown
Author

@amalia-k510 amalia-k510 Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were right to flag that. I assumed that since the standard focuses on dense arrays, anything reaching this fallback would be dense. But it doesn't actually distinguish dense from sparse, plus, after review, pydata/sparse does implement __array_namespace__, so it would hit this fallback and skip the dense step. I don't think there's a generic way to detect sparsity in the standard either. Should we just register handlers for specific sparse types as they come up, or is there a better way of handling it?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just register handlers for specific sparse types as they come up

yeah I think that’s best for now.

Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread 03_23_2026.log Outdated
Copy link
Copy Markdown
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, before we move on, I need to understand this comment:

Catch array-api-compat-wrapped types that lack __array_namespace__ (i.e. PyTorch)

Once I do, I can form an actual opinion about how I’d like this to look (all my comments about an ABC below rely on this). I commented on the line of the comment below so we keep that in a subthread.

Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread pyproject.toml Outdated
Comment thread src/fast_array_utils/stats/_is_constant.py Outdated
Comment thread src/fast_array_utils/stats/_is_constant.py Outdated
Comment thread src/fast_array_utils/stats/_mean_var.py Outdated
@flying-sheep
Copy link
Copy Markdown
Member

flying-sheep commented Apr 20, 2026

OK, so basically torch doesn’t actually support array API yet (see #159 (comment)), so all the fallback code is just for torch. I don’t think torch should be part of this PR then, we should think about it separately if we want to support it.

So please

  1. move the torch parts (i.e. the fallback code and the torch tests) out if this PR
  2. so wherever there is only a if array_api_compat.is_array_api_obj branch and no @register(... | HasArrayNamespace) branch, create that register
  3. revert the Any changes. I can fix the types but essentially unions will gain a | HasArrayNamespace member and all touched functions gain an overload like def func[A: HasArrayNamespace](x: A, ...) -> A: ...

Regarding the benchmarks: I’m not so sure if this is just static overhead … I think your xp.pow(xp.astype(x, dtype), n) is just genuinely slower than the branch numpy code took before. Can you improve that?

@amalia-k510
Copy link
Copy Markdown
Author

amalia-k510 commented Apr 27, 2026

I am still running into mypy issues. It's failing on import jax since jax isn't in the mypy environment. Also, from what I understand, I can't add a mypy override either since the config was removed from main. What would be the best way to handle it? @flying-sheep

Copy link
Copy Markdown
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, this is looking very close to perfect! I’ll take a look at the type issues

I can't add a mypy override either since the config was removed from main

what do you mean? the mypy dependencies are directly in .pre-commit-config.yaml (sadly)

) -> NDArray[Any] | np.number[Any]:
del keep_cupy_as_array
arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr # type: ignore[return-value]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(arr, types.CupyCOOMatrix) can’t be true here, right?

also what slowdown do you mean?

Copy link
Copy Markdown
Author

@amalia-k510 amalia-k510 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy has __array_namespace__, so without the explicit np.ndarray registration it'd match HasArrayNamespace and go through array_api_compat.array_namespace() instead of just getattr(np, op)(...). That's the extra array_api_compat overhead showed up in the benchmarks.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for mypy, sorry, I meant the mypy dependency group was removed, right? So jax isn't installed when mypy runs. Can I just add a [[tool.mypy.overrides]] for jax to the existing [tool.mypy] config?

Copy link
Copy Markdown
Author

@amalia-k510 amalia-k510 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, just saw the edited comment.

Comment thread src/fast_array_utils/stats/_power.py Outdated
Comment on lines +29 to +31
if TYPE_CHECKING:
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)
return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just raise NotImplementedError here?

Comment thread pyproject.toml Outdated
Comment on lines +149 to +152
[[tool.mypy.overrides]]
module = [ "jax", "jax.*" ]
ignore_missing_imports = true

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we want to typecheck it. I added it to the type check deps

Suggested change
[[tool.mypy.overrides]]
module = [ "jax", "jax.*" ]
ignore_missing_imports = true

@flying-sheep
Copy link
Copy Markdown
Member

OK, I started with the typing. As you can see, adding an overload to sum makes sum work where it’s called. This should be done everywhere. I can do it, as you wish.

@amalia-k510
Copy link
Copy Markdown
Author

OK, I started with the typing. As you can see, adding an overload to sum makes sum work where it’s called. This should be done everywhere. I can do it, as you wish.

Thanks for the example! I'd actually like to take that on myself if that's okay.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants