Skip to content

Commit 0c73c32

Browse files
author
Mikhail Zolotukhin
committed
[RFC] Add LoopNest class that implements Schedule's API in a different way.
LoopNest is my attempt to simplify our core abstraction. The main idea behind this change is to merge two classes: `TensorExprNode` and `For` (derived from `Stmt`). Currently they represent basically the same thing, but in a slightly different way. `TensorExprNode` attaches some metadata and provides a different way for traversing through siblings/parents/children. `For` represents the same structure, but without any metadata. Once a kernel is lowered to `For` statements, they are immediately consumed by a codegen, which lowers them to LLVMIR or prints as a CUDA string. This PR adds some functionality to `For` statements (and to other types of statements as well) and implements `SplitWithTail` and `ComputeInline` using only those. The implementation is just a proof of concept: it doesn't cover all corner cases, but they are trivial to add. As a demo, I added a test where we create a simple tensor-expression, then split one of the axis and then lower it to a Stmt. The demo shows that we're producing exactly the same result. For the reference, below is the output of the test (Root stmt - produced by the new implementation, Ref stmt - the product of the existing one): ``` [ RUN ] TensorExprTest.LoopNest_LLVM Root stmt: for (int n = 0; n < N; n++) { for (int i = 0; i < 1024; i++) { for (int j_outer = 0; j_outer < ((256 - 0) / 17); j_outer++) { for (int j_inner = 0; j_inner < 17; j_inner++) { g[(((n * (1024 * 256)) + (i * 256)) + (((j_outer * 17) + j_inner) * 1))] = (((A[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))] + B[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]) + C[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]) + D[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]); } } for (int j_tail = 0; j_tail < ((256 - 0) % 17); j_tail++) { g[(((n * (1024 * 256)) + (i * 256)) + ((j_tail + (((256 - 0) / 17) * 17)) * 1))] = (((A[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))] + B[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]) + C[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]) + D[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]); } } } Ref stmt: for (int n = 0; n < N; n++) { for (int i = 0; i < 1024; i++) { for (int j_outer = 0; j_outer < ((256 - 0) / 17); j_outer++) { for (int j_inner = 0; j_inner < 17; j_inner++) { g[(((n * (1024 * 256)) + (i * 256)) + (((j_outer * 17) + j_inner) * 1))] = (((A[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))] + B[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]) + C[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]) + D[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + ((j_outer * 17) + j_inner))]); } } for (int j_tail = 0; j_tail < ((256 - 0) % 17); j_tail++) { g[(((n * (1024 * 256)) + (i * 256)) + ((j_tail + (((256 - 0) / 17) * 17)) * 1))] = (((A[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))] + B[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]) + C[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]) + D[(((n * ((1 * 256) * 1024)) + (i * (1 * 256))) + (j_tail + (((256 - 0) / 17) * 17)))]); } } } [ OK ] TensorExprTest.LoopNest_LLVM (3 ms) ```
1 parent af20070 commit 0c73c32

File tree

5 files changed

+223
-4
lines changed

5 files changed

+223
-4
lines changed

test/cpp/tensorexpr/test_schedule.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,5 +545,67 @@ void testScheduleDynamicShape2D() {
545545
testWithSize(37, 11);
546546
}
547547

