Skip to content

Commit edd5940

Browse files
authored
Extend vector (#59)
1 parent 4453fea commit edd5940

File tree

7 files changed

+441
-229
lines changed

7 files changed

+441
-229
lines changed

examples/mwe.py

Lines changed: 137 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,159 @@
1-
import mlir.extras.types as T
21
import numpy as np
32

3+
import mlir.extras.types as T
4+
from mlir.dialects import builtin
5+
from mlir.dialects.transform import any_op_t
6+
from mlir.dialects.transform.extras import named_sequence, apply_patterns
7+
from mlir.extras.util import find_ops
8+
from mlir.ir import StringAttr, UnitAttr
9+
410
# you need this to register the memref value caster
511
# noinspection PyUnresolvedReferences
612
import mlir.extras.dialects.ext.memref
7-
from mlir.extras.ast.canonicalize import canonicalize
813
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
9-
from mlir.extras.dialects.ext.arith import constant
14+
from mlir.extras.dialects.ext.bufferization import LayoutMapOption
15+
from mlir.dialects.transform.vector import (
16+
VectorContractLowering,
17+
VectorMultiReductionLowering,
18+
VectorTransferSplit,
19+
VectorTransposeLowering,
20+
)
21+
from mlir.extras.dialects.ext import linalg
1022
from mlir.extras.dialects.ext.func import func
11-
from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range
12-
from mlir.extras.runtime.passes import Pipeline
23+
from mlir.extras.dialects.ext.transform import (
24+
match,
25+
tile_to_scf_for,
26+
get_parent_op,
27+
transform_any_op_t,
28+
)
29+
from mlir.extras.dialects.ext import transform
30+
from mlir.extras.runtime.passes import Pipeline, run_pipeline
1331
from mlir.extras.runtime.refbackend import LLVMJITBackend
1432

1533
ctx = RAIIMLIRContext()
1634
backend = LLVMJITBackend()
1735
module = ExplicitlyManagedModule()
1836

19-
K = 10
20-
memref_i64 = T.memref(K, K, T.i64())
37+
M, K, N = 2, 4, 6
2138

2239

2340
@func
24-
@canonicalize(using=scf)
25-
def memfoo(A: memref_i64, B: memref_i64, C: memref_i64):
26-
one = constant(1)
27-
two = constant(2)
28-
if one > two:
29-
three = constant(3)
30-
else:
31-
for i in range(0, K):
32-
for j in range(0, K):
33-
C[i, j] = A[i, j] * B[i, j]
34-
35-
36-
memfoo.emit()
37-
module = module.finish()
41+
def matmul_tensors(
42+
A: T.tensor(M, K, T.f32()),
43+
B: T.tensor(K, N, T.f32()),
44+
C: T.tensor(M, N, T.f32()),
45+
):
46+
return linalg.matmul(A, B, C)
47+
48+
49+
@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")})
50+
def payload():
51+
matmul_tensors.emit(force=True)
52+
53+
54+
@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()})
55+
def mod_transform():
56+
@named_sequence("main", [any_op_t()], [])
57+
def main(module_op: any_op_t()):
58+
matmul = match(module_op, ops=["linalg.matmul"])
59+
tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[2, 2, 2])
60+
transform.structured.vectorize_children_and_apply_patterns(
61+
get_parent_op(transform_any_op_t(), tiled_matmul, isolated_from_above=True)
62+
)
63+
new_mod = transform.bufferization.one_shot_bufferize(
64+
module_op,
65+
function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap,
66+
bufferize_function_boundaries=True,
67+
)
68+
69+
func_op = match(new_mod, ops=["func.func"])
3870

