Skip to content

Commit 924a247

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Enable angle and atan2 for torch.long (#149017)
This check was added by #85817, that introduced no unit-tests and its content seems to be totally unrelated to title/subject of that PR. Anyway, right now it seems to be working fine on MacOS-13+ Pull Request resolved: #149017 Approved by: https://github.com/dcci
1 parent 7b78a2c commit 924a247

File tree

3 files changed

+0
-4
lines changed

3 files changed

+0
-4
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ static void add_sub_lerp_template(const Tensor& self,
397397
});
398398
}
399399
TORCH_IMPL_FUNC(atan2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
400-
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support atan2 op with int64 input");
401400
mps::binaryOpTensor(
402401
self, other, Scalar(1.0), output, "atan2", ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
403402
MPSGraph* mpsGraph = cachedGraph->graph();

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@ static void unary_op(const Tensor& self,
272272
}
273273

274274
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
275-
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support angle op with int64 input");
276275
if (mps::supportsComplex()) {
277276
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
278277
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];

test/test_mps.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,6 @@ def mps_ops_modifier(ops):
710710
'index_add': [torch.int64],
711711
'log1p': [torch.int64],
712712
'sigmoid': [torch.int64],
713-
'atan2': [torch.int64],
714-
'angle': [torch.int64],
715713

716714
# Operations not supported for integral types
717715
'special.xlog1py': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],

0 commit comments

Comments
 (0)