|
47 | 47 | #include "third_party/tensorflow/compiler/xla/runtime/constraints.h"
|
48 | 48 | #include "third_party/tensorflow/compiler/xla/runtime/custom_call.h"
|
49 | 49 | #include "third_party/tensorflow/compiler/xla/runtime/diagnostics.h"
|
| 50 | +#include "third_party/tensorflow/compiler/xla/runtime/executable.h" |
50 | 51 | #include "third_party/tensorflow/compiler/xla/runtime/execution_engine.h"
|
51 | 52 | #include "third_party/tensorflow/compiler/xla/runtime/memory_mapper.h"
|
52 | 53 | #include "third_party/tensorflow/compiler/xla/runtime/symbolic_shape.h"
|
@@ -173,14 +174,6 @@ namespace jitrt {
|
173 | 174 | // concrete shapes or values if needed.
|
174 | 175 | class JitExecutable;
|
175 | 176 |
|
176 |
| -// Forward declare the Executable class that represents a fully compiled module, |
177 |
| -// which in practice means that it has a function pointer to the compiled |
178 |
| -// function, and knows how to execute it, and return results to the caller. |
179 |
| -class Executable; |
180 |
| - |
181 |
| -// Converts a custom call library into the execution engine symbols binding. |
182 |
| -ExecutionEngine::SymbolsBinding GetSymbolsBinding(DirectCustomCallLibrary lib); |
183 |
| - |
184 | 177 | struct CompilationOptions {
|
185 | 178 | // Calling convention defines an ABI for JitRt to call a compiled kernel. See
|
186 | 179 | // documentation and example below.
|
@@ -316,261 +309,6 @@ struct CompilationOptions {
|
316 | 309 | // otherwise.
|
317 | 310 | Expected<MemrefDesc> ConvertTensorToMemrefDesc(const Tensor& tensor);
|
318 | 311 |
|
319 |
| -//----------------------------------------------------------------------------// |
320 |
| -// Result of compiling MLIR module to executable kernel function. |
321 |
| -//----------------------------------------------------------------------------// |
322 |
| - |
323 |
| -namespace internal { |
324 |
| -class JitCompilationContext; |
325 |
| -} // namespace internal |
326 |
| - |
327 |
| -class Executable { |
328 |
| - public: |
329 |
| - // Forward declare types defined below. |
330 |
| - struct ArgumentsMemoryLayout; |
331 |
| - struct ResultsMemoryLayout; |
332 |
| - struct CallFrame; |
333 |
| - struct ExecuteOpts; |
334 |
| - struct KernelContext; |
335 |
| - |
336 |
| - // Initializes call frame by adding all arguments according to the compiled |
337 |
| - // kernel ABI. Also allocates storage for the returned values according to the |
338 |
| - // results memory layout. |
339 |
| - // |
340 |
| - // If `verify_arguments` is true (in debug mode it's always on, independent of |
341 |
| - // the argument value) this function also verifies that operands passed at run |
342 |
| - // time matches the executable entrypoint signature (e.g. all statically known |
343 |
| - // dimensions of the memrefs matches the operands). Returns an error if finds |
344 |
| - // a mismatch. |
345 |
| - // |
346 |
| - // This function leaves the kernel context argument (the first argument of a |
347 |
| - // kernel function) uninitialized. It will be initialized in the `Execute` |
348 |
| - // function right before the actual execution. |
349 |
| - Error InitializeCallFrame(ArgumentsRef arguments, CallFrame* call_frame, |
350 |
| - bool verify_arguments = true) const; |
351 |
| - |
352 |
| - // Converts returned values owned by the call frame using provided result |
353 |
| - // converter. If compiled function execution finished with an error (error |
354 |
| - // flag is `true` in the call frame) returns error for all results. |
355 |
| - Error ReturnResults(const ResultConverter& results, |
356 |
| - CallFrame* call_frame) const; |
357 |
| - |
358 |
| - // Executes compiled function with given arguments. |
359 |
| - // |
360 |
| - // If `verify_arguments` is true (in debug mode it's always on, independent of |
361 |
| - // the argument value) this function also verifies that arguments passed at |
362 |
| - // run time matches the executable entrypoint signature. If some of the |
363 |
| - // arguments do not match the expected type, this function allocates error |
364 |
| - // async values for all results and returns an error. |
365 |
| - // |
366 |
| - // Returns compiled function results via the user-provided results converter. |
367 |
| - // If execution completed in the error state, returns error for all results. |
368 |
| - Error Execute(ArgumentsRef arguments, const ResultConverter& results, |
369 |
| - const ExecuteOpts& opts, bool verify_arguments = true) const; |
370 |
| - |
371 |
| - // Executes compiled function using user provided call frame. |
372 |
| - // |
373 |
| - // It is the caller responsibility to handle the compiled function results |
374 |
| - // stored in the call frame. |
375 |
| - void Execute(CallFrame& call_frame, const ExecuteOpts& opts) const; |
376 |
| - |
377 |
| - bool IsAsync() const { return results_memory_layout_.has_async_results; } |
378 |
| - |
379 |
| - llvm::StringRef name() const { return name_; } |
380 |
| - |
381 |
| - Optional<size_t> specialization() const { return specialization_; } |
382 |
| - |
383 |
| - // Returns the number of results in the runtime signature. |
384 |
| - unsigned num_results() const; |
385 |
| - |
386 |
| - // Signature of the compiled module entrypoint function before lowering to |
387 |
| - // the runtime dialects. See JitExecutable's `signature_` for more details. |
388 |
| - const FunctionType& signature() const; |
389 |
| - |
390 |
| - // Signature of the compiled module entrypoint function after lowering it from |
391 |
| - // high level dialects to the dialects supported by the jitrt runtime. |
392 |
| - // See JitExecutable's `signature_` for more details. |
393 |
| - const FunctionType& runtime_signature() const; |
394 |
| - |
395 |
| - std::chrono::milliseconds time_to_compile() const; |
396 |
| - |
397 |
| - // Get the object file behind this executable (on linux for example, it will |
398 |
| - // be https://en.wikipedia.org/wiki/Executable_and_Linkable_Format |
399 |
| - // executable). Can be null. |
400 |
| - std::unique_ptr<llvm::MemoryBuffer> obj_file() const; |
401 |
| - |
402 |
| - // CallFrame provides a pointer-stable storage for packed function arguments |
403 |
| - // and storage for returned values. |
404 |
| - struct CallFrame { |
405 |
| - // Pointers to compiled kernel arguments. |
406 |
| - llvm::SmallVector<void*, 32> args; |
407 |
| - |
408 |
| - // We use single block of memory to store compiled kernel results. We need |
409 |
| - // to be able to store pointers to async values and tokens, and strided |
410 |
| - // memrefs which at runtime are represented as StridedMemrefType<T, rank>. |
411 |
| - // |
412 |
| - // Currently we only need to provide result storage for pointers and memref |
413 |
| - // sizes and strides (int64_t type). If we'll need to support more complex |
414 |
| - // return types we'll have to be more careful about alignment requirements. |
415 |
| - static_assert(sizeof(uintptr_t) == sizeof(int64_t), |
416 |
| - "uintptr_t size must be the same as int64_t"); |
417 |
| - |
418 |
| - // Memory where the compiled kernel will write its results. |
419 |
| - llvm::SmallVector<uint8_t, 128> results; |
420 |
| - |
421 |
| - // Tracks whether any of the outputs were set. |
422 |
| - bool has_set_outputs = false; |
423 |
| - |
424 |
| - // Indicates whether the kernel function execution finished with an error. |
425 |
| - bool is_error = false; |
426 |
| - |
427 |
| - // The error message which is available only if `is_error` is true. The |
428 |
| - // assumption is that the error message string is owned by the compiled |
429 |
| - // binary and the call frame can safely keep a non-owning pointer. |
430 |
| - llvm::StringRef error; |
431 |
| - }; |
432 |
| - |
433 |
| - // Requirements for passing arguments to the compiled function. |
434 |
| - struct ArgumentsMemoryLayout { |
435 |
| - // Currently we always pass arguments as an array of pointers. |
436 |
| - size_t num_args_ptrs = 0; |
437 |
| - }; |
438 |
| - |
439 |
| - // Requirements for the contiguous block of memory to store compiled function |
440 |
| - // results. When we invoke a compiled fuction we allocate a block of memory, |
441 |
| - // and pass pointers to pre-computed offsets as output arguments to the |
442 |
| - // function. |
443 |
| - struct ResultsMemoryLayout { |
444 |
| - bool has_async_results = false; // true iff returns async results |
445 |
| - size_t size = 0; // number of bytes required |
446 |
| - llvm::SmallVector<size_t> offsets; // offsets in the block of memory |
447 |
| - }; |
448 |
| - |
449 |
| - // Options for configuring compiled kernel execution. |
450 |
| - struct ExecuteOpts { |
451 |
| - // Async task runner for executing async runtime tasks. Typically it |
452 |
| - // schedules async tasks into the underlying thread pool. It's the caller's |
453 |
| - // responsibility to guarantee that it will outlive the execution of all |
454 |
| - // async tasks started by the executable. |
455 |
| - AsyncTaskRunner* async_task_runner = nullptr; |
456 |
| - |
457 |
| - // A container for passing arbitrary user-provided data to the custom call |
458 |
| - // handlers. Must outlive all async tasks launched by this executable. |
459 |
| - CustomCall::UserData* custom_call_data = nullptr; |
460 |
| - |
461 |
| - // Diagnostic engine is responsible for passing runtime diagnostics back |
462 |
| - // to the caller through the diagnostic handler. |
463 |
| - DiagnosticEngine* diagnostic_engine = nullptr; |
464 |
| - }; |
465 |
| - |
466 |
| - // Loads executable from an object file. It is the caller responsibility to |
467 |
| - // guarantee that signatures do match the compiled function in the object |
468 |
| - // file, otherwise it will surely lead to crash. |
469 |
| - static Expected<Executable> LoadFromObjFile( |
470 |
| - llvm::StringRef name, std::unique_ptr<llvm::MemoryBuffer> obj_file, |
471 |
| - llvm::StringRef entrypoint, FunctionType signature, |
472 |
| - FunctionType runtime_signature, |
473 |
| - ExecutionEngine::SymbolsBinding runtime_symbol_map = {}, |
474 |
| - llvm::StringRef memory_region_name = ""); |
475 |
| - |
476 |
| - // Verifies that all operands types in the entrypoint function signature are |
477 |
| - // supported at run time . Returns a pre-computed layout for the function |
478 |
| - // arguments. If some arguments are not supported returns an error. |
479 |
| - static Expected<ArgumentsMemoryLayout> GetArgumentsMemoryLayout( |
480 |
| - const FunctionType& signature); |
481 |
| - |
482 |
| - // Verifies that all results types in the entrypoint function signature are |
483 |
| - // supported at run time . Returns a pre-computed layout for the function |
484 |
| - // results. If some results are not supported returns an error. |
485 |
| - static Expected<ResultsMemoryLayout> GetResultsMemoryLayout( |
486 |
| - const FunctionType& signature); |
487 |
| - |
488 |
| - // TODO(ezhulenev): The following three functions should be decoupled from |
489 |
| - // the jitrt header file (maybe move them to runtime.h?) so that custom call |
490 |
| - // implementations do not have to depend on the `jitrt` target. |
491 |
| - |
492 |
| - // Returns the user data passed via the ExecuteOpts to the compiled kernel. |
493 |
| - static CustomCall::UserData* GetUserData(xla::runtime::KernelContext* ctx); |
494 |
| - |
495 |
| - // Returns the diagnostic engine passed via the ExecuteOpts to the compiled |
496 |
| - // kernel. |
497 |
| - static DiagnosticEngine* GetDiagnosticEngine( |
498 |
| - xla::runtime::KernelContext* ctx); |
499 |
| - |
500 |
| - // Calls the custom call handler with the given runtime context, arguments and |
501 |
| - // attributes. |
502 |
| - static mlir::LogicalResult Call(xla::runtime::KernelContext* ctx, |
503 |
| - CustomCall& call, void** args, void** attrs); |
504 |
| - |
505 |
| - private: |
506 |
| - friend class internal::JitCompilationContext; |
507 |
| - |
508 |
| - Executable(llvm::StringRef name, |
509 |
| - std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper, |
510 |
| - std::unique_ptr<ExecutionEngine> engine, FunctionType signature, |
511 |
| - FunctionType runtime_signature, |
512 |
| - ArgumentsMemoryLayout arguments_memory_layout, |
513 |
| - ResultsMemoryLayout results_memory_layout, |
514 |
| - Optional<size_t> specialization, |
515 |
| - std::chrono::milliseconds time_to_compile) |
516 |
| - : name_(name.str()), |
517 |
| - memory_mapper_(std::move(memory_mapper)), |
518 |
| - engine_(std::move(engine)), |
519 |
| - fptr_(engine_->entrypoint()), |
520 |
| - signature_(std::move(signature)), |
521 |
| - runtime_signature_(std::move(runtime_signature)), |
522 |
| - arguments_memory_layout_(std::move(arguments_memory_layout)), |
523 |
| - results_memory_layout_(std::move(results_memory_layout)), |
524 |
| - specialization_(specialization), |
525 |
| - time_to_compile_(time_to_compile) { |
526 |
| - assert(fptr_ != nullptr && "kernel function must be not null"); |
527 |
| - } |
528 |
| - |
529 |
| - std::string name_; // name of the compiled kernel module |
530 |
| - |
531 |
| - // Called by `engine_`'s destructor; must appear before it. |
532 |
| - std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper_; // optional |
533 |
| - |
534 |
| - // JitRt execution engine owns the LLVM ORC jit compilation stack. |
535 |
| - std::unique_ptr<ExecutionEngine> engine_; |
536 |
| - |
537 |
| - // Compiled function owned by the `engine_`. |
538 |
| - ExecutionEngine::EntrypointFunctionPtr fptr_; |
539 |
| - |
540 |
| - // Signature of the compiled module entrypoint function before lowering to |
541 |
| - // the runtime dialects (see JitExecutable `signature_` for more details). |
542 |
| - FunctionType signature_; |
543 |
| - |
544 |
| - // Signature of the compiled module entrypoint function after lowering it from |
545 |
| - // high level dialects to the dialects supported by the jitrt runtime. |
546 |
| - // |
547 |
| - // - Operands and results types converted to the types with well-defined ABI |
548 |
| - // (e.g. tensors converted to memrefs). |
549 |
| - // |
550 |
| - // - First argument is always a kernel context added to the function by the |
551 |
| - // lowering pipeline. |
552 |
| - // |
553 |
| - // From this signature executable infers how to pack runtime operands |
554 |
| - // according to the expected memory layout, and how to convert results |
555 |
| - // returned from the JIT-compiled function into high level types (e.g. how to |
556 |
| - // convert StridedMemrefType into Tensorflow Tensor). |
557 |
| - // |
558 |
| - // To infer the type of the returned value, executable looks at the type |
559 |
| - // defined by the `runtime_signature_` to get the memory layout of the |
560 |
| - // returned value, and at the type defined by the `signature_` to get the type |
561 |
| - // expected by the runtime. |
562 |
| - FunctionType runtime_signature_; |
563 |
| - |
564 |
| - ArgumentsMemoryLayout arguments_memory_layout_; |
565 |
| - ResultsMemoryLayout results_memory_layout_; |
566 |
| - |
567 |
| - // Specialization id if this executable is a specialization, or an empty |
568 |
| - // optional if this executable is a default one. |
569 |
| - Optional<size_t> specialization_; |
570 |
| - // The time it took to compile this binary. |
571 |
| - std::chrono::milliseconds time_to_compile_; |
572 |
| -}; |
573 |
| - |
574 | 312 | //----------------------------------------------------------------------------//
|
575 | 313 | // JitExecutable to manage multiple compiled executables.
|
576 | 314 | //----------------------------------------------------------------------------//
|
|
0 commit comments