Adding array-api-compat fallback#159
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #159 +/- ##
==========================================
- Coverage 99.29% 97.32% -1.97%
==========================================
Files 21 21
Lines 566 598 +32
==========================================
+ Hits 562 582 +20
- Misses 4 16 +12 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Merging this PR will degrade performance by 13.15%
|
| Benchmark | BASE |
HEAD |
Efficiency | |
|---|---|---|---|---|
| ❌ | test_stats_benchmark[scipy.sparse.csc_array-2d-ax0-float32-is_constant] |
2.9 ms | 3.3 ms | -13.15% |
Comparing amalia-k510:array-api-implementation (3d4ee3a) with main (3fac19a)
flying-sheep
left a comment
There was a problem hiding this comment.
good start. I wonder if adding a ArrayAPIObject protocol that checks for __array_namespace__ or so could be used instead of putting the array api stuff in the @singledispatch fallback body. What did we do in the AnnData PR?
| if not isinstance(x, np.ndarray) and array_api_compat.is_array_api_obj(x): | ||
| if to_cpu_memory: | ||
| return np.asarray(x, order=order) | ||
| return x # already dense |
There was a problem hiding this comment.
I don’t think that follows, but I also don’t know if we can do better.
There was a problem hiding this comment.
From what I read, the array API standard only covers dense arrays right now, and sparse types are already caught by their own registered handlers before reaching this fallback. So, from my understanding, if anything that lands here and passes is_array_api_obj is going to be dense. Pydata/Sparse are under active development and once they add sparse support later, we can just register a new handler for it.
There was a problem hiding this comment.
It does? I didn’t know it says anything about dense vs sparse. I assumed it just didn’t have any sparse-specific methods defined.
There was a problem hiding this comment.
You were right to flag that. I assumed that since the standard focuses on dense arrays, anything reaching this fallback would be dense. But it doesn't actually distinguish dense from sparse, plus, after review, pydata/sparse does implement __array_namespace__, so it would hit this fallback and skip the dense step. I don't think there's a generic way to detect sparsity in the standard either. Should we just register handlers for specific sparse types as they come up, or is there a better way of handling it?
There was a problem hiding this comment.
Should we just register handlers for specific sparse types as they come up
yeah I think that’s best for now.
for more information, see https://pre-commit.ci
…k510/fast-array-utils into array-api-implementation
There was a problem hiding this comment.
OK, before we move on, I need to understand this comment:
Catch array-api-compat-wrapped types that lack
__array_namespace__(i.e. PyTorch)
Once I do, I can form an actual opinion about how I’d like this to look (all my comments about an ABC below rely on this). I commented on the line of the comment below so we keep that in a subthread.
Co-authored-by: Philipp A. <flying-sheep@web.de>
for more information, see https://pre-commit.ci
|
OK, so basically torch doesn’t actually support array API yet (see #159 (comment)), so all the fallback code is just for torch. I don’t think torch should be part of this PR then, we should think about it separately if we want to support it. So please
Regarding the benchmarks: I’m not so sure if this is just static overhead … I think your |
…k510/fast-array-utils into array-api-implementation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…k510/fast-array-utils into array-api-implementation
for more information, see https://pre-commit.ci
|
I am still running into mypy issues. It's failing on |
| ) -> NDArray[Any] | np.number[Any]: | ||
| del keep_cupy_as_array | ||
| arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)) | ||
| return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr # type: ignore[return-value] |
There was a problem hiding this comment.
if isinstance(arr, types.CupyCOOMatrix) can’t be true here, right?
also what slowdown do you mean?
There was a problem hiding this comment.
numpy has __array_namespace__, so without the explicit np.ndarray registration it'd match HasArrayNamespace and go through array_api_compat.array_namespace() instead of just getattr(np, op)(...). That's the extra array_api_compat overhead showed up in the benchmarks.
There was a problem hiding this comment.
As for mypy, sorry, I meant the mypy dependency group was removed, right? So jax isn't installed when mypy runs. Can I just add a [[tool.mypy.overrides]] for jax to the existing [tool.mypy] config?
There was a problem hiding this comment.
nvm, just saw the edited comment.
| if TYPE_CHECKING: | ||
| assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix) | ||
| return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator] |
There was a problem hiding this comment.
maybe just raise NotImplementedError here?
…k510/fast-array-utils into array-api-implementation
for more information, see https://pre-commit.ci
…k510/fast-array-utils into array-api-implementation
| [[tool.mypy.overrides]] | ||
| module = [ "jax", "jax.*" ] | ||
| ignore_missing_imports = true | ||
|
|
There was a problem hiding this comment.
no, we want to typecheck it. I added it to the type check deps
| [[tool.mypy.overrides]] | |
| module = [ "jax", "jax.*" ] | |
| ignore_missing_imports = true |
|
OK, I started with the typing. As you can see, adding an overload to |
Thanks for the example! I'd actually like to take that on myself if that's okay. |
This PR adds
array-api-compatas a fallback in thesingledispatchfunctions across the stats and conv modules so that Array API-compatible arrays (JAX, PyTorch, and others) work out of the box without needing to register each backend individually. The approach is: in each fallback, tryarray_api_compat.array_namespace(x)first. If it recognizes the array, dispatch through the standard Array API; if not, fall through to the existing numpy path. This touches_generic_ops.py(sum/min/max),_is_constant.py,_power.py, andconv/_to_dense.py.array-api-compatis added as a dependency inpyproject.toml.Tests are in
tests/test_jax.pycovering all the stats functions andto_densewith JAX arrays. One thing to flag:_mean_var.pydidn't need changes since it callsmean()andpower()internally which already go through the fixed dispatchers, at least in my understanding, but JAX requiresjax_enable_x64for thedtype=np.float64calls to work.