|
39 | 39 | #include "tfrt/host_context/task_function.h"
|
40 | 40 | #include "tfrt/jitrt/results.h"
|
41 | 41 | #include "tfrt/support/forward_decls.h"
|
| 42 | +#include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/calling_convention.h" |
42 | 43 | #include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/specialization.h"
|
43 | 44 | #include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h"
|
44 | 45 | #include "third_party/tensorflow/compiler/xla/runtime/arguments.h"
|
@@ -75,6 +76,8 @@ class Tensor;
|
75 | 76 |
|
76 | 77 | namespace jitrt {
|
77 | 78 |
|
| 79 | +using xla::runtime::CallingConvention; |
| 80 | + |
78 | 81 | // Compiled module example:
|
79 | 82 | //
|
80 | 83 | // module @kernel attributes { tfrt.compiled } {
|
@@ -175,24 +178,6 @@ namespace jitrt {
|
175 | 178 | class JitExecutable;
|
176 | 179 |
|
177 | 180 | struct CompilationOptions {
|
178 |
| - // Calling convention defines an ABI for JitRt to call a compiled kernel. See |
179 |
| - // documentation and example below. |
180 |
| - using CallingConvention = |
181 |
| - std::function<mlir::FunctionType(mlir::FunctionType)>; |
182 |
| - |
183 |
| - // Returns a calling convention that only adds the kernel context argument. |
184 |
| - static CallingConvention DefaultCallingConvention(); |
185 |
| - |
186 |
| - // Returns a calling convention that uses user-provided type converter to |
187 |
| - // convert all inputs and results types, and adds the kernel context argument. |
188 |
| - static CallingConvention DefaultCallingConvention(mlir::TypeConverter); |
189 |
| - |
190 |
| - // Returns a calling convention that (1) prepends the kernel context argument, |
191 |
| - // (2) uses the user-provided type converter to convert all inputs and results |
192 |
| - // types, and (3) converts result types into out-params by appending them |
193 |
| - // to the arguments. |
194 |
| - static CallingConvention ResultsToOutsCallingConvention(mlir::TypeConverter); |
195 |
| - |
196 | 181 | // Compiled kernel can be specialized and recompiled at runtime to the
|
197 | 182 | // concrete input shapes and sometimes values (e.g. reduciton dimension).
|
198 | 183 | enum class Specialization {
|
@@ -237,48 +222,8 @@ struct CompilationOptions {
|
237 | 222 | // `rt-to-kernel-function` pass to convert regular functions to "kernels".
|
238 | 223 | std::function<void(mlir::PassManager&)> create_compilation_pipeline;
|
239 | 224 |
|
240 |
| - // Calling convention converts the compiled module entrypoint function type to |
241 |
| - // the function type with a well defined ABI (e.g. tensors do not have an ABI, |
242 |
| - // and must be passed across the function boundary as memrefs). In a nutshell |
243 |
| - // it tells the JitRt how to call the compiled kernel at run time, and how to |
244 |
| - // return results back to the JitRt. |
245 |
| - // |
246 |
| - // All types in the converted function signature should have a registered |
247 |
| - // type conversion (see `type_converter` below) to a type with defined |
248 |
| - // argument or result ABI (see Type::ArgumentAbi and Type::ResultAbi). |
249 |
| - // |
250 |
| - // If conversion is not possible, calling convention must return a null value. |
251 |
| - // |
252 |
| - // Example: abstract kernel defined in high level dialect, e.g. MHLO |
253 |
| - // |
254 |
| - // ```mlir |
255 |
| - // func @kernel(%arg0: tensor<?xf32>, |
256 |
| - // %arg1: tensor<?xf32>) -> tensor<?x?xf32> { ... } |
257 |
| - // ``` |
258 |
| - // |
259 |
| - // after calling convention conversion becomes: |
260 |
| - // |
261 |
| - // ```mlir |
262 |
| - // func @kernel(%ctx: !rt.kernel_context, |
263 |
| - // %arg0: memref<?xf32>, |
264 |
| - // %arg1: memref<?xf32>) -> memref<?x?xf32> { ... } |
265 |
| - // ``` |
266 |
| - // |
267 |
| - // Calling convention function type is not the same as the entrypoint function |
268 |
| - // type produced by the compilation pipeline for several reasons: |
269 |
| - // |
270 |
| - // 1) Compilation pipeline produces LLVM functions with LLVM types, and high |
271 |
| - // level information is lost, e.g. all memrefs are deconstructed into |
272 |
| - // primitive fields when passed as inputs. |
273 |
| - // |
274 |
| - // 2) Compiled kernel function always returns void, and uses runtime API to |
275 |
| - // return results back to the caller (see `rt-convert-to-entrypoint` pass). |
276 |
| - // |
277 |
| - // Calling convention function type is a JitRt-compatible description of the |
278 |
| - // compiled kernel ABI, so that JitRt can correctly initialize CallFrame |
279 |
| - // arguments, allocate memory for returned results, and then correctly decode |
280 |
| - // results memory into the high level types (e.g. convert returned memref |
281 |
| - // descriptor to a Tensorfow tensor). |
| 225 | + // Calling convention defines an ABI for XLA runtime to call an executable. |
| 226 | + // See `CallingConvention` documentation for details. |
282 | 227 | CallingConvention calling_convention = DefaultCallingConvention();
|
283 | 228 |
|
284 | 229 | // Type converter converts MLIR types to the corresponding run time types.
|
|
0 commit comments