Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 121 additions & 6 deletions tests/test_gpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import ctypes
import platform
import random
import sys
import tempfile
import time
from textwrap import dedent

import mlir.extras.types as T
Expand Down Expand Up @@ -40,7 +41,7 @@

# noinspection PyUnresolvedReferences
from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext
from util import hip_bindings_not_installed, hip_check, launch_kernel
from util import hip_bindings_not_installed, hip_check, launch_kernel, hip_synchronize

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")
Expand Down Expand Up @@ -962,6 +963,7 @@ def test_amdgpu_vector(ctx: MLIRContext):

scale = 2
M, K, N = 2 * scale, 4 * scale, 6 * scale
tz_a, tz_b, tz_c = [2, 2, 2]
v2f32 = T.vector(2, T.f32())

@gpu_func
Expand All @@ -972,11 +974,11 @@ def smol_matmul(
):
cst = arith.constant(np.full([4], 0.0, np.float32), T.vector(4, T.f32()))
cst_0 = arith.constant(
np.full([2, 2], 0.0, np.float32), T.vector(2, 2, T.f32())
np.full([tz_a, tz_b], 0.0, np.float32), T.vector(tz_a, tz_b, T.f32())
)
for i, C, v0 in scf.range_(0, M, 2, iter_args=[C]):
for j, C, v1 in scf.range_(0, N, 2, iter_args=[C]):
for k, C, v2 in scf.range_(0, K, 2, iter_args=[C]):
for i, C, v0 in scf.range_(0, M, tz_a, iter_args=[C]):
for j, C, v1 in scf.range_(0, N, tz_b, iter_args=[C]):
for k, C, v2 in scf.range_(0, K, tz_c, iter_args=[C]):
cst[0::1] = A @ load(v2f32) @ [i, k]
cst[2::1] = A @ load(v2f32) @ [i + 1, k]
cst_0[0] = C @ load(v2f32) @ [i, j]
Expand Down Expand Up @@ -1078,3 +1080,116 @@ def gpu_module():
hip_check(hip.hipFree(c_d))

hip_check(hip.hipModuleUnload(hip_module))


@pytest.mark.skipif(hip_bindings_not_installed(), reason="hip not installed")
def test_amdgpu_bank_conflicts(ctx: MLIRContext):
from hip import hip

set_container_module(ctx.module)

M = 1024

@gpu_func
def no_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())):
for i in range(M):
a = A[i, thread_idx.x]
B[i, thread_idx.x] = a * a

@gpu_func
def all_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())):
for i in range(M):
a = A[i, thread_idx.x]
B[thread_idx.x, i] = a * a

props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))
arch = props.gcnArchName.decode()

@module("naive", [f'#rocdl.target<chip = "{arch}">'])
def gpu_module():
no_bank_conflicts.emit()
all_bank_conflicts.emit()

lowered_module = run_pipeline(
gpu_module,
Pipeline()
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
.rocdl_attach_target(chip=arch)
.gpu_to_llvm()
.lower_to_llvm()
.gpu_module_to_binary(),
)

hsaco = get_compile_object_bytes(lowered_module)
hip_module = hip_check(hip.hipModuleLoadData(hsaco))

a_h = np.arange(M).astype(dtype=np.float32)
a_h = np.tile(a_h, (M, 1))
b_h = np.zeros((M, M), dtype=np.float32)

a_num_bytes = a_h.size * a_h.itemsize
b_num_bytes = b_h.size * b_h.itemsize

a_d = hip_check(hip.hipMalloc(a_num_bytes))
b_d = hip_check(hip.hipMalloc(b_num_bytes))

gridX = max(M // 32, 1)
gridY = max(M // 8, 1)
gridZ = 1
warp_size = 32
num_warps = 8
stream = 0
shared_memory = 0

times = {
no_bank_conflicts.__name__: 0,
all_bank_conflicts.__name__: 0,
}
runs = 10
start, stop = hip.hipEventCreate(), hip.hipEventCreate()
for i in range(runs):
kernels = [no_bank_conflicts, all_bank_conflicts]
random.shuffle(kernels)
for kernel in kernels:
hip_check(
hip.hipMemcpy(
a_d, a_h, a_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice
)
)
hip_check(
hip.hipMemcpy(
b_d, b_h, b_num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice
)
)
function = hip_check(
hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())
)

start = time.monotonic()
launch_kernel(
function.as_c_void_p(),
gridX,
gridY,
gridZ,
warp_size,
num_warps,
stream,
shared_memory,
a_d,
b_d,
)
hip_synchronize()
if i > 0:
times[kernel.__name__] += time.monotonic() - start

hip_check(
hip.hipMemcpy(
b_h, b_d, b_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
)
)

times[no_bank_conflicts.__name__] /= runs
times[all_bank_conflicts.__name__] /= runs
for k, v in times.items():
print(f"{k}: {v:.3e}ms")
75 changes: 51 additions & 24 deletions tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
ShapedType,
AffineMap,
AffineConstantExpr,
Attribute,
ArrayAttr,
)

