Conversation
|
|
|
Edit: I'm wrong. Thanks to #10374, I think that there's a deeper problem here with numpy type promotion, and how |
|
Ah, I see. x = np.ndarray(1, dtype=np.float32)
x **= 1.0 / 2.0
x.dtype # => float32
x = np.float32(1)
x **= 1.0 / 2.0
x.dtype # => float64
I think the code in this PR is generic (i.e., works with both scalar and array). import numpy as np
print('scalar', np.linalg.norm(np.ones((2,), dtype=np.float32), 3, 0, False).dtype)
print('array', np.linalg.norm(np.ones((2,2), dtype=np.float32), 3, 0, False).dtype)Output from NumPy 1.14.1: Output from NumPy 1.14.1 + this PR: |
|
Fix looks good, but this needs a test. |
I tested
np.linalg.normwith the following code, but it seems the type is not preserved correctly.Without this fix (NumPy 1.14.1):
With NumPy 1.14.1 + this fix, nothing should be printed.
Related to #10368 and cupy/cupy#875 (comment)