use parking_lot::RwLock;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};
pub trait SchemaProvider: Sync + Send {
fn as_any(&self) -> &dyn Any;
fn table_names(&self) -> Vec<String>;
fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>>;
#[allow(unused_variables)]
fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
Err(DataFusionError::Execution(
"schema provider does not support registering tables".to_owned(),
))
}
#[allow(unused_variables)]
fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
Err(DataFusionError::Execution(
"schema provider does not support deregistering tables".to_owned(),
))
}
fn table_exist(&self, name: &str) -> bool;
}
pub struct MemorySchemaProvider {
tables: RwLock<HashMap<String, Arc<dyn TableProvider>>>,
}
impl MemorySchemaProvider {
pub fn new() -> Self {
Self {
tables: RwLock::new(HashMap::new()),
}
}
}
impl Default for MemorySchemaProvider {
fn default() -> Self {
Self::new()
}
}
impl SchemaProvider for MemorySchemaProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn table_names(&self) -> Vec<String> {
let tables = self.tables.read();
tables.keys().cloned().collect()
}
fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
let tables = self.tables.read();
tables.get(name).cloned()
}
fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
if self.table_exist(name.as_str()) {
return Err(DataFusionError::Execution(format!(
"The table {} already exists",
name
)));
}
let mut tables = self.tables.write();
Ok(tables.insert(name, table))
}
fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
let mut tables = self.tables.write();
Ok(tables.remove(name))
}
fn table_exist(&self, name: &str) -> bool {
let tables = self.tables.read();
tables.contains_key(name)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::datatypes::Schema;
use crate::assert_batches_eq;
use crate::catalog::catalog::{CatalogProvider, MemoryCatalogProvider};
use crate::catalog::schema::{MemorySchemaProvider, SchemaProvider};
use crate::datasource::empty::EmptyTable;
use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl};
use crate::prelude::SessionContext;
#[tokio::test]
async fn test_mem_provider() {
let provider = MemorySchemaProvider::new();
let table_name = "test_table_exist";
assert!(!provider.table_exist(table_name));
assert!(provider.deregister_table(table_name).unwrap().is_none());
let test_table = EmptyTable::new(Arc::new(Schema::empty()));
assert!(provider
.register_table(table_name.to_string(), Arc::new(test_table))
.unwrap()
.is_none());
assert!(provider.table_exist(table_name));
let other_table = EmptyTable::new(Arc::new(Schema::empty()));
let result =
provider.register_table(table_name.to_string(), Arc::new(other_table));
assert!(result.is_err());
}
#[tokio::test]
async fn test_schema_register_listing_table() {
let testdata = crate::test_util::parquet_test_data();
let testdir = if testdata.starts_with('/') {
format!("file://{}", testdata)
} else {
format!("file:///{}", testdata)
};
let filename = if testdir.ends_with('/') {
format!("{}{}", testdir, "alltypes_plain.parquet")
} else {
format!("{}/{}", testdir, "alltypes_plain.parquet")
};
let table_path = ListingTableUrl::parse(filename).unwrap();
let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
let ctx = SessionContext::new();
let config = ListingTableConfig::new(table_path)
.infer(&ctx.state())
.await
.unwrap();
let table = ListingTable::try_new(config).unwrap();
schema
.register_table("alltypes_plain".to_string(), Arc::new(table))
.unwrap();
catalog.register_schema("active", Arc::new(schema)).unwrap();
ctx.register_catalog("cat", Arc::new(catalog));
let df = ctx
.sql("SELECT id, bool_col FROM cat.active.alltypes_plain")
.await
.unwrap();
let actual = df.collect().await.unwrap();
let expected = vec![
"+----+----------+",
"| id | bool_col |",
"+----+----------+",
"| 4 | true |",
"| 5 | false |",
"| 6 | true |",
"| 7 | false |",
"| 2 | true |",
"| 3 | false |",
"| 0 | true |",
"| 1 | false |",
"+----+----------+",
];
assert_batches_eq!(expected, &actual);
}
}