Skip to content

Commit

Permalink
Support warning messages in search/aggregation query results (#3958)
Browse files Browse the repository at this point in the history
* Support warning messages in search/aggregation query results

* private constructor

* Enable respective Test

* add test rule
  • Loading branch information
sazzad16 authored Jan 30, 2025
1 parent a46700a commit 80ba2d5
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 33 deletions.
26 changes: 22 additions & 4 deletions src/main/java/redis/clients/jedis/search/SearchResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ public class SearchResult {

private final long totalResults;
private final List<Document> documents;
private final List<String> warnings;

private SearchResult(long totalResults, List<Document> documents) {
this(totalResults, documents, (List<String>) null);
}

private SearchResult(long totalResults, List<Document> documents, List<String> warnings) {
this.totalResults = totalResults;
this.documents = documents;
this.warnings = warnings;
}

public long getTotalResults() {
Expand All @@ -35,10 +41,16 @@ public List<Document> getDocuments() {
return Collections.unmodifiableList(documents);
}

public List<String> getWarnings() {
return warnings;
}

@Override
public String toString() {
return getClass().getSimpleName() + "{Total results:" + totalResults
+ ", Documents:" + documents + "}";
+ ", Documents:" + documents
+ (warnings != null ? ", Warnings:" + warnings : "")
+ "}";
}

public static class SearchResultBuilder extends Builder<SearchResult> {
Expand Down Expand Up @@ -104,6 +116,7 @@ public static final class PerFieldDecoderSearchResultBuilder extends Builder<Sea

private static final String TOTAL_RESULTS_STR = "total_results";
private static final String RESULTS_STR = "results";
private static final String WARNINGS_STR = "warning";

private final Builder<Document> documentBuilder;

Expand All @@ -120,20 +133,25 @@ public SearchResult build(Object data) {
List<KeyValue> list = (List<KeyValue>) data;
long totalResults = -1;
List<Document> results = null;
List<String> warnings = null;
for (KeyValue kv : list) {
String key = BuilderFactory.STRING.build(kv.getKey());
Object rawVal = kv.getValue();
switch (key) {
case TOTAL_RESULTS_STR:
totalResults = BuilderFactory.LONG.build(kv.getValue());
totalResults = BuilderFactory.LONG.build(rawVal);
break;
case RESULTS_STR:
results = ((List<Object>) kv.getValue()).stream()
results = ((List<Object>) rawVal).stream()
.map(documentBuilder::build)
.collect(Collectors.toList());
break;
case WARNINGS_STR:
warnings = BuilderFactory.STRING_LIST.build(rawVal);
break;
}
}
return new SearchResult(totalResults, results);
return new SearchResult(totalResults, results, warnings);
}
};
/// <-- RESP3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,18 @@ public class AggregationResult {

private final List<Map<String, Object>> results;

private Long cursorId = -1L;

private AggregationResult(Object resp, long cursorId) {
this(resp);
this.cursorId = cursorId;
}
private final List<String> warnings;

private AggregationResult(Object resp) {
List<Object> list = (List<Object>) SafeEncoder.encodeObject(resp);

// the first element is always the number of results
totalResults = (Long) list.get(0);
results = new ArrayList<>(list.size() - 1);
private Long cursorId = -1L;

for (int i = 1; i < list.size(); i++) {
List<Object> mapList = (List<Object>) list.get(i);
Map<String, Object> map = new HashMap<>(mapList.size() / 2, 1f);
for (int j = 0; j < mapList.size(); j += 2) {
Object r = mapList.get(j);
if (r instanceof JedisDataException) {
throw (JedisDataException) r;
}
map.put((String) r, mapList.get(j + 1));
}
results.add(map);
}
private AggregationResult(long totalResults, List<Map<String, Object>> results) {
this(totalResults, results, (List<String>) null);
}

private AggregationResult(long totalResults, List<Map<String, Object>> results) {
private AggregationResult(long totalResults, List<Map<String, Object>> results, List<String> warnings) {
this.totalResults = totalResults;
this.results = results;
this.warnings = warnings;
}

private void setCursorId(Long cursorId) {
Expand Down Expand Up @@ -80,12 +61,17 @@ public Row getRow(int index) {
return new Row(results.get(index));
}

public List<String> getWarnings() {
return warnings;
}

public static final Builder<AggregationResult> SEARCH_AGGREGATION_RESULT = new Builder<AggregationResult>() {

private static final String TOTAL_RESULTS_STR = "total_results";
private static final String RESULTS_STR = "results";
// private static final String FIELDS_STR = "fields";
private static final String FIELDS_STR = "extra_attributes";
private static final String WARNINGS_STR = "warning";

@Override
public AggregationResult build(Object data) {
Expand All @@ -96,14 +82,16 @@ public AggregationResult build(Object data) {
List<KeyValue> kvList = (List<KeyValue>) data;
long totalResults = -1;
List<Map<String, Object>> results = null;
List<String> warnings = null;
for (KeyValue kv : kvList) {
String key = BuilderFactory.STRING.build(kv.getKey());
Object rawVal = kv.getValue();
switch (key) {
case TOTAL_RESULTS_STR:
totalResults = BuilderFactory.LONG.build(kv.getValue());
totalResults = BuilderFactory.LONG.build(rawVal);
break;
case RESULTS_STR:
List<List<KeyValue>> resList = (List<List<KeyValue>>) kv.getValue();
List<List<KeyValue>> resList = (List<List<KeyValue>>) rawVal;
results = new ArrayList<>(resList.size());
for (List<KeyValue> rikv : resList) {
for (KeyValue ikv : rikv) {
Expand All @@ -114,9 +102,12 @@ public AggregationResult build(Object data) {
}
}
break;
case WARNINGS_STR:
warnings = BuilderFactory.STRING_LIST.build(rawVal);
break;
}
}
return new AggregationResult(totalResults, results);
return new AggregationResult(totalResults, results, warnings);
}

list = (List<Object>) SafeEncoder.encodeObject(data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public void setSearchConfigGloballyTest() {

@Test
public void setReadOnlySearchConfigTest() {
JedisDataException de = assertThrows(JedisDataException.class, () -> jedis.configSet("search-max-doctablesize", "10"));
JedisDataException de = assertThrows(JedisDataException.class,
() -> jedis.configSet("search-max-doctablesize", "10"));
assertThat(de.getMessage(), Matchers.not(Matchers.emptyOrNullString()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static redis.clients.jedis.util.AssertUtil.assertEqualsByProtocol;
import static redis.clients.jedis.util.AssertUtil.assertOK;

import java.util.*;

import io.redis.test.annotations.SinceRedisVersion;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -21,6 +23,7 @@
import redis.clients.jedis.exceptions.JedisDataException;
import redis.clients.jedis.search.*;
import redis.clients.jedis.search.schemafields.NumericField;
import redis.clients.jedis.search.schemafields.TagField;
import redis.clients.jedis.search.schemafields.TextField;
import redis.clients.jedis.modules.RedisModuleCommandsTestBase;
import redis.clients.jedis.search.aggr.AggregationBuilder;
Expand Down Expand Up @@ -61,6 +64,14 @@ private void addDocument(Document doc) {
client.hset(key, map);
}

private static Map<String, String> toMap(String... values) {
Map<String, String> map = new HashMap<>();
for (int i = 0; i < values.length; i += 2) {
map.put(values[i], values[i + 1]);
}
return map;
}

@Test
public void testQueryParams() {
Schema sc = new Schema().addNumericField("numval");
Expand Down Expand Up @@ -181,4 +192,33 @@ private void assertSyntaxError(Query query, UnifiedJedis client) {
() -> client.ftExplain(INDEX, query));
assertThat(error.getMessage(), containsString("Syntax error"));
}

@Test
@SinceRedisVersion(value = "7.9.0")
public void warningMaxPrefixExpansions() {
final String configParam = "search-max-prefix-expansions";
String defaultConfigValue = jedis.configGet(configParam).get(configParam);
try {
assertOK(client.ftCreate(INDEX, FTCreateParams.createParams().on(IndexDataType.HASH),
TextField.of("t"), TagField.of("t2")));

client.hset("doc13", toMap("t", "foo", "t2", "foo"));

jedis.configSet(configParam, "1");

SearchResult srcResult = client.ftSearch(INDEX, "fo*");
assertEqualsByProtocol(protocol, null, Arrays.asList(), srcResult.getWarnings());

client.hset("doc23", toMap("t", "fooo", "t2", "fooo"));

AggregationResult aggResult = client.ftAggregate(INDEX, new AggregationBuilder("fo*").loadAll());
assertEqualsByProtocol(protocol,
/* resp2 */ null,
Arrays.asList("Max prefix expansions limit was reached"),
aggResult.getWarnings());
} finally {
jedis.configSet(configParam, defaultConfigValue);
}
}

}

0 comments on commit 80ba2d5

Please sign in to comment.