use hashbrown::{
hash_table::HashTable,
raw::{Bucket, RawTable},
};
use std::mem::size_of;
pub trait VecAllocExt {
type T;
fn push_accounted(&mut self, x: Self::T, accounting: &mut usize);
fn allocated_size(&self) -> usize;
}
impl<T> VecAllocExt for Vec<T> {
type T = T;
fn push_accounted(&mut self, x: Self::T, accounting: &mut usize) {
let prev_capacity = self.capacity();
self.push(x);
let new_capacity = self.capacity();
if new_capacity > prev_capacity {
let bump_size = (new_capacity - prev_capacity) * size_of::<T>();
*accounting = (*accounting).checked_add(bump_size).expect("overflow");
}
}
fn allocated_size(&self) -> usize {
size_of::<T>() * self.capacity()
}
}
pub trait RawTableAllocExt {
type T;
fn insert_accounted(
&mut self,
x: Self::T,
hasher: impl Fn(&Self::T) -> u64,
accounting: &mut usize,
) -> Bucket<Self::T>;
}
impl<T> RawTableAllocExt for RawTable<T> {
type T = T;
fn insert_accounted(
&mut self,
x: Self::T,
hasher: impl Fn(&Self::T) -> u64,
accounting: &mut usize,
) -> Bucket<Self::T> {
let hash = hasher(&x);
match self.try_insert_no_grow(hash, x) {
Ok(bucket) => bucket,
Err(x) => {
let bump_elements = self.capacity().max(16);
let bump_size = bump_elements * size_of::<T>();
*accounting = (*accounting).checked_add(bump_size).expect("overflow");
self.reserve(bump_elements, hasher);
match self.try_insert_no_grow(hash, x) {
Ok(bucket) => bucket,
Err(_) => panic!("just grew the container"),
}
}
}
}
}
pub trait HashTableAllocExt {
type T;
fn insert_accounted(
&mut self,
x: Self::T,
hasher: impl Fn(&Self::T) -> u64,
accounting: &mut usize,
);
}
impl<T> HashTableAllocExt for HashTable<T>
where
T: Eq,
{
type T = T;
fn insert_accounted(
&mut self,
x: Self::T,
hasher: impl Fn(&Self::T) -> u64,
accounting: &mut usize,
) {
let hash = hasher(&x);
match self.find_entry(hash, |y| y == &x) {
Ok(_occupied) => {}
Err(_absent) => {
if self.len() == self.capacity() {
let bump_elements = self.capacity().max(16);
let bump_size = bump_elements * size_of::<T>();
*accounting = (*accounting).checked_add(bump_size).expect("overflow");
self.reserve(bump_elements, &hasher);
}
self.entry(hash, |y| y == &x, hasher).insert(x);
}
}
}
}