Skip to content

Commit 188a1b2

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add checks for compute_slice (#15647)
Summary: Add safety checks to compute_slice, to ensure that we: 1. Do not read outside of the src tensor bounds 2. Do not write outside of the output tensor bounds Also pass in KernelRuntimeContext to use ET_KERNEL_CHECK_MSG and make errors non-fatal. Reviewed By: JacobSzwejbka Differential Revision: D86433966
1 parent b005f10 commit 188a1b2

File tree

6 files changed

+37
-4
lines changed

6 files changed

+37
-4
lines changed

backends/cadence/fusion_g3/operators/op_slice_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ Tensor& slice_copy_Tensor_out(
123123
InvalidArgument,
124124
out);
125125

126-
torch::executor::compute_slice(in, dim, start, length, step, out);
126+
torch::executor::compute_slice(ctx, in, dim, start, length, step, out);
127127
}
128128

129129
return out;

backends/cadence/hifi/operators/op_slice_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Tensor& slice_copy_Tensor_out(
6464
InvalidArgument,
6565
out);
6666

67-
compute_slice(in, dim, start, length, step, out);
67+
compute_slice(ctx, in, dim, start, length, step, out);
6868

6969
return out;
7070
}

kernels/portable/cpu/op_narrow_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Tensor& narrow_copy_out(
4646
out);
4747

4848
if (length != 0) {
49-
compute_slice(in, dim, start, length, 1, out);
49+
compute_slice(ctx, in, dim, start, length, 1, out);
5050
}
5151

5252
return out;

kernels/portable/cpu/op_slice_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Tensor& slice_copy_Tensor_out(
5555
InvalidArgument,
5656
out);
5757

58-
compute_slice(in, dim, start, length, step, out);
58+
compute_slice(ctx, in, dim, start, length, step, out);
5959

6060
return out;
6161
}

kernels/portable/cpu/util/slice_util.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,39 @@ int64_t adjust_slice_indices(
150150
}
151151

152152
void compute_slice(
153+
KernelRuntimeContext& ctx,
153154
const Tensor& in,
154155
int64_t dim,
155156
int64_t start,
156157
int64_t length,
157158
int64_t step,
158159
Tensor& out) {
160+
// No slicing requested.
161+
if (length <= 0) {
162+
return;
163+
}
164+
165+
ET_KERNEL_CHECK_MSG(
166+
ctx,
167+
dim < in.dim(),
168+
InvalidArgument,
169+
/* void */,
170+
"Requested dim is larger than input tensor dim");
159171
size_t dim_length = in.size(dim);
172+
ET_KERNEL_CHECK_MSG(
173+
ctx,
174+
start >= 0 && length >= 0 && step >= 0,
175+
InvalidArgument,
176+
/* void */,
177+
"Input args should be >= 0.");
178+
int64_t requested_slice = start + (length - 1) * step;
179+
ET_KERNEL_CHECK_MSG(
180+
ctx,
181+
static_cast<uint64_t>(requested_slice) <
182+
static_cast<uint64_t>(dim_length),
183+
InvalidArgument,
184+
/* void */,
185+
"Requested slice is larger than the dim size");
160186

161187
size_t leading_dims = getLeadingDims(in, dim);
162188
size_t trailing_dims = getTrailingDims(in, dim);
@@ -170,6 +196,12 @@ void compute_slice(
170196
const char* input_data = in.const_data_ptr<char>();
171197
char* dest = out.mutable_data_ptr<char>();
172198

199+
ET_KERNEL_CHECK_MSG(
200+
ctx,
201+
out.nbytes() >= (length * leading_dims * length_per_step),
202+
InvalidArgument,
203+
/* void */,
204+
"out.nbytes() is smaller than the expected slice size.");
173205
for (const auto i : c10::irange(leading_dims)) {
174206
const char* src = input_data + (i * dim_length + start) * length_per_step;
175207
for ([[maybe_unused]] const auto j : c10::irange(length)) {

kernels/portable/cpu/util/slice_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ int64_t adjust_slice_indices(
5555
int64_t step);
5656

5757
void compute_slice(
58+
KernelRuntimeContext& ctx,
5859
const Tensor& in,
5960
int64_t dim,
6061
int64_t start,

0 commit comments

Comments
 (0)