Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion datafusion/common/src/utils/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

//! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations

use hashbrown::raw::{Bucket, RawTable};
use hashbrown::{
hash_table::HashTable,
raw::{Bucket, RawTable},
};
use std::mem::size_of;

/// Extension trait for [`Vec`] to account for allocations.
Expand Down Expand Up @@ -173,3 +176,71 @@ impl<T> RawTableAllocExt for RawTable<T> {
}
}
}

/// Extension trait for hash browns [`HashTable`] to account for allocations.
pub trait HashTableAllocExt {
/// Item type.
type T;

/// Insert new element into table and increase
/// `accounting` by any newly allocated bytes.
///
/// Returns the bucket where the element was inserted.
/// Note that allocation counts capacity, not size.
///
/// # Example:
/// ```
/// # use datafusion_common::utils::proxy::HashTableAllocExt;
/// # use hashbrown::hash_table::HashTable;
/// let mut table = HashTable::new();
/// let mut allocated = 0;
/// let hash_fn = |x: &u32| (*x as u64) % 1000;
/// // pretend 0x3117 is the hash value for 1
/// table.insert_accounted(1, hash_fn, &mut allocated);
/// assert_eq!(allocated, 64);
///
/// // insert more values
/// for i in 0..100 { table.insert_accounted(i, hash_fn, &mut allocated); }
/// assert_eq!(allocated, 400);
/// ```
fn insert_accounted(
&mut self,
x: Self::T,
hasher: impl Fn(&Self::T) -> u64,
accounting: &mut usize,
);
}

impl<T> HashTableAllocExt for HashTable<T>
where
T: Eq,
{
type T = T;

fn insert_accounted(
&mut self,
x: Self::T,
hasher: impl Fn(&Self::T) -> u64,
accounting: &mut usize,
) {
let hash = hasher(&x);

// NOTE: `find_entry` does NOT grow!
match self.find_entry(hash, |y| y == &x) {
Ok(_occupied) => {}
Err(_absent) => {
if self.len() == self.capacity() {
// need to request more memory
let bump_elements = self.capacity().max(16);
let bump_size = bump_elements * size_of::<T>();
*accounting = (*accounting).checked_add(bump_size).expect("overflow");

self.reserve(bump_elements, &hasher);
}

// still need to insert the element since first try failed
self.entry(hash, |y| y == &x, hasher).insert(x);
}
}
}
}
4 changes: 3 additions & 1 deletion datafusion/execution/src/memory_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use std::{cmp::Ordering, sync::Arc};

mod pool;
pub mod proxy {
pub use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
pub use datafusion_common::utils::proxy::{
HashTableAllocExt, RawTableAllocExt, VecAllocExt,
};
}

pub use pool::*;
Expand Down
10 changes: 5 additions & 5 deletions datafusion/physical-expr-common/src/binary_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::array::{
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow::datatypes::DataType;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
use std::any::type_name;
use std::fmt::Debug;
use std::mem::{size_of, swap};
Expand Down Expand Up @@ -215,7 +215,7 @@ where
/// Should the output be String or Binary?
output_type: OutputType,
/// Underlying hash set for each distinct value
map: hashbrown::raw::RawTable<Entry<O, V>>,
map: hashbrown::hash_table::HashTable<Entry<O, V>>,
/// Total size of the map in bytes
map_size: usize,
/// In progress arrow `Buffer` containing all values
Expand Down Expand Up @@ -246,7 +246,7 @@ where
pub fn new(output_type: OutputType) -> Self {
Self {
output_type,
map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY),
map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
map_size: 0,
buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY),
offsets: vec![O::default()], // first offset is always 0
Expand Down Expand Up @@ -387,7 +387,7 @@ where
let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize);

// is value is already present in the set?
let entry = self.map.get_mut(hash, |header| {
let entry = self.map.find_mut(hash, |header| {
// compare value if hashes match
if header.len != value_len {
return false;
Expand Down Expand Up @@ -425,7 +425,7 @@ where
// value is not "small"
else {
// Check if the value is already present in the set
let entry = self.map.get_mut(hash, |header| {
let entry = self.map.find_mut(hash, |header| {
// compare value if hashes match
if header.len != value_len {
return false;
Expand Down
8 changes: 4 additions & 4 deletions datafusion/physical-expr-common/src/binary_view_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::array::cast::AsArray;
use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
use std::fmt::Debug;
use std::sync::Arc;

Expand Down Expand Up @@ -122,7 +122,7 @@ where
/// Should the output be StringView or BinaryView?
output_type: OutputType,
/// Underlying hash set for each distinct value
map: hashbrown::raw::RawTable<Entry<V>>,
map: hashbrown::hash_table::HashTable<Entry<V>>,
/// Total size of the map in bytes
map_size: usize,

Expand All @@ -148,7 +148,7 @@ where
pub fn new(output_type: OutputType) -> Self {
Self {
output_type,
map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY),
map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
map_size: 0,
builder: GenericByteViewBuilder::new(),
random_state: RandomState::new(),
Expand Down Expand Up @@ -274,7 +274,7 @@ where
// get the value as bytes
let value: &[u8] = value.as_ref();

let entry = self.map.get_mut(hash, |header| {
let entry = self.map.find_mut(hash, |header| {
let v = self.builder.get_value(header.view_idx);

if v.len() != value.len() {
Expand Down