Skip to content

Comments

ENH: Vectorize argsort and argselect with AVX2#25610

Merged
seiko2plus merged 3 commits intonumpy:mainfrom
r-devulap:avx2_arg
Jan 24, 2024
Merged

ENH: Vectorize argsort and argselect with AVX2#25610
seiko2plus merged 3 commits intonumpy:mainfrom
r-devulap:avx2_arg

Conversation

@r-devulap
Copy link
Member

@r-devulap r-devulap commented Jan 17, 2024

Add AVX2 version of argsort and argselect. Benchmark numbers:

| Change   | Before [174ac7bc] <main>   | After [680b6823] <avx2_arg>   |   Ratio | Benchmark (Parameter)                                                                    |
|----------|----------------------------|-------------------------------|---------|------------------------------------------------------------------------------------------|
| +        | 63.8±0.5μs                 | 260±1μs                       |    4.07 | bench_function_base.Sort.time_argsort('quick', 'int64', ('ordered',))                    |
| +        | 69.6±0.07μs                | 210±3μs                       |    3.03 | bench_function_base.Sort.time_argsort('quick', 'int32', ('ordered',))                    |
| +        | 73.7±0.3μs                 | 220±3μs                       |    2.99 | bench_function_base.Sort.time_argsort('quick', 'float64', ('ordered',))                  |
| +        | 72.2±0.03μs                | 214±0.2μs                     |    2.96 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('ordered',))                   |
| +        | 79.1±0.07μs                | 227±0.1μs                     |    2.87 | bench_function_base.Sort.time_argsort('quick', 'float32', ('ordered',))                  |
| +        | 122±0.09μs                 | 250±5μs                       |    2.05 | bench_function_base.Sort.time_argsort('quick', 'int64', ('reversed',))                   |
| +        | 232±0.3μs                  | 468±0.4μs                     |    2.01 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 1000), 10)   |
| +        | 245±1μs                    | 492±2μs                       |    2.01 | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 1000), 100)    |
| +        | 234±0.3μs                  | 467±0.4μs                     |    2    | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 1000), 100)  |
| +        | 247±0.6μs                  | 492±1μs                       |    2    | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 1000), 10)     |
| +        | 105±0.2μs                  | 206±0.3μs                     |    1.95 | bench_function_base.Sort.time_argsort('quick', 'int32', ('reversed',))                   |
| +        | 163±1μs                    | 303±2μs                       |    1.86 | bench_function_base.Partition.time_argpartition('float32', ('ordered',), 100)            |
| +        | 163±1μs                    | 302±2μs                       |    1.85 | bench_function_base.Partition.time_argpartition('float32', ('ordered',), 1000)           |
| +        | 163±1μs                    | 300±3μs                       |    1.84 | bench_function_base.Partition.time_argpartition('float32', ('ordered',), 10)             |
| +        | 255±0.6μs                  | 469±1μs                       |    1.84 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 1000), 10)   |
| +        | 113±0.08μs                 | 206±3μs                       |    1.83 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('reversed',))                  |
| +        | 256±0.9μs                  | 467±1μs                       |    1.82 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 1000), 100)  |
| +        | 172±0.8μs                  | 301±0.2μs                     |    1.75 | bench_function_base.Partition.time_argpartition('int64', ('ordered',), 100)              |
| +        | 172±0.5μs                  | 300±0.4μs                     |    1.75 | bench_function_base.Partition.time_argpartition('int64', ('ordered',), 1000)             |
| +        | 123±0.08μs                 | 214±0.4μs                     |    1.75 | bench_function_base.Sort.time_argsort('quick', 'float64', ('reversed',))                 |
| +        | 261±0.1μs                  | 452±0.3μs                     |    1.73 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 1000), 10)     |
| +        | 261±0.4μs                  | 452±0.5μs                     |    1.73 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 1000), 100)    |
| +        | 172±0.2μs                  | 297±0.6μs                     |    1.73 | bench_function_base.Partition.time_argpartition('int64', ('ordered',), 10)               |
| +        | 127±0.2μs                  | 214±0.1μs                     |    1.69 | bench_function_base.Sort.time_argsort('quick', 'float32', ('reversed',))                 |
| +        | 184±1μs                    | 298±0.2μs                     |    1.63 | bench_function_base.Partition.time_argpartition('float64', ('ordered',), 100)            |
| +        | 183±0.7μs                  | 298±0.5μs                     |    1.63 | bench_function_base.Partition.time_argpartition('float64', ('ordered',), 1000)           |
| +        | 183±0.2μs                  | 296±0.4μs                     |    1.62 | bench_function_base.Partition.time_argpartition('float64', ('ordered',), 10)             |
| +        | 261±2μs                    | 416±6μs                       |    1.6  | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 100), 10)      |
| +        | 259±2μs                    | 415±6μs                       |    1.6  | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 100), 1000)    |
| +        | 262±2μs                    | 416±6μs                       |    1.59 | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 100), 100)     |
| +        | 264±2μs                    | 400±0.7μs                     |    1.51 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 100), 1000)  |
| +        | 271±0.8μs                  | 401±0.5μs                     |    1.48 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 100), 10)    |
| +        | 271±0.8μs                  | 400±0.8μs                     |    1.48 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 100), 100)   |
| +        | 279±3μs                    | 405±3μs                       |    1.45 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 100), 1000)  |
| +        | 281±3μs                    | 404±7μs                       |    1.44 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 100), 10)    |
| +        | 281±2μs                    | 403±4μs                       |    1.44 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 100), 100)   |
| +        | 194±2μs                    | 273±2μs                       |    1.41 | bench_function_base.Partition.time_argpartition('int32', ('ordered',), 100)              |
| +        | 193±1μs                    | 272±2μs                       |    1.41 | bench_function_base.Partition.time_argpartition('int32', ('ordered',), 1000)             |
| +        | 193±1μs                    | 270±2μs                       |    1.4  | bench_function_base.Partition.time_argpartition('int32', ('ordered',), 10)               |
| +        | 280±1μs                    | 383±0.4μs                     |    1.37 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 100), 100)     |
| +        | 279±0.9μs                  | 383±0.2μs                     |    1.37 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 100), 1000)    |
| +        | 280±1μs                    | 383±0.3μs                     |    1.36 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 100), 10)      |
| +        | 199±2μs                    | 212±0.8μs                     |    1.06 | bench_function_base.Sort.time_argsort('merge', 'float32', ('sorted_block', 10))          |
| -        | 496±0.9μs                  | 462±0.5μs                     |    0.93 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 1000), 1000) |
| -        | 170±0.7μs                  | 149±0.1μs                     |    0.88 | bench_function_base.Sort.time_argsort('merge', 'uint32', ('sorted_block', 10))           |
| -        | 533±0.6μs                  | 463±2μs                       |    0.87 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 1000), 1000) |
| -        | 532±0.3μs                  | 445±0.5μs                     |    0.84 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 1000), 1000)   |
| -        | 360±0.3μs                  | 282±0.4μs                     |    0.78 | bench_function_base.Sort.time_argsort('quick', 'int32', ('sorted_block', 1000))          |
| -        | 364±0.6μs                  | 275±0.7μs                     |    0.76 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('sorted_block', 1000))         |
| -        | 389±0.2μs                  | 293±0.6μs                     |    0.75 | bench_function_base.Sort.time_argsort('quick', 'float32', ('sorted_block', 1000))        |
| -        | 476±2μs                    | 341±3μs                       |    0.72 | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 10), 100)      |
| -        | 391±0.4μs                  | 280±0.8μs                     |    0.72 | bench_function_base.Sort.time_argsort('quick', 'float64', ('sorted_block', 1000))        |
| -        | 476±3μs                    | 340±2μs                       |    0.71 | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 10), 10)       |
| -        | 474±3μs                    | 334±3μs                       |    0.7  | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 10), 10)     |
| -        | 477±3μs                    | 336±2μs                       |    0.7  | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 10), 100)    |
| -        | 482±2μs                    | 335±3μs                       |    0.7  | bench_function_base.Partition.time_argpartition('int64', ('sorted_block', 10), 1000)     |
| -        | 474±4μs                    | 326±0.3μs                     |    0.69 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 10), 100)    |
| -        | 477±7μs                    | 326±0.2μs                     |    0.68 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 10), 10)     |
| -        | 485±3μs                    | 329±5μs                       |    0.68 | bench_function_base.Partition.time_argpartition('float64', ('sorted_block', 10), 1000)   |
| -        | 490±2μs                    | 323±0.2μs                     |    0.66 | bench_function_base.Partition.time_argpartition('float32', ('sorted_block', 10), 1000)   |
| -        | 476±2μs                    | 302±0.2μs                     |    0.64 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 10), 10)       |
| -        | 478±2μs                    | 301±0.4μs                     |    0.63 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 10), 100)      |
| -        | 483±2μs                    | 301±0.4μs                     |    0.62 | bench_function_base.Partition.time_argpartition('int32', ('sorted_block', 10), 1000)     |
| -        | 461±0.3μs                  | 284±5μs                       |    0.62 | bench_function_base.Sort.time_argsort('quick', 'int64', ('sorted_block', 10))            |
| -        | 445±0.5μs                  | 273±0.1μs                     |    0.61 | bench_function_base.Sort.time_argsort('quick', 'int64', ('sorted_block', 100))           |
| -        | 315±0.2μs                  | 169±1μs                       |    0.54 | bench_function_base.Partition.time_argpartition('float64', ('uniform',), 10)             |
| -        | 315±0.4μs                  | 169±1μs                       |    0.54 | bench_function_base.Partition.time_argpartition('float64', ('uniform',), 100)            |
| -        | 315±0.4μs                  | 169±2μs                       |    0.54 | bench_function_base.Partition.time_argpartition('float64', ('uniform',), 1000)           |
| -        | 284±0.2μs                  | 149±2μs                       |    0.53 | bench_function_base.Partition.time_argpartition('int64', ('uniform',), 1000)             |
| -        | 328±6μs                    | 169±1μs                       |    0.52 | bench_function_base.Partition.time_argpartition('float32', ('uniform',), 10)             |
| -        | 285±0.3μs                  | 149±2μs                       |    0.52 | bench_function_base.Partition.time_argpartition('int64', ('uniform',), 10)               |
| -        | 284±0.2μs                  | 149±0.2μs                     |    0.52 | bench_function_base.Partition.time_argpartition('int64', ('uniform',), 100)              |
| -        | 329±6μs                    | 169±1μs                       |    0.51 | bench_function_base.Partition.time_argpartition('float32', ('uniform',), 100)            |
| -        | 332±4μs                    | 169±1μs                       |    0.51 | bench_function_base.Partition.time_argpartition('float32', ('uniform',), 1000)           |
| -        | 471±0.1μs                  | 242±0.4μs                     |    0.51 | bench_function_base.Sort.time_argsort('quick', 'int32', ('sorted_block', 100))           |
| -        | 547±0.5μs                  | 276±5μs                       |    0.5  | bench_function_base.Sort.time_argsort('quick', 'int64', ('random',))                     |
| -        | 488±0.5μs                  | 238±2μs                       |    0.49 | bench_function_base.Sort.time_argsort('quick', 'int32', ('sorted_block', 10))            |
| -        | 479±0.4μs                  | 234±2μs                       |    0.49 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('sorted_block', 100))          |
| -        | 486±0.4μs                  | 233±3μs                       |    0.48 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('sorted_block', 10))           |
| -        | 283±2μs                    | 134±1μs                       |    0.47 | bench_function_base.Partition.time_argpartition('int32', ('uniform',), 10)               |
| -        | 286±2μs                    | 133±1μs                       |    0.47 | bench_function_base.Partition.time_argpartition('int32', ('uniform',), 100)              |
| -        | 284±2μs                    | 134±1μs                       |    0.47 | bench_function_base.Partition.time_argpartition('int32', ('uniform',), 1000)             |
| -        | 531±0.2μs                  | 252±4μs                       |    0.47 | bench_function_base.Sort.time_argsort('quick', 'float32', ('sorted_block', 10))          |
| -        | 514±0.2μs                  | 240±0.2μs                     |    0.47 | bench_function_base.Sort.time_argsort('quick', 'float32', ('sorted_block', 100))         |
| -        | 524±0.3μs                  | 232±5μs                       |    0.44 | bench_function_base.Sort.time_argsort('quick', 'float64', ('sorted_block', 100))         |
| -        | 542±0.2μs                  | 232±0.7μs                     |    0.43 | bench_function_base.Sort.time_argsort('quick', 'float64', ('sorted_block', 10))          |
| -        | 576±0.3μs                  | 235±0.2μs                     |    0.41 | bench_function_base.Sort.time_argsort('quick', 'int32', ('random',))                     |
| -        | 572±0.4μs                  | 226±2μs                       |    0.4  | bench_function_base.Sort.time_argsort('quick', 'uint32', ('random',))                    |
| -        | 631±0.4μs                  | 240±6μs                       |    0.38 | bench_function_base.Sort.time_argsort('quick', 'float32', ('random',))                   |
| -        | 1.06±0ms                   | 390±2μs                       |    0.37 | bench_function_base.Partition.time_argpartition('int64', ('random',), 1000)              |
| -        | 1.06±0ms                   | 378±1μs                       |    0.36 | bench_function_base.Partition.time_argpartition('int64', ('random',), 10)                |
| -        | 1.06±0ms                   | 378±2μs                       |    0.36 | bench_function_base.Partition.time_argpartition('int64', ('random',), 100)               |
| -        | 655±0.2μs                  | 232±6μs                       |    0.35 | bench_function_base.Sort.time_argsort('quick', 'float64', ('random',))                   |
| -        | 1.12±0.01ms                | 379±3μs                       |    0.34 | bench_function_base.Partition.time_argpartition('float32', ('random',), 1000)            |
| -        | 1.12±0.01ms                | 368±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('float32', ('random',), 10)              |
| -        | 1.12±0.01ms                | 368±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('float32', ('random',), 100)             |
| -        | 1.18±0ms                   | 388±1μs                       |    0.33 | bench_function_base.Partition.time_argpartition('float64', ('random',), 1000)            |
| -        | 1.06±0.01ms                | 347±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('int32', ('random',), 10)                |
| -        | 1.06±0.01ms                | 346±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('int32', ('random',), 100)               |
| -        | 1.07±0.01ms                | 356±3μs                       |    0.33 | bench_function_base.Partition.time_argpartition('int32', ('random',), 1000)              |
| -        | 1.17±0ms                   | 377±1μs                       |    0.32 | bench_function_base.Partition.time_argpartition('float64', ('random',), 10)              |
| -        | 1.17±0ms                   | 377±2μs                       |    0.32 | bench_function_base.Partition.time_argpartition('float64', ('random',), 100)             |
| -        | 75.4±0.05μs                | 15.7±0.4μs                    |    0.21 | bench_function_base.Sort.time_argsort('quick', 'int64', ('uniform',))                    |
| -        | 1.37±0ms                   | 258±0.9μs                     |    0.19 | bench_function_base.Partition.time_argpartition('float64', ('reversed',), 10)            |
| -        | 1.38±0.01ms                | 258±0.8μs                     |    0.19 | bench_function_base.Partition.time_argpartition('float64', ('reversed',), 100)           |
| -        | 1.37±0.01ms                | 258±3μs                       |    0.19 | bench_function_base.Partition.time_argpartition('float64', ('reversed',), 1000)          |
| -        | 1.56±0.01ms                | 261±2μs                       |    0.17 | bench_function_base.Partition.time_argpartition('float32', ('reversed',), 10)            |
| -        | 1.57±0.01ms                | 260±2μs                       |    0.17 | bench_function_base.Partition.time_argpartition('float32', ('reversed',), 100)           |
| -        | 1.57±0.01ms                | 260±2μs                       |    0.17 | bench_function_base.Partition.time_argpartition('float32', ('reversed',), 1000)          |
| -        | 1.48±0ms                   | 253±0.5μs                     |    0.17 | bench_function_base.Partition.time_argpartition('int64', ('reversed',), 10)              |
| -        | 1.48±0.01ms                | 253±0.6μs                     |    0.17 | bench_function_base.Partition.time_argpartition('int64', ('reversed',), 100)             |
| -        | 1.48±0.01ms                | 253±1μs                       |    0.17 | bench_function_base.Partition.time_argpartition('int64', ('reversed',), 1000)            |
| -        | 1.49±0.01ms                | 228±2μs                       |    0.15 | bench_function_base.Partition.time_argpartition('int32', ('reversed',), 10)              |
| -        | 1.49±0.01ms                | 228±2μs                       |    0.15 | bench_function_base.Partition.time_argpartition('int32', ('reversed',), 100)             |
| -        | 1.49±0.01ms                | 228±2μs                       |    0.15 | bench_function_base.Partition.time_argpartition('int32', ('reversed',), 1000)            |
| -        | 94.4±0.6μs                 | 14.3±0.08μs                   |    0.15 | bench_function_base.Sort.time_argsort('quick', 'int32', ('uniform',))                    |
| -        | 94.5±0.3μs                 | 14.1±0.1μs                    |    0.15 | bench_function_base.Sort.time_argsort('quick', 'uint32', ('uniform',))                   |
| -        | 131±0.1μs                  | 17.8±0.09μs                   |    0.14 | bench_function_base.Sort.time_argsort('quick', 'float32', ('uniform',))                  |
| -        | 143±0.08μs                 | 18.0±0.4μs                    |    0.13 | bench_function_base.Sort.time_argsort('quick', 'float64', ('uniform',))                  |

