Skip to content

Commit 542a284

Browse files
committed
format
1 parent d2c60c9 commit 542a284

19 files changed

+192
-236
lines changed

tests/test_arith.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
from textwrap import dedent
2-
31
import mlir.extras.types as T
42
import pytest
53

64
from mlir.extras.ast.canonicalize import canonicalize
75
from mlir.extras.dialects.ext import arith
8-
from mlir.extras.dialects.ext.arith import Scalar
96
from mlir.extras.dialects.ext.func import func
10-
117
# noinspection PyUnresolvedReferences
128
from mlir.extras.testing import (
139
mlir_ctx as ctx,
@@ -50,8 +46,8 @@ def test_arithmetic(ctx: MLIRContext):
5046
one // two
5147
except ValueError as e:
5248
assert (
53-
str(e)
54-
== "floordiv not supported for lhs=Scalar(%cst = arith.constant 1.000000e+00 : f32)"
49+
str(e)
50+
== "floordiv not supported for lhs=Scalar(%cst = arith.constant 1.000000e+00 : f32)"
5551
)
5652
one % two
5753

tests/test_async.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import platform
2-
3-
import numpy as np
42
from textwrap import dedent
53

4+
import numpy as np
65
import pytest
76

87
from mlir.extras.runtime.passes import Pipeline
98
from mlir.extras.runtime.refbackend import LLVMJITBackend
10-
119
# noinspection PyUnresolvedReferences
1210
from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext, backend
1311

tests/test_func.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
import inspect
22
import sys
33
import threading
4-
from textwrap import dedent
54
from typing import TypeVar
65

7-
import pytest
8-
96
import mlir.extras.types as T
7+
import pytest
8+
from mlir.ir import FunctionType
109

1110
from mlir.extras.ast.canonicalize import canonicalize
1211
from mlir.extras.context import mlir_mod_ctx, RAIIMLIRContextModule
12+
from mlir.extras.dialects.ext import linalg, arith, scf
1313
from mlir.extras.dialects.ext.arith import constant
1414
from mlir.extras.dialects.ext.func import func
15-
from mlir.extras.dialects.ext import linalg, arith, scf, memref
16-
from mlir.ir import FunctionType
17-
1815
# noinspection PyUnresolvedReferences
1916
from mlir.extras.testing import (
2017
mlir_ctx as ctx,
@@ -46,7 +43,8 @@ def demo_fun1():
4643

4744

4845
def test_declare_byte_rep(ctx: MLIRContext):
49-
def demo_fun1(): ...
46+
def demo_fun1():
47+
...
5048

5149
if sys.version_info.minor == 13:
5250
assert demo_fun1.__code__.co_code == b"\x95\x00g\x00"
@@ -152,9 +150,9 @@ def foo1():
152150

153151
@func(generics=list(map(TypeVar, ["M", "N"])))
154152
def matmul_i32_i32(
155-
A: "T.memref(M, N, T.i32())",
156-
B: "T.memref(M, N, T.i32())",
157-
C: "T.memref(M, N, T.i32())",
153+
A: "T.memref(M, N, T.i32())",
154+
B: "T.memref(M, N, T.i32())",
155+
C: "T.memref(M, N, T.i32())",
158156
):
159157
linalg.matmul(A, B, C)
160158

@@ -171,12 +169,11 @@ def test_func_no_context_2(ctx: MLIRContext):
171169

172170

173171
def test_generics_just_args(ctx: MLIRContext):
174-
175172
@func(generics=generics)
176173
def mat_product_kernel(
177-
A: "T.memref(M, K, dtype)",
178-
B: "T.memref(K, N, dtype)",
179-
C: "T.memref(M, N, dtype)",
174+
A: "T.memref(M, K, dtype)",
175+
B: "T.memref(K, N, dtype)",
176+
C: "T.memref(M, N, dtype)",
180177
):
181178
one = arith.constant(1.0, dtype)
182179

@@ -195,9 +192,9 @@ def test_generics_closure(ctx: MLIRContext):
195192

196193
@func(generics=generics)
197194
def mat_product_kernel(
198-
A: "T.memref(M, K, dtype)",
199-
B: "T.memref(K, N, dtype)",
200-
C: "T.memref(M, N, dtype)",
195+
A: "T.memref(M, K, dtype)",
196+
B: "T.memref(K, N, dtype)",
197+
C: "T.memref(M, N, dtype)",
201198
):
202199
one = arith.constant(1, dtype)
203200

@@ -212,15 +209,14 @@ def mat_product_kernel(
212209

213210

214211
def test_generics_with_canonicalizations(ctx: MLIRContext):
215-
216212
generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"]))
217213

218214
@func(generics=generics)
219215
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
220216
def mat_product_kernel(
221-
A: "T.memref(M, K, dtype)",
222-
B: "T.memref(K, N, dtype)",
223-
C: "T.memref(M, N, dtype)",
217+
A: "T.memref(M, K, dtype)",
218+
B: "T.memref(K, N, dtype)",
219+
C: "T.memref(M, N, dtype)",
224220
):
225221
x = arith.constant(1, index=True)
226222
y = arith.constant(1, index=True)

tests/test_gpu.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from mlir.extras.ast.canonicalize import canonicalize
1818
from mlir.extras.dialects.ext import arith, scf, memref, rocdl, gpu
1919
from mlir.extras.dialects.ext.func import func
20-
2120
# noinspection PyUnresolvedReferences
2221
from mlir.extras.dialects.ext.gpu import (
2322
all_reduce,
@@ -38,7 +37,6 @@
3837
from mlir.extras.dialects.ext.scf import forall, in_parallel_
3938
from mlir.extras.dialects.ext.vector import outer, load, shuffle, print_
4039
from mlir.extras.runtime.passes import run_pipeline, Pipeline
41-
4240
# noinspection PyUnresolvedReferences
4341
from mlir.extras.testing import (
4442
mlir_ctx as ctx,
@@ -78,10 +76,10 @@ def test_forall_insert_slice_no_region_with_for_with_gpu_mapping(ctx: MLIRContex
7876
alpha = arith.constant(1, T.f32())
7977

8078
for i, j in forall(
81-
[1, 1],
82-
[2, 2],
83-
[3, 3],
84-
device_mapping=[thread("x"), thread("y")],
79+
[1, 1],
80+
[2, 2],
81+
[3, 3],
82+
device_mapping=[thread("x"), thread("y")],
8583
):
8684
a = memref.load(x, (i, j))
8785
b = memref.load(y, (i, j))
@@ -119,9 +117,9 @@ class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]):
119117
@gpu_func(emit=True)
120118
@canonicalize(using=scf.canonicalizer)
121119
def mat_product_kernel(
122-
A: T.memref(M, N, T.f32()),
123-
B: T.memref(N, K, T.f32()),
124-
C: T.memref(M, K, T.f32()),
120+
A: T.memref(M, N, T.f32()),
121+
B: T.memref(N, K, T.f32()),
122+
C: T.memref(M, K, T.f32()),
125123
):
126124
x = block_idx.x
127125
y = block_idx.y
@@ -156,9 +154,9 @@ class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]):
156154
@gpu_func(emit=True, emit_grid=True)
157155
@canonicalize(using=scf.canonicalizer)
158156
def mat_product_kernel(
159-
A: T.memref(M, N, T.f32()),
160-
B: T.memref(N, K, T.f32()),
161-
C: T.memref(M, K, T.f32()),
157+
A: T.memref(M, N, T.f32()),
158+
B: T.memref(N, K, T.f32()),
159+
C: T.memref(M, K, T.f32()),
162160
):
163161
x = block_idx.x
164162
y = block_idx.y
@@ -214,9 +212,9 @@ class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]):
214212
@gpu_func(emit=True, emit_grid=True)
215213
@canonicalize(using=scf.canonicalizer)
216214
def mat_product_kernel(
217-
A: T.memref(M, N, T.f32()),
218-
B: T.memref(N, K, T.f32()),
219-
C: T.memref(M, K, T.f32()),
215+
A: T.memref(M, N, T.f32()),
216+
B: T.memref(N, K, T.f32()),
217+
C: T.memref(M, K, T.f32()),
220218
):
221219
x = block_idx.x
222220
y = block_idx.y
@@ -283,9 +281,9 @@ class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]):
283281
@gpu_func(emit=True, emit_grid=True)
284282
@canonicalize(using=scf.canonicalizer)
285283
def mat_product_kernel(
286-
A: T.memref(M, N, T.f32()),
287-
B: T.memref(N, K, T.f32()),
288-
C: T.memref(M, K, T.f32()),
284+
A: T.memref(M, N, T.f32()),
285+
B: T.memref(N, K, T.f32()),
286+
C: T.memref(M, K, T.f32()),
289287
):
290288
x = block_idx.x
291289
y = block_idx.y
@@ -349,7 +347,7 @@ def main():
349347
data = memref.alloc((2, 6), T.i32())
350348
sum = memref.alloc((2,), T.i32())
351349

