Skip to content

Commit 0710ecc

Browse files
zhuqi-lucasalamb
andauthored
Improve StringArray(Utf8) sort performance (~2-4x faster) (#7860)
# Which issue does this PR close? Improve StringArray(Utf8) sort performance - Closes [#7847](#7847) # Rationale for this change Support prefix compare, and i optimized it to u32 prefix, and u64 increment compare, it will have best performance when experimenting. # What changes are included in this PR? Support prefix compare, and i optimized it to u32 prefix, and u64 increment compare, it will have best performance when experimenting. # Are these changes tested? Yes ```rust critcmp issue_7847 main --filter "sort string\[" group issue_7847 main ----- ---------- ---- sort string[0-400] nulls to indices 2^12 1.00 51.4±0.56µs ? ?/sec 1.19 61.0±1.02µs ? ?/sec sort string[0-400] to indices 2^12 1.00 96.5±1.63µs ? ?/sec 1.23 118.3±0.91µs ? ?/sec sort string[10] dict nulls to indices 2^12 1.00 72.4±1.00µs ? ?/sec 1.00 72.5±0.61µs ? ?/sec sort string[10] dict to indices 2^12 1.00 137.1±1.51µs ? ?/sec 1.01 138.1±1.06µs ? ?/sec sort string[10] nulls to indices 2^12 1.00 47.5±0.69µs ? ?/sec 1.18 56.3±0.56µs ? ?/sec sort string[10] to indices 2^12 1.00 86.4±1.37µs ? ?/sec 1.20 103.5±1.13µs ? ?/sec ``` # Are there any user-facing changes? If there are user-facing changes then we may require documentation to be updated before approving the PR. If there are any breaking changes to public APIs, please call them out. --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent e4d359b commit 0710ecc

File tree

1 file changed

+377
-4
lines changed

1 file changed

+377
-4
lines changed

arrow-ord/src/sort.rs

Lines changed: 377 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,88 @@ fn sort_bytes<T: ByteArrayType>(
345345
options: SortOptions,
346346
limit: Option<usize>,
347347
) -> UInt32Array {
348-
let mut valids = value_indices
348+
// Note: Why do we use 4‑byte prefix?
349+
// Compute the 4‑byte prefix in BE order, or left‑pad if shorter.
350+
// Most byte‐sequences differ in their first few bytes, so by
351+
// comparing up to 4 bytes as a single u32 we avoid the overhead
352+
// of a full lexicographical compare for the vast majority of cases.
353+
354+
// 1. Build a vector of (index, prefix, length) tuples
355+
let mut valids: Vec<(u32, u32, u64)> = value_indices
349356
.into_iter()
350-
.map(|index| (index, values.value(index as usize).as_ref()))
351-
.collect::<Vec<(u32, &[u8])>>();
357+
.map(|idx| unsafe {
358+
let slice: &[u8] = values.value_unchecked(idx as usize).as_ref();
359+
let len = slice.len() as u64;
360+
// Compute the 4‑byte prefix in BE order, or left‑pad if shorter
361+
let prefix = if slice.len() >= 4 {
362+
let raw = std::ptr::read_unaligned(slice.as_ptr() as *const u32);
363+
u32::from_be(raw)
364+
} else if slice.is_empty() {
365+
// Handle empty slice case to avoid shift overflow
366+
0u32
367+
} else {
368+
let mut v = 0u32;
369+
for &b in slice {
370+
v = (v << 8) | (b as u32);
371+
}
372+
// Safe shift: slice.len() is in range [1, 3], so shift is in range [8, 24]
373+
v << (8 * (4 - slice.len()))
374+
};
375+
(idx, prefix, len)
376+
})
377+
.collect();
352378

353-
sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into()
379+
// 2. compute the number of non-null entries to partially sort
380+
let vlimit = match (limit, options.nulls_first) {
381+
(Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()),
382+
_ => valids.len(),
383+
};
384+
385+
// 3. Comparator: compare prefix, then (when both slices shorter than 4) length, otherwise full slice
386+
let cmp_bytes = |a: &(u32, u32, u64), b: &(u32, u32, u64)| unsafe {
387+
let (ia, pa, la) = *a;
388+
let (ib, pb, lb) = *b;
389+
// 3.1 prefix (first 4 bytes)
390+
let ord = pa.cmp(&pb);
391+
if ord != Ordering::Equal {
392+
return ord;
393+
}
394+
// 3.2 only if both slices had length < 4 (so prefix was padded)
395+
if la < 4 || lb < 4 {
396+
let ord = la.cmp(&lb);
397+
if ord != Ordering::Equal {
398+
return ord;
399+
}
400+
}
401+
// 3.3 full lexicographical compare
402+
let a_bytes: &[u8] = values.value_unchecked(ia as usize).as_ref();
403+
let b_bytes: &[u8] = values.value_unchecked(ib as usize).as_ref();
404+
a_bytes.cmp(b_bytes)
405+
};
406+
407+
// 4. Partially sort according to ascending/descending
408+
if !options.descending {
409+
sort_unstable_by(&mut valids, vlimit, cmp_bytes);
410+
} else {
411+
sort_unstable_by(&mut valids, vlimit, |x, y| cmp_bytes(x, y).reverse());
412+
}
413+
414+
// 5. Assemble nulls and sorted indices into final output
415+
let total = valids.len() + nulls.len();
416+
let out_limit = limit.unwrap_or(total).min(total);
417+
let mut out = Vec::with_capacity(out_limit);
418+
419+
if options.nulls_first {
420+
out.extend_from_slice(&nulls[..nulls.len().min(out_limit)]);
421+
let rem = out_limit - out.len();
422+
out.extend(valids.iter().map(|&(i, _, _)| i).take(rem));
423+
} else {
424+
out.extend(valids.iter().map(|&(i, _, _)| i).take(out_limit));
425+
let rem = out_limit - out.len();
426+
out.extend_from_slice(&nulls[..rem]);
427+
}
428+
429+
out.into()
354430
}
355431

356432
fn sort_byte_view<T: ByteViewType>(
@@ -4841,4 +4917,301 @@ mod tests {
48414917
assert_eq!(valid, vec![0, 2]);
48424918
assert_eq!(nulls, vec![1, 3]);
48434919
}
4920+
4921+
// Test specific edge case strings that exercise the 4-byte prefix logic
4922+
#[test]
4923+
fn test_specific_edge_cases() {
4924+
let test_cases = vec![
4925+
// Key test cases for lengths 1-4 that test prefix padding
4926+
"a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda",
4927+
// Test cases where first 4 bytes are same but subsequent bytes differ
4928+
"abcd", "abcde", "abcdf", "abcdaaa", "abcdbbb",
4929+
// Test cases with length < 4 that require padding
4930+
"z", "za", "zaa", "zaaa", "zaaab", // Empty string
4931+
"", // Test various length combinations with same prefix
4932+
"test", "test1", "test12", "test123", "test1234",
4933+
];
4934+
4935+
// Use standard library sort as reference
4936+
let mut expected = test_cases.clone();
4937+
expected.sort();
4938+
4939+
// Use our sorting algorithm
4940+
let string_array = StringArray::from(test_cases.clone());
4941+
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
4942+
let result = sort_bytes(
4943+
&string_array,
4944+
indices,
4945+
vec![], // no nulls
4946+
SortOptions::default(),
4947+
None,
4948+
);
4949+
4950+
// Verify results
4951+
let sorted_strings: Vec<&str> = result
4952+
.values()
4953+
.iter()
4954+
.map(|&idx| test_cases[idx as usize])
4955+
.collect();
4956+
4957+
assert_eq!(sorted_strings, expected);
4958+
}
4959+
4960+
// Test sorting correctness for different length combinations
4961+
#[test]
4962+
fn test_length_combinations() {
4963+
let test_cases = vec![
4964+
// Focus on testing strings of length 1-4, as these affect padding logic
4965+
("", 0),
4966+
("a", 1),
4967+
("ab", 2),
4968+
("abc", 3),
4969+
("abcd", 4),
4970+
("abcde", 5),
4971+
("b", 1),
4972+
("ba", 2),
4973+
("bab", 3),
4974+
("babc", 4),
4975+
("babcd", 5),
4976+
// Test same prefix with different lengths
4977+
("test", 4),
4978+
("test1", 5),
4979+
("test12", 6),
4980+
("test123", 7),
4981+
];
4982+
4983+
let strings: Vec<&str> = test_cases.iter().map(|(s, _)| *s).collect();
4984+
let mut expected = strings.clone();
4985+
expected.sort();
4986+
4987+
let string_array = StringArray::from(strings.clone());
4988+
let indices: Vec<u32> = (0..strings.len() as u32).collect();
4989+
let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None);
4990+
4991+
let sorted_strings: Vec<&str> = result
4992+
.values()
4993+
.iter()
4994+
.map(|&idx| strings[idx as usize])
4995+
.collect();
4996+
4997+
assert_eq!(sorted_strings, expected);
4998+
}
4999+
5000+
// Test UTF-8 string handling
5001+
#[test]
5002+
fn test_utf8_strings() {
5003+
let test_cases = vec![
5004+
"a",
5005+
"你", // 3-byte UTF-8 character
5006+
"你好", // 6 bytes
5007+
"你好世界", // 12 bytes
5008+
"🎉", // 4-byte emoji
5009+
"🎉🎊", // 8 bytes
5010+
"café", // Contains accent character
5011+
"naïve",
5012+
"Москва", // Cyrillic script
5013+
"東京", // Japanese kanji
5014+
"한국", // Korean
5015+
];
5016+
5017+
let mut expected = test_cases.clone();
5018+
expected.sort();
5019+
5020+
let string_array = StringArray::from(test_cases.clone());
5021+
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
5022+
let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None);
5023+
5024+
let sorted_strings: Vec<&str> = result
5025+
.values()
5026+
.iter()
5027+
.map(|&idx| test_cases[idx as usize])
5028+
.collect();
5029+
5030+
assert_eq!(sorted_strings, expected);
5031+
}
5032+
5033+
// Fuzz testing: generate random UTF-8 strings and verify sort correctness
5034+
#[test]
5035+
fn test_fuzz_random_strings() {
5036+
let mut rng = StdRng::seed_from_u64(42); // Fixed seed for reproducibility
5037+
5038+
for _ in 0..100 {
5039+
// Run 100 rounds of fuzz testing
5040+
let mut test_strings = Vec::new();
5041+
5042+
// Generate 20-50 random strings
5043+
let num_strings = rng.random_range(20..=50);
5044+
5045+
for _ in 0..num_strings {
5046+
let string = generate_random_string(&mut rng);
5047+
test_strings.push(string);
5048+
}
5049+
5050+
// Use standard library sort as reference
5051+
let mut expected = test_strings.clone();
5052+
expected.sort();
5053+
5054+
// Use our sorting algorithm
5055+
let string_array = StringArray::from(test_strings.clone());
5056+
let indices: Vec<u32> = (0..test_strings.len() as u32).collect();
5057+
let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None);
5058+
5059+
let sorted_strings: Vec<String> = result
5060+
.values()
5061+
.iter()
5062+
.map(|&idx| test_strings[idx as usize].clone())
5063+
.collect();
5064+
5065+
assert_eq!(
5066+
sorted_strings, expected,
5067+
"Fuzz test failed with input: {test_strings:?}"
5068+
);
5069+
}
5070+
}
5071+
5072+
// Helper function to generate random UTF-8 strings
5073+
fn generate_random_string(rng: &mut StdRng) -> String {
5074+
// Bias towards generating short strings, especially length 1-4
5075+
let length = if rng.random_bool(0.6) {
5076+
rng.random_range(0..=4) // 60% probability for 0-4 length strings
5077+
} else {
5078+
rng.random_range(5..=20) // 40% probability for longer strings
5079+
};
5080+
5081+
if length == 0 {
5082+
return String::new();
5083+
}
5084+
5085+
let mut result = String::new();
5086+
let mut current_len = 0;
5087+
5088+
while current_len < length {
5089+
let c = generate_random_char(rng);
5090+
let char_len = c.len_utf8();
5091+
5092+
// Ensure we don't exceed target length
5093+
if current_len + char_len <= length {
5094+
result.push(c);
5095+
current_len += char_len;
5096+
} else {
5097+
// If adding this character would exceed length, fill with ASCII
5098+
let remaining = length - current_len;
5099+
for _ in 0..remaining {
5100+
result.push(rng.random_range('a'..='z'));
5101+
current_len += 1;
5102+
}
5103+
break;
5104+
}
5105+
}
5106+
5107+
result
5108+
}
5109+
5110+
// Generate random characters (including various UTF-8 characters)
5111+
fn generate_random_char(rng: &mut StdRng) -> char {
5112+
match rng.random_range(0..10) {
5113+
0..=5 => rng.random_range('a'..='z'), // 60% ASCII lowercase
5114+
6 => rng.random_range('A'..='Z'), // 10% ASCII uppercase
5115+
7 => rng.random_range('0'..='9'), // 10% digits
5116+
8 => {
5117+
// 10% Chinese characters
5118+
let chinese_chars = ['你', '好', '世', '界', '测', '试', '中', '文'];
5119+
chinese_chars[rng.random_range(0..chinese_chars.len())]
5120+
}
5121+
9 => {
5122+
// 10% other Unicode characters (single `char`s)
5123+
let special_chars = ['é', 'ï', '🎉', '🎊', 'α', 'β', 'γ'];
5124+
special_chars[rng.random_range(0..special_chars.len())]
5125+
}
5126+
_ => unreachable!(),
5127+
}
5128+
}
5129+
5130+
// Test descending sort order
5131+
#[test]
5132+
fn test_descending_sort() {
5133+
let test_cases = vec!["a", "ab", "ba", "baa", "abba", "abbc", "abc", "cda"];
5134+
5135+
let mut expected = test_cases.clone();
5136+
expected.sort();
5137+
expected.reverse(); // Descending order
5138+
5139+
let string_array = StringArray::from(test_cases.clone());
5140+
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
5141+
let result = sort_bytes(
5142+
&string_array,
5143+
indices,
5144+
vec![],
5145+
SortOptions {
5146+
descending: true,
5147+
nulls_first: false,
5148+
},
5149+
None,
5150+
);
5151+
5152+
let sorted_strings: Vec<&str> = result
5153+
.values()
5154+
.iter()
5155+
.map(|&idx| test_cases[idx as usize])
5156+
.collect();
5157+
5158+
assert_eq!(sorted_strings, expected);
5159+
}
5160+
5161+
// Stress test: large number of strings with same prefix
5162+
#[test]
5163+
fn test_same_prefix_stress() {
5164+
let mut test_cases = Vec::new();
5165+
let prefix = "same";
5166+
5167+
// Generate many strings with the same prefix
5168+
for i in 0..1000 {
5169+
test_cases.push(format!("{prefix}{i:04}"));
5170+
}
5171+
5172+
let mut expected = test_cases.clone();
5173+
expected.sort();
5174+
5175+
let string_array = StringArray::from(test_cases.clone());
5176+
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
5177+
let result = sort_bytes(&string_array, indices, vec![], SortOptions::default(), None);
5178+
5179+
let sorted_strings: Vec<String> = result
5180+
.values()
5181+
.iter()
5182+
.map(|&idx| test_cases[idx as usize].clone())
5183+
.collect();
5184+
5185+
assert_eq!(sorted_strings, expected);
5186+
}
5187+
5188+
// Test limit parameter
5189+
#[test]
5190+
fn test_with_limit() {
5191+
let test_cases = vec!["z", "y", "x", "w", "v", "u", "t", "s"];
5192+
let limit = 3;
5193+
5194+
let mut expected = test_cases.clone();
5195+
expected.sort();
5196+
expected.truncate(limit);
5197+
5198+
let string_array = StringArray::from(test_cases.clone());
5199+
let indices: Vec<u32> = (0..test_cases.len() as u32).collect();
5200+
let result = sort_bytes(
5201+
&string_array,
5202+
indices,
5203+
vec![],
5204+
SortOptions::default(),
5205+
Some(limit),
5206+
);
5207+
5208+
let sorted_strings: Vec<&str> = result
5209+
.values()
5210+
.iter()
5211+
.map(|&idx| test_cases[idx as usize])
5212+
.collect();
5213+
5214+
assert_eq!(sorted_strings, expected);
5215+
assert_eq!(sorted_strings.len(), limit);
5216+
}
48445217
}

0 commit comments

Comments
 (0)