|
1 | 1 | import re |
| 2 | +import subprocess |
2 | 3 | from pathlib import Path |
3 | 4 | from textwrap import dedent |
4 | 5 |
|
|
8 | 9 | from mlir.dialects.memref import cast |
9 | 10 | from mlir.dialects.nvgpu import ( |
10 | 11 | TensorMapDescriptorType, |
11 | | - TensorMapSwizzleKind, |
| 12 | + TensorMapInterleaveKind, |
12 | 13 | TensorMapL2PromoKind, |
13 | 14 | TensorMapOOBKind, |
14 | | - TensorMapInterleaveKind, |
| 15 | + TensorMapSwizzleKind, |
| 16 | + tma_create_descriptor, |
15 | 17 | ) |
16 | | -from mlir.dialects.nvgpu import tma_create_descriptor |
17 | 18 | from mlir.dialects.transform import any_op_t |
18 | 19 | from mlir.dialects.transform.extras import named_sequence |
19 | 20 | from mlir.dialects.transform.structured import MatchInterfaceEnum |
20 | 21 | from mlir.ir import StringAttr, UnitAttr |
21 | 22 |
|
22 | 23 | from mlir import _mlir_libs |
23 | 24 | from mlir.extras.ast.canonicalize import canonicalize |
24 | | -from mlir.extras.dialects.ext import arith, memref, scf, gpu, linalg, transform, nvgpu |
| 25 | +from mlir.extras.dialects.ext import arith, gpu, linalg, memref, nvgpu, scf, transform |
25 | 26 | from mlir.extras.dialects.ext.func import func |
26 | 27 | from mlir.extras.dialects.ext.gpu import smem_space |
27 | 28 | from mlir.extras.dialects.ext.llvm import llvm_ptr_t |
28 | | -from mlir.extras.runtime.passes import run_pipeline, Pipeline |
| 29 | +from mlir.extras.runtime.passes import Pipeline, run_pipeline |
29 | 30 | from mlir.extras.runtime.refbackend import LLVMJITBackend |
30 | 31 |
|
31 | 32 | # noinspection PyUnresolvedReferences |
32 | | -from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext |
| 33 | +from mlir.extras.testing import MLIRContext, filecheck, mlir_ctx as ctx |
33 | 34 | from mlir.extras.util import find_ops |
34 | 35 |
|
35 | 36 | # needed since the fix isn't defined here nor conftest.py |
@@ -200,7 +201,8 @@ def payload(): |
200 | 201 | compute_linspace_val.emit() |
201 | 202 |
|
202 | 203 | @func |
203 | | - def printMemrefF32(x: T.memref(T.f32())): ... |
| 204 | + def printMemrefF32(x: T.memref(T.f32())): |
| 205 | + ... |
204 | 206 |
|
205 | 207 | printMemrefF32_.append(printMemrefF32) |
206 | 208 |
|
@@ -421,8 +423,15 @@ def main(module: any_op_t()): |
421 | 423 | CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f"libmlir_cuda_runtime.so" |
422 | 424 |
|
423 | 425 |
|
| 426 | +NVIDIA_GPU = False |
| 427 | +try: |
| 428 | + subprocess.check_output("nvidia-smi") |
| 429 | + NVIDIA_GPU = True |
| 430 | +except Exception: |
| 431 | + print("No Nvidia GPU in system!") |
| 432 | + |
424 | 433 | # based on https://github.com/llvm/llvm-project/blob/9cc2122bf5a81f7063c2a32b2cb78c8d615578a1/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir#L6 |
425 | | -@pytest.mark.skipif(not CUDA_RUNTIME_LIB_PATH.exists(), reason="no cuda library") |
| 434 | +@pytest.mark.skipif(not NVIDIA_GPU, reason="no cuda library") |
426 | 435 | def test_transform_mma_sync_matmul_f16_f16_accum_run(ctx: MLIRContext, capfd): |
427 | 436 | range_ = scf.range_ |
428 | 437 |
|
@@ -549,7 +558,8 @@ def payload(): |
549 | 558 | compute_linspace_val.emit() |
550 | 559 |
|
551 | 560 | @func |
552 | | - def printMemrefF32(x: T.memref(T.f32())): ... |
| 561 | + def printMemrefF32(x: T.memref(T.f32())): |
| 562 | + ... |
553 | 563 |
|
554 | 564 | printMemrefF32_.append(printMemrefF32) |
555 | 565 |
|
|
0 commit comments