Skip to content

Commit

Permalink
Merge pull request #243 from LemmyNet/main
Browse files Browse the repository at this point in the history
[pull] master from LemmyNet:main
  • Loading branch information
pull[bot] authored Jan 21, 2025
2 parents e09f1d2 + 31b8a4b commit cbe7bce
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 3 deletions.
38 changes: 38 additions & 0 deletions crates/db_schema/replaceable_schema/triggers.sql
Original file line number Diff line number Diff line change
Expand Up @@ -889,3 +889,41 @@ CALL r.create_inbox_combined_trigger ('person_post_mention');

CALL r.create_inbox_combined_trigger ('private_message');

-- Prevent using delete instead of uplete on action tables
CREATE FUNCTION r.require_uplete ()
RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
BEGIN
IF pg_trigger_depth() = 1 AND NOT starts_with (current_query(), '/**/') THEN
RAISE 'using delete instead of uplete is not allowed for this table';
END IF;
RETURN NULL;
END
$$;

CREATE TRIGGER require_uplete
BEFORE DELETE ON comment_actions
FOR EACH STATEMENT
EXECUTE FUNCTION r.require_uplete ();

CREATE TRIGGER require_uplete
BEFORE DELETE ON community_actions
FOR EACH STATEMENT
EXECUTE FUNCTION r.require_uplete ();

CREATE TRIGGER require_uplete
BEFORE DELETE ON instance_actions
FOR EACH STATEMENT
EXECUTE FUNCTION r.require_uplete ();

CREATE TRIGGER require_uplete
BEFORE DELETE ON person_actions
FOR EACH STATEMENT
EXECUTE FUNCTION r.require_uplete ();

CREATE TRIGGER require_uplete
BEFORE DELETE ON post_actions
FOR EACH STATEMENT
EXECUTE FUNCTION r.require_uplete ();

