Skip to content

Commit 76b3fea

Browse files
ezhulenevcopybara-github
authored andcommitted
[xla:runtime] NFC: Extract executable library from jitrt and move it to xla/runtime
PiperOrigin-RevId: 467332094
1 parent eba528e commit 76b3fea

File tree

3 files changed

+4
-723
lines changed

3 files changed

+4
-723
lines changed

backends/jitrt/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ licenses(["notice"])
111111
# "//third_party/tensorflow/compiler/xla/runtime:custom_call",
112112
# "//third_party/tensorflow/compiler/xla/runtime:custom_call_registry",
113113
# "//third_party/tensorflow/compiler/xla/runtime:diagnostics",
114+
# "//third_party/tensorflow/compiler/xla/runtime:executable",
114115
# "//third_party/tensorflow/compiler/xla/runtime:execution_engine",
115116
# "//third_party/tensorflow/compiler/xla/runtime:memory_mapper",
116117
# "//third_party/tensorflow/compiler/xla/runtime:symbolic_shape",

backends/jitrt/include/tfrt/jitrt/jitrt.h

Lines changed: 1 addition & 263 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "third_party/tensorflow/compiler/xla/runtime/constraints.h"
4848
#include "third_party/tensorflow/compiler/xla/runtime/custom_call.h"
4949
#include "third_party/tensorflow/compiler/xla/runtime/diagnostics.h"
50+
#include "third_party/tensorflow/compiler/xla/runtime/executable.h"
5051
#include "third_party/tensorflow/compiler/xla/runtime/execution_engine.h"
5152
#include "third_party/tensorflow/compiler/xla/runtime/memory_mapper.h"
5253
#include "third_party/tensorflow/compiler/xla/runtime/symbolic_shape.h"
@@ -173,14 +174,6 @@ namespace jitrt {
173174
// concrete shapes or values if needed.
174175
class JitExecutable;
175176

176-
// Forward declare the Executable class that represents a fully compiled module,
177-
// which in practice means that it has a function pointer to the compiled
178-
// function, and knows how to execute it, and return results to the caller.
179-
class Executable;
180-
181-
// Converts a custom call library into the execution engine symbols binding.
182-
ExecutionEngine::SymbolsBinding GetSymbolsBinding(DirectCustomCallLibrary lib);
183-
184177
struct CompilationOptions {
185178
// Calling convention defines an ABI for JitRt to call a compiled kernel. See
186179
// documentation and example below.
@@ -316,261 +309,6 @@ struct CompilationOptions {
316309
// otherwise.
317310
Expected<MemrefDesc> ConvertTensorToMemrefDesc(const Tensor& tensor);
318311

319-
//----------------------------------------------------------------------------//
320-
// Result of compiling MLIR module to executable kernel function.
321-
//----------------------------------------------------------------------------//
322-
323-
namespace internal {
324-
class JitCompilationContext;
325-
} // namespace internal
326-
327-
class Executable {
328-
public:
329-
// Forward declare types defined below.
330-
struct ArgumentsMemoryLayout;
331-
struct ResultsMemoryLayout;
332-
struct CallFrame;
333-
struct ExecuteOpts;
334-
struct KernelContext;
335-
336-
// Initializes call frame by adding all arguments according to the compiled
337-
// kernel ABI. Also allocates storage for the returned values according to the
338-
// results memory layout.
339-
//
340-
// If `verify_arguments` is true (in debug mode it's always on, independent of
341-
// the argument value) this function also verifies that operands passed at run
342-
// time matches the executable entrypoint signature (e.g. all statically known
343-
// dimensions of the memrefs matches the operands). Returns an error if finds
344-
// a mismatch.
345-
//
346-
// This function leaves the kernel context argument (the first argument of a
347-
// kernel function) uninitialized. It will be initialized in the `Execute`
348-
// function right before the actual execution.
349-
Error InitializeCallFrame(ArgumentsRef arguments, CallFrame* call_frame,
350-
bool verify_arguments = true) const;
351-
352-
// Converts returned values owned by the call frame using provided result
353-
// converter. If compiled function execution finished with an error (error
354-
// flag is `true` in the call frame) returns error for all results.
355-
Error ReturnResults(const ResultConverter& results,
356-
CallFrame* call_frame) const;
357-
358-
// Executes compiled function with given arguments.
359-
//
360-
// If `verify_arguments` is true (in debug mode it's always on, independent of
361-
// the argument value) this function also verifies that arguments passed at
362-
// run time matches the executable entrypoint signature. If some of the
363-
// arguments do not match the expected type, this function allocates error
364-
// async values for all results and returns an error.
365-
//
366-
// Returns compiled function results via the user-provided results converter.
367-
// If execution completed in the error state, returns error for all results.
368-
Error Execute(ArgumentsRef arguments, const ResultConverter& results,
369-
const ExecuteOpts& opts, bool verify_arguments = true) const;
370-
371-
// Executes compiled function using user provided call frame.
372-
//
373-
// It is the caller responsibility to handle the compiled function results
374-
// stored in the call frame.
375-
void Execute(CallFrame& call_frame, const ExecuteOpts& opts) const;
376-
377-
bool IsAsync() const { return results_memory_layout_.has_async_results; }
378-
379-
llvm::StringRef name() const { return name_; }
380-
381-
Optional<size_t> specialization() const { return specialization_; }
382-
383-
// Returns the number of results in the runtime signature.
384-
unsigned num_results() const;
385-
386-
// Signature of the compiled module entrypoint function before lowering to
387-
// the runtime dialects. See JitExecutable's `signature_` for more details.
388-
const FunctionType& signature() const;
389-
390-
// Signature of the compiled module entrypoint function after lowering it from
391-
// high level dialects to the dialects supported by the jitrt runtime.
392-
// See JitExecutable's `signature_` for more details.
393-
const FunctionType& runtime_signature() const;
394-
395-
std::chrono::milliseconds time_to_compile() const;
396-
397-
// Get the object file behind this executable (on linux for example, it will
398-
// be https://en.wikipedia.org/wiki/Executable_and_Linkable_Format
399-
// executable). Can be null.
400-
std::unique_ptr<llvm::MemoryBuffer> obj_file() const;
401-
402-
// CallFrame provides a pointer-stable storage for packed function arguments
403-
// and storage for returned values.
404-
struct CallFrame {
405-
// Pointers to compiled kernel arguments.
406-
llvm::SmallVector<void*, 32> args;
407-
408-
// We use single block of memory to store compiled kernel results. We need
409-
// to be able to store pointers to async values and tokens, and strided
410-
// memrefs which at runtime are represented as StridedMemrefType<T, rank>.
411-
//
412-
// Currently we only need to provide result storage for pointers and memref
413-
// sizes and strides (int64_t type). If we'll need to support more complex
414-
// return types we'll have to be more careful about alignment requirements.
415-
static_assert(sizeof(uintptr_t) == sizeof(int64_t),
416-
"uintptr_t size must be the same as int64_t");
417-
418-
// Memory where the compiled kernel will write its results.
419-
llvm::SmallVector<uint8_t, 128> results;
420-
421-
// Tracks whether any of the outputs were set.
422-
bool has_set_outputs = false;
423-
424-
// Indicates whether the kernel function execution finished with an error.
425-
bool is_error = false;
426-
427-
// The error message which is available only if `is_error` is true. The
428-
// assumption is that the error message string is owned by the compiled
429-
// binary and the call frame can safely keep a non-owning pointer.
430-
llvm::StringRef error;
431-
};
432-
433-
// Requirements for passing arguments to the compiled function.
434-
struct ArgumentsMemoryLayout {
435-
// Currently we always pass arguments as an array of pointers.
436-
size_t num_args_ptrs = 0;
437-
};
438-
439-
// Requirements for the contiguous block of memory to store compiled function
440-
// results. When we invoke a compiled fuction we allocate a block of memory,
441-
// and pass pointers to pre-computed offsets as output arguments to the
442-
// function.
443-
struct ResultsMemoryLayout {
444-
bool has_async_results = false; // true iff returns async results
445-
size_t size = 0; // number of bytes required
446-
llvm::SmallVector<size_t> offsets; // offsets in the block of memory
447-
};
448-
449-
// Options for configuring compiled kernel execution.
450-
struct ExecuteOpts {
451-
// Async task runner for executing async runtime tasks. Typically it
452-
// schedules async tasks into the underlying thread pool. It's the caller's
453-
// responsibility to guarantee that it will outlive the execution of all
454-
// async tasks started by the executable.
455-
AsyncTaskRunner* async_task_runner = nullptr;
456-
457-
// A container for passing arbitrary user-provided data to the custom call
458-
// handlers. Must outlive all async tasks launched by this executable.
459-
CustomCall::UserData* custom_call_data = nullptr;
460-
461-
// Diagnostic engine is responsible for passing runtime diagnostics back
462-
// to the caller through the diagnostic handler.
463-
DiagnosticEngine* diagnostic_engine = nullptr;
464-
};
465-
466-
// Loads executable from an object file. It is the caller responsibility to
467-
// guarantee that signatures do match the compiled function in the object
468-
// file, otherwise it will surely lead to crash.
469-
static Expected<Executable> LoadFromObjFile(
470-
llvm::StringRef name, std::unique_ptr<llvm::MemoryBuffer> obj_file,
471-
llvm::StringRef entrypoint, FunctionType signature,
472-
FunctionType runtime_signature,
473-
ExecutionEngine::SymbolsBinding runtime_symbol_map = {},
474-
llvm::StringRef memory_region_name = "");
475-
476-
// Verifies that all operands types in the entrypoint function signature are
477-
// supported at run time . Returns a pre-computed layout for the function
478-
// arguments. If some arguments are not supported returns an error.
479-
static Expected<ArgumentsMemoryLayout> GetArgumentsMemoryLayout(
480-
const FunctionType& signature);
481-
482-
// Verifies that all results types in the entrypoint function signature are
483-
// supported at run time . Returns a pre-computed layout for the function
484-
// results. If some results are not supported returns an error.
485-
static Expected<ResultsMemoryLayout> GetResultsMemoryLayout(
486-
const FunctionType& signature);
487-
488-
// TODO(ezhulenev): The following three functions should be decoupled from
489-
// the jitrt header file (maybe move them to runtime.h?) so that custom call
490-
// implementations do not have to depend on the `jitrt` target.
491-
492-
// Returns the user data passed via the ExecuteOpts to the compiled kernel.
493-
static CustomCall::UserData* GetUserData(xla::runtime::KernelContext* ctx);
494-
495-
// Returns the diagnostic engine passed via the ExecuteOpts to the compiled
496-
// kernel.
497-
static DiagnosticEngine* GetDiagnosticEngine(
498-
xla::runtime::KernelContext* ctx);
499-
500-
// Calls the custom call handler with the given runtime context, arguments and
501-
// attributes.
502-
static mlir::LogicalResult Call(xla::runtime::KernelContext* ctx,
503-
CustomCall& call, void** args, void** attrs);
504-
505-
private:
506-
friend class internal::JitCompilationContext;
507-
508-
Executable(llvm::StringRef name,
509-
std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper,
510-
std::unique_ptr<ExecutionEngine> engine, FunctionType signature,
511-
FunctionType runtime_signature,
512-
ArgumentsMemoryLayout arguments_memory_layout,
513-
ResultsMemoryLayout results_memory_layout,
514-
Optional<size_t> specialization,
515-
std::chrono::milliseconds time_to_compile)
516-
: name_(name.str()),
517-
memory_mapper_(std::move(memory_mapper)),
518-
engine_(std::move(engine)),
519-
fptr_(engine_->entrypoint()),
520-
signature_(std::move(signature)),
521-
runtime_signature_(std::move(runtime_signature)),
522-
arguments_memory_layout_(std::move(arguments_memory_layout)),
523-
results_memory_layout_(std::move(results_memory_layout)),
524-
specialization_(specialization),
525-
time_to_compile_(time_to_compile) {
526-
assert(fptr_ != nullptr && "kernel function must be not null");
527-
}
528-
529-
std::string name_; // name of the compiled kernel module
530-
531-
// Called by `engine_`'s destructor; must appear before it.
532-
std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper_; // optional
533-
534-
// JitRt execution engine owns the LLVM ORC jit compilation stack.
535-
std::unique_ptr<ExecutionEngine> engine_;
536-
537-
// Compiled function owned by the `engine_`.
538-
ExecutionEngine::EntrypointFunctionPtr fptr_;
539-
540-
// Signature of the compiled module entrypoint function before lowering to
541-
// the runtime dialects (see JitExecutable `signature_` for more details).
542-
FunctionType signature_;
543-
544-
// Signature of the compiled module entrypoint function after lowering it from
545-
// high level dialects to the dialects supported by the jitrt runtime.
546-
//
547-
// - Operands and results types converted to the types with well-defined ABI
548-
// (e.g. tensors converted to memrefs).
549-
//
550-
// - First argument is always a kernel context added to the function by the
551-
// lowering pipeline.
552-
//
553-
// From this signature executable infers how to pack runtime operands
554-
// according to the expected memory layout, and how to convert results
555-
// returned from the JIT-compiled function into high level types (e.g. how to
556-
// convert StridedMemrefType into Tensorflow Tensor).
557-
//
558-
// To infer the type of the returned value, executable looks at the type
559-
// defined by the `runtime_signature_` to get the memory layout of the
560-
// returned value, and at the type defined by the `signature_` to get the type
561-
// expected by the runtime.
562-
FunctionType runtime_signature_;
563-
564-
ArgumentsMemoryLayout arguments_memory_layout_;
565-
ResultsMemoryLayout results_memory_layout_;
566-
567-
// Specialization id if this executable is a specialization, or an empty
568-
// optional if this executable is a default one.
569-
Optional<size_t> specialization_;
570-
// The time it took to compile this binary.
571-
std::chrono::milliseconds time_to_compile_;
572-
};
573-
574312
//----------------------------------------------------------------------------//
575313
// JitExecutable to manage multiple compiled executables.
576314
//----------------------------------------------------------------------------//

0 commit comments

Comments
 (0)