Skip to content

Commit

Permalink
Allow access to UDTF in SessionContext (apache#11071)
Browse files Browse the repository at this point in the history
* get table function from session context

* support deregistering UDTF

* add `state_ref()` for `SessionContext`

---------

Co-authored-by: Shehab <[email protected]>
  • Loading branch information
2 people authored and findepi committed Jul 16, 2024
1 parent d2d6177 commit ed1c28c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
5 changes: 5 additions & 0 deletions datafusion/core/src/datasource/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ impl TableFunction {
&self.name
}

/// Get the implementation of the table function
pub fn function(&self) -> &Arc<dyn TableFunctionImpl> {
&self.fun
}

/// Get the function implementation and generate a table
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
self.fun.call(args)
Expand Down
29 changes: 27 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
catalog::{CatalogProvider, CatalogProviderList, MemoryCatalogProvider},
dataframe::DataFrame,
datasource::{
function::TableFunctionImpl,
function::{TableFunction, TableFunctionImpl},
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
provider::TableProviderFactory,
},
Expand All @@ -52,7 +52,7 @@ use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use datafusion_common::{
config::{ConfigExtension, TableOptions},
exec_err, not_impl_err, plan_err,
exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNodeRecursion, TreeNodeVisitor},
DFSchema, SchemaReference, TableReference,
};
Expand Down Expand Up @@ -928,6 +928,7 @@ impl SessionContext {
dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some();

// DROP FUNCTION IF EXISTS drops the specified function only if that
// function exists and in this way, it avoids error. While the DROP FUNCTION
Expand Down Expand Up @@ -1008,6 +1009,11 @@ impl SessionContext {
self.state.write().deregister_udwf(name).ok();
}

/// Deregisters a UDTF within this context.
pub fn deregister_udtf(&self, name: &str) {
self.state.write().deregister_udtf(name).ok();
}

/// Creates a [`DataFrame`] for reading a data source.
///
/// For more control such as reading multiple files, you can use
Expand Down Expand Up @@ -1266,6 +1272,20 @@ impl SessionContext {
Ok(DataFrame::new(self.state(), plan))
}

/// Retrieves a [`TableFunction`] reference by name.
///
/// Returns an error if no table function has been registered with the provided name.
///
/// [`register_udtf`]: SessionContext::register_udtf
pub fn table_function(&self, name: &str) -> Result<Arc<TableFunction>> {
self.state
.read()
.table_functions()
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Table function '{name}' not found"))
}

/// Return a [`TableProvider`] for the specified table.
pub async fn table_provider<'a>(
&self,
Expand Down Expand Up @@ -1303,6 +1323,11 @@ impl SessionContext {
state
}

/// Get reference to [`SessionState`]
pub fn state_ref(&self) -> Arc<RwLock<SessionState>> {
self.state.clone()
}

/// Get weak reference to [`SessionState`]
pub fn state_weak_ref(&self) -> Weak<RwLock<SessionState>> {
Arc::downgrade(&self.state)
Expand Down
14 changes: 14 additions & 0 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,11 @@ impl SessionState {
&self.window_functions
}

/// Return reference to table_functions
pub fn table_functions(&self) -> &HashMap<String, Arc<TableFunction>> {
&self.table_functions
}

/// Return [SerializerRegistry] for extensions
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
self.serializer_registry.clone()
Expand All @@ -851,6 +856,15 @@ impl SessionState {
Arc::new(TableFunction::new(name.to_owned(), fun)),
);
}

/// Deregsiter a user defined table function
pub fn deregister_udtf(
&mut self,
name: &str,
) -> datafusion_common::Result<Option<Arc<dyn TableFunctionImpl>>> {
let udtf = self.table_functions.remove(name);
Ok(udtf.map(|x| x.function().clone()))
}
}

struct SessionContextProvider<'a> {
Expand Down
15 changes: 15 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_table_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ async fn test_simple_read_csv_udtf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_deregister_udtf() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {}));

assert!(ctx.state().table_functions().contains_key("read_csv"));

ctx.deregister_udtf("read_csv");

assert!(!ctx.state().table_functions().contains_key("read_csv"));

Ok(())
}

struct SimpleCsvTable {
schema: SchemaRef,
exprs: Vec<Expr>,
Expand Down

0 comments on commit ed1c28c

Please sign in to comment.