|
1 | | -import mlir.extras.types as T |
2 | 1 | import numpy as np |
3 | 2 |
|
| 3 | +import mlir.extras.types as T |
| 4 | +from mlir.dialects import builtin |
| 5 | +from mlir.dialects.transform import any_op_t |
| 6 | +from mlir.dialects.transform.extras import named_sequence, apply_patterns |
| 7 | +from mlir.extras.util import find_ops |
| 8 | +from mlir.ir import StringAttr, UnitAttr |
| 9 | + |
4 | 10 | # you need this to register the memref value caster |
5 | 11 | # noinspection PyUnresolvedReferences |
6 | 12 | import mlir.extras.dialects.ext.memref |
7 | | -from mlir.extras.ast.canonicalize import canonicalize |
8 | 13 | from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule |
9 | | -from mlir.extras.dialects.ext.arith import constant |
| 14 | +from mlir.extras.dialects.ext.bufferization import LayoutMapOption |
| 15 | +from mlir.dialects.transform.vector import ( |
| 16 | + VectorContractLowering, |
| 17 | + VectorMultiReductionLowering, |
| 18 | + VectorTransferSplit, |
| 19 | + VectorTransposeLowering, |
| 20 | +) |
| 21 | +from mlir.extras.dialects.ext import linalg |
10 | 22 | from mlir.extras.dialects.ext.func import func |
11 | | -from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range |
12 | | -from mlir.extras.runtime.passes import Pipeline |
| 23 | +from mlir.extras.dialects.ext.transform import ( |
| 24 | + match, |
| 25 | + tile_to_scf_for, |
| 26 | + get_parent_op, |
| 27 | + transform_any_op_t, |
| 28 | +) |
| 29 | +from mlir.extras.dialects.ext import transform |
| 30 | +from mlir.extras.runtime.passes import Pipeline, run_pipeline |
13 | 31 | from mlir.extras.runtime.refbackend import LLVMJITBackend |
14 | 32 |
|
15 | 33 | ctx = RAIIMLIRContext() |
16 | 34 | backend = LLVMJITBackend() |
17 | 35 | module = ExplicitlyManagedModule() |
18 | 36 |
|
19 | | -K = 10 |
20 | | -memref_i64 = T.memref(K, K, T.i64()) |
| 37 | +M, K, N = 2, 4, 6 |
21 | 38 |
|
22 | 39 |
|
23 | 40 | @func |
24 | | -@canonicalize(using=scf) |
25 | | -def memfoo(A: memref_i64, B: memref_i64, C: memref_i64): |
26 | | - one = constant(1) |
27 | | - two = constant(2) |
28 | | - if one > two: |
29 | | - three = constant(3) |
30 | | - else: |
31 | | - for i in range(0, K): |
32 | | - for j in range(0, K): |
33 | | - C[i, j] = A[i, j] * B[i, j] |
34 | | - |
35 | | - |
36 | | -memfoo.emit() |
37 | | -module = module.finish() |
| 41 | +def matmul_tensors( |
| 42 | + A: T.tensor(M, K, T.f32()), |
| 43 | + B: T.tensor(K, N, T.f32()), |
| 44 | + C: T.tensor(M, N, T.f32()), |
| 45 | +): |
| 46 | + return linalg.matmul(A, B, C) |
| 47 | + |
| 48 | + |
| 49 | +@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")}) |
| 50 | +def payload(): |
| 51 | + matmul_tensors.emit(force=True) |
| 52 | + |
| 53 | + |
| 54 | +@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()}) |
| 55 | +def mod_transform(): |
| 56 | + @named_sequence("main", [any_op_t()], []) |
| 57 | + def main(module_op: any_op_t()): |
| 58 | + matmul = match(module_op, ops=["linalg.matmul"]) |
| 59 | + tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[2, 2, 2]) |
| 60 | + transform.structured.vectorize_children_and_apply_patterns( |
| 61 | + get_parent_op(transform_any_op_t(), tiled_matmul, isolated_from_above=True) |
| 62 | + ) |
| 63 | + new_mod = transform.bufferization.one_shot_bufferize( |
| 64 | + module_op, |
| 65 | + function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap, |
| 66 | + bufferize_function_boundaries=True, |
| 67 | + ) |
| 68 | + |
| 69 | + func_op = match(new_mod, ops=["func.func"]) |
38 | 70 |
|
39 | | -print(module) |
| 71 | + @apply_patterns(func_op) |
| 72 | + def pats(): |
| 73 | + transform.apply_patterns.vector.lower_contraction( |
| 74 | + lowering_strategy=VectorContractLowering.OuterProduct |
| 75 | + ) |
| 76 | + transform.apply_patterns.vector.transfer_permutation_patterns() |
| 77 | + transform.apply_patterns.vector.lower_multi_reduction( |
| 78 | + lowering_strategy=VectorMultiReductionLowering.InnerParallel |
| 79 | + ) |
| 80 | + transform.apply_patterns.vector.split_transfer_full_partial( |
| 81 | + split_transfer_strategy=VectorTransferSplit.LinalgCopy |
| 82 | + ) |
| 83 | + transform.apply_patterns.vector.transfer_to_scf( |
| 84 | + max_transfer_rank=1, full_unroll=True |
| 85 | + ) |
| 86 | + transform.apply_patterns.vector.lower_transfer(max_transfer_rank=1) |
| 87 | + transform.apply_patterns.vector.lower_shape_cast() |
| 88 | + transform.apply_patterns.vector.lower_transpose( |
| 89 | + lowering_strategy=VectorTransposeLowering.Shuffle1D |
| 90 | + ) |
40 | 91 |
|
41 | | -module = backend.compile( |
| 92 | + |
| 93 | +module = module.finish() |
| 94 | +# print(module) |
| 95 | + |
| 96 | +vectorized_module = run_pipeline( |
42 | 97 | module, |
43 | | - kernel_name=memfoo.__name__, |
44 | | - pipeline=Pipeline().bufferize().lower_to_llvm(), |
| 98 | + pipeline=Pipeline().transform_interpreter( |
| 99 | + entry_point="main", debug_payload_root_tag="payload" |
| 100 | + ), |
| 101 | +) |
| 102 | + |
| 103 | +# print(vectorized_module) |
| 104 | + |
| 105 | +# https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44 |
| 106 | +lower_to_llvm = ( |
| 107 | + Pipeline() |
| 108 | + .Func( |
| 109 | + Pipeline() |
| 110 | + # Blanket-convert any remaining high-level vector ops to loops if any remain. |
| 111 | + .convert_vector_to_scf() |
| 112 | + # Blanket-convert any remaining linalg ops to loops if any remain. |
| 113 | + .convert_linalg_to_loops() |
| 114 | + ) |
| 115 | + # Blanket-convert any remaining affine ops if any remain. |
| 116 | + .lower_affine() |
| 117 | + # Convert SCF to CF (always needed). |
| 118 | + .convert_scf_to_cf() |
| 119 | + # Sprinkle some cleanups. |
| 120 | + .canonicalize() |
| 121 | + .cse() |
| 122 | + # Convert vector to LLVM (always needed). |
| 123 | + .convert_vector_to_llvm() |
| 124 | + # Convert Math to LLVM (always needed). |
| 125 | + .Func(Pipeline().convert_math_to_llvm()) |
| 126 | + # Expand complicated MemRef operations before lowering them. |
| 127 | + .expand_strided_metadata() |
| 128 | + # The expansion may create affine expressions. Get rid of them. |
| 129 | + .lower_affine() |
| 130 | + # Convert MemRef to LLVM (always needed). |
| 131 | + .finalize_memref_to_llvm() |
| 132 | + # Convert Func to LLVM (always needed). |
| 133 | + .convert_func_to_llvm() |
| 134 | + # Convert Index to LLVM (always needed). |
| 135 | + .convert_index_to_llvm() |
| 136 | + # Convert remaining unrealized_casts (always needed). |
| 137 | + .reconcile_unrealized_casts() |
45 | 138 | ) |
46 | 139 |
|
47 | | -# windows defaults to int32 |
48 | | -A = np.random.randint(0, 10, (K, K)).astype(np.int64) |
49 | | -B = np.random.randint(0, 10, (K, K)).astype(np.int64) |
50 | | -C = np.zeros((K, K), dtype=np.int64) |
51 | 140 |
|
52 | | -backend.load(module).memfoo(A, B, C) |
53 | | -assert np.array_equal(A * B, C) |
| 141 | +compiled_module = backend.compile( |
| 142 | + find_ops( |
| 143 | + vectorized_module.operation, |
| 144 | + lambda x: "transform.target_tag" in x.attributes |
| 145 | + and x.attributes["transform.target_tag"].value == "payload", |
| 146 | + single=True, |
| 147 | + ), |
| 148 | + kernel_name=matmul_tensors.__name__, |
| 149 | + pipeline=lower_to_llvm, |
| 150 | +) |
| 151 | + |
| 152 | +print(compiled_module) |
| 153 | + |
| 154 | +A = np.random.randint(0, 10, (M, K)).astype(np.float32) |
| 155 | +B = np.random.randint(0, 10, (K, N)).astype(np.float32) |
| 156 | +C = np.zeros((M, N), dtype=np.float32) |
| 157 | + |
| 158 | +backend.load(compiled_module).matmul_tensors_capi_wrapper(A, B, C) |
| 159 | +assert np.allclose(A @ B, C) |
0 commit comments