Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connect): sql #3696

Merged
merged 5 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-catalog/python-catalog/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl DataCatalogTable for PythonTable {
}

/// Wrapper around a `daft.catalog.python_catalog.PythonCatalog`
#[derive(Debug)]
pub struct PythonCatalog {
python_catalog_pyobj: PyObject,
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-catalog/src/data_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{data_catalog_table::DataCatalogTable, errors::Result};
///
/// It allows registering and retrieving data sources, as well as querying their schemas.
/// The catalog is used by the query planner to resolve table references in queries.
pub trait DataCatalog: Sync + Send {
pub trait DataCatalog: Sync + Send + std::fmt::Debug {
/// Lists the fully-qualified names of tables in the catalog with the specified prefix
fn list_tables(&self, prefix: &str) -> Result<Vec<String>>;

Expand Down
24 changes: 22 additions & 2 deletions src/daft-catalog/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
///
/// Users of Daft can register various [`DataCatalog`] with Daft, enabling
/// discovery of tables across various [`DataCatalog`] implementations.
#[derive(Debug, Clone, Default)]
pub struct DaftMetaCatalog {
/// Map of catalog names to the DataCatalog impls.
///
Expand Down Expand Up @@ -95,16 +96,26 @@
}

/// Registers a LogicalPlan with a name in the DaftMetaCatalog
pub fn register_named_table(&mut self, name: &str, view: LogicalPlanBuilder) -> Result<()> {
pub fn register_named_table(
&mut self,
name: &str,
view: impl Into<LogicalPlanBuilder>,
) -> Result<()> {
if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(Error::InvalidTableName {
name: name.to_string(),
});
}
self.named_tables.insert(name.to_string(), view);
self.named_tables.insert(name.to_string(), view.into());
Ok(())
}

/// Check if a named table is registered in the DaftMetaCatalog
///
pub fn contains_named_table(&self, name: &str) -> bool {
self.named_tables.contains_key(name)
}

Check warning on line 117 in src/daft-catalog/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-catalog/src/lib.rs#L115-L117

Added lines #L115 - L117 were not covered by tests

/// Provides high-level functionality for reading a table of data against a [`DaftMetaCatalog`]
///
/// Resolves the provided table_identifier against the catalog:
Expand Down Expand Up @@ -146,6 +157,15 @@
table_id: searched_table_name.to_string(),
})
}
/// Copy from another catalog, using tables from other in case of conflict
pub fn copy_from(&mut self, other: &Self) {
for (name, plan) in &other.named_tables {
self.named_tables.insert(name.clone(), plan.clone());
}
for (name, catalog) in &other.data_catalogs {
self.data_catalogs.insert(name.clone(), catalog.clone());
}

Check warning on line 167 in src/daft-catalog/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-catalog/src/lib.rs#L166-L167

Added lines #L166 - L167 were not covered by tests
}
}

#[cfg(test)]
Expand Down
4 changes: 3 additions & 1 deletion src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ async-stream = "0.3.6"
common-daft-config = {workspace = true, optional = true, features = ["python"]}
common-error = {workspace = true, optional = true, features = ["python"]}
common-file-formats = {workspace = true, optional = true, features = ["python"]}
daft-catalog = {path = "../daft-catalog", optional = true, features = ["python"]}
daft-core = {workspace = true, optional = true, features = ["python"]}
daft-dsl = {workspace = true, optional = true, features = ["python"]}
daft-local-execution = {workspace = true, optional = true, features = ["python"]}
Expand Down Expand Up @@ -43,7 +44,8 @@ python = [
"dep:daft-scan",
"dep:daft-schema",
"dep:daft-sql",
"dep:daft-table"
"dep:daft-table",
"dep:daft-catalog"
]

[lints]
Expand Down
17 changes: 12 additions & 5 deletions src/daft-connect/src/connect_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,24 @@
}
OpType::Command(command) => {
let command = command.command_type.required("command_type")?;

match command {
CommandType::WriteOperation(op) => {
let result = session.execute_write_operation(op, rb).await?;
Ok(Response::new(result))
}
CommandType::RegisterFunction(_) => todo!(),

Check warning on line 87 in src/daft-connect/src/connect_service.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/connect_service.rs#L87

Added line #L87 was not covered by tests
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
CommandType::CreateDataframeView(create_dataframe) => {
let result = session
.execute_create_dataframe_view(create_dataframe, rb)
.await?;

Check warning on line 91 in src/daft-connect/src/connect_service.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/connect_service.rs#L91

Added line #L91 was not covered by tests
Ok(Response::new(result))
}
CommandType::SqlCommand(sql) => {
let result = session.execute_sql_command(sql, rb).await?;
Ok(Response::new(result))
}
other => {
return not_yet_implemented!(
"Command type: {}",
command_type_to_str(&other)
)
not_yet_implemented!("CommandType '{:?}'", command_type_to_str(&other))

Check warning on line 99 in src/daft-connect/src/connect_service.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/connect_service.rs#L99

Added line #L99 was not covered by tests
}
}
}
Expand Down
135 changes: 133 additions & 2 deletions src/daft-connect/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use daft_micropartition::MicroPartition;
use daft_ray_execution::RayEngine;
use daft_table::Table;
use eyre::bail;
use eyre::{bail, Context};
use futures::{
stream::{self, BoxStream},
StreamExt, TryFutureExt, TryStreamExt,
Expand All @@ -19,7 +19,8 @@
use spark_connect::{
relation::RelType,
write_operation::{SaveMode, SaveType},
ExecutePlanResponse, Relation, ShowString, WriteOperation,
CreateDataFrameViewCommand, ExecutePlanResponse, Relation, ShowString, SqlCommand,
WriteOperation,
};
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
use tracing::debug;
Expand Down Expand Up @@ -236,6 +237,136 @@
Ok(Box::pin(stream))
}

