Skip to content

Commit f5eb05f

Browse files
committed
Revert "Reland #2 of "Added {logical_not, trace} refs, moved logical ops to use method overloads""
This reverts commit f3665dd. Reverted pytorch#79819 on behalf of https://github.com/malfet due to land raced with softshrink refs
1 parent f3665dd commit f5eb05f

File tree

5 files changed

+27
-61
lines changed

5 files changed

+27
-61
lines changed

test/test_meta.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ def run_meta_crossref(
400400
torch.mode: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::mode
401401
torch.multinomial: {bf16, f32, f64}, # aten::multinomial, aten::multinomial.out
402402
torch.mvlgamma: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense, aten::mvlgamma.out
403+
torch.nanmean: {bf16, f16, f32, f64},
404+
torch.nanquantile: {f32, f64},
403405
torch.nn.functional.conv1d: {bf16, f32, f64, i64},
404406
torch.nn.functional.conv2d: {bf16, f32, f64, i64},
405407
torch.nn.functional.conv_transpose1d: {f32, f64, i64},
@@ -463,9 +465,9 @@ def run_meta_crossref(
463465
torch.functional.cdist: {f32, f64},
464466
torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8},
465467
torch.inner: {bf16, f32, f64, i16, i32, i64, i8, u8},
468+
torch.logical_not: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8},
466469
torch.nn.functional.cross_entropy: {bf16, f32, f64},
467470
torch.nn.functional.interpolate: {bf16, f32, f64, u8},
468-
torch.nanmean: {bf16, f16, f32, f64}, # TODO(chilli): Doesn't seem to work for some reason?
469471
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
470472
torch.linalg.pinv: {f32, f64},
471473
torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
@@ -625,6 +627,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
625627
aten.log_sigmoid_forward.output: {bf16, f64, f32},
626628
aten.logcumsumexp.default: {bf16, f64, f32},
627629
aten.logcumsumexp.out: {bf16, f64, f32},
630+
aten.logical_not.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
631+
aten.logical_not_.default: {bf16, f16, f64, f32},
628632
aten.masked_select.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
629633
aten.masked_select.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
630634
aten.max_pool3d_with_indices.default: {f64, f32},

torch/_decomp/decompositions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,11 @@ def _fused_dropout_decomposition(input, p, generator=None):
10051005
return (res, mask)
10061006

10071007

1008+
@register_decomposition(aten.logical_not)
1009+
def logical_not(self: Tensor) -> Tensor:
1010+
return ~self.to(dtype=torch.bool)
1011+
1012+
10081013
@register_decomposition(aten.xlogy.Tensor)
10091014
@pw_cast_for_int_to_real
10101015
def xlogy(self: Tensor, other: Tensor) -> Tensor:
@@ -1161,6 +1166,11 @@ def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
11611166
return result.log().add(maxes_squeezed)
11621167

11631168

1169+
@register_decomposition(aten.trace.default)
1170+
def trace(self: Tensor) -> Tensor:
1171+
return torch.sum(torch.diag(self))
1172+
1173+
11641174
# nb: Should use acc_t, not op_math
11651175
@register_decomposition(aten.log_sigmoid_forward)
11661176
@out_wrapper_multi('output', 'buffer')

torch/_prims/context.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,7 @@ def torch_to_refs_map():
2727
(torch.nn.functional, torch._refs.nn.functional),
2828
(torch.special, torch._refs.special),
2929
]
30-
r: Dict[Any, Any] = {
31-
torch.Tensor.__invert__: torch._refs.bitwise_not,
32-
torch.Tensor.__xor__: torch._refs.bitwise_xor,
33-
torch.Tensor.__and__: torch._refs.bitwise_and,
34-
torch.Tensor.__or__: torch._refs.bitwise_or,
35-
torch.Tensor.__eq__: torch._refs.eq,
36-
}
30+
r = {}
3731
for mod_torch, mod_refs in modules:
3832
for s in mod_refs.__all__: # type: ignore[attr-defined]
3933
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)

torch/_refs/__init__.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888
"square",
8989
"tan",
9090
"tanh",
91-
"trace",
9291
#
9392
# Elementwise Binary References
9493
#
@@ -120,7 +119,6 @@
120119
# 'ldexp',
121120
"le",
122121
"logical_and",
123-
"logical_not",
124122
"logical_or",
125123
"logical_xor",
126124
"lt",
@@ -998,10 +996,10 @@ def _lcm(a: TensorLikeType, b: TensorLikeType):
998996

999997
def _logical_and(a: TensorLikeType, b: TensorLikeType):
1000998
if not utils.is_boolean_dtype(a.dtype):
1001-
a = a != 0
999+
a = ne(a, 0)
10021000
if not utils.is_boolean_dtype(b.dtype):
1003-
b = b != 0
1004-
return a & b
1001+
b = ne(b, 0)
1002+
return bitwise_and(a, b)
10051003