352-
power_csts = [arith.constant(0)] + [arith.constant(2**i) for i in range(5)]
350+
power_csts = [arith.constant(0)] + [arith.constant(2 ** i) for i in range(5)]
353351
odd_csts = [
354352
arith.constant(3),
355353
arith.constant(6),
@@ -440,7 +438,7 @@ def main():
440438
data = memref.alloc((2, 6), T.i32())
441439
sum = memref.alloc((2,), T.i32())
442440

443-
power_csts = [arith.constant(0)] + [arith.constant(2**i) for i in range(5)]
441+
power_csts = [arith.constant(0)] + [arith.constant(2 ** i) for i in range(5)]
444442
odd_csts = [
445443
arith.constant(3),
446444
arith.constant(6),
@@ -710,14 +708,13 @@ def _():
710708

711709

712710
def test_amdgpu(ctx: MLIRContext):
713-
714711
set_container_module(ctx.module)
715712

716713
M, K, N, dtype = 32, 32, 32, T.f32()
717714

718715
@gpu_func
719716
def mat_product_kernel(
720-
A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)
717+
A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)
721718
):
722719
x = block_dim.x * block_idx.x + thread_idx.x
723720
y = block_dim.y * block_idx.y + thread_idx.y
@@ -829,7 +826,7 @@ def test_amdgpu_square(ctx: MLIRContext):
829826

