use datafusion_catalog::memory::MemorySchemaProvider;
use datafusion_catalog::MemoryCatalogProvider;
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::{Arc, Weak};
use super::options::ReadOptions;
use crate::{
catalog::{
CatalogProvider, CatalogProviderList, TableProvider, TableProviderFactory,
},
catalog_common::listing_schema::ListingSchemaProvider,
dataframe::DataFrame,
datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
},
datasource::{provider_as_source, MemTable, ViewTable},
error::{DataFusionError, Result},
execution::{options::ArrowReadOptions, runtime_env::RuntimeEnv, FunctionRegistry},
logical_expr::AggregateUDF,
logical_expr::ScalarUDF,
logical_expr::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable,
TableType, UNNAMED_TABLE,
},
physical_expr::PhysicalExpr,
physical_plan::ExecutionPlan,
variable::{VarProvider, VarType},
};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::{
config::{ConfigExtension, TableOptions},
exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNodeRecursion, TreeNodeVisitor},
DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
expr_rewriter::FunctionRewrite,
logical_plan::{DdlStatement, Statement},
planner::ExprPlanner,
Expr, UserDefinedLogicalNode, WindowUDF,
};
pub use crate::execution::session_state::SessionState;
use crate::datasource::dynamic_file::DynamicListTableFactory;
use crate::execution::session_state::SessionStateBuilder;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_catalog::{
DynamicFileCatalog, SessionStore, TableFunction, TableFunctionImpl, UrlTableFactory,
};
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;
pub use datafusion_expr::execution_props::ExecutionProps;
use datafusion_optimizer::{AnalyzerRule, OptimizerRule};
use object_store::ObjectStore;
use parking_lot::RwLock;
use url::Url;
mod avro;
mod csv;
mod json;
#[cfg(feature = "parquet")]
mod parquet;
pub trait DataFilePaths {
fn to_urls(self) -> Result<Vec<ListingTableUrl>>;
}
impl DataFilePaths for &str {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl DataFilePaths for String {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl DataFilePaths for &String {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl<P> DataFilePaths for Vec<P>
where
P: AsRef<str>,
{
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
self.iter()
.map(ListingTableUrl::parse)
.collect::<Result<Vec<ListingTableUrl>>>()
}
}
#[derive(Clone)]
pub struct SessionContext {
session_id: String,
session_start_time: DateTime<Utc>,
state: Arc<RwLock<SessionState>>,
}
impl Default for SessionContext {
fn default() -> Self {
Self::new()
}
}
impl SessionContext {
pub fn new() -> Self {
Self::new_with_config(SessionConfig::new())
}
pub async fn refresh_catalogs(&self) -> Result<()> {
let cat_names = self.catalog_names().clone();
for cat_name in cat_names.iter() {
let cat = self.catalog(cat_name.as_str()).ok_or_else(|| {
DataFusionError::Internal("Catalog not found!".to_string())
})?;
for schema_name in cat.schema_names() {
let schema = cat.schema(schema_name.as_str()).ok_or_else(|| {
DataFusionError::Internal("Schema not found!".to_string())
})?;
let lister = schema.as_any().downcast_ref::<ListingSchemaProvider>();
if let Some(lister) = lister {
lister.refresh(&self.state()).await?;
}
}
}
Ok(())
}
pub fn new_with_config(config: SessionConfig) -> Self {
let runtime = Arc::new(RuntimeEnv::default());
Self::new_with_config_rt(config, runtime)
}
pub fn new_with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.build();
Self::new_with_state(state)
}
pub fn new_with_state(state: SessionState) -> Self {
Self {
session_id: state.session_id().to_string(),
session_start_time: Utc::now(),
state: Arc::new(RwLock::new(state)),
}
}
pub fn enable_url_table(self) -> Self {
let current_catalog_list = Arc::clone(self.state.read().catalog_list());
let factory = Arc::new(DynamicListTableFactory::new(SessionStore::new()));
let catalog_list = Arc::new(DynamicFileCatalog::new(
current_catalog_list,
Arc::clone(&factory) as Arc<dyn UrlTableFactory>,
));
let session_id = self.session_id.clone();
let ctx: SessionContext = self
.into_state_builder()
.with_session_id(session_id)
.with_catalog_list(catalog_list)
.build()
.into();
factory.session_store().with_state(ctx.state_weak_ref());
ctx
}
pub fn into_state_builder(self) -> SessionStateBuilder {
let SessionContext {
session_id: _,
session_start_time: _,
state,
} = self;
let state = match Arc::try_unwrap(state) {
Ok(rwlock) => rwlock.into_inner(),
Err(state) => state.read().clone(),
};
SessionStateBuilder::from(state)
}
pub fn session_start_time(&self) -> DateTime<Utc> {
self.session_start_time
}
pub fn with_function_factory(
self,
function_factory: Arc<dyn FunctionFactory>,
) -> Self {
self.state.write().set_function_factory(function_factory);
self
}
pub fn add_optimizer_rule(
&self,
optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
) {
self.state.write().append_optimizer_rule(optimizer_rule);
}
pub fn add_analyzer_rule(&self, analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>) {
self.state.write().add_analyzer_rule(analyzer_rule);
}
pub fn register_object_store(
&self,
url: &Url,
object_store: Arc<dyn ObjectStore>,
) -> Option<Arc<dyn ObjectStore>> {
self.runtime_env().register_object_store(url, object_store)
}
pub fn register_batch(
&self,
table_name: &str,
batch: RecordBatch,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
self.register_table(
TableReference::Bare {
table: table_name.into(),
},
Arc::new(table),
)
}
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
Arc::clone(self.state.read().runtime_env())
}
pub fn session_id(&self) -> String {
self.session_id.clone()
}
pub fn table_factory(
&self,
file_type: &str,
) -> Option<Arc<dyn TableProviderFactory>> {
self.state.read().table_factories().get(file_type).cloned()
}
pub fn enable_ident_normalization(&self) -> bool {
self.state
.read()
.config()
.options()
.sql_parser
.enable_ident_normalization
}
pub fn copied_config(&self) -> SessionConfig {
self.state.read().config().clone()
}
pub fn copied_table_options(&self) -> TableOptions {
self.state.read().default_table_options()
}
pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
self.sql_with_options(sql, SQLOptions::new()).await
}
pub async fn sql_with_options(
&self,
sql: &str,
options: SQLOptions,
) -> Result<DataFrame> {
let plan = self.state().create_logical_plan(sql).await?;
options.verify_plan(&plan)?;
self.execute_logical_plan(plan).await
}
pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result<Expr> {
self.state.read().create_logical_expr(sql, df_schema)
}
pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
match plan {
LogicalPlan::Ddl(ddl) => {
match ddl {
DdlStatement::CreateExternalTable(cmd) => {
(Box::pin(async move { self.create_external_table(&cmd).await })
as std::pin::Pin<Box<dyn futures::Future<Output = _> + Send>>)
.await
}
DdlStatement::CreateMemoryTable(cmd) => {
Box::pin(self.create_memory_table(cmd)).await
}
DdlStatement::CreateView(cmd) => {
Box::pin(self.create_view(cmd)).await
}
DdlStatement::CreateCatalogSchema(cmd) => {
Box::pin(self.create_catalog_schema(cmd)).await
}
DdlStatement::CreateCatalog(cmd) => {
Box::pin(self.create_catalog(cmd)).await
}
DdlStatement::DropTable(cmd) => Box::pin(self.drop_table(cmd)).await,
DdlStatement::DropView(cmd) => Box::pin(self.drop_view(cmd)).await,
DdlStatement::DropCatalogSchema(cmd) => {
Box::pin(self.drop_schema(cmd)).await
}
DdlStatement::CreateFunction(cmd) => {
Box::pin(self.create_function(cmd)).await
}
DdlStatement::DropFunction(cmd) => {
Box::pin(self.drop_function(cmd)).await
}
ddl => Ok(DataFrame::new(self.state(), LogicalPlan::Ddl(ddl))),
}
}
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
self.set_variable(stmt).await
}
LogicalPlan::Statement(Statement::Prepare(Prepare {
name,
input,
data_types,
})) => {
if !data_types.is_empty() {
let param_names = input.get_parameter_names()?;
if param_names.len() != data_types.len() {
return plan_err!(
"Prepare specifies {} data types but query has {} parameters",
data_types.len(),
param_names.len()
);
}
}
self.state.write().store_prepared(name, data_types, input)?;
self.return_empty_dataframe()
}
LogicalPlan::Statement(Statement::Execute(execute)) => {
self.execute_prepared(execute)
}
LogicalPlan::Statement(Statement::Deallocate(deallocate)) => {
self.state
.write()
.remove_prepared(deallocate.name.as_str())?;
self.return_empty_dataframe()
}
plan => Ok(DataFrame::new(self.state(), plan)),
}
}
pub fn create_physical_expr(
&self,
expr: Expr,
df_schema: &DFSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
self.state.read().create_physical_expr(expr, df_schema)
}
fn return_empty_dataframe(&self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(DataFrame::new(self.state(), plan))
}
async fn create_external_table(
&self,
cmd: &CreateExternalTable,
) -> Result<DataFrame> {
let exist = self.table_exist(cmd.name.clone())?;
if cmd.temporary {
return not_impl_err!("Temporary tables not supported");
}
if exist {
match cmd.if_not_exists {
true => return self.return_empty_dataframe(),
false => {
return exec_err!("Table '{}' already exists", cmd.name);
}
}
}
let table_provider: Arc<dyn TableProvider> =
self.create_custom_table(cmd).await?;
self.register_table(cmd.name.clone(), table_provider)?;
self.return_empty_dataframe()
}
async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result<DataFrame> {
let CreateMemoryTable {
name,
input,
if_not_exists,
or_replace,
constraints,
column_defaults,
temporary,
} = cmd;
let input = Arc::unwrap_or_clone(input);
let input = self.state().optimize(&input)?;
if temporary {
return not_impl_err!("Temporary tables not supported");
}
let table = self.table(name.clone()).await;
match (if_not_exists, or_replace, table) {
(true, false, Ok(_)) => self.return_empty_dataframe(),
(false, true, Ok(_)) => {
self.deregister_table(name.clone())?;
let schema = Arc::new(input.schema().as_ref().into());
let physical = DataFrame::new(self.state(), input);
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(
MemTable::try_new(schema, batches)?
.with_constraints(constraints)
.with_column_defaults(column_defaults.into_iter().collect()),
);
self.register_table(name.clone(), table)?;
self.return_empty_dataframe()
}
(true, true, Ok(_)) => {
exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'")
}
(_, _, Err(_)) => {
let df_schema = input.schema();
let schema = Arc::new(df_schema.as_ref().into());
let physical = DataFrame::new(self.state(), input);
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(
MemTable::try_new(schema, batches)?
.with_constraints(constraints)
.with_column_defaults(column_defaults.into_iter().collect()),
);
self.register_table(name, table)?;
self.return_empty_dataframe()
}
(false, false, Ok(_)) => exec_err!("Table '{name}' already exists"),
}
}
async fn create_view(&self, cmd: CreateView) -> Result<DataFrame> {
let CreateView {
name,
input,
or_replace,
definition,
temporary,
} = cmd;
let view = self.table(name.clone()).await;
if temporary {
return not_impl_err!("Temporary views not supported");
}
match (or_replace, view) {
(true, Ok(_)) => {
self.deregister_table(name.clone())?;
let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?);
self.register_table(name, table)?;
self.return_empty_dataframe()
}
(_, Err(_)) => {
let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?);
self.register_table(name, table)?;
self.return_empty_dataframe()
}
(false, Ok(_)) => exec_err!("Table '{name}' already exists"),
}
}
async fn create_catalog_schema(&self, cmd: CreateCatalogSchema) -> Result<DataFrame> {
let CreateCatalogSchema {
schema_name,
if_not_exists,
..
} = cmd;
let tokens: Vec<&str> = schema_name.split('.').collect();
let (catalog, schema_name) = match tokens.len() {
1 => {
let state = self.state.read();
let name = &state.config().options().catalog.default_catalog;
let catalog = state.catalog_list().catalog(name).ok_or_else(|| {
DataFusionError::Execution(format!(
"Missing default catalog '{name}'"
))
})?;
(catalog, tokens[0])
}
2 => {
let name = &tokens[0];
let catalog = self.catalog(name).ok_or_else(|| {
DataFusionError::Execution(format!("Missing catalog '{name}'"))
})?;
(catalog, tokens[1])
}
_ => return exec_err!("Unable to parse catalog from {schema_name}"),
};
let schema = catalog.schema(schema_name);
match (if_not_exists, schema) {
(true, Some(_)) => self.return_empty_dataframe(),
(true, None) | (false, None) => {
let schema = Arc::new(MemorySchemaProvider::new());
catalog.register_schema(schema_name, schema)?;
self.return_empty_dataframe()
}
(false, Some(_)) => exec_err!("Schema '{schema_name}' already exists"),
}
}
async fn create_catalog(&self, cmd: CreateCatalog) -> Result<DataFrame> {
let CreateCatalog {
catalog_name,
if_not_exists,
..
} = cmd;
let catalog = self.catalog(catalog_name.as_str());
match (if_not_exists, catalog) {
(true, Some(_)) => self.return_empty_dataframe(),
(true, None) | (false, None) => {
let new_catalog = Arc::new(MemoryCatalogProvider::new());
self.state
.write()
.catalog_list()
.register_catalog(catalog_name, new_catalog);
self.return_empty_dataframe()
}
(false, Some(_)) => exec_err!("Catalog '{catalog_name}' already exists"),
}
}
async fn drop_table(&self, cmd: DropTable) -> Result<DataFrame> {
let DropTable {
name, if_exists, ..
} = cmd;
let result = self
.find_and_deregister(name.clone(), TableType::Base)
.await;
match (result, if_exists) {
(Ok(true), _) => self.return_empty_dataframe(),
(_, true) => self.return_empty_dataframe(),
(_, _) => exec_err!("Table '{name}' doesn't exist."),
}
}
async fn drop_view(&self, cmd: DropView) -> Result<DataFrame> {
let DropView {
name, if_exists, ..
} = cmd;
let result = self
.find_and_deregister(name.clone(), TableType::View)
.await;
match (result, if_exists) {
(Ok(true), _) => self.return_empty_dataframe(),
(_, true) => self.return_empty_dataframe(),
(_, _) => exec_err!("View '{name}' doesn't exist."),
}
}
async fn drop_schema(&self, cmd: DropCatalogSchema) -> Result<DataFrame> {
let DropCatalogSchema {
name,
if_exists: allow_missing,
cascade,
schema: _,
} = cmd;
let catalog = {
let state = self.state.read();
let catalog_name = match &name {
SchemaReference::Full { catalog, .. } => catalog.to_string(),
SchemaReference::Bare { .. } => {
state.config_options().catalog.default_catalog.to_string()
}
};
if let Some(catalog) = state.catalog_list().catalog(&catalog_name) {
catalog
} else if allow_missing {
return self.return_empty_dataframe();
} else {
return self.schema_doesnt_exist_err(name);
}
};
let dereg = catalog.deregister_schema(name.schema_name(), cascade)?;
match (dereg, allow_missing) {
(None, true) => self.return_empty_dataframe(),
(None, false) => self.schema_doesnt_exist_err(name),
(Some(_), _) => self.return_empty_dataframe(),
}
}
fn schema_doesnt_exist_err(&self, schemaref: SchemaReference) -> Result<DataFrame> {
exec_err!("Schema '{schemaref}' doesn't exist.")
}
async fn set_variable(&self, stmt: SetVariable) -> Result<DataFrame> {
let SetVariable {
variable, value, ..
} = stmt;
let mut state = self.state.write();
state.config_mut().options_mut().set(&variable, &value)?;
drop(state);
self.return_empty_dataframe()
}
async fn create_custom_table(
&self,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn TableProvider>> {
let state = self.state.read().clone();
let file_type = cmd.file_type.to_uppercase();
let factory =
state
.table_factories()
.get(file_type.as_str())
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory).create(&state, cmd).await?;
Ok(table)
}
async fn find_and_deregister(
&self,
table_ref: impl Into<TableReference>,
table_type: TableType,
) -> Result<bool> {
let table_ref = table_ref.into();
let table = table_ref.table().to_owned();
let maybe_schema = {
let state = self.state.read();
let resolved = state.resolve_table_ref(table_ref);
state
.catalog_list()
.catalog(&resolved.catalog)
.and_then(|c| c.schema(&resolved.schema))
};
if let Some(schema) = maybe_schema {
if let Some(table_provider) = schema.table(&table).await? {
if table_provider.table_type() == table_type {
schema.deregister_table(&table)?;
return Ok(true);
}
}
}
Ok(false)
}
async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
let function = {
let state = self.state.read().clone();
let function_factory = state.function_factory();
match function_factory {
Some(f) => f.create(&state, stmt).await?,
_ => Err(DataFusionError::Configuration(
"Function factory has not been configured".into(),
))?,
}
};
match function {
RegisterFunction::Scalar(f) => {
self.state.write().register_udf(f)?;
}
RegisterFunction::Aggregate(f) => {
self.state.write().register_udaf(f)?;
}
RegisterFunction::Window(f) => {
self.state.write().register_udwf(f)?;
}
RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
};
self.return_empty_dataframe()
}
async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
let mut dropped = false;
dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some();
if !stmt.if_exists && !dropped {
exec_err!("Function does not exist")
} else {
self.return_empty_dataframe()
}
}
fn execute_prepared(&self, execute: Execute) -> Result<DataFrame> {
let Execute {
name, parameters, ..
} = execute;
let prepared = self.state.read().get_prepared(&name).ok_or_else(|| {
exec_datafusion_err!("Prepared statement '{}' does not exist", name)
})?;
let mut params: Vec<ScalarValue> = parameters
.into_iter()
.map(|e| match e {
Expr::Literal(scalar) => Ok(scalar),
_ => not_impl_err!("Unsupported parameter type: {}", e),
})
.collect::<Result<_>>()?;
if !prepared.data_types.is_empty() {
if params.len() != prepared.data_types.len() {
return exec_err!(
"Prepared statement '{}' expects {} parameters, but {} provided",
name,
prepared.data_types.len(),
params.len()
);
}
params = params
.into_iter()
.zip(prepared.data_types.iter())
.map(|(e, dt)| e.cast_to(dt))
.collect::<Result<_>>()?;
}
let params = ParamValues::List(params);
let plan = prepared
.plan
.as_ref()
.clone()
.replace_params_with_values(¶ms)?;
Ok(DataFrame::new(self.state(), plan))
}
pub fn register_variable(
&self,
variable_type: VarType,
provider: Arc<dyn VarProvider + Send + Sync>,
) {
self.state
.write()
.execution_props_mut()
.add_var_provider(variable_type, provider);
}
pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
self.state.write().register_udtf(name, fun)
}
pub fn register_udf(&self, f: ScalarUDF) {
let mut state = self.state.write();
state.register_udf(Arc::new(f)).ok();
}
pub fn register_udaf(&self, f: AggregateUDF) {
self.state.write().register_udaf(Arc::new(f)).ok();
}
pub fn register_udwf(&self, f: WindowUDF) {
self.state.write().register_udwf(Arc::new(f)).ok();
}
pub fn deregister_udf(&self, name: &str) {
self.state.write().deregister_udf(name).ok();
}
pub fn deregister_udaf(&self, name: &str) {
self.state.write().deregister_udaf(name).ok();
}
pub fn deregister_udwf(&self, name: &str) {
self.state.write().deregister_udwf(name).ok();
}
pub fn deregister_udtf(&self, name: &str) {
self.state.write().deregister_udtf(name).ok();
}
async fn _read_type<'a, P: DataFilePaths>(
&self,
table_paths: P,
options: impl ReadOptions<'a>,
) -> Result<DataFrame> {
let table_paths = table_paths.to_urls()?;
let session_config = self.copied_config();
let listing_options =
options.to_listing_options(&session_config, self.copied_table_options());
let option_extension = listing_options.file_extension.clone();
if table_paths.is_empty() {
return exec_err!("No table paths were provided");
}
for path in &table_paths {
let file_path = path.as_str();
if !file_path.ends_with(option_extension.clone().as_str())
&& !path.is_collection()
{
return exec_err!(
"File path '{file_path}' does not match the expected extension '{option_extension}'"
);
}
}
let resolved_schema = options
.get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
.await?;
let config = ListingTableConfig::new_with_multi_paths(table_paths)
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let provider = ListingTable::try_new(config)?;
self.read_table(Arc::new(provider))
}
pub async fn read_arrow<P: DataFilePaths>(
&self,
table_paths: P,
options: ArrowReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}
pub fn read_empty(&self) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::empty(true).build()?,
))
}
pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
.build()?,
))
}
pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?,
))
}
pub fn read_batches(
&self,
batches: impl IntoIterator<Item = RecordBatch>,
) -> Result<DataFrame> {
let mut batches = batches.into_iter().peekable();
let schema = if let Some(batch) = batches.peek() {
batch.schema()
} else {
Arc::new(Schema::empty())
};
let provider = MemTable::try_new(schema, vec![batches.collect()])?;
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?,
))
}
pub async fn register_listing_table(
&self,
table_ref: impl Into<TableReference>,
table_path: impl AsRef<str>,
options: ListingOptions,
provided_schema: Option<SchemaRef>,
sql_definition: Option<String>,
) -> Result<()> {
let table_path = ListingTableUrl::parse(table_path)?;
let resolved_schema = match provided_schema {
Some(s) => s,
None => options.infer_schema(&self.state(), &table_path).await?,
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?.with_definition(sql_definition);
self.register_table(table_ref, Arc::new(table))?;
Ok(())
}
fn register_type_check<P: DataFilePaths>(
&self,
table_paths: P,
extension: impl AsRef<str>,
) -> Result<()> {
let table_paths = table_paths.to_urls()?;
if table_paths.is_empty() {
return exec_err!("No table paths were provided");
}
let extension = extension.as_ref();
for path in &table_paths {
let file_path = path.as_str();
if !file_path.ends_with(extension) && !path.is_collection() {
return exec_err!(
"File path '{file_path}' does not match the expected extension '{extension}'"
);
}
}
Ok(())
}
pub async fn register_arrow(
&self,
name: &str,
table_path: &str,
options: ArrowReadOptions<'_>,
) -> Result<()> {
let listing_options = options
.to_listing_options(&self.copied_config(), self.copied_table_options());
self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}
pub fn register_catalog(
&self,
name: impl Into<String>,
catalog: Arc<dyn CatalogProvider>,
) -> Option<Arc<dyn CatalogProvider>> {
let name = name.into();
self.state
.read()
.catalog_list()
.register_catalog(name, catalog)
}
pub fn catalog_names(&self) -> Vec<String> {
self.state.read().catalog_list().catalog_names()
}
pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
self.state.read().catalog_list().catalog(name)
}
pub fn register_table(
&self,
table_ref: impl Into<TableReference>,
provider: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table_ref: TableReference = table_ref.into();
let table = table_ref.table().to_owned();
self.state
.read()
.schema_for_ref(table_ref)?
.register_table(table, provider)
}
pub fn deregister_table(
&self,
table_ref: impl Into<TableReference>,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table_ref = table_ref.into();
let table = table_ref.table().to_owned();
self.state
.read()
.schema_for_ref(table_ref)?
.deregister_table(&table)
}
pub fn table_exist(&self, table_ref: impl Into<TableReference>) -> Result<bool> {
let table_ref: TableReference = table_ref.into();
let table = table_ref.table();
let table_ref = table_ref.clone();
Ok(self
.state
.read()
.schema_for_ref(table_ref)?
.table_exist(table))
}
pub async fn table(&self, table_ref: impl Into<TableReference>) -> Result<DataFrame> {
let table_ref: TableReference = table_ref.into();
let provider = self.table_provider(table_ref.clone()).await?;
let plan = LogicalPlanBuilder::scan(
table_ref,
provider_as_source(Arc::clone(&provider)),
None,
)?
.build()?;
Ok(DataFrame::new(self.state(), plan))
}
pub fn table_function(&self, name: &str) -> Result<Arc<TableFunction>> {
self.state
.read()
.table_functions()
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Table function '{name}' not found"))
}
pub async fn table_provider(
&self,
table_ref: impl Into<TableReference>,
) -> Result<Arc<dyn TableProvider>> {
let table_ref = table_ref.into();
let table = table_ref.table().to_string();
let schema = self.state.read().schema_for_ref(table_ref)?;
match schema.table(&table).await? {
Some(ref provider) => Ok(Arc::clone(provider)),
_ => plan_err!("No table named '{table}'"),
}
}
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}
pub fn state(&self) -> SessionState {
let mut state = self.state.read().clone();
state.execution_props_mut().start_execution();
state
}
pub fn state_ref(&self) -> Arc<RwLock<SessionState>> {
Arc::clone(&self.state)
}
pub fn state_weak_ref(&self) -> Weak<RwLock<SessionState>> {
Arc::downgrade(&self.state)
}
pub fn register_catalog_list(&self, catalog_list: Arc<dyn CatalogProviderList>) {
self.state.write().register_catalog_list(catalog_list)
}
pub fn register_table_options_extension<T: ConfigExtension>(&self, extension: T) {
self.state
.write()
.register_table_options_extension(extension)
}
}
impl FunctionRegistry for SessionContext {
fn udfs(&self) -> HashSet<String> {
self.state.read().udfs()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.state.read().udf(name)
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.state.read().udaf(name)
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.state.read().udwf(name)
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
self.state.write().register_udf(udf)
}
fn register_udaf(
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
self.state.write().register_udaf(udaf)
}
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
self.state.write().register_udwf(udwf)
}
fn register_function_rewrite(
&mut self,
rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
) -> Result<()> {
self.state.write().register_function_rewrite(rewrite)
}
fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
self.state.read().expr_planners()
}
fn register_expr_planner(
&mut self,
expr_planner: Arc<dyn ExprPlanner>,
) -> Result<()> {
self.state.write().register_expr_planner(expr_planner)
}
}
impl From<&SessionContext> for TaskContext {
fn from(session: &SessionContext) -> Self {
TaskContext::from(&*session.state.read())
}
}
impl From<SessionState> for SessionContext {
fn from(state: SessionState) -> Self {
Self::new_with_state(state)
}
}
impl From<SessionContext> for SessionStateBuilder {
fn from(session: SessionContext) -> Self {
session.into_state_builder()
}
}
#[async_trait]
pub trait QueryPlanner: Debug {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>>;
}
#[async_trait]
pub trait FunctionFactory: Debug + Sync + Send {
async fn create(
&self,
state: &SessionState,
statement: CreateFunction,
) -> Result<RegisterFunction>;
}
pub enum RegisterFunction {
Scalar(Arc<ScalarUDF>),
Aggregate(Arc<AggregateUDF>),
Window(Arc<WindowUDF>),
Table(String, Arc<dyn TableFunctionImpl>),
}
#[derive(Debug)]
pub struct EmptySerializerRegistry;
impl SerializerRegistry for EmptySerializerRegistry {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)
}
fn deserialize_logical_plan(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
)
}
}
#[derive(Clone, Debug, Copy)]
pub struct SQLOptions {
allow_ddl: bool,
allow_dml: bool,
allow_statements: bool,
}
impl Default for SQLOptions {
fn default() -> Self {
Self {
allow_ddl: true,
allow_dml: true,
allow_statements: true,
}
}
}
impl SQLOptions {
pub fn new() -> Self {
Default::default()
}
pub fn with_allow_ddl(mut self, allow: bool) -> Self {
self.allow_ddl = allow;
self
}
pub fn with_allow_dml(mut self, allow: bool) -> Self {
self.allow_dml = allow;
self
}
pub fn with_allow_statements(mut self, allow: bool) -> Self {
self.allow_statements = allow;
self
}
pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> {
plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?;
Ok(())
}
}
struct BadPlanVisitor<'a> {
options: &'a SQLOptions,
}
impl<'a> BadPlanVisitor<'a> {
fn new(options: &'a SQLOptions) -> Self {
Self { options }
}
}
impl<'n> TreeNodeVisitor<'n> for BadPlanVisitor<'_> {
type Node = LogicalPlan;
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
plan_err!("DDL not supported: {}", ddl.name())
}
LogicalPlan::Dml(dml) if !self.options.allow_dml => {
plan_err!("DML not supported: {}", dml.op)
}
LogicalPlan::Copy(_) if !self.options.allow_dml => {
plan_err!("DML not supported: COPY")
}
LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
plan_err!("Statement not supported: {}", stmt.name())
}
_ => Ok(TreeNodeRecursion::Continue),
}
}
}
#[cfg(test)]
mod tests {
use super::{super::options::CsvReadOptions, *};
use crate::assert_batches_eq;
use crate::execution::memory_pool::MemoryConsumer;
use crate::test;
use crate::test_util::{plan_and_collect, populate_csv_partitions};
use arrow::datatypes::{DataType, TimeUnit};
use std::env;
use std::error::Error;
use std::path::PathBuf;
use datafusion_common_runtime::SpawnedTask;
use crate::catalog::SchemaProvider;
use crate::execution::session_state::SessionStateBuilder;
use crate::physical_planner::PhysicalPlanner;
use async_trait::async_trait;
use datafusion_expr::planner::TypePlanner;
use sqlparser::ast;
use tempfile::TempDir;
#[tokio::test]
async fn shared_memory_and_disk_manager() {
let ctx1 = SessionContext::new();
let memory_pool = ctx1.runtime_env().memory_pool.clone();
let mut reservation = MemoryConsumer::new("test").register(&memory_pool);
reservation.grow(100);
let disk_manager = ctx1.runtime_env().disk_manager.clone();
let ctx2 =
SessionContext::new_with_config_rt(SessionConfig::new(), ctx1.runtime_env());
assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100);
assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100);
drop(reservation);
assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 0);
assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 0);
assert!(std::ptr::eq(
Arc::as_ptr(&disk_manager),
Arc::as_ptr(&ctx1.runtime_env().disk_manager)
));
assert!(std::ptr::eq(
Arc::as_ptr(&disk_manager),
Arc::as_ptr(&ctx2.runtime_env().disk_manager)
));
}
#[tokio::test]
async fn create_variable_expr() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = create_ctx(&tmp_dir, partition_count).await?;
let variable_provider = test::variable::SystemVar::new();
ctx.register_variable(VarType::System, Arc::new(variable_provider));
let variable_provider = test::variable::UserDefinedVar::new();
ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));
let provider = test::create_table_dual();
ctx.register_table("dual", provider)?;
let results =
plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
.await?;
let expected = [
"+----------------------+------------------------+---------------------+",
"| @@version | @name | @integer + Int64(1) |",
"+----------------------+------------------------+---------------------+",
"| system-var-@@version | user-defined-var-@name | 42 |",
"+----------------------+------------------------+---------------------+",
];
assert_batches_eq!(expected, &results);
Ok(())
}
#[tokio::test]
async fn create_variable_err() -> Result<()> {
let ctx = SessionContext::new();
let err = plan_and_collect(&ctx, "SElECT @= X3").await.unwrap_err();
assert_eq!(
err.strip_backtrace(),
"Error during planning: variable [\"@=\"] has no type information"
);
Ok(())
}
#[tokio::test]
async fn register_deregister() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = create_ctx(&tmp_dir, partition_count).await?;
let provider = test::create_table_dual();
ctx.register_table("dual", provider)?;
assert!(ctx.deregister_table("dual")?.is_some());
assert!(ctx.deregister_table("dual")?.is_none());
Ok(())
}
#[tokio::test]
async fn send_context_to_threads() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = Arc::new(create_ctx(&tmp_dir, partition_count).await?);
let threads: Vec<_> = (0..2)
.map(|_| ctx.clone())
.map(|ctx| {
SpawnedTask::spawn(async move {
ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")
.await
})
})
.collect();
for handle in threads {
handle.join().await.unwrap().unwrap();
}
Ok(())
}
#[tokio::test]
async fn with_listing_schema_provider() -> Result<()> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let path = path.join("tests/tpch-csv");
let url = format!("file://{}", path.display());
let cfg = SessionConfig::new()
.set_str("datafusion.catalog.location", url.as_str())
.set_str("datafusion.catalog.format", "CSV")
.set_str("datafusion.catalog.has_header", "true");
let session_state = SessionStateBuilder::new()
.with_config(cfg)
.with_default_features()
.build();
let ctx = SessionContext::new_with_state(session_state);
ctx.refresh_catalogs().await?;
let result =
plan_and_collect(&ctx, "select c_name from default.customer limit 3;")
.await?;
let actual = arrow::util::pretty::pretty_format_batches(&result)
.unwrap()
.to_string();
let expected = r#"+--------------------+
| c_name |
+--------------------+
| Customer#000000002 |
| Customer#000000003 |
| Customer#000000004 |
+--------------------+"#;
assert_eq!(actual, expected);
Ok(())
}
#[tokio::test]
async fn test_dynamic_file_query() -> Result<()> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let path = path.join("tests/tpch-csv/customer.csv");
let url = format!("file://{}", path.display());
let cfg = SessionConfig::new();
let session_state = SessionStateBuilder::new()
.with_default_features()
.with_config(cfg)
.build();
let ctx = SessionContext::new_with_state(session_state).enable_url_table();
let result = plan_and_collect(
&ctx,
format!("select c_name from '{}' limit 3;", &url).as_str(),
)
.await?;
let actual = arrow::util::pretty::pretty_format_batches(&result)
.unwrap()
.to_string();
let expected = r#"+--------------------+
| c_name |
+--------------------+
| Customer#000000002 |
| Customer#000000003 |
| Customer#000000004 |
+--------------------+"#;
assert_eq!(actual, expected);
Ok(())
}
#[tokio::test]
async fn custom_query_planner() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let session_state = SessionStateBuilder::new()
.with_config(SessionConfig::new())
.with_runtime_env(runtime)
.with_default_features()
.with_query_planner(Arc::new(MyQueryPlanner {}))
.build();
let ctx = SessionContext::new_with_state(session_state);
let df = ctx.sql("SELECT 1").await?;
df.collect().await.expect_err("query not supported");
Ok(())
}
#[tokio::test]
async fn disabled_default_catalog_and_schema() -> Result<()> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_create_default_catalog_and_schema(false),
);
assert!(matches!(
ctx.register_table("test", test::table_with_sequence(1, 1)?),
Err(DataFusionError::Plan(_))
));
let err = ctx
.sql("select * from datafusion.public.test")
.await
.unwrap_err();
let err = err
.source()
.and_then(|err| err.downcast_ref::<DataFusionError>())
.unwrap();
assert!(matches!(err, &DataFusionError::Plan(_)));
Ok(())
}
#[tokio::test]
async fn custom_catalog_and_schema() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(true)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
#[tokio::test]
async fn custom_catalog_and_schema_no_default() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(false)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
#[tokio::test]
async fn custom_catalog_and_schema_and_information_schema() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(true)
.with_information_schema(true)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
async fn catalog_and_schema_test(config: SessionConfig) {
let ctx = SessionContext::new_with_config(config);
let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
schema
.register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap())
.unwrap();
catalog
.register_schema("my_schema", Arc::new(schema))
.unwrap();
ctx.register_catalog("my_catalog", Arc::new(catalog));
for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] {
let result = plan_and_collect(
&ctx,
&format!("SELECT COUNT(*) AS count FROM {table_ref}"),
)
.await
.unwrap();
let expected = [
"+-------+",
"| count |",
"+-------+",
"| 1 |",
"+-------+",
];
assert_batches_eq!(expected, &result);
}
}
#[tokio::test]
async fn cross_catalog_access() -> Result<()> {
let ctx = SessionContext::new();
let catalog_a = MemoryCatalogProvider::new();
let schema_a = MemorySchemaProvider::new();
schema_a
.register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
catalog_a.register_schema("schema_a", Arc::new(schema_a))?;
ctx.register_catalog("catalog_a", Arc::new(catalog_a));
let catalog_b = MemoryCatalogProvider::new();
let schema_b = MemorySchemaProvider::new();
schema_b
.register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
catalog_b.register_schema("schema_b", Arc::new(schema_b))?;
ctx.register_catalog("catalog_b", Arc::new(catalog_b));
let result = plan_and_collect(
&ctx,
"SELECT cat, SUM(i) AS total FROM (
SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a
UNION ALL
SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b
) AS all
GROUP BY cat
ORDER BY cat
",
)
.await?;
let expected = [
"+-----+-------+",
"| cat | total |",
"+-----+-------+",
"| a | 1 |",
"| b | 3 |",
"+-----+-------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn catalogs_not_leaked() {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_information_schema(true),
);
let catalog = Arc::new(MemoryCatalogProvider::new());
let catalog_weak = Arc::downgrade(&catalog);
ctx.register_catalog("my_catalog", catalog);
let catalog_list_weak = {
let state = ctx.state.read();
Arc::downgrade(state.catalog_list())
};
drop(ctx);
assert_eq!(Weak::strong_count(&catalog_list_weak), 0);
assert_eq!(Weak::strong_count(&catalog_weak), 0);
}
#[tokio::test]
async fn sql_create_schema() -> Result<()> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_information_schema(true),
);
ctx.sql("CREATE SCHEMA abc").await?.collect().await?;
ctx.sql("CREATE TABLE abc.y AS VALUES (1,2,3)")
.await?
.collect()
.await?;
let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
Ok(())
}
#[tokio::test]
async fn sql_create_catalog() -> Result<()> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_information_schema(true),
);
ctx.sql("CREATE DATABASE test").await?.collect().await?;
ctx.sql("CREATE SCHEMA test.abc").await?.collect().await?;
ctx.sql("CREATE TABLE test.abc.y AS VALUES (1,2,3)")
.await?
.collect()
.await?;
let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_catalog='test' AND table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
Ok(())
}
#[tokio::test]
async fn custom_type_planner() -> Result<()> {
let state = SessionStateBuilder::new()
.with_default_features()
.with_type_planner(Arc::new(MyTypePlanner {}))
.build();
let ctx = SessionContext::new_with_state(state);
let result = ctx
.sql("SELECT DATETIME '2021-01-01 00:00:00'")
.await?
.collect()
.await?;
let expected = [
"+-----------------------------+",
"| Utf8(\"2021-01-01 00:00:00\") |",
"+-----------------------------+",
"| 2021-01-01T00:00:00 |",
"+-----------------------------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[test]
fn preserve_session_context_id() -> Result<()> {
let ctx = SessionContext::new();
assert_eq!(ctx.session_id(), ctx.enable_url_table().session_id());
Ok(())
}
struct MyPhysicalPlanner {}
#[async_trait]
impl PhysicalPlanner for MyPhysicalPlanner {
async fn create_physical_plan(
&self,
_logical_plan: &LogicalPlan,
_session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
not_impl_err!("query not supported")
}
fn create_physical_expr(
&self,
_expr: &Expr,
_input_dfschema: &DFSchema,
_session_state: &SessionState,
) -> Result<Arc<dyn PhysicalExpr>> {
unimplemented!()
}
}
#[derive(Debug)]
struct MyQueryPlanner {}
#[async_trait]
impl QueryPlanner for MyQueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
let physical_planner = MyPhysicalPlanner {};
physical_planner
.create_physical_plan(logical_plan, session_state)
.await
}
}
async fn create_ctx(
tmp_dir: &TempDir,
partition_count: usize,
) -> Result<SessionContext> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_target_partitions(8),
);
let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?;
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new().schema(&schema),
)
.await?;
Ok(ctx)
}
#[derive(Debug)]
struct MyTypePlanner {}
impl TypePlanner for MyTypePlanner {
fn plan_type(&self, sql_type: &ast::DataType) -> Result<Option<DataType>> {
match sql_type {
ast::DataType::Datetime(precision) => {
let precision = match precision {
Some(0) => TimeUnit::Second,
Some(3) => TimeUnit::Millisecond,
Some(6) => TimeUnit::Microsecond,
None | Some(9) => TimeUnit::Nanosecond,
_ => unreachable!(),
};
Ok(Some(DataType::Timestamp(precision, None)))
}
_ => Ok(None),
}
}
}
}