5 changes: 4 additions & 1 deletion crates/db_schema/src/utils/uplete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ impl QueryFragment<Pg> for UpleteQuery {
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> Result<(), Error> {
assert_ne!(self.set_null_columns.len(), 0, "`set_null` was not called");

// This is checked by require_uplete triggers
out.push_sql("/**/");

// Declare `update_keys` and `delete_keys` CTEs, which select primary keys
for (prefix, subquery) in [
("WITH update_keys", &self.update_subquery),
Expand Down Expand Up @@ -357,7 +360,7 @@ mod tests {
let update_count = "SELECT count(*) FROM update_result";
let delete_count = "SELECT count(*) FROM delete_result";

format!(r#"WITH {with_queries} SELECT ({update_count}), ({delete_count}) -- binds: []"#)
format!(r#"/**/WITH {with_queries} SELECT ({update_count}), ({delete_count}) -- binds: []"#)
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions crates/utils/src/rate_limit/rate_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ static START_TIME: LazyLock<Instant> = LazyLock::new(Instant::now);

/// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
/// store nanoseconds
#[derive(PartialEq, Debug, Clone, Copy)]
#[derive(PartialEq, Debug, Clone, Copy, Hash)]
pub struct InstantSecs {
secs: u32,
pub secs: u32,
}

#[allow(clippy::expect_used)]
Expand Down
176 changes: 176 additions & 0 deletions src/idempotency_middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use actix_web::{
body::EitherBody,
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
http::Method,
Error,
HttpMessage,
HttpResponse,
};
use futures_util::future::LocalBoxFuture;
use lemmy_api_common::lemmy_db_views::structs::LocalUserView;
use lemmy_db_schema::newtypes::LocalUserId;
use lemmy_utils::rate_limit::rate_limiter::InstantSecs;
use std::{
collections::HashSet,
future::{ready, Ready},
hash::{Hash, Hasher},
sync::{Arc, RwLock},
time::Duration,
};

/// https://www.ietf.org/archive/id/draft-ietf-httpapi-idempotency-key-header-01.html
const IDEMPOTENCY_HEADER: &str = "Idempotency-Key";

/// Delete idempotency keys older than this
const CLEANUP_INTERVAL_SECS: u32 = 120;

#[derive(Debug)]
struct Entry {
user_id: LocalUserId,
key: String,
// Creation time is ignored for Eq, Hash and only used to cleanup old entries
created: InstantSecs,
}

impl PartialEq for Entry {
fn eq(&self, other: &Self) -> bool {
self.user_id == other.user_id && self.key == other.key
}
}
impl Eq for Entry {}

impl Hash for Entry {
fn hash<H: Hasher>(&self, state: &mut H) {
self.user_id.hash(state);
self.key.hash(state);
}
}

#[derive(Clone)]
pub struct IdempotencySet {
set: Arc<RwLock<HashSet<Entry>>>,
}

impl Default for IdempotencySet {
fn default() -> Self {
let set: Arc<RwLock<HashSet<Entry>>> = Default::default();

let set_ = set.clone();
tokio::spawn(async move {
let interval = Duration::from_secs(CLEANUP_INTERVAL_SECS.into());
let state_weak_ref = Arc::downgrade(&set_);

// Run at every interval to delete entries older than the interval.
// This loop stops when all other references to `state` are dropped.
while let Some(state) = state_weak_ref.upgrade() {
tokio::time::sleep(interval).await;
let now = InstantSecs::now();
#[allow(clippy::expect_used)]
let mut lock = state.write().expect("lock failed");
lock.retain(|e| e.created.secs > now.secs.saturating_sub(CLEANUP_INTERVAL_SECS));
lock.shrink_to_fit();
}
});
Self { set }
}
}

pub struct IdempotencyMiddleware {
idempotency_set: IdempotencySet,
}

impl IdempotencyMiddleware {
pub fn new(idempotency_set: IdempotencySet) -> Self {
Self { idempotency_set }
}
}

impl<S, B> Transform<S, ServiceRequest> for IdempotencyMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type InitError = ();
type Transform = IdempotencyService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;

fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(IdempotencyService {
service,
idempotency_set: self.idempotency_set.clone(),
}))
}
}

pub struct IdempotencyService<S> {
service: S,
idempotency_set: IdempotencySet,
}

impl<S, B> Service<ServiceRequest> for IdempotencyService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

forward_ready!(service);

#[allow(clippy::expect_used)]
fn call(&self, req: ServiceRequest) -> Self::Future {
let is_post_or_put = req.method() == Method::POST || req.method() == Method::PUT;
let idempotency = req
.headers()
.get(IDEMPOTENCY_HEADER)
.map(|i| i.to_str().unwrap_or_default().to_string())
// Ignore values longer than 32 chars
.and_then(|i| (i.len() <= 32).then_some(i))
// Only use idempotency for POST and PUT requests
.and_then(|i| is_post_or_put.then_some(i));

let user_id = {
let ext = req.extensions();
ext.get().map(|u: &LocalUserView| u.local_user.id)
};

if let (Some(key), Some(user_id)) = (idempotency, user_id) {
let value = Entry {
user_id,
key,
created: InstantSecs::now(),
};
if self
.idempotency_set
.set
.read()
.expect("lock failed")
.contains(&value)
{
// Duplicate request, return error
let (req, _pl) = req.into_parts();
let response = HttpResponse::UnprocessableEntity()
.finish()
.map_into_right_body();
return Box::pin(async { Ok(ServiceResponse::new(req, response)) });
} else {
// New request, store key and continue
self
.idempotency_set
.set
.write()
.expect("lock failed")
.insert(value);
}
}

let fut = self.service.call(req);

Box::pin(async move { fut.await.map(ServiceResponse::map_into_left_body) })
}
}
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod api_routes_v3;
pub mod api_routes_v4;
pub mod code_migrations;
pub mod idempotency_middleware;
pub mod prometheus_metrics;
pub mod scheduled_tasks;
pub mod session_middleware;
Expand All @@ -18,6 +19,7 @@ use actix_web::{
};
use actix_web_prom::PrometheusMetricsBuilder;
use clap::{Parser, Subcommand};
use idempotency_middleware::{IdempotencyMiddleware, IdempotencySet};
use lemmy_api::sitemap::get_sitemap;
use lemmy_api_common::{
context::LemmyContext,
Expand Down Expand Up @@ -334,6 +336,9 @@ fn create_http_server(
.build()
.map_err(|e| LemmyErrorType::Unknown(format!("Should always be buildable: {e}")))?;

// Must create this outside of HTTP server so that duplicate requests get detected across threads.
let idempotency_set = IdempotencySet::default();

// Create Http server
let bind = (settings.bind, settings.port);
let server = HttpServer::new(move || {
Expand All @@ -355,6 +360,7 @@ fn create_http_server(
.app_data(Data::new(context.clone()))
.app_data(Data::new(rate_limit_cell.clone()))
.wrap(FederationMiddleware::new(federation_config.clone()))
.wrap(IdempotencyMiddleware::new(idempotency_set.clone()))
.wrap(SessionMiddleware::new(context.clone()))
.wrap(Condition::new(
SETTINGS.prometheus.is_some(),
Expand Down

0 comments on commit cbe7bce

Please sign in to comment.