diff --git a/src/main/java/redis/clients/jedis/search/SearchResult.java b/src/main/java/redis/clients/jedis/search/SearchResult.java index 55afbe0b24..b51e791927 100644 --- a/src/main/java/redis/clients/jedis/search/SearchResult.java +++ b/src/main/java/redis/clients/jedis/search/SearchResult.java @@ -21,10 +21,16 @@ public class SearchResult { private final long totalResults; private final List documents; + private final List warnings; private SearchResult(long totalResults, List documents) { + this(totalResults, documents, (List) null); + } + + private SearchResult(long totalResults, List documents, List warnings) { this.totalResults = totalResults; this.documents = documents; + this.warnings = warnings; } public long getTotalResults() { @@ -35,10 +41,16 @@ public List getDocuments() { return Collections.unmodifiableList(documents); } + public List getWarnings() { + return warnings; + } + @Override public String toString() { return getClass().getSimpleName() + "{Total results:" + totalResults - + ", Documents:" + documents + "}"; + + ", Documents:" + documents + + (warnings != null ? ", Warnings:" + warnings : "") + + "}"; } public static class SearchResultBuilder extends Builder { @@ -104,6 +116,7 @@ public static final class PerFieldDecoderSearchResultBuilder extends Builder documentBuilder; @@ -120,20 +133,25 @@ public SearchResult build(Object data) { List list = (List) data; long totalResults = -1; List results = null; + List warnings = null; for (KeyValue kv : list) { String key = BuilderFactory.STRING.build(kv.getKey()); + Object rawVal = kv.getValue(); switch (key) { case TOTAL_RESULTS_STR: - totalResults = BuilderFactory.LONG.build(kv.getValue()); + totalResults = BuilderFactory.LONG.build(rawVal); break; case RESULTS_STR: - results = ((List) kv.getValue()).stream() + results = ((List) rawVal).stream() .map(documentBuilder::build) .collect(Collectors.toList()); break; + case WARNINGS_STR: + warnings = BuilderFactory.STRING_LIST.build(rawVal); + break; } } - return new SearchResult(totalResults, results); + return new SearchResult(totalResults, results, warnings); } }; /// <-- RESP3 diff --git a/src/main/java/redis/clients/jedis/search/aggr/AggregationResult.java b/src/main/java/redis/clients/jedis/search/aggr/AggregationResult.java index cec65f9cd9..3eba4ac1d1 100644 --- a/src/main/java/redis/clients/jedis/search/aggr/AggregationResult.java +++ b/src/main/java/redis/clients/jedis/search/aggr/AggregationResult.java @@ -19,37 +19,18 @@ public class AggregationResult { private final List> results; - private Long cursorId = -1L; - - private AggregationResult(Object resp, long cursorId) { - this(resp); - this.cursorId = cursorId; - } + private final List warnings; - private AggregationResult(Object resp) { - List list = (List) SafeEncoder.encodeObject(resp); - - // the first element is always the number of results - totalResults = (Long) list.get(0); - results = new ArrayList<>(list.size() - 1); + private Long cursorId = -1L; - for (int i = 1; i < list.size(); i++) { - List mapList = (List) list.get(i); - Map map = new HashMap<>(mapList.size() / 2, 1f); - for (int j = 0; j < mapList.size(); j += 2) { - Object r = mapList.get(j); - if (r instanceof JedisDataException) { - throw (JedisDataException) r; - } - map.put((String) r, mapList.get(j + 1)); - } - results.add(map); - } + private AggregationResult(long totalResults, List> results) { + this(totalResults, results, (List) null); } - private AggregationResult(long totalResults, List> results) { + private AggregationResult(long totalResults, List> results, List warnings) { this.totalResults = totalResults; this.results = results; + this.warnings = warnings; } private void setCursorId(Long cursorId) { @@ -80,12 +61,17 @@ public Row getRow(int index) { return new Row(results.get(index)); } + public List getWarnings() { + return warnings; + } + public static final Builder SEARCH_AGGREGATION_RESULT = new Builder() { private static final String TOTAL_RESULTS_STR = "total_results"; private static final String RESULTS_STR = "results"; // private static final String FIELDS_STR = "fields"; private static final String FIELDS_STR = "extra_attributes"; + private static final String WARNINGS_STR = "warning"; @Override public AggregationResult build(Object data) { @@ -96,14 +82,16 @@ public AggregationResult build(Object data) { List kvList = (List) data; long totalResults = -1; List> results = null; + List warnings = null; for (KeyValue kv : kvList) { String key = BuilderFactory.STRING.build(kv.getKey()); + Object rawVal = kv.getValue(); switch (key) { case TOTAL_RESULTS_STR: - totalResults = BuilderFactory.LONG.build(kv.getValue()); + totalResults = BuilderFactory.LONG.build(rawVal); break; case RESULTS_STR: - List> resList = (List>) kv.getValue(); + List> resList = (List>) rawVal; results = new ArrayList<>(resList.size()); for (List rikv : resList) { for (KeyValue ikv : rikv) { @@ -114,9 +102,12 @@ public AggregationResult build(Object data) { } } break; + case WARNINGS_STR: + warnings = BuilderFactory.STRING_LIST.build(rawVal); + break; } } - return new AggregationResult(totalResults, results); + return new AggregationResult(totalResults, results, warnings); } list = (List) SafeEncoder.encodeObject(data); diff --git a/src/test/java/redis/clients/jedis/modules/ConsolidatedConfigurationCommandsTest.java b/src/test/java/redis/clients/jedis/modules/ConsolidatedConfigurationCommandsTest.java index 741e719f26..c376554bbc 100644 --- a/src/test/java/redis/clients/jedis/modules/ConsolidatedConfigurationCommandsTest.java +++ b/src/test/java/redis/clients/jedis/modules/ConsolidatedConfigurationCommandsTest.java @@ -40,7 +40,8 @@ public void setSearchConfigGloballyTest() { @Test public void setReadOnlySearchConfigTest() { - JedisDataException de = assertThrows(JedisDataException.class, () -> jedis.configSet("search-max-doctablesize", "10")); + JedisDataException de = assertThrows(JedisDataException.class, + () -> jedis.configSet("search-max-doctablesize", "10")); assertThat(de.getMessage(), Matchers.not(Matchers.emptyOrNullString())); } diff --git a/src/test/java/redis/clients/jedis/modules/search/SearchDefaultDialectTest.java b/src/test/java/redis/clients/jedis/modules/search/SearchDefaultDialectTest.java index 8192f29a1a..819880877f 100644 --- a/src/test/java/redis/clients/jedis/modules/search/SearchDefaultDialectTest.java +++ b/src/test/java/redis/clients/jedis/modules/search/SearchDefaultDialectTest.java @@ -7,10 +7,12 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import static redis.clients.jedis.util.AssertUtil.assertEqualsByProtocol; import static redis.clients.jedis.util.AssertUtil.assertOK; import java.util.*; +import io.redis.test.annotations.SinceRedisVersion; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; @@ -21,6 +23,7 @@ import redis.clients.jedis.exceptions.JedisDataException; import redis.clients.jedis.search.*; import redis.clients.jedis.search.schemafields.NumericField; +import redis.clients.jedis.search.schemafields.TagField; import redis.clients.jedis.search.schemafields.TextField; import redis.clients.jedis.modules.RedisModuleCommandsTestBase; import redis.clients.jedis.search.aggr.AggregationBuilder; @@ -61,6 +64,14 @@ private void addDocument(Document doc) { client.hset(key, map); } + private static Map toMap(String... values) { + Map map = new HashMap<>(); + for (int i = 0; i < values.length; i += 2) { + map.put(values[i], values[i + 1]); + } + return map; + } + @Test public void testQueryParams() { Schema sc = new Schema().addNumericField("numval"); @@ -181,4 +192,33 @@ private void assertSyntaxError(Query query, UnifiedJedis client) { () -> client.ftExplain(INDEX, query)); assertThat(error.getMessage(), containsString("Syntax error")); } + + @Test + @SinceRedisVersion(value = "7.9.0") + public void warningMaxPrefixExpansions() { + final String configParam = "search-max-prefix-expansions"; + String defaultConfigValue = jedis.configGet(configParam).get(configParam); + try { + assertOK(client.ftCreate(INDEX, FTCreateParams.createParams().on(IndexDataType.HASH), + TextField.of("t"), TagField.of("t2"))); + + client.hset("doc13", toMap("t", "foo", "t2", "foo")); + + jedis.configSet(configParam, "1"); + + SearchResult srcResult = client.ftSearch(INDEX, "fo*"); + assertEqualsByProtocol(protocol, null, Arrays.asList(), srcResult.getWarnings()); + + client.hset("doc23", toMap("t", "fooo", "t2", "fooo")); + + AggregationResult aggResult = client.ftAggregate(INDEX, new AggregationBuilder("fo*").loadAll()); + assertEqualsByProtocol(protocol, + /* resp2 */ null, + Arrays.asList("Max prefix expansions limit was reached"), + aggResult.getWarnings()); + } finally { + jedis.configSet(configParam, defaultConfigValue); + } + } + }