pub async fn execute_create_dataframe_view(
&self,
create_dataframe: CreateDataFrameViewCommand,
rb: ResponseBuilder<ExecutePlanResponse>,
) -> Result<ExecuteStream, Status> {
let CreateDataFrameViewCommand {
input,
name,
is_global,
replace,
} = create_dataframe;

if is_global {
return not_yet_implemented!("Global dataframe view");

Check warning on line 253 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L253

Added line #L253 was not covered by tests
}

let input = input.required("input")?;
let input = SparkAnalyzer::new(self)
.to_logical_plan(input)
.await

Check warning on line 259 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L259

Added line #L259 was not covered by tests
.map_err(|e| {
Status::internal(
textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
)

Check warning on line 263 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L261-L263

Added lines #L261 - L263 were not covered by tests
})?;

{
let catalog = self.catalog.read().unwrap();
if !replace && catalog.contains_named_table(&name) {
return Err(Status::internal("Dataframe view already exists"));

Check warning on line 269 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L269

Added line #L269 was not covered by tests
}
}

let mut catalog = self.catalog.write().unwrap();

catalog.register_named_table(&name, input).map_err(|e| {
Status::internal(textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"))

Check warning on line 276 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L276

Added line #L276 was not covered by tests
})?;

let response = rb.result_complete_response();
let stream = stream::once(ready(Ok(response)));
Ok(Box::pin(stream))
}

#[allow(deprecated)]
pub async fn execute_sql_command(
&self,
SqlCommand {
sql,
args,
pos_args,
named_arguments,
pos_arguments,
input,
}: SqlCommand,
res: ResponseBuilder<ExecutePlanResponse>,
) -> Result<ExecuteStream, Status> {
if !args.is_empty() {
return not_yet_implemented!("Named arguments");

Check warning on line 298 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L298

Added line #L298 was not covered by tests
}
if !pos_args.is_empty() {
return not_yet_implemented!("Positional arguments");

Check warning on line 301 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L301

Added line #L301 was not covered by tests
}
if !named_arguments.is_empty() {
return not_yet_implemented!("Named arguments");

Check warning on line 304 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L304

Added line #L304 was not covered by tests
}
if !pos_arguments.is_empty() {
return not_yet_implemented!("Positional arguments");

Check warning on line 307 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L307

Added line #L307 was not covered by tests
}

if input.is_some() {
return not_yet_implemented!("Input");

Check warning on line 311 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L311

Added line #L311 was not covered by tests
}

let catalog = self.catalog.read().unwrap();
let catalog = catalog.clone();

let mut planner = daft_sql::SQLPlanner::new(catalog);

let plan = planner
.plan_sql(&sql)
.wrap_err("Error planning SQL")
.map_err(|e| {
Status::internal(
textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
)

Check warning on line 325 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L323-L325

Added lines #L323 - L325 were not covered by tests
})?;

let plan = LogicalPlanBuilder::from(plan);

// TODO: code duplication
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this reminder for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a pretty big block of code duplicated in a few spots that I want to clean up.

        // TODO: code duplication
        let result_complete = res.result_complete_response();

        let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(1);

        let this = self.clone();

        tokio::spawn(async move {
            let execution_fut = async {
                let mut result_stream = this.run_query(plan).await?;
                while let Some(result) = result_stream.next().await {
                    let result = result?;
                    let tables = result.get_tables()?;
                    for table in tables.as_slice() {
                        let response = res.arrow_batch_response(table)?;
                        if tx.send(Ok(response)).await.is_err() {
                            return Ok(());
                        }
                    }
                }
                Ok(())
            };
            if let Err(e) = execution_fut.await {
                let _ = tx.send(Err(e)).await;
            }
        });

        let stream = ReceiverStream::new(rx);

        let stream = stream
            .map_err(|e| {
                Status::internal(
                    textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
                )
            })
            .chain(stream::once(ready(Ok(result_complete))));

        Ok(Box::pin(stream))

let result_complete = res.result_complete_response();

let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(1);

let this = self.clone();

tokio::spawn(async move {
let execution_fut = async {
let mut result_stream = this.run_query(plan).await?;
while let Some(result) = result_stream.next().await {
let result = result?;
let tables = result.get_tables()?;
for table in tables.as_slice() {
let response = res.arrow_batch_response(table)?;
if tx.send(Ok(response)).await.is_err() {
return Ok(());

Check warning on line 346 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L346

Added line #L346 was not covered by tests
}
}
}
Ok(())
};
if let Err(e) = execution_fut.await {
let _ = tx.send(Err(e)).await;

Check warning on line 353 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L353

Added line #L353 was not covered by tests
}
});

let stream = ReceiverStream::new(rx);

let stream = stream
.map_err(|e| {
Status::internal(
textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
)

Check warning on line 363 in src/daft-connect/src/execute.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/execute.rs#L361-L363

Added lines #L361 - L363 were not covered by tests
})
.chain(stream::once(ready(Ok(result_complete))));

Ok(Box::pin(stream))
}

async fn show_string(
&self,
show_string: ShowString,
Expand Down
8 changes: 7 additions & 1 deletion src/daft-connect/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use std::{collections::BTreeMap, sync::Arc};
use std::{
collections::BTreeMap,
sync::{Arc, RwLock},
};

use daft_catalog::DaftMetaCatalog;
use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use uuid::Uuid;

Expand All @@ -15,6 +19,7 @@ pub struct Session {
/// MicroPartitionSet associated with this session
/// this will be filled up as the user runs queries
pub(crate) psets: Arc<InMemoryPartitionSetCache>,
pub(crate) catalog: Arc<RwLock<DaftMetaCatalog>>,
}

impl Session {
Expand All @@ -34,6 +39,7 @@ impl Session {
id,
server_side_session_id,
psets: Arc::new(InMemoryPartitionSetCache::empty()),
catalog: Arc::new(RwLock::new(DaftMetaCatalog::default())),
}
}

Expand Down
38 changes: 37 additions & 1 deletion src/daft-connect/src/spark_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
};
use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder};
use daft_schema::schema::{Schema, SchemaRef};
use daft_sql::SQLPlanner;
use daft_table::Table;
use datatype::to_daft_datatype;
pub use datatype::to_spark_datatype;
Expand All @@ -36,7 +37,7 @@
},
read::ReadType,
relation::RelType,
Deduplicate, Expression, Limit, Range, Relation, Sort,
Deduplicate, Expression, Limit, Range, Relation, Sort, Sql,
};
use tracing::debug;

