Skip to content

Commit 0a042cb

Browse files
ezhulenevcopybara-github
authored andcommitted
[xla:runtime] NFC: Extract calling_convention library from jitrt and move it to xla/runtime
PiperOrigin-RevId: 467360160
1 parent b20ec05 commit 0a042cb

File tree

5 files changed

+10
-139
lines changed

5 files changed

+10
-139
lines changed

backends/jitrt/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ licenses(["notice"])
9898
# "@llvm-project//mlir:ToLLVMIRTranslation",
9999
# "@llvm-project//mlir:Transforms",
100100
# "@llvm-project//mlir:mlir_c_runner_utils",
101+
# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:calling_convention",
101102
# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:passes",
102103
# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:specialization",
103104
# "//third_party/tensorflow/compiler/xla/mlir/transforms/runtime:type_converter",

backends/jitrt/cpp_tests/calling_convention_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ static const char* entrypoint = "log2_1d";
4747

4848
struct CallingConventionTestCase {
4949
std::string test_name;
50-
CompilationOptions::CallingConvention calling_convention;
50+
CallingConvention calling_convention;
5151
int expected_num_results;
5252
int expected_num_operands;
5353
};
@@ -117,11 +117,11 @@ INSTANTIATE_TEST_SUITE_P(
117117
CallingConventionTest, CallingConventionTest,
118118
testing::ValuesIn<CallingConventionTestCase>({
119119
{"DefaultCallingConvention",
120-
CompilationOptions::DefaultCallingConvention(
120+
xla::runtime::DefaultCallingConvention(
121121
mlir::bufferization::BufferizeTypeConverter()),
122122
/*expected_num_results=*/1, /*expected_num_operands=*/2},
123123
{"ResultsToOutsCallingConvention",
124-
CompilationOptions::ResultsToOutsCallingConvention(
124+
xla::runtime::ResultsToOutsCallingConvention(
125125
mlir::bufferization::BufferizeTypeConverter()),
126126
/*expected_num_results=*/0, /*expected_num_operands=*/3},
127127
}),

backends/jitrt/cpp_tests/end_to_end_example_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ TEST(EndToEndExampleTest, CompiledAndExecute) {
272272
// the ABI boundary. The expectation is that compiler pipeline will act
273273
// according to this calling convention, and the entrypoint will have the same
274274
// function signature.
275-
opts.calling_convention = CompilationOptions::DefaultCallingConvention(
275+
opts.calling_convention = xla::runtime::DefaultCallingConvention(
276276
mlir::bufferization::BufferizeTypeConverter());
277277

278278
// Add a conversion from the `!testlib.custom_arg` MLIR type to the run time

backends/jitrt/include/tfrt/jitrt/jitrt.h

Lines changed: 5 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "tfrt/host_context/task_function.h"
4040
#include "tfrt/jitrt/results.h"
4141
#include "tfrt/support/forward_decls.h"
42+
#include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/calling_convention.h"
4243
#include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/specialization.h"
4344
#include "third_party/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h"
4445
#include "third_party/tensorflow/compiler/xla/runtime/arguments.h"
@@ -75,6 +76,8 @@ class Tensor;
7576

7677
namespace jitrt {
7778

79+
using xla::runtime::CallingConvention;
80+
7881
// Compiled module example:
7982
//
8083
// module @kernel attributes { tfrt.compiled } {
@@ -175,24 +178,6 @@ namespace jitrt {
175178
class JitExecutable;
176179

177180
struct CompilationOptions {
178-
// Calling convention defines an ABI for JitRt to call a compiled kernel. See
179-
// documentation and example below.
180-
using CallingConvention =
181-
std::function<mlir::FunctionType(mlir::FunctionType)>;
182-
183-
// Returns a calling convention that only adds the kernel context argument.
184-
static CallingConvention DefaultCallingConvention();
185-
186-
// Returns a calling convention that uses user-provided type converter to
187-
// convert all inputs and results types, and adds the kernel context argument.
188-
static CallingConvention DefaultCallingConvention(mlir::TypeConverter);
189-
190-
// Returns a calling convention that (1) prepends the kernel context argument,
191-
// (2) uses the user-provided type converter to convert all inputs and results
192-
// types, and (3) converts result types into out-params by appending them
193-
// to the arguments.
194-
static CallingConvention ResultsToOutsCallingConvention(mlir::TypeConverter);
195-
196181
// Compiled kernel can be specialized and recompiled at runtime to the
197182
// concrete input shapes and sometimes values (e.g. reduciton dimension).
198183
enum class Specialization {
@@ -237,48 +222,8 @@ struct CompilationOptions {
237222
// `rt-to-kernel-function` pass to convert regular functions to "kernels".
238223
std::function<void(mlir::PassManager&)> create_compilation_pipeline;
239224

240-
// Calling convention converts the compiled module entrypoint function type to
241-
// the function type with a well defined ABI (e.g. tensors do not have an ABI,
242-
// and must be passed across the function boundary as memrefs). In a nutshell
243-
// it tells the JitRt how to call the compiled kernel at run time, and how to
244-
// return results back to the JitRt.
245-
//
246-
// All types in the converted function signature should have a registered
247-
// type conversion (see `type_converter` below) to a type with defined
248-
// argument or result ABI (see Type::ArgumentAbi and Type::ResultAbi).
249-
//
250-
// If conversion is not possible, calling convention must return a null value.
251-
//
252-
// Example: abstract kernel defined in high level dialect, e.g. MHLO
253-
//
254-
// ```mlir
255-
// func @kernel(%arg0: tensor<?xf32>,
256-
// %arg1: tensor<?xf32>) -> tensor<?x?xf32> { ... }
257-
// ```
258-
//
259-
// after calling convention conversion becomes:
260-
//
261-
// ```mlir
262-
// func @kernel(%ctx: !rt.kernel_context,
263-
// %arg0: memref<?xf32>,
264-
// %arg1: memref<?xf32>) -> memref<?x?xf32> { ... }
265-
// ```
266-
//
267-
// Calling convention function type is not the same as the entrypoint function
268-
// type produced by the compilation pipeline for several reasons:
269-
//
270-
// 1) Compilation pipeline produces LLVM functions with LLVM types, and high
271-
// level information is lost, e.g. all memrefs are deconstructed into
272-
// primitive fields when passed as inputs.
273-
//
274-
// 2) Compiled kernel function always returns void, and uses runtime API to
275-
// return results back to the caller (see `rt-convert-to-entrypoint` pass).
276-
//
277-
// Calling convention function type is a JitRt-compatible description of the
278-
// compiled kernel ABI, so that JitRt can correctly initialize CallFrame
279-
// arguments, allocate memory for returned results, and then correctly decode
280-
// results memory into the high level types (e.g. convert returned memref
281-
// descriptor to a Tensorfow tensor).
225+
// Calling convention defines an ABI for XLA runtime to call an executable.
226+
// See `CallingConvention` documentation for details.
282227
CallingConvention calling_convention = DefaultCallingConvention();
283228

284229
// Type converter converts MLIR types to the corresponding run time types.

backends/jitrt/lib/jitrt.cc

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -142,81 +142,6 @@ Expected<MemrefDesc> ConvertTensorToMemrefDesc(const Tensor& tensor) {
142142
return MakeStringError("unsupported tensor type: ", tensor.tensor_type());
143143
}
144144

145-
//----------------------------------------------------------------------------//
146-
// Default calling convention for kernels compiled for JitRt.
147-
//----------------------------------------------------------------------------//
148-
149-
using CallingConvention = CompilationOptions::CallingConvention;
150-
151-
/*static*/ CallingConvention CompilationOptions::DefaultCallingConvention() {
152-
return [](mlir::FunctionType func) {
153-
mlir::MLIRContext* ctx = func.getContext();
154-
155-
llvm::SmallVector<mlir::Type> inputs = {KernelContextType::get(ctx)};
156-
inputs.reserve(1 + func.getNumInputs());
157-
llvm::append_range(inputs, func.getInputs());
158-
159-
return mlir::FunctionType::get(ctx, inputs, func.getResults());
160-
};
161-
}
162-
163-
/*static*/ CallingConvention CompilationOptions::DefaultCallingConvention(
164-
mlir::TypeConverter type_converter) {
165-
return [c = std::move(type_converter)](mlir::FunctionType func) mutable {
166-
mlir::MLIRContext* ctx = func.getContext();
167-
168-
// Track if all type conversions were successful.
169-
bool failed_conversion = false;
170-
auto convert = [&](mlir::Type type) -> mlir::Type {
171-
auto converted = c.convertType(type);
172-
if (!converted) failed_conversion = true;
173-
return converted;
174-
};
175-
176-
// Add kernel context as the first argument.
177-
llvm::SmallVector<mlir::Type> inputs = {KernelContextType::get(ctx)};
178-
inputs.reserve(1 + func.getNumInputs());
179-
llvm::transform(func.getInputs(), std::back_inserter(inputs), convert);
180-
181-
// Apply type conversion to all results types.
182-
llvm::SmallVector<mlir::Type> results;
183-
results.reserve(func.getNumResults());
184-
llvm::transform(func.getResults(), std::back_inserter(results), convert);
185-
186-
// Return null if any of the type conversions failed.
187-
if (failed_conversion) return mlir::FunctionType();
188-
189-
return mlir::FunctionType::get(ctx, inputs, results);
190-
};
191-
}
192-
193-
/*static*/ CallingConvention CompilationOptions::ResultsToOutsCallingConvention(
194-
mlir::TypeConverter type_converter) {
195-
return [c = std::move(type_converter)](mlir::FunctionType func) mutable {
196-
mlir::MLIRContext* ctx = func.getContext();
197-
198-
// Track if all type conversions were successful.
199-
bool failed_conversion = false;
200-
201-
auto convert = [&](mlir::Type type) -> mlir::Type {
202-
auto converted = c.convertType(type);
203-
if (!converted) failed_conversion = true;
204-
return converted;
205-
};
206-
207-
llvm::SmallVector<mlir::Type> inputs;
208-
inputs.reserve(1 + func.getNumInputs() + func.getNumResults());
209-
inputs.push_back(KernelContextType::get(ctx));
210-
llvm::transform(func.getInputs(), std::back_inserter(inputs), convert);
211-
llvm::transform(func.getResults(), std::back_inserter(inputs), convert);
212-
213-
// Return null if any of the type conversions failed.
214-
if (failed_conversion) return mlir::FunctionType();
215-
216-
return mlir::FunctionType::get(ctx, inputs, {});
217-
};
218-
}
219-
220145
//----------------------------------------------------------------------------//
221146
// Setup MLIR pass pipeline to lower to LLVM dialect, and use ORC JIT to codegen
222147
// functions at runtime.

0 commit comments

Comments
 (0)