548+
void testLoopNest() {
549+
550+
KernelScope kernel_scope;
551+
const int kVectorSize = 8;
552+
const int kVectorCount = 128;
553+
const int kSize1 = 1024;
554+
const int kSize2 = 256;
555+
556+
VarHandle n("N", kHandle);
557+
Buffer a(VarHandle("A", kHandle), kFloat, {n, ExprHandle(kSize1), ExprHandle(kSize2)});
558+
Buffer b(VarHandle("B", kHandle), kFloat, {n, ExprHandle(kSize1), ExprHandle(kSize2)});
559+
Buffer c(VarHandle("C", kHandle), kFloat, {n, ExprHandle(kSize1), ExprHandle(kSize2)});
560+
Buffer d(VarHandle("D", kHandle), kFloat, {n, ExprHandle(kSize1), ExprHandle(kSize2)});
561+
562+
Tensor* e = Compute(
563+
"e",
564+
{{n, "n"}, {kSize1, "i"}, {kSize2, "j"}},
565+
[&](const VarHandle& n, const VarHandle& i, const VarHandle& j) {
566+
return a(n, i, j) + b(n, i, j);
567+
});
568+
Tensor* f = Compute(
569+
"f",
570+
{{n, "n"}, {kSize1, "i"}, {kSize2, "j"}},
571+
[&](const VarHandle& n, const VarHandle& i, const VarHandle& j) {
572+
return (*e)(n, i, j) + c(n, i, j);
573+
});
574+
Tensor* g = Compute(
575+
"g",
576+
{{n, "n"}, {kSize1, "i"}, {kSize2, "j"}},
577+
[&](const VarHandle& n, const VarHandle& i, const VarHandle& j) {
578+
return (*f)(n, i, j) + d(n, i, j);
579+
});
580+
581+
582+
// NEW API:
583+
{
584+
LoopNest l({e, f, g});
585+
l.ComputeInline(l.getLoopBodyFor(e));
586+
l.ComputeInline(l.getLoopBodyFor(f));
587+
std::vector<Stmt*> loops =
588+
l.getLoopStmtsFor(g); // gives a list of loops from outer to inner
589+
Stmt *j_outer, *j_inner, *j_tail;
590+
l.SplitWithTail(loops[2], 17, &j_outer, &j_inner, &j_tail);
591+
l.ApplyInlines();
592+
std::cerr << "Root stmt:\n" << *l.root_stmt();
593+
}
594+
595+
// CURRENT API:
596+
{
597+
Schedule sch({g});
598+
e->ComputeInline();
599+
f->ComputeInline();
600+
VarHandle j(g->function()->arg(2));
601+
VarHandle j_outer, j_inner, j_tail;
602+
TensorOperation* tail_op;
603+
g->SplitWithTail(j, 17, true, &j_outer, &j_inner, &j_tail, &tail_op);
604+
Stmt* s = sch.Lower();
605+
std::cerr << "Ref stmt:\n" << *s;
606+
}
607+
// Produced Stmts are identical in both Current and New APIs
608+
}
609+
548610
} // namespace jit
549611
} // namespace torch

