Skip to content

Commit

Permalink
feat: Make graceful shutdown add fatal messages in postgres
Browse files Browse the repository at this point in the history
Also contains modifications to make the graceful shutdown process wait
for graceful shutdown actions to finish.
  • Loading branch information
srh committed Jun 28, 2024
1 parent 3272593 commit 8fe1af2
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 72 deletions.
5 changes: 5 additions & 0 deletions packages/cubejs-api-gateway/src/sql-server.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
setupLogger,
registerInterface,
shutdownInterface,
execSql,
SqlInterfaceInstance,
Request as NativeRequest,
Expand Down Expand Up @@ -330,4 +331,8 @@ export class SQLServer {
public async close(): Promise<void> {
// @todo Implement
}

public async shutdown(): Promise<void> {
await shutdownInterface(this.sqlInterfaceInstance!);
}
}
2 changes: 0 additions & 2 deletions packages/cubejs-backend-native/js/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ export const shutdownInterface = async (instance: SqlInterfaceInstance): Promise
const native = loadNative();

await native.shutdownInterface(instance);

await new Promise((resolve) => setTimeout(resolve, 2000));
};

export const execSql = async (instance: SqlInterfaceInstance, sqlQuery: string, stream: any, securityContext?: any): Promise<void> => {
Expand Down
37 changes: 29 additions & 8 deletions packages/cubejs-backend-native/src/node_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,10 @@ fn register_interface(mut cx: FunctionContext) -> JsResult<JsPromise> {

Ok(())
}));

CubeServices::wait_loops(loops).await.unwrap();
{
let mut w = services.processing_loop_handles.write().await;
*w = loops;
}
});
});