830827
@gpu_func
831828
def mat_product_kernel(
832-
A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)
829+
A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)
833830
):
834831
x = block_dim.x * block_idx.x + thread_idx.x
835832
y = block_dim.y * block_idx.y + thread_idx.y
@@ -943,9 +940,9 @@ def test_amdgpu_vector(ctx: MLIRContext):
943940

944941
@gpu_func
945942
def smol_matmul(
946-
A: T.memref(M, K, T.f32()),
947-
B: T.memref(K, N, T.f32()),
948-
C: T.memref(M, N, T.f32()),
943+
A: T.memref(M, K, T.f32()),
944+
B: T.memref(K, N, T.f32()),
945+
C: T.memref(M, N, T.f32()),
949946
):
950947
cst = arith.constant(np.full([4], 0.0, np.float32), T.vector(4, T.f32()))
951948
cst_0 = arith.constant(
@@ -1186,9 +1183,9 @@ def test_amdgpu_vector_wmma(ctx: MLIRContext):
11861183
@gpu_func
11871184
@canonicalize(using=scf.canonicalizer)
11881185
def smol_matmul(
1189-
a: T.memref(M, K, T.f16()),
1190-
b: T.memref(K, N, T.f16()),
1191-
c: T.memref(M, N, T.f16()),
1186+
a: T.memref(M, K, T.f16()),
1187+
b: T.memref(K, N, T.f16()),
1188+
c: T.memref(M, N, T.f16()),
11921189
):
11931190
lIdx = thread_idx.x
11941191
# a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b

tests/test_linalg.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from textwrap import dedent
2-
1+
import mlir.extras.types as T
32
import pytest
43

5-
import mlir.extras.types as T
64
from mlir.extras.dialects.ext import linalg, memref, tensor
7-
85
# noinspection PyUnresolvedReferences
96
from mlir.extras.testing import (
107
MLIRContext,

tests/test_llvm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from textwrap import dedent
22

3+
import mlir.extras.types as T
34
import pytest
45

5-
import mlir.extras.types as T
66
from mlir.extras.dialects.ext import llvm
77
from mlir.extras.dialects.ext.func import func
8-
98
# noinspection PyUnresolvedReferences
109
from mlir.extras.testing import MLIRContext, filecheck, mlir_ctx as ctx
1110
from util import llvm_bindings_not_installed

tests/test_memref.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
yield_,
2525
canonicalizer,
2626
)
27-
2827
# noinspection PyUnresolvedReferences
2928
from mlir.extras.testing import (
3029
mlir_ctx as ctx,
@@ -187,8 +186,8 @@ def test_ellipsis_and_full_slice_plus_coordinate_1(ctx: MLIRContext):
187186
w = mem[1, :, :, :, :]
188187
except IndexError as e:
189188
assert (
190-
str(e)
191-
== "Too many indices for shaped type with rank: 5 non-None/Ellipsis indices for dim 4."
189+
str(e)
190+
== "Too many indices for shaped type with rank: 5 non-None/Ellipsis indices for dim 4."
192191
)
193192

194193

@@ -205,7 +204,7 @@ def test_ellipsis_and_full_slice_plus_coordinate_2(ctx: MLIRContext):
205204

206205
golden_w_1_strides = (np.array(golden_w_1.strides) // dtype_size_in_bytes).tolist()
207206
golden_w_1_rank_reduce_strides = (
208-
np.array(golden_w_1_rank_reduce.strides) // dtype_size_in_bytes
207+
np.array(golden_w_1_rank_reduce.strides) // dtype_size_in_bytes
209208
).tolist()
210209
golden_w_2_strides = (np.array(golden_w_2.strides) // dtype_size_in_bytes).tolist()
211210
golden_w_3_strides = (np.array(golden_w_3.strides) // dtype_size_in_bytes).tolist()
@@ -214,7 +213,7 @@ def test_ellipsis_and_full_slice_plus_coordinate_2(ctx: MLIRContext):
214213

215214
golden_w_1_offset = get_np_view_offset(golden_w_1) // dtype_size_in_bytes
216215
golden_w_1_rank_reduce_offset = (
217-
get_np_view_offset(golden_w_1_rank_reduce) // dtype_size_in_bytes
216+
get_np_view_offset(golden_w_1_rank_reduce) // dtype_size_in_bytes
218217
)
219218
golden_w_2_offset = get_np_view_offset(golden_w_2) // dtype_size_in_bytes
220219
golden_w_3_offset = get_np_view_offset(golden_w_3) // dtype_size_in_bytes
@@ -276,10 +275,10 @@ def test_ellipsis_and_full_slice_plus_coordinate_3(ctx: MLIRContext):
276275
golden_w_8_strides = (np.array(golden_w_8.strides) // dtype_size_in_bytes).tolist()
277276
golden_w_9_strides = (np.array(golden_w_9.strides) // dtype_size_in_bytes).tolist()
278277
golden_w_10_strides = (
279-
np.array(golden_w_10.strides) // dtype_size_in_bytes
278+
np.array(golden_w_10.strides) // dtype_size_in_bytes
280279
).tolist()
281280
golden_w_11_strides = (
282-
np.array(golden_w_11.strides) // dtype_size_in_bytes
281+
np.array(golden_w_11.strides) // dtype_size_in_bytes
283282
).tolist()
284283

285284
golden_w_1_offset = get_np_view_offset(golden_w_1) // dtype_size_in_bytes

0 commit comments

Comments
 (0)