Expand Down Expand Up @@ -144,6 +145,7 @@
RelType::ShowString(_) => unreachable!("should already be handled in execute"),
RelType::Deduplicate(rel) => self.deduplicate(*rel).await,
RelType::Sort(rel) => self.sort(*rel).await,
RelType::Sql(sql) => self.sql(sql).await,
plan => not_yet_implemented!("relation type: \"{}\"", rel_name(&plan))?,
}
}
Expand Down Expand Up @@ -644,6 +646,40 @@
Ok(result)
}

#[allow(deprecated)]
async fn sql(&self, sql: Sql) -> eyre::Result<LogicalPlanBuilder> {
let Sql {
query,
args,
pos_args,
named_arguments,
pos_arguments,
} = sql;
if !args.is_empty() {
not_yet_implemented!("args")?;

Check warning on line 659 in src/daft-connect/src/spark_analyzer.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/spark_analyzer.rs#L659

Added line #L659 was not covered by tests
}
if !pos_args.is_empty() {
not_yet_implemented!("pos_args")?;

Check warning on line 662 in src/daft-connect/src/spark_analyzer.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/spark_analyzer.rs#L662

Added line #L662 was not covered by tests
}
if !named_arguments.is_empty() {
not_yet_implemented!("named_arguments")?;

Check warning on line 665 in src/daft-connect/src/spark_analyzer.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/spark_analyzer.rs#L665

Added line #L665 was not covered by tests
}
if !pos_arguments.is_empty() {
not_yet_implemented!("pos_arguments")?;

Check warning on line 668 in src/daft-connect/src/spark_analyzer.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/spark_analyzer.rs#L668

Added line #L668 was not covered by tests
}

let catalog = self
.session
.catalog
.read()
.map_err(|e| eyre::eyre!("Failed to read catalog: {e}"))?;
let catalog = catalog.clone();

let mut planner = SQLPlanner::new(catalog);
let plan = planner.plan_sql(&query)?;
Ok(plan.into())
}

pub fn to_daft_expr(&self, expression: &Expression) -> eyre::Result<daft_dsl::ExprRef> {
if let Some(common) = &expression.common {
if common.origin.is_some() {
Expand Down
Loading
Loading