Skip to content

Commit e49d497

Browse files
authored
Improve Random{NumberGenerator}.GetItems/String for non-power of 2 choices (#107988)
In .NET 9, we added an optimization to Random.GetItems and RandomNumberGenerator.GetItems/GetString that special-cases a power-of-2 number of choices that's <= 256. In such a case, we can avoid many trips to the RNG by requesting bytes in bulk, rather than requesting an Int32 per element. Each byte is masked to produce the index into the choices. This PR extends that optimization to also cover non-power-of-2 choices. It can't just mask off the bits as in the power-of-2 case, but it can mask off some bits and then do rejection sampling, which on average still yields big wins.
1 parent 7a31c17 commit e49d497

File tree

3 files changed

+230
-52
lines changed

3 files changed

+230
-52
lines changed

src/libraries/System.Private.CoreLib/src/System/Random.cs

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -197,46 +197,94 @@ public void GetItems<T>(ReadOnlySpan<T> choices, Span<T> destination)
197197
throw new ArgumentException(SR.Arg_EmptySpan, nameof(choices));
198198
}
199199

200-
// The most expensive part of this operation is the call to get random data. We can
201-
// do so potentially many fewer times if:
202-
// - the instance was constructed as `new Random()` or is `Random.Shared`, such that it's not seeded nor is it
203-
// a custom derived type. We don't want to observably change the deterministically-produced sequence from previous releases.
204-
// - the number of choices is <= 256. This let's us get a single byte per choice.
205-
// - the number of choices is a power of two. This let's us use a byte and simply mask off
206-
// unnecessary bits cheaply rather than needing to use rejection sampling.
207-
// In such a case, we can grab a bunch of random bytes in one call.
200+
// The most expensive part of this operation is the call to get random data. If the number of
201+
// choices is <= 256 (which is the majority use case), we can use a single byte per element,
202+
// which means we can ammortize the cost of getting random data by getting random bytes in bulk.
203+
// However, we can only do that if this instance is Random.Shared or an instance created with
204+
// `new Random()`. If it was created with a seed, changing which members we call and how many
205+
// times may result in a visible difference in the sequence of output items. Similarly if it's
206+
// a derived instance, which overrides get called and when is observable.
208207
ImplBase impl = _impl;
209208
if ((impl is null || impl.GetType() == typeof(XoshiroImpl)) &&
210-
BitOperations.IsPow2(choices.Length) &&
211209
choices.Length <= 256)
212210
{
213-
Span<byte> randomBytes = stackalloc byte[512]; // arbitrary size, a balance between stack consumed and number of random calls required
214-
while (!destination.IsEmpty)
211+
// Get stack space to store random bytes. This size was chosen to balance between
212+
// stack consumed and number of random calls required.
213+
Span<byte> randomBytes = stackalloc byte[512];
214+
215+
if (BitOperations.IsPow2(choices.Length))
215216
{
216-
if (destination.Length < randomBytes.Length)
217+
// To avoid bias, we can't just % all bytes to get them into range; that would cause
218+
// the lower values to be more likely than the higher values. If the number of choices
219+
// is a power of 2, though, we can just mask off the extraneous bits.
220+
221+
int mask = choices.Length - 1;
222+
223+
while (!destination.IsEmpty)
217224
{
218-
randomBytes = randomBytes.Slice(0, destination.Length);
225+
// If this will be the last iteration, avoid over-requesting randomness.
226+
if (destination.Length < randomBytes.Length)
227+
{
228+
randomBytes = randomBytes.Slice(0, destination.Length);
229+
}
230+
231+
NextBytes(randomBytes);
232+
233+
for (int i = 0; i < randomBytes.Length; i++)
234+
{
235+
destination[i] = choices[randomBytes[i] & mask];
236+
}
237+
238+
destination = destination.Slice(randomBytes.Length);
219239
}
240+
}
241+
else
242+
{
243+
// As the length isn't a power of two, we can't just mask off all extraneous bits, and
244+
// instead need to do rejection sampling. However, we can mask off the irrelevant bits, which
245+
// then reduces the chances of needing to reject a value.
220246

221-
NextBytes(randomBytes);
247+
int mask = (int)BitOperations.RoundUpToPowerOf2((uint)choices.Length) - 1;
222248

223-
int mask = choices.Length - 1;
224-
for (int i = 0; i < randomBytes.Length; i++)
249+
while (!destination.IsEmpty)
225250
{
226-
destination[i] = choices[randomBytes[i] & mask];
251+
// Unlike in the IsPow2 case, where every byte will be used, some bytes here may
252+
// be rejected. On average, half the bytes may be rejected, so we heuristically
253+
// choose to shrink to twice the destination length.
254+
if (destination.Length * 2 < randomBytes.Length)
255+
{
256+
randomBytes = randomBytes.Slice(0, destination.Length * 2);
257+
}
258+
259+
NextBytes(randomBytes);
260+
261+
int i = 0;
262+
foreach (byte b in randomBytes)
263+
{
264+
if ((uint)i >= (uint)destination.Length)
265+
{
266+
break;
267+
}
268+
269+
byte masked = (byte)(b & mask);
270+
if (masked < (uint)choices.Length)
271+
{
272+
destination[i++] = choices[masked];
273+
}
274+
}
275+
276+
destination = destination.Slice(i);
227277
}
228-
229-
destination = destination.Slice(randomBytes.Length);
230278
}
231-
232-
return;
233279
}
234-
235-
// Simple fallback: get each item individually, generating a new random Int32 for each
236-
// item. This is slower than the above, but it works for all types and sizes of choices.
237-
for (int i = 0; i < destination.Length; i++)
280+
else
238281
{
239-
destination[i] = choices[Next(choices.Length)];
282+
// Simple fallback: get each item individually, generating a new random Int32 for each
283+
// item. This is slower than the above, but it works for all types and sizes of choices.
284+
for (int i = 0; i < destination.Length; i++)
285+
{
286+
destination[i] = choices[Next(choices.Length)];
287+
}
240288
}
241289
}
242290

