From 6b158807232913197784a31d1745c59e016b804c Mon Sep 17 00:00:00 2001 From: Ji Bin Date: Sat, 13 Jan 2024 23:04:32 +0800 Subject: [PATCH] Support range search, fix #245 --- src/impl/MilvusClientImpl.cpp | 9 ++- src/impl/types/SearchArguments.cpp | 62 +++++++++++++++++-- src/include/milvus/types/SearchArguments.h | 46 +++++++++++++- test/st/TestSearch.cpp | 70 ++++++++++++++++++++++ test/ut/TestSearchArguments.cpp | 21 +++++++ 5 files changed, 201 insertions(+), 7 deletions(-) diff --git a/src/impl/MilvusClientImpl.cpp b/src/impl/MilvusClientImpl.cpp index e4b7c1a..712a88d 100644 --- a/src/impl/MilvusClientImpl.cpp +++ b/src/impl/MilvusClientImpl.cpp @@ -22,7 +22,6 @@ #include "TypeUtils.h" #include "common.pb.h" -#include "milvus.grpc.pb.h" #include "milvus.pb.h" #include "schema.pb.h" @@ -741,7 +740,13 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result kv_pair = rpc_request.add_search_params(); kv_pair->set_key(milvus::KeyParams()); - kv_pair->set_value(arguments.ExtraParams()); + // merge extra params with range search + auto json = nlohmann::json::parse(arguments.ExtraParams()); + if (arguments.RangeSearch()) { + json["range_filter"] = arguments.RangeFilter(); + json["radius"] = arguments.Radius(); + } + kv_pair->set_value(json.dump()); rpc_request.set_travel_timestamp(arguments.TravelTimestamp()); rpc_request.set_guarantee_timestamp(arguments.GuaranteeTimestamp()); diff --git a/src/impl/types/SearchArguments.cpp b/src/impl/types/SearchArguments.cpp index 057d8bb..31474c6 100644 --- a/src/impl/types/SearchArguments.cpp +++ b/src/impl/types/SearchArguments.cpp @@ -17,6 +17,7 @@ #include "milvus/types/SearchArguments.h" #include +#include namespace milvus { namespace { @@ -28,7 +29,7 @@ struct Validation { bool required; Status - Validate(const SearchArguments& data, std::unordered_map params) const { + Validate(const SearchArguments&, std::unordered_map params) const { auto it = params.find(param); if (it != params.end()) { auto value = it->second; @@ -43,7 +44,7 @@ struct Validation { }; Status -validate(const SearchArguments& data, std::unordered_map params) { +validate(const SearchArguments& data, const std::unordered_map& params) { auto status = Status::OK(); auto validations = { Validation{"nprobe", 1, 65536, false}, @@ -128,7 +129,7 @@ SearchArguments::TargetVectors() const { Status SearchArguments::AddTargetVector(std::string field_name, const std::string& vector) { - return AddTargetVector(field_name, std::string{vector}); + return AddTargetVector(std::move(field_name), std::string{vector}); } Status @@ -223,6 +224,20 @@ SearchArguments::TopK() const { return topk_; } +int64_t +SearchArguments::Nprobe() const { + if (extra_params_.find("nprobe") != extra_params_.end()) { + return extra_params_.at("nprobe"); + } + return 1; +} + +Status +SearchArguments::SetNprobe(int64_t nprobe) { + extra_params_["nprobe"] = nprobe; + return Status::OK(); +} + Status SearchArguments::SetRoundDecimal(int round_decimal) { round_decimal_ = round_decimal; @@ -236,6 +251,12 @@ SearchArguments::RoundDecimal() const { Status SearchArguments::SetMetricType(::milvus::MetricType metric_type) { + if (((metric_type == MetricType::IP && metric_type_ == MetricType::L2) || + (metric_type == MetricType::L2 && metric_type_ == MetricType::IP)) && + range_search_) { + // switch radius and range_filter + std::swap(radius_, range_filter_); + } metric_type_ = metric_type; return Status::OK(); } @@ -251,7 +272,7 @@ SearchArguments::AddExtraParam(std::string key, int64_t value) { return Status::OK(); } -const std::string +std::string SearchArguments::ExtraParams() const { return ::nlohmann::json(extra_params_).dump(); } @@ -261,4 +282,37 @@ SearchArguments::Validate() const { return validate(*this, extra_params_); } +float +SearchArguments::Radius() const { + return radius_; +} + +float +SearchArguments::RangeFilter() const { + return range_filter_; +} + +Status +SearchArguments::SetRange(float from, float to) { + auto low = std::min(from, to); + auto high = std::max(from, to); + if (metric_type_ == MetricType::IP) { + radius_ = low; + range_filter_ = high; + range_search_ = true; + } else if (metric_type_ == MetricType::L2) { + radius_ = high; + range_filter_ = low; + range_search_ = true; + } else { + return {StatusCode::INVALID_AGUMENT, "Metric type is not supported"}; + } + return Status::OK(); +} + +bool +SearchArguments::RangeSearch() const { + return range_search_; +} + } // namespace milvus diff --git a/src/include/milvus/types/SearchArguments.h b/src/include/milvus/types/SearchArguments.h index 22ae34d..eeede41 100644 --- a/src/include/milvus/types/SearchArguments.h +++ b/src/include/milvus/types/SearchArguments.h @@ -164,6 +164,18 @@ class SearchArguments { int64_t TopK() const; + /** + * @brief Get nprobe + */ + int64_t + Nprobe() const; + + /** + * @brief Set nprobe + */ + Status + SetNprobe(int64_t nlist); + /** * @brief Specifies the decimal place of the returned results. */ @@ -197,7 +209,7 @@ class SearchArguments { /** * @brief Get extra param */ - const std::string + std::string ExtraParams() const; /** @@ -207,6 +219,35 @@ class SearchArguments { Status Validate() const; + /** + * @brief Get range radius + * @return + */ + float + Radius() const; + + /** + * @brief Get range filter + * @return + */ + float + RangeFilter() const; + + /** + * @brief Set range radius + * @param from range radius from + * @param to range radius to + */ + Status + SetRange(float from, float to); + + /** + * @brief Get if do range search + * @return + */ + bool + RangeSearch() const; + private: std::string collection_name_; std::set partition_names_; @@ -225,6 +266,9 @@ class SearchArguments { int64_t topk_{1}; int round_decimal_{-1}; + float radius_; + float range_filter_; + bool range_search_{false}; ::milvus::MetricType metric_type_{::milvus::MetricType::L2}; }; diff --git a/test/st/TestSearch.cpp b/test/st/TestSearch.cpp index f8f46a8..c41406a 100644 --- a/test/st/TestSearch.cpp +++ b/test/st/TestSearch.cpp @@ -139,6 +139,76 @@ TEST_F(MilvusServerTestSearch, SearchWithoutIndex) { dropCollection(); } +TEST_F(MilvusServerTestSearch, RangeSearch) { + std::vector fields{ + std::make_shared("age", std::vector{12, 13, 14, 15, 16, 17, 18}), + std::make_shared( + "name", std::vector{"Tom", "Jerry", "Lily", "Foo", "Bar", "Jake", "Jonathon"}), + std::make_shared("face", std::vector>{ + std::vector{0.1f, 0.2f, 0.3f, 0.4f}, + std::vector{0.2f, 0.3f, 0.4f, 0.5f}, + std::vector{0.3f, 0.4f, 0.5f, 0.6f}, + std::vector{0.4f, 0.5f, 0.6f, 0.7f}, + std::vector{0.5f, 0.6f, 0.7f, 0.8f}, + std::vector{0.6f, 0.7f, 0.8f, 0.9f}, + std::vector{0.7f, 0.8f, 0.9f, 1.0f}, + })}; + + createCollectionAndPartitions(true); + auto dml_results = insertRecords(fields); + loadCollection(); + + milvus::SearchArguments arguments{}; + arguments.SetCollectionName(collection_name); + arguments.AddPartitionName(partition_name); + arguments.SetRange(0.3, 1.0); + arguments.SetTopK(10); + arguments.AddOutputField("age"); + arguments.AddOutputField("name"); + arguments.AddTargetVector("face", std::vector{0.f, 0.f, 0.f, 0.f}); + arguments.AddTargetVector("face", std::vector{1.f, 1.f, 1.f, 1.f}); + milvus::SearchResults search_results{}; + auto status = client_->Search(arguments, search_results); + EXPECT_EQ(status.Message(), "OK"); + EXPECT_TRUE(status.IsOk()); + + const auto& results = search_results.Results(); + EXPECT_EQ(results.size(), 2); + + // validate results + auto validateScores = [&results](int firstRet, int secondRet) { + // check score should between range + for (const auto& result : results) { + for (const auto& score : result.Scores()) { + EXPECT_GE(score, 0.3); + EXPECT_LE(score, 1.0); + } + } + EXPECT_EQ(results.at(0).Ids().IntIDArray().size(), firstRet); + EXPECT_EQ(results.at(1).Ids().IntIDArray().size(), secondRet); + }; + + // valid score in range is 3, 2 + validateScores(3, 2); + + // add fields, then search again, should be 6 and 4 + insertRecords(fields); + loadCollection(); + status = client_->Search(arguments, search_results); + EXPECT_TRUE(status.IsOk()); + validateScores(6, 4); + + // add fields twice, and now it should be 12, 8, as limit is 10, then should be 10, 8 + insertRecords(fields); + insertRecords(fields); + loadCollection(); + status = client_->Search(arguments, search_results); + EXPECT_TRUE(status.IsOk()); + validateScores(10, 8); + + dropCollection(); +} + TEST_F(MilvusServerTestSearch, SearchWithStringFilter) { std::vector fields{ std::make_shared("age", std::vector{12, 13}), diff --git a/test/ut/TestSearchArguments.cpp b/test/ut/TestSearchArguments.cpp index 922a049..780b89b 100644 --- a/test/ut/TestSearchArguments.cpp +++ b/test/ut/TestSearchArguments.cpp @@ -170,3 +170,24 @@ TEST_F(SearchArgumentsTest, ValidateTesting) { EXPECT_TRUE(status.IsOk()); } } + +TEST_F(SearchArgumentsTest, Nprobe) { + milvus::SearchArguments arguments; + arguments.AddExtraParam("nprobe", 10); + EXPECT_EQ(10, arguments.Nprobe()); + + arguments.SetNprobe(20); + EXPECT_EQ(20, arguments.Nprobe()); +} + +TEST_F(SearchArgumentsTest, RangeSearchParams) { + milvus::SearchArguments arguments; + arguments.SetMetricType(milvus::MetricType::IP); + arguments.SetRange(0.1, 0.2); + EXPECT_NEAR(0.1, arguments.Radius(), 0.00001); + EXPECT_NEAR(0.2, arguments.RangeFilter(), 0.00001); + + arguments.SetMetricType(milvus::MetricType::L2); + EXPECT_NEAR(0.2, arguments.Radius(), 0.00001); + EXPECT_NEAR(0.1, arguments.RangeFilter(), 0.00001); +}