From 4f75bce011ec48ad3934a6641d2aafba64e10e81 Mon Sep 17 00:00:00 2001 From: Will Crichton Date: Wed, 21 Aug 2024 15:59:10 -0700 Subject: [PATCH] Cancellation solution --- crates/server/src/main.rs | 70 +++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index 836b05e..7d94e8e 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -17,10 +17,17 @@ async fn index(_req: Request) -> Response { } #[derive(Serialize, Deserialize)] -struct Messages { +struct MessagesRequest { messages: Vec, } +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +enum MessagesResponse { + Success { messages: Vec }, + Cancelled, +} + async fn load_docs(paths: Vec) -> Vec { let mut doc_futs = paths .into_iter() @@ -33,46 +40,74 @@ async fn load_docs(paths: Vec) -> Vec { docs } -type Payload = (Arc>, oneshot::Sender>); +type Payload = (Arc>, oneshot::Sender>>); -fn chatbot_thread() -> mpsc::Sender { - let (tx, mut rx) = mpsc::channel::(1024); +fn chatbot_thread() -> (mpsc::Sender, mpsc::Sender<()>) { + let (req_tx, mut req_rx) = mpsc::channel::(1024); + let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); tokio::spawn(async move { let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]); - while let Some((messages, responder)) = rx.recv().await { + while let Some((messages, responder)) = req_rx.recv().await { let doc_paths = chatbot.retrieval_documents(&messages); let docs = load_docs(doc_paths).await; - let response = chatbot.query_chat(&messages, &docs).await; - responder.send(response).unwrap(); + let chat_fut = chatbot.query_chat(&messages, &docs); + let cancel_fut = cancel_rx.recv(); + tokio::select! { + response = chat_fut => { + responder.send(Some(response)).unwrap(); + } + _ = cancel_fut => { + responder.send(None).unwrap(); + } + } } }); - tx + (req_tx, cancel_tx) } -async fn query_chat(messages: &Arc>) -> Vec { - static SENDER: LazyLock> = LazyLock::new(chatbot_thread); +static CHATBOT_THREAD: LazyLock<(mpsc::Sender, mpsc::Sender<()>)> = + LazyLock::new(chatbot_thread); +async fn query_chat(messages: &Arc>) -> Option> { let (tx, rx) = oneshot::channel(); - SENDER.send((Arc::clone(messages), tx)).await.unwrap(); + CHATBOT_THREAD + .0 + .send((Arc::clone(messages), tx)) + .await + .unwrap(); rx.await.unwrap() } +async fn cancel(_req: Request) -> Response { + CHATBOT_THREAD.1.send(()).await.unwrap(); + Ok(Content::Html("success".into())) +} + async fn chat(req: Request) -> Response { let Request::Post(body) = req else { return Err(StatusCode::METHOD_NOT_ALLOWED); }; - let Ok(mut data) = serde_json::from_str::(&body) else { + let Ok(mut data) = serde_json::from_str::(&body) else { return Err(StatusCode::INTERNAL_SERVER_ERROR); }; let messages = Arc::new(data.messages); - let (i, mut responses) = join!(chatbot::gen_random_number(), query_chat(&messages)); + let (i, responses_opt) = join!(chatbot::gen_random_number(), query_chat(&messages)); - let response = responses.remove(i % responses.len()); - data.messages = Arc::into_inner(messages).unwrap(); - data.messages.push(response); + let response = match responses_opt { + Some(mut responses) => { + let response = responses.remove(i % responses.len()); + data.messages = Arc::into_inner(messages).unwrap(); + data.messages.push(response); - Ok(Content::Json(serde_json::to_string(&data).unwrap())) + MessagesResponse::Success { + messages: data.messages, + } + } + None => MessagesResponse::Cancelled, + }; + + Ok(Content::Json(serde_json::to_string(&response).unwrap())) } #[tokio::main] @@ -80,6 +115,7 @@ async fn main() { miniserve::Server::new() .route("/", index) .route("/chat", chat) + .route("/cancel", cancel) .run() .await }