10061004

10071005
logical_and = _make_elementwise_binary_reference(
@@ -1011,21 +1009,12 @@ def _logical_and(a: TensorLikeType, b: TensorLikeType):
10111009
)
10121010

10131011

1014-
@_make_elementwise_unary_reference(
1015-
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_not
1016-
)
1017-
def logical_not(a: TensorLikeType):
1018-
if not utils.is_boolean_dtype(a.dtype):
1019-
return a == 0
1020-
return ~a
1021-
1022-
10231012
def _logical_or(a: TensorLikeType, b: TensorLikeType):
10241013
if not utils.is_boolean_dtype(a.dtype):
1025-
a = a != 0
1014+
a = ne(a, 0)
10261015
if not utils.is_boolean_dtype(b.dtype):
1027-
b = b != 0
1028-
return a | b
1016+
b = ne(b, 0)
1017+
return bitwise_or(a, b)
10291018

10301019

10311020
logical_or = _make_elementwise_binary_reference(
@@ -1037,10 +1026,10 @@ def _logical_or(a: TensorLikeType, b: TensorLikeType):
10371026

10381027
def _logical_xor(a: TensorLikeType, b: TensorLikeType):
10391028
if not utils.is_boolean_dtype(a.dtype):
1040-
a = a != 0
1029+
a = ne(a, 0)
10411030
if not utils.is_boolean_dtype(b.dtype):
1042-
b = b != 0
1043-
return a ^ b
1031+
b = ne(b, 0)
1032+
return bitwise_xor(a, b)
10441033

10451034

10461035
# TODO: skip unnecessary conversion of long to float
@@ -2625,13 +2614,6 @@ def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
26252614
return item(all(eq(a, b))) # type: ignore[return-value]
26262615

26272616

2628-
@register_decomposition(torch.ops.aten.trace)
2629-
def trace(self: TensorLikeType) -> TensorLikeType:
2630-
utils.check(
2631-
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
2632-
)
2633-
return torch.sum(torch.diag(self, 0))
2634-
2635-
2617+
# populate the decomp table
26362618
import torch._refs.nn.functional
26372619
import torch._refs.special

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3662,10 +3662,6 @@ def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs):
36623662
requires_grad=requires_grad))),)
36633663

36643664

3665-
def error_inputs_trace(op, device):
3666-
yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix")
3667-
3668-
36693665
def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs):
36703666
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
36713667
cases = (((S, S, S), (2, 1, 0.5)),
@@ -4334,6 +4330,7 @@ def error_inputs_embedding(op_info, device, **kwargs):
43344330
def error_inputs_t(op_info, device, **kwargs):
43354331
yield ErrorInput(
43364332
SampleInput(torch.randn(2, 3, 4, 5, device=device)),
4333+
error_type=RuntimeError,
43374334
error_regex="expects a tensor with <= 2",
43384335
)
43394336

@@ -17637,7 +17634,6 @@ def error_inputs_mean(op_info, device, **kwargs):
1763717634
dtypes=all_types_and_complex(),
1763817635
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
1763917636
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
17640-
error_inputs_func=error_inputs_trace,
1764117637
supports_inplace_autograd=False,
1764217638
supports_out=False,
1764317639
supports_forward_ad=True,
@@ -20624,16 +20620,6 @@ def __init__(
2062420620
),
2062520621
)
2062620622
),
20627-
ElementwiseUnaryPythonRefInfo(
20628-
"_refs.logical_not",
20629-
torch_opinfo_name="logical_not",
20630-
skips=(
20631-
DecorateInfo(
20632-
# NotImplementedError: argument of type: <class 'complex'>
20633-
unittest.skip("Fails aten complex and nvfuser doesn't support eq(a, 0)"), 'TestCommon', 'test_python_ref_executor'
20634-
),
20635-
)
20636-
),
2063720623
ElementwiseBinaryPythonRefInfo(
2063820624
"_refs.logical_or",
2063920625
torch_opinfo_name="logical_or",
@@ -21207,16 +21193,6 @@ def __init__(
2120721193
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
2120821194
),
2120921195
),
21210-
PythonRefInfo(
21211-
"_refs.trace",
21212-
torch_opinfo_name="trace",
21213-
decorators=(
21214-
# TODO: torch.diag is currently not supported by either refs, meta funcs, or NVFuser
21215-
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
21216-
DecorateInfo(unittest.skip("diag is not supported by meta"), 'TestCommon', 'test_python_ref_meta'),
21217-
DecorateInfo(unittest.skip("diag is not supported by nvfuser"), 'TestCommon', 'test_python_ref_executor'),
21218-
),
21219-
),
2122021196
#
2122121197
# Tensor Creation Reference OpInfos
2122221198
#

0 commit comments

Comments
 (0)