Skip to content

Commit ce0b6ec

Browse files
t-vigchanan
authored andcommitted
Add derivative to pow with scalar base (pytorch#12450)
Summary: Fixes: pytorch#12426 Thank you, DriesSmit, for the report! Pull Request resolved: pytorch#12450 Differential Revision: D10238556 Pulled By: soumith fbshipit-source-id: 8bf71467c6734ecc5ff30f15500304d731f7e155
1 parent 9d27e75 commit ce0b6ec

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

test/test_autograd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,6 +2150,10 @@ def run_test(input_size, exponent):
21502150
run_test((10, 10), torch.zeros(10, 10))
21512151
run_test((10,), 0)
21522152

2153+
def test_pow_scalar_base(self):
2154+
a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_()
2155+
gradcheck(lambda a: torch.pow(2, a), (a,))
2156+
21532157
@skipIfRocm
21542158
def test_pinverse(self):
21552159
# Why is pinverse tested this way, and not ordinarily as other linear algebra methods?

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,9 @@
559559
self: pow_backward_self(grad, self, exponent)
560560
exponent: pow_backward_exponent(grad, self, exponent)
561561

562+
- name: pow(Scalar base, Tensor self)
563+
self: pow_backward_exponent(grad, base, self)
564+
562565
- name: _prod(Tensor self, int64_t dim, bool keepdim)
563566
self: prod_backward(grad, self, result, dim, keepdim)
564567

tools/autograd/templates/Functions.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ Tensor pow_backward_exponent(Tensor grad, const Tensor & self, const Tensor & ex
130130
return grad * self.pow(exponent) * self.log();
131131
}
132132

133+
Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor & exponent) {
134+
return grad * at::pow(base, exponent) * std::log(base.toDouble());
135+
}
136+
133137
Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
134138
Tensor args = at::arange(-p + 1, 1, -1, self.options()).div_(2.);
135139
args = args.add(self.unsqueeze(-1));

0 commit comments

Comments
 (0)