From bdfb8c6feb9609b7665b971e1d64d937ed2e357c Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 11:52:41 -0800 Subject: [PATCH] [FEAT] consolidate Spark session fixture into conftest.py (#3341) --- Cargo.lock | 1 + daft/daft/__init__.pyi | 3 +- src/daft-connect/Cargo.toml | 1 + src/daft-connect/src/lib.rs | 51 +++++++++++++++++++++-------- tests/connect/__init__.py | 0 tests/connect/conftest.py | 29 ++++++++++++++++ tests/connect/test_config_simple.py | 24 -------------- tests/connect/test_range_simple.py | 21 ------------ 8 files changed, 70 insertions(+), 60 deletions(-) delete mode 100644 tests/connect/__init__.py create mode 100644 tests/connect/conftest.py diff --git a/Cargo.lock b/Cargo.lock index 33c6acd968..36e3598748 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1932,6 +1932,7 @@ name = "daft-connect" version = "0.3.0-dev0" dependencies = [ "arrow2", + "async-stream", "common-daft-config", "daft-local-execution", "daft-local-plan", diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 368af56011..b821f58115 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1208,10 +1208,11 @@ def sql_expr(sql: str) -> PyExpr: ... def list_sql_functions() -> list[SQLFunctionStub]: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... -def connect_start(addr: str) -> ConnectionHandle: ... +def connect_start(addr: str = "sc://0.0.0.0:0") -> ConnectionHandle: ... class ConnectionHandle: def shutdown(self) -> None: ... + def port(self) -> int: ... # expr numeric ops def abs(expr: PyExpr) -> PyExpr: ... diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index 7955bdacd4..f94cb284be 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -1,5 +1,6 @@ [dependencies] arrow2 = {workspace = true} +async-stream = "0.3.6" common-daft-config = {workspace = true} daft-local-execution = {workspace = true} daft-local-plan = {workspace = true} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 95cd9ce75a..882b2af1af 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -37,6 +37,7 @@ pub mod util; #[cfg_attr(feature = "python", pyo3::pyclass)] pub struct ConnectionHandle { shutdown_signal: Option>, + port: u16, } #[cfg_attr(feature = "python", pyo3::pymethods)] @@ -47,12 +48,19 @@ impl ConnectionHandle { }; shutdown_signal.send(()).unwrap(); } + + pub fn port(&self) -> u16 { + self.port + } } pub fn start(addr: &str) -> eyre::Result { info!("Daft-Connect server listening on {addr}"); let addr = util::parse_spark_connect_address(addr)?; + let listener = std::net::TcpListener::bind(addr)?; + let port = listener.local_addr()?.port(); + let service = DaftSparkConnectService::default(); info!("Daft-Connect server listening on {addr}"); @@ -61,25 +69,40 @@ pub fn start(addr: &str) -> eyre::Result { let handle = ConnectionHandle { shutdown_signal: Some(shutdown_signal), + port, }; std::thread::spawn(move || { let runtime = tokio::runtime::Runtime::new().unwrap(); - let result = runtime - .block_on(async { - tokio::select! { - result = Server::builder() - .add_service(SparkConnectServiceServer::new(service)) - .serve(addr) => { - result - } - _ = shutdown_receiver => { - info!("Received shutdown signal"); - Ok(()) + let result = runtime.block_on(async { + let incoming = { + let listener = tokio::net::TcpListener::from_std(listener) + .wrap_err("Failed to create TcpListener from std::net::TcpListener")?; + + async_stream::stream! { + loop { + match listener.accept().await { + Ok((stream, _)) => yield Ok(stream), + Err(e) => yield Err(e), + } } } - }) - .wrap_err_with(|| format!("Failed to start server on {addr}")); + }; + + let result = tokio::select! { + result = Server::builder() + .add_service(SparkConnectServiceServer::new(service)) + .serve_with_incoming(incoming)=> { + result + } + _ = shutdown_receiver => { + info!("Received shutdown signal"); + Ok(()) + } + }; + + result.wrap_err_with(|| format!("Failed to start server on {addr}")) + }); if let Err(e) = result { eprintln!("Daft-Connect server error: {e:?}"); @@ -363,7 +386,7 @@ impl SparkConnectService for DaftSparkConnectService { #[cfg(feature = "python")] #[pyo3::pyfunction] -#[pyo3(name = "connect_start")] +#[pyo3(name = "connect_start", signature = (addr = "sc://0.0.0.0:0"))] pub fn py_connect_start(addr: &str) -> pyo3::PyResult { start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}"))) } diff --git a/tests/connect/__init__.py b/tests/connect/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/connect/conftest.py b/tests/connect/conftest.py new file mode 100644 index 0000000000..60c5ae9986 --- /dev/null +++ b/tests/connect/conftest.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import pytest +from pyspark.sql import SparkSession + + +@pytest.fixture(scope="session") +def spark_session(): + """ + Fixture to create and clean up a Spark session. + + This fixture is available to all test files and creates a single + Spark session for the entire test suite run. + """ + from daft.daft import connect_start + + # Start Daft Connect server + server = connect_start() + + url = f"sc://localhost:{server.port()}" + + # Initialize Spark Connect session + session = SparkSession.builder.appName("DaftConfigTest").remote(url).getOrCreate() + + yield session + + # Cleanup + server.shutdown() + session.stop() diff --git a/tests/connect/test_config_simple.py b/tests/connect/test_config_simple.py index de65c7c0f2..9a472f24e2 100644 --- a/tests/connect/test_config_simple.py +++ b/tests/connect/test_config_simple.py @@ -1,29 +1,5 @@ from __future__ import annotations -import time - -import pytest -from pyspark.sql import SparkSession - - -@pytest.fixture -def spark_session(): - """Fixture to create and clean up a Spark session.""" - from daft.daft import connect_start - - # Start Daft Connect server - server = connect_start("sc://localhost:50051") - - # Initialize Spark Connect session - session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate() - - yield session - - # Cleanup - server.shutdown() - session.stop() - time.sleep(2) # Allow time for session cleanup - def test_set_operation(spark_session): """Test the Set operation with various data types and edge cases.""" diff --git a/tests/connect/test_range_simple.py b/tests/connect/test_range_simple.py index 86f348470e..b277d38481 100644 --- a/tests/connect/test_range_simple.py +++ b/tests/connect/test_range_simple.py @@ -1,26 +1,5 @@ from __future__ import annotations -import pytest -from pyspark.sql import SparkSession - - -@pytest.fixture -def spark_session(): - """Fixture to create and clean up a Spark session.""" - from daft.daft import connect_start - - # Start Daft Connect server - server = connect_start("sc://localhost:50051") - - # Initialize Spark Connect session - session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate() - - yield session - - # Cleanup - server.shutdown() - session.stop() - def test_range_operation(spark_session): # Create a range using Spark