Skip to content

Commit

Permalink
ip rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
ermalkaleci committed Nov 30, 2023
1 parent 43cbd17 commit 2f02bb7
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 60 deletions.
11 changes: 8 additions & 3 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ extensions:
- path: /liveness
method: chain_getBlockHash
cors: all
rate_limit: # 20 requests per second per connection
burst: 20
period_secs: 1
rate_limit:
rules:
- burst: 20 # 20 requests per second per connection
period_secs: 1
apply_to: connection # default is ip
- burst: 500 # 500 requests per 10 seconds per ip
period_secs: 10
apply_to: ip

middlewares:
methods:
Expand Down
155 changes: 144 additions & 11 deletions src/extensions/rate_limit/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use futures::{future::BoxFuture, FutureExt};
use governor::{DefaultDirectRateLimiter, Jitter, Quota, RateLimiter};
use governor::{DefaultDirectRateLimiter, DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter};
use jsonrpsee::{
server::{middleware::rpc::RpcServiceT, types::Request},
MethodResponse,
Expand All @@ -10,8 +10,21 @@ use std::{sync::Arc, time::Duration};

use super::{Extension, ExtensionRegistry};

#[derive(Deserialize, Debug, Copy, Clone, Default)]
#[derive(Deserialize, Debug, Clone, Default, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RateLimitType {
#[default]
Ip,
Connection,
}

#[derive(Deserialize, Debug, Clone, Default)]
pub struct RateLimitConfig {
pub rules: Vec<Rule>,
}

#[derive(Deserialize, Debug, Clone, Default)]
pub struct Rule {
// burst is the maximum number of requests that can be made in a period
pub burst: u32,
// period is the period of time in which the burst is allowed
Expand All @@ -22,6 +35,8 @@ pub struct RateLimitConfig {
// e.g. if jitter_up_to_millis is 1000, then additional delay of random(0, 1000) milliseconds will be added
#[serde(default = "default_jitter_up_to_millis")]
pub jitter_up_to_millis: u64,
#[serde(default)]
pub apply_to: RateLimitType,
}

fn default_period_secs() -> u64 {
Expand All @@ -34,28 +49,82 @@ fn default_jitter_up_to_millis() -> u64 {

pub struct RateLimitBuilder {
config: RateLimitConfig,
ip_jitter: Option<Jitter>,
ip_limiter: Option<Arc<DefaultKeyedRateLimiter<String>>>,
}

#[async_trait::async_trait]
impl Extension for RateLimitBuilder {
type Config = RateLimitConfig;

async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result<Self, anyhow::Error> {
Ok(Self::new(*config))
Ok(Self::new(config.clone()))
}
}

impl RateLimitBuilder {
pub fn new(config: RateLimitConfig) -> Self {
assert!(config.burst > 0, "burst must be greater than 0");
assert!(config.period_secs > 0, "period_secs must be greater than 0");
Self { config }
// make sure there is at least one rule
assert!(!config.rules.is_empty(), "must have at least one rule");
// make sure there is at most one ip rule
assert!(
config.rules.iter().filter(|r| r.apply_to == RateLimitType::Ip).count() <= 1,
"can only have one ip rule"
);
// make sure there is at most one connection rule
assert!(
config
.rules
.iter()
.filter(|r| r.apply_to == RateLimitType::Connection)
.count()
<= 1,
"can only have one connection rule"
);
// make sure all rules are valid
for rule in config.rules.iter() {
assert!(rule.burst > 0, "burst must be greater than 0");
assert!(rule.period_secs > 0, "period_secs must be greater than 0");
}

if let Some(rule) = config.rules.iter().find(|r| r.apply_to == RateLimitType::Ip) {
let burst = NonZeroU32::new(rule.burst).unwrap();
let replenish_interval_ns = Duration::from_secs(rule.period_secs).as_nanos() / (burst.get() as u128);
let quota = Quota::with_period(Duration::from_nanos(replenish_interval_ns as u64))
.unwrap()
.allow_burst(burst);
Self {
config: config.clone(),
ip_jitter: Some(Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis))),
ip_limiter: Some(Arc::new(RateLimiter::keyed(quota))),
}
} else {
Self {
config,
ip_jitter: None,
ip_limiter: None,
}
}
}
pub fn connection_limit(&self) -> Option<RateLimit> {
if let Some(rule) = self
.config
.rules
.iter()
.find(|r| r.apply_to == RateLimitType::Connection)
{
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let jitter = Jitter::up_to(Duration::from_millis(rule.jitter_up_to_millis));
Some(RateLimit::new(burst, period, jitter))
} else {
None
}
}
pub fn build(&self) -> RateLimit {
let burst = NonZeroU32::new(self.config.burst).unwrap();
let period = Duration::from_secs(self.config.period_secs);
let jitter = Jitter::up_to(Duration::from_millis(self.config.jitter_up_to_millis));
RateLimit::new(burst, period, jitter)
pub fn ip_limit(&self, remote_ip: String) -> Option<IpRateLimitService> {
self.ip_limiter.as_ref().map(|ip_limiter| {
IpRateLimitService::new(remote_ip, ip_limiter.clone(), self.ip_jitter.unwrap_or_default())
})
}
}

Expand Down Expand Up @@ -121,6 +190,70 @@ where
}
}

