Skip to content

Commit

Permalink
[FEAT] consolidate Spark session fixture into conftest.py (#3341)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka authored Nov 20, 2024
1 parent ec24c80 commit bdfb8c6
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 60 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1208,10 +1208,11 @@ def sql_expr(sql: str) -> PyExpr: ...
def list_sql_functions() -> list[SQLFunctionStub]: ...
def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ...
def to_struct(inputs: list[PyExpr]) -> PyExpr: ...
def connect_start(addr: str) -> ConnectionHandle: ...
def connect_start(addr: str = "sc://0.0.0.0:0") -> ConnectionHandle: ...

class ConnectionHandle:
def shutdown(self) -> None: ...
def port(self) -> int: ...

# expr numeric ops
def abs(expr: PyExpr) -> PyExpr: ...
Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[dependencies]
arrow2 = {workspace = true}
async-stream = "0.3.6"
common-daft-config = {workspace = true}
daft-local-execution = {workspace = true}
daft-local-plan = {workspace = true}
Expand Down
51 changes: 37 additions & 14 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub mod util;
#[cfg_attr(feature = "python", pyo3::pyclass)]
pub struct ConnectionHandle {
shutdown_signal: Option<tokio::sync::oneshot::Sender<()>>,
port: u16,
}

#[cfg_attr(feature = "python", pyo3::pymethods)]
Expand All @@ -47,12 +48,19 @@ impl ConnectionHandle {
};
shutdown_signal.send(()).unwrap();
}

pub fn port(&self) -> u16 {
self.port
}
}

pub fn start(addr: &str) -> eyre::Result<ConnectionHandle> {
info!("Daft-Connect server listening on {addr}");
let addr = util::parse_spark_connect_address(addr)?;

let listener = std::net::TcpListener::bind(addr)?;
let port = listener.local_addr()?.port();

let service = DaftSparkConnectService::default();

info!("Daft-Connect server listening on {addr}");
Expand All @@ -61,25 +69,40 @@ pub fn start(addr: &str) -> eyre::Result<ConnectionHandle> {

let handle = ConnectionHandle {
shutdown_signal: Some(shutdown_signal),
port,
};

std::thread::spawn(move || {
let runtime = tokio::runtime::Runtime::new().unwrap();
let result = runtime
.block_on(async {
tokio::select! {
result = Server::builder()
.add_service(SparkConnectServiceServer::new(service))
.serve(addr) => {
result
}
_ = shutdown_receiver => {
info!("Received shutdown signal");
Ok(())
let result = runtime.block_on(async {
let incoming = {
let listener = tokio::net::TcpListener::from_std(listener)
.wrap_err("Failed to create TcpListener from std::net::TcpListener")?;

async_stream::stream! {
loop {
match listener.accept().await {
Ok((stream, _)) => yield Ok(stream),
Err(e) => yield Err(e),
}
}
}
})
.wrap_err_with(|| format!("Failed to start server on {addr}"));
};

let result = tokio::select! {
result = Server::builder()
.add_service(SparkConnectServiceServer::new(service))
.serve_with_incoming(incoming)=> {
result
}
_ = shutdown_receiver => {
info!("Received shutdown signal");
Ok(())
}
};

result.wrap_err_with(|| format!("Failed to start server on {addr}"))
});

if let Err(e) = result {
eprintln!("Daft-Connect server error: {e:?}");
Expand Down Expand Up @@ -363,7 +386,7 @@ impl SparkConnectService for DaftSparkConnectService {

#[cfg(feature = "python")]
#[pyo3::pyfunction]
#[pyo3(name = "connect_start")]
#[pyo3(name = "connect_start", signature = (addr = "sc://0.0.0.0:0"))]
pub fn py_connect_start(addr: &str) -> pyo3::PyResult<ConnectionHandle> {
start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}")))
}
Expand Down
Empty file removed tests/connect/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions tests/connect/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

import pytest
from pyspark.sql import SparkSession


@pytest.fixture(scope="session")
def spark_session():
"""
Fixture to create and clean up a Spark session.
This fixture is available to all test files and creates a single
Spark session for the entire test suite run.
"""
from daft.daft import connect_start

# Start Daft Connect server
server = connect_start()

url = f"sc://localhost:{server.port()}"

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").remote(url).getOrCreate()

yield session

# Cleanup
server.shutdown()
session.stop()
24 changes: 0 additions & 24 deletions tests/connect/test_config_simple.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,5 @@
from __future__ import annotations

import time

import pytest
from pyspark.sql import SparkSession


@pytest.fixture
def spark_session():
"""Fixture to create and clean up a Spark session."""
from daft.daft import connect_start

# Start Daft Connect server
server = connect_start("sc://localhost:50051")

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate()

yield session

# Cleanup
server.shutdown()
session.stop()
time.sleep(2) # Allow time for session cleanup


def test_set_operation(spark_session):
"""Test the Set operation with various data types and edge cases."""
Expand Down
21 changes: 0 additions & 21 deletions tests/connect/test_range_simple.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,5 @@
from __future__ import annotations

import pytest
from pyspark.sql import SparkSession


@pytest.fixture
def spark_session():
"""Fixture to create and clean up a Spark session."""
from daft.daft import connect_start

# Start Daft Connect server
server = connect_start("sc://localhost:50051")

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate()

yield session

# Cleanup
server.shutdown()
session.stop()


def test_range_operation(spark_session):
# Create a range using Spark
Expand Down

0 comments on commit bdfb8c6

Please sign in to comment.