src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Random.cs

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ public static void GetItems_Buffer_ArgValidation()
792792
}
793793

794794
[Fact]
795-
public static void GetItems_Allocating_Array_Seeded()
795+
public static void GetItems_Allocating_Array_Seeded_NonPower2()
796796
{
797797
Random random = new Random(0x70636A61);
798798
byte[] items = new byte[] { 1, 2, 3 };
@@ -808,7 +808,7 @@ public static void GetItems_Allocating_Array_Seeded()
808808
}
809809

810810
[Fact]
811-
public static void GetItems_Allocating_Span_Seeded()
811+
public static void GetItems_Allocating_Span_Seeded_NonPower2()
812812
{
813813
Random random = new Random(0x70636A61);
814814
ReadOnlySpan<byte> items = new byte[] { 1, 2, 3 };
@@ -824,7 +824,7 @@ public static void GetItems_Allocating_Span_Seeded()
824824
}
825825

826826
[Fact]
827-
public static void GetItems_Buffer_Seeded()
827+
public static void GetItems_Buffer_Seeded_NonPower2()
828828
{
829829
Random random = new Random(0x70636A61);
830830
ReadOnlySpan<byte> items = new byte[] { 1, 2, 3 };
@@ -840,6 +840,90 @@ public static void GetItems_Buffer_Seeded()
840840
AssertExtensions.SequenceEqual(new byte[] { 1, 1, 3, 1, 3, 2, 2 }, buffer);
841841
}
842842

