From 27d48eb17cf7f5234a57338e4f1026dd763736dc Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Tue, 9 Jul 2024 16:04:58 +0800 Subject: [PATCH 1/7] add link to tpp --- CMakeLists.txt | 14 ++++++++ include/gc/Transforms/Passes.td | 17 ++++++++++ lib/gc/Transforms/CMakeLists.txt | 4 +++ lib/gc/Transforms/Pipeline.cpp | 32 +++++++++++++++++++ .../mlir/test/gc/Transforms/Pipeline/gpu.mlir | 12 +++++++ 5 files changed, 79 insertions(+) create mode 100644 test/mlir/test/gc/Transforms/Pipeline/gpu.mlir diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d82a7af8..b5f7be5f7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,20 @@ include_directories( ${PROJECT_SOURCE_DIR}/include ) +if(TPP_DIR) + message(STATUS "Using TPP_DIR in: ${TPP_DIR}") + add_definitions("-DTPP_ENABLED") + include_directories(${TPP_DIR}/include) + include_directories(${TPP_DIR}/build/include) + link_directories(${TPP_DIR}/build) + link_directories(${TPP_DIR}/build/lib) + set(TPP_AVAILABLE_LIBS + TPPCheckDialect TPPCheckToLoops TPPGPU TPPIR TPPLinalgToFunc TPPLinalgToXSMM TPPPassBundles + TPPPerfDialect TPPPerfToFunc TPPPerfToLoop TPPPipeline TPPRunner TPPTestLib TPPTransforms + TPPTransformsUtils TPPXsmmDialect tpp_xsmm_runner_utils TPPXsmmToFunc xsmm + ) +endif() + # The paths are added in the subfolders using the gc_add_path() function. # These lists are also used by tests. set(GC_LIB_SOURCES CACHE INTERNAL "The graph_compiler library source paths") diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index aaea602b6..4663f59b8 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -46,4 +46,21 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> { "vector::VectorDialect"]; } +def GCGPUPipeline: Pass<"gc-gpu-pipeline"> { + let summary = "All-in-one pipeline for GC for GPU"; + let dependentDialects = ["onednn_graph::OneDNNGraphDialect", + "tensor::TensorDialect", + "memref::MemRefDialect", + "linalg::LinalgDialect", + "linalgx::LinalgxDialect", + "LLVM::LLVMDialect", + "scf::SCFDialect", + "bufferization::BufferizationDialect", + "omp::OpenMPDialect", + "gpu::GPUDialect", + "xegpu::XeGPUDialect", + "math::MathDialect", + "vector::VectorDialect"]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 7be337566..610052ad9 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -25,4 +25,8 @@ add_mlir_library(GCPasses MLIROneDNNGraph ) +if(TPP_DIR) + target_link_libraries(GCPasses PRIVATE ${TPP_AVAILABLE_LIBS}) +endif() + set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCPasses) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 6e5151e9e..e4b455c61 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -13,17 +13,23 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" +#ifdef TPP_ENABLED +#include "TPP/Passes.h" +#endif + #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/Linalgx/LinalgxDialect.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" @@ -143,7 +149,20 @@ void populateCPUPipeline(mlir::PassManager &pm) { populateLLVMPasses(pm); } +void populateGPUPipeline(mlir::PassManager &pm) { + // middle-end, arith/math/vector dialects + populateVectorPasses(pm); + // back-end, arith/math/vector/memref dialects + populateBufferizationPasses(pm); +#ifdef TPP_ENABLED + // + pm.addNestedPass( + tpp::createLinalgToXeGPU(tpp::LinalgToXeGPUOptions{32, 1, 16})); +#endif +} + #define GEN_PASS_DEF_GCCPUPIPELINE +#define GEN_PASS_DEF_GCGPUPIPELINE #include "gc/Transforms/Passes.h.inc" namespace { @@ -162,5 +181,18 @@ class GCCPUPipeline : public impl::GCCPUPipelineBase { } }; +class GCGPUPipeline : public impl::GCGPUPipelineBase { +public: + friend struct PassHelper; + using impl::GCGPUPipelineBase::GCGPUPipelineBase; + void runOnOperation() final { + auto op = getOperation(); + PassManager pm{op->getContext()}; + populateGPUPipeline(pm); + if (failed(pm.run(op))) + signalPassFailure(); + } +}; + } // namespace } // namespace mlir::gc diff --git a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir new file mode 100644 index 000000000..e2aa6dd46 --- /dev/null +++ b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir @@ -0,0 +1,12 @@ +// RUN: gc-opt --gc-gpu-pipeline %s | FileCheck %s + +func.func @mlp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x16xf32>) -> tensor<8x16xf32> + %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<8x16xf32>, tensor<16x16xf32>) + outs(%1 : tensor<8x16xf32>) -> tensor<8x16xf32> + %3 = tensor.empty() : tensor<8x16xf32> + %4 = linalg.add ins(%arg2, %2 : tensor<8x16xf32>, tensor<8x16xf32>) outs(%3 : tensor<8x16xf32>) -> tensor<8x16xf32> + return %4 : tensor<8x16xf32> +} From 9460cd66bcf7d9316a631b09bce98cd864731f06 Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Tue, 9 Jul 2024 16:05:33 +0800 Subject: [PATCH 2/7] update --- .../mlir/test/gc/Transforms/Pipeline/gpu.mlir | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir index e2aa6dd46..eeb2a81a8 100644 --- a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir +++ b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir @@ -10,3 +10,23 @@ func.func @mlp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor< %4 = linalg.add ins(%arg2, %2 : tensor<8x16xf32>, tensor<8x16xf32>) outs(%3 : tensor<8x16xf32>) -> tensor<8x16xf32> return %4 : tensor<8x16xf32> } + +// func.func @mlp(%arg0: memref<8x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<8x16xf32>, %arg3: memref<8x16xf32>) { +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : f32 +// %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x16xf32> +// linalg.fill ins(%cst : f32) outs(%alloc : memref<8x16xf32>) +// linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<8x16xf32>, memref<16x16xf32>) outs(%alloc : memref<8x16xf32>) +// %0 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %1 = xegpu.update_nd_offset %0, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %2 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %3 = xegpu.create_nd_tdesc %alloc[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %4 = xegpu.update_nd_offset %3, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %6 = arith.addf %2, %5 : vector<8x16xf32> +// %7 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %8 = xegpu.update_nd_offset %7, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// xegpu.store_nd %6, %8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// memref.dealloc %alloc : memref<8x16xf32> +// return +// } \ No newline at end of file From b75864ee01b7ce8a288189e5bab090ae9414e741 Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Tue, 9 Jul 2024 16:42:28 +0800 Subject: [PATCH 3/7] update --- include/gc/Transforms/Passes.td | 10 ++++++++++ lib/gc/Transforms/Pipeline.cpp | 10 +++++----- test/mlir/test/gc/Transforms/Pipeline/gpu.mlir | 2 +- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 4663f59b8..678f4969c 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -61,6 +61,16 @@ def GCGPUPipeline: Pass<"gc-gpu-pipeline"> { "xegpu::XeGPUDialect", "math::MathDialect", "vector::VectorDialect"]; + let options = [ + Option<"kTile", "k-tile", "int64_t", + /*default=*/"32", + "GEMM tile size for reduction dimension.">, + Option<"stages", "stages", "int64_t", + /*default=*/"1", + "Number of cooperative prefetch stages.">, + ListOption<"dpasTile", "dpas-tile", "int64_t", + "DPAS register block sizes MxNxK">, + ]; } #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index e4b455c61..b6b6aab40 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -149,15 +149,14 @@ void populateCPUPipeline(mlir::PassManager &pm) { populateLLVMPasses(pm); } -void populateGPUPipeline(mlir::PassManager &pm) { +void populateGPUPipeline(mlir::PassManager &pm, + tpp::LinalgToXeGPUOptions options) { // middle-end, arith/math/vector dialects populateVectorPasses(pm); // back-end, arith/math/vector/memref dialects populateBufferizationPasses(pm); #ifdef TPP_ENABLED - // - pm.addNestedPass( - tpp::createLinalgToXeGPU(tpp::LinalgToXeGPUOptions{32, 1, 16})); + pm.addNestedPass(tpp::createLinalgToXeGPU(options)); #endif } @@ -188,7 +187,8 @@ class GCGPUPipeline : public impl::GCGPUPipelineBase { void runOnOperation() final { auto op = getOperation(); PassManager pm{op->getContext()}; - populateGPUPipeline(pm); + tpp::LinalgToXeGPUOptions options{kTile, stages, dpasTile}; + populateGPUPipeline(pm, options); if (failed(pm.run(op))) signalPassFailure(); } diff --git a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir index eeb2a81a8..c19483f7d 100644 --- a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir +++ b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt --gc-gpu-pipeline %s | FileCheck %s +// RUN: gc-opt --gc-gpu-pipeline="dpas-tile=8,16,16 k-tile=16" -canonicalize %s | FileCheck %s func.func @mlp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { %cst = arith.constant 0.000000e+00 : f32 From 6ee1a1e0c499783808880f5a776b7e1cd695572e Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Tue, 9 Jul 2024 16:50:39 +0800 Subject: [PATCH 4/7] update --- .../mlir/test/gc/Transforms/Pipeline/gpu.mlir | 86 ++++++++++++++++--- 1 file changed, 73 insertions(+), 13 deletions(-) diff --git a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir index c19483f7d..07ddacdd2 100644 --- a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir +++ b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir @@ -1,32 +1,92 @@ // RUN: gc-opt --gc-gpu-pipeline="dpas-tile=8,16,16 k-tile=16" -canonicalize %s | FileCheck %s -func.func @mlp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { +func.func @matmul(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { + linalg.matmul ins(%arg0, %arg1 : memref<8x16xf16>, memref<16x16xf16>) + outs(%arg2 : memref<8x16xf32>) + return +} + +// func.func @matmul(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { +// %c1024 = arith.constant 1024 : index +// %c16 = arith.constant 16 : index +// %c0 = arith.constant 0 : index +// %0 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %1 = xegpu.update_nd_offset %0, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %3 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %4 = xegpu.update_nd_offset %3, [0, 0] : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %6 = xegpu.update_nd_offset %5, [0, 0] : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %7:3 = scf.for %arg3 = %c0 to %c16 step %c16 iter_args(%arg4 = %2, %arg5 = %4, %arg6 = %6) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr>) { +// %8 = arith.remui %arg3, %c1024 : index +// %9 = arith.cmpi eq, %8, %c0 : index +// scf.if %9 { +// gpu.barrier +// } +// %10 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, vnni_axis = 1 : i64}> : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> -> vector<8x8x2xf16> +// %11 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, vnni_axis = 0 : i64}> : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> -> vector<8x16x2xf16> +// %12 = xegpu.update_nd_offset %arg5, [0, 16] : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %13 = xegpu.update_nd_offset %arg6, [16, 0] : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// xegpu.prefetch_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// xegpu.prefetch_nd %13 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %14 = xegpu.dpas %10, %11, %arg4 : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> +// scf.yield %14, %12, %13 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// } +// xegpu.store_nd %7#0, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// return +// } + +func.func @mlp(%arg0: tensor<8x16xf16>, %arg1: tensor<16x16xf16>, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<8x16xf32> %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x16xf32>) -> tensor<8x16xf32> - %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<8x16xf32>, tensor<16x16xf32>) + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<8x16xf16>, tensor<16x16xf16>) outs(%1 : tensor<8x16xf32>) -> tensor<8x16xf32> %3 = tensor.empty() : tensor<8x16xf32> %4 = linalg.add ins(%arg2, %2 : tensor<8x16xf32>, tensor<8x16xf32>) outs(%3 : tensor<8x16xf32>) -> tensor<8x16xf32> return %4 : tensor<8x16xf32> } -// func.func @mlp(%arg0: memref<8x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<8x16xf32>, %arg3: memref<8x16xf32>) { +// func.func @mlp(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>, %arg3: memref<8x16xf32>) { +// %c1024 = arith.constant 1024 : index +// %c16 = arith.constant 16 : index // %c0 = arith.constant 0 : index // %cst = arith.constant 0.000000e+00 : f32 // %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x16xf32> // linalg.fill ins(%cst : f32) outs(%alloc : memref<8x16xf32>) -// linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<8x16xf32>, memref<16x16xf32>) outs(%alloc : memref<8x16xf32>) -// %0 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %0 = xegpu.create_nd_tdesc %alloc[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> // %1 = xegpu.update_nd_offset %0, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -// %2 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> -// %3 = xegpu.create_nd_tdesc %alloc[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -// %4 = xegpu.update_nd_offset %3, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -// %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> -// %6 = arith.addf %2, %5 : vector<8x16xf32> -// %7 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -// %8 = xegpu.update_nd_offset %7, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -// xegpu.store_nd %6, %8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %3 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %4 = xegpu.update_nd_offset %3, [0, 0] : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %6 = xegpu.update_nd_offset %5, [0, 0] : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %7:3 = scf.for %arg4 = %c0 to %c16 step %c16 iter_args(%arg5 = %2, %arg6 = %4, %arg7 = %6) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr>) { +// %17 = arith.remui %arg4, %c1024 : index +// %18 = arith.cmpi eq, %17, %c0 : index +// scf.if %18 { +// gpu.barrier +// } +// %19 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, vnni_axis = 1 : i64}> : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> -> vector<8x8x2xf16> +// %20 = xegpu.load_nd %arg7 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, vnni_axis = 0 : i64}> : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> -> vector<8x16x2xf16> +// %21 = xegpu.update_nd_offset %arg6, [0, 16] : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %22 = xegpu.update_nd_offset %arg7, [16, 0] : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// xegpu.prefetch_nd %21 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// xegpu.prefetch_nd %22 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %23 = xegpu.dpas %19, %20, %arg5 : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> +// scf.yield %23, %21, %22 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// } +// xegpu.store_nd %7#0, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %8 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %9 = xegpu.update_nd_offset %8, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %10 = xegpu.load_nd %9 : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %11 = xegpu.create_nd_tdesc %alloc[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %12 = xegpu.update_nd_offset %11, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %13 = xegpu.load_nd %12 : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %14 = arith.addf %10, %13 : vector<8x16xf32> +// %15 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %16 = xegpu.update_nd_offset %15, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// xegpu.store_nd %14, %16 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> // memref.dealloc %alloc : memref<8x16xf32> // return // } \ No newline at end of file From efbdc48e92fa24e67647bd81e10ac9162297f479 Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Tue, 9 Jul 2024 17:00:22 +0800 Subject: [PATCH 5/7] update --- lib/gc/Transforms/Pipeline.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index b6b6aab40..9eb3bd19b 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -149,16 +149,16 @@ void populateCPUPipeline(mlir::PassManager &pm) { populateLLVMPasses(pm); } +#ifdef TPP_ENABLED void populateGPUPipeline(mlir::PassManager &pm, tpp::LinalgToXeGPUOptions options) { // middle-end, arith/math/vector dialects populateVectorPasses(pm); // back-end, arith/math/vector/memref dialects populateBufferizationPasses(pm); -#ifdef TPP_ENABLED pm.addNestedPass(tpp::createLinalgToXeGPU(options)); -#endif } +#endif #define GEN_PASS_DEF_GCCPUPIPELINE #define GEN_PASS_DEF_GCGPUPIPELINE @@ -186,11 +186,16 @@ class GCGPUPipeline : public impl::GCGPUPipelineBase { using impl::GCGPUPipelineBase::GCGPUPipelineBase; void runOnOperation() final { auto op = getOperation(); +#ifdef TPP_ENABLED PassManager pm{op->getContext()}; tpp::LinalgToXeGPUOptions options{kTile, stages, dpasTile}; populateGPUPipeline(pm, options); if (failed(pm.run(op))) signalPassFailure(); +#elif + op->emitError() << "No TPP passes.\n"; + signalPassFailure(); +#endif } }; From 1c226c0239244c70e668c9e15e051a187a0f5687 Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Tue, 9 Jul 2024 17:01:26 +0800 Subject: [PATCH 6/7] update --- test/mlir/test/gc/Transforms/Pipeline/gpu.mlir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir index 07ddacdd2..ac02c18a8 100644 --- a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir +++ b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir @@ -1,4 +1,5 @@ -// RUN: gc-opt --gc-gpu-pipeline="dpas-tile=8,16,16 k-tile=16" -canonicalize %s | FileCheck %s +// RUN: gc-opt %s -o=/dev/null 2>&1 +// gc-opt --gc-gpu-pipeline="dpas-tile=8,16,16 k-tile=16" -canonicalize %s | FileCheck %s func.func @matmul(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { linalg.matmul ins(%arg0, %arg1 : memref<8x16xf16>, memref<16x16xf16>) From 3980a9bb02612a3938c41c544e3e45c56347991f Mon Sep 17 00:00:00 2001 From: Longsheng Du Date: Tue, 9 Jul 2024 17:04:46 +0800 Subject: [PATCH 7/7] fix --- lib/gc/Transforms/Pipeline.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 9eb3bd19b..66f273007 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -192,7 +192,7 @@ class GCGPUPipeline : public impl::GCGPUPipelineBase { populateGPUPipeline(pm, options); if (failed(pm.run(op))) signalPassFailure(); -#elif +#else op->emitError() << "No TPP passes.\n"; signalPassFailure(); #endif