Expand All @@ -140,13 +142,32 @@ fn shutdown_interface(mut cx: FunctionContext) -> JsResult<JsPromise> {
let services = interface.services.clone();
let runtime = tokio_runtime_node(&mut cx)?;

runtime.block_on(async move {
let _ = services
.stop_processing_loops()
.await
.or_else(|err| cx.throw_error(err.to_string()));
runtime.spawn(async move {
match services.stop_processing_loops().await {
Ok(_) => {
let mut handles = Vec::new();
{
let mut w = services.processing_loop_handles.write().await;
std::mem::swap(&mut *w, &mut handles);
}
for h in handles {
let _ = h.await;
}

deferred
.settle_with(&channel, move |mut cx| Ok(cx.undefined()))
.await
.unwrap();
}
Err(err) => {
channel.send(move |mut cx| {
let err = JsError::error(&mut cx, err.to_string()).unwrap();
deferred.reject(&mut cx, err);
Ok(())
});
}
};
});
deferred.settle_with(&channel, move |mut cx| Ok(cx.undefined()));

Ok(promise)
}
Expand Down
24 changes: 18 additions & 6 deletions packages/cubejs-server/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ export class CubejsServer {
);
}

if (this.sqlServer) {
locks.push(
this.sqlServer.shutdown()
);
}

if (this.server) {
locks.push(
this.server.stop(
Expand All @@ -237,13 +243,19 @@ export class CubejsServer {
);
}

if (graceful) {
// Await before all connections/refresh scheduler will end jobs
await Promise.all(locks);
}
const shutdownAll = async () => {
try {
if (graceful) {
// Await before all connections/refresh scheduler will end jobs
await Promise.all(locks);
}
await this.core.shutdown();
} finally {
timeoutKiller.cancel();
}
};

await this.core.shutdown();
await timeoutKiller.cancel();
await Promise.any([shutdownAll(), timeoutKiller]);

return 0;
} catch (e: any) {
Expand Down
6 changes: 3 additions & 3 deletions rust/cubesql/cubesql/src/bin/cubesqld.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use cubesql::{

use log::Level;
use simple_logger::SimpleLogger;
use std::env;
use std::{env, sync::Arc};

use tokio::runtime::Builder;

Expand Down Expand Up @@ -39,14 +39,14 @@ fn main() {
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
runtime.block_on(async move {
config.configure().await;
let services = config.cube_services().await;
let services = Arc::new(config.cube_services().await);
log::debug!("Cube SQL Start");
stop_on_ctrl_c(&services).await;
services.wait_processing_loops().await.unwrap();
});
}

async fn stop_on_ctrl_c(s: &CubeServices) {
async fn stop_on_ctrl_c(s: &Arc<CubeServices>) {
let s = s.clone();
tokio::spawn(async move {
let mut counter = 0;
Expand Down
19 changes: 5 additions & 14 deletions rust/cubesql/cubesql/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,14 @@ use std::{
use std::sync::Arc;

use crate::sql::compiler_cache::{CompilerCache, CompilerCacheImpl};
use tokio::task::JoinHandle;
use tokio::{sync::RwLock, task::JoinHandle};

#[derive(Clone)]
pub struct CubeServices {
pub injector: Arc<Injector>,
pub processing_loop_handles: RwLock<Vec<LoopHandle>>,
}

impl CubeServices {
pub async fn start_processing_loops(&self) -> Result<(), CubeError> {
let futures = self.spawn_processing_loops().await?;
tokio::spawn(async move {
if let Err(e) = Self::wait_loops(futures).await {
error!("Error in processing loop: {}", e);
}
});
Ok(())
}

pub async fn wait_processing_loops(&self) -> Result<(), CubeError> {
let processing_loops = self.spawn_processing_loops().await?;
Self::wait_loops(processing_loops).await
Expand All @@ -57,9 +47,9 @@ impl CubeServices {
let mut futures = Vec::new();

if self.injector.has_service_typed::<PostgresServer>().await {
let mysql_server = self.injector.get_service_typed::<PostgresServer>().await;
let postgres_server = self.injector.get_service_typed::<PostgresServer>().await;
futures.push(tokio::spawn(async move {
if let Err(e) = mysql_server.processing_loop().await {
if let Err(e) = postgres_server.processing_loop().await {
error!("{}", e.to_string());
};

Expand Down Expand Up @@ -347,6 +337,7 @@ impl Config {
pub async fn cube_services(&self) -> CubeServices {
CubeServices {
injector: self.injector.clone(),
processing_loop_handles: RwLock::new(Vec::new()),
}
}

Expand Down
54 changes: 39 additions & 15 deletions rust/cubesql/cubesql/src/sql/postgres/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use log::{error, trace};
use std::sync::Arc;
use tokio::{
net::TcpListener,
sync::{oneshot, watch, RwLock},
sync::{watch, RwLock},
};
use tokio_util::sync::CancellationToken;

use crate::{
config::processing_loop::ProcessingLoop,
Expand Down Expand Up @@ -33,18 +34,28 @@ impl ProcessingLoop for PostgresServer {

println!("🔗 Cube SQL (pg) is listening on {}", self.address);

let shim_cancellation_token = CancellationToken::new();

let mut joinset = tokio::task::JoinSet::new();

loop {
let mut stop_receiver = self.close_socket_rx.write().await;
let (socket, _) = tokio::select! {
res = stop_receiver.changed() => {
if res.is_err() || *stop_receiver.borrow() {
trace!("[pg] Stopping processing_loop via channel");

return Ok(());
shim_cancellation_token.cancel();
break;
} else {
continue;
}
}
Some(_) = joinset.join_next() => {
// We do nothing here; whatever is here needs to be in the join_next() cleanup
// after the loop.
continue;
}
accept_res = listener.accept() => {
match accept_res {
Ok(res) => res,
Expand Down Expand Up @@ -73,20 +84,17 @@ impl ProcessingLoop for PostgresServer {

trace!("[pg] New connection {}", session.state.connection_id);

let (mut tx, rx) = oneshot::channel::<()>();

let connection_id = session.state.connection_id;
let session_manager = self.session_manager.clone();
tokio::spawn(async move {
tx.closed().await;

trace!("[pg] Removing connection {}", connection_id);

session_manager.drop_session(connection_id).await;
});

tokio::spawn(async move {
let handler = AsyncPostgresShim::run_on(socket, session.clone(), logger.clone());
let connection_interruptor = shim_cancellation_token.clone();
let join_handle: tokio::task::JoinHandle<()> = tokio::spawn(async move {
let handler = AsyncPostgresShim::run_on(
connection_interruptor,
socket,
session.clone(),
logger.clone(),
);
if let Err(e) = handler.await {
logger.error(
format!("Error during processing PostgreSQL connection: {}", e).as_str(),
Expand All @@ -99,11 +107,27 @@ impl ProcessingLoop for PostgresServer {
trace!("Backtrace: not found");
}
};
});

// We use a separate task because `handler` above, the result of
// `AsyncPostgresShim::run_on,` can panic, which we want to catch. (And which the
// JoinHandle catches.)
joinset.spawn(async move {
let _ = join_handle.await;

// Handler can finish with panic, it's why we are using additional channel to drop session by moving it here
std::mem::drop(rx);
trace!("[pg] Removing connection {}", connection_id);

session_manager.drop_session(connection_id).await;
});
}

// Now that we've had the stop signal, wait for outstanding connection tasks to finish
// cleanly.
while let Some(_) = joinset.join_next().await {
// We do nothing here, same as the join_next() handler in the loop.
}

Ok(())
}

async fn stop_processing(&self) -> Result<(), CubeError> {
Expand Down
Loading

0 comments on commit 8fe1af2

Please sign in to comment.