Skip to content

Commit 3d9c42b

Browse files
resistorMikhail Zolotukhin
authored andcommitted
Add support for IfThenElse (pytorch#103)
1 parent a7e9127 commit 3d9c42b

File tree

15 files changed

+165
-1
lines changed

15 files changed

+165
-1
lines changed

test/cpp/tensorexpr/test_expr.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,5 +324,31 @@ void testCond01() {
324324
ExpectAllNear(a_v, a_ref, 1e-5);
325325
}
326326

327+
void testIfThenElse01() {
328+
KernelScope kernel_scope;
329+
Expr v = ifThenElse(Expr(1), Expr(1.0f), Expr(2.0f));
330+
331+
std::ostringstream oss;
332+
oss << v;
333+
ASSERT_EQ(oss.str(), "IfThenElse(1, 1, 2)");
334+
335+
SimpleIREvaluator eval(v);
336+
eval();
337+
ASSERT_EQ(eval.value().as<float>(), 1.0f);
338+
}
339+
340+
void testIfThenElse02() {
341+
KernelScope kernel_scope;
342+
Expr v = ifThenElse(Expr(0), Expr(1.0f), Expr(2.0f));
343+
344+
std::ostringstream oss;
345+
oss << v;
346+
ASSERT_EQ(oss.str(), "IfThenElse(0, 1, 2)");
347+
348+
SimpleIREvaluator eval(v);
349+
eval();
350+
ASSERT_EQ(eval.value().as<float>(), 2.0f);
351+
}
352+
327353
} // namespace jit
328354
} // namespace torch

test/cpp/tensorexpr/test_llvm.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,30 @@ void testLLVMLoadStoreTest() {
153153
EXPECT_EQ(b_buffer[0], 42);
154154
}
155155

156+
void testLLVMIfThenElseTest() {
157+
KernelScope kernel_scope;
158+
Buffer a(Var("A", kHandle), kInt32, {1});
159+
Buffer b(Var("B", kHandle), kInt32, {1});
160+
Buffer c(Var("C", kHandle), kInt32, {1});
161+
std::vector<int32_t> a_buffer = {42};
162+
std::vector<int32_t> b_buffer = {-11};
163+
std::vector<int32_t> c_buffer = {1};
164+
165+
auto store = Store::make(
166+
b,
167+
IntImm::make(0),
168+
IfThenElse::make(
169+
Load::make(c, IntImm::make(0), IntImm::make(1)), // cond
170+
Load::make(a, IntImm::make(0), IntImm::make(1)), // then
171+
IntImm::make(0)), // else
172+
IntImm::make(1));
173+
LLVMCodeGen cg(store, {a, b, c});
174+
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
175+
EXPECT_EQ(cg.value<int>(args), 0);
176+
EXPECT_EQ(a_buffer[0], 42);
177+
EXPECT_EQ(b_buffer[0], 42);
178+
}
179+
156180
void testLLVMVecLoadStoreTest() {
157181
KernelScope kernel_scope;
158182
Buffer a(Var("A", kHandle), kInt32, {1});

test/cpp/tensorexpr/tests.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ namespace jit {
4141
_(AsmjitIntMulTest) \
4242
_(AsmjitIntDivTest) \
4343
_(Cond01) \
44+
_(IfThenElse01) \
45+
_(IfThenElse02) \
4446
_(ATen_cast_Float) \
4547
_(ATennegInt) \
4648
_(ATennegFloat) \
@@ -109,7 +111,8 @@ namespace jit {
109111
_(LLVMComputeMul) \
110112
_(LLVMBroadcastAdd) \
111113
_(LLVMDynamicShapeAdd) \
112-
_(LLVMBindDynamicShapeAdd)
114+
_(LLVMBindDynamicShapeAdd) \
115+
_(LLVMIfThenElseTest)
113116

114117
#define TH_FORALL_TESTS_CUDA(_) \
115118
_(CudaTestVectorAdd01) \

torch/csrc/jit/tensorexpr/eval.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,15 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
379379
}
380380
}
381381

382+
TORCH_API void visit(const IfThenElse* v) override {
383+
v->condition().accept(this);
384+
if (value_.as<int>()) {
385+
v->true_value().accept(this);
386+
} else {
387+
v->false_value().accept(this);
388+
}
389+
}
390+
382391
TORCH_API void visit(const Load* v) override {
383392
const Variable* base_node = v->base_handle().node();
384393
auto iter = buffer_mapping_.find(base_node);

torch/csrc/jit/tensorexpr/expr.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ Expr fmod(const Expr& v1, const Expr& v2) {
193193
Expr remainder(const Expr& v1, const Expr& v2) {
194194
return Intrinsics::make(kRemainder, v1, v2);
195195
}
196+
197+
Expr ifThenElse(const Expr& c, const Expr& t, const Expr& f) {
198+
return IfThenElse::make(c, t, f);
199+
}
200+
196201
} // namespace tensorexpr
197202
} // namespace jit
198203
} // namespace torch

