|
17 | 17 | from mlir.extras.ast.canonicalize import canonicalize |
18 | 18 | from mlir.extras.dialects.ext import arith, scf, memref, rocdl, gpu |
19 | 19 | from mlir.extras.dialects.ext.func import func |
20 | | - |
21 | 20 | # noinspection PyUnresolvedReferences |
22 | 21 | from mlir.extras.dialects.ext.gpu import ( |
23 | 22 | all_reduce, |
|
38 | 37 | from mlir.extras.dialects.ext.scf import forall, in_parallel_ |
39 | 38 | from mlir.extras.dialects.ext.vector import outer, load, shuffle, print_ |
40 | 39 | from mlir.extras.runtime.passes import run_pipeline, Pipeline |
41 | | - |
42 | 40 | # noinspection PyUnresolvedReferences |
43 | 41 | from mlir.extras.testing import ( |
44 | 42 | mlir_ctx as ctx, |
@@ -78,10 +76,10 @@ def test_forall_insert_slice_no_region_with_for_with_gpu_mapping(ctx: MLIRContex |
78 | 76 | alpha = arith.constant(1, T.f32()) |
79 | 77 |
|
80 | 78 | for i, j in forall( |
81 | | - [1, 1], |
82 | | - [2, 2], |
83 | | - [3, 3], |
84 | | - device_mapping=[thread("x"), thread("y")], |
| 79 | + [1, 1], |
| 80 | + [2, 2], |
| 81 | + [3, 3], |
| 82 | + device_mapping=[thread("x"), thread("y")], |
85 | 83 | ): |
86 | 84 | a = memref.load(x, (i, j)) |
87 | 85 | b = memref.load(y, (i, j)) |
@@ -119,9 +117,9 @@ class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]): |
119 | 117 | @gpu_func(emit=True) |
120 | 118 | @canonicalize(using=scf.canonicalizer) |
121 | 119 | def mat_product_kernel( |
122 | | - A: T.memref(M, N, T.f32()), |
123 | | - B: T.memref(N, K, T.f32()), |
124 | | - C: T.memref(M, K, T.f32()), |
| 120 | + A: T.memref(M, N, T.f32()), |
| 121 | + B: T.memref(N, K, T.f32()), |
| 122 | + C: T.memref(M, K, T.f32()), |
125 | 123 | ): |
126 | 124 | x = block_idx.x |
127 | 125 | y = block_idx.y |
@@ -156,9 +154,9 @@ class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]): |
156 | 154 | @gpu_func(emit=True, emit_grid=True) |
157 | 155 | @canonicalize(using=scf.canonicalizer) |
158 | 156 | def mat_product_kernel( |
159 | | - A: T.memref(M, N, T.f32()), |
160 | | - B: T.memref(N, K, T.f32()), |
161 | | - C: T.memref(M, K, T.f32()), |
| 157 | + A: T.memref(M, N, T.f32()), |
| 158 | + B: T.memref(N, K, T.f32()), |
| 159 | + C: T.memref(M, K, T.f32()), |
162 | 160 | ): |
163 | 161 | x = block_idx.x |
164 | 162 | y = block_idx.y |
@@ -214,9 +212,9 @@ class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]): |
214 | 212 | @gpu_func(emit=True, emit_grid=True) |
215 | 213 | @canonicalize(using=scf.canonicalizer) |
216 | 214 | def mat_product_kernel( |
217 | | - A: T.memref(M, N, T.f32()), |
218 | | - B: T.memref(N, K, T.f32()), |
219 | | - C: T.memref(M, K, T.f32()), |
| 215 | + A: T.memref(M, N, T.f32()), |
| 216 | + B: T.memref(N, K, T.f32()), |
| 217 | + C: T.memref(M, K, T.f32()), |
220 | 218 | ): |
221 | 219 | x = block_idx.x |
222 | 220 | y = block_idx.y |
@@ -283,9 +281,9 @@ class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]): |
283 | 281 | @gpu_func(emit=True, emit_grid=True) |
284 | 282 | @canonicalize(using=scf.canonicalizer) |
285 | 283 | def mat_product_kernel( |
286 | | - A: T.memref(M, N, T.f32()), |
287 | | - B: T.memref(N, K, T.f32()), |
288 | | - C: T.memref(M, K, T.f32()), |
| 284 | + A: T.memref(M, N, T.f32()), |
| 285 | + B: T.memref(N, K, T.f32()), |
| 286 | + C: T.memref(M, K, T.f32()), |
289 | 287 | ): |
290 | 288 | x = block_idx.x |
291 | 289 | y = block_idx.y |
@@ -349,7 +347,7 @@ def main(): |
349 | 347 | data = memref.alloc((2, 6), T.i32()) |
350 | 348 | sum = memref.alloc((2,), T.i32()) |
351 | 349 |
|
352 | | - power_csts = [arith.constant(0)] + [arith.constant(2**i) for i in range(5)] |
| 350 | + power_csts = [arith.constant(0)] + [arith.constant(2 ** i) for i in range(5)] |
353 | 351 | odd_csts = [ |
354 | 352 | arith.constant(3), |
355 | 353 | arith.constant(6), |
@@ -440,7 +438,7 @@ def main(): |
440 | 438 | data = memref.alloc((2, 6), T.i32()) |
441 | 439 | sum = memref.alloc((2,), T.i32()) |
442 | 440 |
|
443 | | - power_csts = [arith.constant(0)] + [arith.constant(2**i) for i in range(5)] |
| 441 | + power_csts = [arith.constant(0)] + [arith.constant(2 ** i) for i in range(5)] |
444 | 442 | odd_csts = [ |
445 | 443 | arith.constant(3), |
446 | 444 | arith.constant(6), |
@@ -710,14 +708,13 @@ def _(): |
710 | 708 |
|
711 | 709 |
|
712 | 710 | def test_amdgpu(ctx: MLIRContext): |
713 | | - |
714 | 711 | set_container_module(ctx.module) |
715 | 712 |
|
716 | 713 | M, K, N, dtype = 32, 32, 32, T.f32() |
717 | 714 |
|
718 | 715 | @gpu_func |
719 | 716 | def mat_product_kernel( |
720 | | - A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype) |
| 717 | + A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype) |
721 | 718 | ): |
722 | 719 | x = block_dim.x * block_idx.x + thread_idx.x |
723 | 720 | y = block_dim.y * block_idx.y + thread_idx.y |
@@ -829,7 +826,7 @@ def test_amdgpu_square(ctx: MLIRContext): |
829 | 826 |
|
830 | 827 | @gpu_func |
831 | 828 | def mat_product_kernel( |
832 | | - A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype) |
| 829 | + A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype) |
833 | 830 | ): |
834 | 831 | x = block_dim.x * block_idx.x + thread_idx.x |
835 | 832 | y = block_dim.y * block_idx.y + thread_idx.y |
@@ -943,9 +940,9 @@ def test_amdgpu_vector(ctx: MLIRContext): |
943 | 940 |
|
944 | 941 | @gpu_func |
945 | 942 | def smol_matmul( |
946 | | - A: T.memref(M, K, T.f32()), |
947 | | - B: T.memref(K, N, T.f32()), |
948 | | - C: T.memref(M, N, T.f32()), |
| 943 | + A: T.memref(M, K, T.f32()), |
| 944 | + B: T.memref(K, N, T.f32()), |
| 945 | + C: T.memref(M, N, T.f32()), |
949 | 946 | ): |
950 | 947 | cst = arith.constant(np.full([4], 0.0, np.float32), T.vector(4, T.f32())) |
951 | 948 | cst_0 = arith.constant( |
@@ -1186,9 +1183,9 @@ def test_amdgpu_vector_wmma(ctx: MLIRContext): |
1186 | 1183 | @gpu_func |
1187 | 1184 | @canonicalize(using=scf.canonicalizer) |
1188 | 1185 | def smol_matmul( |
1189 | | - a: T.memref(M, K, T.f16()), |
1190 | | - b: T.memref(K, N, T.f16()), |
1191 | | - c: T.memref(M, N, T.f16()), |
| 1186 | + a: T.memref(M, K, T.f16()), |
| 1187 | + b: T.memref(K, N, T.f16()), |
| 1188 | + c: T.memref(M, N, T.f16()), |
1192 | 1189 | ): |
1193 | 1190 | lIdx = thread_idx.x |
1194 | 1191 | # a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b |
|
0 commit comments