我正在尝试优化一些代码,其中一个耗时的操作如下:
import numpy as np survivors = np.where(a > 0)[0] pos = len(survivors) a[:pos] = a[survivors] b[:pos] = b[survivors] c[:pos] = c[survivors]在我的代码中, a是一个非常大的(超过100000个)NumPy浮点数组。 他们中的许多人将是0。
有没有办法加快速度?
I'm trying to optimize some code and one of the time consuming operations is the following:
import numpy as np survivors = np.where(a > 0)[0] pos = len(survivors) a[:pos] = a[survivors] b[:pos] = b[survivors] c[:pos] = c[survivors]In my code a is a very large (more than 100000) NumPy array of floats. Many of them will be 0.
Is there a way to speed this up?
最满意答案
据我所知,没有任何东西可以用纯粹的NumPy来加速它。 但是,如果你有numba,你可以使用jitted函数编写自己的“选择”版本:
import numba as nb @nb.njit def selection(a, b, c): insert_idx = 0 for idx, item in enumerate(a): if item > 0: a[insert_idx] = a[idx] b[insert_idx] = b[idx] c[insert_idx] = c[idx] insert_idx += 1在我的测试运行中,这比你的NumPy代码快了大约2倍。 然而,如果你不使用conda numba可能是一个沉重的依赖。
例:
>>> import numpy as np >>> a = np.array([0., 1., 2., 0.]) >>> b = np.array([1., 2., 3., 4.]) >>> c = np.array([1., 2., 3., 4.]) >>> selection(a, b, c) >>> a, b, c (array([ 1., 2., 2., 0.]), array([ 2., 3., 3., 4.]), array([ 2., 3., 3., 4.]))定时:
由于所有方法都在原地工作,所以很难准确计算时间,因此我实际上使用timeit.repeat来测量number=1 (避免由于解决方案的原位而导致时间损坏)并且我使用了最终的时间列表的min ,因为它被宣传为文档中最有用的量化指标:
注意
从结果向量计算平均值和标准偏差并报告这些是很诱人的。 但是,这不是很有用。 在典型情况下,最低值给出了机器运行给定代码段的速度的下限; 结果向量中较高的值通常不是由Python的速度变化引起的,而是由于其他过程干扰您的计时准确性。 因此结果的min()可能是您应该感兴趣的唯一数字。之后,您应该查看整个向量并应用常识而不是统计。
Numba解决方案
import timeit min(timeit.repeat("""selection(a, b, c)""", """import numpy as np from __main__ import selection a = np.arange(1000000) % 3 b = a.copy() c = a.copy() """, repeat=100, number=1))0.007700118746939211
原始方案
import timeit min(timeit.repeat("""survivors = np.where(a > 0)[0] pos = len(survivors) a[:pos] = a[survivors] b[:pos] = b[survivors] c[:pos] = c[survivors]""", """import numpy as np a = np.arange(1000000) % 3 b = a.copy() c = a.copy() """, repeat=100, number=1))0.028622144571883723
Alexander McFarlane的解决方案(现已删除)
import timeit min(timeit.repeat("""survivors = comb_array[:, 0].nonzero()[0] comb_array[:len(survivors)] = comb_array[survivors]""", """import numpy as np a = np.arange(1000000) % 3 b = a.copy() c = a.copy() comb_array = np.vstack([a,b,c]).T""", repeat=100, number=1))0.058305527038669425
因此,Numba解决方案实际上可以将此速度提高3-4倍,而Alexander McFarlane的解决方案实际上比原始方法更慢(2倍)。 然而,少量的repeat s可能会稍微偏向时间。
As far as I see it there's nothing that could speed it up with pure NumPy. However if you have numba you could write your own version of this "selection" using a jitted function:
import numba as nb @nb.njit def selection(a, b, c): insert_idx = 0 for idx, item in enumerate(a): if item > 0: a[insert_idx] = a[idx] b[insert_idx] = b[idx] c[insert_idx] = c[idx] insert_idx += 1In my test runs this was roughly a factor 2 faster than your NumPy code. However numba might be a heavy dependency if you're not using conda.
Example:
>>> import numpy as np >>> a = np.array([0., 1., 2., 0.]) >>> b = np.array([1., 2., 3., 4.]) >>> c = np.array([1., 2., 3., 4.]) >>> selection(a, b, c) >>> a, b, c (array([ 1., 2., 2., 0.]), array([ 2., 3., 3., 4.]), array([ 2., 3., 3., 4.]))Timing:
It's hard to time this accuratly because all approaches work in-place, so I actually use timeit.repeat to measure the timings with a number=1 (that avoids broken timings due to the in-place-ness of the solutions) and I used the min of the resulting list of timings because that's advertised as the most useful quantitative measure in the documentation:
Note
It’s tempting to calculate mean and standard deviation from the result vector and report these. However, this is not very useful. In a typical case, the lowest value gives a lower bound for how fast your machine can run the given code snippet; higher values in the result vector are typically not caused by variability in Python’s speed, but by other processes interfering with your timing accuracy. So the min() of the result is probably the only number you should be interested in. After that, you should look at the entire vector and apply common sense rather than statistics.
Numba solution
import timeit min(timeit.repeat("""selection(a, b, c)""", """import numpy as np from __main__ import selection a = np.arange(1000000) % 3 b = a.copy() c = a.copy() """, repeat=100, number=1))0.007700118746939211
Original solution
import timeit min(timeit.repeat("""survivors = np.where(a > 0)[0] pos = len(survivors) a[:pos] = a[survivors] b[:pos] = b[survivors] c[:pos] = c[survivors]""", """import numpy as np a = np.arange(1000000) % 3 b = a.copy() c = a.copy() """, repeat=100, number=1))0.028622144571883723
Alexander McFarlane's solution (now deleted)
import timeit min(timeit.repeat("""survivors = comb_array[:, 0].nonzero()[0] comb_array[:len(survivors)] = comb_array[survivors]""", """import numpy as np a = np.arange(1000000) % 3 b = a.copy() c = a.copy() comb_array = np.vstack([a,b,c]).T""", repeat=100, number=1))0.058305527038669425
So the Numba solution can actually speed this up by a factor 3-4 while the solution of Alexander McFarlane is actually slower (2x) than the original approach. However the small number of repeats may bias the timings somewhat.
更多推荐
发布评论