Skip to content

gh-90213: Speed up right shifts of negative integers #30277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Lib/test/test_long.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,10 @@ def test_medium_rshift(self):
self.assertEqual((-1122) >> 9, -3)
self.assertEqual(2**128 >> 9, 2**119)
self.assertEqual(-2**128 >> 9, -2**119)
# Exercise corner case of the current algorithm, where the result of
# shifting a two-limb int by the limb size still has two limbs.
self.assertEqual((1 - BASE*BASE) >> SHIFT, -BASE)
self.assertEqual((BASE - 1 - BASE*BASE) >> SHIFT, -BASE)

def test_big_rshift(self):
self.assertEqual(42 >> 32, 0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Speed up right shift of negative integers, by removing unnecessary creation
of temporaries. Original patch by Xinhang Xu, reworked by Mark Dickinson.
98 changes: 69 additions & 29 deletions Objects/longobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -4677,13 +4677,23 @@ divmod_shift(PyObject *shiftby, Py_ssize_t *wordshift, digit *remshift)
return 0;
}

/* Inner function for both long_rshift and _PyLong_Rshift, shifting an
integer right by PyLong_SHIFT*wordshift + remshift bits.
wordshift should be nonnegative. */

static PyObject *
long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift)
{
PyLongObject *z = NULL;
Py_ssize_t newsize, hishift, i, j;
Py_ssize_t newsize, hishift, size_a;
twodigits accum;
int a_negative;

/* Total number of bits shifted must be nonnegative. */
assert(wordshift >= 0);
assert(remshift < PyLong_SHIFT);

/* Fast path for small a. */
if (IS_MEDIUM_VALUE(a)) {
stwodigits m, x;
digit shift;
Expand All @@ -4693,37 +4703,67 @@ long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift)
return _PyLong_FromSTwoDigits(x);
}

if (Py_SIZE(a) < 0) {
/* Right shifting negative numbers is harder */
PyLongObject *a1, *a2;
a1 = (PyLongObject *) long_invert(a);
if (a1 == NULL)
return NULL;
a2 = (PyLongObject *) long_rshift1(a1, wordshift, remshift);
Py_DECREF(a1);
if (a2 == NULL)
return NULL;
z = (PyLongObject *) long_invert(a2);
Py_DECREF(a2);
a_negative = Py_SIZE(a) < 0;
size_a = Py_ABS(Py_SIZE(a));

if (a_negative) {
/* For negative 'a', adjust so that 0 < remshift <= PyLong_SHIFT,
while keeping PyLong_SHIFT*wordshift + remshift the same. This
ensures that 'newsize' is computed correctly below. */
if (remshift == 0) {
if (wordshift == 0) {
/* Can only happen if the original shift was 0. */
return long_long((PyObject *)a);
}
remshift = PyLong_SHIFT;
--wordshift;
}
}
else {
newsize = Py_SIZE(a) - wordshift;
if (newsize <= 0)
return PyLong_FromLong(0);
hishift = PyLong_SHIFT - remshift;
z = _PyLong_New(newsize);
if (z == NULL)
return NULL;
j = wordshift;
accum = a->ob_digit[j++] >> remshift;
for (i = 0; j < Py_SIZE(a); i++, j++) {
accum |= (twodigits)a->ob_digit[j] << hishift;
z->ob_digit[i] = (digit)(accum & PyLong_MASK);
accum >>= PyLong_SHIFT;

assert(wordshift >= 0);
newsize = size_a - wordshift;
if (newsize <= 0) {
/* Shifting all the bits of 'a' out gives either -1 or 0. */
return PyLong_FromLong(-a_negative);
}
z = _PyLong_New(newsize);
if (z == NULL) {
return NULL;
}
hishift = PyLong_SHIFT - remshift;

accum = a->ob_digit[wordshift];
if (a_negative) {
/*
For a positive integer a and nonnegative shift, we have:

(-a) >> shift == -((a + 2**shift - 1) >> shift).

In the addition `a + (2**shift - 1)`, the low `wordshift` digits of
`2**shift - 1` all have value `PyLong_MASK`, so we get a carry out
from the bottom `wordshift` digits when at least one of the least
significant `wordshift` digits of `a` is nonzero. Digit `wordshift`
of `2**shift - 1` has value `PyLong_MASK >> hishift`.
*/
Py_SET_SIZE(z, -newsize);

digit sticky = 0;
for (Py_ssize_t j = 0; j < wordshift; j++) {
sticky |= a->ob_digit[j];
}
z->ob_digit[i] = (digit)accum;
z = maybe_small_long(long_normalize(z));
accum += (PyLong_MASK >> hishift) + (digit)(sticky != 0);
}

accum >>= remshift;
for (Py_ssize_t i = 0, j = wordshift + 1; j < size_a; i++, j++) {
accum += (twodigits)a->ob_digit[j] << hishift;
z->ob_digit[i] = (digit)(accum & PyLong_MASK);
accum >>= PyLong_SHIFT;
}
assert(accum <= PyLong_MASK);
z->ob_digit[newsize - 1] = (digit)accum;

z = maybe_small_long(long_normalize(z));
return (PyObject *)z;
}

Expand Down