11#include < torch/csrc/jit/tensorexpr/kernel.h>
2+ #include < torch/csrc/jit/tensorexpr/ir_printer.h>
23#include < torch/csrc/jit/tensorexpr/schedule.h>
34
45using namespace torch ::jit;
@@ -120,12 +121,67 @@ Expr TensorExprKernel::demoteOutput(const Expr& e, const torch::jit::Value* v) {
120121 return e;
121122}
122123
124+ static bool isOne (Expr e) {
125+ auto const & n = e.AsNode <IntImm>();
126+ if (!n) {
127+ return false ;
128+ }
129+ return n->value () == 1 ;
130+ }
131+
132+ static std::vector<Expr> broadcastShapes (
133+ const std::vector<Expr>& a,
134+ const std::vector<Expr>& b) {
135+ auto at = a.rbegin ();
136+ auto bt = b.rbegin ();
137+ std::vector<Expr> ret;
138+ while (at != a.rend () || bt != b.rend ()) {
139+ if (at == a.rend ()) {
140+ ret.push_back (*bt++);
141+ continue ;
142+ }
143+ if (bt == b.rend ()) {
144+ ret.push_back (*at++);
145+ continue ;
146+ }
147+ // TODO: if neither *at nor *bt is 1, ensure they are identical
148+ // expressions. Nb: `==` doesn't work since that simply produces a new
149+ // Expr.
150+ Expr dim = isOne (*at) ? *bt : *at;
151+ ret.push_back (dim);
152+ at++;
153+ bt++;
154+ }
155+ std::reverse (ret.begin (), ret.end ());
156+ return ret;
157+ }
158+
159+ template <typename ... Args>
160+ static std::vector<Expr> broadcastShapes (
161+ const std::vector<Expr>& a,
162+ const std::vector<Expr>& b,
163+ Args... args) {
164+ return broadcastShapes (broadcastShapes (a, b), args...);
165+ }
166+
167+ std::vector<Expr> TensorExprKernel::valueShape (const torch::jit::Value* v) {
168+ auto it = tensors_.find (v->unique ());
169+ if (it == tensors_.end ()) {
170+ return {1 };
171+ }
172+ return it->second .dims ();
173+ }
174+
123175Tensor TensorExprKernel::ComputeOneOperand (
124176 const std::string& name,
125177 const torch::jit::Value* v,
126178 std::function<Expr(const Expr&)> inner_expr) {
179+ auto const & n = v->node ();
180+ auto const & shape = valueShape (n->inputs ()[0 ]);
127181 return Compute (
128- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
182+ name,
183+ c10::fmap<DimArg>(shape),
184+ [this , v, inner_expr](const std::vector<Var>& axes) {
129185 auto const & n = v->node ();
130186 std::vector<Expr> inputs = {tensorOrConstant (n->inputs ()[0 ], axes)};
131187
@@ -139,8 +195,13 @@ Tensor TensorExprKernel::ComputeTwoOperand(
139195 const std::string& name,
140196 const torch::jit::Value* v,
141197 std::function<Expr(const Expr&, const Expr&)> inner_expr) {
198+ auto const & n = v->node ();
199+ auto const & shape =
200+ broadcastShapes (valueShape (n->inputs ()[0 ]), valueShape (n->inputs ()[1 ]));
142201 return Compute (
143- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
202+ name,
203+ c10::fmap<DimArg>(shape),
204+ [this , v, inner_expr](const std::vector<Var>& axes) {
144205 auto const & n = v->node ();
145206 std::vector<Expr> inputs = {
146207 tensorOrConstant (n->inputs ()[0 ], axes),
@@ -157,8 +218,13 @@ Tensor TensorExprKernel::ComputeTwoOperandWithAlpha(
157218 const std::string& name,
158219 const torch::jit::Value* v,
159220 std::function<Expr(const Expr&, const Expr&)> inner_expr) {
221+ auto const & n = v->node ();
222+ auto const & shape =
223+ broadcastShapes (valueShape (n->inputs ()[0 ]), valueShape (n->inputs ()[1 ]));
160224 return Compute (
161- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
225+ name,
226+ c10::fmap<DimArg>(shape),
227+ [this , v, inner_expr](const std::vector<Var>& axes) {
162228 auto const & n = v->node ();
163229 std::vector<Expr> inputs = {
164230 tensorOrConstant (n->inputs ()[0 ], axes),
@@ -176,8 +242,15 @@ Tensor TensorExprKernel::ComputeThreeOperand(
176242 const std::string& name,
177243 const torch::jit::Value* v,
178244 std::function<Expr(const Expr&, const Expr&, const Expr&)> inner_expr) {
245+ auto const & n = v->node ();
246+ auto const & shape = broadcastShapes (
247+ valueShape (n->inputs ()[0 ]),
248+ valueShape (n->inputs ()[1 ]),
249+ valueShape (n->inputs ()[2 ]));
179250 return Compute (
180- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
251+ name,
252+ c10::fmap<DimArg>(shape),
253+ [this , v, inner_expr](const std::vector<Var>& axes) {
181254 auto const & n = v->node ();
182255 std::vector<Expr> inputs = {
183256 tensorOrConstant (n->inputs ()[0 ], axes),
@@ -194,9 +267,18 @@ Tensor TensorExprKernel::ComputeThreeOperand(
194267Tensor TensorExprKernel::ComputeFourOperand (
195268 const std::string& name,
196269 const torch::jit::Value* v,
197- std::function<Expr(const Expr&, const Expr&, const Expr&, const Expr&)> inner_expr) {
270+ std::function<Expr(const Expr&, const Expr&, const Expr&, const Expr&)>
271+ inner_expr) {
272+ auto const & n = v->node ();
273+ auto const & shape = broadcastShapes (
274+ valueShape (n->inputs ()[0 ]),
275+ valueShape (n->inputs ()[1 ]),
276+ valueShape (n->inputs ()[2 ]),
277+ valueShape (n->inputs ()[3 ]));
198278 return Compute (
199- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
279+ name,
280+ c10::fmap<DimArg>(shape),
281+ [this , v, inner_expr](const std::vector<Var>& axes) {
200282 auto const & n = v->node ();
201283 std::vector<Expr> inputs = {
202284 tensorOrConstant (n->inputs ()[0 ], axes),
0 commit comments