@seiko2plus seiko2plus added the component: SIMD Issues in SIMD (fast instruction sets) code or machinery label Jan 18, 2024
@seiko2plus seiko2plus merged commit 221427b into numpy:main Jan 24, 2024
@seiko2plus
Copy link
Member

Thank you @r-devulap!

@lesteve
Copy link
Contributor

lesteve commented Feb 1, 2024

There was a new failure in scikit-learn when testing against numpy dev because of this change, see scikit-learn/scikit-learn#28326 for more details. The np.argsort changes enough that the number of clusters in HDBSCAN is 3 with numpy dev, and 2 with numpy 1.26.3.

The scikit-learn tests pass in 0a4b2b8 (previous merge commit in main) and fail in 221427b (the merge commit for this PR).

I could not find any mention of this change in the changelog, but maybe I missed it?

I think this would be worth to adding a changelog entry about this change to indicate that np.argsort and np.argselect results may change in numpy 2.0.

@rgommers
Copy link
Member

rgommers commented Mar 5, 2024

I think this would be worth to adding a changelog entry about this change to indicate that np.argsort and np.argselect results may change in numpy 2.0.

I verified that this is indeed common when sorting integers.

>>> import hashlib
>>> import numpy as np

>>> rng = np.random.RandomState(seed=123098)  # minor: reproducers across versions should use RandomState
>>> x = rng.randint(100, size=10_000)
>>> hashlib.sha256(np.argsort(x).tobytes()).hexdigest()

I'll include this note in my next release notes PR:

Minor changes in behavior of sorting functions
----------------------------------------------

Due to algorithmic changes and use of SIMD code, sorting functions with methods
that aren't stable may return slightly different results in 2.0.0 compared to
1.26.x. This includes the default method of `~numpy.sort` and `~numpy.argsort`.

rgommers added a commit to rgommers/numpy that referenced this pull request Mar 5, 2024
As asked for in numpygh-25610

[skip actions] [skip azp] [skip cirrus]
rgommers added a commit to rgommers/numpy that referenced this pull request Mar 6, 2024
As asked for in numpygh-25610

[skip actions] [skip azp] [skip cirrus]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

01 - Enhancement component: SIMD Issues in SIMD (fast instruction sets) code or machinery

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants