Skip to content

Commit b24286d

Browse files
authored
[docs] Parallelization tutorial (cython#5184)
Parallelization tutorial to try to explain prange/parallel in a little more user-friendly way
1 parent ac6dd0a commit b24286d

File tree

13 files changed

+572
-0
lines changed

13 files changed

+572
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# tag: openmp
2+
3+
from cython.parallel import parallel
4+
from cython.cimports.openmp import omp_get_thread_num
5+
import cython
6+
7+
@cython.cfunc
8+
@cython.nogil
9+
def long_running_task1() -> cython.void:
10+
pass
11+
12+
@cython.cfunc
13+
@cython.nogil
14+
def long_running_task2() -> cython.void:
15+
pass
16+
17+
def do_two_tasks():
18+
thread_num: cython.int
19+
with cython.nogil, parallel(num_threads=2):
20+
thread_num = omp_get_thread_num()
21+
if thread_num == 0:
22+
long_running_task1()
23+
elif thread_num == 1:
24+
long_running_task2()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# tag: openmp
2+
3+
from cython.parallel cimport parallel
4+
from openmp cimport omp_get_thread_num
5+
6+
7+
8+
9+
cdef void long_running_task1() nogil:
10+
pass
11+
12+
13+
14+
cdef void long_running_task2() nogil:
15+
pass
16+
17+
def do_two_tasks():
18+
cdef int thread_num
19+
with nogil, parallel(num_threads=2):
20+
thread_num = omp_get_thread_num()
21+
if thread_num == 0:
22+
long_running_task1()
23+
elif thread_num == 1:
24+
long_running_task2()
25+
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# distutils: language = c++
2+
3+
from cython.parallel import parallel, prange
4+
from cython.cimports.libc.stdlib import malloc, free
5+
from cython.cimports.libcpp.algorithm import nth_element
6+
import cython
7+
from cython.operator import dereference
8+
9+
import numpy as np
10+
11+
@cython.boundscheck(False)
12+
@cython.wraparound(False)
13+
def median_along_axis0(x: cython.double[:,:]):
14+
out: cython.double[::1] = np.empty(x.shape[1])
15+
i: cython.Py_ssize_t
16+
j: cython.Py_ssize_t
17+
scratch: cython.pointer(cython.double)
18+
median_it: cython.pointer(cython.double)
19+
with cython.nogil, parallel():
20+
# allocate scratch space per loop
21+
scratch = cython.cast(
22+
cython.pointer(cython.double),
23+
malloc(cython.sizeof(cython.double)*x.shape[0]))
24+
try:
25+
for i in prange(x.shape[1]):
26+
# copy row into scratch space
27+
for j in range(x.shape[0]):
28+
scratch[j] = x[j, i]
29+
median_it = scratch + x.shape[0]//2
30+
nth_element(scratch, median_it, scratch + x.shape[0])
31+
# for the sake of a simple example, don't handle even lengths...
32+
out[i] = dereference(median_it)
33+
finally:
34+
free(scratch)
35+
return np.asarray(out)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# distutils: language = c++
2+
3+
from cython.parallel cimport parallel, prange
4+
from libcpp.vector cimport vector
5+
from libcpp.algorithm cimport nth_element
6+
cimport cython
7+
from cython.operator cimport dereference
8+
9+
import numpy as np
10+
11+
@cython.boundscheck(False)
12+
@cython.wraparound(False)
13+
def median_along_axis0(const double[:,:] x):
14+
cdef double[::1] out = np.empty(x.shape[1])
15+
cdef Py_ssize_t i, j
16+
17+
cdef vector[double] *scratch
18+
cdef vector[double].iterator median_it
19+
with nogil, parallel():
20+
# allocate scratch space per loop
21+
scratch = new vector[double](x.shape[0])
22+
try:
23+
for i in prange(x.shape[1]):
24+
# copy row into scratch space
25+
for j in range(x.shape[0]):
26+
dereference(scratch)[j] = x[j, i]
27+
median_it = scratch.begin() + scratch.size()//2
28+
nth_element(scratch.begin(), median_it, scratch.end())
29+
# for the sake of a simple example, don't handle even lengths...
30+
out[i] = dereference(median_it)
31+
finally:
32+
del scratch
33+
return np.asarray(out)
34+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from cython.parallel import prange
2+
import cython
3+
from cython.cimports.libc.math import sqrt
4+
5+
@cython.boundscheck(False)
6+
@cython.wraparound(False)
7+
def l2norm(x: cython.double[:]):
8+
total: cython.double = 0
9+
i: cython.Py_ssize_t
10+
for i in prange(x.shape[0], nogil=True):
11+
total += x[i]*x[i]
12+
return sqrt(total)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from cython.parallel cimport prange
2+
cimport cython
3+
from libc.math cimport sqrt
4+
5+
@cython.boundscheck(False)
6+
@cython.wraparound(False)
7+
def l2norm(double[:] x):
8+
cdef double total = 0
9+
cdef Py_ssize_t i
10+
for i in prange(x.shape[0], nogil=True):
11+
total += x[i]*x[i]
12+
return sqrt(total)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from cython.parallel import parallel, prange
2+
import cython
3+
from cython.cimports.libc.math import sqrt
4+
5+
@cython.boundscheck(False)
6+
@cython.wraparound(False)
7+
def normalize(x: cython.double[:]):
8+
i: cython.Py_ssize_t
9+
total: cython.double = 0
10+
norm: cython.double
11+
with cython.nogil, parallel():
12+
for i in prange(x.shape[0]):
13+
total += x[i]*x[i]
14+
norm = sqrt(total)
15+
for i in prange(x.shape[0]):
16+
x[i] /= norm
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from cython.parallel cimport parallel, prange
2+
cimport cython
3+
from libc.math cimport sqrt
4+
5+
@cython.boundscheck(False)
6+
@cython.wraparound(False)
7+
def normalize(double[:] x):
8+
cdef Py_ssize_t i
9+
cdef double total = 0
10+
cdef double norm
11+
with nogil, parallel():
12+
for i in prange(x.shape[0]):
13+
total += x[i]*x[i]
14+
norm = sqrt(total)
15+
for i in prange(x.shape[0]):
16+
x[i] /= norm
17+
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from cython.parallel import prange
2+
import cython
3+
from cython.cimports.libc.math import sin
4+
5+
import numpy as np
6+
7+
@cython.boundscheck(False)
8+
@cython.wraparound(False)
9+
def do_sine(input: cython.double[:,:]):
10+
output : cython.double[:,:] = np.empty_like(input)
11+
i : cython.Py_ssize_t
12+
j : cython.Py_ssize_t
13+
for i in prange(input.shape[0], nogil=True):
14+
for j in range(input.shape[1]):
15+
output[i, j] = sin(input[i, j])
16+
return np.asarray(output)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from cython.parallel cimport prange
2+
cimport cython
3+
from libc.math cimport sin
4+
5+
import numpy as np
6+
7+
@cython.boundscheck(False)
8+
@cython.wraparound(False)
9+
def do_sine(double[:,:] input):
10+
cdef double[:,:] output = np.empty_like(input)
11+
cdef Py_ssize_t i, j
12+
13+
for i in prange(input.shape[0], nogil=True):
14+
for j in range(input.shape[1]):
15+
output[i, j] = sin(input[i, j])
16+
return np.asarray(output)

0 commit comments

Comments
 (0)