use crate::{
catalog::{
catalog::{CatalogList, MemoryCatalogList},
information_schema::CatalogWithInformationSchema,
},
datasource::listing::{ListingOptions, ListingTable},
datasource::{
file_format::{
avro::{AvroFormat, DEFAULT_AVRO_EXTENSION},
csv::{CsvFormat, DEFAULT_CSV_EXTENSION},
json::{JsonFormat, DEFAULT_JSON_EXTENSION},
parquet::{ParquetFormat, DEFAULT_PARQUET_EXTENSION},
FileFormat,
},
MemTable, ViewTable,
},
logical_plan::{PlanType, ToStringifiedPlan},
optimizer::{
eliminate_filter::EliminateFilter, eliminate_limit::EliminateLimit,
optimizer::Optimizer,
},
physical_optimizer::{
aggregate_statistics::AggregateStatistics,
hash_build_probe_order::HashBuildProbeOrder, optimizer::PhysicalOptimizerRule,
},
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
use parking_lot::RwLock;
use std::sync::Arc;
use std::{
any::{Any, TypeId},
hash::{BuildHasherDefault, Hasher},
string::String,
};
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
};
use arrow::datatypes::{DataType, SchemaRef};
use crate::catalog::{
catalog::{CatalogProvider, MemoryCatalogProvider},
schema::{MemorySchemaProvider, SchemaProvider},
};
use crate::dataframe::DataFrame;
use crate::datasource::listing::{ListingTableConfig, ListingTableUrl};
use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::{
provider_as_source, CreateCatalog, CreateCatalogSchema, CreateExternalTable,
CreateMemoryTable, CreateView, DropTable, FunctionRegistry, LogicalPlan,
LogicalPlanBuilder, UNNAMED_TABLE,
};
use crate::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
use crate::optimizer::filter_push_down::FilterPushDown;
use crate::optimizer::limit_push_down::LimitPushDown;
use crate::optimizer::optimizer::{OptimizerConfig, OptimizerRule};
use crate::optimizer::projection_push_down::ProjectionPushDown;
use crate::optimizer::reduce_outer_join::ReduceOuterJoin;
use crate::optimizer::simplify_expressions::SimplifyExpressions;
use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use crate::optimizer::subquery_filter_to_join::SubqueryFilterToJoin;
use datafusion_sql::{ResolvedTableReference, TableReference};
use crate::physical_optimizer::coalesce_batches::CoalesceBatches;
use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec;
use crate::physical_optimizer::repartition::Repartition;
use crate::config::{
ConfigOptions, OPT_BATCH_SIZE, OPT_COALESCE_BATCHES, OPT_COALESCE_TARGET_BATCH_SIZE,
OPT_FILTER_NULL_JOIN_KEYS, OPT_OPTIMIZER_SKIP_FAILED_RULES,
};
use crate::datasource::datasource::TableProviderFactory;
use crate::execution::runtime_env::RuntimeEnv;
use crate::logical_plan::plan::Explain;
use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parquet};
use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::physical_plan::udaf::AggregateUDF;
use crate::physical_plan::udf::ScalarUDF;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::PhysicalPlanner;
use crate::variable::{VarProvider, VarType};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_common::ScalarValue;
use datafusion_expr::logical_plan::DropView;
use datafusion_expr::{TableSource, TableType};
use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
use datafusion_optimizer::type_coercion::TypeCoercion;
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
};
use parquet::file::properties::WriterProperties;
use uuid::Uuid;
use super::options::{
AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions,
};
const DEFAULT_CATALOG: &str = "datafusion";
const DEFAULT_SCHEMA: &str = "public";
#[derive(Clone)]
pub struct SessionContext {
session_id: String,
pub session_start_time: DateTime<Utc>,
pub state: Arc<RwLock<SessionState>>,
pub table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
}
impl Default for SessionContext {
fn default() -> Self {
Self::new()
}
}
impl SessionContext {
pub fn new() -> Self {
Self::with_config(SessionConfig::new())
}
pub fn with_config(config: SessionConfig) -> Self {
let runtime = Arc::new(RuntimeEnv::default());
Self::with_config_rt(config, runtime)
}
pub fn with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let state = SessionState::with_config_rt(config, runtime);
Self {
session_id: state.session_id.clone(),
session_start_time: chrono::Utc::now(),
state: Arc::new(RwLock::new(state)),
table_factories: HashMap::default(),
}
}
pub fn with_state(state: SessionState) -> Self {
Self {
session_id: state.session_id.clone(),
session_start_time: chrono::Utc::now(),
state: Arc::new(RwLock::new(state)),
table_factories: HashMap::default(),
}
}
pub fn register_table_factory(
&mut self,
file_type: &str,
factory: Arc<dyn TableProviderFactory>,
) {
self.table_factories.insert(file_type.to_string(), factory);
}
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
self.state.read().runtime_env.clone()
}
pub fn session_id(&self) -> String {
self.session_id.clone()
}
pub fn copied_config(&self) -> SessionConfig {
self.state.read().config.clone()
}
pub async fn sql(&self, sql: &str) -> Result<Arc<DataFrame>> {
let plan = self.create_logical_plan(sql)?;
match plan {
LogicalPlan::CreateExternalTable(cmd) => match cmd.file_type.as_str() {
"PARQUET" | "CSV" | "JSON" | "AVRO" => {
self.create_listing_table(&cmd).await
}
_ => self.create_custom_table(&cmd).await,
},
LogicalPlan::CreateMemoryTable(CreateMemoryTable {
name,
input,
if_not_exists,
or_replace,
}) => {
let table = self.table(name.as_str());
match (if_not_exists, or_replace, table) {
(true, false, Ok(_)) => self.return_empty_dataframe(),
(false, true, Ok(_)) => {
self.deregister_table(name.as_str())?;
let physical =
Arc::new(DataFrame::new(self.state.clone(), &input));
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(
Arc::new(input.schema().as_ref().into()),
batches,
)?);
self.register_table(name.as_str(), table)?;
self.return_empty_dataframe()
}
(true, true, Ok(_)) => Err(DataFusionError::Internal(
"'IF NOT EXISTS' cannot coexist with 'REPLACE'".to_string(),
)),
(_, _, Err(_)) => {
let physical =
Arc::new(DataFrame::new(self.state.clone(), &input));
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(
Arc::new(input.schema().as_ref().into()),
batches,
)?);
self.register_table(name.as_str(), table)?;
self.return_empty_dataframe()
}
(false, false, Ok(_)) => Err(DataFusionError::Execution(format!(
"Table '{:?}' already exists",
name
))),
}
}
LogicalPlan::CreateView(CreateView {
name,
input,
or_replace,
definition,
}) => {
let view = self.table(name.as_str());
match (or_replace, view) {
(true, Ok(_)) => {
self.deregister_table(name.as_str())?;
let table =
Arc::new(ViewTable::try_new((*input).clone(), definition)?);
self.register_table(name.as_str(), table)?;
self.return_empty_dataframe()
}
(_, Err(_)) => {
let table =
Arc::new(ViewTable::try_new((*input).clone(), definition)?);
self.register_table(name.as_str(), table)?;
self.return_empty_dataframe()
}
(false, Ok(_)) => Err(DataFusionError::Execution(format!(
"Table '{:?}' already exists",
name
))),
}
}
LogicalPlan::DropTable(DropTable {
name, if_exists, ..
}) => {
let result = self.find_and_deregister(name.as_str(), TableType::Base);
match (result, if_exists) {
(Ok(true), _) => self.return_empty_dataframe(),
(_, true) => self.return_empty_dataframe(),
(_, _) => Err(DataFusionError::Execution(format!(
"Table {:?} doesn't exist.",
name
))),
}
}
LogicalPlan::DropView(DropView {
name, if_exists, ..
}) => {
let result = self.find_and_deregister(name.as_str(), TableType::View);
match (result, if_exists) {
(Ok(true), _) => self.return_empty_dataframe(),
(_, true) => self.return_empty_dataframe(),
(_, _) => Err(DataFusionError::Execution(format!(
"View {:?} doesn't exist.",
name
))),
}
}
LogicalPlan::CreateCatalogSchema(CreateCatalogSchema {
schema_name,
if_not_exists,
..
}) => {
let tokens: Vec<&str> = schema_name.split('.').collect();
let (catalog, schema_name) = match tokens.len() {
1 => Ok((DEFAULT_CATALOG, schema_name.as_str())),
2 => Ok((tokens[0], tokens[1])),
_ => Err(DataFusionError::Execution(format!(
"Unable to parse catalog from {}",
schema_name
))),
}?;
let catalog = self.catalog(catalog).ok_or_else(|| {
DataFusionError::Execution(format!(
"Missing '{}' catalog",
DEFAULT_CATALOG
))
})?;
let schema = catalog.schema(schema_name);
match (if_not_exists, schema) {
(true, Some(_)) => self.return_empty_dataframe(),
(true, None) | (false, None) => {
let schema = Arc::new(MemorySchemaProvider::new());
catalog.register_schema(schema_name, schema)?;
self.return_empty_dataframe()
}
(false, Some(_)) => Err(DataFusionError::Execution(format!(
"Schema '{:?}' already exists",
schema_name
))),
}
}
LogicalPlan::CreateCatalog(CreateCatalog {
catalog_name,
if_not_exists,
..
}) => {
let catalog = self.catalog(catalog_name.as_str());
match (if_not_exists, catalog) {
(true, Some(_)) => self.return_empty_dataframe(),
(true, None) | (false, None) => {
let new_catalog = Arc::new(MemoryCatalogProvider::new());
self.state
.write()
.catalog_list
.register_catalog(catalog_name, new_catalog);
self.return_empty_dataframe()
}
(false, Some(_)) => Err(DataFusionError::Execution(format!(
"Catalog '{:?}' already exists",
catalog_name
))),
}
}
plan => Ok(Arc::new(DataFrame::new(self.state.clone(), &plan))),
}
}
fn return_empty_dataframe(&self) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
async fn create_custom_table(
&self,
cmd: &CreateExternalTable,
) -> Result<Arc<DataFrame>> {
let factory = &self.table_factories.get(&cmd.file_type).ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory).create(cmd.name.as_str(), cmd.location.as_str());
self.register_table(cmd.name.as_str(), table)?;
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
async fn create_listing_table(
&self,
cmd: &CreateExternalTable,
) -> Result<Arc<DataFrame>> {
let (file_format, file_extension) = match cmd.file_type.as_str() {
"CSV" => (
Arc::new(
CsvFormat::default()
.with_has_header(cmd.has_header)
.with_delimiter(cmd.delimiter as u8),
) as Arc<dyn FileFormat>,
DEFAULT_CSV_EXTENSION,
),
"PARQUET" => (
Arc::new(ParquetFormat::default()) as Arc<dyn FileFormat>,
DEFAULT_PARQUET_EXTENSION,
),
"AVRO" => (
Arc::new(AvroFormat::default()) as Arc<dyn FileFormat>,
DEFAULT_AVRO_EXTENSION,
),
"JSON" => (
Arc::new(JsonFormat::default()) as Arc<dyn FileFormat>,
DEFAULT_JSON_EXTENSION,
),
_ => Err(DataFusionError::Execution(
"Only known FileTypes can be ListingTables!".to_string(),
))?,
};
let table = self.table(cmd.name.as_str());
match (cmd.if_not_exists, table) {
(true, Ok(_)) => self.return_empty_dataframe(),
(_, Err(_)) => {
let provided_schema = if cmd.schema.fields().is_empty() {
None
} else {
Some(Arc::new(cmd.schema.as_ref().to_owned().into()))
};
let options = ListingOptions {
format: file_format,
collect_stat: false,
file_extension: file_extension.to_owned(),
target_partitions: self.copied_config().target_partitions,
table_partition_cols: cmd.table_partition_cols.clone(),
};
self.register_listing_table(
cmd.name.as_str(),
cmd.location.clone(),
options,
provided_schema,
cmd.definition.clone(),
)
.await?;
self.return_empty_dataframe()
}
(false, Ok(_)) => Err(DataFusionError::Execution(format!(
"Table '{:?}' already exists",
cmd.name
))),
}
}
fn find_and_deregister<'a>(
&self,
table_ref: impl Into<TableReference<'a>>,
table_type: TableType,
) -> Result<bool> {
let table_ref = table_ref.into();
let table_provider = self
.state
.read()
.schema_for_ref(table_ref)?
.table(table_ref.table());
if let Some(table_provider) = table_provider {
if table_provider.table_type() == table_type {
self.deregister_table(table_ref)?;
return Ok(true);
}
}
Ok(false)
}
pub fn create_logical_plan(&self, sql: &str) -> Result<LogicalPlan> {
let mut statements = DFParser::parse_sql(sql)?;
if statements.len() != 1 {
return Err(DataFusionError::NotImplemented(
"The context currently only supports a single SQL statement".to_string(),
));
}
let state = self.state.read().clone();
let query_planner = SqlToRel::new(&state);
query_planner.statement_to_plan(statements.pop_front().unwrap())
}
pub fn register_variable(
&mut self,
variable_type: VarType,
provider: Arc<dyn VarProvider + Send + Sync>,
) {
self.state
.write()
.execution_props
.add_var_provider(variable_type, provider);
}
pub fn register_udf(&mut self, f: ScalarUDF) {
self.state
.write()
.scalar_functions
.insert(f.name.clone(), Arc::new(f));
}
pub fn register_udaf(&mut self, f: AggregateUDF) {
self.state
.write()
.aggregate_functions
.insert(f.name.clone(), Arc::new(f));
}
pub async fn read_avro(
&self,
table_path: impl AsRef<str>,
options: AvroReadOptions<'_>,
) -> Result<Arc<DataFrame>> {
let table_path = ListingTableUrl::parse(table_path)?;
let target_partitions = self.copied_config().target_partitions;
let listing_options = options.to_listing_options(target_partitions);
let resolved_schema = match options.schema {
Some(s) => s,
None => {
listing_options
.infer_schema(&self.state(), &table_path)
.await?
}
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let provider = ListingTable::try_new(config)?;
self.read_table(Arc::new(provider))
}
pub async fn read_json(
&mut self,
table_path: impl AsRef<str>,
options: NdJsonReadOptions<'_>,
) -> Result<Arc<DataFrame>> {
let table_path = ListingTableUrl::parse(table_path)?;
let target_partitions = self.copied_config().target_partitions;
let listing_options = options.to_listing_options(target_partitions);
let resolved_schema = match options.schema {
Some(s) => s,
None => {
listing_options
.infer_schema(&self.state(), &table_path)
.await?
}
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let provider = ListingTable::try_new(config)?;
self.read_table(Arc::new(provider))
}
pub fn read_empty(&self) -> Result<Arc<DataFrame>> {
Ok(Arc::new(DataFrame::new(
self.state.clone(),
&LogicalPlanBuilder::empty(true).build()?,
)))
}
pub async fn read_csv(
&self,
table_path: impl AsRef<str>,
options: CsvReadOptions<'_>,
) -> Result<Arc<DataFrame>> {
let table_path = ListingTableUrl::parse(table_path)?;
let target_partitions = self.copied_config().target_partitions;
let listing_options = options.to_listing_options(target_partitions);
let resolved_schema = match options.schema {
Some(s) => Arc::new(s.to_owned()),
None => {
listing_options
.infer_schema(&self.state(), &table_path)
.await?
}
};
let config = ListingTableConfig::new(table_path.clone())
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let provider = ListingTable::try_new(config)?;
self.read_table(Arc::new(provider))
}
pub async fn read_parquet(
&self,
table_path: impl AsRef<str>,
options: ParquetReadOptions<'_>,
) -> Result<Arc<DataFrame>> {
let table_path = ListingTableUrl::parse(table_path)?;
let target_partitions = self.copied_config().target_partitions;
let listing_options = options.to_listing_options(target_partitions);
let resolved_schema = listing_options
.infer_schema(&self.state(), &table_path)
.await?;
let config = ListingTableConfig::new(table_path)
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let provider = ListingTable::try_new(config)?;
self.read_table(Arc::new(provider))
}
pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<Arc<DataFrame>> {
Ok(Arc::new(DataFrame::new(
self.state.clone(),
&LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
.build()?,
)))
}
pub async fn register_listing_table(
&self,
name: &str,
table_path: impl AsRef<str>,
options: ListingOptions,
provided_schema: Option<SchemaRef>,
sql: Option<String>,
) -> Result<()> {
let table_path = ListingTableUrl::parse(table_path)?;
let resolved_schema = match provided_schema {
None => options.infer_schema(&self.state(), &table_path).await?,
Some(s) => s,
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?.with_definition(sql);
self.register_table(name, Arc::new(table))?;
Ok(())
}
pub async fn register_csv(
&self,
name: &str,
table_path: &str,
options: CsvReadOptions<'_>,
) -> Result<()> {
let listing_options =
options.to_listing_options(self.copied_config().target_partitions);
self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}
pub async fn register_json(
&self,
name: &str,
table_path: &str,
options: NdJsonReadOptions<'_>,
) -> Result<()> {
let listing_options =
options.to_listing_options(self.copied_config().target_partitions);
self.register_listing_table(
name,
table_path,
listing_options,
options.schema,
None,
)
.await?;
Ok(())
}
pub async fn register_parquet(
&self,
name: &str,
table_path: &str,
options: ParquetReadOptions<'_>,
) -> Result<()> {
let (target_partitions, parquet_pruning) = {
let conf = self.copied_config();
(conf.target_partitions, conf.parquet_pruning)
};
let listing_options = options
.parquet_pruning(parquet_pruning)
.to_listing_options(target_partitions);
self.register_listing_table(name, table_path, listing_options, None, None)
.await?;
Ok(())
}
pub async fn register_avro(
&self,
name: &str,
table_path: &str,
options: AvroReadOptions<'_>,
) -> Result<()> {
let listing_options =
options.to_listing_options(self.copied_config().target_partitions);
self.register_listing_table(
name,
table_path,
listing_options,
options.schema,
None,
)
.await?;
Ok(())
}
pub fn register_catalog(
&self,
name: impl Into<String>,
catalog: Arc<dyn CatalogProvider>,
) -> Option<Arc<dyn CatalogProvider>> {
let name = name.into();
let information_schema = self.copied_config().information_schema;
let state = self.state.read();
let catalog = if information_schema {
Arc::new(CatalogWithInformationSchema::new(
Arc::downgrade(&state.catalog_list),
catalog,
))
} else {
catalog
};
state.catalog_list.register_catalog(name, catalog)
}
pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
self.state.read().catalog_list.catalog(name)
}
pub fn register_table<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
provider: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table_ref = table_ref.into();
self.state
.read()
.schema_for_ref(table_ref)?
.register_table(table_ref.table().to_owned(), provider)
}
pub fn deregister_table<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table_ref = table_ref.into();
self.state
.read()
.schema_for_ref(table_ref)?
.deregister_table(table_ref.table())
}
pub fn table_exist<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<bool> {
let table_ref = table_ref.into();
Ok(self
.state
.read()
.schema_for_ref(table_ref)?
.table_exist(table_ref.table()))
}
pub fn table<'a>(
&self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<Arc<DataFrame>> {
let table_ref = table_ref.into();
let schema = self.state.read().schema_for_ref(table_ref)?;
match schema.table(table_ref.table()) {
Some(ref provider) => {
let plan = LogicalPlanBuilder::scan(
table_ref.table(),
provider_as_source(Arc::clone(provider)),
None,
)?
.build()?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
_ => Err(DataFusionError::Plan(format!(
"No table named '{}'",
table_ref.table()
))),
}
}
#[deprecated(
note = "Please use the catalog provider interface (`SessionContext::catalog`) to examine available catalogs, schemas, and tables"
)]
pub fn tables(&self) -> Result<HashSet<String>> {
Ok(self
.state
.read()
.schema_for_ref(TableReference::Bare { table: "" })?
.table_names()
.iter()
.cloned()
.collect())
}
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
self.state.read().optimize(plan)
}
pub async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
let state_cloned = {
let mut state = self.state.write();
state.execution_props.start_execution();
state.clone()
};
state_cloned.create_physical_plan(logical_plan).await
}
pub async fn write_csv(
&self,
plan: Arc<dyn ExecutionPlan>,
path: impl AsRef<str>,
) -> Result<()> {
let state = self.state.read().clone();
plan_to_csv(&state, plan, path).await
}
pub async fn write_json(
&self,
plan: Arc<dyn ExecutionPlan>,
path: impl AsRef<str>,
) -> Result<()> {
let state = self.state.read().clone();
plan_to_json(&state, plan, path).await
}
pub async fn write_parquet(
&self,
plan: Arc<dyn ExecutionPlan>,
path: impl AsRef<str>,
writer_properties: Option<WriterProperties>,
) -> Result<()> {
let state = self.state.read().clone();
plan_to_parquet(&state, plan, path, writer_properties).await
}
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}
pub fn state(&self) -> SessionState {
self.state.read().clone()
}
}
impl FunctionRegistry for SessionContext {
fn udfs(&self) -> HashSet<String> {
self.state.read().udfs()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.state.read().udf(name)
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.state.read().udaf(name)
}
}
#[async_trait]
pub trait QueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>>;
}
struct DefaultQueryPlanner {}
#[async_trait]
impl QueryPlanner for DefaultQueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
let planner = DefaultPhysicalPlanner::default();
planner
.create_physical_plan(logical_plan, session_state)
.await
}
}
pub const TARGET_PARTITIONS: &str = "target_partitions";
pub const REPARTITION_JOINS: &str = "repartition_joins";
pub const REPARTITION_AGGREGATIONS: &str = "repartition_aggregations";
pub const REPARTITION_WINDOWS: &str = "repartition_windows";
pub const PARQUET_PRUNING: &str = "parquet_pruning";
type AnyMap =
HashMap<TypeId, Arc<dyn Any + Send + Sync + 'static>, BuildHasherDefault<IdHasher>>;
#[derive(Default)]
struct IdHasher(u64);
impl Hasher for IdHasher {
fn write(&mut self, _: &[u8]) {
unreachable!("TypeId calls write_u64");
}
#[inline]
fn write_u64(&mut self, id: u64) {
self.0 = id;
}
#[inline]
fn finish(&self) -> u64 {
self.0
}
}
#[derive(Clone)]
pub struct SessionConfig {
pub target_partitions: usize,
default_catalog: String,
default_schema: String,
create_default_catalog_and_schema: bool,
information_schema: bool,
pub repartition_joins: bool,
pub repartition_aggregations: bool,
pub repartition_windows: bool,
pub parquet_pruning: bool,
pub config_options: ConfigOptions,
extensions: AnyMap,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
target_partitions: num_cpus::get(),
default_catalog: DEFAULT_CATALOG.to_owned(),
default_schema: DEFAULT_SCHEMA.to_owned(),
create_default_catalog_and_schema: true,
information_schema: false,
repartition_joins: true,
repartition_aggregations: true,
repartition_windows: true,
parquet_pruning: true,
config_options: ConfigOptions::new(),
extensions: HashMap::with_capacity_and_hasher(
0,
BuildHasherDefault::default(),
),
}
}
}
impl SessionConfig {
pub fn new() -> Self {
Default::default()
}
pub fn from_env() -> Self {
Self {
config_options: ConfigOptions::from_env(),
..Default::default()
}
}
pub fn set(mut self, key: &str, value: ScalarValue) -> Self {
self.config_options.set(key, value);
self
}
pub fn set_bool(self, key: &str, value: bool) -> Self {
self.set(key, ScalarValue::Boolean(Some(value)))
}
pub fn set_u64(self, key: &str, value: u64) -> Self {
self.set(key, ScalarValue::UInt64(Some(value)))
}
pub fn with_batch_size(self, n: usize) -> Self {
assert!(n > 0);
self.set_u64(OPT_BATCH_SIZE, n.try_into().unwrap())
}
pub fn with_target_partitions(mut self, n: usize) -> Self {
assert!(n > 0);
self.target_partitions = n;
self
}
pub fn with_default_catalog_and_schema(
mut self,
catalog: impl Into<String>,
schema: impl Into<String>,
) -> Self {
self.default_catalog = catalog.into();
self.default_schema = schema.into();
self
}
pub fn create_default_catalog_and_schema(mut self, create: bool) -> Self {
self.create_default_catalog_and_schema = create;
self
}
pub fn with_information_schema(mut self, enabled: bool) -> Self {
self.information_schema = enabled;
self
}
pub fn with_repartition_joins(mut self, enabled: bool) -> Self {
self.repartition_joins = enabled;
self
}
pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self {
self.repartition_aggregations = enabled;
self
}
pub fn with_repartition_windows(mut self, enabled: bool) -> Self {
self.repartition_windows = enabled;
self
}
pub fn with_parquet_pruning(mut self, enabled: bool) -> Self {
self.parquet_pruning = enabled;
self
}
pub fn batch_size(&self) -> usize {
self.config_options
.get_u64(OPT_BATCH_SIZE)
.try_into()
.unwrap()
}
pub fn config_options(&self) -> &ConfigOptions {
&self.config_options
}
pub fn to_props(&self) -> HashMap<String, String> {
let mut map = HashMap::new();
for (k, v) in self.config_options.options() {
map.insert(k.to_string(), format!("{}", v));
}
map.insert(
TARGET_PARTITIONS.to_owned(),
format!("{}", self.target_partitions),
);
map.insert(
REPARTITION_JOINS.to_owned(),
format!("{}", self.repartition_joins),
);
map.insert(
REPARTITION_AGGREGATIONS.to_owned(),
format!("{}", self.repartition_aggregations),
);
map.insert(
REPARTITION_WINDOWS.to_owned(),
format!("{}", self.repartition_windows),
);
map.insert(
PARQUET_PRUNING.to_owned(),
format!("{}", self.parquet_pruning),
);
map
}
pub fn with_extension<T>(mut self, ext: Arc<T>) -> Self
where
T: Send + Sync + 'static,
{
let ext = ext as Arc<dyn Any + Send + Sync + 'static>;
let id = TypeId::of::<T>();
self.extensions.insert(id, ext);
self
}
pub fn get_extension<T>(&self) -> Option<Arc<T>>
where
T: Send + Sync + 'static,
{
let id = TypeId::of::<T>();
self.extensions
.get(&id)
.cloned()
.map(|ext| Arc::downcast(ext).expect("TypeId unique"))
}
}
#[derive(Clone)]
pub struct SessionState {
pub session_id: String,
pub optimizer: Optimizer,
pub physical_optimizers: Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>>,
pub query_planner: Arc<dyn QueryPlanner + Send + Sync>,
pub catalog_list: Arc<dyn CatalogList>,
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
pub config: SessionConfig,
pub execution_props: ExecutionProps,
pub runtime_env: Arc<RuntimeEnv>,
}
impl Debug for SessionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionState")
.field("session_id", &self.session_id)
.finish()
}
}
pub fn default_session_builder(config: SessionConfig) -> SessionState {
SessionState::with_config_rt(config, Arc::new(RuntimeEnv::default()))
}
impl SessionState {
pub fn with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let session_id = Uuid::new_v4().to_string();
let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc<dyn CatalogList>;
if config.create_default_catalog_and_schema {
let default_catalog = MemoryCatalogProvider::new();
default_catalog
.register_schema(
&config.default_schema,
Arc::new(MemorySchemaProvider::new()),
)
.expect("memory catalog provider can register schema");
let default_catalog: Arc<dyn CatalogProvider> = if config.information_schema {
Arc::new(CatalogWithInformationSchema::new(
Arc::downgrade(&catalog_list),
Arc::new(default_catalog),
))
} else {
Arc::new(default_catalog)
};
catalog_list
.register_catalog(config.default_catalog.clone(), default_catalog);
}
let mut rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
Arc::new(SimplifyExpressions::new()),
Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Arc::new(SubqueryFilterToJoin::new()),
Arc::new(EliminateFilter::new()),
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateLimit::new()),
Arc::new(ProjectionPushDown::new()),
Arc::new(RewriteDisjunctivePredicate::new()),
];
if config.config_options.get_bool(OPT_FILTER_NULL_JOIN_KEYS) {
rules.push(Arc::new(FilterNullJoinKeys::default()));
}
rules.push(Arc::new(ReduceOuterJoin::new()));
rules.push(Arc::new(FilterPushDown::new()));
rules.push(Arc::new(TypeCoercion::new()));
rules.push(Arc::new(LimitPushDown::new()));
rules.push(Arc::new(SingleDistinctToGroupBy::new()));
let mut physical_optimizers: Vec<Arc<dyn PhysicalOptimizerRule + Sync + Send>> = vec![
Arc::new(AggregateStatistics::new()),
Arc::new(HashBuildProbeOrder::new()),
];
if config.config_options.get_bool(OPT_COALESCE_BATCHES) {
physical_optimizers.push(Arc::new(CoalesceBatches::new(
config
.config_options
.get_u64(OPT_COALESCE_TARGET_BATCH_SIZE)
.try_into()
.unwrap(),
)));
}
physical_optimizers.push(Arc::new(Repartition::new()));
physical_optimizers.push(Arc::new(AddCoalescePartitionsExec::new()));
SessionState {
session_id,
optimizer: Optimizer::new(rules),
physical_optimizers,
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
config,
execution_props: ExecutionProps::new(),
runtime_env: runtime,
}
}
fn resolve_table_ref<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> ResolvedTableReference<'a> {
table_ref
.into()
.resolve(&self.config.default_catalog, &self.config.default_schema)
}
fn schema_for_ref<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<Arc<dyn SchemaProvider>> {
let resolved_ref = self.resolve_table_ref(table_ref);
self.catalog_list
.catalog(resolved_ref.catalog)
.ok_or_else(|| {
DataFusionError::Plan(format!(
"failed to resolve catalog: {}",
resolved_ref.catalog
))
})?
.schema(resolved_ref.schema)
.ok_or_else(|| {
DataFusionError::Plan(format!(
"failed to resolve schema: {}",
resolved_ref.schema
))
})
}
pub fn with_query_planner(
mut self,
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
) -> Self {
self.query_planner = query_planner;
self
}
pub fn with_optimizer_rules(
mut self,
rules: Vec<Arc<dyn OptimizerRule + Send + Sync>>,
) -> Self {
self.optimizer = Optimizer::new(rules);
self
}
pub fn with_physical_optimizer_rules(
mut self,
physical_optimizers: Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>>,
) -> Self {
self.physical_optimizers = physical_optimizers;
self
}
pub fn add_optimizer_rule(
mut self,
optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
) -> Self {
self.optimizer.rules.push(optimizer_rule);
self
}
pub fn add_physical_optimizer_rule(
mut self,
optimizer_rule: Arc<dyn PhysicalOptimizerRule + Send + Sync>,
) -> Self {
self.physical_optimizers.push(optimizer_rule);
self
}
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
let mut optimizer_config = OptimizerConfig::new().with_skip_failing_rules(
self.config
.config_options
.get_bool(OPT_OPTIMIZER_SKIP_FAILED_RULES),
);
optimizer_config.query_execution_start_time =
self.execution_props.query_execution_start_time;
if let LogicalPlan::Explain(e) = plan {
let mut stringified_plans = e.stringified_plans.clone();
let plan = self.optimizer.optimize(
e.plan.as_ref(),
&mut optimizer_config,
|optimized_plan, optimizer| {
let optimizer_name = optimizer.name().to_string();
let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name };
stringified_plans.push(optimized_plan.to_stringified(plan_type));
},
)?;
Ok(LogicalPlan::Explain(Explain {
verbose: e.verbose,
plan: Arc::new(plan),
stringified_plans,
schema: e.schema.clone(),
}))
} else {
self.optimizer
.optimize(plan, &mut optimizer_config, |_, _| {})
}
}
pub async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
let planner = self.query_planner.clone();
let logical_plan = self.optimize(logical_plan)?;
planner.create_physical_plan(&logical_plan, self).await
}
}
impl ContextProvider for SessionState {
fn get_table_provider(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
let resolved_ref = self.resolve_table_ref(name);
match self.schema_for_ref(resolved_ref) {
Ok(schema) => {
let provider = schema.table(resolved_ref.table).ok_or_else(|| {
DataFusionError::Plan(format!(
"'{}.{}.{}' not found",
resolved_ref.catalog, resolved_ref.schema, resolved_ref.table
))
})?;
Ok(provider_as_source(provider))
}
Err(e) => Err(e),
}
}
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.scalar_functions.get(name).cloned()
}
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.aggregate_functions.get(name).cloned()
}
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
if variable_names.is_empty() {
return None;
}
let provider_type = if is_system_variables(variable_names) {
VarType::System
} else {
VarType::UserDefined
};
self.execution_props
.var_providers
.as_ref()
.and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
}
}
impl FunctionRegistry for SessionState {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"There is no UDF named \"{}\" in the registry",
name
))
})
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"There is no UDAF named \"{}\" in the registry",
name
))
})
}
}
pub enum TaskProperties {
SessionConfig(SessionConfig),
KVPairs(HashMap<String, String>),
}
pub struct TaskContext {
session_id: String,
task_id: Option<String>,
properties: TaskProperties,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
runtime: Arc<RuntimeEnv>,
}
impl TaskContext {
pub fn new(
task_id: String,
session_id: String,
task_props: HashMap<String, String>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
runtime: Arc<RuntimeEnv>,
) -> Self {
Self {
task_id: Some(task_id),
session_id,
properties: TaskProperties::KVPairs(task_props),
scalar_functions,
aggregate_functions,
runtime,
}
}
pub fn session_config(&self) -> SessionConfig {
let task_props = &self.properties;
match task_props {
TaskProperties::KVPairs(props) => {
let session_config = SessionConfig::new();
if props.is_empty() {
session_config
} else {
session_config
.with_batch_size(
props.get(OPT_BATCH_SIZE).unwrap().parse().unwrap(),
)
.with_target_partitions(
props.get(TARGET_PARTITIONS).unwrap().parse().unwrap(),
)
.with_repartition_joins(
props.get(REPARTITION_JOINS).unwrap().parse().unwrap(),
)
.with_repartition_aggregations(
props
.get(REPARTITION_AGGREGATIONS)
.unwrap()
.parse()
.unwrap(),
)
.with_repartition_windows(
props.get(REPARTITION_WINDOWS).unwrap().parse().unwrap(),
)
.with_parquet_pruning(
props.get(PARQUET_PRUNING).unwrap().parse().unwrap(),
)
}
}
TaskProperties::SessionConfig(session_config) => session_config.clone(),
}
}
pub fn session_id(&self) -> String {
self.session_id.clone()
}
pub fn task_id(&self) -> Option<String> {
self.task_id.clone()
}
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
self.runtime.clone()
}
}
impl From<&SessionContext> for TaskContext {
fn from(session: &SessionContext) -> Self {
let session_id = session.session_id.clone();
let (config, scalar_functions, aggregate_functions) = {
let session_state = session.state.read();
(
session_state.config.clone(),
session_state.scalar_functions.clone(),
session_state.aggregate_functions.clone(),
)
};
let runtime = session.runtime_env();
Self {
task_id: None,
session_id,
properties: TaskProperties::SessionConfig(config),
scalar_functions,
aggregate_functions,
runtime,
}
}
}
impl From<&SessionState> for TaskContext {
fn from(state: &SessionState) -> Self {
let session_id = state.session_id.clone();
let config = state.config.clone();
let scalar_functions = state.scalar_functions.clone();
let aggregate_functions = state.aggregate_functions.clone();
let runtime = state.runtime_env.clone();
Self {
task_id: None,
session_id,
properties: TaskProperties::SessionConfig(config),
scalar_functions,
aggregate_functions,
runtime,
}
}
}
impl FunctionRegistry for TaskContext {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDF named \"{}\" in the TaskContext",
name
))
})
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDAF named \"{}\" in the TaskContext",
name
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::context::QueryPlanner;
use crate::test;
use crate::test_util::parquet_test_data;
use crate::variable::VarType;
use crate::{
assert_batches_eq,
logical_plan::{create_udf, Expr},
};
use crate::{logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator};
use arrow::array::ArrayRef;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_expr::Volatility;
use datafusion_physical_expr::functions::make_scalar_function;
use std::fs::File;
use std::sync::Weak;
use std::thread::{self, JoinHandle};
use std::{io::prelude::*, sync::Mutex};
use tempfile::TempDir;
#[tokio::test]
async fn shared_memory_and_disk_manager() {
let ctx1 = SessionContext::new();
let memory_manager = ctx1.runtime_env().memory_manager.clone();
let disk_manager = ctx1.runtime_env().disk_manager.clone();
let ctx2 =
SessionContext::with_config_rt(SessionConfig::new(), ctx1.runtime_env());
assert!(std::ptr::eq(
Arc::as_ptr(&memory_manager),
Arc::as_ptr(&ctx1.runtime_env().memory_manager)
));
assert!(std::ptr::eq(
Arc::as_ptr(&memory_manager),
Arc::as_ptr(&ctx2.runtime_env().memory_manager)
));
assert!(std::ptr::eq(
Arc::as_ptr(&disk_manager),
Arc::as_ptr(&ctx1.runtime_env().disk_manager)
));
assert!(std::ptr::eq(
Arc::as_ptr(&disk_manager),
Arc::as_ptr(&ctx2.runtime_env().disk_manager)
));
}
#[tokio::test]
async fn create_variable_expr() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let mut ctx = create_ctx(&tmp_dir, partition_count).await?;
let variable_provider = test::variable::SystemVar::new();
ctx.register_variable(VarType::System, Arc::new(variable_provider));
let variable_provider = test::variable::UserDefinedVar::new();
ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));
let provider = test::create_table_dual();
ctx.register_table("dual", provider)?;
let results =
plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
.await?;
let expected = vec![
"+----------------------+------------------------+---------------------+",
"| @@version | @name | @integer + Int64(1) |",
"+----------------------+------------------------+---------------------+",
"| system-var-@@version | user-defined-var-@name | 42 |",
"+----------------------+------------------------+---------------------+",
];
assert_batches_eq!(expected, &results);
Ok(())
}
#[tokio::test]
async fn create_variable_err() -> Result<()> {
let ctx = SessionContext::new();
let err = plan_and_collect(&ctx, "SElECT @= X#=?!~ 5")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Execution error: variable [\"@\"] has no type information"
);
Ok(())
}
#[tokio::test]
async fn register_deregister() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = create_ctx(&tmp_dir, partition_count).await?;
let provider = test::create_table_dual();
ctx.register_table("dual", provider)?;
assert!(ctx.deregister_table("dual")?.is_some());
assert!(ctx.deregister_table("dual")?.is_none());
Ok(())
}
#[tokio::test]
async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let mut ctx = SessionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();
let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
ctx.register_udf(create_udf(
"MY_FUNC",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
myfunc,
));
let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function \'my_func\'"
);
let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?;
let expected = vec![
"+--------------+",
"| MY_FUNC(t.i) |",
"+--------------+",
"| 1 |",
"+--------------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let mut ctx = SessionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();
let my_avg = create_udaf(
"MY_AVG",
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
ctx.register_udaf(my_avg);
let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: Invalid function \'my_avg\'"
);
let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?;
let expected = vec![
"+-------------+",
"| MY_AVG(t.i) |",
"+-------------+",
"| 1 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn query_csv_with_custom_partition_extension() -> Result<()> {
let tmp_dir = TempDir::new()?;
let file_extension = ".tst";
let ctx = SessionContext::new();
let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?;
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new()
.schema(&schema)
.file_extension(file_extension),
)
.await?;
let results =
plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?;
assert_eq!(results.len(), 1);
let expected = vec![
"+--------------+--------------+-----------------+",
"| SUM(test.c1) | SUM(test.c2) | COUNT(UInt8(1)) |",
"+--------------+--------------+-----------------+",
"| 10 | 110 | 20 |",
"+--------------+--------------+-----------------+",
];
assert_batches_eq!(expected, &results);
Ok(())
}
#[tokio::test]
async fn send_context_to_threads() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = Arc::new(Mutex::new(create_ctx(&tmp_dir, partition_count).await?));
let threads: Vec<JoinHandle<Result<_>>> = (0..2)
.map(|_| ctx.clone())
.map(|ctx_clone| {
thread::spawn(move || {
let ctx = ctx_clone.lock().expect("Locked context");
ctx.create_logical_plan(
"SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3",
)
})
})
.collect();
for thread in threads {
thread.join().expect("Failed to join thread")?;
}
Ok(())
}
#[tokio::test]
async fn custom_query_planner() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let session_state = SessionState::with_config_rt(SessionConfig::new(), runtime)
.with_query_planner(Arc::new(MyQueryPlanner {}));
let ctx = SessionContext::with_state(session_state);
let df = ctx.sql("SELECT 1").await?;
df.collect().await.expect_err("query not supported");
Ok(())
}
#[tokio::test]
async fn disabled_default_catalog_and_schema() -> Result<()> {
let ctx = SessionContext::with_config(
SessionConfig::new().create_default_catalog_and_schema(false),
);
assert!(matches!(
ctx.register_table("test", test::table_with_sequence(1, 1)?),
Err(DataFusionError::Plan(_))
));
assert!(matches!(
ctx.sql("select * from datafusion.public.test").await,
Err(DataFusionError::Plan(_))
));
Ok(())
}
#[tokio::test]
async fn custom_catalog_and_schema() {
let config = SessionConfig::new()
.create_default_catalog_and_schema(true)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
#[tokio::test]
async fn custom_catalog_and_schema_no_default() {
let config = SessionConfig::new()
.create_default_catalog_and_schema(false)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
#[tokio::test]
async fn custom_catalog_and_schema_and_information_schema() {
let config = SessionConfig::new()
.create_default_catalog_and_schema(true)
.with_information_schema(true)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
async fn catalog_and_schema_test(config: SessionConfig) {
let ctx = SessionContext::with_config(config);
let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
schema
.register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap())
.unwrap();
catalog
.register_schema("my_schema", Arc::new(schema))
.unwrap();
ctx.register_catalog("my_catalog", Arc::new(catalog));
for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] {
let result = plan_and_collect(
&ctx,
&format!("SELECT COUNT(*) AS count FROM {}", table_ref),
)
.await
.unwrap();
let expected = vec![
"+-------+",
"| count |",
"+-------+",
"| 1 |",
"+-------+",
];
assert_batches_eq!(expected, &result);
}
}
#[tokio::test]
async fn cross_catalog_access() -> Result<()> {
let ctx = SessionContext::new();
let catalog_a = MemoryCatalogProvider::new();
let schema_a = MemorySchemaProvider::new();
schema_a
.register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
catalog_a.register_schema("schema_a", Arc::new(schema_a))?;
ctx.register_catalog("catalog_a", Arc::new(catalog_a));
let catalog_b = MemoryCatalogProvider::new();
let schema_b = MemorySchemaProvider::new();
schema_b
.register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
catalog_b.register_schema("schema_b", Arc::new(schema_b))?;
ctx.register_catalog("catalog_b", Arc::new(catalog_b));
let result = plan_and_collect(
&ctx,
"SELECT cat, SUM(i) AS total FROM (
SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a
UNION ALL
SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b
) AS all
GROUP BY cat
ORDER BY cat
",
)
.await?;
let expected = vec![
"+-----+-------+",
"| cat | total |",
"+-----+-------+",
"| a | 1 |",
"| b | 3 |",
"+-----+-------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn catalogs_not_leaked() {
let ctx = SessionContext::with_config(
SessionConfig::new().with_information_schema(true),
);
let catalog = Arc::new(MemoryCatalogProvider::new());
let catalog_weak = Arc::downgrade(&catalog);
ctx.register_catalog("my_catalog", catalog);
let catalog_list_weak = {
let state = ctx.state.read();
Arc::downgrade(&state.catalog_list)
};
drop(ctx);
assert_eq!(Weak::strong_count(&catalog_list_weak), 0);
assert_eq!(Weak::strong_count(&catalog_weak), 0);
}
#[tokio::test]
async fn sql_create_schema() -> Result<()> {
let ctx = SessionContext::with_config(
SessionConfig::new().with_information_schema(true),
);
ctx.sql("CREATE SCHEMA abc").await?.collect().await?;
ctx.sql("CREATE TABLE abc.y AS VALUES (1,2,3)")
.await?
.collect()
.await?;
let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
Ok(())
}
#[tokio::test]
async fn sql_create_catalog() -> Result<()> {
let ctx = SessionContext::with_config(
SessionConfig::new().with_information_schema(true),
);
ctx.sql("CREATE DATABASE test").await?.collect().await?;
ctx.sql("CREATE SCHEMA test.abc").await?.collect().await?;
ctx.sql("CREATE TABLE test.abc.y AS VALUES (1,2,3)")
.await?
.collect()
.await?;
let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_catalog='test' AND table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
Ok(())
}
#[tokio::test]
async fn read_with_glob_path() -> Result<()> {
let ctx = SessionContext::new();
let df = ctx
.read_parquet(
format!("{}/alltypes_plain*.parquet", parquet_test_data()),
ParquetReadOptions::default(),
)
.await?;
let results = df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 10);
Ok(())
}
#[tokio::test]
async fn read_with_glob_path_issue_2465() -> Result<()> {
let ctx = SessionContext::new();
let df = ctx
.read_parquet(
format!("{}/..//*/alltypes_plain*.parquet", parquet_test_data()),
ParquetReadOptions::default(),
)
.await?;
let results = df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 10);
Ok(())
}
#[tokio::test]
async fn read_from_registered_table_with_glob_path() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_parquet(
"test",
&format!("{}/alltypes_plain*.parquet", parquet_test_data()),
ParquetReadOptions::default(),
)
.await?;
let df = ctx.sql("SELECT * FROM test").await?;
let results = df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 10);
Ok(())
}
struct MyPhysicalPlanner {}
#[async_trait]
impl PhysicalPlanner for MyPhysicalPlanner {
async fn create_physical_plan(
&self,
_logical_plan: &LogicalPlan,
_session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::NotImplemented(
"query not supported".to_string(),
))
}
fn create_physical_expr(
&self,
_expr: &Expr,
_input_dfschema: &crate::logical_plan::DFSchema,
_input_schema: &Schema,
_session_state: &SessionState,
) -> Result<Arc<dyn crate::physical_plan::PhysicalExpr>> {
unimplemented!()
}
}
struct MyQueryPlanner {}
#[async_trait]
impl QueryPlanner for MyQueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
let physical_planner = MyPhysicalPlanner {};
physical_planner
.create_physical_plan(logical_plan, session_state)
.await
}
}
async fn plan_and_collect(
ctx: &SessionContext,
sql: &str,
) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}
fn populate_csv_partitions(
tmp_dir: &TempDir,
partition_count: usize,
file_extension: &str,
) -> Result<SchemaRef> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt64, false),
Field::new("c3", DataType::Boolean, false),
]));
for partition in 0..partition_count {
let filename = format!("partition-{}.{}", partition, file_extension);
let file_path = tmp_dir.path().join(&filename);
let mut file = File::create(file_path)?;
for i in 0..=10 {
let data = format!("{},{},{}\n", partition, i, i % 2 == 0);
file.write_all(data.as_bytes())?;
}
}
Ok(schema)
}
async fn create_ctx(
tmp_dir: &TempDir,
partition_count: usize,
) -> Result<SessionContext> {
let ctx =
SessionContext::with_config(SessionConfig::new().with_target_partitions(8));
let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?;
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new().schema(&schema),
)
.await?;
Ok(ctx)
}
#[async_trait]
trait CallReadTrait {
async fn call_read_csv(&self) -> Arc<DataFrame>;
async fn call_read_avro(&self) -> Arc<DataFrame>;
async fn call_read_parquet(&self) -> Arc<DataFrame>;
}
struct CallRead {}
#[async_trait]
impl CallReadTrait for CallRead {
async fn call_read_csv(&self) -> Arc<DataFrame> {
let ctx = SessionContext::new();
ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap()
}
async fn call_read_avro(&self) -> Arc<DataFrame> {
let ctx = SessionContext::new();
ctx.read_avro("dummy", AvroReadOptions::default())
.await
.unwrap()
}
async fn call_read_parquet(&self) -> Arc<DataFrame> {
let ctx = SessionContext::new();
ctx.read_parquet("dummy", ParquetReadOptions::default())
.await
.unwrap()
}
}
}