diff --git a/rama-http/Cargo.toml b/rama-http/Cargo.toml index ac596aa22..47dfa0dfd 100644 --- a/rama-http/Cargo.toml +++ b/rama-http/Cargo.toml @@ -59,6 +59,7 @@ tokio = { workspace = true, features = ["macros", "fs", "io-std"] } tokio-util = { workspace = true, features = ["io"] } tracing = { workspace = true } uuid = { workspace = true, features = ["v4"] } +matchit = "0.8.6" [dev-dependencies] brotli = { workspace = true } diff --git a/rama-http/src/service/web/mod.rs b/rama-http/src/service/web/mod.rs index 7c39386ba..58680c7ba 100644 --- a/rama-http/src/service/web/mod.rs +++ b/rama-http/src/service/web/mod.rs @@ -11,3 +11,7 @@ pub use endpoint::{extract, EndpointServiceFn, IntoEndpointService}; pub mod k8s; #[doc(inline)] pub use k8s::{k8s_health, k8s_health_builder}; + +mod router; +#[doc(inline)] +pub use router::{Router}; diff --git a/rama-http/src/service/web/router.rs b/rama-http/src/service/web/router.rs new file mode 100644 index 000000000..47508aeb1 --- /dev/null +++ b/rama-http/src/service/web/router.rs @@ -0,0 +1,173 @@ +use super::{IntoEndpointService}; +use crate::{Request, Response, StatusCode}; +use rama_core::{service::{BoxService, Service, service_fn}, Context}; +use std::{convert::Infallible, sync::Arc}; +use std::collections::HashMap; +use http::{Method}; + +use matchit::Router as MatchitRouter; +use rama_http_types::IntoResponse; + +/// A basic web router that can be used to serve HTTP requests based on path matching. +/// It will also provide extraction of path parameters and wildcards out of the box so +/// you can define your paths accordingly. + +pub struct Router { + routes: MatchitRouter>>>, + not_found: Arc>, +} + +impl std::fmt::Debug for Router { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Router").finish() + } +} + +impl Clone for Router { + fn clone(&self) -> Self { + Self { + routes: self.routes.clone(), + not_found: self.not_found.clone(), + } + } +} + +/// default trait +impl Default for Router +where + State: Clone + Send + Sync + 'static, +{ + fn default() -> Self { + Self::new() + } +} + +impl Router +where + State: Clone + Send + Sync + 'static, +{ + /// create a new web router + pub(crate) fn new() -> Self { + Self { + routes: MatchitRouter::new(), + not_found: Arc::new( + service_fn(|| async { Ok(StatusCode::NOT_FOUND.into_response()) }).boxed(), + ), + } + } + + pub fn route(mut self, method: Method, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + let boxed_service = Arc::new(BoxService::new(service.into_endpoint_service())); + match self.routes.insert(path.to_string(), HashMap::new()) { + Ok(_) => { + if let Some(entry) = self.routes.at_mut(path).ok() { + entry.value.insert(method, boxed_service); + } + }, + Err(_err) => { + if let Some(existing) = self.routes.at_mut(path).ok() { + existing.value.insert(method, boxed_service); + } + } + }; + self + } + + pub fn get(self, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + self.route(Method::GET, path, service) + } + + pub fn post(self, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + self.route(Method::POST, path, service) + } + + pub fn put(self, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + self.route(Method::PUT, path, service) + } + + pub fn delete(self, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + self.route(Method::DELETE, path, service) + } + + pub fn patch(self, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + self.route(Method::PATCH, path, service) + } + + pub fn head(self, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + self.route(Method::HEAD, path, service) + } + + pub fn options(self, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + self.route(Method::OPTIONS, path, service) + } + + pub fn trace(self, path: &str, service: I) -> Self + where + I: IntoEndpointService, + { + self.route(Method::TRACE, path, service) + } + + /// use the given service in case no match could be found. + pub fn not_found(mut self, service: I) -> Self + where + I: IntoEndpointService, + { + self.not_found = Arc::new(service.into_endpoint_service().boxed()); + self + } +} + +impl Service for Router +where + State: Clone + Send + Sync + 'static, +{ + type Response = Response; + type Error = Infallible; + + async fn serve( + &self, + mut ctx: Context, + req: Request<>, + ) -> Result { + let uri_string = req.uri().to_string(); + match &self.routes.at(uri_string.as_str()) { + Ok(matched) => { + let params: HashMap = matched.params.clone().iter().map(|(k, v)| (k.to_string(), v.to_string())).collect(); + ctx.insert(params); + if let Some(service) = matched.value.get(&req.method()) { + service.boxed().serve(ctx, req).await + } else { + self.not_found.serve(ctx, req).await + } + }, + Err(_err) => { + self.not_found.serve(ctx, req).await + } + } + } +}