Skip to content

Commit

Permalink
use datafusion catalog directly + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
connortsui20 committed Mar 3, 2025
1 parent 9a4c6f1 commit 97d3db0
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 20 deletions.
8 changes: 8 additions & 0 deletions optd-datafusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@ To run the demo, execute the following command:
$ cargo run --example -p optd-datafusion --example demo -- <path_to_sql>
$ cargo run --example -p optd-datafusion --example demo -- optd-datafusion/sql/test_join.sql
```

# Tests

To run the tests, simply run:

```sh
$ cargo test --test run_queries
```
9 changes: 8 additions & 1 deletion optd-datafusion/src/df_conversion/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ pub(crate) struct OptdDFContext {
/// Maps table names to DataFusion [`TableProvider`]s.
pub(crate) providers: HashMap<String, Arc<dyn TableProvider>>,
/// DataFusion session state.
pub(crate) session_state: SessionState,
///
/// We only need to carry this around to create `DataFusion` Scan nodes.
session_state: SessionState,
}

impl OptdDFContext {
Expand All @@ -19,6 +21,11 @@ impl OptdDFContext {
session_state: session_state.clone(),
}
}

/// Returns the DataFusion session state.
pub(crate) fn session_state(&self) -> &SessionState {
&self.session_state
}
}

impl Debug for OptdDFContext {
Expand Down
6 changes: 5 additions & 1 deletion optd-datafusion/src/df_conversion/from_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use optd_core::{
use std::{collections::HashMap, str::FromStr, sync::Arc};

impl OptdDFContext {
/// Converts an `optd` [`PhysicalPlan`] into an executable DataFusion [`ExecutionPlan`].
#[async_recursion]
pub(crate) async fn optd_to_df_relational(
&self,
Expand All @@ -41,7 +42,7 @@ impl OptdDFContext {
// TODO(yuchen): support filters inside table scan.
let filters = vec![];
let plan = provider
.scan(&self.session_state, None, &filters, None)
.scan(self.session_state(), None, &filters, None)
.await?;

Ok(plan)
Expand Down Expand Up @@ -127,6 +128,9 @@ impl OptdDFContext {
}
}

/// Converts an `optd` [`ScalarPlan`] into a physical DataFusion [`PhysicalExpr`].
///
/// TODO(connor): Is the context necessary if we have a catalog?
pub(crate) fn optd_to_df_scalar(
pred: &ScalarPlan,
context: &SchemaRef,
Expand Down
65 changes: 50 additions & 15 deletions optd-datafusion/src/iceberg_conversion.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,61 @@
use crate::NAMESPACE;
use datafusion::catalog::TableProvider;
use datafusion::common::arrow::datatypes::{DataType as DFType, Schema as DFSchema};
use datafusion::execution::SessionState;
use iceberg::spec::{NestedField, PrimitiveType, Schema, Type};
use iceberg::{Catalog, NamespaceIdent, Result, TableCreation, TableIdent};
use std::sync::atomic::{AtomicI32, Ordering};
use std::{collections::HashMap, sync::Arc};

static FIELD_ID: AtomicI32 = AtomicI32::new(0);
use std::sync::Arc;

// Given a map of table names to [`TableProvider`]s, ingest them into an Iceberg [`Catalog`].
pub(crate) async fn ingest_providers<C>(
catalog: &C,
providers: &HashMap<String, Arc<dyn TableProvider>>,
iceberg_catalog: &C,
datafusion_session: &SessionState,
) -> Result<()>
where
C: Catalog,
{
let mut catalog_names = datafusion_session.catalog_list().catalog_names();
assert_eq!(
catalog_names.len(),
1,
"TODO(connor): There should only be 1 catalog by our usage"
);

let catalog_name = catalog_names.pop().expect("We checked non-empty above");

let datafusion_catalog = datafusion_session
.catalog_list()
.catalog(&catalog_name)
.expect("This catalog must exist if it was just listed");

// Ignore the method name, it's just DataFusion naming.
let mut table_collection_names = datafusion_catalog.schema_names();
assert_eq!(
table_collection_names.len(),
1,
"TODO(connor): There should only be 1 catalog by our usage"
);

let table_collection_name = table_collection_names
.pop()
.expect("We checked non-empty above");

let table_collection = datafusion_catalog
.schema(&table_collection_name)
.expect("This collection must exist if it was just listed");

let namespace_ident = NamespaceIdent::from_vec(vec![NAMESPACE.to_string()]).unwrap();

for (name, provider) in providers {
for name in table_collection.table_names() {
let provider = table_collection
.table(&name)
.await
.expect("TODO(connor): Error handle")
.expect("This table must exist if it was just listed");

// Create the table identifier.
let table_ident = TableIdent::new(namespace_ident.clone(), name.clone());

if catalog.table_exists(&table_ident).await? {
if iceberg_catalog.table_exists(&table_ident).await? {
eprintln!("TODO(connor): Table update is unimplemented, doing nothing for now");
} else {
let df_schema = provider.schema();
Expand All @@ -37,7 +70,9 @@ where
sort_order: None,
};

catalog.create_table(&namespace_ident, create_table).await?;
iceberg_catalog
.create_table(&namespace_ident, create_table)
.await?;
}
}

Expand All @@ -48,12 +83,12 @@ where
fn df_to_iceberg_schema(df_schema: &DFSchema) -> Schema {
let fields = &df_schema.fields;

let fields = fields.iter().map(|field| {
let fields = fields.iter().enumerate().map(|(i, field)| {
let field_name = field.name();
let iceberg_type = df_to_iceberg_datatype(field.data_type());
let iceberg_type = df_to_iceberg_type(field.data_type());

Arc::new(NestedField {
id: FIELD_ID.fetch_add(1, Ordering::Relaxed),
id: i as i32,
name: field_name.clone(),
required: true,
field_type: Box::new(iceberg_type),
Expand All @@ -76,9 +111,9 @@ fn df_to_iceberg_schema(df_schema: &DFSchema) -> Schema {
/// See:
/// - https://docs.rs/datafusion/latest/datafusion/common/arrow/datatypes/enum.DataType.html
/// - https://docs.rs/iceberg/latest/iceberg/spec/enum.Type.html
fn df_to_iceberg_datatype(df_datatype: &DFType) -> Type {
fn df_to_iceberg_type(df_datatype: &DFType) -> Type {
match df_datatype {
DFType::Null => unimplemented!("All Iceberg types are nullable"),
DFType::Null => unimplemented!("TODO: All Iceberg types are (seem to be) nullable"),
DFType::Boolean => Type::Primitive(PrimitiveType::Boolean),
DFType::Int8 => Type::Primitive(PrimitiveType::Int),
DFType::Int16 => Type::Primitive(PrimitiveType::Int),
Expand Down
10 changes: 7 additions & 3 deletions optd-datafusion/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl MockOptdOptimizer {
let optimized_plan =
cascades::match_any_physical_plan(self.0.memo.as_ref(), goal_id).await?;

// We are allowed to do anything we want with the catalog here.
std::hint::black_box(&self.0.catalog);

Ok(optimized_plan)
Expand Down Expand Up @@ -132,9 +133,12 @@ impl QueryPlanner for MockOptdOptimizer {

// The DataFusion to `optd` conversion will have read in all of the tables necessary to
// execute the query. Now we can update our own catalog with any new tables.
crate::iceberg_conversion::ingest_providers(self.0.catalog.as_ref(), &optd_ctx.providers)
.await
.expect("Unable to ingest providers");
crate::iceberg_conversion::ingest_providers(
self.0.catalog.as_ref(),
optd_ctx.session_state(),
)
.await
.expect("Unable to ingest providers");

// Run the `optd` optimizer on the `LogicalPlan`.
let optd_optimized_physical_plan = self
Expand Down
51 changes: 51 additions & 0 deletions optd-datafusion/tests/run_queries.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//! Simple tests that only check that the program doesn't panic.
use std::error::Error;
use std::fs;

use optd_datafusion::run_queries;

#[tokio::test]
async fn test_scan() -> Result<(), Box<dyn Error>> {
let file = fs::read_to_string("./sql/test_scan.sql")?;

// Retrieve all of the SQL queries from the file.
let queries: Vec<&str> = file
.split(';')
.filter(|query| !query.trim().is_empty())
.collect();

run_queries(&queries).await?;

Ok(())
}

#[tokio::test]
async fn test_filter() -> Result<(), Box<dyn Error>> {
let file = fs::read_to_string("./sql/test_filter.sql")?;

// Retrieve all of the SQL queries from the file.
let queries: Vec<&str> = file
.split(';')
.filter(|query| !query.trim().is_empty())
.collect();

run_queries(&queries).await?;

Ok(())
}

#[tokio::test]
async fn test_join() -> Result<(), Box<dyn Error>> {
let file = fs::read_to_string("./sql/test_join.sql")?;

// Retrieve all of the SQL queries from the file.
let queries: Vec<&str> = file
.split(';')
.filter(|query| !query.trim().is_empty())
.collect();

run_queries(&queries).await?;

Ok(())
}

0 comments on commit 97d3db0

Please sign in to comment.