diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index 243a6f4f6..e5e4aee33 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -12,8 +12,34 @@ #include "mlir/Pass/Pass.h" namespace mlir { + +namespace LLVM { +class LLVMDialect; +} + +namespace scf { +class SCFDialect; +} + +namespace openmp { +class OpenMPDialect; +} + +namespace linalg { +class LinalgDialect; +} + +namespace MemRef { +class MemRefDialect; +} + +class PassManager; + namespace gc { +void populateFrontendPasses(mlir::PassManager &); +void populateCPUPipeline(mlir::PassManager &); + #define GEN_PASS_DECL #include "gc/Transforms/Passes.h.inc" diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index d31baa5a7..ff0cd8a90 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -31,4 +31,17 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { ]; } +def GCCPUPipeline: Pass<"gc-cpu-pipeline"> { + let summary = "All-in-one pipeline for GC for CPU"; + let dependentDialects = ["onednn_graph::OneDNNGraphDialect", + "tensor::TensorDialect", + "memref::MemRefDialect", + "linalg::LinalgDialect", + "LLVM::LLVMDialect", + "scf::SCFDialect", + "bufferization::BufferizationDialect", + "omp::OpenMPDialect", + "vector::VectorDialect"]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/CMakeLists.txt b/lib/gc/CMakeLists.txt index ea92ba80e..03f7023b8 100644 --- a/lib/gc/CMakeLists.txt +++ b/lib/gc/CMakeLists.txt @@ -6,4 +6,5 @@ include(functions) add_subdirectory(CAPI) add_subdirectory(Dialect) -add_subdirectory(Transforms) \ No newline at end of file +add_subdirectory(Transforms) +add_subdirectory(ExecutionEngine) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt new file mode 100644 index 000000000..8aa223412 --- /dev/null +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(CPURuntime) diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt new file mode 100644 index 000000000..6be58e28f --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -0,0 +1,15 @@ +find_package(OpenMP REQUIRED) + +if ("iomp" IN_LIST OpenMP_C_LIB_NAMES OR "omp" IN_LIST OpenMP_C_LIB_NAMES OR "omp5" IN_LIST OpenMP_C_LIB_NAMES) +else() + add_definitions("-DGC_NEEDS_OMP_WRAPPER=1") +endif() + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +add_mlir_library(GCCpuRuntime + SHARED + Parallel.cpp + + EXCLUDE_FROM_LIBMLIR + ) \ No newline at end of file diff --git a/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp new file mode 100644 index 000000000..3a5b4c2c1 --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/Parallel.cpp @@ -0,0 +1,188 @@ +//===-- Parallel.cpp - parallel ---------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +#define WEAK_SYMBOL __attribute__((weak)) + +namespace { +struct barrier_t { + alignas(64) std::atomic pending_; + std::atomic rounds_; + uint64_t total_; + // pad barrier to size of cacheline to avoid false sharing + char padding_[64 - 4 * sizeof(int32_t)]; +}; + +using barrier_idle_func = uint64_t (*)(std::atomic *remaining, + int32_t expected_remain, int32_t tid, + void *args); +} // namespace + +extern "C" { +int gc_runtime_keep_alive = 0; +void gc_arrive_at_barrier(barrier_t *b, barrier_idle_func idle_func, + void *idle_args) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + assert(cnt >= 0); + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + if (idle_func) { + if (cur_round != b->rounds_.load()) { + return; + } + idle_func(&b->rounds_, cur_round + 1, -1, idle_args); + } + while (cur_round == b->rounds_.load()) { + _mm_pause(); + } + } +} + +static_assert(sizeof(barrier_t) == 64, "size of barrier_t should be 64-byte"); + +void gc_init_barrier(barrier_t *b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} + +#if GC_NEEDS_OMP_WRAPPER +void WEAK_SYMBOL __kmpc_barrier(void *loc, int32_t global_tid) { +#pragma omp barrier +} + +int WEAK_SYMBOL __kmpc_global_thread_num(void *loc) { + return omp_get_thread_num(); +} + +// The implementation was extracted and simplified from LLVM libomp +// at openmp/runtime/src/kmp_sched.cpp +void WEAK_SYMBOL __kmpc_for_static_init_8u(void *loc, int32_t gtid, + int32_t schedtype, + int32_t *plastiter, uint64_t *plower, + uint64_t *pupper, int64_t *pstride, + int64_t incr, int64_t chunk) { + if (unlikely(schedtype != 34)) { + std::abort(); + } + const int32_t FALSE = 0; + const int32_t TRUE = 1; + using UT = uint64_t; + // using ST = int64_t; + /* this all has to be changed back to TID and such.. */ + uint32_t tid = gtid; + uint32_t nth = omp_get_num_threads(); + UT trip_count; + + /* special handling for zero-trip loops */ + if (incr > 0 ? (*pupper < *plower) : (*plower < *pupper)) { + if (plastiter != nullptr) + *plastiter = FALSE; + /* leave pupper and plower set to entire iteration space */ + *pstride = incr; /* value should never be used */ + return; + } + + if (nth == 1) { + if (plastiter != nullptr) + *plastiter = TRUE; + *pstride = + (incr > 0) ? (*pupper - *plower + 1) : (-(*plower - *pupper + 1)); + return; + } + + /* compute trip count */ + if (incr == 1) { + trip_count = *pupper - *plower + 1; + } else if (incr == -1) { + trip_count = *plower - *pupper + 1; + } else if (incr > 0) { + // upper-lower can exceed the limit of signed type + trip_count = (UT)(*pupper - *plower) / incr + 1; + } else { + trip_count = (UT)(*plower - *pupper) / (-incr) + 1; + } + if (trip_count < nth) { + if (tid < trip_count) { + *pupper = *plower = *plower + tid * incr; + } else { + // set bounds so non-active threads execute no iterations + *plower = *pupper + (incr > 0 ? 1 : -1); + } + if (plastiter != nullptr) + *plastiter = (tid == trip_count - 1); + } else { + UT small_chunk = trip_count / nth; + UT extras = trip_count % nth; + *plower += incr * (tid * small_chunk + (tid < extras ? tid : extras)); + *pupper = *plower + small_chunk * incr - (tid < extras ? 0 : incr); + if (plastiter != nullptr) + *plastiter = (tid == nth - 1); + } + *pstride = trip_count; +} + +void WEAK_SYMBOL __kmpc_for_static_fini(void *ptr, int32_t v) {} + +static thread_local int next_num_threads = 0; + +/*! +@ingroup PARALLEL +The type for a microtask which gets passed to @ref __kmpc_fork_call(). +The arguments to the outlined function are +@param global_tid the global thread identity of the thread executing the +function. +@param bound_tid the local identity of the thread executing the function +@param ... pointers to shared variables accessed by the function. +*/ +using kmpc_micro = void (*)(int32_t *global_tid, int32_t *bound_tid, ...); +void WEAK_SYMBOL __kmpc_fork_call(void *loc, int32_t argc, void *pfunc, ...) { + if (unlikely(argc != 1 && argc != 0)) { + std::abort(); + } + va_list ap; + va_start(ap, pfunc); + void *c = va_arg(ap, void *); + int32_t global_tid = 0; + if (unlikely(next_num_threads)) { +#pragma omp parallel num_threads(next_num_threads) + { + kmpc_micro func = (kmpc_micro)(pfunc); + func(&global_tid, nullptr, c); + } + next_num_threads = 0; + } else { +#pragma omp parallel + { + kmpc_micro func = (kmpc_micro)(pfunc); + func(&global_tid, nullptr, c); + } + } + va_end(ap); +} + +void WEAK_SYMBOL __kmpc_push_num_threads(void *loc, int32_t global_tid, + int32_t num_threads) { + next_num_threads = num_threads; +} +#endif +} diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index d25c8a027..09ec00b02 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp + Pipeline.cpp TileNamed.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp new file mode 100644 index 000000000..81fae6877 --- /dev/null +++ b/lib/gc/Transforms/Pipeline.cpp @@ -0,0 +1,154 @@ +//===- Pipeline.cpp - Graph Compiler all-in-one pipeline --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" + +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Transforms/Passes.h" + +namespace mlir::gc { + +// linalg + linalgX + tensor +void populateFrontendPasses(mlir::PassManager &pm) { + // pm.addPass(onednn_graph::createConvertOneDNNGraphToLinalg()); +} + +// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack +void populateTensorPasses(mlir::PassManager &pm) { + // todo: padding propagation pass + // todo: layout propagation pass + // todo: tensor constant propagation pass + // todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass + // todo: fine-grain fusion pass + // todo: lower linalg to arith/math on virtual vector pass + + // REMOVE this pass after the above passes are added. Currently we add this + // pass to make the pipeline work properly + pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); +} + +// scf + arith + math + vector + tensor + linalg.brgemm +void populateVectorPasses(mlir::PassManager &pm) { + // todo: bf16 promotion pass, device dependent pass + // todo: bf16 cast elimilation pass, fast-math kind pass, designed to support + // oneDNN graph spec + pm.addNestedPass(arith::createArithExpandOpsPass()); + // todo: lower to physical vector pass, device dependent pass +} + +// scf + arith + math + vector + memref + linalg.brgemm +void populateBufferizationPasses(mlir::PassManager &pm) { + bufferization::OneShotBufferizationOptions options; + pm.addPass(bufferization::createOneShotBufferizePass(options)); + pm.addPass(createCSEPass()); + pm.addPass(mlir::func::createFuncBufferizePass()); + pm.addNestedPass( + bufferization::createBufferizationBufferizePass()); + pm.addNestedPass( + bufferization::createFinalizingBufferizePass()); + bufferization::BufferResultsToOutParamsOpts opt{}; + opt.hoistStaticAllocs = true; + pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); + // todo: buffer schedule pass + // todo: Need to improve this pass to support nested parallel. + pm.addNestedPass(bufferization::createBufferHoistingPass()); + pm.addNestedPass(bufferization::createBufferLoopHoistingPass()); + pm.addNestedPass(bufferization::createBufferDeallocationPass()); + pm.addPass(createBufferizationToMemRefPass()); +} + +// scf + arith + math + vector + memref + func/microkernel +void populateMicroKernelPasses(mlir::PassManager &pm) { + // todo: ConvertLinalgToMicrokernel pass + // todo: CleanupInvalidMicrokernel pass + // todo: InvariantMicrokernelMotion pass + // todo: ConvertMicrokernelToDnnlFunc to lower brgemm to dnnl call + // todo: ConvertMicrokernelToXsmm, to lower brgemm to libxsmm call + // todo: LowerMicrokernel pass + // todo: DispatchMicrokernel +} + +void populateCPURuntimePasses(mlir::PassManager &pm) { + // todo: flatten nested parallel pass to support coarse-grain usion + // remove this pass after we add FlattenNestedParallel + pm.addPass(createConvertSCFToOpenMPPass()); +} + +void populateLoweringToLLVMPasses(mlir::PassManager &pm) { + pm.addPass(createConvertSCFToCFPass()); + pm.addPass(cpuruntime::createCPURuntimeToLLVM()); + pm.addPass(createConvertOpenMPToLLVMPass()); + pm.addNestedPass(createConvertMathToLLVMPass()); + pm.addPass(createConvertMathToLibmPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + pm.addNestedPass(createArithToLLVMConversionPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(createConvertControlFlowToLLVMPass()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addPass(createSymbolDCEPass()); +} + +void populateLLVMPasses(mlir::PassManager &pm) { + pm.addPass(memref::createExpandOpsPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + populateLoweringToLLVMPasses(pm); +} + +void populateCPUPipeline(mlir::PassManager &pm) { + // front-end, oneDNN graph dialect + populateFrontendPasses(pm); + // middle-end, LinalgX/Linalg/tensor dialects + populateTensorPasses(pm); + // middle-end, arith/math/vector dialects + populateVectorPasses(pm); + // back-end, arith/math/vector/memref dialects + populateBufferizationPasses(pm); + // REMOVE this pass after the TensorPasses are added. Currently we add this + // pass to make the pipeline work properly + pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); + populateMicroKernelPasses(pm); + populateCPURuntimePasses(pm); + // // back-end, llvm dialect + populateLLVMPasses(pm); +} + +#define GEN_PASS_DEF_GCCPUPIPELINE +#include "gc/Transforms/Passes.h.inc" +namespace { + +class GCCPUPipeline : public impl::GCCPUPipelineBase { +public: + friend struct PassHelper; + using impl::GCCPUPipelineBase::GCCPUPipelineBase; + void runOnOperation() final { + auto op = getOperation(); + PassManager pm{op->getContext()}; + populateCPUPipeline(pm); + if (failed(pm.run(op))) + signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::gc diff --git a/scripts/license.py b/scripts/license.py index 49f28eaa8..3ed2ce521 100644 --- a/scripts/license.py +++ b/scripts/license.py @@ -15,10 +15,10 @@ # SPDX-License-Identifier: Apache-2.0 import datetime, sys, re, argparse -from typing import Dict, Set +from typing import Dict, Set, List WIDTH: int = 80 -intel_license: list[str] = [ +intel_license: List[str] = [ 'Copyright \\(C\\) (\\d\\d\\d\\d-)?$YEAR Intel Corporation', '', 'Licensed under the Apache License, Version 2.0 (the "License");', @@ -35,7 +35,7 @@ 'SPDX-License-Identifier: Apache-2.0', ] -llvm_license: list[str] = [ +llvm_license: List[str] = [ "===-{1,2} $FILE - .* -*\\*- $LANG -\\*-===", '', 'This file is licensed under the Apache License v2.0 with LLVM Exceptions.', @@ -45,7 +45,7 @@ "===-*===", ] -def check_license(filepath: str, license: list[str], var: Dict[str, str], re_line: Set[int]): +def check_license(filepath: str, license: List[str], var: Dict[str, str], re_line: Set[int]): with open(filepath, 'r') as f: idx: int = 0 for line in f.readlines(): @@ -117,7 +117,7 @@ def use_llvm_license(path: str) -> bool: var: Dict[str, str] = {} re_line: Set[int] = set() - lic = list[str] + lic = List[str] if filepath.startswith("test/") or filepath.startswith("./test/"): continue diff --git a/src/gc-cpu-runner/CMakeLists.txt b/src/gc-cpu-runner/CMakeLists.txt index f3f768612..2599eef84 100644 --- a/src/gc-cpu-runner/CMakeLists.txt +++ b/src/gc-cpu-runner/CMakeLists.txt @@ -1,3 +1,20 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + if(GC_DEV_LINK_LLVM_DYLIB) set(LLVM_LINK_COMPONENTS LLVM @@ -36,7 +53,8 @@ endif() #LLVM_LINK_COMPONENTS is processed by LLVM cmake in add_llvm_executable set(gc_cpu_runner_libs - ${MLIR_LINK_COMPONENTS}) + ${MLIR_LINK_COMPONENTS} + GCCpuRuntime) add_mlir_tool(gc-cpu-runner gc-cpu-runner.cpp ) diff --git a/src/gc-cpu-runner/gc-cpu-runner.cpp b/src/gc-cpu-runner/gc-cpu-runner.cpp index 3ece8f2ff..353abffe9 100644 --- a/src/gc-cpu-runner/gc-cpu-runner.cpp +++ b/src/gc-cpu-runner/gc-cpu-runner.cpp @@ -27,7 +27,11 @@ #include "llvm/Support/TargetSelect.h" #include +extern int gc_runtime_keep_alive; + int main(int argc, char **argv) { + // keeps GCCPURuntime linked + gc_runtime_keep_alive = 0; llvm::InitLLVM y(argc, argv); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); diff --git a/test/gc/Transforms/Pipeline/run.mlir b/test/gc/Transforms/Pipeline/run.mlir new file mode 100644 index 000000000..71feb0843 --- /dev/null +++ b/test/gc/Transforms/Pipeline/run.mlir @@ -0,0 +1,23 @@ +// RUN: gc-opt %s --gc-cpu-pipeline | gc-cpu-runner -e main -entry-point-result=void | FileCheck %s + +module { +func.func @aaa() -> tensor<128xf32> { + %c2 = arith.constant 2.0 : f32 + %a = tensor.empty() : tensor<128xf32> + %2 = linalg.fill ins(%c2 : f32) outs(%a : tensor<128xf32>) -> tensor<128xf32> + return %2 : tensor<128xf32> +} + +func.func @main() { + %result = call @aaa() : ()-> tensor<128xf32> + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c128 step %c1 { + %4 = tensor.extract %result[%iv] : tensor<128xf32> + cpuruntime.printf "%f\n" %4 : f32 + } + return +} +// CHECK-COUNT-128: 2.000000 +} \ No newline at end of file diff --git a/test/gc/Transforms/Pipeline/tensor_args.mlir b/test/gc/Transforms/Pipeline/tensor_args.mlir new file mode 100644 index 000000000..adcfb3bd8 --- /dev/null +++ b/test/gc/Transforms/Pipeline/tensor_args.mlir @@ -0,0 +1,13 @@ +// RUN: gc-opt %s --gc-cpu-pipeline | FileCheck %s + +module { +// CHECK: aaa +// check that the func returns void +// CHECK-NOT: ) -> !llvm.struct< +func.func @aaa(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> { + %out = tensor.empty() : tensor<128xf32> + %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + // CHECK-NOT: memcpy + return %2 : tensor<128xf32> +} +} \ No newline at end of file diff --git a/test/gc/cpu-runner/tid.mlir b/test/gc/cpu-runner/tid.mlir new file mode 100644 index 000000000..aedcc0a20 --- /dev/null +++ b/test/gc/cpu-runner/tid.mlir @@ -0,0 +1,37 @@ +// RUN: gc-opt %s --convert-cpuruntime-to-llvm --convert-openmp-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --reconcile-unrealized-casts | gc-cpu-runner -e main -entry-point-result=void | FileCheck %s +module { + func.func private @omp_get_thread_num() -> i32 + + func.func @check_parallel() { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %0 = llvm.mlir.constant(1 : i64) : i64 + omp.parallel num_threads(%c8: index) { + omp.wsloop { + omp.loop_nest (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c64) step (%c1, %c1) { + cpuruntime.printf "ITR %zu\n" %arg2 : index + omp.yield + } + omp.terminator + } + %tid = func.call @omp_get_thread_num() : () -> i32 + cpuruntime.printf "EXIT %d\n" %tid : i32 + omp.terminator + } + return + } + + func.func @main() { + %0 = func.call @omp_get_thread_num() : () -> i32 + cpuruntime.printf "TID %d\n" %0 : i32 + call @check_parallel() : ()->() + return + } + // CHECK: TID 0 + // CHECK-COUNT-64: ITR {{[0-9]+}} + // CHECK-NOT: ITR + // CHECK-COUNT-8: EXIT {{[0-9]+}} + // CHECK-NOT: EXIT +} \ No newline at end of file