From 18722cdf006e2ca2c93e81dc7f24c2bc4776a363 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 18 Sep 2023 17:50:30 -0700 Subject: [PATCH] [mlir][c] Expose AsmState. Enable usage where capturing AsmState is good. Haven't plumbed through to python yet. --- mlir/include/mlir-c/IR.h | 26 ++++++++++++++- mlir/include/mlir/CAPI/IR.h | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 4 ++- mlir/lib/CAPI/IR/IR.cpp | 49 +++++++++++++++++++++++++++-- mlir/test/CAPI/ir.c | 7 +++++ 5 files changed, 83 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index b5c6a3094bc67..68eccab6dbaca 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -48,6 +48,7 @@ extern "C" { }; \ typedef struct name name +DEFINE_C_API_STRUCT(MlirAsmState, void); DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void); DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); @@ -383,6 +384,29 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MLIR_CAPI_EXPORTED void mlirOperationStateEnableResultTypeInference(MlirOperationState *state); +//===----------------------------------------------------------------------===// +// AsmState API. +// While many of these are simple settings that could be represented in a +// struct, they are wrapped in a heap allocated object and accessed via +// functions to maximize the possibility of compatibility over time. +//===----------------------------------------------------------------------===// + +/// Creates new AsmState, as with AsmState the IR should not be mutated +/// in-between using this state. +/// Must be freed with a call to mlirAsmStateDestroy(). +// TODO: This should be expanded to handle location & resouce map. +MLIR_CAPI_EXPORTED MlirAsmState +mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags); + +/// Creates new AsmState from value. +/// Must be freed with a call to mlirAsmStateDestroy(). +// TODO: This should be expanded to handle location & resouce map. +MLIR_CAPI_EXPORTED MlirAsmState +mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags); + +/// Destroys printing flags created with mlirAsmStateCreate. +MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state); + //===----------------------------------------------------------------------===// // Op Printing flags API. // While many of these are simple settings that could be represented in a @@ -815,7 +839,7 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); /// Prints a value as an operand (i.e., the ValueID). MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, - MlirOpPrintingFlags flags, + MlirAsmState state, MlirStringCallback callback, void *userData); diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index b8ccec896c27b..1836cb0acb67e 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -21,6 +21,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState) DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig) DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b06937bc285e2..af713547cccbb 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3430,9 +3430,11 @@ void mlir::python::populateIRCore(py::module &m) { MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (useLocalScope) mlirOpPrintingFlagsUseLocalScope(flags); - mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(), + MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags); + mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(), printAccum.getUserData()); mlirOpPrintingFlagsDestroy(flags); + mlirAsmStateDestroy(state); return printAccum.join(); }, py::arg("use_local_scope") = false, kGetNameAsOperand) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index ef234a912490e..7f5c2aaee6738 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -138,6 +138,51 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { delete unwrap(registry); } +//===----------------------------------------------------------------------===// +// AsmState API. +//===----------------------------------------------------------------------===// + +MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, + MlirOpPrintingFlags flags) { + return wrap(new AsmState(unwrap(op), *unwrap(flags))); +} + +static Operation *findParent(Operation *op, bool shouldUseLocalScope) { + do { + // If we are printing local scope, stop at the first operation that is + // isolated from above. + if (shouldUseLocalScope && op->hasTrait()) + break; + + // Otherwise, traverse up to the next parent. + Operation *parentOp = op->getParentOp(); + if (!parentOp) + break; + op = parentOp; + } while (true); + return op; +} + +MlirAsmState mlirAsmStateCreateForValue(MlirValue value, + MlirOpPrintingFlags flags) { + Operation *op; + mlir::Value val = unwrap(value); + if (auto result = llvm::dyn_cast(val)) { + op = result.getOwner(); + } else { + op = llvm::cast(val).getOwner()->getParentOp(); + if (!op) { + emitError(val.getLoc()) << "<>"; + return {nullptr}; + } + } + op = findParent(op, unwrap(flags)->shouldUseLocalScope()); + return wrap(new AsmState(op, *unwrap(flags))); +} + +/// Destroys printing flags created with mlirAsmStateCreate. +void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); } + //===----------------------------------------------------------------------===// // Printing flags API. //===----------------------------------------------------------------------===// @@ -840,11 +885,11 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback, unwrap(value).print(stream); } -void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags, +void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); Value cppValue = unwrap(value); - cppValue.printAsOperand(stream, *unwrap(flags)); + cppValue.printAsOperand(stream, *unwrap(state)); } MlirOpOperand mlirValueGetFirstUse(MlirValue value) { diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 5725d05a3e132..c031e61945d03 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -487,6 +487,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown) // clang-format on + MlirAsmState state = mlirAsmStateCreateForOperation(parentOperation, flags); + fprintf(stderr, "With state: |"); + mlirValuePrintAsOperand(value, state, printToStderr, NULL); + // CHECK: With state: |%0| + fprintf(stderr, "|\n"); + mlirAsmStateDestroy(state); + mlirOpPrintingFlagsDestroy(flags); }