Skip to content

Commit

Permalink
Cancellation solution
Browse files Browse the repository at this point in the history
  • Loading branch information
willcrichton committed Sep 4, 2024
1 parent 1203af6 commit 264df8e
Showing 1 changed file with 53 additions and 17 deletions.
70 changes: 53 additions & 17 deletions crates/server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ async fn index(_req: Request) -> Response {
}

#[derive(Serialize, Deserialize)]
struct Messages {
struct MessagesRequest {
messages: Vec<String>,
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
enum MessagesResponse {
Success { messages: Vec<String> },
Cancelled,
}

async fn load_docs(paths: Vec<PathBuf>) -> Vec<String> {
let mut doc_futs = paths
.into_iter()
Expand All @@ -33,53 +40,82 @@ async fn load_docs(paths: Vec<PathBuf>) -> Vec<String> {
docs
}

type Payload = (Arc<Vec<String>>, oneshot::Sender<Vec<String>>);
type Payload = (Arc<Vec<String>>, oneshot::Sender<Option<Vec<String>>>);

fn chatbot_thread() -> mpsc::Sender<Payload> {
let (tx, mut rx) = mpsc::channel::<Payload>(1024);
fn chatbot_thread() -> (mpsc::Sender<Payload>, mpsc::Sender<()>) {
let (req_tx, mut req_rx) = mpsc::channel::<Payload>(1024);
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
tokio::spawn(async move {
let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]);
while let Some((messages, responder)) = rx.recv().await {
while let Some((messages, responder)) = req_rx.recv().await {
let doc_paths = chatbot.retrieval_documents(&messages);
let docs = load_docs(doc_paths).await;
let response = chatbot.query_chat(&messages, &docs).await;
responder.send(response).unwrap();
let chat_fut = chatbot.query_chat(&messages, &docs);
let cancel_fut = cancel_rx.recv();
tokio::select! {
response = chat_fut => {
responder.send(Some(response)).unwrap();
}
_ = cancel_fut => {
responder.send(None).unwrap();
}
}
}
});
tx
(req_tx, cancel_tx)
}

async fn query_chat(messages: &Arc<Vec<String>>) -> Vec<String> {
static SENDER: LazyLock<mpsc::Sender<Payload>> = LazyLock::new(chatbot_thread);
static CHATBOT_THREAD: LazyLock<(mpsc::Sender<Payload>, mpsc::Sender<()>)> =
LazyLock::new(chatbot_thread);

async fn query_chat(messages: &Arc<Vec<String>>) -> Option<Vec<String>> {
let (tx, rx) = oneshot::channel();
SENDER.send((Arc::clone(messages), tx)).await.unwrap();
CHATBOT_THREAD
.0
.send((Arc::clone(messages), tx))
.await
.unwrap();
rx.await.unwrap()
}

async fn cancel(_req: Request) -> Response {
CHATBOT_THREAD.1.send(()).await.unwrap();
Ok(Content::Html("success".into()))
}

async fn chat(req: Request) -> Response {
let Request::Post(body) = req else {
return Err(StatusCode::METHOD_NOT_ALLOWED);
};
let Ok(mut data) = serde_json::from_str::<Messages>(&body) else {
let Ok(mut data) = serde_json::from_str::<MessagesRequest>(&body) else {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};

let messages = Arc::new(data.messages);
let (i, mut responses) = join!(chatbot::gen_random_number(), query_chat(&messages));
let (i, responses_opt) = join!(chatbot::gen_random_number(), query_chat(&messages));

let response = responses.remove(i % responses.len());
data.messages = Arc::into_inner(messages).unwrap();
data.messages.push(response);
let response = match responses_opt {
Some(mut responses) => {
let response = responses.remove(i % responses.len());
data.messages = Arc::into_inner(messages).unwrap();
data.messages.push(response);

Ok(Content::Json(serde_json::to_string(&data).unwrap()))
MessagesResponse::Success {
messages: data.messages,
}
}
None => MessagesResponse::Cancelled,
};

Ok(Content::Json(serde_json::to_string(&response).unwrap()))
}

#[tokio::main]
async fn main() {
miniserve::Server::new()
.route("/", index)
.route("/chat", chat)
.route("/cancel", cancel)
.run()
.await
}

0 comments on commit 264df8e

Please sign in to comment.