use datafusion_common::{DataFusionError, Result};
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use url::Url;
#[derive(Debug, Clone)]
pub struct ObjectStoreUrl {
url: Url,
}
impl ObjectStoreUrl {
pub fn parse(s: impl AsRef<str>) -> Result<Self> {
let mut parsed =
Url::parse(s.as_ref()).map_err(|e| DataFusionError::External(Box::new(e)))?;
let remaining = &parsed[url::Position::BeforePath..];
if !remaining.is_empty() && remaining != "/" {
return Err(DataFusionError::Execution(format!(
"ObjectStoreUrl must only contain scheme and authority, got: {}",
remaining
)));
}
parsed.set_path("/");
Ok(Self { url: parsed })
}
pub fn local_filesystem() -> Self {
Self::parse("file://").unwrap()
}
pub fn as_str(&self) -> &str {
self.as_ref()
}
}
impl AsRef<str> for ObjectStoreUrl {
fn as_ref(&self) -> &str {
self.url.as_ref()
}
}
impl AsRef<Url> for ObjectStoreUrl {
fn as_ref(&self) -> &Url {
&self.url
}
}
impl std::fmt::Display for ObjectStoreUrl {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.as_str().fmt(f)
}
}
pub struct ObjectStoreRegistry {
object_stores: RwLock<HashMap<String, Arc<dyn ObjectStore>>>,
}
impl std::fmt::Debug for ObjectStoreRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ObjectStoreRegistry")
.field(
"schemes",
&self.object_stores.read().keys().collect::<Vec<_>>(),
)
.finish()
}
}
impl Default for ObjectStoreRegistry {
fn default() -> Self {
Self::new()
}
}
impl ObjectStoreRegistry {
pub fn new() -> Self {
let mut map: HashMap<String, Arc<dyn ObjectStore>> = HashMap::new();
map.insert("file://".to_string(), Arc::new(LocalFileSystem::new()));
Self {
object_stores: RwLock::new(map),
}
}
pub fn register_store(
&self,
scheme: impl AsRef<str>,
host: impl AsRef<str>,
store: Arc<dyn ObjectStore>,
) -> Option<Arc<dyn ObjectStore>> {
let mut stores = self.object_stores.write();
let s = format!("{}://{}", scheme.as_ref(), host.as_ref());
stores.insert(s, store)
}
pub fn get_by_url(&self, url: impl AsRef<Url>) -> Result<Arc<dyn ObjectStore>> {
let url = url.as_ref();
let s = &url[url::Position::BeforeScheme..url::Position::AfterHost];
let stores = self.object_stores.read();
let store = stores.get(s).ok_or_else(|| {
DataFusionError::Internal(format!(
"No suitable object store found for {}",
url
))
})?;
Ok(store.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::listing::ListingTableUrl;
use std::sync::Arc;
#[test]
fn test_object_store_url() {
let listing = ListingTableUrl::parse("file:///").unwrap();
let store = listing.object_store();
assert_eq!(store.as_str(), "file:///");
let file = ObjectStoreUrl::parse("file://").unwrap();
assert_eq!(file.as_str(), "file:///");
let listing = ListingTableUrl::parse("s3://bucket/").unwrap();
let store = listing.object_store();
assert_eq!(store.as_str(), "s3://bucket/");
let url = ObjectStoreUrl::parse("s3://bucket").unwrap();
assert_eq!(url.as_str(), "s3://bucket/");
let url = ObjectStoreUrl::parse("s3://username:password@host:123").unwrap();
assert_eq!(url.as_str(), "s3://username:password@host:123/");
let err = ObjectStoreUrl::parse("s3://bucket:invalid").unwrap_err();
assert_eq!(err.to_string(), "External error: invalid port number");
let err = ObjectStoreUrl::parse("s3://bucket?").unwrap_err();
assert_eq!(err.to_string(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: ?");
let err = ObjectStoreUrl::parse("s3://bucket?foo=bar").unwrap_err();
assert_eq!(err.to_string(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: ?foo=bar");
let err = ObjectStoreUrl::parse("s3://host:123/foo").unwrap_err();
assert_eq!(err.to_string(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: /foo");
let err =
ObjectStoreUrl::parse("s3://username:password@host:123/foo").unwrap_err();
assert_eq!(err.to_string(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: /foo");
}
#[test]
fn test_get_by_url_s3() {
let sut = ObjectStoreRegistry::default();
sut.register_store("s3", "bucket", Arc::new(LocalFileSystem::new()));
let url = ListingTableUrl::parse("s3://bucket/key").unwrap();
sut.get_by_url(&url).unwrap();
}
#[test]
fn test_get_by_url_file() {
let sut = ObjectStoreRegistry::default();
let url = ListingTableUrl::parse("file:///bucket/key").unwrap();
sut.get_by_url(&url).unwrap();
}
#[test]
fn test_get_by_url_local() {
let sut = ObjectStoreRegistry::default();
let url = ListingTableUrl::parse("../").unwrap();
sut.get_by_url(&url).unwrap();
}
}