88#include < torch/csrc/jit/passes/dead_code_elimination.h>
99#include < torch/csrc/jit/passes/utils/subgraph_utils.h>
1010#include < torch/csrc/jit/tensorexpr/buffer.h>
11+ #include < torch/csrc/jit/tensorexpr/cuda_codegen.h>
1112#include < torch/csrc/jit/tensorexpr/eval.h>
1213#include < torch/csrc/jit/tensorexpr/llvm_codegen.h>
1314#include < torch/csrc/jit/tensorexpr/schedule.h>
@@ -303,13 +304,23 @@ std::vector<Expr> computeIndicesToBroadcast(
303304 return bcast;
304305}
305306
306- struct TensorExprKernel {
307- std::vector<Buffer> buffer_args;
308- std::vector<Tensor> tensor_outputs;
309- std::unordered_map<int64_t , Tensor> tensors;
310- std::unique_ptr<CodeGen> codegen;
311- KernelArena kernel_arena;
312-
307+ class TensorExprKernel {
308+ private:
309+ enum BackendType {
310+ kUninitialized ,
311+ kSimpleIREval ,
312+ kLLVMCodeGen ,
313+ kCudaCodeGen ,
314+ };
315+ std::vector<Buffer> buffer_args_;
316+ std::vector<Tensor> tensor_outputs_;
317+ std::unordered_map<int64_t , Tensor> tensors_;
318+ std::unique_ptr<CodeGen> codegen_;
319+ KernelArena kernel_arena_;
320+ BackendType backend_type_ = BackendType::kUninitialized ;
321+ at::Device device_ = at::kCPU ;
322+
323+ private:
313324 Expr constant (torch::jit::Value* v) {
314325 if (v->node ()->kind () == prim::Constant) {
315326 const auto val = toIValue (v).value ();
@@ -332,8 +343,12 @@ struct TensorExprKernel {
332343 }
333344
334345 template <typename T>
335- Expr chunk (const T& t, size_t chunk_idx, size_t dim, size_t chunks,
336- const std::vector<Var>& axes) {
346+ Expr chunk (
347+ const T& t,
348+ size_t chunk_idx,
349+ size_t dim,
350+ size_t chunks,
351+ const std::vector<Var>& axes) {
337352 auto sizes = bufferSizes (t);
338353 size_t step = sizes[dim] / chunks;
339354
@@ -375,8 +390,8 @@ struct TensorExprKernel {
375390 }
376391
377392 Expr tensorOrConstant (torch::jit::Value* v, const std::vector<Var>& axes) {
378- auto ti = tensors .find (v->unique ());
379- if (ti != tensors .end ()) {
393+ auto ti = tensors_ .find (v->unique ());
394+ if (ti != tensors_ .end ()) {
380395 return broadcast (ti->second , axes);
381396 }
382397 return constant (v);
@@ -699,22 +714,115 @@ struct TensorExprKernel {
699714 }
700715 }
701716
717+ void LowerToBackend (BackendType backend_type) {
718+ torch::jit::tensorexpr::schedule::Schedule sch (tensor_outputs_);
719+
720+ // Compute non-output tensors_ inline
721+ for (auto & p : tensors_) {
722+ p.second .ComputeInline ();
723+ }
724+ if (backend_type == kCudaCodeGen ) {
725+ for (auto & output : tensor_outputs_) {
726+ // TODO: implement the universal fused dispatching config.
727+ if (output.args ().size () < 2 ) {
728+ throw std::runtime_error (
729+ " Only tensors with more than 2D is supported in CudaCodeGen" );
730+ }
731+ Var x = output.arg (0 );
732+ Var y = output.arg (1 );
733+ output.GPUExecConfig ({x}, {y});
734+ }
735+ }
736+
737+ Stmt stmt = sch.Lower ();
738+
739+ // Set up formal params (inputs, then outputs) for kernel.
740+ std::vector<CodeGen::BufferArg> params (
741+ buffer_args_.begin (), buffer_args_.end ());
742+ for (auto & o : tensor_outputs_) {
743+ params.push_back (o);
744+ }
745+
746+ // Generate code.
747+ switch (backend_type_) {
748+ case kCudaCodeGen :
749+ codegen_ = std::make_unique<CudaCodeGen>(stmt, params);
750+ break ;
751+ case kLLVMCodeGen :
752+ codegen_ = std::make_unique<LLVMCodeGen>(stmt, params);
753+ break ;
754+ case kSimpleIREval :
755+ codegen_ = std::make_unique<SimpleIREvaluator>(stmt, params);
756+ break ;
757+ default :
758+ throw std::runtime_error (" invalid backend type" );
759+ }
760+ }
761+
762+ void PickAndCheckBackendType (const at::ArrayRef<IValue>& inputs) {
763+ at::Device device = inputs[0 ].toTensor ().device ();
764+ BackendType backend_type = BackendType::kUninitialized ;
765+ if (device.type () == at::kCUDA ) {
766+ backend_type = kCudaCodeGen ;
767+ } else if (device.type () == at::kCPU ) {
768+ #ifdef ENABLE_LLVM
769+ backend_type = kLLVMCodeGen ;
770+ #else
771+ backend_type = kSimpleIREval ;
772+ ;
773+ #endif
774+ } else {
775+ throw std::runtime_error (" Invalid device type" );
776+ }
777+
778+ if (backend_type_ == kUninitialized ) {
779+ backend_type_ = backend_type;
780+ device_ = device;
781+ LowerToBackend (backend_type);
782+ } else if (backend_type_ != backend_type) {
783+ // TODO: if we have to support muliptole backends with the same subgraph,
784+ // we need to add kernel caching.
785+ throw std::runtime_error (
786+ " Inconsistent backend_type: " + std::to_string (backend_type_) +
787+ " vs " + std::to_string (backend_type));
788+ }
789+ }
790+
791+ void CodeGenRun (const std::vector<CodeGen::CallArg>& run_args) {
792+ if (backend_type_ == kCudaCodeGen || backend_type_ == kSimpleIREval ) {
793+ codegen_->call (run_args);
794+ } else if (backend_type_ == kLLVMCodeGen ) {
795+ for (int i = 0 ; i < buffer_args_.size (); i++) {
796+ codegen_->bind (buffer_args_[i], run_args[i]);
797+ }
798+ int offset = buffer_args_.size ();
799+ for (int i = 0 ; i < tensor_outputs_.size (); i++) {
800+ codegen_->bind (tensor_outputs_[i], run_args[i + offset]);
801+ }
802+ codegen_->run ();
803+ } else {
804+ throw std::runtime_error (
805+ " Invalid backend type: " + std::to_string (backend_type_));
806+ }
807+ }
808+
809+ public:
702810 explicit TensorExprKernel (const Node* node) {
703- KernelScope kernel_scope (kernel_arena );
811+ KernelScope kernel_scope (kernel_arena_ );
704812 auto subgraph = node->g (attr::Subgraph);
705813
706814 // Bind inputs to buffers.
707815 for (auto const & input : subgraph->inputs ()) {
708816 Buffer in_buffer = texprBuffer (input);
709- tensors .emplace (
817+ tensors_ .emplace (
710818 input->unique (),
711819 Compute (
712820 " input" ,
713821 texprDims (input),
714822 [this , in_buffer](const std::vector<Var>& axes) {
715823 return broadcast (in_buffer, axes);
716824 }));
717- buffer_args .push_back (std::move (in_buffer));
825+ buffer_args_ .push_back (std::move (in_buffer));
718826 }
719827
720828 // Bind nodes to tensor compute expressions.
@@ -730,58 +838,36 @@ struct TensorExprKernel {
730838 }
731839 }
732840
733- // Move output operands from `tensors ` to `tensor_outputs `
841+ // Move output operands from `tensors_ ` to `tensor_outputs_ `
734842 for (const auto & output : subgraph->outputs ()) {
735- CHECK (tensors .count (output->unique ())) << " Output must be a tensor" ;
736- tensor_outputs .emplace_back (tensors .at (output->unique ()));
737- tensors .erase (output->unique ());
843+ CHECK (tensors_ .count (output->unique ())) << " Output must be a tensor" ;
844+ tensor_outputs_ .emplace_back (tensors_ .at (output->unique ()));
845+ tensors_ .erase (output->unique ());
738846 }
739-
740- torch::jit::tensorexpr::schedule::Schedule sch (tensor_outputs);
741-
742- // Compute non-output tensors inline
743- for (auto & p : tensors) {
744- p.second .ComputeInline ();
745- }
746- Stmt stmt = sch.Lower ();
747-
748- #if TX_DEBUG
749- std::cerr << stmt << " \n " ;
750- #endif
751-
752- #ifdef ENABLE_LLVM
753- // Set up formal params (inputs, then outputs) for kernel.
754- std::vector<CodeGen::BufferArg> params (
755- buffer_args.begin (), buffer_args.end ());
756- for (auto & o : tensor_outputs) {
757- params.push_back (o);
758- }
759-
760- // Generate code.
761- codegen = std::make_unique<LLVMCodeGen>(stmt, params);
762- #else
763- codegen = std::make_unique<SimpleIREvaluator>(stmt);
764- #endif
765847 }
766848
767849 void run (Stack& stack) {
768- KernelScope kernel_scope (kernel_arena );
850+ KernelScope kernel_scope (kernel_arena_ );
769851 // Set up arguments (inputs, then outputs) for kernel call.
770- auto inputs = last (stack, buffer_args.size ());
771- for (int i = 0 ; i < buffer_args.size (); i++) {
772- codegen->bind (buffer_args[i], inputs[i].toTensor ().data_ptr ());
852+ auto inputs = last (stack, buffer_args_.size ());
853+ PickAndCheckBackendType (inputs);
854+
855+ std::vector<CodeGen::CallArg> run_args;
856+ for (int i = 0 ; i < buffer_args_.size (); i++) {
857+ run_args.push_back (inputs[i].toTensor ().data_ptr ());
773858 }
774859 std::vector<at::Tensor> outputs;
775- for (auto & o : tensor_outputs) {
776- outputs.push_back (at::empty (bufferSizes (o), tensorType (o)));
777- codegen->bind (o, outputs.back ().data_ptr ());
860+ for (auto & o : tensor_outputs_) {
861+ outputs.push_back (at::empty (
862+ bufferSizes (o), c10::TensorOptions (tensorType (o)).device (device_)));
863+ run_args.push_back (outputs.back ().data_ptr ());
778864 }
779865
780866 // Call the kernel.
781- codegen-> run ( );
867+ CodeGenRun (run_args );
782868
783869 // Update the stack.
784- drop (stack, buffer_args .size ());
870+ drop (stack, buffer_args_ .size ());
785871 for (auto & o : outputs) {
786872 push_one (stack, std::move (o));
787873 }
0 commit comments