Skip to content

Commit c239008

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Fix index_select for scalar_types (pytorch#161206)
By copy-n-pasting logic from `index_select_out_cpu` (and `_cuda`), where essentially the resizing is done inside the op, which also fixes faulty logic for scalars Pull Request resolved: pytorch#161206 Approved by: https://github.com/manuelcandales
1 parent f09458c commit c239008

File tree

2 files changed

+7
-41
lines changed

2 files changed

+7
-41
lines changed

aten/src/ATen/native/mps/operations/Indexing.mm

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -595,28 +595,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
595595
}
596596

597597
Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) {
598-
IntArrayRef input_shape = self.sizes();
599-
auto num_input_dims = input_shape.size();
600-
601-
auto num_indices = index.numel();
602-
TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector");
603-
604-
dim = maybe_wrap_dim(dim, self.dim());
605-
std::vector<int64_t> shape_data(num_input_dims);
606-
607-
// Calculate new shape
608-
for (const auto i : c10::irange(num_input_dims)) {
609-
if (i == static_cast<decltype(i)>(dim)) {
610-
shape_data[i] = num_indices;
611-
} else {
612-
shape_data[i] = input_shape[i];
613-
}
614-
}
615-
616-
IntArrayRef output_shape = IntArrayRef(shape_data.data(), num_input_dims);
617-
618-
Tensor result = at::empty(output_shape, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
619-
598+
Tensor result = at::empty({0}, self.options());
620599
index_select_out_mps(self, dim, index, result);
621600
return result;
622601
}
@@ -638,25 +617,11 @@ Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) {
638617
TORCH_CHECK(self.scalar_type() == output.scalar_type(),
639618
"index_select(): self and output must have the same scalar type");
640619
TORCH_CHECK(dim == 0 || dim < self.dim(), "index_select(): Indexing dim ", dim, " is out of bounds of tensor");
641-
TORCH_CHECK(output.dim() == 0 || index.size(-1) == output.size(dim),
642-
"index_select(): index and output must have the same size at `dim`th dimension, but got ",
643-
index.size(-1),
644-
" and ",
645-
output.size(dim),
646-
".");
647-
648-
for (const auto i : irange(self.dim())) {
649-
if (i == dim)
650-
continue;
651-
TORCH_CHECK(self.size(i) == output.size(i),
652-
"index_select(): self and output must have the same dimensions except for `dim`th dimension, but got ",
653-
self.size(i),
654-
" and ",
655-
output.size(i),
656-
" at dimension ",
657-
i,
658-
".");
620+
auto output_size = self.sizes().vec();
621+
if (self.dim() > 0) {
622+
output_size[dim] = num_indices;
659623
}
624+
at::native::resize_output(output, output_size);
660625

661626
// Empty index
662627
if (num_indices == 0 || self.numel() == 0) {

test/test_indexing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
skipXLA,
2626
)
2727
from torch.testing._internal.common_dtype import (
28+
all_mps_types_and,
2829
all_types_and,
2930
all_types_and_complex_and,
3031
all_types_complex_float8_and,
@@ -2046,8 +2047,8 @@ def test_index_fill(self, device, dtype):
20462047

20472048
# The test fails for zero-dimensional tensors on XLA
20482049
@onlyNativeDeviceTypes
2049-
@expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/160737
20502050
@dtypes(*all_types_complex_float8_and(torch.half, torch.bool, torch.bfloat16))
2051+
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat))
20512052
def test_index_select(self, device, dtype):
20522053
num_src, num_out = 3, 5
20532054

0 commit comments

Comments
 (0)