BUG: Include broadcasting for rtol argument in matrix_rank#25877
BUG: Include broadcasting for rtol argument in matrix_rank#25877mhvk merged 1 commit intonumpy:mainfrom
rtol argument in matrix_rank#25877Conversation
|
This looks like it fixes the problem, although I'm not sure if the broadcast_to is necessary. Should we add tests? |
|
Thanks! A regression test would be great indeed. |
numpy/linalg/_linalg.py
Outdated
|
|
||
| if rtol is None: | ||
| rtol = max(A.shape[-2:]) * finfo(S.dtype).eps | ||
| rtol = broadcast_to(rtol, A.shape[:-2])[..., newaxis] |
There was a problem hiding this comment.
This should be inside the if tol is None branch. Indeed, the if rtol is None statement should be moved there - no point calculating rtol if it is not going to be used anyway.
There was a problem hiding this comment.
I changed it to rtol = asarray(rtol)[..., newaxis] as we still want to call asarray (similarly asarray(tol) later) to accept list inputs, like: rtol = [0.2, 0.3, 0.1].
numpy/linalg/tests/test_linalg.py
Outdated
| assert_equal(matrix_rank(ms), np.array([3, 4, 0])) | ||
| # works on scalar | ||
| assert_equal(matrix_rank(1), 1) | ||
| assert_equal(matrix_rank(I, tol=0.0), matrix_rank(I, rtol=0.0)) |
There was a problem hiding this comment.
I'm fairly sure this would not have failed without the fix - would suggest using the explicit example from the issue.
There was a problem hiding this comment.
Sure! I added a regression test with an example from the issue.
rtol argument in matrix_rankrtol argument in matrix_rank
0009fa6 to
cd652a7
Compare
I changed |
numpy/linalg/_linalg.py
Outdated
| if rtol is not None and tol is not None: | ||
| raise ValueError("`tol` and `rtol` can't be both set.") | ||
|
|
||
| if rtol is None: |
There was a problem hiding this comment.
Could you still move this and the next two lines inside the if tol is None branch? Also, since in this branch, rtol is guaranteed to be a scalar, it can be,
if tol is None:
if rtol is None:
rtol = max(A.shape[-2:]) * finfo(S.dtype).eps
else:
rtol = asarray(rtol)[..., newaxis]
tol = S.max(axis=-1, keepdims=True) * rtol
else:
...
There was a problem hiding this comment.
Ah right, I confused tol with rtol in your previous comment - now I should be right.
cd652a7 to
b5f5d18
Compare
mhvk
left a comment
There was a problem hiding this comment.
OK, thanks for the quick change. Looks all good now!
Hi @asmeurer,
This PR fixes broadcasting of
rtolargument inmatrix_rankfunction that you pointed out in #25437 (comment).