test/cpp/tensorexpr/tests.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ namespace jit {
160160
_(LLVMBindDynamicShapeAdd) \
161161
_(LLVMTensorDynamicShapeAdd) \
162162
_(LLVMDynamicShape2D) \
163-
_(LLVMIfThenElseTest)
163+
_(LLVMIfThenElseTest) \
164+
_(LoopNest)
164165

165166
#define TH_FORALL_TESTS_CUDA(_) \
166167
_(CudaTestVectorAdd01) \

torch/csrc/jit/tensorexpr/schedule.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,109 @@ LoopAxis* LoopAxisTransform::NewAxis(
873873
return axis;
874874
}
875875

876+
// XXX
877+
LoopNest::LoopNest(const std::vector<Tensor*> tensors_to_compute) {
878+
std::vector<Tensor*> output_tensors(tensors_to_compute);
879+
880+
std::vector<Stmt*> loops;
881+
for (Tensor *t : tensors_to_compute) {
882+
Stmt* loop = LowerToStmt(t);
883+
loops.push_back(loop);
884+
}
885+
root_stmt_ = new Block(loops);
886+
}
887+
888+
Stmt* LoopNest::LowerToStmt(Tensor* t) {
889+
Function* f = t->function();
890+
// TODO: Support multiple-output functions
891+
Stmt* body = f->ElementStmt(0);
892+
893+
stmt_to_tensor_[body] = t;
894+
tensor_to_stmt_[t] = body;
895+
896+
CHECK(f->ndim() >= 1);
897+
for (size_t i = 0; i < f->ndim(); i++) {
898+
// Going in reverse order: from innermost loop to the outermost
899+
size_t dim_index = f->ndim() - i - 1;
900+
Range r(0, ExprHandle(f->dim(dim_index)));
901+
body = For::make(
902+
VarHandle(f->arg(dim_index)), r.start(), r.stop(), body);
903+
}
904+
return body;
905+
}
906+
907+
void LoopNest::ComputeInline(Stmt *s) {
908+
// TODO: check if `s` is a body of a loop
909+
inlined_functions_.insert(stmt_to_tensor_.at(s)->function());
910+
}
911+
912+
void LoopNest::ApplyInlines() {
913+
// TODO: check if `s` is a body of a loop
914+
std::vector<Function*> inlined_functions_vec(
915+
inlined_functions_.begin(), inlined_functions_.end());
916+
root_stmt_ = InjectInlines(root_stmt_, inlined_functions_vec);
917+
}
918+
919+
void LoopNest::SplitWithTail(Stmt *s, int factor, Stmt** inner, Stmt **outer, Stmt **tail) {
920+
Block* p = dynamic_cast<Block*>(s->parent_);
921+
For* f = dynamic_cast<For*>(s);
922+
if (!f) {
923+
std::cerr << "Stmt is not a For loop!\n";
924+
return;
925+
}
926+
if (!p) {
927+
std::cerr << "Parent is not a Block!\n";
928+
return;
929+
}
930+
auto const& size = ExprHandle(f->stop()) - ExprHandle(f->start());
931+
auto const& split_count = size / factor;
932+
auto const& tail_size = size % factor;
933+
934+
// TODO: handle a special case when the bounds are known and no tail loop is
935+
// needed.
936+
937+
const std::string& loop_var_name = f->var()->name_hint();
938+
Dtype loop_var_dtype = f->var()->dtype();
939+
940+
VarHandle i_inner(loop_var_name + "_inner", loop_var_dtype);
941+
VarHandle i_outer(loop_var_name + "_outer", loop_var_dtype);
942+
VarHandle i_tail(loop_var_name + "_tail", loop_var_dtype);
943+
944+
// x -> x.outer * inner.size + x.inner
945+
auto combined_index1 = i_outer * factor + i_inner;
946+
// x -> x.tail + outer.size * inner.size
947+
auto combined_index2 = i_tail + split_count * factor;
948+
949+
Stmt* body_inner = Substitute(f->body(), {{f->var(), combined_index1}});
950+
Stmt* body_tail = Substitute(f->body(), {{f->var(), combined_index2}});
951+
952+
*inner = For::make(i_inner, 0, factor, body_inner);
953+
*outer = For::make(i_outer, 0, split_count, *inner);
954+
*tail = For::make(i_tail, 0, tail_size, body_tail);
955+
956+
// TODO: cleanup API for adding/removing statements
957+
p->replace_stmt(s, *outer);
958+
p->append_stmt(*tail);
959+
960+
// TODO: record history of transformations
961+
}
962+
963+
std::vector<Stmt*> LoopNest::getLoopStmtsFor(Tensor* t) const {
964+
std::vector<Stmt*> result;
965+
Stmt* cur_stmt = tensor_to_stmt_.at(t);
966+
while (cur_stmt) {
967+
if (auto *loop = dynamic_cast<For*>(cur_stmt)) {
968+
result.push_back(cur_stmt);
969+
}
970+
cur_stmt = cur_stmt->parent_;
971+
}
972+
return std::vector(result.rbegin(), result.rend());
973+
}
974+
975+
Stmt* LoopNest::getLoopBodyFor(Tensor* t) const {
976+
return tensor_to_stmt_.at(t);
977+
}
978+
876979
} // namespace schedule
877980
} // namespace tensorexpr
878981
} // namespace jit

