MAINT: Fix computation of numpy.array_api.linalg.vector_norm#21084
MAINT: Fix computation of numpy.array_api.linalg.vector_norm#21084seberg merged 4 commits intonumpy:mainfrom
Conversation
Various pieces were incorrect due to a lack of complete coverage of this function in the array API test suite.
Previously it would always give float64 because an internal calculation involved a NumPy scalar and a Python float. The fix is to use a 0-D array instead of a NumPy scalar so that it type promotes with the float correctly. Fixes numpy#21083 I don't have a test for this yet because I'm unclear how exactly to test it.
numpy/array_api/linalg.py
Outdated
| if a.ndim == 0: | ||
| a = a[None] | ||
| else: | ||
| a = a.flatten() |
There was a problem hiding this comment.
I don't understand, flatten should always return a 1-D array?
EDIT: Actually, why not .ravel()? flatten forces a copy and that is really not necessary (ravel copies more often than you would expect too, but at least not always.)
There was a problem hiding this comment.
Oh I never realized that about flatten. We have to make it a 1-D array to force it to do a vector norm (this weird shape-based behavior is why we split norm into vector_norm and matrix_norm in the array API).
There was a problem hiding this comment.
Yes, but np.array(0).ravel() (or flatten) also is 1D, this is not <= 1 but == 1, so ravel() covers the a[None] path just as well.
| (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest])) | ||
| _axis = 0 | ||
| else: | ||
| _axis = axis |
There was a problem hiding this comment.
We do have a normalize_axis_tuple helper in numpy/core/numeric.py, maybe worthwhile?
There was a problem hiding this comment.
Ah, I should have guessed there are probably already helpers in NumPy to do some of these things. I'll see if I can find anything that makes this simpler.
numpy/array_api/linalg.py
Outdated
| elif isinstance(axis, int): | ||
| _axis = (axis,) | ||
| else: | ||
| _axis = axis |
There was a problem hiding this comment.
Might put this together with the setup (i.e. figuring out the result shape there? But looks fine.
|
Frankly, no, this is almost certainly ancient code from the early days of NumPy (or possibly older!), I would just git grep for |
|
Yeah, I figured out where the norm tests are. I just am really unclear how exactly they work. Just reading them it looks to me like this ought to already be tested (specifically here). Actually, looking closer now, I think I see the problem, which is that the test here doesn't actually test the output dtype, just that it is a floating dtype. |
|
On first glance, I think the tests are just buggy. It would be nice to just replace this stuff with |
|
I agree. I don't know if this sort of thing is standard in the NumPy test suite or if it's just a weird leftover thing from ages ago that hasn't been modernized. |
|
Just a weird leftover, probably it was a good and typical pattern at some point. I would not mind if it was just parametrized even if that is a biggish diff, on first sight, it looks straight forward enough. |
|
I also didn't realize those test are also testing integer inputs to norm() (which aren't even in scope in the array API). So testing the correct dtype is a little more complicated. Is there a helper function to get the correct promoted type? Also, I hope this isn't a problem, but I'm probably not going to refactor those tests to use parameterization. It would make a good first issue for someone to do that, but I really just want to add a test for my fix if it's not too complicated. |
I dunno
nah sure, but doesn't hurt to hope ;). |
|
@seberg's comments on using |
|
I think fiddling with the existing linalg tests is going to be a deeper rabbit hole than I really want to get down for this PR. If you want I can add a simple, separate regression test for the norm() fix. Otherwise, I would suggest opening a "good first contribution" issue for refactoring those tests to use pytest parameterization, and also to more explicitly test the result type in each case. I have cleaned up the implementation of numpy.array_api.vector_norm a little bit as per the review comments. |
|
Unfortunately there is now a (probably very small) merge conflict. @honno otherwise you were happy with this and we should just merge it? |
Yep LGTM other then merge conflicts. Again @seberg's previous suggestions on However let's just see if @asmeurer is okay with this PR still, as there have been more questions and clarifications on the |
|
The actual fix in the array_api is still relevant. I have fixed the conflicts. |
|
Lets put it in then. Thanks! |
There were several issues due to the fact that it was mostly untested in the array API test suite.
This also fixes #21083. I'm unclear exactly how I should add a test for this.
For the array API vector_norm, tests are being added to the array API test suite.