@@ -595,28 +595,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
595595}
596596
597597Tensor index_select_mps (const Tensor& self, int64_t dim, const Tensor& index) {
598- IntArrayRef input_shape = self.sizes ();
599- auto num_input_dims = input_shape.size ();
600-
601- auto num_indices = index.numel ();
602- TORCH_CHECK_INDEX (index.dim () <= 1 , " index_select(): Index is supposed to be a vector" );
603-
604- dim = maybe_wrap_dim (dim, self.dim ());
605- std::vector<int64_t > shape_data (num_input_dims);
606-
607- // Calculate new shape
608- for (const auto i : c10::irange (num_input_dims)) {
609- if (i == static_cast <decltype (i)>(dim)) {
610- shape_data[i] = num_indices;
611- } else {
612- shape_data[i] = input_shape[i];
613- }
614- }
615-
616- IntArrayRef output_shape = IntArrayRef (shape_data.data (), num_input_dims);
617-
618- Tensor result = at::empty (output_shape, self.scalar_type (), std::nullopt , kMPS , std::nullopt , std::nullopt );
619-
598+ Tensor result = at::empty ({0 }, self.options ());
620599 index_select_out_mps (self, dim, index, result);
621600 return result;
622601}
@@ -638,25 +617,11 @@ Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) {
638617 TORCH_CHECK (self.scalar_type () == output.scalar_type (),
639618 " index_select(): self and output must have the same scalar type" );
640619 TORCH_CHECK (dim == 0 || dim < self.dim (), " index_select(): Indexing dim " , dim, " is out of bounds of tensor" );
641- TORCH_CHECK (output.dim () == 0 || index.size (-1 ) == output.size (dim),
642- " index_select(): index and output must have the same size at `dim`th dimension, but got " ,
643- index.size (-1 ),
644- " and " ,
645- output.size (dim),
646- " ." );
647-
648- for (const auto i : irange (self.dim ())) {
649- if (i == dim)
650- continue ;
651- TORCH_CHECK (self.size (i) == output.size (i),
652- " index_select(): self and output must have the same dimensions except for `dim`th dimension, but got " ,
653- self.size (i),
654- " and " ,
655- output.size (i),
656- " at dimension " ,
657- i,
658- " ." );
620+ auto output_size = self.sizes ().vec ();
621+ if (self.dim () > 0 ) {
622+ output_size[dim] = num_indices;
659623 }
624+ at::native::resize_output (output, output_size);
660625
661626 // Empty index
662627 if (num_indices == 0 || self.numel () == 0 ) {
0 commit comments