diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 2d3075d07d..a191a70c80 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "42330ec03d8ec626df49fe24e2af9e65751e50bd" +ENZYMEXLA_COMMIT = "a47b63c2bd7754cf33ab71c257c3749e63b4ff1b" ENZYMEXLA_SHA256 = "" diff --git a/docs/Project.toml b/docs/Project.toml index c0ccd2bbcb..550ce7eb64 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +PrettyChairmarks = "aafa11c5-44f9-44a1-b829-427e6ce1ffc2" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 2853e7abd5..792461abb5 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -101,6 +101,7 @@ export default defineConfig({ text: "Persistent Compilation Cache", link: "/tutorials/persistent_compile_cache", }, + { text: "Raising", link: "/tutorials/raising" } ], }, { @@ -185,6 +186,7 @@ export default defineConfig({ text: "Persistent Compilation Cache", link: "/tutorials/persistent_compile_cache", }, + { text: "Raising", link: "/tutorials/raising" } ], } ], diff --git a/docs/src/tutorials/raising.md b/docs/src/tutorials/raising.md new file mode 100644 index 0000000000..a807259b17 --- /dev/null +++ b/docs/src/tutorials/raising.md @@ -0,0 +1,104 @@ +# Raising + +## Raising GPU Kernels + + + +## Raising Scalar Loops to Tensor IR + +We will implement a simple N body simulation code in Reactant. Instead of using +broadcasting or high-level abstractions, we will use loops and scalar operations +to implement this. + +```@example raising_stablehlo +using Reactant, PrettyChairmarks + +Reactant.allowscalar(true) # generally not recommended to turn on globally +``` + +We will implement a naive function to compute the attractive force between each +pair of particles in a system. + +```@example raising_stablehlo +function compute_attractive_force( + positions::AbstractMatrix, masses::AbstractVector, G::Number +) + N = size(positions, 2) + F = similar(positions, N, N) + + @trace for i in 1:N + @trace for j in 1:N + dx = positions[1, i] - positions[1, j] + dy = positions[2, i] - positions[2, j] + dz = positions[3, i] - positions[3, j] + + invr² = ifelse(i == j, dx, inv(dx^2 + dy^2 + dz^2)) + + Fx = G * masses[i] * masses[j] * invr² * dx + Fy = G * masses[i] * masses[j] * invr² * dy + Fz = G * masses[i] * masses[j] * invr² * dz + F[i, j] = Fx + Fy + Fz + end + end + + return F +end +``` + +```@example raising_stablehlo +positions = randn(Float32, 3, 1024) +masses = rand(Float32, 1024) .* 10 + +positions_ra = Reactant.to_rarray(positions) +masses_ra = Reactant.to_rarray(masses) +nothing # hide +``` + +Let's see what the HLO IR looks like for this function (without enabling the loop +raising). + +```@example raising_stablehlo +@code_hlo compile_options = CompileOptions(; + disable_auto_batching_passes=true +) compute_attractive_force(positions_ra, masses_ra, 2.0f0) +``` + +This IR has a nested loop, but that won't work nicely for GPUs/TPUs. Even for CPUs, XLA +often doens't do a great job with loops. By default, we will attempt to raise loops to a +tensor IR. + +```@example raising_stablehlo +hlo = @code_hlo compute_attractive_force(positions_ra, masses_ra, 2.0f0) +@assert !contains(repr(hlo), "stablehlo.while") #hide +hlo +``` + +This IR won't have any loops, instead it will be written in a tensor IR! Let ensure that +the values are identical. + +```@example raising_stablehlo +y_jl = compute_attractive_force(positions, masses, 2.0f0) +y_ra = @jit compute_attractive_force(positions_ra, masses_ra, 2.0f0) +maximum(abs, Array(y_ra) .- y_jl) +``` + +Let's time the execution of the two versions. + +```@example raising_stablehlo +fn1 = @compile sync=true compile_options=CompileOptions(; + disable_auto_batching_passes=true +) compute_attractive_force(positions_ra, masses_ra, 2.0f0) +fn2 = @compile sync=true compute_attractive_force(positions_ra, masses_ra, 2.0f0) +``` + +Runtime for non-raised function: + +```@example raising_stablehlo +@bs fn1(positions_ra, masses_ra, 2.0f0) +``` + +Runtime for raised function: + +```@example raising_stablehlo +@bs fn2(positions_ra, masses_ra, 2.0f0) +```