Skip to content

Commit d6510db

Browse files
committed
format
1 parent d2c60c9 commit d6510db

16 files changed

+24
-50
lines changed

tests/test_arith.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
from textwrap import dedent
2-
31
import mlir.extras.types as T
42
import pytest
53

64
from mlir.extras.ast.canonicalize import canonicalize
75
from mlir.extras.dialects.ext import arith
8-
from mlir.extras.dialects.ext.arith import Scalar
96
from mlir.extras.dialects.ext.func import func
107

118
# noinspection PyUnresolvedReferences

tests/test_async.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import platform
2-
3-
import numpy as np
42
from textwrap import dedent
53

4+
import numpy as np
65
import pytest
76

87
from mlir.extras.runtime.passes import Pipeline

tests/test_func.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
import inspect
22
import sys
33
import threading
4-
from textwrap import dedent
54
from typing import TypeVar
65

7-
import pytest
8-
96
import mlir.extras.types as T
7+
import pytest
8+
from mlir.ir import FunctionType
109

1110
from mlir.extras.ast.canonicalize import canonicalize
1211
from mlir.extras.context import mlir_mod_ctx, RAIIMLIRContextModule
12+
from mlir.extras.dialects.ext import linalg, arith, scf
1313
from mlir.extras.dialects.ext.arith import constant
1414
from mlir.extras.dialects.ext.func import func
15-
from mlir.extras.dialects.ext import linalg, arith, scf, memref
16-
from mlir.ir import FunctionType
1715

1816
# noinspection PyUnresolvedReferences
1917
from mlir.extras.testing import (
@@ -171,7 +169,6 @@ def test_func_no_context_2(ctx: MLIRContext):
171169

172170

173171
def test_generics_just_args(ctx: MLIRContext):
174-
175172
@func(generics=generics)
176173
def mat_product_kernel(
177174
A: "T.memref(M, K, dtype)",
@@ -212,7 +209,6 @@ def mat_product_kernel(
212209

213210

214211
def test_generics_with_canonicalizations(ctx: MLIRContext):
215-
216212
generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"]))
217213

218214
@func(generics=generics)

tests/test_gpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,6 @@ def _():
710710

711711

712712
def test_amdgpu(ctx: MLIRContext):
713-
714713
set_container_module(ctx.module)
715714

716715
M, K, N, dtype = 32, 32, 32, T.f32()

tests/test_linalg.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from textwrap import dedent
2-
1+
import mlir.extras.types as T
32
import pytest
43

5-
import mlir.extras.types as T
64
from mlir.extras.dialects.ext import linalg, memref, tensor
75

86
# noinspection PyUnresolvedReferences

tests/test_llvm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from textwrap import dedent
22

3+
import mlir.extras.types as T
34
import pytest
45

5-
import mlir.extras.types as T
66
from mlir.extras.dialects.ext import llvm
77
from mlir.extras.dialects.ext.func import func
88

tests/test_nvgpu_nvvm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,6 @@ def main(module: any_op_t()):
418418

419419
CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f"libmlir_cuda_runtime.so"
420420

421-
422421
NVIDIA_GPU = False
423422
try:
424423
subprocess.check_output("nvidia-smi")
@@ -669,7 +668,6 @@ def main(module: any_op_t()):
669668

670669

671670
def test_tma(ctx: MLIRContext):
672-
673671
M = K = N = 64
674672

675673
@gpu.func

tests/test_other_hosts.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import pytest
22

3-
from util import jax_not_installed, mlir_bindings_not_installed
3+
from util import jax_not_installed
44

55

66
@pytest.mark.skipif(jax_not_installed(), reason="jax not installed")
77
def test_jax_trampolines_smoke():
8-
from mlir import ir
98
from jaxlib.mlir import ir
109

1110
# noinspection PyUnresolvedReferences

tests/test_regions.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
1-
from textwrap import dedent
2-
3-
import pytest
4-
51
import mlir.extras.types as T
2+
import pytest
3+
from mlir.dialects.builtin import module
64
from mlir.dialects.func import return_
75
from mlir.dialects.memref import alloca_scope, alloca_scope_return
86
from mlir.dialects.scf import yield_ as scf_yield
97
from mlir.dialects.tensor import rank, yield_ as tensor_yield
10-
from mlir.dialects.builtin import module
8+
from mlir.extras.types import tensor
9+
1110
from mlir.extras.dialects.ext import linalg, memref
1211
from mlir.extras.dialects.ext.arith import constant
1312
from mlir.extras.dialects.ext.cf import br, cond_br
1413
from mlir.extras.dialects.ext.func import func
1514
from mlir.extras.dialects.ext.memref import alloca_scope
1615
from mlir.extras.dialects.ext.scf import execute_region
1716
from mlir.extras.dialects.ext.tensor import S, generate
18-
from mlir.extras.util import bb
1917

2018
# noinspection PyUnresolvedReferences
2119
from mlir.extras.testing import (
@@ -24,7 +22,7 @@
2422
filecheck_with_comments,
2523
MLIRContext,
2624
)
27-
from mlir.extras.types import tensor
25+
from mlir.extras.util import bb
2826

2927
# needed since the fix isn't defined here nor conftest.py
3028
pytest.mark.usefixtures("ctx")
@@ -409,7 +407,6 @@ def matmul_i16_i16(
409407

410408

411409
def test_defer_emit_1(ctx: MLIRContext):
412-
413410
matmul_i16_i16.emit(decl=True)
414411

415412
@module
@@ -428,7 +425,6 @@ def mod():
428425

429426

430427
def test_defer_emit_2(ctx: MLIRContext):
431-
432428
matmul_i16_i16.emit(force=True)
433429

434430
@module
@@ -459,7 +455,6 @@ def matmul_i16_i16(
459455

460456

461457
def test_defer_emit_3(ctx: MLIRContext):
462-
463458
matmul_i16_i16.emit(force=True)
464459

465460
@module

tests/test_runtime.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,24 @@
33
import re
44
from textwrap import dedent
55

6+
import mlir.extras.types as T
67
import numpy as np
78
import pytest
9+
from mlir.dialects.arith import sitofp, index_cast
10+
from mlir.dialects.memref import cast
811
from mlir.ir import UnitAttr, Module, StridedLayoutAttr, InsertionPoint, Context
912
from mlir.runtime import get_unranked_memref_descriptor, get_ranked_memref_descriptor
1013

11-
import mlir.extras.types as T
1214
from mlir.extras.ast.canonicalize import canonicalize
1315
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
1416
from mlir.extras.dialects.ext import linalg
15-
from mlir.dialects.arith import sitofp, index_cast
1617
from mlir.extras.dialects.ext.arith import constant
1718
from mlir.extras.dialects.ext.func import func
1819
from mlir.extras.dialects.ext.memref import load, store, S
1920
from mlir.extras.dialects.ext.scf import (
2021
canonicalizer,
2122
range_,
2223
)
23-
from mlir.dialects.memref import cast
2424
from mlir.extras.runtime.passes import Pipeline, run_pipeline
2525
from mlir.extras.runtime.refbackend import (
2626
LLVMJITBackend,

0 commit comments

Comments
 (0)