Skip to content

Commit 02ecfaa

Browse files
committed
docs: add tutorial on raising loops [skip ci]
1 parent 19a94cf commit 02ecfaa

File tree

4 files changed

+108
-1
lines changed

4 files changed

+108
-1
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "8e83bb047e01f2f0499b5a0680405ff029330436"
7+
ENZYMEXLA_COMMIT = "a47b63c2bd7754cf33ab71c257c3749e63b4ff1b"
88

99
ENZYMEXLA_SHA256 = ""
1010

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
44
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
55
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
6+
PrettyChairmarks = "aafa11c5-44f9-44a1-b829-427e6ce1ffc2"
67
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
78
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
89
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ export default defineConfig({
101101
text: "Persistent Compilation Cache",
102102
link: "/tutorials/persistent_compile_cache",
103103
},
104+
{ text: "Raising", link: "/tutorials/raising" }
104105
],
105106
},
106107
{
@@ -185,6 +186,7 @@ export default defineConfig({
185186
text: "Persistent Compilation Cache",
186187
link: "/tutorials/persistent_compile_cache",
187188
},
189+
{ text: "Raising", link: "/tutorials/raising" }
188190
],
189191
}
190192
],

docs/src/tutorials/raising.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Raising
2+
3+
## Raising GPU Kernels
4+
5+
<!-- TODO: write this section -->
6+
7+
## Raising Scalar Loops to Tensor IR
8+
9+
We will implement a simple N body simulation code in Reactant. Instead of using
10+
broadcasting or high-level abstractions, we will use loops and scalar operations
11+
to implement this.
12+
13+
```@example raising_stablehlo
14+
using Reactant, PrettyChairmarks
15+
16+
Reactant.allowscalar(true) # generally not recommended to turn on globally
17+
```
18+
19+
We will implement a naive function to compute the attractive force between each
20+
pair of particles in a system.
21+
22+
```@example raising_stablehlo
23+
function compute_attractive_force(
24+
positions::AbstractMatrix, masses::AbstractVector, G::Number
25+
)
26+
N = size(positions, 2)
27+
F = similar(positions, N, N)
28+
29+
@trace for i in 1:N
30+
@trace for j in 1:N
31+
dx = positions[1, i] - positions[1, j]
32+
dy = positions[2, i] - positions[2, j]
33+
dz = positions[3, i] - positions[3, j]
34+
35+
invr² = ifelse(i == j, dx, inv(dx^2 + dy^2 + dz^2))
36+
37+
Fx = G * masses[i] * masses[j] * invr² * dx
38+
Fy = G * masses[i] * masses[j] * invr² * dy
39+
Fz = G * masses[i] * masses[j] * invr² * dz
40+
F[i, j] = Fx + Fy + Fz
41+
end
42+
end
43+
44+
return F
45+
end
46+
```
47+
48+
```@example raising_stablehlo
49+
positions = randn(Float32, 3, 1024)
50+
masses = rand(Float32, 1024) .* 10
51+
52+
positions_ra = Reactant.to_rarray(positions)
53+
masses_ra = Reactant.to_rarray(masses)
54+
nothing # hide
55+
```
56+
57+
Let's see what the HLO IR looks like for this function (without enabling the loop
58+
raising).
59+
60+
```@example raising_stablehlo
61+
@code_hlo compile_options = CompileOptions(;
62+
disable_auto_batching_passes=true
63+
) compute_attractive_force(positions_ra, masses_ra, 2.0f0)
64+
```
65+
66+
This IR has a nested loop, but that won't work nicely for GPUs/TPUs. Even for CPUs, XLA
67+
often doens't do a great job with loops. By default, we will attempt to raise loops to a
68+
tensor IR.
69+
70+
```@example raising_stablehlo
71+
hlo = @code_hlo compute_attractive_force(positions_ra, masses_ra, 2.0f0)
72+
@assert !contains(repr(hlo), "stablehlo.while") #hide
73+
hlo
74+
```
75+
76+
This IR won't have any loops, instead it will be written in a tensor IR! Let ensure that
77+
the values are identical.
78+
79+
```@example raising_stablehlo
80+
y_jl = compute_attractive_force(positions, masses, 2.0f0)
81+
y_ra = @jit compute_attractive_force(positions_ra, masses_ra, 2.0f0)
82+
maximum(abs, Array(y_ra) .- y_jl)
83+
```
84+
85+
Let's time the execution of the two versions.
86+
87+
```@example raising_stablehlo
88+
fn1 = @compile sync=true compile_options=CompileOptions(;
89+
disable_auto_batching_passes=true
90+
) compute_attractive_force(positions_ra, masses_ra, 2.0f0)
91+
fn2 = @compile sync=true compute_attractive_force(positions_ra, masses_ra, 2.0f0)
92+
```
93+
94+
Runtime for non-raised function:
95+
96+
```@example raising_stablehlo
97+
@bs fn1(positions_ra, masses_ra, 2.0f0)
98+
```
99+
100+
Runtime for raised function:
101+
102+
```@example raising_stablehlo
103+
@bs fn2(positions_ra, masses_ra, 2.0f0)
104+
```

0 commit comments

Comments
 (0)