from mlir.extras import types as T
Expand All @@ -28,6 +30,7 @@
# you need this to register the memref value caster
# noinspection PyUnresolvedReferences
from mlir.extras.dialects.ext import arith, linalg, memref, transform, vector, scf, func
from mlir.dialects import affine
from mlir.extras.dialects.ext.vector import outer, shuffle, load
from mlir.extras.dialects.ext.transform import (
get_parent_op,
Expand Down Expand Up @@ -71,7 +74,7 @@ def mod_transform():
@named_sequence("main", [any_op_t()], [])
def main(module_op: any_op_t()):
matmul = match(module_op, ops=["linalg.matmul"])
tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[2, 2, 2])
tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[4, 3, 8])
transform.structured.vectorize_children_and_apply_patterns(
get_parent_op(
transform_any_op_t(), tiled_matmul, isolated_from_above=True
Expand Down Expand Up @@ -136,40 +139,64 @@ def pats():
assert np.allclose(A @ B, C)


def test_e2e_sugar(ctx: MLIRContext):
testdata = (
(2, 2, 2),
(2, 3, 2),
(2, 3, 4),
(2, 4, 8),
(2, 4, 16),
(4, 4, 16),
(4, 8, 16),
)


@pytest.mark.parametrize("tz_a, tz_b, tz_c", testdata)
def test_e2e_sugar(ctx: MLIRContext, tz_a, tz_b, tz_c):
backend = LLVMJITBackend()

scale = 16
M, K, N = 2 * scale, 4 * scale, 6 * scale
v2f32 = T.vector(2, T.f32())

vaf32 = T.vector(tz_a, T.f32())
vbf32 = T.vector(tz_b, T.f32())
vcf32 = T.vector(tz_c, T.f32())
vacrossbf32 = T.vector(tz_a, tz_b, T.f32())
vatimescf32 = T.vector(tz_a * tz_c, T.f32())

shuffle_mask = np.arange(tz_a * tz_c).reshape(tz_a, tz_c).reshape((-1,), order="F")

@func.func(emit=True)
def smol_matmul(
A: T.memref(M, K, T.f32()),
B: T.memref(K, N, T.f32()),
C: T.memref(M, N, T.f32()),
):
cst = arith.constant(np.full([4], 0.0, np.float32), T.vector(4, T.f32()))
cst_0 = arith.constant(
np.full([2, 2], 0.0, np.float32), T.vector(2, 2, T.f32())
)
for i, C, v0 in scf.range_(0, M, 2, iter_args=[C]):
for j, C, v1 in scf.range_(0, N, 2, iter_args=[C]):
for k, C, v2 in scf.range_(0, K, 2, iter_args=[C]):
cst[0::1] = A @ load(v2f32) @ [i, k]
cst[2::1] = A @ load(v2f32) @ [i + 1, k]
cst_0[0] = C @ load(v2f32) @ [i, j]
cst_0[1] = C @ load(v2f32) @ [i + 1, j]
cst = cst @ shuffle(mask=[0, 2, 1, 3]) @ cst

v19 = cst[0:2:1] @ outer(cst_0) @ (B @ load(v2f32) @ [k, j])
v20 = cst[2:4:1] @ outer(v19) @ (B @ load(v2f32) @ [k + 1, j])
C[i, j] = v20[0]
C[i + 1, j] = v20[1]

scf.yield_(C)
scf.yield_(v2)
scf.yield_(v1)
cst = arith.constant(np.full([tz_a * tz_c], 0.0, np.float32), vatimescf32)
acc = arith.constant(np.full([tz_a, tz_b], 0.0, np.float32), vacrossbf32)

for m, C, v0 in scf.range_(0, M, tz_a, iter_args=[C]):
for n, C, v1 in scf.range_(0, N, tz_b, iter_args=[C]):
for k, C, v2 in scf.range_(0, K, tz_c, iter_args=[C]):
for i in range(tz_a):
cst[tz_c * i :: 1] = A @ load(vcf32) @ [m + i, k]
cst = cst @ shuffle(mask=shuffle_mask) @ cst

for i in range(tz_a):
acc[i] = C @ load(vbf32) @ [m + i, n]

for i in range(tz_c):
acc = (
(cst[i * tz_a : (i + 1) * tz_a : 1])
@ outer(acc)
@ (B @ load(vbf32) @ [k + i, n])
)

for i in range(tz_a):
C[m + i, n] = acc[i]

scf.yield_(results_=[C])
scf.yield_(results_=[v2])
scf.yield_(results_=[v1])

compiled_module = backend.compile(
ctx.module,
Expand Down
6 changes: 6 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def hip_check(call_result):
return result


def hip_synchronize():
from hip import hip

hip.hipDeviceSynchronize()


def hip_bindings_not_installed():
try:
from hip import hip
Expand Down