Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop processing search requests when _msearch is canceled #17005

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix Shallow copy snapshot failures on closed index ([#16868](https://github.com/opensearch-project/OpenSearch/pull/16868))
- Fix multi-value sort for unsigned long ([#16732](https://github.com/opensearch-project/OpenSearch/pull/16732))
- The `phone-search` analyzer no longer emits the tel/sip prefix, international calling code, extension numbers and unformatted input as a token ([#16993](https://github.com/opensearch-project/OpenSearch/pull/16993))
- Stop processing search requests when _msearch request is cancelled ([#17005](https://github.com/opensearch-project/OpenSearch/pull/17005))
- Fix GRPC AUX_TRANSPORT_PORT and SETTING_GRPC_PORT settings and remove lingering HTTP terminology ([#17037](https://github.com/opensearch-project/OpenSearch/pull/17037))
- Fix exists queries on nested flat_object fields throws exception ([#16803](https://github.com/opensearch-project/OpenSearch/pull/16803))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -193,6 +196,19 @@
if (responseCounter.decrementAndGet() == 0) {
assert requests.isEmpty();
finish();
} else if (isCancelled(request.request.getParentTask())) {
// Drain the rest of the queue
SearchRequestSlot request;
while ((request = requests.poll()) != null) {
responses.set(
request.responseSlot,
new MultiSearchResponse.Item(null, new TaskCancelledException("Parent task was cancelled"))
);
if (responseCounter.decrementAndGet() == 0) {
assert requests.isEmpty();
finish();
}
}
} else {
if (thread == Thread.currentThread()) {
// we are on the same thread, we need to fork to another thread to avoid recursive stack overflow on a single thread
Expand Down Expand Up @@ -220,6 +236,14 @@
});
}

private boolean isCancelled(TaskId taskId) {
if (taskId.isSet()) {
CancellableTask task = taskManager.getCancellableTask(taskId.getId());
return task != null && task.isCancelled();
}
return false;

Check warning on line 244 in server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java#L244

Added line #L244 was not covered by tests
}

/**
* Slots a search request
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskListener;
import org.opensearch.tasks.TaskManager;
import org.opensearch.telemetry.tracing.noop.NoopTracer;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -62,7 +64,9 @@
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -289,4 +293,118 @@ public void testDefaultMaxConcurrentSearches() {
assertThat(result, equalTo(1));
}

public void testCancellation() {
// Initialize dependencies of TransportMultiSearchAction
Settings settings = Settings.builder().put("node.name", TransportMultiSearchActionTests.class.getSimpleName()).build();
ActionFilters actionFilters = mock(ActionFilters.class);
when(actionFilters.filters()).thenReturn(new ActionFilter[0]);
ThreadPool threadPool = new ThreadPool(settings);
TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundAddress -> DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()),
null,
Collections.emptySet(),
NoopTracer.INSTANCE
) {
@Override
public TaskManager getTaskManager() {
return taskManager;
}
};
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).build());

// Keep track of the number of concurrent searches started by multi search api,
// and if there are more searches than is allowed create an error and remember that.
int maxAllowedConcurrentSearches = 1; // Allow 1 search at a time.
AtomicInteger counter = new AtomicInteger();
AtomicReference<AssertionError> errorHolder = new AtomicReference<>();
// randomize whether or not requests are executed asynchronously
ExecutorService executorService = threadPool.executor(ThreadPool.Names.GENERIC);
final Set<SearchRequest> requests = Collections.newSetFromMap(Collections.synchronizedMap(new IdentityHashMap<>()));
CountDownLatch countDownLatch = new CountDownLatch(1);
CancellableTask[] parentTask = new CancellableTask[1];
NodeClient client = new NodeClient(settings, threadPool) {
@Override
public void search(final SearchRequest request, final ActionListener<SearchResponse> listener) {
if (parentTask[0] != null && parentTask[0].isCancelled()) {
fail("Should not execute search after parent task is cancelled");
}
try {
countDownLatch.await(10, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}

requests.add(request);
executorService.execute(() -> {
counter.decrementAndGet();
listener.onResponse(
new SearchResponse(
InternalSearchResponse.empty(),
null,
0,
0,
0,
0L,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
)
);
});
}

@Override
public String getLocalNodeId() {
return "local_node_id";
}
};

TransportMultiSearchAction action = new TransportMultiSearchAction(
threadPool,
actionFilters,
transportService,
clusterService,
10,
System::nanoTime,
client
);

// Execute the multi search api and fail if we find an error after executing:
try {
/*
* Allow for a large number of search requests in a single batch as previous implementations could stack overflow if the number
* of requests in a single batch was large
*/
int numSearchRequests = scaledRandomIntBetween(1024, 8192);
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
multiSearchRequest.maxConcurrentSearchRequests(maxAllowedConcurrentSearches);
for (int i = 0; i < numSearchRequests; i++) {
multiSearchRequest.add(new SearchRequest());
}
MultiSearchResponse[] responses = new MultiSearchResponse[1];
Exception[] exceptions = new Exception[1];
parentTask[0] = (CancellableTask) action.execute(multiSearchRequest, new TaskListener<>() {
@Override
public void onResponse(Task task, MultiSearchResponse items) {
responses[0] = items;
}

@Override
public void onFailure(Task task, Exception e) {
exceptions[0] = e;
}
});
parentTask[0].cancel("Giving up");
countDownLatch.countDown();

assertNull(responses[0]);
assertNull(exceptions[0]);
} finally {
assertTrue(OpenSearchTestCase.terminate(threadPool));
}
}
}
Loading