torch/csrc/jit/tensorexpr/schedule.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,29 @@ class TORCH_API Schedule {
671671
ScheduleNode* node_ = nullptr;
672672
};
673673

674+
class TORCH_API LoopNest {
675+
public:
676+
LoopNest(const std::vector<Tensor*> tensors_to_compute);
677+
Stmt* root_stmt() const {
678+
return root_stmt_;
679+
}
680+
681+
std::vector<Stmt*> getLoopStmtsFor(Tensor*) const;
682+
Stmt* getLoopBodyFor(Tensor*) const;
683+
std::unordered_map<Tensor*, Stmt*> tensor_to_stmt_;
684+
685+
void ComputeInline(Stmt* s);
686+
void ApplyInlines();
687+
void SplitWithTail(Stmt *s, int factor, Stmt** inner, Stmt **outer, Stmt **tail);
688+
689+
private:
690+
Stmt* LowerToStmt(Tensor *t);
691+
692+
std::unordered_set<Function*> inlined_functions_;
693+
std::unordered_map<Stmt*, Tensor*> stmt_to_tensor_;
694+
Stmt* root_stmt_;
695+
};
696+
674697
} // namespace schedule
675698
} // namespace tensorexpr
676699
} // namespace jit

torch/csrc/jit/tensorexpr/stmt.h

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class Stmt : public KernelScopedObject {
1616
Stmt() {}
1717
TORCH_API virtual void accept(IRVisitor* visitor) const = 0;
1818
virtual Stmt* accept_mutator(IRMutator* mutator) = 0;
19+
20+
Stmt* parent_ = nullptr;
1921
};
2022

2123
template <class Op>
@@ -84,9 +86,26 @@ class Block : public StmtNode<Block> {
8486
Stmt* stmt(int index) const {
8587
return stmts_[index];
8688
}
89+
void append_stmt(Stmt *s) {
90+
stmts_.push_back(s);
91+
}
92+
bool replace_stmt(Stmt* old_stmt, Stmt* new_stmt) {
93+
for (size_t i = 0; i < stmts_.size(); i++) {
94+
if (stmts_[i] == old_stmt) {
95+
stmts_[i] = new_stmt;
96+
return true;
97+
}
98+
}
99+
return false;
100+
}
87101

102+
explicit Block(const std::vector<Stmt*>& stmts) : stmts_(stmts) {
103+
for (auto s : stmts) {
104+
s->parent_ = this;
105+
}
106+
}
88107
private:
89-
explicit Block(const std::vector<Stmt*>& stmts) : stmts_(stmts) {}
108+
// TODO: change to a list to facilitate insertions and removals
90109
std::vector<Stmt*> stmts_;
91110
};
92111

@@ -358,8 +377,14 @@ class For : public StmtNode<For> {
358377
}
359378

360379
For(const Var* var, const Expr* start, const Expr* stop, Stmt* body)
361-
: var_(var), start_(start), stop_(stop), body_(body) {
380+
: var_(var), start_(start), stop_(stop) {
362381
CHECK(var && start && stop && body);
382+
Block *b = dynamic_cast<Block*>(body);
383+
if (!b) {
384+
b = new Block({body});
385+
}
386+
body_ = b;
387+
body_->parent_ = this;
363388
}
364389

365390
For(const Var* var,
@@ -370,9 +395,14 @@ class For : public StmtNode<For> {
370395
: var_(var),
371396
start_(start),
372397
stop_(stop),
373-
body_(body),
374398
loop_options_(loop_options) {
375399
CHECK(var && start && stop && body);
400+
Block *b = dynamic_cast<Block*>(body);
401+
if (!b) {
402+
b = new Block({body});
403+
}
404+
body_ = b;
405+
body_->parent_ = this;
376406
}
377407

378408
private:

0 commit comments

Comments
 (0)