pub mod memory;
pub use datafusion_sql::{ResolvedTableReference, TableReference};
pub use memory::{
MemoryCatalogProvider, MemoryCatalogProviderList, MemorySchemaProvider,
};
use std::collections::BTreeSet;
use std::ops::ControlFlow;
mod r#async;
mod catalog;
mod dynamic_file;
pub mod information_schema;
mod schema;
mod session;
mod table;
pub use catalog::*;
pub use dynamic_file::catalog::*;
pub use r#async::*;
pub use schema::*;
pub use session::*;
pub use table::*;
pub mod streaming;
pub fn resolve_table_references(
statement: &datafusion_sql::parser::Statement,
enable_ident_normalization: bool,
) -> datafusion_common::Result<(Vec<TableReference>, Vec<TableReference>)> {
use datafusion_sql::parser::{
CopyToSource, CopyToStatement, Statement as DFStatement,
};
use datafusion_sql::planner::object_name_to_table_reference;
use information_schema::INFORMATION_SCHEMA;
use information_schema::INFORMATION_SCHEMA_TABLES;
use sqlparser::ast::*;
struct RelationVisitor {
relations: BTreeSet<ObjectName>,
all_ctes: BTreeSet<ObjectName>,
ctes_in_scope: Vec<ObjectName>,
}
impl RelationVisitor {
fn insert_relation(&mut self, relation: &ObjectName) {
if !self.relations.contains(relation)
&& !self.ctes_in_scope.contains(relation)
{
self.relations.insert(relation.clone());
}
}
}
impl Visitor for RelationVisitor {
type Break = ();
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> {
self.insert_relation(relation);
ControlFlow::Continue(())
}
fn pre_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
if let Some(with) = &q.with {
for cte in &with.cte_tables {
if !with.recursive {
cte.visit(self);
}
self.ctes_in_scope
.push(ObjectName(vec![cte.alias.name.clone()]));
}
}
ControlFlow::Continue(())
}
fn post_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
if let Some(with) = &q.with {
for _ in &with.cte_tables {
self.all_ctes.insert(self.ctes_in_scope.pop().unwrap());
}
}
ControlFlow::Continue(())
}
fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> {
if let Statement::ShowCreate {
obj_type: ShowCreateObject::Table | ShowCreateObject::View,
obj_name,
} = statement
{
self.insert_relation(obj_name)
}
let requires_information_schema = matches!(
statement,
Statement::ShowFunctions { .. }
| Statement::ShowVariable { .. }
| Statement::ShowStatus { .. }
| Statement::ShowVariables { .. }
| Statement::ShowCreate { .. }
| Statement::ShowColumns { .. }
| Statement::ShowTables { .. }
| Statement::ShowCollation { .. }
);
if requires_information_schema {
for s in INFORMATION_SCHEMA_TABLES {
self.relations.insert(ObjectName(vec![
Ident::new(INFORMATION_SCHEMA),
Ident::new(*s),
]));
}
}
ControlFlow::Continue(())
}
}
let mut visitor = RelationVisitor {
relations: BTreeSet::new(),
all_ctes: BTreeSet::new(),
ctes_in_scope: vec![],
};
fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) {
match statement {
DFStatement::Statement(s) => {
let _ = s.as_ref().visit(visitor);
}
DFStatement::CreateExternalTable(table) => {
visitor.relations.insert(table.name.clone());
}
DFStatement::CopyTo(CopyToStatement { source, .. }) => match source {
CopyToSource::Relation(table_name) => {
visitor.insert_relation(table_name);
}
CopyToSource::Query(query) => {
query.visit(visitor);
}
},
DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor),
}
}
visit_statement(statement, &mut visitor);
let table_refs = visitor
.relations
.into_iter()
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
.collect::<datafusion_common::Result<_>>()?;
let ctes = visitor
.all_ctes
.into_iter()
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
.collect::<datafusion_common::Result<_>>()?;
Ok((table_refs, ctes))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_table_references_shadowed_cte() {
use datafusion_sql::parser::DFParser;
let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t";
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
assert_eq!(table_refs.len(), 1);
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].to_string(), "t");
assert_eq!(table_refs[0].to_string(), "t");
let query = "(with t as (select 1) select * from t) union (select * from t)";
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
assert_eq!(table_refs.len(), 1);
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].to_string(), "t");
assert_eq!(table_refs[0].to_string(), "t");
let query = "(with t as (with u as (select 1) select * from u) select * from u cross join t)";
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
assert_eq!(table_refs.len(), 1);
assert_eq!(ctes.len(), 2);
assert_eq!(ctes[0].to_string(), "t");
assert_eq!(ctes[1].to_string(), "u");
assert_eq!(table_refs[0].to_string(), "u");
}
#[test]
fn resolve_table_references_recursive_cte() {
use datafusion_sql::parser::DFParser;
let query = "
WITH RECURSIVE nodes AS (
SELECT 1 as id
UNION ALL
SELECT id + 1 as id
FROM nodes
WHERE id < 10
)
SELECT * FROM nodes
";
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
assert_eq!(table_refs.len(), 0);
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].to_string(), "nodes");
}
}