Skip to content

Commit e9c24ac

Browse files
soumithfacebook-github-bot
authored andcommitted
fix arange shape issue inconsistency across cpu and cuda (#18462)
Summary: Fixes pytorch/pytorch#18363 Pull Request resolved: pytorch/pytorch#18462 Differential Revision: D14620263 Pulled By: soumith fbshipit-source-id: 223524cdda2f5d55c2ca8d4cdcf6f7a05a6c15eb
1 parent b800238 commit e9c24ac

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

aten/src/ATen/native/RangeFactories.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,29 @@ Tensor& arange_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
116116
auto xend = end.to<accscalar_t>();
117117
auto xstep = step.to<accscalar_t>();
118118

119+
// we use double precision for (start - end) / step
120+
// to compute size_d for consistency across devices.
121+
// The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t,
122+
// but double on cpu for the same,
123+
// and the effective output size starts differing on CPU vs GPU because of precision issues, which
124+
// we dont want.
125+
// the corner-case we do want to take into account is int64_t, which has higher precision than double
126+
double size_d;
127+
if (std::is_same<scalar_t, int64_t>::value) {
128+
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
129+
/ step.to<accscalar_t>());
130+
} else {
131+
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
132+
/ step.to<double>());
133+
}
134+
119135
AT_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
120136
AT_CHECK(std::isfinite(static_cast<double>(xstart)) &&
121137
std::isfinite(static_cast<double>(xend)),
122138
"unsupported range: ", xstart, " -> ", xend);
123139
AT_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
124140
"upper bound and larger bound inconsistent with step sign");
125141

126-
double size_d = std::ceil(static_cast<double>(xend - xstart) / xstep);
127142
AT_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
128143
"invalid size, possible overflow?");
129144
int64_t size = static_cast<int64_t>(size_d);

aten/src/ATen/native/cuda/RangeFactories.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,29 @@ Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
136136
auto xend = end.to<accscalar_t>();
137137
auto xstep = step.to<accscalar_t>();
138138

139+
// we use double precision for (start - end) / step
140+
// to compute size_d for consistency across devices.
141+
// The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t,
142+
// but double on cpu for the same,
143+
// and the effective output size starts differing on CPU vs GPU because of precision issues, which
144+
// we dont want.
145+
// the corner-case we do want to take into account is int64_t, which has higher precision than double
146+
double size_d;
147+
if (std::is_same<scalar_t, int64_t>::value) {
148+
size_d = std::ceil(static_cast<double>(end.to<accscalar_t>() - start.to<accscalar_t>())
149+
/ step.to<accscalar_t>());
150+
} else {
151+
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
152+
/ step.to<double>());
153+
}
154+
139155
AT_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
140156
AT_CHECK(std::isfinite(static_cast<double>(xstart)) &&
141157
std::isfinite(static_cast<double>(xend)),
142158
"unsupported range: ", xstart, " -> ", xend);
143159
AT_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
144160
"upper bound and larger bound inconsistent with step sign");
145161

146-
double size_d = std::ceil(static_cast<double>(xend - xstart) / xstep);
147162
AT_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
148163
"invalid size, possible overflow?");
149164
int64_t size = static_cast<int64_t>(size_d);

0 commit comments

Comments
 (0)