diff --git a/packages/cubejs-api-gateway/src/sql-server.ts b/packages/cubejs-api-gateway/src/sql-server.ts index 303ef5760691a..a1f9178ec8273 100644 --- a/packages/cubejs-api-gateway/src/sql-server.ts +++ b/packages/cubejs-api-gateway/src/sql-server.ts @@ -1,6 +1,7 @@ import { setupLogger, registerInterface, + shutdownInterface, execSql, SqlInterfaceInstance, Request as NativeRequest, @@ -330,4 +331,8 @@ export class SQLServer { public async close(): Promise { // @todo Implement } + + public async shutdown(): Promise { + await shutdownInterface(this.sqlInterfaceInstance!); + } } diff --git a/packages/cubejs-backend-native/js/index.ts b/packages/cubejs-backend-native/js/index.ts index e15f499b3baac..073e7489be199 100644 --- a/packages/cubejs-backend-native/js/index.ts +++ b/packages/cubejs-backend-native/js/index.ts @@ -332,8 +332,6 @@ export const shutdownInterface = async (instance: SqlInterfaceInstance): Promise const native = loadNative(); await native.shutdownInterface(instance); - - await new Promise((resolve) => setTimeout(resolve, 2000)); }; export const execSql = async (instance: SqlInterfaceInstance, sqlQuery: string, stream: any, securityContext?: any): Promise => { diff --git a/packages/cubejs-backend-native/src/node_export.rs b/packages/cubejs-backend-native/src/node_export.rs index 869b73c475b2e..90732cf1932f1 100644 --- a/packages/cubejs-backend-native/src/node_export.rs +++ b/packages/cubejs-backend-native/src/node_export.rs @@ -123,8 +123,10 @@ fn register_interface(mut cx: FunctionContext) -> JsResult { Ok(()) })); - - CubeServices::wait_loops(loops).await.unwrap(); + { + let mut w = services.processing_loop_handles.write().await; + *w = loops; + } }); }); @@ -140,13 +142,32 @@ fn shutdown_interface(mut cx: FunctionContext) -> JsResult { let services = interface.services.clone(); let runtime = tokio_runtime_node(&mut cx)?; - runtime.block_on(async move { - let _ = services - .stop_processing_loops() - .await - .or_else(|err| cx.throw_error(err.to_string())); + runtime.spawn(async move { + match services.stop_processing_loops().await { + Ok(_) => { + let mut handles = Vec::new(); + { + let mut w = services.processing_loop_handles.write().await; + std::mem::swap(&mut *w, &mut handles); + } + for h in handles { + let _ = h.await; + } + + deferred + .settle_with(&channel, move |mut cx| Ok(cx.undefined())) + .await + .unwrap(); + } + Err(err) => { + channel.send(move |mut cx| { + let err = JsError::error(&mut cx, err.to_string()).unwrap(); + deferred.reject(&mut cx, err); + Ok(()) + }); + } + }; }); - deferred.settle_with(&channel, move |mut cx| Ok(cx.undefined())); Ok(promise) } diff --git a/packages/cubejs-server/src/server.ts b/packages/cubejs-server/src/server.ts index fdf02f809c4b2..1ced01e2b59d4 100644 --- a/packages/cubejs-server/src/server.ts +++ b/packages/cubejs-server/src/server.ts @@ -229,6 +229,12 @@ export class CubejsServer { ); } + if (this.sqlServer) { + locks.push( + this.sqlServer.shutdown() + ); + } + if (this.server) { locks.push( this.server.stop( @@ -237,13 +243,19 @@ export class CubejsServer { ); } - if (graceful) { - // Await before all connections/refresh scheduler will end jobs - await Promise.all(locks); - } + const shutdownAll = async () => { + try { + if (graceful) { + // Await before all connections/refresh scheduler will end jobs + await Promise.all(locks); + } + await this.core.shutdown(); + } finally { + timeoutKiller.cancel(); + } + }; - await this.core.shutdown(); - await timeoutKiller.cancel(); + await Promise.any([shutdownAll(), timeoutKiller]); return 0; } catch (e: any) { diff --git a/rust/cubesql/cubesql/src/bin/cubesqld.rs b/rust/cubesql/cubesql/src/bin/cubesqld.rs index 05ca50853412d..385dad0cd49ed 100644 --- a/rust/cubesql/cubesql/src/bin/cubesqld.rs +++ b/rust/cubesql/cubesql/src/bin/cubesqld.rs @@ -5,7 +5,7 @@ use cubesql::{ use log::Level; use simple_logger::SimpleLogger; -use std::env; +use std::{env, sync::Arc}; use tokio::runtime::Builder; @@ -39,14 +39,14 @@ fn main() { let runtime = Builder::new_multi_thread().enable_all().build().unwrap(); runtime.block_on(async move { config.configure().await; - let services = config.cube_services().await; + let services = Arc::new(config.cube_services().await); log::debug!("Cube SQL Start"); stop_on_ctrl_c(&services).await; services.wait_processing_loops().await.unwrap(); }); } -async fn stop_on_ctrl_c(s: &CubeServices) { +async fn stop_on_ctrl_c(s: &Arc) { let s = s.clone(); tokio::spawn(async move { let mut counter = 0; diff --git a/rust/cubesql/cubesql/src/config/mod.rs b/rust/cubesql/cubesql/src/config/mod.rs index 0dbc2e9e7b1c0..a8b4d2a3bd15c 100644 --- a/rust/cubesql/cubesql/src/config/mod.rs +++ b/rust/cubesql/cubesql/src/config/mod.rs @@ -22,24 +22,14 @@ use std::{ use std::sync::Arc; use crate::sql::compiler_cache::{CompilerCache, CompilerCacheImpl}; -use tokio::task::JoinHandle; +use tokio::{sync::RwLock, task::JoinHandle}; -#[derive(Clone)] pub struct CubeServices { pub injector: Arc, + pub processing_loop_handles: RwLock>, } impl CubeServices { - pub async fn start_processing_loops(&self) -> Result<(), CubeError> { - let futures = self.spawn_processing_loops().await?; - tokio::spawn(async move { - if let Err(e) = Self::wait_loops(futures).await { - error!("Error in processing loop: {}", e); - } - }); - Ok(()) - } - pub async fn wait_processing_loops(&self) -> Result<(), CubeError> { let processing_loops = self.spawn_processing_loops().await?; Self::wait_loops(processing_loops).await @@ -57,9 +47,9 @@ impl CubeServices { let mut futures = Vec::new(); if self.injector.has_service_typed::().await { - let mysql_server = self.injector.get_service_typed::().await; + let postgres_server = self.injector.get_service_typed::().await; futures.push(tokio::spawn(async move { - if let Err(e) = mysql_server.processing_loop().await { + if let Err(e) = postgres_server.processing_loop().await { error!("{}", e.to_string()); }; @@ -347,6 +337,7 @@ impl Config { pub async fn cube_services(&self) -> CubeServices { CubeServices { injector: self.injector.clone(), + processing_loop_handles: RwLock::new(Vec::new()), } } diff --git a/rust/cubesql/cubesql/src/sql/postgres/service.rs b/rust/cubesql/cubesql/src/sql/postgres/service.rs index a75d3bd731961..a0b2489d08a0b 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/service.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/service.rs @@ -3,8 +3,9 @@ use log::{error, trace}; use std::sync::Arc; use tokio::{ net::TcpListener, - sync::{oneshot, watch, RwLock}, + sync::{watch, RwLock}, }; +use tokio_util::sync::CancellationToken; use crate::{ config::processing_loop::ProcessingLoop, @@ -33,6 +34,10 @@ impl ProcessingLoop for PostgresServer { println!("🔗 Cube SQL (pg) is listening on {}", self.address); + let shim_cancellation_token = CancellationToken::new(); + + let mut joinset = tokio::task::JoinSet::new(); + loop { let mut stop_receiver = self.close_socket_rx.write().await; let (socket, _) = tokio::select! { @@ -40,11 +45,17 @@ impl ProcessingLoop for PostgresServer { if res.is_err() || *stop_receiver.borrow() { trace!("[pg] Stopping processing_loop via channel"); - return Ok(()); + shim_cancellation_token.cancel(); + break; } else { continue; } } + Some(_) = joinset.join_next() => { + // We do nothing here; whatever is here needs to be in the join_next() cleanup + // after the loop. + continue; + } accept_res = listener.accept() => { match accept_res { Ok(res) => res, @@ -73,20 +84,17 @@ impl ProcessingLoop for PostgresServer { trace!("[pg] New connection {}", session.state.connection_id); - let (mut tx, rx) = oneshot::channel::<()>(); - let connection_id = session.state.connection_id; let session_manager = self.session_manager.clone(); - tokio::spawn(async move { - tx.closed().await; - - trace!("[pg] Removing connection {}", connection_id); - - session_manager.drop_session(connection_id).await; - }); - tokio::spawn(async move { - let handler = AsyncPostgresShim::run_on(socket, session.clone(), logger.clone()); + let connection_interruptor = shim_cancellation_token.clone(); + let join_handle: tokio::task::JoinHandle<()> = tokio::spawn(async move { + let handler = AsyncPostgresShim::run_on( + connection_interruptor, + socket, + session.clone(), + logger.clone(), + ); if let Err(e) = handler.await { logger.error( format!("Error during processing PostgreSQL connection: {}", e).as_str(), @@ -99,11 +107,27 @@ impl ProcessingLoop for PostgresServer { trace!("Backtrace: not found"); } }; + }); + + // We use a separate task because `handler` above, the result of + // `AsyncPostgresShim::run_on,` can panic, which we want to catch. (And which the + // JoinHandle catches.) + joinset.spawn(async move { + let _ = join_handle.await; - // Handler can finish with panic, it's why we are using additional channel to drop session by moving it here - std::mem::drop(rx); + trace!("[pg] Removing connection {}", connection_id); + + session_manager.drop_session(connection_id).await; }); } + + // Now that we've had the stop signal, wait for outstanding connection tasks to finish + // cleanly. + while let Some(_) = joinset.join_next().await { + // We do nothing here, same as the join_next() handler in the loop. + } + + Ok(()) } async fn stop_processing(&self) -> Result<(), CubeError> { diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index 609ed73aeefbb..aeb507090450d 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -37,6 +37,8 @@ use uuid::Uuid; pub struct AsyncPostgresShim { socket: TcpStream, + // If empty, this means socket is on a message boundary. + partial_write_buf: bytes::BytesMut, // Extended query cursors: HashMap, portals: HashMap, @@ -225,19 +227,34 @@ impl From for ConnectionError { impl AsyncPostgresShim { pub async fn run_on( + shutdown_interruptor: CancellationToken, socket: TcpStream, session: Arc, logger: Arc, ) -> Result<(), ConnectionError> { let mut shim = Self { socket, + partial_write_buf: bytes::BytesMut::new(), cursors: HashMap::new(), portals: HashMap::new(), session, logger, }; - match shim.run().await { + let run_result = tokio::select! { + _ = shutdown_interruptor.cancelled() => { + // We flush the partially written buf and add the fatal message -- it's another + // place's responsibility to impose a timeout and abort us. + shim.socket.write_all_buf(&mut shim.partial_write_buf).await?; + shim.partial_write_buf = bytes::BytesMut::new(); + shim.write_admin_shutdown_fatal_message().await?; + shim.socket.shutdown().await?; + return Ok(()); + } + res = shim.run() => res, + }; + + match run_result { Err(e) => { if let ConnectionError::Protocol(ProtocolError::IO { source, .. }, _) = &e { if source.kind() == ErrorKind::BrokenPipe @@ -250,6 +267,7 @@ impl AsyncPostgresShim { } else if let ConnectionError::CompilationError(CompilationError::Fatal(_, _), _) = &e { + assert!(shim.partial_write_buf.is_empty()); shim.write(e.to_error_response()).await?; shim.socket.shutdown().await?; return Ok(()); @@ -264,6 +282,16 @@ impl AsyncPostgresShim { } } + fn admin_shutdown_error() -> ConnectionError { + ConnectionError::Protocol( + ProtocolError::ErrorResponse { + source: ErrorResponse::admin_shutdown(), + backtrace: Backtrace::disabled(), + }, + None, + ) + } + pub async fn run(&mut self) -> Result<(), ConnectionError> { let initial_parameters = match self.process_initial_message().await? { StartupState::Success(parameters) => parameters, @@ -536,7 +564,7 @@ impl AsyncPostgresShim { &mut self, message: Vec, ) -> Result<(), ConnectionError> { - buffer::write_messages(&mut self.socket, message).await?; + buffer::write_messages(&mut self.partial_write_buf, &mut self.socket, message).await?; Ok(()) } @@ -546,8 +574,12 @@ impl AsyncPostgresShim { completion: PortalCompletion, ) -> Result<(), ConnectionError> { match completion { - PortalCompletion::Complete(c) => buffer::write_message(&mut self.socket, c).await?, - PortalCompletion::Suspended(s) => buffer::write_message(&mut self.socket, s).await?, + PortalCompletion::Complete(c) => { + buffer::write_message(&mut self.partial_write_buf, &mut self.socket, c).await? + } + PortalCompletion::Suspended(s) => { + buffer::write_message(&mut self.partial_write_buf, &mut self.socket, s).await? + } } Ok(()) @@ -557,7 +589,18 @@ impl AsyncPostgresShim { &mut self, message: Message, ) -> Result<(), ConnectionError> { - buffer::write_message(&mut self.socket, message).await?; + buffer::write_message(&mut self.partial_write_buf, &mut self.socket, message).await?; + + Ok(()) + } + + pub async fn write_admin_shutdown_fatal_message(&mut self) -> Result<(), ConnectionError> { + buffer::write_message( + &mut bytes::BytesMut::new(), + &mut self.socket, + Self::admin_shutdown_error().to_error_response(), + ) + .await?; Ok(()) } @@ -617,7 +660,12 @@ impl AsyncPostgresShim { startup_message.major, startup_message.minor, ), ); - buffer::write_message(&mut self.socket, error_response).await?; + buffer::write_message( + &mut self.partial_write_buf, + &mut self.socket, + error_response, + ) + .await?; return Ok(StartupState::Denied); } @@ -628,7 +676,12 @@ impl AsyncPostgresShim { protocol::ErrorCode::InvalidAuthorizationSpecification, "no PostgreSQL user name specified in startup packet".to_string(), ); - buffer::write_message(&mut self.socket, error_response).await?; + buffer::write_message( + &mut self.partial_write_buf, + &mut self.socket, + error_response, + ) + .await?; return Ok(StartupState::Denied); } @@ -675,7 +728,12 @@ impl AsyncPostgresShim { protocol::ErrorCode::InvalidPassword, format!("password authentication failed for user \"{}\"", &user), ); - buffer::write_message(&mut self.socket, error_response).await?; + buffer::write_message( + &mut self.partial_write_buf, + &mut self.socket, + error_response, + ) + .await?; return Ok(false); } @@ -875,14 +933,14 @@ impl AsyncPostgresShim { } match chunk { - PortalBatch::Rows(writer) if writer.has_data() => buffer::write_direct(&mut self.socket, writer).await?, + PortalBatch::Rows(writer) if writer.has_data() => buffer::write_direct(&mut self.partial_write_buf, &mut self.socket, writer).await?, PortalBatch::Completion(completion) => { self.session.state.end_query(); // TODO: match completion { - PortalCompletion::Complete(c) => buffer::write_message(&mut self.socket, c).await?, - PortalCompletion::Suspended(s) => buffer::write_message(&mut self.socket, s).await?, + PortalCompletion::Complete(c) => buffer::write_message(&mut self.partial_write_buf, &mut self.socket, c).await?, + PortalCompletion::Suspended(s) => buffer::write_message(&mut self.partial_write_buf, &mut self.socket, s).await?, } return Ok(()); @@ -1609,7 +1667,7 @@ impl AsyncPostgresShim { }, PortalBatch::Rows(writer) => { if writer.has_data() { - buffer::write_direct(&mut self.socket, writer).await? + buffer::write_direct(&mut self.partial_write_buf, &mut self.socket, writer).await? } } PortalBatch::Completion(completion) => return self.write_completion(completion).await, diff --git a/rust/cubesql/cubesql/src/sql/postgres/writer.rs b/rust/cubesql/cubesql/src/sql/postgres/writer.rs index 62c3860abaa7a..c90687bdebbad 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/writer.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/writer.rs @@ -394,7 +394,7 @@ mod tests { writer.write_value(true)?; writer.end_row()?; - buffer::write_direct(&mut cursor, writer).await?; + buffer::write_direct(&mut BytesMut::new(), &mut cursor, writer).await?; assert_eq!( cursor.get_ref()[0..], @@ -422,7 +422,7 @@ mod tests { writer.write_value(true)?; writer.end_row()?; - buffer::write_direct(&mut cursor, writer).await?; + buffer::write_direct(&mut BytesMut::new(), &mut cursor, writer).await?; assert_eq!( cursor.get_ref()[0..], @@ -450,7 +450,7 @@ mod tests { writer.write_value(Decimal128Value::new(2, 15))?; writer.end_row()?; - buffer::write_direct(&mut cursor, writer).await?; + buffer::write_direct(&mut BytesMut::new(), &mut cursor, writer).await?; assert_eq!( cursor.get_ref()[0..], @@ -488,7 +488,7 @@ mod tests { writer.write_value(ListValue::new(Arc::new(col.finish()) as ArrayRef))?; writer.end_row()?; - buffer::write_direct(&mut cursor, writer).await?; + buffer::write_direct(&mut BytesMut::new(), &mut cursor, writer).await?; assert_eq!( cursor.get_ref()[0..], diff --git a/rust/cubesql/pg-srv/src/buffer.rs b/rust/cubesql/pg-srv/src/buffer.rs index ab81d9cb10ccb..27988797e0d21 100644 --- a/rust/cubesql/pg-srv/src/buffer.rs +++ b/rust/cubesql/pg-srv/src/buffer.rs @@ -131,12 +131,18 @@ pub async fn read_format( /// Same as the write_message function, but it doesn’t append header for frame (code + size). pub async fn write_direct( + partial_write: &mut BytesMut, writer: &mut Writer, message: Message, ) -> Result<(), ProtocolError> { + let mut bytes_mut = BytesMut::new(); match message.serialize() { Some(buffer) => { - writer.write_all(&buffer).await?; + // TODO: Yet another memory copy. + bytes_mut.extend_from_slice(&buffer); + *partial_write = bytes_mut; + writer.write_all_buf(partial_write).await?; + *partial_write = BytesMut::new(); writer.flush().await?; } _ => {} @@ -170,8 +176,11 @@ fn message_serialize( Ok(()) } -/// Write multiple F messages with frame's headers to the writer. +/// Write multiple F messages with frame's headers to the writer. The variable +/// `*partial_write` is set for graceful shutdown attempts with partial writes. +/// Upon a successful write, it is left empty. pub async fn write_messages( + partial_write: &mut BytesMut, writer: &mut Writer, messages: Vec, ) -> Result<(), ProtocolError> { @@ -181,20 +190,34 @@ pub async fn write_messages( message_serialize(message, &mut buffer)?; } - writer.write_all(&buffer).await?; + // For simplicity we obviously don't save message boundary data with + // `*partial_write`, which means that a AdminShutdown fatal error message + // would have to be written after _all_ these messages. + *partial_write = buffer; + writer.write_all_buf(partial_write).await?; + *partial_write = BytesMut::new(); + + // (We _could_ reuse the buffer in *partial_write, doing fewer allocations -- after + // making other serialization logic allocate less and thinking about memory usage.) + writer.flush().await?; Ok(()) } -/// Write single F message with frame's headers to the writer. +/// Write single F message with frame's headers to the writer. As with the +/// function `write_messages`, `*partial_write` is set for graceful shutdown +/// attempts with partial writes. Upon a successful write, it is left empty. pub async fn write_message( + partial_write: &mut BytesMut, writer: &mut Writer, message: Message, ) -> Result<(), ProtocolError> { let mut buffer = BytesMut::with_capacity(64); message_serialize(message, &mut buffer)?; - writer.write_all(&buffer).await?; + *partial_write = buffer; + writer.write_all_buf(partial_write).await?; + *partial_write = BytesMut::new(); writer.flush().await?; Ok(()) } diff --git a/rust/cubesql/pg-srv/src/protocol.rs b/rust/cubesql/pg-srv/src/protocol.rs index 9928ff6d1bc67..257acf4abb7f1 100644 --- a/rust/cubesql/pg-srv/src/protocol.rs +++ b/rust/cubesql/pg-srv/src/protocol.rs @@ -217,6 +217,14 @@ impl ErrorResponse { message: "canceling statement due to user request".to_string(), } } + + pub fn admin_shutdown() -> Self { + Self { + severity: ErrorSeverity::Fatal, + code: ErrorCode::AdminShutdown, + message: "terminating connection due to shutdown signal".to_string(), + } + } } impl Serialize for ErrorResponse { @@ -958,6 +966,7 @@ pub enum ErrorCode { ObjectNotInPrerequisiteState, // Class 57 - Operator Intervention QueryCanceled, + AdminShutdown, // XX - Internal Error InternalError, } @@ -979,6 +988,7 @@ impl Display for ErrorCode { Self::ConfigurationLimitExceeded => "53400", Self::ObjectNotInPrerequisiteState => "55000", Self::QueryCanceled => "57014", + Self::AdminShutdown => "57P01", Self::InternalError => "XX000", }; write!(f, "{}", string) @@ -1123,7 +1133,12 @@ mod tests { // First step, We write struct to the buffer let mut cursor = Cursor::new(vec![]); - buffer::write_message(&mut cursor, expected_message.clone()).await?; + buffer::write_message( + &mut bytes::BytesMut::new(), + &mut cursor, + expected_message.clone(), + ) + .await?; // Second step, We read form the buffer and output structure must be the same as original let buffer = cursor.get_ref()[..].to_vec(); @@ -1348,7 +1363,7 @@ mod tests { async fn test_frontend_message_write_complete_parse() -> Result<(), ProtocolError> { let mut cursor = Cursor::new(vec![]); - buffer::write_message(&mut cursor, ParseComplete {}).await?; + buffer::write_message(&mut bytes::BytesMut::new(), &mut cursor, ParseComplete {}).await?; assert_eq!(cursor.get_ref()[0..], vec![49, 0, 0, 0, 4]); @@ -1375,7 +1390,7 @@ mod tests { Format::Text, ), ]); - buffer::write_message(&mut cursor, desc).await?; + buffer::write_message(&mut bytes::BytesMut::new(), &mut cursor, desc).await?; assert_eq!( cursor.get_ref()[0..],