diff --git a/src/sql_injection/detect_sql_injection.rs b/src/sql_injection/detect_sql_injection.rs index 0bf3c7b..725dfdc 100644 --- a/src/sql_injection/detect_sql_injection.rs +++ b/src/sql_injection/detect_sql_injection.rs @@ -2,7 +2,6 @@ use super::have_comments_changed::have_comments_changed; use super::is_common_sql_string::is_common_sql_string; use super::tokenize_query::tokenize_query; use crate::diff_in_vec_len; -use sqlparser::tokenizer::Token; const SPACE_CHAR: char = ' '; diff --git a/src/sql_injection/detect_sql_injection_test.rs b/src/sql_injection/detect_sql_injection_test.rs index 3e76c13..51fa1b8 100644 --- a/src/sql_injection/detect_sql_injection_test.rs +++ b/src/sql_injection/detect_sql_injection_test.rs @@ -2,13 +2,40 @@ mod tests { use crate::sql_injection::detect_sql_injection::detect_sql_injection_str; + fn dialect(s: &str) -> i32 { + match s { + "mysql" => 8, + "postgresql" => 9, + "sqlite" => 12, + "clickhouse" => 3, + _ => panic!("Unknown dialect"), + } + } + + fn get_supported_dialects() -> Vec { + vec![ + dialect("mysql"), + dialect("postgresql"), + dialect("sqlite"), + dialect("clickhouse"), + ] + } + macro_rules! is_injection { ($query:expr, $input:expr) => { - assert!(detect_sql_injection_str( - &$query.to_lowercase(), - &$input.to_lowercase(), - 0 - )) + for dia in get_supported_dialects().iter() { + assert!( + detect_sql_injection_str( + &$query.to_lowercase(), + &$input.to_lowercase(), + dia.clone() + ), + "should be an injection\nquery: {}\ninput: {}\ndialect: {}\n", + $query, + $input, + dia.clone() + ) + } }; ($query:expr, $input:expr, $dialect:expr) => { assert!(detect_sql_injection_str( @@ -20,11 +47,19 @@ mod tests { } macro_rules! not_injection { ($query:expr, $input:expr) => { - assert!(!detect_sql_injection_str( - &$query.to_lowercase(), - &$input.to_lowercase(), - 0 - )) + for dia in get_supported_dialects().iter() { + assert!( + !detect_sql_injection_str( + &$query.to_lowercase(), + &$input.to_lowercase(), + dia.clone() + ), + "should not be an injection\nquery: {}\ninput: {}\ndialect: {}\n", + $query, + $input, + dia.clone() + ) + } }; ($query:expr, $input:expr, $dialect:expr) => { assert!(!detect_sql_injection_str( @@ -35,28 +70,22 @@ mod tests { }; } - fn dialect(s: &str) -> i32 { - match s { - "mysql" => 8, - "postgresql" => 9, - "sqlite" => 12, - _ => panic!("Unknown dialect"), - } - } - #[test] fn test_postgres_dollar_signs() { not_injection!( "SELECT * FROM users WHERE id = $$' OR 1=1 -- $$", - "' OR 1=1 -- " + "' OR 1=1 -- ", + dialect("postgresql") ); not_injection!( "SELECT * FROM users WHERE id = $$1; DROP TABLE users; -- $$", - "1; DROP TABLE users; -- " + "1; DROP TABLE users; -- ", + dialect("postgresql") ); is_injection!( "SELECT * FROM users WHERE id = $$1$$ OR 1=1 -- $$", - "1$$ OR 1=1 -- " + "1$$ OR 1=1 -- ", + dialect("postgresql") ); } @@ -64,15 +93,18 @@ mod tests { fn test_postgres_dollar_named_dollar_signs() { not_injection!( "SELECT * FROM users WHERE id = $name$' OR 1=1 -- $name$", - "' OR 1=1 -- " + "' OR 1=1 -- ", + dialect("postgresql") ); not_injection!( "SELECT * FROM users WHERE id = $name$1; DROP TABLE users; -- $name$", - "1; DROP TABLE users; -- " + "1; DROP TABLE users; -- ", + dialect("postgresql") ); is_injection!( "SELECT * FROM users WHERE id = $name$1$name$ OR 1=1 -- $name$", - "1$name$ OR 1=1 -- " + "1$name$ OR 1=1 -- ", + dialect("postgresql") ); }