Skip to content

Commit f791fd3

Browse files
zheng-xqMikhail Zolotukhin
authored andcommitted
Add end-to-end support and a PyTorch fuser example on CudaCodeGen (pytorch#104)
1 parent 3d9c42b commit f791fd3

File tree

12 files changed

+229
-71
lines changed

12 files changed

+229
-71
lines changed

test/test_tensorexpr.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,26 @@ def easy(x, y):
1515
np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
1616

1717

18+
# TODO: combine this with the test_easy
19+
def test_easy_cuda():
20+
if not torch.cuda.is_available():
21+
return
22+
23+
def easy(x, y):
24+
aaa = torch.add(x, y)
25+
return aaa
26+
27+
traced = torch.jit.trace(easy, (torch.rand(32, 16, device='cuda'), torch.rand(32, 16, device='cuda')))
28+
29+
a = torch.rand(32, 16, device='cuda')
30+
b = torch.rand(32, 16, device='cuda')
31+
x = traced(a, b)
32+
a_cpu = a.cpu()
33+
b_cpu = b.cpu()
34+
x_cpu = x.cpu()
35+
np.testing.assert_allclose(a_cpu.numpy() + b_cpu.numpy(), x_cpu.numpy())
36+
37+
1838
def test_three_arg():
1939
def easy(x, y, z):
2040
aaa = torch.add(x, y)

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 139 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <torch/csrc/jit/passes/dead_code_elimination.h>
99
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
1010
#include <torch/csrc/jit/tensorexpr/buffer.h>
11+
#include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
1112
#include <torch/csrc/jit/tensorexpr/eval.h>
1213
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
1314
#include <torch/csrc/jit/tensorexpr/schedule.h>
@@ -303,13 +304,23 @@ std::vector<Expr> computeIndicesToBroadcast(
303304
return bcast;
304305
}
305306

306-
struct TensorExprKernel {
307-
std::vector<Buffer> buffer_args;
308-
std::vector<Tensor> tensor_outputs;
309-
std::unordered_map<int64_t, Tensor> tensors;
310-
std::unique_ptr<CodeGen> codegen;
311-
KernelArena kernel_arena;
312-
307+
class TensorExprKernel {
308+
private:
309+
enum BackendType {
310+
kUninitialized,
311+
kSimpleIREval,
312+
kLLVMCodeGen,
313+
kCudaCodeGen,
314+
};
315+
std::vector<Buffer> buffer_args_;
316+
std::vector<Tensor> tensor_outputs_;
317+
std::unordered_map<int64_t, Tensor> tensors_;
318+
std::unique_ptr<CodeGen> codegen_;
319+
KernelArena kernel_arena_;
320+
BackendType backend_type_ = BackendType::kUninitialized;
321+
at::Device device_ = at::kCPU;
322+
323+
private:
313324
Expr constant(torch::jit::Value* v) {
314325
if (v->node()->kind() == prim::Constant) {
315326
const auto val = toIValue(v).value();
@@ -332,8 +343,12 @@ struct TensorExprKernel {
332343
}
333344

334345
template <typename T>
335-
Expr chunk(const T& t, size_t chunk_idx, size_t dim, size_t chunks,
336-
const std::vector<Var>& axes) {
346+
Expr chunk(
347+
const T& t,
348+
size_t chunk_idx,
349+
size_t dim,
350+
size_t chunks,
351+
const std::vector<Var>& axes) {
337352
auto sizes = bufferSizes(t);
338353
size_t step = sizes[dim] / chunks;
339354

@@ -375,8 +390,8 @@ struct TensorExprKernel {
375390
}
376391

377392
Expr tensorOrConstant(torch::jit::Value* v, const std::vector<Var>& axes) {
378-
auto ti = tensors.find(v->unique());
379-
if (ti != tensors.end()) {
393+
auto ti = tensors_.find(v->unique());
394+
if (ti != tensors_.end()) {
380395
return broadcast(ti->second, axes);
381396
}
382397
return constant(v);
@@ -699,22 +714,115 @@ struct TensorExprKernel {
699714
}
700715
}
701716

717+
void LowerToBackend(BackendType backend_type) {
718+
torch::jit::tensorexpr::schedule::Schedule sch(tensor_outputs_);
719+
720+
// Compute non-output tensors_ inline
721+
for (auto& p : tensors_) {
722+
p.second.ComputeInline();
723+
}
724+
if (backend_type == kCudaCodeGen) {
725+
for (auto& output : tensor_outputs_) {
726+
// TODO: implement the universal fused dispatching config.
727+
if (output.args().size() < 2) {
728+
throw std::runtime_error(
729+
"Only tensors with more than 2D is supported in CudaCodeGen");
730+
}
731+
Var x = output.arg(0);
732+
Var y = output.arg(1);
733+
output.GPUExecConfig({x}, {y});
734+
}
735+
}
736+
737+
Stmt stmt = sch.Lower();
738+
739+
// Set up formal params (inputs, then outputs) for kernel.
740+
std::vector<CodeGen::BufferArg> params(
741+
buffer_args_.begin(), buffer_args_.end());
742+
for (auto& o : tensor_outputs_) {
743+
params.push_back(o);
744+
}
745+
746+
// Generate code.
747+
switch (backend_type_) {
748+
case kCudaCodeGen:
749+
codegen_ = std::make_unique<CudaCodeGen>(stmt, params);
750+
break;
751+
case kLLVMCodeGen:
752+
codegen_ = std::make_unique<LLVMCodeGen>(stmt, params);
753+
break;
754+
case kSimpleIREval:
755+
codegen_ = std::make_unique<SimpleIREvaluator>(stmt, params);
756+
break;
757+
default:
758+
throw std::runtime_error("invalid backend type");
759+
}
760+
}
761+
762+
void PickAndCheckBackendType(const at::ArrayRef<IValue>& inputs) {
763+
at::Device device = inputs[0].toTensor().device();
764+
BackendType backend_type = BackendType::kUninitialized;
765+
if (device.type() == at::kCUDA) {
766+
backend_type = kCudaCodeGen;
767+
} else if (device.type() == at::kCPU) {
768+
#ifdef ENABLE_LLVM
769+
backend_type = kLLVMCodeGen;
770+
#else
771+
backend_type = kSimpleIREval;
772+
;
773+
#endif
774+
} else {
775+
throw std::runtime_error("Invalid device type");
776+
}
777+
778+
if (backend_type_ == kUninitialized) {
779+
backend_type_ = backend_type;
780+
device_ = device;
781+
LowerToBackend(backend_type);
782+
} else if (backend_type_ != backend_type) {
783+
// TODO: if we have to support muliptole backends with the same subgraph,
784+
// we need to add kernel caching.
785+
throw std::runtime_error(
786+
"Inconsistent backend_type: " + std::to_string(backend_type_) +
787+
" vs " + std::to_string(backend_type));
788+
}
789+
}
790+
791+
void CodeGenRun(const std::vector<CodeGen::CallArg>& run_args) {
792+
if (backend_type_ == kCudaCodeGen || backend_type_ == kSimpleIREval) {
793+
codegen_->call(run_args);
794+
} else if (backend_type_ == kLLVMCodeGen) {
795+
for (int i = 0; i < buffer_args_.size(); i++) {
796+
codegen_->bind(buffer_args_[i], run_args[i]);
797+
}
798+
int offset = buffer_args_.size();
799+
for (int i = 0; i < tensor_outputs_.size(); i++) {
800+
codegen_->bind(tensor_outputs_[i], run_args[i + offset]);
801+
}
802+
codegen_->run();
803+
} else {
804+
throw std::runtime_error(
805+
"Invalid backend type: " + std::to_string(backend_type_));
806+
}
807+
}
808+
809+
public:
702810
explicit TensorExprKernel(const Node* node) {
703-
KernelScope kernel_scope(kernel_arena);
811+
KernelScope kernel_scope(kernel_arena_);
704812
auto subgraph = node->g(attr::Subgraph);
705813

706814
// Bind inputs to buffers.
707815
for (auto const& input : subgraph->inputs()) {
708816
Buffer in_buffer = texprBuffer(input);
709-
tensors.emplace(
817+
tensors_.emplace(
710818
input->unique(),
711819
Compute(
712820
"input",
713821
texprDims(input),
714822
[this, in_buffer](const std::vector<Var>& axes) {
715823
return broadcast(in_buffer, axes);
716824
}));
717-
buffer_args.push_back(std::move(in_buffer));
825+
buffer_args_.push_back(std::move(in_buffer));
718826
}
719827

720828
// Bind nodes to tensor compute expressions.
@@ -730,58 +838,36 @@ struct TensorExprKernel {
730838
}
731839
}
732840

733-
// Move output operands from `tensors` to `tensor_outputs`
841+
// Move output operands from `tensors_` to `tensor_outputs_`
734842
for (const auto& output : subgraph->outputs()) {
735-
CHECK(tensors.count(output->unique())) << "Output must be a tensor";
736-
tensor_outputs.emplace_back(tensors.at(output->unique()));
737-
tensors.erase(output->unique());
843+
CHECK(tensors_.count(output->unique())) << "Output must be a tensor";
844+
tensor_outputs_.emplace_back(tensors_.at(output->unique()));
845+
tensors_.erase(output->unique());
738846
}
739-
740-
torch::jit::tensorexpr::schedule::Schedule sch(tensor_outputs);
741-
742-
// Compute non-output tensors inline
743-
for (auto& p : tensors) {
744-
p.second.ComputeInline();
745-
}
746-
Stmt stmt = sch.Lower();
747-
748-
#if TX_DEBUG
749-
std::cerr << stmt << "\n";
750-
#endif
751-
752-
#ifdef ENABLE_LLVM
753-
// Set up formal params (inputs, then outputs) for kernel.
754-
std::vector<CodeGen::BufferArg> params(
755-
buffer_args.begin(), buffer_args.end());
756-
for (auto& o : tensor_outputs) {
757-
params.push_back(o);
758-
}
759-
760-
// Generate code.
761-
codegen = std::make_unique<LLVMCodeGen>(stmt, params);
762-
#else
763-
codegen = std::make_unique<SimpleIREvaluator>(stmt);
764-
#endif
765847
}
766848

767849
void run(Stack& stack) {
768-
KernelScope kernel_scope(kernel_arena);
850+
KernelScope kernel_scope(kernel_arena_);
769851
// Set up arguments (inputs, then outputs) for kernel call.
770-
auto inputs = last(stack, buffer_args.size());
771-
for (int i = 0; i < buffer_args.size(); i++) {
772-
codegen->bind(buffer_args[i], inputs[i].toTensor().data_ptr());
852+
auto inputs = last(stack, buffer_args_.size());
853+
PickAndCheckBackendType(inputs);
854+
855+
std::vector<CodeGen::CallArg> run_args;
856+
for (int i = 0; i < buffer_args_.size(); i++) {
857+
run_args.push_back(inputs[i].toTensor().data_ptr());
773858
}
774859
std::vector<at::Tensor> outputs;
775-
for (auto& o : tensor_outputs) {
776-
outputs.push_back(at::empty(bufferSizes(o), tensorType(o)));
777-
codegen->bind(o, outputs.back().data_ptr());
860+
for (auto& o : tensor_outputs_) {
861+
outputs.push_back(at::empty(
862+
bufferSizes(o), c10::TensorOptions(tensorType(o)).device(device_)));
863+
run_args.push_back(outputs.back().data_ptr());
778864
}
779865

780866
// Call the kernel.
781-
codegen->run();
867+
CodeGenRun(run_args);
782868

783869
// Update the stack.
784-
drop(stack, buffer_args.size());
870+
drop(stack, buffer_args_.size());
785871
for (auto& o : outputs) {
786872
push_one(stack, std::move(o));
787873
}

torch/csrc/jit/tensorexpr/codegen.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,19 @@ class CodeGen {
2121
: ir_node_(const_cast<BaseStmtNode*>(stmt.node())),
2222
buffer_args_({BufferArg(ts)...}) {}
2323

24+
CodeGen(const Stmt& stmt, const std::vector<BufferArg>& buffer_args)
25+
: ir_node_(const_cast<BaseStmtNode*>(stmt.node())),
26+
buffer_args_(buffer_args) {}
27+
2428
template <typename... Ts>
2529
CodeGen(const Expr& expr, Ts... ts)
2630
: ir_node_(const_cast<BaseExprNode*>(expr.node())),
2731
buffer_args_({BufferArg(ts)...}) {}
2832

33+
CodeGen(const Expr& expr, const std::vector<BufferArg>& buffer_args)
34+
: ir_node_(const_cast<BaseExprNode*>(expr.node())),
35+
buffer_args_(buffer_args) {}
36+
2937
CodeGen(const IRNode* node) : ir_node_(const_cast<IRNode*>(node)) {}
3038

3139
virtual ~CodeGen() {}
@@ -54,6 +62,10 @@ class CodeGen {
5462
LOG(FATAL) << "Unimplemented interface";
5563
}
5664

65+
TORCH_API virtual void call(const std::vector<CallArg>& args) {
66+
LOG(FATAL) << "unimplemented call";
67+
}
68+
5769
private:
5870
IRNode* ir_node_ = nullptr;
5971
std::vector<BufferArg> buffer_args_;

torch/csrc/jit/tensorexpr/cuda_codegen.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void CudaPrinter::visit(const For* v) {
9090
const LoopOptions& loop_options = v->loop_options();
9191
if (loop_options.is_gpu_block_index()) {
9292
ScopedVarName var_name(
93-
name_manager_, v->var().node(), loop_options.gpu_block_index_str());
93+
name_manager(), v->var().node(), loop_options.gpu_block_index_str());
9494
v->body().accept(this);
9595
int gpu_block_index = loop_options.gpu_block_index();
9696
if (gpu_block_extents_.size() <= gpu_block_index) {
@@ -104,7 +104,7 @@ void CudaPrinter::visit(const For* v) {
104104
gpu_block_extents_[gpu_block_index] = v->stop();
105105
} else if (loop_options.is_gpu_thread_index()) {
106106
ScopedVarName var_name(
107-
name_manager_, v->var().node(), loop_options.gpu_thread_index_str());
107+
name_manager(), v->var().node(), loop_options.gpu_thread_index_str());
108108
v->body().accept(this);
109109
int gpu_thread_index = loop_options.gpu_thread_index();
110110
if (gpu_thread_extents_.size() <= gpu_thread_index) {
@@ -122,7 +122,7 @@ void CudaPrinter::visit(const For* v) {
122122
}
123123

124124
void CudaCodeGen::Initialize() {
125-
printer_.reset(new CudaPrinter(&oss_, &name_manager_));
125+
printer_.reset(new CudaPrinter(&oss_));
126126
// TODO: handle multiple kernels.
127127
// TODO: handle dynamic dimension.
128128
// TODO: call nvrtc.
@@ -135,7 +135,7 @@ void CudaCodeGen::Initialize() {
135135
const BufferArg& buffer_arg = buffer_args[i];
136136
const Var& var = buffer_arg.var();
137137
Dtype dtype = buffer_arg.dtype();
138-
oss_ << dtype.ToCppString() << "* " << name_manager_.get_unique_name(var);
138+
oss_ << dtype.ToCppString() << "* " << name_manager()->get_unique_name(var);
139139
}
140140
oss_ << ") {";
141141

0 commit comments

Comments
 (0)