Skip to content

Commit b01d6f2

Browse files
cybershiptrooperpytorchmergebot
authored andcommitted
1 parent 5842e5c commit b01d6f2

File tree

5 files changed

+15
-4
lines changed

5 files changed

+15
-4
lines changed

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,8 +568,6 @@ aten::adaptive_max_pool3d_backward
568568
aten::adaptive_max_pool3d_backward.grad_input
569569
aten::addbmm
570570
aten::addbmm.out
571-
aten::addmv
572-
aten::addmv.out
573571
aten::addr_
574572
aten::affine_grid_generator
575573
aten::affine_grid_generator.out

test/functorch/test_aotdispatch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2463,7 +2463,6 @@ def forward(self, x):
24632463
}
24642464

24652465
symbolic_aot_autograd_failures = {
2466-
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
24672466
xfail('amax', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
24682467
xfail('amin', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
24692468
xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides

test/test_decomp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
174174
(torch.float16, torch.ops.aten.var_mean.dim): 5e-7,
175175
(torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
176176
(torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
177+
# see https://github.com/pytorch/pytorch/pull/96264
178+
(torch.float16, torch.ops.aten.mv.default): 1e-5,
177179
}
178180
if ref.is_floating_point():
179181
orig_diff = (orig - ref).abs().max()

test/test_proxy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,6 @@ def f(a, b, c, d, e):
13321332
xfail('linalg.eigvals'),
13331333
skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
13341334
xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
1335-
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
13361335
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
13371336
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
13381337
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel

torch/_decomp/decompositions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,19 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
11361136
return out + beta * self
11371137

11381138

1139+
@register_decomposition(aten.addmv)
1140+
@out_wrapper()
1141+
@pw_cast_for_opmath
1142+
def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1):
1143+
if not self.is_floating_point() and not self.is_complex():
1144+
beta = int(beta)
1145+
alpha = int(alpha)
1146+
out = alpha * torch.mv(mat1, vec)
1147+
if beta == 0:
1148+
return out
1149+
return out + beta * self
1150+
1151+
11391152
@register_decomposition(aten.native_group_norm_backward)
11401153
@pw_cast_for_opmath
11411154
def native_group_norm_backward(

0 commit comments

Comments
 (0)