torch/csrc/jit/tensorexpr/expr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,9 @@ TORCH_API Expr pow(const Expr& v1, const Expr& v2);
261261
TORCH_API Expr fmod(const Expr& v1, const Expr& v2);
262262
TORCH_API Expr remainder(const Expr& v1, const Expr& v2);
263263

264+
TORCH_API Expr ifThenElse(const Expr& c, const Expr& t, const Expr& f);
265+
266+
264267
} // namespace tensorexpr
265268
} // namespace jit
266269
} // namespace torch

torch/csrc/jit/tensorexpr/ir.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,40 @@ class Broadcast : public ExprNode<Broadcast> {
606606
Expr value_;
607607
int lanes_;
608608
};
609+
class IfThenElse : public ExprNode<IfThenElse> {
610+
public:
611+
const Expr& condition() const {
612+
return condition_;
613+
}
614+
615+
// Lazily evaluated only if condition is true
616+
const Expr& true_value() const {
617+
return true_;
618+
}
619+
620+
// Lazily evaluated only if condition is false
621+
const Expr& false_value() const {
622+
return false_;
623+
}
624+
625+
static Expr make(const Expr& c, const Expr& t, const Expr& f) {
626+
return Expr(new IfThenElse(c, t, f));
627+
}
628+
629+
private:
630+
IfThenElse(const Expr& c, const Expr& t, const Expr& f)
631+
: ExprNodeBase(t.dtype()),
632+
condition_(c),
633+
true_(t),
634+
false_(f) {
635+
CHECK_EQ(c.dtype().scalar_type(), kInt32);
636+
CHECK_EQ(c.dtype().lanes(), 1);
637+
CHECK_EQ(t.dtype(), f.dtype());
638+
}
639+
Expr condition_;
640+
Expr true_;
641+
Expr false_;
642+
};
609643

610644
class BaseCallNode : public BaseExprNode {
611645
public:

torch/csrc/jit/tensorexpr/ir_mutator.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,22 @@ Expr IRMutator::mutate(const Broadcast* v) {
150150
return Broadcast::make(value_new, lanes);
151151
}
152152

153+
Expr IRMutator::mutate(const IfThenElse* v) {
154+
Expr condition = v->condition();
155+
Expr true_value = v->true_value();
156+
Expr false_value = v->false_value();
157+
Expr condition_new = condition.accept_mutator(this);
158+
Expr true_value_new = true_value.accept_mutator(this);
159+
Expr false_value_new = false_value.accept_mutator(this);
160+
if (same_node(condition, condition_new) &&
161+
same_node(true_value, true_value_new) &&
162+
same_node(false_value, false_value_new)) {
163+
return Expr(v);
164+
}
165+
166+
return IfThenElse::make(condition_new, true_value_new, false_value_new);
167+
}
168+
153169
Expr IRMutator::mutate(const Intrinsics* v) {
154170
const BaseCallNode* base = v;
155171
return this->mutate(base);

torch/csrc/jit/tensorexpr/ir_mutator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class For;
2424
class Block;
2525
class Store;
2626
class Broadcast;
27+
class IfThenElse;
2728
class Expr;
2829
class Stmt;
2930
class BaseCallNode;
@@ -52,6 +53,7 @@ class TORCH_API IRMutator {
5253
virtual Expr mutate(const Ramp* v);
5354
virtual Expr mutate(const Load* v);
5455
virtual Expr mutate(const Broadcast* v);
56+
virtual Expr mutate(const IfThenElse* v);
5557
// BaseCallNode is the base class for all call nodes.
5658
// For any visitors that only needs the common behavior, only override this
5759
// function is enough. This is because all derived class handlers will call

torch/csrc/jit/tensorexpr/ir_printer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ void IRPrinter::visit(const Broadcast* v) {
161161
os() << "Broadcast(" << v->value() << ", " << v->lanes() << ")";
162162
}
163163

164+
void IRPrinter::visit(const IfThenElse* v) {
165+
os() << "IfThenElse(" << v->condition() << ", " << v->true_value() << ", " << v->false_value() << ")";
166+
}
167+
164168
void IRPrinter::visit(const BaseCallNode* v) {
165169
os() << v->func_name() << "(";
166170
for (int i = 0; i < v->nparams(); i++) {

0 commit comments

Comments
 (0)