From 86a8905d196a4952359e135507b2eaa8b190fb9b Mon Sep 17 00:00:00 2001 From: msm-code Date: Wed, 3 Jun 2020 22:29:22 +0200 Subject: [PATCH] Fix performance and accuracy for wide ascii strings (#155) Due to a refactoring earlier (probably), we started returning many more results than we should. This commit should fix this, without affecting performance negatively. --- libursa/DatabaseSnapshot.cpp | 5 ++- libursa/OnDiskDataset.cpp | 47 +++++++++--------------- libursa/OnDiskDataset.h | 16 +------- libursa/OnDiskIndex.h | 5 --- libursa/Query.cpp | 71 +++++++++++++++++++++++------------- libursa/Query.h | 41 +++++++++++++++++---- src/Tests.cpp | 10 ++--- teste2e/test_select.py | 68 +++++++++++++++++----------------- teste2e/util.py | 23 +++++++++--- 9 files changed, 159 insertions(+), 127 deletions(-) diff --git a/libursa/DatabaseSnapshot.cpp b/libursa/DatabaseSnapshot.cpp index bf073388..a908b61a 100644 --- a/libursa/DatabaseSnapshot.cpp +++ b/libursa/DatabaseSnapshot.cpp @@ -233,7 +233,8 @@ QueryCounters DatabaseSnapshot::execute(const Query &query, } } - const QueryGraphCollection graphs{query, types_to_query, config}; + Query query_copy{query.clone()}; + query_copy.precompute(types_to_query, config); task->spec().estimate_work(datasets_to_query.size()); @@ -243,7 +244,7 @@ QueryCounters DatabaseSnapshot::execute(const Query &query, if (!ds->has_all_taints(taints)) { continue; } - ds->execute(graphs, out, &counters); + ds->execute(query_copy, out, &counters); } return counters; } diff --git a/libursa/OnDiskDataset.cpp b/libursa/OnDiskDataset.cpp index 135fd54f..d6b586e7 100644 --- a/libursa/OnDiskDataset.cpp +++ b/libursa/OnDiskDataset.cpp @@ -64,19 +64,27 @@ std::string OnDiskDataset::get_file_name(FileId fid) const { return files_index->get_file_name(fid); } -QueryResult OnDiskDataset::query(const QueryGraphCollection &graphs, +QueryResult OnDiskDataset::query(const Query &query, QueryCounters *counters) const { - QueryResult result = QueryResult::everything(); - for (auto &ndx : indices) { - auto subresult{ndx.query(graphs.get(ndx.index_type()), counters)}; - result.do_and(subresult, &counters->ands()); - } - return result; + return query.run( + [this](auto &graphs, QueryCounters *counters) { + QueryResult result = QueryResult::everything(); + for (auto &ndx : indices) { + if (graphs.count(ndx.index_type()) == 0) { + throw std::runtime_error("Unexpected graph type in query"); + } + auto subresult{ + ndx.query(graphs.at(ndx.index_type()), counters)}; + result.do_and(subresult, &counters->ands()); + } + return result; + }, + counters); } -void OnDiskDataset::execute(const QueryGraphCollection &graphs, - ResultWriter *out, QueryCounters *counters) const { - QueryResult result = query(graphs, counters); +void OnDiskDataset::execute(const Query &query, ResultWriter *out, + QueryCounters *counters) const { + QueryResult result = this->query(query, counters); if (result.is_everything()) { files_index->for_each_filename( [&out](const std::string &fname) { out->push_back(fname); }); @@ -296,22 +304,3 @@ std::vector OnDiskDataset::get_compact_candidates( return out; } - -QueryGraphCollection::QueryGraphCollection( - const Query &query, const std::unordered_set &types, - const DatabaseConfig &config) { - graphs_.reserve(types.size()); - for (const auto type : types) { - graphs_.emplace(type, std::move(query.to_graph(type, config))); - } -} - -const QueryGraph &QueryGraphCollection::get(IndexType type) const { - const auto it = graphs_.find(type); - if (it == graphs_.end()) { - throw std::runtime_error( - "QueryGraphCollection doesn't contain a graph of the requested " - "type"); - } - return it->second; -} diff --git a/libursa/OnDiskDataset.h b/libursa/OnDiskDataset.h index f2fb6947..903c981e 100644 --- a/libursa/OnDiskDataset.h +++ b/libursa/OnDiskDataset.h @@ -15,17 +15,6 @@ #include "ResultWriter.h" #include "Task.h" -class QueryGraphCollection { - std::unordered_map graphs_; - - public: - QueryGraphCollection(const Query &query, - const std::unordered_set &types, - const DatabaseConfig &config); - - const QueryGraph &get(IndexType type) const; -}; - class OnDiskDataset { std::string name; fs::path db_base; @@ -37,8 +26,7 @@ class OnDiskDataset { return taints == other.taints; } std::string get_file_name(FileId fid) const; - QueryResult query(const QueryGraphCollection &graphs, - QueryCounters *counters) const; + QueryResult query(const Query &query, QueryCounters *counters) const; const OnDiskIndex &get_index_with_type(IndexType index_type) const; void drop_file(const std::string &fname) const; @@ -54,7 +42,7 @@ class OnDiskDataset { } void toggle_taint(const std::string &taint); bool has_all_taints(const std::set &taints) const; - void execute(const QueryGraphCollection &graphs, ResultWriter *out, + void execute(const Query &query, ResultWriter *out, QueryCounters *counters) const; uint64_t get_file_count() const { return files_index->get_file_count(); } void for_each_filename(std::function cb) const { diff --git a/libursa/OnDiskIndex.h b/libursa/OnDiskIndex.h index 4a335a22..043315ed 100644 --- a/libursa/OnDiskIndex.h +++ b/libursa/OnDiskIndex.h @@ -26,11 +26,6 @@ class OnDiskIndex { std::vector query_primitive(TriGram trigram, QueryCounter *counter) const; std::pair get_run_offsets(TriGram trigram) const; - bool internal_expand(QString::const_iterator qit, uint8_t *out, size_t pos, - size_t comb_len, const TrigramGenerator &gen, - QueryResult &res) const; - QueryResult expand_wildcards(const QString &qstr, size_t len, - const TrigramGenerator &gen) const; static void on_disk_merge_core(const std::vector &indexes, RawFile *out, TaskSpec *task); diff --git a/libursa/Query.cpp b/libursa/Query.cpp index 7b6f2093..7a0add24 100644 --- a/libursa/Query.cpp +++ b/libursa/Query.cpp @@ -67,7 +67,17 @@ const QString &Query::as_value() const { return value; } -std::string Query::as_string_repr() const { return "[primitive]"; } +std::string Query::as_string_repr() const { + std::string out = ""; + for (const auto &token : value) { + if (token.num_possible_values() == 1) { + out += token.possible_values()[0]; + } else { + out += "?"; + } + } + return out; +} Query q(QString &&qstr) { return Query(std::move(qstr)); } @@ -197,39 +207,50 @@ QueryGraph to_query_graph(const QString &str, int size, return result; } -QueryGraph Query::to_graph(IndexType ntype, - const DatabaseConfig &config) const { +void Query::precompute(const std::unordered_set &types_to_query, + const DatabaseConfig &config) { if (type == QueryType::PRIMITIVE) { - TokenValidator validator = get_validator_for(ntype); - size_t input_len = get_ngram_size_for(ntype); - return to_query_graph(value, input_len, config, validator); + value_graphs.clear(); + for (const auto &ntype : types_to_query) { + TokenValidator validator = get_validator_for(ntype); + size_t input_len = get_ngram_size_for(ntype); + auto graph{to_query_graph(value, input_len, config, validator)}; + value_graphs.emplace(ntype, std::move(graph)); + } + } else { + for (auto &query : queries) { + query.precompute(types_to_query, config); + } } +} - if (type == QueryType::AND) { - QueryGraph result; +QueryResult Query::run(const QueryPrimitive &primitive, + QueryCounters *counters) const { + if (type == QueryType::PRIMITIVE) { + return primitive(value_graphs, counters); + } else if (type == QueryType::AND) { + auto result = QueryResult::everything(); for (const auto &query : queries) { - result.and_(query.to_graph(ntype, config)); + result.do_and(query.run(primitive, counters), &counters->ands()); } return result; - } - - if (type == QueryType::OR) { - if (queries.empty()) { - return QueryGraph(); - } - QueryGraph result = std::move(queries[0].to_graph(ntype, config)); - for (size_t i = 1; i < queries.size(); i++) { - result.or_(queries[i].to_graph(ntype, config)); + } else if (type == QueryType::OR) { + auto result = QueryResult::empty(); + for (const auto &query : queries) { + result.do_or(query.run(primitive, counters), &counters->ors()); } return result; - } - - if (type == QueryType::MIN_OF) { - std::vector subgraphs; + } else if (type == QueryType::MIN_OF) { + std::vector results; + std::vector results_ptrs; + results.reserve(queries.size()); + results_ptrs.reserve(queries.size()); for (const auto &query : queries) { - subgraphs.emplace_back(query.to_graph(ntype, config)); + results.emplace_back(query.run(primitive, counters)); + results_ptrs.emplace_back(&results.back()); } - return QueryGraph::min_of(count, std::move(subgraphs)); + return QueryResult::do_min_of(count, results_ptrs, &counters->minofs()); + } else { + throw std::runtime_error("Unexpected query type"); } - throw std::runtime_error("Unknown query type."); } diff --git a/libursa/Query.h b/libursa/Query.h index ef76b632..77eff3b3 100644 --- a/libursa/Query.h +++ b/libursa/Query.h @@ -1,22 +1,41 @@ #pragma once +#include #include #include +#include +#include #include #include "DatabaseConfig.h" #include "QString.h" #include "QueryGraph.h" +#include "QueryResult.h" #include "Utils.h" -enum QueryType { PRIMITIVE = 1, AND = 2, OR = 3, MIN_OF = 4 }; +enum class QueryType { PRIMITIVE = 1, AND = 2, OR = 3, MIN_OF = 4 }; + +using QueryPrimitive = std::function &, QueryCounters *counter)>; class Query { + private: + Query(const Query &other) + : type(other.type), value_graphs(), count(other.count) { + queries.reserve(other.queries.size()); + for (const auto &query : other.queries) { + queries.emplace_back(query.clone()); + } + value.reserve(other.value.size()); + for (const auto &token : other.value) { + value.emplace_back(token.clone()); + } + } + public: explicit Query(QString &&qstr); explicit Query(uint32_t count, std::vector &&queries); explicit Query(const QueryType &type, std::vector &&queries); - Query(const Query &other) = delete; Query(Query &&other) = default; const std::vector &as_queries() const; @@ -26,14 +45,22 @@ class Query { const QueryType &get_type() const; bool operator==(const Query &other) const; - // Converts this instance of Query to equivalent QueryGraph. - QueryGraph to_graph(IndexType ntype, const DatabaseConfig &config) const; + QueryResult run(const QueryPrimitive &primitive, + QueryCounters *counters) const; + void precompute(const std::unordered_set &types_to_query, + const DatabaseConfig &config); + + Query clone() const { return Query(*this); } private: QueryType type; - uint32_t count; // used for QueryType::MIN_OF - QString value; // used for QueryType::PRIMITIVE - std::vector queries; // used for QueryType::AND/OR + // used for QueryType::PRIMITIVE + QString value; + std::unordered_map value_graphs; + // used for QueryType::MIN_OF + uint32_t count; + // used for QueryType::AND/OR/MIN_OF + std::vector queries; }; // Creates a literal query. Literals can contain wildcards and alternatives. diff --git a/src/Tests.cpp b/src/Tests.cpp index 99879d4b..6a4ce3c0 100644 --- a/src/Tests.cpp +++ b/src/Tests.cpp @@ -52,12 +52,10 @@ QString mqs(const std::string &str) { return out; } -QueryGraph mqg(const std::string &str, IndexType type) { - QString out; - for (const auto &c : str) { - out.emplace_back(QToken::single(c)); - } - return q(std::move(out)).to_graph(type, DatabaseConfig()); +QueryGraph mqg(const std::string &str, IndexType ntype) { + TokenValidator validator = get_validator_for(ntype); + size_t input_len = get_ngram_size_for(ntype); + return to_query_graph(mqs(str), input_len, DatabaseConfig(), validator); } TEST_CASE("packing 3grams", "[internal]") { diff --git a/teste2e/test_select.py b/teste2e/test_select.py index 56c6bf5f..e30a9e80 100644 --- a/teste2e/test_select.py +++ b/teste2e/test_select.py @@ -4,16 +4,12 @@ def test_select_with_taints(ursadb: UrsadbTestContext): - store_files( - ursadb, "gram3", {"tainted": b"test",}, - ) + store_files(ursadb, "gram3", {"tainted": b"test"}) topology = ursadb.check_request("topology;") dsname = list(topology["result"]["datasets"].keys())[0] - store_files( - ursadb, "gram3", {"untainted": b"test",}, - ) + store_files(ursadb, "gram3", {"untainted": b"test"}) ursadb.check_request(f'dataset "{dsname}" taint "test";') @@ -29,40 +25,49 @@ def test_select_with_weird_filenames(ursadb: UrsadbTestContext): "hmm \" ' hmm", "hmm \\ hmm", "hmm $(ls) $$ $shell hmm", - "hmm <> hmm" + "hmm <> hmm", ] - store_files( - ursadb, "gram3", { - name: b"test" for name in weird_names - }, - ) + store_files(ursadb, "gram3", {name: b"test" for name in weird_names}) check_query(ursadb, '"test"', weird_names) def test_select_with_datasets(ursadb: UrsadbTestContext): - store_files( - ursadb, "gram3", {"first": b"test",}, - ) + store_files(ursadb, "gram3", {"first": b"test"}) topology = ursadb.check_request("topology;") dsname = list(topology["result"]["datasets"].keys())[0] + store_files(ursadb, "gram3", {"second": b"test"}) + + check_query(ursadb, f'with datasets ["{dsname}"] "test"', ["first"]) + + +def test_select_ascii_wide(ursadb: UrsadbTestContext): + # ensure that `ascii wide` strings don't produce too large results store_files( - ursadb, "gram3", {"second": b"test",}, + ursadb, + ["gram3", "text4", "wide8"], + { + "ascii": b"koty", + "wide": b"k\x00o\x00t\x00y\x00", + "falsepositive": b"k\x00o\x00???\x00o\x00t\x00y\x00", + }, ) - check_query(ursadb, f'with datasets ["{dsname}"] "test"', ["first"]) + check_query(ursadb, f'"koty"', ["ascii"]) + check_query(ursadb, f'"k\\x00o\\x00t\\x00y\\x00"', ["wide"]) + check_query(ursadb, f'("koty" | "k\\x00o\\x00t\\x00y\\x00")', ["ascii", "wide"]) @pytest.mark.parametrize( "ursadb", - [UrsadbConfig(query_max_ngram=256*256, query_max_edge=255)], + [UrsadbConfig(query_max_ngram=256 * 256, query_max_edge=255)], indirect=["ursadb"], ) def test_select_with_wildcards(ursadb: UrsadbTestContext): store_files( - ursadb, "gram3", {"first": b"first", "fiRst": b"fiRst", "second": b"second"}, + ursadb, "gram3", {"first": b"first", "fiRst": b"fiRst", "second": b"second"} ) check_query(ursadb, '"first"', ["first"]) @@ -70,21 +75,18 @@ def test_select_with_wildcards(ursadb: UrsadbTestContext): check_query(ursadb, '"fi\\x??st"', ["first", "fiRst"]) check_query(ursadb, '"fi\\x?2st"', ["first", "fiRst"]) - check_query(ursadb, '{66 69 72 73 74}', ["first"]) - check_query(ursadb, '{66 69 52 73 74}', ["fiRst"]) - check_query(ursadb, '{66 69 ?? 73 74}', ["first", "fiRst"]) - check_query(ursadb, '{66 69 ?2 73 74}', ["first", "fiRst"]) - + check_query(ursadb, "{66 69 72 73 74}", ["first"]) + check_query(ursadb, "{66 69 52 73 74}", ["fiRst"]) + check_query(ursadb, "{66 69 ?? 73 74}", ["first", "fiRst"]) + check_query(ursadb, "{66 69 ?2 73 74}", ["first", "fiRst"]) @pytest.mark.parametrize( - "ursadb", - [UrsadbConfig(query_max_ngram=16, query_max_edge=16)], - indirect=["ursadb"], + "ursadb", [UrsadbConfig(query_max_ngram=16, query_max_edge=16)], indirect=["ursadb"] ) def test_select_with_wildcards_with_limits(ursadb: UrsadbTestContext): store_files( - ursadb, "gram3", {"first": b"first", "fiRst": b"fiRst", "second": b"second"}, + ursadb, "gram3", {"first": b"first", "fiRst": b"fiRst", "second": b"second"} ) check_query(ursadb, '"first"', ["first"]) @@ -93,8 +95,8 @@ def test_select_with_wildcards_with_limits(ursadb: UrsadbTestContext): check_query(ursadb, '"fi\\x?2st"', ["first", "fiRst"]) check_query(ursadb, '"fi\\x?2s\\x??"', ["first", "fiRst"]) - check_query(ursadb, '{66 69 72 73 74}', ["first"]) - check_query(ursadb, '{66 69 52 73 74}', ["fiRst"]) - check_query(ursadb, '{66 69 ?? 73 74}', ["first", "fiRst", "second"]) - check_query(ursadb, '{66 69 ?2 73 74}', ["first", "fiRst"]) - check_query(ursadb, '{66 69 ?2 73 ??}', ["first", "fiRst"]) + check_query(ursadb, "{66 69 72 73 74}", ["first"]) + check_query(ursadb, "{66 69 52 73 74}", ["fiRst"]) + check_query(ursadb, "{66 69 ?? 73 74}", ["first", "fiRst", "second"]) + check_query(ursadb, "{66 69 ?2 73 74}", ["first", "fiRst"]) + check_query(ursadb, "{66 69 ?2 73 ??}", ["first", "fiRst"]) diff --git a/teste2e/util.py b/teste2e/util.py index 83d90402..88a21d70 100644 --- a/teste2e/util.py +++ b/teste2e/util.py @@ -7,7 +7,7 @@ import resource from pathlib import Path import pytest -from typing import Dict, Any, List +from typing import Dict, Any, List, Union import zmq import hashlib import shutil @@ -153,7 +153,7 @@ def match_pattern(value: Any, pattern: Any): def store_files( ursadb: UrsadbTestContext, - type: str, + type: Union[str, List[str]], data: Dict[str, bytes], expect_error: bool = False, taints: List[str] = [], @@ -166,17 +166,23 @@ def store_files( (tmpdir / name).write_bytes(value) filenames.append(str(tmpdir / name)) - filenames = [f.replace('\\', '\\\\').replace('"', '\\"') for f in filenames] + filenames = [f.replace("\\", "\\\\").replace('"', '\\"') for f in filenames] ursa_names = " ".join(f'"{f}"' for f in filenames) taints_mod = "" if taints: - taint_list = ','.join(f'"{t}"' for t in taints) + taint_list = ",".join(f'"{t}"' for t in taints) taints_mod = f" with taints [{taint_list}]" - res = ursadb.request(f"index {ursa_names} with [{type}]{taints_mod};") + if isinstance(type, str): + types = f"[{type}]" + else: + types = "[" + ", ".join(type) + "]" + + query = f"index {ursa_names} with {types}{taints_mod};" + res = ursadb.request(query) if ("error" in res) != expect_error: - print(f"index {ursa_names} with [{type}]{taints_mod};") + print(query) print(res) assert False @@ -185,6 +191,10 @@ def check_query(ursadb: UrsadbTestContext, query: str, expected: List[str]): response = ursadb.check_request(f"select {query};") assert response["type"] == "select" assert response["result"]["mode"] == "raw" + if len(response["result"]["files"]) != len(expected): + print("length mismatch") + print(response["result"]["files"]) + print(expected) assert len(response["result"]["files"]) == len(expected) for fpath in response["result"]["files"]: @@ -206,6 +216,7 @@ def check_query(ursadb: UrsadbTestContext, query: str, expected: List[str]): for fpath in files["result"]["files"]: assert any(fpath.endswith(f"/{fname}") for fname in expected) + def get_index_hash(ursadb: UrsadbTestContext, type: str) -> str: """ Tries to find sha256 hash of the provided index """ indexes = list(ursadb.ursadb_dir.glob(f"{type}*"))