use std::sync::Arc;
use async_trait::async_trait;
use datafusion_common::DataFusionError;
use hashbrown::raw::{Bucket, RawTable};
use super::{ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager};
pub struct MemoryConsumerProxy {
name: String,
id: MemoryConsumerId,
memory_manager: Arc<MemoryManager>,
used: usize,
}
impl MemoryConsumerProxy {
pub fn new(
name: impl Into<String>,
id: MemoryConsumerId,
memory_manager: Arc<MemoryManager>,
) -> Self {
memory_manager.register_requester(&id);
Self {
name: name.into(),
id,
memory_manager,
used: 0,
}
}
pub async fn alloc(&mut self, bytes: usize) -> Result<(), DataFusionError> {
self.try_grow(bytes).await?;
self.used = self.used.checked_add(bytes).expect("overflow");
Ok(())
}
}
#[async_trait]
impl MemoryConsumer for MemoryConsumerProxy {
fn name(&self) -> String {
self.name.clone()
}
fn id(&self) -> &crate::execution::MemoryConsumerId {
&self.id
}
fn memory_manager(&self) -> Arc<MemoryManager> {
Arc::clone(&self.memory_manager)
}
fn type_(&self) -> &ConsumerType {
&ConsumerType::Tracking
}
async fn spill(&self) -> Result<usize, DataFusionError> {
Err(DataFusionError::ResourcesExhausted(format!(
"Cannot spill {}",
self.name
)))
}
fn mem_used(&self) -> usize {
self.used
}
}
impl Drop for MemoryConsumerProxy {
fn drop(&mut self) {
self.memory_manager
.drop_consumer(self.id(), self.mem_used());
}
}
pub trait VecAllocExt {
type T;
fn push_accounted(&mut self, x: Self::T, accounting: &mut usize);
}
impl<T> VecAllocExt for Vec<T> {
type T = T;
fn push_accounted(&mut self, x: Self::T, accounting: &mut usize) {
if self.capacity() == self.len() {
let bump_elements = (self.capacity() * 2).max(2);
let bump_size = std::mem::size_of::<u32>() * bump_elements;
self.reserve(bump_elements);
*accounting = (*accounting).checked_add(bump_size).expect("overflow");
}
self.push(x);
}
}
pub trait RawTableAllocExt {
type T;
fn insert_accounted(
&mut self,
x: Self::T,
hasher: impl Fn(&Self::T) -> u64,
accounting: &mut usize,
) -> Bucket<Self::T>;
}
impl<T> RawTableAllocExt for RawTable<T> {
type T = T;
fn insert_accounted(
&mut self,
x: Self::T,
hasher: impl Fn(&Self::T) -> u64,
accounting: &mut usize,
) -> Bucket<Self::T> {
let hash = hasher(&x);
match self.try_insert_no_grow(hash, x) {
Ok(bucket) => bucket,
Err(x) => {
let bump_elements = (self.capacity() * 2).max(16);
let bump_size = bump_elements * std::mem::size_of::<T>();
*accounting = (*accounting).checked_add(bump_size).expect("overflow");
self.reserve(bump_elements, hasher);
match self.try_insert_no_grow(hash, x) {
Ok(bucket) => bucket,
Err(_) => panic!("just grew the container"),
}
}
}
}
}