diff --git a/mlir/test/Conversion/TosaToEmitC/fix_mem.py b/mlir/test/Conversion/TosaToEmitC/fix_mem.py new file mode 100644 index 0000000000000..dc456069dc24f --- /dev/null +++ b/mlir/test/Conversion/TosaToEmitC/fix_mem.py @@ -0,0 +1,145 @@ +import argparse +import dataclasses +import re +import pathlib +import typing + + +@dataclasses.dataclass(frozen=True) +class Pattern: + match: str + substitution: str | typing.Callable[[str], str] + name: str + + def substitute(self, input: str) -> str: + return re.sub(self.match, self.substitution, input, 0, re.MULTILINE) + + +SUBSTITUTIONS = [ + # Insert additional constant with 0 + Pattern( + r"func.func(.*)\{", + "func.func\\1{\\n %zzz = arith.constant 0 : index", + "const_0", + ), + # Convert all 0D memrefs 1D memrefs + Pattern(r"memref<(\D\S*)>", "memref<1x\\1>", "memref-1d"), + # memref.load + Pattern( + r"memref.load %(.*)\[()\] : memref<(.*)>", + "memref.load %\\1[%zzz] : memref<\\3>", + "memref-load", + ), + # memref.store + Pattern( + r"memref.store %(.*), %(.*)\[()\] : memref<(.*)>", + "memref.store %\\1, %\\2[%zzz] : memref<\\4>", + "memref-store", + ), + # memref.alloca alignment + Pattern( + r"%(.*) \= memref.alloca\(\) \{alignment \= .* : .*\} : memref\<(.*)\>", + "%\\1 = memref.alloca() : memref<\\2>", + "memref-alloca", + ), + # memref.copy + Pattern( + r"memref.copy %(.*), %(.*) : memref<(.*)> to memref<(.*)>", + "linalg.copy ins(%\\1 : memref<\\3>) outs(%\\2 : memref<\\4>)", + "memref-copy", + ), + # arith.extf + Pattern( + r"%(.*) = arith.extf %(.*) : (.*) to (.*)", + "%\\1 = emitc.cast %\\2 : \\3 to \\4", + "arith-extf", + ), + # arith.truncf + Pattern( + r"%(.*) = arith.truncf %(.*) : (.*) to (.*)", + "%\\1 = emitc.cast %\\2 : \\3 to \\4", + "arith-truncf", + ), + # arith.index_cast + Pattern( + r"%(.*) = arith.index_cast %(.*) : index to (.*)", + "%\\1 = emitc.cast %\\2 : index to \\3", + "arith-index-cast", + ), + # arith.cmpf ogt + Pattern( + r"%(.*) = arith.cmpf ogt, %(.*), %(.*) : (.*)", + "%\\1 = emitc.cmp gt , %\\2, %\\3 : (\\4, \\4) -> i1", + "arith-cmpf", + ), + # arith.cmpf ugt + Pattern( + r"%(.*) = arith.cmpf ugt, %(.*), %(.*) : (.*)", + "%\\1 = emitc.cmp gt , %\\2, %\\3 : (\\4, \\4) -> i1", + "arith-cmpf", + ), + # arith.cmpf ult + Pattern( + r"%(.*) = arith.cmpf ult, %(.*), %(.*) : (.*)", + "%\\1 = emitc.cmp lt , %\\2, %\\3 : (\\4, \\4) -> i1", + "arith-cmpf", + ), + # arith.cmpf uno + Pattern( + r"%(.*) = arith.cmpf uno, %(.*), %(.*) : (.*)", + "%\\1 = emitc.cmp ne , %\\2, %\\3 : (\\4, \\4) -> i1", + "arith-cmpf", + ), + # args + Pattern( + r"func.func @forward\(%(.*): memref<(.*)>, %(.*): memref<(.*)>\) \{", + 'func.func @forward(%xxx: !emitc.array<\\2>, %yyy: !emitc.array<\\4>) {\\n %\\1 = "builtin.unrealized_conversion_cast"(%xxx) : (!emitc.array<\\2>) -> memref<\\2>\\n %\\3 = "builtin.unrealized_conversion_cast"(%yyy) : (!emitc.array<\\4>) -> memref<\\4>', + "func-args", + ), + # TODO: replace tensor.concat with insert_slice + # TODO: add output_shape to tensor.expand_shape + # TODO: arith: truncf, extf, cmpf + # # tensor.concat + # Pattern( + # r"%(.*) = tensor.concat.*\n", + # lambda m: m, + # "tensor-concat", + # ), +] + + +def substitute(input: str, names: list[str] | None) -> str: + for pattern in SUBSTITUTIONS: + if names is None or pattern.name in names: + input = pattern.substitute(input) + return input + + +def run(input_path: str, output_path: str, patterns: list[str] | None): + input = pathlib.Path(input_path) + if not input.exists(): + raise ValueError(f"File not found: {input}") + + input_str = input.read_text() + output_str = substitute(input_str, patterns) + + output = pathlib.Path(output_path) + output.write_text(output_str) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_path", metavar="input-path", help="Path to input mlir file" + ) + parser.add_argument( + "output_path", metavar="output-path", help="Path to output mlir file" + ) + parser.add_argument("-p", "--pattern", action="append") + args = parser.parse_args() + + run(args.input_path, args.output_path, args.pattern) + + +if __name__ == "__main__": + main() diff --git a/mlir/test/Conversion/TosaToEmitC/tosa-to-emitc.mlir b/mlir/test/Conversion/TosaToEmitC/tosa-to-emitc.mlir new file mode 100644 index 0000000000000..e2d69a2cd3d6e --- /dev/null +++ b/mlir/test/Conversion/TosaToEmitC/tosa-to-emitc.mlir @@ -0,0 +1,49 @@ + +// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,canonicalize,linalg-generalize-named-ops,tosa-to-arith,tosa-to-tensor,canonicalize))" %s -o %t.model.linalg.mlir +// RUN: mlir-opt --canonicalize --linalg-fuse-elementwise-ops --linalg-inline-scalar-operands --linalg-fold-unit-extent-dims --fold-tensor-subset-ops --canonicalize %t.model.linalg.mlir -o %t.model.linalg.opt.mlir +// RUN: mlir-opt --pass-pipeline='builtin.module(one-shot-bufferize{allow-unknown-ops bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}, canonicalize)' %t.model.linalg.opt.mlir -o %t.model.buffers.mlir + +// RUN: mlir-opt --canonicalize --buffer-results-to-out-params --buffer-hoisting --buffer-loop-hoisting --promote-buffers-to-stack --fold-memref-alias-ops --canonicalize --buffer-deallocation-pipeline --canonicalize %t.model.buffers.mlir -o %t.model.buffers.opt.mlir +// RUN: mlir-opt --canonicalize --convert-linalg-to-loops --fold-memref-alias-ops --canonicalize %t.model.buffers.opt.mlir -o %t.model.scf.mlir + +// RUN: python %S/fix_mem.py -p memref-copy %t.model.scf.mlir %t.model.scf.1.mlir + +// RUN: mlir-opt --canonicalize --convert-linalg-to-loops --canonicalize %t.model.scf.1.mlir -o %t.model.scf.2.mlir +// RUN: mlir-opt --canonicalize --fold-memref-alias-ops --normalize-memrefs --canonicalize %t.model.scf.2.mlir -o %t.model.scf.3.mlir + +// RUN: mlir-opt --arith-expand --canonicalize %t.model.scf.3.mlir -o %t.model.scf.4.mlir + +// RUN: python %S/fix_mem.py %t.model.scf.4.mlir %t.model.scf.5.mlir + +// RUN: mlir-opt --convert-math-to-libm --canonicalize %t.model.scf.5.mlir -o %t.model.scf.6.mlir +// RUN: mlir-opt --convert-func-to-emitc --convert-scf-to-emitc --convert-arith-to-emitc --convert-memref-to-emitc --canonicalize %t.model.scf.6.mlir -o %t.model.emitc.mlir + +// RUN: mlir-translate --mlir-to-cpp %t.model.emitc.mlir | FileCheck %s + +// CHECK: Fail this test + +// ----- + + +module attributes {tf_saved_model.semantics} { + func.func @main(%arg0: tensor {ml_program.identifier = "serve_b:0", tf_saved_model.index_path = ["b"]}, %arg1: tensor {ml_program.identifier = "serve_a:0", tf_saved_model.index_path = ["a"]}) -> (tensor {ml_program.identifier = "PartitionedCall:0", tf_saved_model.index_path = ["output_0"]}) attributes {tf_saved_model.exported_names = ["serve"]} { + %0 = tosa.add %arg1, %arg0 : (tensor, tensor) -> tensor + %1 = tosa.mul %arg1, %arg0 {shift = 0 : i8} : (tensor, tensor) -> tensor + %2 = tosa.add %1, %arg1 : (tensor, tensor) -> tensor + %3 = tosa.add %2, %1 : (tensor, tensor) -> tensor + %4 = tosa.add %arg1, %3 : (tensor, tensor) -> tensor + %5 = tosa.mul %1, %arg1 {shift = 0 : i8} : (tensor, tensor) -> tensor + %6 = tosa.sub %5, %4 : (tensor, tensor) -> tensor + %7 = tosa.reciprocal %5 : (tensor) -> tensor + %8 = tosa.mul %6, %7 {shift = 0 : i8} : (tensor, tensor) -> tensor + %9 = tosa.sub %1, %arg0 : (tensor, tensor) -> tensor + %10 = tosa.add %0, %arg1 : (tensor, tensor) -> tensor + %11 = tosa.add %arg0, %10 : (tensor, tensor) -> tensor + %12 = tosa.add %9, %11 : (tensor, tensor) -> tensor + %13 = tosa.add %1, %12 : (tensor, tensor) -> tensor + %14 = tosa.add %8, %13 : (tensor, tensor) -> tensor + return %14 : tensor + } +} + +