use datafusion_common::{internal_err, Result};
use std::{cmp::Ordering, sync::Arc};
mod pool;
pub mod proxy {
pub use datafusion_common::utils::proxy::{
HashTableAllocExt, RawTableAllocExt, VecAllocExt,
};
}
pub use pool::*;
pub trait MemoryPool: Send + Sync + std::fmt::Debug {
fn register(&self, _consumer: &MemoryConsumer) {}
fn unregister(&self, _consumer: &MemoryConsumer) {}
fn grow(&self, reservation: &MemoryReservation, additional: usize);
fn shrink(&self, reservation: &MemoryReservation, shrink: usize);
fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>;
fn reserved(&self) -> usize;
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct MemoryConsumer {
name: String,
can_spill: bool,
}
impl MemoryConsumer {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
can_spill: false,
}
}
pub fn with_can_spill(self, can_spill: bool) -> Self {
Self { can_spill, ..self }
}
pub fn can_spill(&self) -> bool {
self.can_spill
}
pub fn name(&self) -> &str {
&self.name
}
pub fn register(self, pool: &Arc<dyn MemoryPool>) -> MemoryReservation {
pool.register(&self);
MemoryReservation {
registration: Arc::new(SharedRegistration {
pool: Arc::clone(pool),
consumer: self,
}),
size: 0,
}
}
}
#[derive(Debug)]
struct SharedRegistration {
pool: Arc<dyn MemoryPool>,
consumer: MemoryConsumer,
}
impl Drop for SharedRegistration {
fn drop(&mut self) {
self.pool.unregister(&self.consumer);
}
}
#[derive(Debug)]
pub struct MemoryReservation {
registration: Arc<SharedRegistration>,
size: usize,
}
impl MemoryReservation {
pub fn size(&self) -> usize {
self.size
}
pub fn consumer(&self) -> &MemoryConsumer {
&self.registration.consumer
}
pub fn free(&mut self) -> usize {
let size = self.size;
if size != 0 {
self.shrink(size)
}
size
}
pub fn shrink(&mut self, capacity: usize) {
let new_size = self.size.checked_sub(capacity).unwrap();
self.registration.pool.shrink(self, capacity);
self.size = new_size
}
pub fn try_shrink(&mut self, capacity: usize) -> Result<usize> {
if let Some(new_size) = self.size.checked_sub(capacity) {
self.registration.pool.shrink(self, capacity);
self.size = new_size;
Ok(new_size)
} else {
internal_err!(
"Cannot free the capacity {capacity} out of allocated size {}",
self.size
)
}
}
pub fn resize(&mut self, capacity: usize) {
match capacity.cmp(&self.size) {
Ordering::Greater => self.grow(capacity - self.size),
Ordering::Less => self.shrink(self.size - capacity),
_ => {}
}
}
pub fn try_resize(&mut self, capacity: usize) -> Result<()> {
match capacity.cmp(&self.size) {
Ordering::Greater => self.try_grow(capacity - self.size)?,
Ordering::Less => self.shrink(self.size - capacity),
_ => {}
};
Ok(())
}
pub fn grow(&mut self, capacity: usize) {
self.registration.pool.grow(self, capacity);
self.size += capacity;
}
pub fn try_grow(&mut self, capacity: usize) -> Result<()> {
self.registration.pool.try_grow(self, capacity)?;
self.size += capacity;
Ok(())
}
pub fn split(&mut self, capacity: usize) -> MemoryReservation {
self.size = self.size.checked_sub(capacity).unwrap();
Self {
size: capacity,
registration: Arc::clone(&self.registration),
}
}
pub fn new_empty(&self) -> Self {
Self {
size: 0,
registration: Arc::clone(&self.registration),
}
}
pub fn take(&mut self) -> MemoryReservation {
self.split(self.size)
}
}
impl Drop for MemoryReservation {
fn drop(&mut self) {
self.free();
}
}
pub mod units {
pub const TB: u64 = 1 << 40;
pub const GB: u64 = 1 << 30;
pub const MB: u64 = 1 << 20;
pub const KB: u64 = 1 << 10;
}
pub fn human_readable_size(size: usize) -> String {
use units::*;
let size = size as u64;
let (value, unit) = {
if size >= 2 * TB {
(size as f64 / TB as f64, "TB")
} else if size >= 2 * GB {
(size as f64 / GB as f64, "GB")
} else if size >= 2 * MB {
(size as f64 / MB as f64, "MB")
} else if size >= 2 * KB {
(size as f64 / KB as f64, "KB")
} else {
(size as f64, "B")
}
};
format!("{value:.1} {unit}")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_underflow() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let mut a1 = MemoryConsumer::new("a1").register(&pool);
assert_eq!(pool.reserved(), 0);
a1.grow(100);
assert_eq!(pool.reserved(), 100);
assert_eq!(a1.free(), 100);
assert_eq!(pool.reserved(), 0);
a1.try_grow(100).unwrap_err();
assert_eq!(pool.reserved(), 0);
a1.try_grow(30).unwrap();
assert_eq!(pool.reserved(), 30);
let mut a2 = MemoryConsumer::new("a2").register(&pool);
a2.try_grow(25).unwrap_err();
assert_eq!(pool.reserved(), 30);
drop(a1);
assert_eq!(pool.reserved(), 0);
a2.try_grow(25).unwrap();
assert_eq!(pool.reserved(), 25);
}
#[test]
fn test_split() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let mut r1 = MemoryConsumer::new("r1").register(&pool);
r1.try_grow(20).unwrap();
assert_eq!(r1.size(), 20);
assert_eq!(pool.reserved(), 20);
let r2 = r1.split(5);
assert_eq!(r1.size(), 15);
assert_eq!(r2.size(), 5);
assert_eq!(pool.reserved(), 20);
drop(r1);
assert_eq!(r2.size(), 5);
assert_eq!(pool.reserved(), 5);
}
#[test]
fn test_new_empty() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let mut r1 = MemoryConsumer::new("r1").register(&pool);
r1.try_grow(20).unwrap();
let mut r2 = r1.new_empty();
r2.try_grow(5).unwrap();
assert_eq!(r1.size(), 20);
assert_eq!(r2.size(), 5);
assert_eq!(pool.reserved(), 25);
}
#[test]
fn test_take() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let mut r1 = MemoryConsumer::new("r1").register(&pool);
r1.try_grow(20).unwrap();
let mut r2 = r1.take();
r2.try_grow(5).unwrap();
assert_eq!(r1.size(), 0);
assert_eq!(r2.size(), 25);
assert_eq!(pool.reserved(), 25);
r1.try_grow(3).unwrap();
assert_eq!(r1.size(), 3);
assert_eq!(r2.size(), 25);
assert_eq!(pool.reserved(), 28);
}
}