39-
print(module)
71+
@apply_patterns(func_op)
72+
def pats():
73+
transform.apply_patterns.vector.lower_contraction(
74+
lowering_strategy=VectorContractLowering.OuterProduct
75+
)
76+
transform.apply_patterns.vector.transfer_permutation_patterns()
77+
transform.apply_patterns.vector.lower_multi_reduction(
78+
lowering_strategy=VectorMultiReductionLowering.InnerParallel
79+
)
80+
transform.apply_patterns.vector.split_transfer_full_partial(
81+
split_transfer_strategy=VectorTransferSplit.LinalgCopy
82+
)
83+
transform.apply_patterns.vector.transfer_to_scf(
84+
max_transfer_rank=1, full_unroll=True
85+
)
86+
transform.apply_patterns.vector.lower_transfer(max_transfer_rank=1)
87+
transform.apply_patterns.vector.lower_shape_cast()
88+
transform.apply_patterns.vector.lower_transpose(
89+
lowering_strategy=VectorTransposeLowering.Shuffle1D
90+
)
4091

41-
module = backend.compile(
92+
93+
module = module.finish()
94+
# print(module)
95+
96+
vectorized_module = run_pipeline(
4297
module,
43-
kernel_name=memfoo.__name__,
44-
pipeline=Pipeline().bufferize().lower_to_llvm(),
98+
pipeline=Pipeline().transform_interpreter(
99+
entry_point="main", debug_payload_root_tag="payload"
100+
),
101+
)
102+
103+
# print(vectorized_module)
104+
105+
# https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44
106+
lower_to_llvm = (
107+
Pipeline()
108+
.Func(
109+
Pipeline()
110+
# Blanket-convert any remaining high-level vector ops to loops if any remain.
111+
.convert_vector_to_scf()
112+
# Blanket-convert any remaining linalg ops to loops if any remain.
113+
.convert_linalg_to_loops()
114+
)
115+
# Blanket-convert any remaining affine ops if any remain.
116+
.lower_affine()
117+
# Convert SCF to CF (always needed).
118+
.convert_scf_to_cf()
119+
# Sprinkle some cleanups.
120+
.canonicalize()
121+
.cse()
122+
# Convert vector to LLVM (always needed).
123+
.convert_vector_to_llvm()
124+
# Convert Math to LLVM (always needed).
125+
.Func(Pipeline().convert_math_to_llvm())
126+
# Expand complicated MemRef operations before lowering them.
127+
.expand_strided_metadata()
128+
# The expansion may create affine expressions. Get rid of them.
129+
.lower_affine()
130+
# Convert MemRef to LLVM (always needed).
131+
.finalize_memref_to_llvm()
132+
# Convert Func to LLVM (always needed).
133+
.convert_func_to_llvm()
134+
# Convert Index to LLVM (always needed).
135+
.convert_index_to_llvm()
136+
# Convert remaining unrealized_casts (always needed).
137+
.reconcile_unrealized_casts()
45138
)
46139

47-
# windows defaults to int32
48-
A = np.random.randint(0, 10, (K, K)).astype(np.int64)
49-
B = np.random.randint(0, 10, (K, K)).astype(np.int64)
50-
C = np.zeros((K, K), dtype=np.int64)
51140

52-
backend.load(module).memfoo(A, B, C)
53-
assert np.array_equal(A * B, C)
141+
compiled_module = backend.compile(
142+
find_ops(
143+
vectorized_module.operation,
144+
lambda x: "transform.target_tag" in x.attributes
145+
and x.attributes["transform.target_tag"].value == "payload",
146+
single=True,
147+
),
148+
kernel_name=matmul_tensors.__name__,
149+
pipeline=lower_to_llvm,
150+
)
151+
152+
print(compiled_module)
153+
154+
A = np.random.randint(0, 10, (M, K)).astype(np.float32)
155+
B = np.random.randint(0, 10, (K, N)).astype(np.float32)
156+
C = np.zeros((M, N), dtype=np.float32)
157+
158+
backend.load(compiled_module).matmul_tensors_capi_wrapper(A, B, C)
159+
assert np.allclose(A @ B, C)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# noinspection PyUnresolvedReferences
2+
from ....dialects.bufferization import *

0 commit comments

Comments
 (0)