843+
[Fact]
844+
public static void GetItems_Allocating_Array_Seeded_Power2()
845+
{
846+
Random random = new Random(0x70636A61);
847+
byte[] items = new byte[] { 1, 2, 3, 4 };
848+
849+
byte[] result = random.GetItems(items, length: 7);
850+
Assert.Equal(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, result);
851+
852+
result = random.GetItems(items, length: 7);
853+
Assert.Equal(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, result);
854+
855+
result = random.GetItems(items, length: 7);
856+
Assert.Equal(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, result);
857+
}
858+
859+
[Fact]
860+
public static void GetItems_Allocating_Span_Seeded_Power2()
861+
{
862+
Random random = new Random(0x70636A61);
863+
ReadOnlySpan<byte> items = new byte[] { 1, 2, 3, 4 };
864+
865+
byte[] result = random.GetItems(items, length: 7);
866+
Assert.Equal(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, result);
867+
868+
result = random.GetItems(items, length: 7);
869+
Assert.Equal(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, result);
870+
871+
result = random.GetItems(items, length: 7);
872+
Assert.Equal(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, result);
873+
}
874+
875+
[Fact]
876+
public static void GetItems_Buffer_Seeded_Power2()
877+
{
878+
Random random = new Random(0x70636A61);
879+
ReadOnlySpan<byte> items = new byte[] { 1, 2, 3, 4 };
880+
881+
Span<byte> buffer = stackalloc byte[7];
882+
random.GetItems(items, buffer);
883+
AssertExtensions.SequenceEqual(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, buffer);
884+
885+
random.GetItems(items, buffer);
886+
AssertExtensions.SequenceEqual(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, buffer);
887+
888+
random.GetItems(items, buffer);
889+
AssertExtensions.SequenceEqual(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, buffer);
890+
}
891+
892+
[Theory]
893+
[InlineData(0)]
894+
[InlineData(1)]
895+
[InlineData(2)]
896+
[InlineData(3)]
897+
[InlineData(4)]
898+
public static void GetItems_AllValuesInRange(int mode)
899+
{
900+
Random random = mode switch
901+
{
902+
0 => new Random(),
903+
1 => new Random(42),
904+
2 => new SubRandom(),
905+
3 => new SubRandom(42),
906+
_ => Random.Shared,
907+
};
908+
909+
foreach (int numItems in Enumerable.Range(1, 8).Append(300))
910+
{
911+
int[] items = Enumerable.Range(42, numItems).ToArray();
912+
for (int length = 1; length <= 16; length++)
913+
{
914+
int[] result = random.GetItems(items, length: length);
915+
Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1));
916+
917+
result = random.GetItems((ReadOnlySpan<int>)items, length: length);
918+
Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1));
919+
920+
Array.Clear(result);
921+
random.GetItems(items, (Span<int>)result);
922+
Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1));
923+
}
924+
}
925+
}
926+
843927
private static Random Create(bool derived, bool seeded) =>
844928
(derived, seeded) switch
845929
{

src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/RandomNumberGenerator.cs

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -346,44 +346,90 @@ private static void GetHexStringCore(Span<char> destination, bool lowercase)
346346

347347
private static void GetItemsCore<T>(ReadOnlySpan<T> choices, Span<T> destination)
348348
{
349-
// The most expensive part of this operation is the call to get random data. We can
350-
// do so potentially many fewer times if:
351-
// - the number of choices is <= 256. This let's us get a single byte per choice.
352-
// - the number of choices is a power of two. This let's us use a byte and simply mask off
353-
// unnecessary bits cheaply rather than needing to use rejection sampling.
354-
// In such a case, we can grab a bunch of random bytes in one call.
355-
if (BitOperations.IsPow2(choices.Length) && choices.Length <= 256)
349+
Debug.Assert(choices.Length > 0);
350+
351+
// The most expensive part of this operation is the call to get random data. If the number of
352+
// choices is <= 256 (which is the majority use case), we can use a single byte per element,
353+
// which means we can ammortize the cost of getting random data by getting random bytes in bulk.
354+
if (choices.Length <= 256)
356355
{
357356
// Get stack space to store random bytes. This size was chosen to balance between
358357
// stack consumed and number of random calls required.
359358
Span<byte> randomBytes = stackalloc byte[512];
360359

361-
while (!destination.IsEmpty)
360+
if (BitOperations.IsPow2(choices.Length))
362361
{
363-
if (destination.Length < randomBytes.Length)
362+
// To avoid bias, we can't just % all bytes to get them into range; that would cause
363+
// the lower values to be more likely than the higher values. If the number of choices
364+
// is a power of 2, though, we can just mask off the extraneous bits.
365+
366+
int mask = choices.Length - 1;
367+
368+
while (!destination.IsEmpty)
364369
{
365-
randomBytes = randomBytes.Slice(0, destination.Length);
370+
// If this will be the last iteration, avoid over-requesting randomness.
371+
if (destination.Length < randomBytes.Length)
372+
{
373+
randomBytes = randomBytes.Slice(0, destination.Length);
374+
}
375+
376+
RandomNumberGeneratorImplementation.FillSpan(randomBytes);
377+
378+
for (int i = 0; i < randomBytes.Length; i++)
379+
{
380+
destination[i] = choices[randomBytes[i] & mask];
381+
}
382+
383+
destination = destination.Slice(randomBytes.Length);
366384
}
385+
}
386+
else
387+
{
388+
// As the length isn't a power of two, we can't just mask off all extraneous bits, and
389+
// instead need to do rejection sampling. However, we can mask off the irrelevant bits, which
390+
// then reduces the chances of needing to reject a value.
367391

368-
RandomNumberGeneratorImplementation.FillSpan(randomBytes);
392+
int mask = (int)BitOperations.RoundUpToPowerOf2((uint)choices.Length) - 1;
369393

370-
int mask = choices.Length - 1;
371-
for (int i = 0; i < randomBytes.Length; i++)
394+
while (!destination.IsEmpty)
372395
{
373-
destination[i] = choices[randomBytes[i] & mask];
396+
// Unlike in the IsPow2 case, where every byte will be used, some bytes here may
397+
// be rejected. On average, half the bytes may be rejected, so we heuristically
398+
// choose to shrink to twice the destination length.
399+
if (destination.Length * 2 < randomBytes.Length)
400+
{
401+
randomBytes = randomBytes.Slice(0, destination.Length * 2);
402+
}
403+
404+
RandomNumberGeneratorImplementation.FillSpan(randomBytes);
405+
406+
int i = 0;
407+
foreach (byte b in randomBytes)
408+
{
409+
if ((uint)i >= (uint)destination.Length)
410+
{
411+
break;
412+
}
413+
414+
byte masked = (byte)(b & mask);
415+
if (masked < (uint)choices.Length)
416+
{
417+
destination[i++] = choices[masked];
418+
}
419+
}
420+
421+
destination = destination.Slice(i);
374422
}
375-
376-
destination = destination.Slice(randomBytes.Length);
377423
}
378-
379-
return;
380424
}
381-
382-
// Simple fallback: get each item individually, generating a new random Int32 for each
383-
// item. This is slower than the above, but it works for all types and sizes of choices.
384-
for (int i = 0; i < destination.Length; i++)
425+
else
385426
{
386-
destination[i] = choices[GetInt32(choices.Length)];
427+
// Simple fallback: get each item individually, generating a new random Int32 for each
428+
// item. This is slower than the above, but it works for all types and sizes of choices.
429+
for (int i = 0; i < destination.Length; i++)
430+
{
431+
destination[i] = choices[GetInt32(choices.Length)];
432+
}
387433
}
388434
}
389435

0 commit comments

Comments
 (0)