Skip to content

Commit

Permalink
Introduce RemoteIndexBuildStrategy, refactor NativeIndexBuildStrategy…
Browse files Browse the repository at this point in the history
… to accept vector value supplier (opensearch-project#2525)

Signed-off-by: Jay Deng <[email protected]>
  • Loading branch information
jed326 authored Feb 19, 2025
1 parent 6203ba1 commit f9d8b53
Show file tree
Hide file tree
Showing 30 changed files with 898 additions and 330 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import java.util.Locale;
import java.util.Optional;

import static org.opensearch.knn.TestUtils.*;
import static org.opensearch.knn.TestUtils.BWC_VERSION;
import static org.opensearch.knn.TestUtils.CLIENT_TIMEOUT_VALUE;
import static org.opensearch.knn.TestUtils.KNN_BWC_PREFIX;
import static org.opensearch.knn.TestUtils.RESTART_UPGRADE_OLD_CLUSTER;

public abstract class AbstractRestartUpgradeTestCase extends KNNRestTestCase {
protected String testIndex;
Expand Down Expand Up @@ -58,6 +61,7 @@ protected static final boolean isRunningAgainstOldCluster() {
return Boolean.parseBoolean(System.getProperty(RESTART_UPGRADE_OLD_CLUSTER));
}

@Override
protected final Optional<String> getBWCVersion() {
return Optional.ofNullable(System.getProperty(BWC_VERSION, null));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@
package org.opensearch.knn.bwc;

import org.junit.Before;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.test.rest.OpenSearchRestTestCase;

import java.util.Locale;
import java.util.Optional;

import static org.opensearch.knn.TestUtils.*;
import static org.opensearch.knn.TestUtils.BWCSUITE_CLUSTER;
import static org.opensearch.knn.TestUtils.BWC_VERSION;
import static org.opensearch.knn.TestUtils.CLIENT_TIMEOUT_VALUE;
import static org.opensearch.knn.TestUtils.KNN_BWC_PREFIX;
import static org.opensearch.knn.TestUtils.MIXED_CLUSTER;
import static org.opensearch.knn.TestUtils.OLD_CLUSTER;
import static org.opensearch.knn.TestUtils.ROLLING_UPGRADE_FIRST_ROUND;
import static org.opensearch.knn.TestUtils.UPGRADED_CLUSTER;

public abstract class AbstractRollingUpgradeTestCase extends KNNRestTestCase {
protected String testIndex;
Expand Down Expand Up @@ -81,6 +88,7 @@ protected final boolean isFirstMixedRound() {
return Boolean.parseBoolean(System.getProperty(ROLLING_UPGRADE_FIRST_ROUND, "false"));
}

@Override
protected final Optional<String> getBWCVersion() {
return Optional.ofNullable(System.getProperty(BWC_VERSION, null));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class KNNFeatureFlags {

// Feature flags
private static final String KNN_FORCE_EVICT_CACHE_ENABLED = "knn.feature.cache.force_evict.enabled";
private static final String KNN_REMOTE_VECTOR_BUILD = "knn.feature.remote_index_build.enabled";

@VisibleForTesting
public static final Setting<Boolean> KNN_FORCE_EVICT_CACHE_ENABLED_SETTING = Setting.boolSetting(
Expand All @@ -35,8 +36,18 @@ public class KNNFeatureFlags {
Dynamic
);

/**
* Feature flag to control remote index build at the cluster level
*/
public static final Setting<Boolean> KNN_REMOTE_VECTOR_BUILD_SETTING = Setting.boolSetting(
KNN_REMOTE_VECTOR_BUILD,
false,
NodeScope,
Dynamic
);

public static List<Setting<?>> getFeatureFlags() {
return ImmutableList.of(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING);
return ImmutableList.of(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING, KNN_REMOTE_VECTOR_BUILD_SETTING);
}

/**
Expand All @@ -46,4 +57,11 @@ public static List<Setting<?>> getFeatureFlags() {
public static boolean isForceEvictCacheEnabled() {
return Booleans.parseBoolean(KNNSettings.state().getSettingValue(KNN_FORCE_EVICT_CACHE_ENABLED).toString(), false);
}

/**
* @return true if remote vector index build feature flag is enabled
*/
public static boolean isKNNRemoteVectorBuildEnabled() {
return Booleans.parseBooleanStrict(KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_BUILD).toString(), false);
}
}
31 changes: 29 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@

import static java.util.stream.Collectors.toUnmodifiableMap;
import static org.opensearch.common.settings.Setting.Property.Dynamic;
import static org.opensearch.common.settings.Setting.Property.Final;
import static org.opensearch.common.settings.Setting.Property.IndexScope;
import static org.opensearch.common.settings.Setting.Property.NodeScope;
import static org.opensearch.common.settings.Setting.Property.Final;
import static org.opensearch.common.settings.Setting.Property.UnmodifiableOnRestore;
import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio;
import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue;
Expand Down Expand Up @@ -94,6 +94,8 @@ public class KNNSettings {
public static final String KNN_FAISS_AVX512_SPR_DISABLED = "knn.faiss.avx512_spr.disabled";
public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled";
public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled";
public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled";
public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo";

/**
* Default setting values
Expand Down Expand Up @@ -371,6 +373,21 @@ public class KNNSettings {
NodeScope
);

/**
* Index level setting to control whether remote index build is enabled or not.
*/
public static final Setting<Boolean> KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING = Setting.boolSetting(
KNN_INDEX_REMOTE_VECTOR_BUILD,
false,
Dynamic,
IndexScope
);

/**
* Cluster level setting which indicates the repository that the remote index build should write to.
*/
public static final Setting<String> KNN_REMOTE_VECTOR_REPO_SETTING = Setting.simpleString(KNN_REMOTE_VECTOR_REPO, Dynamic, NodeScope);

/**
* Dynamic settings
*/
Expand Down Expand Up @@ -525,6 +542,14 @@ private Setting<?> getSetting(String key) {
return KNN_DERIVED_SOURCE_ENABLED_SETTING;
}

if (KNN_INDEX_REMOTE_VECTOR_BUILD.equals(key)) {
return KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING;
}

if (KNN_REMOTE_VECTOR_REPO.equals(key)) {
return KNN_REMOTE_VECTOR_REPO_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -550,7 +575,9 @@ public List<Setting<?>> getSettings() {
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING,
KNN_DERIVED_SOURCE_ENABLED_SETTING
KNN_DERIVED_SOURCE_ENABLED_SETTING,
KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING,
KNN_REMOTE_VECTOR_REPO_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory;
import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams;
import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down Expand Up @@ -44,6 +45,7 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier;
private Function<KNNScalarQuantizedVectorsFormatParams, KnnVectorsFormat> scalarQuantizedVectorsFormatSupplier;
private final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory;
private static final String MAX_CONNECTIONS = "max_connections";
private static final String BEAM_WIDTH = "beam_width";

Expand All @@ -54,11 +56,26 @@ public BasePerFieldKnnVectorsFormat(
Supplier<KnnVectorsFormat> defaultFormatSupplier,
Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier
) {
this.mapperService = mapperService;
this.defaultMaxConnections = defaultMaxConnections;
this.defaultBeamWidth = defaultBeamWidth;
this.defaultFormatSupplier = defaultFormatSupplier;
this.vectorsFormatSupplier = vectorsFormatSupplier;
this(mapperService, defaultMaxConnections, defaultBeamWidth, defaultFormatSupplier, vectorsFormatSupplier, null);
}

public BasePerFieldKnnVectorsFormat(
Optional<MapperService> mapperService,
int defaultMaxConnections,
int defaultBeamWidth,
Supplier<KnnVectorsFormat> defaultFormatSupplier,
Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier,
Function<KNNScalarQuantizedVectorsFormatParams, KnnVectorsFormat> scalarQuantizedVectorsFormatSupplier
) {
this(
mapperService,
defaultMaxConnections,
defaultBeamWidth,
defaultFormatSupplier,
vectorsFormatSupplier,
scalarQuantizedVectorsFormatSupplier,
new NativeIndexBuildStrategyFactory()
);
}

@Override
Expand Down Expand Up @@ -141,7 +158,8 @@ private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() {
int approximateThreshold = getApproximateThresholdValue();
return new NativeEngines990KnnVectorsFormat(
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()),
approximateThreshold
approximateThreshold,
nativeIndexBuildStrategyFactory
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
package org.opensearch.knn.index.codec.KNN80Codec;

import lombok.extern.log4j.Log4j2;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.DocValuesConsumer;
Expand All @@ -19,8 +14,13 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.plugin.stats.KNNGraphValue;

import java.io.IOException;
Expand Down Expand Up @@ -72,9 +72,9 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
// For BDV it is fine to use knnVectorValues.totalLiveDocs() as we already run the full loop to calculate total
// live docs
if (isMerge) {
NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs());
NativeIndexWriter.getWriter(field, state).mergeIndex(() -> knnVectorValues, (int) knnVectorValues.totalLiveDocs());
} else {
NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs());
NativeIndexWriter.getWriter(field, state).flushIndex(() -> knnVectorValues, (int) knnVectorValues.totalLiveDocs());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory;
import org.opensearch.knn.index.engine.KNNEngine;

import java.util.Optional;
Expand All @@ -25,6 +26,13 @@ public class KNN9120PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsForma
private static final Tuple<Integer, ExecutorService> DEFAULT_MERGE_THREAD_COUNT_AND_EXECUTOR_SERVICE = Tuple.tuple(1, null);

public KNN9120PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
this(mapperService, new NativeIndexBuildStrategyFactory());
}

public KNN9120PerFieldKnnVectorsFormat(
final Optional<MapperService> mapperService,
NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory
) {
super(
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Expand Down Expand Up @@ -67,7 +75,8 @@ public KNN9120PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServi
// Executor service
mergeThreadCountAndExecutorService.v2()
);
}
},
nativeIndexBuildStrategyFactory
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory;
import org.opensearch.knn.index.engine.KNNEngine;

import java.io.IOException;
Expand All @@ -33,6 +34,7 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat {
private static FlatVectorsFormat flatVectorsFormat;
private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat";
private static int approximateThreshold;
private final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory;

public NativeEngines990KnnVectorsFormat() {
this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()));
Expand All @@ -47,9 +49,18 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma
}

public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int approximateThreshold) {
this(flatVectorsFormat, approximateThreshold, new NativeIndexBuildStrategyFactory());
}

public NativeEngines990KnnVectorsFormat(
final FlatVectorsFormat flatVectorsFormat,
int approximateThreshold,
final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory
) {
super(FORMAT_NAME);
NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat;
NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold;
this.nativeIndexBuildStrategyFactory = nativeIndexBuildStrategyFactory;
}

/**
Expand All @@ -59,7 +70,12 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma
*/
@Override
public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException {
return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold);
return new NativeEngines990KnnVectorsWriter(
state,
flatVectorsFormat.fieldsWriter(state),
approximateThreshold,
nativeIndexBuildStrategyFactory
);
}

/**
Expand Down
Loading

0 comments on commit f9d8b53

Please sign in to comment.