Skip to content

Commit

Permalink
🐳 Add limits and timeouts. Fixes #1
Browse files Browse the repository at this point in the history
  • Loading branch information
coreequip committed Feb 10, 2025
1 parent e7f3ca7 commit eae8c69
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use trust_dns_resolver::TokioAsyncResolver;
#[derive(Debug, Clone)]
struct SpfChecker {
resolver: Arc<TokioAsyncResolver>,
max_depth: usize,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -46,11 +47,15 @@ struct ErrorResponse {

impl SpfChecker {
async fn new() -> Result<Self> {
let resolver =
TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
let mut opts = ResolverOpts::default();
opts.timeout = std::time::Duration::from_secs(2);
opts.attempts = 2;

let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), opts);

Ok(Self {
resolver: Arc::new(resolver),
max_depth: 10,
})
}

Expand All @@ -77,6 +82,15 @@ impl SpfChecker {
target: String,
visited: &mut HashSet<String>,
) -> Result<(bool, Option<String>, Option<Vec<String>>)> {
if visited.len() >= self.max_depth {
log_message(&format!(
"Maximum recursion depth of {} reached. Visited domains: {:?}",
self.max_depth,
visited.iter().collect::<Vec<_>>()
));
return Ok((false, None, None));
}

if !visited.insert(domain.clone()) {
return Ok((false, None, None));
}
Expand Down Expand Up @@ -151,8 +165,6 @@ async fn check_spf(
Query(params): Query<SpfCheckParams>,
checker: axum::extract::State<Arc<SpfChecker>>,
) -> impl IntoResponse {
log_message(&format!("Request to check \"{}\" for \"{}\"", params.domain, params.target));

let start = std::time::Instant::now();
let mut visited = HashSet::new();

Expand All @@ -161,19 +173,36 @@ async fn check_spf(
.await
{
Ok((found, spf_record, included_domains)) => {
let elapsed_ms = start.elapsed().as_millis() as u64;
log_message(&format!(
"Successfully checked \"{}\" for \"{}\" ({}ms)",
params.domain,
params.target,
elapsed_ms
));

let response = SpfCheckResponse {
found,
checked_domains: visited.len(),
domain: params.domain,
target: params.target,
elapsed_ms: start.elapsed().as_millis() as u64,
elapsed_ms,
has_spf_record: spf_record.is_some(),
spf_record,
included_domains,
};
(StatusCode::OK, Json(response)).into_response()
}
Err(err) => {
let elapsed_ms = start.elapsed().as_millis() as u64;
log_message(&format!(
"Failed to check \"{}\" for \"{}\": {} ({}ms)",
params.domain,
params.target,
err,
elapsed_ms
));

let error = ErrorResponse {
error: err.to_string(),
};
Expand Down

0 comments on commit eae8c69

Please sign in to comment.