#[derive(Clone)]
pub struct IpRateLimitService {
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
}

impl IpRateLimitService {
pub fn new(key: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self {
Self {
ip_addr: key,
limiter,
jitter,
}
}
}

impl<S> tower::Layer<S> for IpRateLimitService {
type Service = IpRateLimitLayer<S>;

fn layer(&self, service: S) -> Self::Service {
IpRateLimitLayer::new(service, self.ip_addr.clone(), self.limiter.clone(), self.jitter)
}
}

#[derive(Clone)]
pub struct IpRateLimitLayer<S> {
service: S,
ip_addr: String,
limiter: Arc<DefaultKeyedRateLimiter<String>>,
jitter: Jitter,
}

impl<S> IpRateLimitLayer<S> {
pub fn new(service: S, ip_addr: String, limiter: Arc<DefaultKeyedRateLimiter<String>>, jitter: Jitter) -> Self {
Self {
service,
ip_addr,
limiter,
jitter,
}
}
}

impl<'a, S> RpcServiceT<'a> for IpRateLimitLayer<S>
where
S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
{
type Future = BoxFuture<'a, MethodResponse>;

fn call(&self, req: Request<'a>) -> Self::Future {
let ip_addr = self.ip_addr.clone();
let jitter = self.jitter;
let service = self.service.clone();
let limiter = self.limiter.clone();

async move {
limiter.until_key_ready_with_jitter(&ip_addr, jitter).await;
service.call(req).await
}
.boxed()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
118 changes: 78 additions & 40 deletions src/extensions/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
use std::{future::Future, net::SocketAddr};

use async_trait::async_trait;
use http::header::HeaderValue;
use hyper::server::conn::AddrStream;
use hyper::service::Service;
use hyper::service::{make_service_fn, service_fn};
use jsonrpsee::server::{
middleware::rpc::RpcServiceBuilder, RandomStringIdProvider, RpcModule, ServerBuilder, ServerHandle,
middleware::rpc::RpcServiceBuilder, stop_channel, RandomStringIdProvider, RpcModule, ServerBuilder, ServerHandle,
};
use serde::ser::StdError;
use serde::Deserialize;
use std::str::FromStr;
use std::sync::Arc;
use std::{future::Future, net::SocketAddr};
use tower::ServiceBuilder;
use tower_http::cors::{AllowOrigin, CorsLayer};

use super::{rate_limit::RateLimit, Extension, ExtensionRegistry};
use proxy_get_request::ProxyGetRequestLayer;

use self::proxy_get_request::ProxyGetRequestMethod;
use super::{Extension, ExtensionRegistry};
use crate::extensions::rate_limit::RateLimitBuilder;

mod proxy_get_request;
use proxy_get_request::{ProxyGetRequestLayer, ProxyGetRequestMethod};

pub struct SubwayServerBuilder {
pub config: ServerConfig,
Expand Down Expand Up @@ -90,40 +95,73 @@ impl SubwayServerBuilder {

pub async fn build<Fut: Future<Output = anyhow::Result<RpcModule<()>>>>(
&self,
rate_limit: Option<RateLimit>,
rate_limit_builder: Option<Arc<RateLimitBuilder>>,
builder: impl FnOnce() -> Fut,
) -> anyhow::Result<(SocketAddr, ServerHandle)> {
let rpc_middleware = RpcServiceBuilder::new().option_layer(rate_limit);

let service_builder = tower::ServiceBuilder::new()
.layer(cors_layer(self.config.cors.clone()).expect("Invalid CORS config"))
.layer(
ProxyGetRequestLayer::new(
self.config
.http_methods
.iter()
.map(|m| ProxyGetRequestMethod {
path: m.path.clone(),
method: m.method.clone(),
})
.collect(),
)
.expect("Invalid health config"),
);

let server = ServerBuilder::default()
.set_rpc_middleware(rpc_middleware)
.set_http_middleware(service_builder)
.max_connections(self.config.max_connections)
.set_id_provider(RandomStringIdProvider::new(16))
.build((self.config.listen_address.as_str(), self.config.port))
.await?;

let module = builder().await?;

let addr = server.local_addr()?;
let server = server.start(module);

Ok((addr, server))
let config = self.config.clone();

let (stop_handle, server_handle) = stop_channel();
let handle = stop_handle.clone();
let methods = builder().await?;

// make_service handle each connection
let make_service = make_service_fn(move |socket: &AddrStream| {
let remote_ip = socket.remote_addr().ip().to_string();
let rpc_middleware = RpcServiceBuilder::new()
.option_layer(rate_limit_builder.as_ref().and_then(|r| r.ip_limit(remote_ip)))
.option_layer(rate_limit_builder.as_ref().and_then(|r| r.connection_limit()));

let http_middleware: ServiceBuilder<_> = tower::ServiceBuilder::new()
.layer(cors_layer(config.cors.clone()).expect("Invalid CORS config"))
.layer(
ProxyGetRequestLayer::new(
config
.http_methods
.iter()
.map(|m| ProxyGetRequestMethod {
path: m.path.clone(),
method: m.method.clone(),
})
.collect(),
)
.expect("Invalid health config"),
);

let service_builder = ServerBuilder::default()
.set_rpc_middleware(rpc_middleware)
.set_http_middleware(http_middleware)
.max_connections(config.max_connections)
.set_id_provider(RandomStringIdProvider::new(16))
.to_service_builder();

let methods = methods.clone();
let stop_handle = stop_handle.clone();
let service_builder = service_builder.clone();

async move {
// service_fn handle each request
Ok::<_, Box<dyn StdError + Send + Sync>>(service_fn(move |req| {
let methods = methods.clone();
let stop_handle = stop_handle.clone();
let service_builder = service_builder.clone();

let mut service = service_builder.build(methods, stop_handle);
service.call(req)
}))
}
});

let ip_addr = std::net::IpAddr::from_str(&self.config.listen_address)?;
let addr = SocketAddr::new(ip_addr, self.config.port);

let server = hyper::Server::bind(&addr).serve(make_service);
let addr = server.local_addr();

tokio::spawn(async move {
let graceful = server.with_graceful_shutdown(async move { handle.shutdown().await });
graceful.await.unwrap()
});

Ok((addr, server_handle))
}
}
8 changes: 2 additions & 6 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,13 @@ pub async fn build(config: Config) -> anyhow::Result<SubwayServerHandle> {
.get::<SubwayServerBuilder>()
.expect("Server extension not found");

let rate_limit = extensions_registry
.read()
.await
.get::<RateLimitBuilder>()
.map(|b| b.build());
let rate_limit_builder = extensions_registry.read().await.get::<RateLimitBuilder>();

let request_timeout_seconds = server_builder.config.request_timeout_seconds;

let registry = extensions_registry.clone();
let (addr, handle) = server_builder
.build(rate_limit, move || async move {
.build(rate_limit_builder, move || async move {
let mut module = RpcModule::new(());

let tracer = telemetry::Tracer::new("server");
Expand Down

0 comments on commit 2f02bb7

Please sign in to comment.