diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDataFrameAnalyticsAction.java index 003bef914f72c..cd2ee3e01c6d4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDataFrameAnalyticsAction.java @@ -77,10 +77,10 @@ public TransportPreviewDataFrameAnalyticsAction( this.clusterService = clusterService; } - private static Map mergeRow(DataFrameDataExtractor.Row row, List fieldNames) { - return row.getValues() == null + private static Map mergeRow(String[] row, List fieldNames) { + return row == null ? Collections.emptyMap() - : IntStream.range(0, row.getValues().length).boxed().collect(Collectors.toMap(fieldNames::get, i -> row.getValues()[i])); + : IntStream.range(0, row.length).boxed().collect(Collectors.toMap(fieldNames::get, i -> row[i])); } @Override @@ -121,7 +121,7 @@ void preview(Task task, DataFrameAnalyticsConfig config, ActionListener { List fieldNames = extractor.getFieldNames(); - l.onResponse(new Response(rows.stream().map((r) -> mergeRow(r, fieldNames)).collect(Collectors.toList()))); + l.onResponse(new Response(rows.stream().map(r -> mergeRow(r, fieldNames)).collect(Collectors.toList()))); })); })); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index c890ab599c380..894115d76db72 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -19,7 +19,6 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; @@ -107,14 +106,14 @@ public void cancel() { isCancelled = true; } - public Optional> next() throws IOException { + public Optional next() throws IOException { if (hasNext() == false) { throw new NoSuchElementException(); } - Optional> hits = Optional.ofNullable(nextSearch()); - if (hits.isPresent() && hits.get().isEmpty() == false) { - lastSortKey = hits.get().get(hits.get().size() - 1).getSortKey(); + Optional hits = Optional.ofNullable(nextSearch()); + if (hits.isPresent() && hits.get().length > 0) { + lastSortKey = (long) hits.get()[hits.get().length - 1].getSortValues()[0]; } else { hasNext = false; } @@ -126,7 +125,7 @@ public Optional> next() throws IOException { * Does no sorting of the results. * @param listener To alert with the extracted rows */ - public void preview(ActionListener> listener) { + public void preview(ActionListener> listener) { SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client) // This ensures the search throws if there are failures and the scroll context gets cleared automatically @@ -155,22 +154,24 @@ public void preview(ActionListener> listener) { return; } - List rows = new ArrayList<>(searchResponse.getHits().getHits().length); + List rows = new ArrayList<>(searchResponse.getHits().getHits().length); for (SearchHit hit : searchResponse.getHits().getHits()) { - var unpooled = hit.asUnpooled(); - String[] extractedValues = extractValues(unpooled); - rows.add(extractedValues == null ? new Row(null, unpooled, true) : new Row(extractedValues, unpooled, false)); + String[] extractedValues = extractValues(hit); + rows.add(extractedValues); } delegate.onResponse(rows); }) ); } - protected List nextSearch() throws IOException { + protected SearchHit[] nextSearch() throws IOException { + if (isCancelled) { + return null; + } return tryRequestWithSearchResponse(() -> executeSearchRequest(buildSearchRequest())); } - private List tryRequestWithSearchResponse(Supplier request) throws IOException { + private SearchHit[] tryRequestWithSearchResponse(Supplier request) throws IOException { try { // We've set allow_partial_search_results to false which means if something @@ -179,7 +180,7 @@ private List tryRequestWithSearchResponse(Supplier request) try { LOGGER.trace(() -> "[" + context.jobId + "] Search response was obtained"); - List rows = processSearchResponse(searchResponse); + SearchHit[] rows = processSearchResponse(searchResponse); // Request was successfully executed and processed so we can restore the flag to retry if a future failure occurs hasPreviousSearchFailed = false; @@ -246,22 +247,12 @@ private void setFetchSource(SearchRequestBuilder searchRequestBuilder) { } } - private List processSearchResponse(SearchResponse searchResponse) { - if (searchResponse.getHits().getHits().length == 0) { + private SearchHit[] processSearchResponse(SearchResponse searchResponse) { + if (isCancelled || searchResponse.getHits().getHits().length == 0) { hasNext = false; return null; } - - SearchHits hits = searchResponse.getHits(); - List rows = new ArrayList<>(hits.getHits().length); - for (SearchHit hit : hits) { - if (isCancelled) { - hasNext = false; - break; - } - rows.add(createRow(hit)); - } - return rows; + return searchResponse.getHits().asUnpooled().getHits(); } private String extractNonProcessedValues(SearchHit hit, String organicFeature) { @@ -317,14 +308,13 @@ private String[] extractProcessedValue(ProcessedField processedField, SearchHit return extractedValue; } - private Row createRow(SearchHit hit) { - var unpooled = hit.asUnpooled(); - String[] extractedValues = extractValues(unpooled); + public Row createRow(SearchHit hit) { + String[] extractedValues = extractValues(hit); if (extractedValues == null) { - return new Row(null, unpooled, true); + return new Row(null, hit, true); } boolean isTraining = trainTestSplitter.get().isTraining(extractedValues); - Row row = new Row(extractedValues, unpooled, isTraining); + Row row = new Row(extractedValues, hit, isTraining); LOGGER.trace( () -> format( "[%s] Extracted row: sort key = [%s], is_training = [%s], values = %s", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index d4c10e25a2ade..6205653ce9c0f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -256,9 +257,14 @@ private static void writeDataRows( long rowsProcessed = 0; while (dataExtractor.hasNext()) { - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); if (rows.isPresent()) { - for (DataFrameDataExtractor.Row row : rows.get()) { + for (SearchHit searchHit : rows.get()) { + if (dataExtractor.isCancelled()) { + break; + } + rowsProcessed++; + DataFrameDataExtractor.Row row = dataExtractor.createRow(searchHit); if (row.shouldSkip()) { dataCountsTracker.incrementSkippedDocsCount(); } else { @@ -271,7 +277,6 @@ private static void writeDataRows( } } } - rowsProcessed += rows.get().size(); progressTracker.updateLoadingDataProgress(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java index ee91b0637bfc7..3e1968ca19ce1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; @@ -22,11 +23,9 @@ import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.io.IOException; -import java.util.Collections; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedList; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -97,6 +96,9 @@ private void addResultAndJoinIfEndOfBatch(RowResults rowResults) { private void joinCurrentResults() { try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) { while (currentResults.isEmpty() == false) { + if (dataExtractor.isCancelled()) { + break; + } RowResults result = currentResults.pop(); DataFrameDataExtractor.Row row = dataFrameRowsIterator.next(); checkChecksumsMatch(row, result); @@ -164,12 +166,12 @@ private void consumeDataExtractor() throws IOException { private class ResultMatchingDataFrameRows implements Iterator { - private List currentDataFrameRows = Collections.emptyList(); + private SearchHit[] currentDataFrameRows = SearchHits.EMPTY; private int currentDataFrameRowsIndex; @Override public boolean hasNext() { - return dataExtractor.hasNext() || currentDataFrameRowsIndex < currentDataFrameRows.size(); + return dataExtractor.hasNext() || currentDataFrameRowsIndex < currentDataFrameRows.length; } @Override @@ -177,7 +179,7 @@ public DataFrameDataExtractor.Row next() { DataFrameDataExtractor.Row row = null; while (hasNoMatch(row) && hasNext()) { advanceToNextBatchIfNecessary(); - row = currentDataFrameRows.get(currentDataFrameRowsIndex++); + row = dataExtractor.createRow(currentDataFrameRows[currentDataFrameRowsIndex++]); } if (hasNoMatch(row)) { @@ -191,13 +193,13 @@ private static boolean hasNoMatch(DataFrameDataExtractor.Row row) { } private void advanceToNextBatchIfNecessary() { - if (currentDataFrameRowsIndex >= currentDataFrameRows.size()) { - currentDataFrameRows = getNextDataRowsBatch().orElse(Collections.emptyList()); + if (currentDataFrameRowsIndex >= currentDataFrameRows.length) { + currentDataFrameRows = getNextDataRowsBatch().orElse(SearchHits.EMPTY); currentDataFrameRowsIndex = 0; } } - private Optional> getNextDataRowsBatch() { + private Optional getNextDataRowsBatch() { try { return dataExtractor.next(); } catch (IOException e) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index 993e00bd4adf4..2ba9146533b78 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -118,19 +118,19 @@ public void testTwoPageExtraction() throws IOException { assertThat(dataExtractor.hasNext(), is(true)); // First batch - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(3)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "11", "21" })); - assertThat(rows.get().get(1).getValues(), equalTo(new String[] { "12", "22" })); - assertThat(rows.get().get(2).getValues(), equalTo(new String[] { "13", "23" })); + assertThat(rows.get().length, equalTo(3)); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "11", "21" })); + assertThat(dataExtractor.createRow(rows.get()[1]).getValues(), equalTo(new String[] { "12", "22" })); + assertThat(dataExtractor.createRow(rows.get()[2]).getValues(), equalTo(new String[] { "13", "23" })); assertThat(dataExtractor.hasNext(), is(true)); // Second batch rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(1)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "31", "41" })); + assertThat(rows.get().length, equalTo(1)); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "31", "41" })); assertThat(dataExtractor.hasNext(), is(true)); // Third batch should return empty @@ -208,18 +208,18 @@ public void testRecoveryFromErrorOnSearch() throws IOException { assertThat(dataExtractor.hasNext(), is(true)); // First batch expected as normally since we'll retry after the error - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(2)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "11", "21" })); - assertThat(rows.get().get(1).getValues(), equalTo(new String[] { "12", "22" })); + assertThat(rows.get().length, equalTo(2)); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "11", "21" })); + assertThat(dataExtractor.createRow(rows.get()[1]).getValues(), equalTo(new String[] { "12", "22" })); assertThat(dataExtractor.hasNext(), is(true)); // We get second batch as we retried after the error rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(1)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "13", "23" })); + assertThat(rows.get().length, equalTo(1)); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "13", "23" })); assertThat(dataExtractor.hasNext(), is(true)); // Next batch should return empty @@ -262,10 +262,10 @@ public void testIncludeSourceIsFalseAndNoSourceFields() throws IOException { assertThat(dataExtractor.hasNext(), is(true)); - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(1)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "11", "21" })); + assertThat(rows.get().length, equalTo(1)); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "11", "21" })); assertThat(dataExtractor.hasNext(), is(true)); assertThat(dataExtractor.next(), isEmpty()); @@ -297,10 +297,10 @@ public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOExceptio assertThat(dataExtractor.hasNext(), is(true)); - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(1)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "11", "21" })); + assertThat(rows.get().length, equalTo(1)); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "11", "21" })); assertThat(dataExtractor.hasNext(), is(true)); assertThat(dataExtractor.next(), isEmpty()); @@ -364,18 +364,18 @@ public void testMissingValues_GivenSupported() throws IOException { assertThat(dataExtractor.hasNext(), is(true)); // First batch - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(3)); + assertThat(rows.get().length, equalTo(3)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "11", "21" })); - assertThat(rows.get().get(1).getValues()[0], equalTo(DataFrameDataExtractor.NULL_VALUE)); - assertThat(rows.get().get(1).getValues()[1], equalTo("22")); - assertThat(rows.get().get(2).getValues(), equalTo(new String[] { "13", "23" })); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "11", "21" })); + assertThat(dataExtractor.createRow(rows.get()[1]).getValues()[0], equalTo(DataFrameDataExtractor.NULL_VALUE)); + assertThat(dataExtractor.createRow(rows.get()[1]).getValues()[1], equalTo("22")); + assertThat(dataExtractor.createRow(rows.get()[2]).getValues(), equalTo(new String[] { "13", "23" })); - assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(false)); - assertThat(rows.get().get(2).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[0]).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[1]).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[2]).shouldSkip(), is(false)); assertThat(dataExtractor.hasNext(), is(true)); @@ -399,17 +399,17 @@ public void testMissingValues_GivenNotSupported() throws IOException { assertThat(dataExtractor.hasNext(), is(true)); // First batch - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(3)); + assertThat(rows.get().length, equalTo(3)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "11", "21" })); - assertThat(rows.get().get(1).getValues(), is(nullValue())); - assertThat(rows.get().get(2).getValues(), equalTo(new String[] { "13", "23" })); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "11", "21" })); + assertThat(dataExtractor.createRow(rows.get()[1]).getValues(), is(nullValue())); + assertThat(dataExtractor.createRow(rows.get()[2]).getValues(), equalTo(new String[] { "13", "23" })); - assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(true)); - assertThat(rows.get().get(2).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[0]).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[1]).shouldSkip(), is(true)); + assertThat(dataExtractor.createRow(rows.get()[2]).shouldSkip(), is(false)); assertThat(dataExtractor.hasNext(), is(true)); @@ -538,20 +538,20 @@ public void testExtractionWithProcessedFeatures() throws IOException { assertThat(dataExtractor.hasNext(), is(true)); // First batch - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(3)); + assertThat(rows.get().length, equalTo(3)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "21", "dog", "1", "0" })); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "21", "dog", "1", "0" })); assertThat( - rows.get().get(1).getValues(), + dataExtractor.createRow(rows.get()[1]).getValues(), equalTo(new String[] { "22", "dog", DataFrameDataExtractor.NULL_VALUE, DataFrameDataExtractor.NULL_VALUE }) ); - assertThat(rows.get().get(2).getValues(), equalTo(new String[] { "23", "dog", "0", "0" })); + assertThat(dataExtractor.createRow(rows.get()[2]).getValues(), equalTo(new String[] { "23", "dog", "0", "0" })); - assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(false)); - assertThat(rows.get().get(2).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[0]).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[1]).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[2]).shouldSkip(), is(false)); } public void testExtractionWithMultipleScalarTypesInSource() throws IOException { @@ -577,17 +577,17 @@ public void testExtractionWithMultipleScalarTypesInSource() throws IOException { assertThat(dataExtractor.hasNext(), is(true)); // First batch - Optional> rows = dataExtractor.next(); + Optional rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(3)); + assertThat(rows.get().length, equalTo(3)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] { "1", "21", })); - assertThat(rows.get().get(1).getValues(), equalTo(new String[] { "true", "22" })); - assertThat(rows.get().get(2).getValues(), equalTo(new String[] { "false", "23" })); + assertThat(dataExtractor.createRow(rows.get()[0]).getValues(), equalTo(new String[] { "1", "21", })); + assertThat(dataExtractor.createRow(rows.get()[1]).getValues(), equalTo(new String[] { "true", "22" })); + assertThat(dataExtractor.createRow(rows.get()[2]).getValues(), equalTo(new String[] { "false", "23" })); - assertThat(rows.get().get(0).shouldSkip(), is(false)); - assertThat(rows.get().get(1).shouldSkip(), is(false)); - assertThat(rows.get().get(2).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[0]).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[1]).shouldSkip(), is(false)); + assertThat(dataExtractor.createRow(rows.get()[2]).shouldSkip(), is(false)); } public void testExtractionWithProcessedFieldThrows() { @@ -610,7 +610,7 @@ public void testExtractionWithProcessedFieldThrows() { assertThat(dataExtractor.hasNext(), is(true)); - expectThrows(RuntimeException.class, () -> dataExtractor.next()); + expectThrows(RuntimeException.class, () -> Arrays.stream(dataExtractor.next().get()).forEach(dataExtractor::createRow)); } private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java index 3a95a3bb65f10..cb02b8294b115 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java @@ -306,6 +306,7 @@ private void givenDataFrameBatches(List> batche DelegateStubDataExtractor delegateStubDataExtractor = new DelegateStubDataExtractor(batches); when(dataExtractor.hasNext()).thenAnswer(a -> delegateStubDataExtractor.hasNext()); when(dataExtractor.next()).thenAnswer(a -> delegateStubDataExtractor.next()); + when(dataExtractor.createRow(any(SearchHit.class))).thenAnswer(a -> delegateStubDataExtractor.makeRow(a.getArgument(0))); } private static SearchHit newHit(String json) { @@ -340,19 +341,32 @@ private void givenClientHasNoFailures() { private static class DelegateStubDataExtractor { - private final List> batches; + private final List batches; + private final Map rows = new HashMap<>(); private int batchIndex; - private DelegateStubDataExtractor(List> batches) { - this.batches = batches; + private DelegateStubDataExtractor(List> rows) { + batches = new ArrayList<>(rows.size()); + for (List batch : rows) { + List batchHits = new ArrayList<>(batch.size()); + for (DataFrameDataExtractor.Row row : batch) { + this.rows.put(row.getHit(), row); + batchHits.add(row.getHit()); + } + batches.add(batchHits.toArray(new SearchHit[0])); + } } public boolean hasNext() { return batchIndex < batches.size(); } - public Optional> next() { + public Optional next() { return Optional.of(batches.get(batchIndex++)); } + + public DataFrameDataExtractor.Row makeRow(SearchHit hit) { + return rows.get(hit); + } } }