Fix the torch.take() wrapper to make axis optional for ndim = 1#47
Fix the torch.take() wrapper to make axis optional for ndim = 1#47asmeurer merged 1 commit intodata-apis:mainfrom
Conversation
| if axis is None: | ||
| if x.ndim != 1: | ||
| raise ValueError("axis must be specified when ndim > 1") | ||
| axis = 0 |
There was a problem hiding this comment.
Or you can dispatch to torch.take, either should be alright.
There was a problem hiding this comment.
torch.take doesn't support an axis argument.
There was a problem hiding this comment.
exactly :D torch.take() == numpy.take(axis=None)
There was a problem hiding this comment.
it's a bit more general, as it sees the whole tensor as a 0dim tensor and indexes into it, so it also works for ndim > 1.
There was a problem hiding this comment.
Right. The flattening behavior is also there in NumPy, but I think we wanted to avoid that with the array API take. I'm not sure to what degree we should try to avoid that for the compat library, though (c.f. #34 (comment))
betatim
left a comment
There was a problem hiding this comment.
Looks like a sensible way to implement this "sometimes optional" behaviour
Closes #34