diff --git a/CHANGELOG.md b/CHANGELOG.md index e0df5d4edc..2c6252c46c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features +* [Remote Vector Index Build] Introduce Remote Native Index Build feature flag, settings, and initial skeleton [#2525](https://github.com/opensearch-project/k-NN/pull/2525) +* [Remote Vector Index Build] Implement vector data upload and vector data size threshold setting [#2550](https://github.com/opensearch-project/k-NN/pull/2550) ### Enhancements * Introduce node level circuit breakers for k-NN [#2509](https://github.com/opensearch-project/k-NN/pull/2509) ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index cd760aa548..c5d9a05748 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -12,7 +12,6 @@ import org.opensearch.OpenSearchParseException; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; -import org.opensearch.transport.client.Client; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Booleans; @@ -29,6 +28,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.os.OsProbe; +import org.opensearch.transport.client.Client; import java.security.InvalidParameterException; import java.util.Arrays; @@ -99,6 +99,7 @@ public class KNNSettings { public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled"; public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo"; + public static final String KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD = "index.knn.remote_index_build.size_threshold"; /** * Default setting values @@ -129,6 +130,8 @@ public class KNNSettings { // 10% of the JVM heap public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = false; + // TODO: Tune this default value based on benchmarking + public static final ByteSizeValue KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_DEFAULT_VALUE = new ByteSizeValue(50, ByteSizeUnit.MB); /** * Settings Definition @@ -397,6 +400,15 @@ public class KNNSettings { */ public static final Setting KNN_REMOTE_VECTOR_REPO_SETTING = Setting.simpleString(KNN_REMOTE_VECTOR_REPO, Dynamic, NodeScope); + /** + * Index level setting which indicates the size threshold above which remote vector builds will be enabled. + */ + public static final Setting KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING = Setting.byteSizeSetting( + KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD, + KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_DEFAULT_VALUE, + Dynamic, + IndexScope + ); /** * Dynamic settings */ @@ -584,6 +596,10 @@ private Setting getSetting(String key) { return KNN_REMOTE_VECTOR_REPO_SETTING; } + if (KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD.equals(key)) { + return KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -611,7 +627,8 @@ public List> getSettings() { KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, KNN_DERIVED_SOURCE_ENABLED_SETTING, KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING, - KNN_REMOTE_VECTOR_REPO_SETTING + KNN_REMOTE_VECTOR_REPO_SETTING, + KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java index 4b991fa507..e6a842d11c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java @@ -7,14 +7,18 @@ import org.apache.lucene.index.FieldInfo; import org.opensearch.index.IndexSettings; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.repositories.RepositoriesService; +import java.io.IOException; import java.util.function.Supplier; import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; /** * Creates the {@link NativeIndexBuildStrategy} @@ -34,11 +38,18 @@ public NativeIndexBuildStrategyFactory(Supplier repositorie } /** - * Creates or returns the desired {@link NativeIndexBuildStrategy} implementation. Intended to be used by {@link NativeIndexWriter} - * @param fieldInfo - * @return + * @param fieldInfo Field related attributes/info + * @param totalLiveDocs Number of documents with the vector field. This values comes from {@link org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsWriter#flush} + * and {@link org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsWriter#mergeOneField} + * @param knnVectorValues An instance of {@link KNNVectorValues} which is used to evaluate the size threshold KNN_REMOTE_VECTOR_BUILD_THRESHOLD + * @return The {@link NativeIndexBuildStrategy} to be used. Intended to be used by {@link NativeIndexWriter} + * @throws IOException */ - public NativeIndexBuildStrategy getBuildStrategy(final FieldInfo fieldInfo) { + public NativeIndexBuildStrategy getBuildStrategy( + final FieldInfo fieldInfo, + final int totalLiveDocs, + final KNNVectorValues knnVectorValues + ) throws IOException { final KNNEngine knnEngine = extractKNNEngine(fieldInfo); boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; @@ -47,11 +58,15 @@ public NativeIndexBuildStrategy getBuildStrategy(final FieldInfo fieldInfo) { ? MemOptimizedNativeIndexBuildStrategy.getInstance() : DefaultIndexBuildStrategy.getInstance(); - if (repositoriesServiceSupplier != null + initializeVectorValues(knnVectorValues); + long vectorBlobLength = ((long) knnVectorValues.bytesPerVector()) * totalLiveDocs; + + if (KNNFeatureFlags.isKNNRemoteVectorBuildEnabled() + && repositoriesServiceSupplier != null && indexSettings != null && knnEngine.supportsRemoteIndexBuild() - && RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings)) { - return new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy); + && RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings, vectorBlobLength)) { + return new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy, indexSettings); } else { return strategy; } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 636e209506..e0ff86e849 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -58,7 +58,7 @@ public class NativeIndexWriter { private final SegmentWriteState state; private final FieldInfo fieldInfo; - private final NativeIndexBuildStrategy indexBuilder; + private final NativeIndexBuildStrategyFactory indexBuilderFactory; @Nullable private final QuantizationState quantizationState; @@ -148,6 +148,11 @@ private void buildAndWriteIndex(final Supplier> knnVectorValu knnVectorValuesSupplier, totalLiveDocs ); + NativeIndexBuildStrategy indexBuilder = indexBuilderFactory.getBuildStrategy( + fieldInfo, + totalLiveDocs, + knnVectorValuesSupplier.get() + ); indexBuilder.buildAndWriteIndex(nativeIndexParams); CodecUtil.writeFooter(output); } @@ -316,6 +321,6 @@ private static NativeIndexWriter createWriter( @Nullable final QuantizationState quantizationState, NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory ) { - return new NativeIndexWriter(state, fieldInfo, nativeIndexBuildStrategyFactory.getBuildStrategy(fieldInfo), quantizationState); + return new NativeIndexWriter(state, fieldInfo, nativeIndexBuildStrategyFactory, quantizationState); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java new file mode 100644 index 0000000000..97366bc92b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessor.java @@ -0,0 +1,213 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.common.CheckedTriFunction; +import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.BlobContainer; +import org.opensearch.common.blobstore.BlobPath; +import org.opensearch.common.blobstore.stream.write.WriteContext; +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.repositories.blobstore.BlobStoreRepository; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.DOC_ID_FILE_EXTENSION; +import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.VECTORS_PATH; +import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.VECTOR_BLOB_FILE_EXTENSION; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; + +@Log4j2 +@AllArgsConstructor +public class DefaultVectorRepositoryAccessor implements VectorRepositoryAccessor { + private final BlobStoreRepository repository; + private final IndexSettings indexSettings; + + /** + * If the repository implements {@link AsyncMultiStreamBlobContainer}, then parallel uploads will be used. Parallel uploads are backed by a {@link WriteContext}, for which we have a custom + * {@link org.opensearch.common.blobstore.stream.write.StreamContextSupplier} implementation. + * + * @see DefaultVectorRepositoryAccessor#getStreamContext + * @see DefaultVectorRepositoryAccessor#getTransferPartStreamSupplier + * + * @param blobName Base name of the blobs we are writing, excluding file extensions + * @param totalLiveDocs Number of documents we are processing. This is used to compute the size of the blob we are writing + * @param vectorDataType Data type of the vector (FLOAT, BYTE, BINARY) + * @param knnVectorValuesSupplier Supplier for {@link KNNVectorValues} + * @throws IOException + * @throws InterruptedException + */ + @Override + public void writeToRepository( + String blobName, + int totalLiveDocs, + VectorDataType vectorDataType, + Supplier> knnVectorValuesSupplier + ) throws IOException, InterruptedException { + assert repository != null; + // Get the blob container based on blobName and the repo base path. This is where the blobs will be written to. + BlobPath path = repository.basePath().add(indexSettings.getUUID() + VECTORS_PATH); + BlobContainer blobContainer = repository.blobStore().blobContainer(path); + + KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + initializeVectorValues(knnVectorValues); + long vectorBlobLength = (long) knnVectorValues.bytesPerVector() * totalLiveDocs; + + if (blobContainer instanceof AsyncMultiStreamBlobContainer) { + // First initiate vectors upload + log.debug("Repository {} Supports Parallel Blob Upload", repository); + // WriteContext is the main entry point into asyncBlobUpload. It stores all of our upload configurations, analogous to + // BuildIndexParams + WriteContext writeContext = new WriteContext.Builder().fileName(blobName + VECTOR_BLOB_FILE_EXTENSION) + .streamContextSupplier((partSize) -> getStreamContext(partSize, vectorBlobLength, knnVectorValuesSupplier, vectorDataType)) + .fileSize(vectorBlobLength) + .failIfAlreadyExists(true) + .writePriority(WritePriority.NORMAL) + // TODO: Checksum implementations -- It is difficult to calculate a checksum on the knnVectorValues as + // there is no underlying file upon which we can create the checksum. We should be able to create a + // checksum still by iterating through once, however this will be an expensive operation. + .uploadFinalizer((bool) -> {}) + .doRemoteDataIntegrityCheck(false) + .expectedChecksum(null) + .build(); + + AtomicReference exception = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + ((AsyncMultiStreamBlobContainer) blobContainer).asyncBlobUpload( + writeContext, + new LatchedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Void unused) { + log.debug( + "Parallel vector upload succeeded for blob {} with size {}", + blobName + VECTOR_BLOB_FILE_EXTENSION, + vectorBlobLength + ); + } + + @Override + public void onFailure(Exception e) { + log.error( + "Parallel vector upload failed for blob {} with size {}", + blobName + VECTOR_BLOB_FILE_EXTENSION, + vectorBlobLength, + e + ); + exception.set(e); + } + }, latch) + ); + + // Then upload doc id blob before waiting on vector uploads + // TODO: We wrap with a BufferedInputStream to support retries. We can tune this buffer size to optimize performance. + // Note: We do not use the parallel upload API here as the doc id blob will be much smaller than the vector blob + writeDocIds(knnVectorValuesSupplier.get(), vectorBlobLength, totalLiveDocs, blobName, blobContainer); + latch.await(); + if (exception.get() != null) { + throw new IOException(exception.get()); + } + } else { + log.debug("Repository {} Does Not Support Parallel Blob Upload", repository); + // Write Vectors + InputStream vectorStream = new BufferedInputStream(new VectorValuesInputStream(knnVectorValuesSupplier.get(), vectorDataType)); + log.debug("Writing {} bytes for {} docs to {}", vectorBlobLength, totalLiveDocs, blobName + VECTOR_BLOB_FILE_EXTENSION); + blobContainer.writeBlob(blobName + VECTOR_BLOB_FILE_EXTENSION, vectorStream, vectorBlobLength, true); + // Then write doc ids + writeDocIds(knnVectorValuesSupplier.get(), vectorBlobLength, totalLiveDocs, blobName, blobContainer); + } + } + + /** + * Helper method for uploading doc ids to repository, as it's re-used in both parallel and sequential upload cases + * @param knnVectorValues + * @param vectorBlobLength + * @param totalLiveDocs + * @param blobName + * @param blobContainer + * @throws IOException + */ + private void writeDocIds( + KNNVectorValues knnVectorValues, + long vectorBlobLength, + long totalLiveDocs, + String blobName, + BlobContainer blobContainer + ) throws IOException { + InputStream docStream = new BufferedInputStream(new DocIdInputStream(knnVectorValues)); + log.debug( + "Writing {} bytes for {} docs ids to {}", + vectorBlobLength, + totalLiveDocs * Integer.BYTES, + blobName + DOC_ID_FILE_EXTENSION + ); + blobContainer.writeBlob(blobName + DOC_ID_FILE_EXTENSION, docStream, totalLiveDocs * Integer.BYTES, true); + } + + /** + * Returns a {@link org.opensearch.common.StreamContext}. Intended to be invoked as a {@link org.opensearch.common.blobstore.stream.write.StreamContextSupplier}, + * which takes the partSize determined by the repository implementation and calculates the number of parts as well as handles the last part of the stream. + * + * @see DefaultVectorRepositoryAccessor#getTransferPartStreamSupplier + * + * @param partSize Size of each InputStream to be uploaded in parallel. Provided by repository implementation + * @param vectorBlobLength Total size of the vectors across all InputStreams + * @param knnVectorValuesSupplier Supplier for {@link KNNVectorValues} + * @param vectorDataType Data type of the vector (FLOAT, BYTE, BINARY) + * @return a {@link org.opensearch.common.StreamContext} with a function that will create {@link InputStream}s of {@param partSize} + */ + private StreamContext getStreamContext( + long partSize, + long vectorBlobLength, + Supplier> knnVectorValuesSupplier, + VectorDataType vectorDataType + ) { + long lastPartSize = (vectorBlobLength % partSize) != 0 ? vectorBlobLength % partSize : partSize; + int numberOfParts = (int) ((vectorBlobLength % partSize) == 0 ? vectorBlobLength / partSize : (vectorBlobLength / partSize) + 1); + return new StreamContext( + getTransferPartStreamSupplier(knnVectorValuesSupplier, vectorDataType), + partSize, + lastPartSize, + numberOfParts + ); + } + + /** + * This method handles creating {@link VectorValuesInputStream}s based on the part number, the requested size of the stream part, and the position that the stream starts at within the underlying {@link KNNVectorValues} + * + * @param knnVectorValuesSupplier Supplier for {@link KNNVectorValues} + * @param vectorDataType Data type of the vector (FLOAT, BYTE, BINARY) + * @return a function with which the repository implementation will use to create {@link VectorValuesInputStream}s of specific sizes and start positions. + */ + private CheckedTriFunction getTransferPartStreamSupplier( + Supplier> knnVectorValuesSupplier, + VectorDataType vectorDataType + ) { + return ((partNo, size, position) -> { + log.info("Creating InputStream for partNo: {}, size: {}, position: {}", partNo, size, position); + VectorValuesInputStream vectorValuesInputStream = new VectorValuesInputStream( + knnVectorValuesSupplier.get(), + vectorDataType, + position, + size + ); + return new InputStreamContainer(vectorValuesInputStream, size, position); + }); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java new file mode 100644 index 0000000000..c1a3987018 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/DocIdInputStream.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; + +/** + * {@link InputStream} implementation of doc ids backed by {@link KNNVectorValues} rather than any file. Intended for use by {@link RemoteIndexBuildStrategy} + */ +@Log4j2 +class DocIdInputStream extends InputStream { + private final KNNVectorValues knnVectorValues; + // Doc ids are 4 byte integers, byte read() only returns a single byte, so we will need to track the byte position within a doc id. + // For simplicity, and to maintain the byte ordering, we use a buffer with size of 1 int. + private ByteBuffer currentBuffer; + + /** + * Use to represent the doc ids of a {@link KNNVectorValues} as an {@link InputStream}. Expected to be used only with {@link org.opensearch.common.blobstore.BlobContainer#writeBlob}. + * @param knnVectorValues + * @throws IOException + * @see VectorValuesInputStream + */ + public DocIdInputStream(KNNVectorValues knnVectorValues) throws IOException { + this.currentBuffer = ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); + this.knnVectorValues = knnVectorValues; + initializeVectorValues(this.knnVectorValues); + reloadBuffer(); + } + + @Override + public int read() throws IOException { + if (currentBuffer == null) { + return -1; + } + + if (!currentBuffer.hasRemaining()) { + advanceAndReloadBuffer(); + if (currentBuffer == null) { + return -1; + } + } + + // Unsigned byte conversion is not technically needed as we are using a ByteBuffer, however we perform this operation still just in + // case. + return currentBuffer.get() & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (currentBuffer == null) { + return -1; + } + + int available = currentBuffer.remaining(); + if (available <= 0) { + advanceAndReloadBuffer(); + if (currentBuffer == null) { + return -1; + } + available = currentBuffer.remaining(); + } + + int bytesToRead = Math.min(available, len); + currentBuffer.get(b, off, bytesToRead); + return bytesToRead; + } + + /** + * Advances to the next doc, and then refills the buffer with the new doc. + * @throws IOException + */ + private void advanceAndReloadBuffer() throws IOException { + int docId = knnVectorValues.nextDoc(); + if (docId != -1 && docId != DocIdSetIterator.NO_MORE_DOCS) { + reloadBuffer(); + } else { + // Reset buffer to null to indicate that there are no more docs to be read + currentBuffer = null; + } + } + + /** + * Reload {@link currentBuffer} with the current doc id that {@link knnVectorValues} is pointing to + * @throws IOException + */ + private void reloadBuffer() throws IOException { + currentBuffer.clear(); + currentBuffer.putInt(knnVectorValues.docId()); + currentBuffer.position(0); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java index 8555e2ad68..72ab3aa1b2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java @@ -7,15 +7,13 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.NotImplementedException; -import org.apache.lucene.index.SegmentWriteState; import org.opensearch.common.StopWatch; +import org.opensearch.common.UUIDs; import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.index.IndexSettings; -import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; import org.opensearch.repositories.RepositoryMissingException; @@ -25,6 +23,7 @@ import java.util.function.Supplier; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; /** @@ -37,28 +36,64 @@ public class RemoteIndexBuildStrategy implements NativeIndexBuildStrategy { private final Supplier repositoriesServiceSupplier; private final NativeIndexBuildStrategy fallbackStrategy; - private static final String VECTOR_BLOB_FILE_EXTENSION = ".knnvec"; - private static final String DOC_ID_FILE_EXTENSION = ".knndid"; + private final IndexSettings indexSettings; + + static final String VECTOR_BLOB_FILE_EXTENSION = ".knnvec"; + static final String DOC_ID_FILE_EXTENSION = ".knndid"; + static final String VECTORS_PATH = "_vectors"; /** - * Public constructor - * - * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used for interacting with repository + * Public constructor, intended to be called by {@link org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory} based in + * part on the return value from {@link RemoteIndexBuildStrategy#shouldBuildIndexRemotely} + * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used to interact with a repository + * @param fallbackStrategy Delegate {@link NativeIndexBuildStrategy} used to fall back to local build + * @param indexSettings {@link IndexSettings} used to retrieve information about the index */ - public RemoteIndexBuildStrategy(Supplier repositoriesServiceSupplier, NativeIndexBuildStrategy fallbackStrategy) { + public RemoteIndexBuildStrategy( + Supplier repositoriesServiceSupplier, + NativeIndexBuildStrategy fallbackStrategy, + IndexSettings indexSettings + ) { this.repositoriesServiceSupplier = repositoriesServiceSupplier; this.fallbackStrategy = fallbackStrategy; + this.indexSettings = indexSettings; } /** - * @return whether to use the remote build feature + * @param indexSettings {@link IndexSettings} used to check if index setting is enabled for the feature + * @param vectorBlobLength The size of the vector blob, used to determine if the size threshold is met + * @return true if remote index build should be used, else false */ - public static boolean shouldBuildIndexRemotely(IndexSettings indexSettings) { + public static boolean shouldBuildIndexRemotely(IndexSettings indexSettings, long vectorBlobLength) { + if (indexSettings == null) { + return false; + } + + // If setting is not enabled, return false + if (!indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING)) { + log.debug("Remote index build is disabled for index: [{}]", indexSettings.getIndex().getName()); + return false; + } + + // If vector repo is not configured, return false String vectorRepo = KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey()); - return KNNFeatureFlags.isKNNRemoteVectorBuildEnabled() - && indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING) - && vectorRepo != null - && !vectorRepo.isEmpty(); + if (vectorRepo == null || vectorRepo.isEmpty()) { + log.debug("Vector repo is not configured, falling back to local build for index: [{}]", indexSettings.getIndex().getName()); + return false; + } + + // If size threshold is not met, return false + if (vectorBlobLength < indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING).getBytes()) { + log.debug( + "Data size [{}] is less than remote index build threshold [{}], falling back to local build for index [{}]", + vectorBlobLength, + indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING).getBytes(), + indexSettings.getIndex().getName() + ); + return false; + } + + return true; } /** @@ -73,16 +108,19 @@ public static boolean shouldBuildIndexRemotely(IndexSettings indexSettings) { */ @Override public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { - // TODO: Metrics Collection StopWatch stopWatch; long time_in_millis; try { + VectorRepositoryAccessor vectorRepositoryAccessor = new DefaultVectorRepositoryAccessor(getRepository(), indexSettings); stopWatch = new StopWatch().start(); - writeToRepository( - indexInfo.getFieldName(), - indexInfo.getKnnVectorValuesSupplier(), + // We create a new time based UUID per file in order to avoid conflicts across shards. It is also very difficult to get the + // shard id in this context. + String blobName = UUIDs.base64UUID() + "_" + indexInfo.getFieldName() + "_" + indexInfo.getSegmentWriteState().segmentInfo.name; + vectorRepositoryAccessor.writeToRepository( + blobName, indexInfo.getTotalLiveDocs(), - indexInfo.getSegmentWriteState() + indexInfo.getVectorDataType(), + indexInfo.getKnnVectorValuesSupplier() ); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); @@ -98,7 +136,7 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - readFromRepository(); + vectorRepositoryAccessor.readFromRepository(); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); } catch (Exception e) { @@ -126,25 +164,6 @@ private BlobStoreRepository getRepository() throws RepositoryMissingException { return (BlobStoreRepository) repository; } - /** - * Write relevant vector data to repository - * - * @param fieldName - * @param knnVectorValuesSupplier - * @param totalLiveDocs - * @param segmentWriteState - * @throws IOException - * @throws InterruptedException - */ - private void writeToRepository( - String fieldName, - Supplier> knnVectorValuesSupplier, - int totalLiveDocs, - SegmentWriteState segmentWriteState - ) throws IOException, InterruptedException { - throw new NotImplementedException(); - } - /** * Submit vector build request to remote vector build service * @@ -159,11 +178,4 @@ private void submitVectorBuild() { private void awaitVectorBuild() { throw new NotImplementedException(); } - - /** - * Read constructed vector file from remote repository and write to IndexOutput - */ - private void readFromRepository() { - throw new NotImplementedException(); - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorRepositoryAccessor.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorRepositoryAccessor.java new file mode 100644 index 0000000000..7d93177d2f --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorRepositoryAccessor.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import org.apache.commons.lang.NotImplementedException; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.util.function.Supplier; + +/** + * Interface which dictates how we use we interact with a {@link org.opensearch.repositories.blobstore.BlobStoreRepository} from {@link RemoteIndexBuildStrategy} + */ +public interface VectorRepositoryAccessor { + /** + * This method is responsible for writing both the vector blobs and doc ids provided by {@param knnVectorValuesSupplier} to the configured repository + * + * @param blobName Base name of the blobs we are writing, excluding file extensions + * @param totalLiveDocs Number of documents we are processing. This is used to compute the size of the blob we are writing + * @param vectorDataType Data type of the vector (FLOAT, BYTE, BINARY) + * @param knnVectorValuesSupplier Supplier for {@link org.opensearch.knn.index.vectorvalues.KNNVectorValues} + * @throws java.io.IOException + * @throws InterruptedException + */ + void writeToRepository( + String blobName, + int totalLiveDocs, + VectorDataType vectorDataType, + Supplier> knnVectorValuesSupplier + ) throws IOException, InterruptedException; + + /** + * Read constructed vector file from remote repository and write to IndexOutput + */ + default void readFromRepository() { + throw new NotImplementedException(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java new file mode 100644 index 0000000000..c46677e80f --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorValuesInputStream.java @@ -0,0 +1,208 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.opensearch.knn.index.VectorDataType.BINARY; +import static org.opensearch.knn.index.VectorDataType.BYTE; +import static org.opensearch.knn.index.VectorDataType.FLOAT; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; + +/** + * {@link InputStream} implementation backed by {@link KNNVectorValues} rather than any file. Intended for use by {@link RemoteIndexBuildStrategy} + */ +@Log4j2 +class VectorValuesInputStream extends InputStream { + + private final KNNVectorValues knnVectorValues; + // It is difficult to avoid using a buffer in this class as we need to be able to convert from float[] to byte[]. this buffer + // will be filled 1 vector at a time. + private ByteBuffer currentBuffer; + private final int bytesPerVector; + private long bytesRemaining; + private final VectorDataType vectorDataType; + + /** + * Used to represent a part of a {@link KNNVectorValues} as an {@link InputStream}. Expected to be used with + * {@link org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer#asyncBlobUpload}. The repository will interact with this class only through the constructor and the read methods. + *

+ * Note: For S3 (but generically too), the retryable input stream is backed by a buffer with the same size as the stream, so all bytes are loaded onto heap + * at once (16mb chunks by default) when a given {@link VectorValuesInputStream} is being processed. + *

+ *

+ * Note: The S3 implementation will only request up to {@param size} bytes from this InputStream. However, that is implementation specific and may not be + * true for all implementations, so we do our own size enforcement here as well. + *

+ * + * @param knnVectorValues + * @param vectorDataType + * @param startPosition + * @param size + * @throws IOException + * @see DocIdInputStream + */ + public VectorValuesInputStream(KNNVectorValues knnVectorValues, VectorDataType vectorDataType, long startPosition, long size) + throws IOException { + this.bytesRemaining = size; + this.knnVectorValues = knnVectorValues; + this.vectorDataType = vectorDataType; + initializeVectorValues(this.knnVectorValues); + this.bytesPerVector = this.knnVectorValues.bytesPerVector(); + // We use currentBuffer == null to indicate that there are no more vectors to be read + this.currentBuffer = ByteBuffer.allocate(bytesPerVector).order(ByteOrder.LITTLE_ENDIAN); + // Position the InputStream at the specific byte within the specific vector that startPosition references + setPosition(startPosition); + } + + /** + * Used to represent the entire {@link KNNVectorValues} as a single {@link InputStream}. Expected to be used with + * {@link org.opensearch.common.blobstore.BlobContainer#writeBlob} + * + * @param knnVectorValues + * @param vectorDataType + * @throws IOException + * @see DocIdInputStream + */ + public VectorValuesInputStream(KNNVectorValues knnVectorValues, VectorDataType vectorDataType) throws IOException { + this(knnVectorValues, vectorDataType, 0, Long.MAX_VALUE); + } + + @Override + public int read() throws IOException { + if (bytesRemaining <= 0 || currentBuffer == null) { + return -1; + } + + if (!currentBuffer.hasRemaining()) { + advanceAndReloadBuffer(); + if (currentBuffer == null) { + return -1; + } + } + + bytesRemaining--; + // Unsigned byte conversion is not technically needed as we are using a ByteBuffer, however we perform this operation still just in + // case. + return currentBuffer.get() & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (bytesRemaining <= 0 || currentBuffer == null) { + return -1; + } + + int available = currentBuffer.remaining(); + if (available <= 0) { + advanceAndReloadBuffer(); + if (currentBuffer == null) { + return -1; + } + available = currentBuffer.remaining(); + } + + int bytesToRead = Math.min(available, len); + int bytesActuallyRead = (int) Math.min(bytesRemaining, bytesToRead); + currentBuffer.get(b, off, bytesActuallyRead); + bytesRemaining -= bytesActuallyRead; + return bytesActuallyRead; + } + + /** + * This class does not support skipping. Instead, use {@link VectorValuesInputStream#setPosition}. + * + * @param n the number of bytes to be skipped. + * @return + * @throws IOException + */ + @Override + public long skip(long n) throws IOException { + throw new UnsupportedOperationException("VectorValuesInputStream does not support skip"); + } + + /** + * Advances n bytes forward in the knnVectorValues. + * Note: {@link KNNVectorValues#advance} is not supported when we are merging segments, so we do not use it here. + * Note: {@link KNNVectorValues#nextDoc} is relatively efficient, but {@link KNNVectorValues#getVector} may + * perform a disk read, so we avoid using {@link VectorValuesInputStream#reloadBuffer()} here. + * + * @param n + * @return + * @throws IOException + */ + private void setPosition(long n) throws IOException { + if (currentBuffer.position() != 0) { + throw new UnsupportedOperationException("setPosition is only supported from the start of a vector"); + } + + long bytesSkipped = 0; + int vectorsToSkip = (int) (n / bytesPerVector); + log.debug("Skipping {} bytes, {} vectors", n, vectorsToSkip); + int docId = knnVectorValues.docId(); + while (docId != -1 && docId != DocIdSetIterator.NO_MORE_DOCS && vectorsToSkip > 0) { + docId = knnVectorValues.nextDoc(); + bytesSkipped += bytesPerVector; + vectorsToSkip--; + } + + // After skipping the correct number of vectors, fill the buffer with the current vector + reloadBuffer(); + + // Advance to the correct position within the current vector + long remainingBytes = n - bytesSkipped; + if (remainingBytes > 0) { + currentBuffer.position((int) remainingBytes); + } + } + + /** + * Reload {@link currentBuffer} with the current vector that {@link knnVectorValues} is pointing to + * @throws IOException + */ + private void reloadBuffer() throws IOException { + currentBuffer.clear(); + if (vectorDataType == FLOAT) { + float[] floatVector = ((KNNFloatVectorValues) knnVectorValues).getVector(); + currentBuffer.asFloatBuffer().put(floatVector); + } else if (vectorDataType == BYTE) { + byte[] byteVector = ((KNNByteVectorValues) knnVectorValues).getVector(); + currentBuffer.put(byteVector); + } else if (vectorDataType == BINARY) { + byte[] binaryVector = ((KNNBinaryVectorValues) knnVectorValues).getVector(); + currentBuffer.put(binaryVector); + } else { + throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); + } + // Reset buffer to head for future reads + currentBuffer.position(0); + } + + /** + * Advances to the next doc, and then refills the buffer with the new doc. + * @throws IOException + */ + private void advanceAndReloadBuffer() throws IOException { + int docId = knnVectorValues.nextDoc(); + if (docId != -1 && docId != DocIdSetIterator.NO_MORE_DOCS) { + reloadBuffer(); + } else { + // Reset buffer to null to indicate that there are no more docs to be read + currentBuffer = null; + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/package-info.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/package-info.java new file mode 100644 index 0000000000..9db94445cd --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/package-info.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Package which includes the classes used for building vector indices remotely. + *

+ * For repository uploads, there are 2 methods we can use -- [1] {@link org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer#asyncBlobUpload}, which uses multiple {@link java.io.InputStream}s + * to perform blob upload in parallel across streams and [2] {@link org.opensearch.common.blobstore.BlobContainer#writeBlob}, which uses a single {@link java.io.InputStream} to perform blob upload sequentially. + *

+ *

+ * In order to optimize the performance of vector blob uploads, we use the parallel upload method to upload vector blobs, as since doc id blobs will be relatively small we only use the sequential upload method there. + * The parallel blob upload method takes in {@link org.opensearch.common.blobstore.stream.write.WriteContext}, which takes in a {@link org.opensearch.common.blobstore.stream.write.StreamContextSupplier} + * that subsequently creates {@link org.opensearch.common.io.InputStreamContainer}s of a part size determined by the repository. + *

+ *

+ * We are splitting the {@link org.opensearch.knn.index.vectorvalues.KNNVectorValues} into N streams, however since it is an iterator we will need to create N instances in order to iterator through the vector + * values in parallel. {@link org.opensearch.knn.index.codec.nativeindex.remote.VectorValuesInputStream} takes in both a position and a size in the constructor arguments, which will iterate to the exact + * byte specified by position and set the head of the InputStream to that position. The stream will then only allow size bytes to be read from it. + *

+ *

+ * The part size (and therefore number of parts) is determined by the repository implementation, so from this package we are only responsible for creating correctly sized and positioned InputStreams based on the + * the part size requested by the repository. + *

+ */ +package org.opensearch.knn.index.codec.nativeindex.remote; diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java new file mode 100644 index 0000000000..20106760bc --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/DefaultVectorRepositoryAccessorTests.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import org.mockito.Mockito; +import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.BlobContainer; +import org.opensearch.common.blobstore.BlobPath; +import org.opensearch.common.blobstore.BlobStore; +import org.opensearch.common.blobstore.fs.FsBlobStore; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.repositories.blobstore.BlobStoreRepository; + +import java.io.IOException; +import java.nio.file.Path; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.DOC_ID_FILE_EXTENSION; +import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.VECTOR_BLOB_FILE_EXTENSION; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; + +public class DefaultVectorRepositoryAccessorTests extends RemoteIndexBuildTests { + + /** + * Test that whenever an non AsyncMultiStreamBlobContainer is used, writeBlob is invoked twice + */ + public void testRepositoryInteractionWithAsyncMultiStreamBlobContainer() throws IOException, InterruptedException { + RepositoriesService repositoriesService = mock(RepositoriesService.class); + BlobStoreRepository mockRepository = mock(BlobStoreRepository.class); + BlobPath testBasePath = new BlobPath().add("testBasePath"); + BlobStore mockBlobStore = mock(BlobStore.class); + + when(repositoriesService.repository(any())).thenReturn(mockRepository); + when(mockRepository.basePath()).thenReturn(testBasePath); + when(mockRepository.blobStore()).thenReturn(mockBlobStore); + + BlobContainer testContainer = Mockito.spy(new TestBlobContainer(mock(FsBlobStore.class), testBasePath, mock(Path.class))); + when(mockBlobStore.blobContainer(any())).thenReturn(testContainer); + + VectorRepositoryAccessor objectUnderTest = new DefaultVectorRepositoryAccessor(mockRepository, mock(IndexSettings.class)); + + String BLOB_NAME = "test_blob"; + int NUM_DOCS = 100; + objectUnderTest.writeToRepository(BLOB_NAME, NUM_DOCS, VectorDataType.FLOAT, knnVectorValuesSupplier); + + KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + initializeVectorValues(knnVectorValues); + verify(testContainer).writeBlob( + eq(BLOB_NAME + VECTOR_BLOB_FILE_EXTENSION), + any(), + eq((long) NUM_DOCS * knnVectorValues.bytesPerVector()), + eq(true) + ); + verify(testContainer).writeBlob(eq(BLOB_NAME + DOC_ID_FILE_EXTENSION), any(), eq((long) NUM_DOCS * Integer.BYTES), eq(true)); + verify(mockBlobStore).blobContainer(any()); + verify(mockRepository).basePath(); + } + + /** + * Test that whenever an AsyncMultiStreamBlobContainer is used, both asyncBlobUpload and writeBlob are invoked once and only once + */ + public void testRepositoryInteractionWithBlobContainer() throws IOException, InterruptedException { + RepositoriesService repositoriesService = mock(RepositoriesService.class); + BlobStoreRepository mockRepository = mock(BlobStoreRepository.class); + BlobPath testBasePath = new BlobPath().add("testBasePath"); + BlobStore mockBlobStore = mock(BlobStore.class); + + when(repositoriesService.repository(any())).thenReturn(mockRepository); + when(mockRepository.basePath()).thenReturn(testBasePath); + when(mockRepository.blobStore()).thenReturn(mockBlobStore); + + AsyncMultiStreamBlobContainer testContainer = Mockito.spy( + new TestAsyncBlobContainer(mock(FsBlobStore.class), testBasePath, mock(Path.class), false) + ); + when(mockBlobStore.blobContainer(any())).thenReturn(testContainer); + + VectorRepositoryAccessor objectUnderTest = new DefaultVectorRepositoryAccessor(mockRepository, mock(IndexSettings.class)); + + String BLOB_NAME = "test_blob"; + int NUM_DOCS = 100; + objectUnderTest.writeToRepository(BLOB_NAME, NUM_DOCS, VectorDataType.FLOAT, knnVectorValuesSupplier); + + verify(testContainer).asyncBlobUpload(any(), any()); + verify(testContainer).writeBlob(eq(BLOB_NAME + DOC_ID_FILE_EXTENSION), any(), eq((long) NUM_DOCS * Integer.BYTES), eq(true)); + verify(mockBlobStore).blobContainer(any()); + verify(mockRepository).basePath(); + } + + /** + * Test that when an exception is thrown during asyncBlobUpload, the exception is rethrown. + */ + public void testAsyncUploadThrowsException() throws InterruptedException, IOException { + RepositoriesService repositoriesService = mock(RepositoriesService.class); + BlobStoreRepository mockRepository = mock(BlobStoreRepository.class); + BlobPath testBasePath = new BlobPath().add("testBasePath"); + BlobStore mockBlobStore = mock(BlobStore.class); + + when(repositoriesService.repository(any())).thenReturn(mockRepository); + when(mockRepository.basePath()).thenReturn(testBasePath); + when(mockRepository.blobStore()).thenReturn(mockBlobStore); + + AsyncMultiStreamBlobContainer testContainer = Mockito.spy( + new TestAsyncBlobContainer(mock(FsBlobStore.class), testBasePath, mock(Path.class), true) + ); + when(mockBlobStore.blobContainer(any())).thenReturn(testContainer); + + VectorRepositoryAccessor objectUnderTest = new DefaultVectorRepositoryAccessor(mockRepository, mock(IndexSettings.class)); + + String BLOB_NAME = "test_blob"; + int NUM_DOCS = 100; + assertThrows( + IOException.class, + () -> objectUnderTest.writeToRepository(BLOB_NAME, NUM_DOCS, VectorDataType.FLOAT, knnVectorValuesSupplier) + ); + + verify(testContainer).asyncBlobUpload(any(), any()); + // Doc ids should still get written because exception is handled after awaiting on asyncBlobUpload + verify(testContainer).writeBlob(eq(BLOB_NAME + DOC_ID_FILE_EXTENSION), any(), eq((long) NUM_DOCS * Integer.BYTES), eq(true)); + verify(mockBlobStore).blobContainer(any()); + verify(mockRepository).basePath(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/KnnVectorValuesInputStreamTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/KnnVectorValuesInputStreamTests.java new file mode 100644 index 0000000000..612b9bbe22 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/KnnVectorValuesInputStreamTests.java @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; +import static org.opensearch.knn.index.vectorvalues.TestVectorValues.getRandomByteVector; +import static org.opensearch.knn.index.vectorvalues.TestVectorValues.getRandomVector; + +public class KnnVectorValuesInputStreamTests extends KNNTestCase { + + /** + * Tests that reading doc IDs out of a DocIdInputStream yields the same results as reading the doc ids + */ + public void testDocIdInputStream() throws IOException { + int NUM_DOCS = randomIntBetween(1, 1000); + + List vectorValues = getRandomFloatVectors(NUM_DOCS, 1); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValuesForStream = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + InputStream docIdInputStream = new DocIdInputStream(knnVectorValuesForStream); + + // 1. Read all input stream bytes + byte[] docIdStreamBytes = docIdInputStream.readAllBytes(); + + // 2. Read all of knnVectorValues into a byte buffer: + ByteBuffer buffer = ByteBuffer.allocate(NUM_DOCS * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); + initializeVectorValues(knnVectorValues); + int docId = knnVectorValues.docId(); + while (docId != -1 && docId != DocIdSetIterator.NO_MORE_DOCS) { + buffer.putInt(docId); + docId = knnVectorValues.nextDoc(); + } + + // Check the 2 arrays have the same content + assertArrayEquals(docIdStreamBytes, buffer.array()); + } + + /** + * Tests that reading float vectors out of a VectorValuesInputStream yields the same results as reading the doc vectors + */ + public void testFloatVectorValuesInputStream() throws IOException { + int NUM_DOCS = randomIntBetween(1, 1000); + int NUM_DIMENSION = randomIntBetween(1, 1000); + + List vectorValues = getRandomFloatVectors(NUM_DOCS, NUM_DIMENSION); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValuesForStream = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + InputStream vectorValuesInputStream = new VectorValuesInputStream(knnVectorValuesForStream, VectorDataType.FLOAT); + + // 1. Read all input stream bytes + byte[] vectorStreamBytes = vectorValuesInputStream.readAllBytes(); + FloatBuffer vectorStreamFloats = ByteBuffer.wrap(vectorStreamBytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + + // 2. Read all of knnVectorValues into a byte buffer: + initializeVectorValues(knnVectorValues); + FloatBuffer expectedBuffer = ByteBuffer.allocate(NUM_DOCS * knnVectorValues.bytesPerVector()) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer(); + int docId = knnVectorValues.docId(); + while (docId != -1 && docId != DocIdSetIterator.NO_MORE_DOCS) { + expectedBuffer.put(knnVectorValues.getVector()); + docId = knnVectorValues.nextDoc(); + } + expectedBuffer.position(0); + + // Check the 2 arrays have the same content + assertEquals(expectedBuffer, vectorStreamFloats); + } + + public void testByteVectorValuesInputStream() throws IOException { + int NUM_DOCS = randomIntBetween(1, 1000); + int NUM_DIMENSION = randomIntBetween(1, 1000); + + List vectorValues = getRandomByteVectors(NUM_DOCS, NUM_DIMENSION); + final TestVectorValues.PreDefinedByteVectorValues randomVectorValues = new TestVectorValues.PreDefinedByteVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValuesForStream = KNNVectorValuesFactory.getVectorValues( + VectorDataType.BYTE, + randomVectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, randomVectorValues); + + InputStream vectorValuesInputStream = new VectorValuesInputStream(knnVectorValuesForStream, VectorDataType.BYTE); + + // 1. Read all input stream bytes + byte[] vectorStreamBytes = vectorValuesInputStream.readAllBytes(); + + // 2. Read all of knnVectorValues into a byte buffer: + initializeVectorValues(knnVectorValues); + ByteBuffer expectedBuffer = ByteBuffer.allocate(NUM_DOCS * knnVectorValues.bytesPerVector()).order(ByteOrder.LITTLE_ENDIAN); + int docId = knnVectorValues.docId(); + while (docId != -1 && docId != DocIdSetIterator.NO_MORE_DOCS) { + expectedBuffer.put(knnVectorValues.getVector()); + docId = knnVectorValues.nextDoc(); + } + + // Check the 2 arrays have the same content + assertArrayEquals(expectedBuffer.array(), vectorStreamBytes); + } + + /** + * Tests that creating N VectorValuesInputStream over the same KNNVectorValues yields the same result as reading it all from the same VectorValuesInputStream + */ + public void testMultiPartVectorValueInputStream() throws IOException { + final int NUM_DOCS = randomIntBetween(100, 1000); + final int NUM_DIMENSION = randomIntBetween(1, 1000); + final int NUM_PARTS = randomIntBetween(1, NUM_DOCS / 10); + final int PART_SIZE; + final int LAST_PART_SIZE; + + List vectorValues = getRandomFloatVectors(NUM_DOCS, NUM_DIMENSION); + final Supplier randomVectorValuesSupplier = + () -> new TestVectorValues.PreDefinedFloatVectorValues(vectorValues); + + final Supplier> knnVectorValuesSupplier = () -> KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValuesSupplier.get() + ); + + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + initializeVectorValues(knnVectorValues); + int vectorBlobLength = knnVectorValues.bytesPerVector() * NUM_DOCS; + PART_SIZE = vectorBlobLength / NUM_PARTS; + LAST_PART_SIZE = (vectorBlobLength % PART_SIZE) != 0 ? vectorBlobLength % PART_SIZE : PART_SIZE; + + // 1. Create NUM_PARTS input streams + final List streamList = new ArrayList<>(NUM_PARTS); + for (int partNumber = 0; partNumber < NUM_PARTS; partNumber++) { + System.out.println(partNumber); + streamList.add( + new VectorValuesInputStream(knnVectorValuesSupplier.get(), VectorDataType.FLOAT, (long) partNumber * PART_SIZE, PART_SIZE) + ); + } + + // Last part only needs to be written if it is a different size, otherwise previous for loop would cover it + if (LAST_PART_SIZE != PART_SIZE) { + streamList.add( + new VectorValuesInputStream( + knnVectorValuesSupplier.get(), + VectorDataType.FLOAT, + vectorBlobLength - LAST_PART_SIZE, + LAST_PART_SIZE + ) + ); + } + + // 2. Read all input stream parts into the same buffer + ByteBuffer testBuffer = ByteBuffer.allocate(vectorBlobLength).order(ByteOrder.LITTLE_ENDIAN); + for (VectorValuesInputStream stream : streamList) { + byte[] partBytes = stream.readAllBytes(); + testBuffer.put(partBytes); + } + + // 3. Read all knnVectorValues into a buffer: + VectorValuesInputStream expectedStream = new VectorValuesInputStream(knnVectorValuesSupplier.get(), VectorDataType.FLOAT); + assertArrayEquals(expectedStream.readAllBytes(), testBuffer.array()); + } + + /** + * Tests that invoking {@link VectorValuesInputStream#read()} N times yields the same results as {@link VectorValuesInputStream#read(byte[], 0, N)} + */ + public void testVectorValuesInputStreamReadByte() throws IOException { + final int NUM_DIMENSION = randomIntBetween(1, 1000); + // We use only 1 doc here because VectorValuesInputStream.read will only read up to 1 vector maximum at a time. + // To read all the vectors we would need to call readNBytes, however we want to specifically test the methods we have overridden + // here + List vectorValues = getRandomFloatVectors(1, NUM_DIMENSION); + + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + + // Read stream byte by byte + final KNNVectorValues knnVectorValuesForReadByte = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + initializeVectorValues(knnVectorValuesForReadByte); + int vectorBlobLength = knnVectorValuesForReadByte.bytesPerVector(); + InputStream vectorStreamForReadByte = new VectorValuesInputStream(knnVectorValuesForReadByte, VectorDataType.FLOAT); + ByteBuffer bufferReadByByte = ByteBuffer.allocate(vectorBlobLength).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < vectorBlobLength; i++) { + bufferReadByByte.put((byte) vectorStreamForReadByte.read()); + } + bufferReadByByte.position(0); + + // Read stream with entire length + final KNNVectorValues knnVectorValuesForRead = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + InputStream vectorStreamForRead = new VectorValuesInputStream(knnVectorValuesForRead, VectorDataType.FLOAT); + ByteBuffer bufferRead = ByteBuffer.allocate(vectorBlobLength).order(ByteOrder.LITTLE_ENDIAN); + assertEquals(vectorBlobLength, vectorStreamForRead.read(bufferRead.array(), 0, vectorBlobLength)); + + assertArrayEquals(bufferRead.array(), bufferReadByByte.array()); + } + + /** + * Tests that invoking {@link DocIdInputStream#read()} N times yields the same results as {@link DocIdInputStream#read(byte[], 0, N)} + */ + public void testDocIdInputStreamReadByte() throws IOException { + // We use only 1 doc here because DocIdInputStream.read will only read up to 1 doc id at a time + // To read all the vectors we would need to call readNBytes, however we want to specifically test the methods we have overridden + // here + List vectorValues = getRandomFloatVectors(1, 1); + + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + + // Read stream byte by byte + final KNNVectorValues docIdsForReadByte = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + initializeVectorValues(docIdsForReadByte); + int blobLength = Integer.BYTES; + InputStream docStreamForReadByte = new DocIdInputStream(docIdsForReadByte); + ByteBuffer bufferReadByByte = ByteBuffer.allocate(blobLength).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < blobLength; i++) { + bufferReadByByte.put((byte) docStreamForReadByte.read()); + } + bufferReadByByte.position(0); + + // Read stream with entire length + final KNNVectorValues docIdsForRead = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + InputStream docStreamForRead = new DocIdInputStream(docIdsForRead); + ByteBuffer bufferRead = ByteBuffer.allocate(blobLength).order(ByteOrder.LITTLE_ENDIAN); + assertEquals(blobLength, docStreamForRead.read(bufferRead.array(), 0, blobLength)); + + assertArrayEquals(bufferRead.array(), bufferReadByByte.array()); + } + + private List getRandomFloatVectors(int numDocs, int dimension) { + ArrayList vectorValues = new ArrayList<>(); + for (int i = 0; i < numDocs; i++) { + vectorValues.add(getRandomVector(dimension)); + } + return vectorValues; + } + + private List getRandomByteVectors(int numDocs, int dimension) { + ArrayList vectorValues = new ArrayList<>(); + for (int i = 0; i < numDocs; i++) { + vectorValues.add(getRandomByteVector(dimension)); + } + return vectorValues; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java index 1589021f67..fc2281f5e3 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java @@ -5,63 +5,78 @@ package org.opensearch.knn.index.codec.nativeindex.remote; -import org.mockito.Mockito; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; -import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.store.IndexOutputWithBuffer; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; -import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.common.SetOnce; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.core.index.Index; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.RepositoryMissingException; -import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.List; -import java.util.Map; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; -public class RemoteIndexBuildStrategyTests extends OpenSearchTestCase { +public class RemoteIndexBuildStrategyTests extends RemoteIndexBuildTests { + private static final String TEST_INDEX = "test-index"; - static int fallbackCounter = 0; - - private static class TestIndexBuildStrategy implements NativeIndexBuildStrategy { - - @Override - public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { - fallbackCounter++; - } - } + /** + * Test that we fallback to the fallback NativeIndexBuildStrategy when an exception is thrown + */ + public void testRemoteIndexBuildStrategyFallback() throws IOException { + RepositoriesService repositoriesService = mock(RepositoriesService.class); + when(repositoriesService.repository(any())).thenThrow(new RepositoryMissingException("Fallback")); - public void testFallback() throws IOException { - List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); - final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - vectorValues + final SetOnce fallback = new SetOnce<>(); + RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy( + () -> repositoriesService, + new TestIndexBuildStrategy(fallback), + mock(IndexSettings.class) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + objectUnderTest.buildAndWriteIndex(buildIndexParams); + assertTrue(fallback.get()); + } - RepositoriesService repositoriesService = mock(RepositoriesService.class); - when(repositoriesService.repository(any())).thenThrow(new RepositoryMissingException("Fallback")); + public void testShouldBuildIndexRemotely() { + IndexSettings indexSettings; + ClusterSettings clusterSettings; + Index index = mock(Index.class); + when(index.getName()).thenReturn(TEST_INDEX); + // Check index settings null + assertFalse(RemoteIndexBuildStrategy.shouldBuildIndexRemotely(null, 0)); - RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy(() -> repositoriesService, new TestIndexBuildStrategy()); + // Check index setting disabled + indexSettings = mock(IndexSettings.class); + when(indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING)).thenReturn(false); + when(indexSettings.getIndex()).thenReturn(index); + assertFalse(RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings, 0)); - IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); + // Check repo not configured + indexSettings = mock(IndexSettings.class); + when(indexSettings.getIndex()).thenReturn(index); + when(indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING)).thenReturn(true); + clusterSettings = mock(ClusterSettings.class); + when(clusterSettings.get(KNN_REMOTE_VECTOR_REPO_SETTING)).thenReturn(""); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + KNNSettings.state().setClusterService(clusterService); + assertFalse(RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings, 0)); - BuildIndexParams buildIndexParams = BuildIndexParams.builder() - .indexOutputWithBuffer(indexOutputWithBuffer) - .knnEngine(KNNEngine.FAISS) - .vectorDataType(VectorDataType.FLOAT) - .parameters(Map.of("index", "param")) - .knnVectorValuesSupplier(() -> knnVectorValues) - .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) - .build(); + // Check size threshold + int BYTE_SIZE = randomIntBetween(50, 1000); + when(indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING)).thenReturn(new ByteSizeValue(BYTE_SIZE)); + assertFalse(RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings, randomInt(BYTE_SIZE - 1))); - objectUnderTest.buildAndWriteIndex(buildIndexParams); - assertEquals(1, fallbackCounter); + // Check happy path + clusterSettings = mock(ClusterSettings.class); + when(clusterSettings.get(KNN_REMOTE_VECTOR_REPO_SETTING)).thenReturn("test-vector-repo"); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + KNNSettings.state().setClusterService(clusterService); + assertTrue(RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings, randomIntBetween(BYTE_SIZE - 1, BYTE_SIZE * 2))); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java new file mode 100644 index 0000000000..904eb2c927 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.Version; +import org.junit.Before; +import org.mockito.Mockito; +import org.opensearch.common.SetOnce; +import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.BlobPath; +import org.opensearch.common.blobstore.DeleteResult; +import org.opensearch.common.blobstore.fs.FsBlobContainer; +import org.opensearch.common.blobstore.fs.FsBlobStore; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.blobstore.stream.write.WriteContext; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; + +/** + * Base test class for remote index build tests + */ +abstract class RemoteIndexBuildTests extends KNNTestCase { + + final List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(vectorValues); + final Supplier> knnVectorValuesSupplier = KNNVectorValuesFactory.getVectorValuesSupplier( + VectorDataType.FLOAT, + randomVectorValues + ); + final IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); + final String segmentName = "test-segment-name"; + final SegmentInfo segmentInfo = new SegmentInfo( + mock(Directory.class), + mock(Version.class), + mock(Version.class), + segmentName, + 0, + false, + false, + mock(Codec.class), + mock(Map.class), + new byte[16], + mock(Map.class), + mock(Sort.class) + ); + final SegmentWriteState segmentWriteState = new SegmentWriteState( + mock(InfoStream.class), + mock(Directory.class), + segmentInfo, + mock(FieldInfos.class), + null, + mock(IOContext.class) + ); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + final BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexOutputWithBuffer(indexOutputWithBuffer) + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .knnVectorValuesSupplier(knnVectorValuesSupplier) + .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) + .segmentWriteState(segmentWriteState) + .build(); + + record TestIndexBuildStrategy(SetOnce fallback) implements NativeIndexBuildStrategy { + @Override + public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { + fallback.set(true); + } + } + + static class TestAsyncBlobContainer extends FsBlobContainer implements AsyncMultiStreamBlobContainer { + private final boolean throwsException; + + public TestAsyncBlobContainer(FsBlobStore blobStore, BlobPath blobPath, Path path, boolean throwsException) { + super(blobStore, blobPath, path); + this.throwsException = throwsException; + + } + + @Override + public void asyncBlobUpload(WriteContext writeContext, ActionListener actionListener) throws IOException { + if (this.throwsException) { + actionListener.onFailure(new IOException("Test Exception")); + } else { + actionListener.onResponse(null); + } + } + + @Override + public void writeBlob(String blobName, InputStream inputStream, long blobSize, boolean failIfAlreadyExists) throws IOException {} + + @Override + public void readBlobAsync(String s, ActionListener actionListener) {} + + @Override + public boolean remoteIntegrityCheckSupported() { + return false; + } + + @Override + public void deleteAsync(ActionListener actionListener) {} + + @Override + public void deleteBlobsAsyncIgnoringIfNotExists(List list, ActionListener actionListener) {} + } + + static class TestBlobContainer extends FsBlobContainer { + + public TestBlobContainer(FsBlobStore blobStore, BlobPath blobPath, Path path) { + super(blobStore, blobPath, path); + } + + @Override + public void writeBlob(String blobName, InputStream inputStream, long blobSize, boolean failIfAlreadyExists) {} + } + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + ClusterSettings clusterSettings = mock(ClusterSettings.class); + when(clusterSettings.get(KNN_REMOTE_VECTOR_REPO_SETTING)).thenReturn("test-repo-name"); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + KNNSettings.state().setClusterService(clusterService); + } +} diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java index 337ab6c489..e63418b4de 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java @@ -15,6 +15,7 @@ import java.util.Arrays; import java.util.List; +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomByte; import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat; public class TestVectorValues { @@ -407,4 +408,12 @@ public static float[] getRandomVector(int dimension) { } return data; } + + public static byte[] getRandomByteVector(int dimension) { + byte[] data = new byte[dimension]; + for (int i = 0; i < dimension; i++) { + data[i] = randomByte(); + } + return data; + } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index e1b0e27005..e0992f38ee 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -112,6 +112,7 @@ import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD; import static org.opensearch.knn.index.SpaceType.L2; import static org.opensearch.knn.index.engine.KNNEngine.FAISS; import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; @@ -993,8 +994,10 @@ protected Settings buildKNNIndexSettings(int approximateThreshold) { .put(KNN_INDEX, true) .put(INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, approximateThreshold); - if (isRemoteIndexBuildSupported(getBWCVersion())) { - builder.put(KNN_INDEX_REMOTE_VECTOR_BUILD, randomBoolean()); + // Randomly enable remote index build feature to test fallbacks + if (isRemoteIndexBuildSupported(getBWCVersion()) && randomBoolean()) { + builder.put(KNN_INDEX_REMOTE_VECTOR_BUILD, true); + builder.put(KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD, "0mb"); } return builder.build(); }