From 2a096278066cace68893cdbfba56f2fbe322f35b Mon Sep 17 00:00:00 2001 From: David Kyle Date: Fri, 31 Jan 2025 16:42:27 +0000 Subject: [PATCH] [ML] Fix infer on and elasticsearch service endpoint created with a deployment id (#121428) Fixes a bug where the deployment Id was lost creating the text embedding model configuration --- docs/changelog/121428.yaml | 6 +++ .../inference/CreateFromDeploymentIT.java | 27 ++++++++++ .../CustomElandInternalServiceSettings.java | 10 ---- ...dInternalTextEmbeddingServiceSettings.java | 26 +++------ .../ElasticRerankerServiceSettings.java | 10 +++- .../ElasticsearchInternalService.java | 18 ++----- .../ElasticsearchInternalServiceSettings.java | 16 +----- .../ElserInternalServiceSettings.java | 35 +++++------- ...lingualE5SmallInternalServiceSettings.java | 16 +++++- ...rnalTextEmbeddingServiceSettingsTests.java | 6 +++ ...ticsearchInternalServiceSettingsTests.java | 2 +- .../ElasticsearchInternalServiceTests.java | 54 ++++++++++++++----- .../ElserInternalModelTests.java | 2 +- .../ElserInternalServiceSettingsTests.java | 7 ++- 14 files changed, 133 insertions(+), 102 deletions(-) create mode 100644 docs/changelog/121428.yaml diff --git a/docs/changelog/121428.yaml b/docs/changelog/121428.yaml new file mode 100644 index 0000000000000..b0a31b2399611 --- /dev/null +++ b/docs/changelog/121428.yaml @@ -0,0 +1,6 @@ +pr: 121428 +summary: Fix infer on and elasticsearch service endpoint created with a deployment + id +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java index 0a2200ff912ac..44a42defd2d97 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java @@ -40,6 +40,14 @@ public void testAttachToDeployment() throws IOException { is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", "attach_to_deployment", "deployment_id", "existing_deployment")) ); + var getModel = getModel(inferenceId); + serviceSettings = getModel.get("service_settings"); + assertThat( + getModel.toString(), + serviceSettings, + is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", "attach_to_deployment", "deployment_id", "existing_deployment")) + ); + var results = infer(inferenceId, List.of("washing machine")); assertNotNull(results.get("sparse_embedding")); @@ -80,6 +88,25 @@ public void testAttachWithModelId() throws IOException { ) ); + var getModel = getModel(inferenceId); + serviceSettings = getModel.get("service_settings"); + assertThat( + getModel.toString(), + serviceSettings, + is( + Map.of( + "num_allocations", + 1, + "num_threads", + 1, + "model_id", + "attach_with_model_id", + "deployment_id", + "existing_deployment_with_model_id" + ) + ) + ); + var results = infer(inferenceId, List.of("washing machine")); assertNotNull(results.get("sparse_embedding")); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java index 3cc7e0c6c2b53..4591418419ded 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import java.io.IOException; @@ -20,15 +19,6 @@ public CustomElandInternalServiceSettings(ElasticsearchInternalServiceSettings o super(other); } - public CustomElandInternalServiceSettings( - Integer numAllocations, - int numThreads, - String modelId, - AdaptiveAllocationsSettings adaptiveAllocationsSettings - ) { - super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); - } - public CustomElandInternalServiceSettings(StreamInput in) throws IOException { super(in); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java index 133be5e2b7623..d5f4143b65d36 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; @@ -105,33 +106,17 @@ private static CommonFields commonFieldsFromMap(Map map, Validat private final SimilarityMeasure similarityMeasure; private final DenseVectorFieldMapper.ElementType elementType; - public CustomElandInternalTextEmbeddingServiceSettings( - int numAllocations, - int numThreads, - String modelId, - AdaptiveAllocationsSettings adaptiveAllocationsSettings - ) { - this( - numAllocations, - numThreads, - modelId, - adaptiveAllocationsSettings, - null, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ); - } - - public CustomElandInternalTextEmbeddingServiceSettings( + CustomElandInternalTextEmbeddingServiceSettings( int numAllocations, int numThreads, String modelId, AdaptiveAllocationsSettings adaptiveAllocationsSettings, + @Nullable String deploymentId, Integer dimensions, SimilarityMeasure similarityMeasure, DenseVectorFieldMapper.ElementType elementType ) { - super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, deploymentId); this.dimensions = dimensions; this.similarityMeasure = Objects.requireNonNull(similarityMeasure); this.elementType = Objects.requireNonNull(elementType); @@ -159,7 +144,8 @@ private CustomElandInternalTextEmbeddingServiceSettings(CommonFields commonField commonFields.internalServiceSettings.getNumAllocations(), commonFields.internalServiceSettings.getNumThreads(), commonFields.internalServiceSettings.modelId(), - commonFields.internalServiceSettings.getAdaptiveAllocationsSettings() + commonFields.internalServiceSettings.getAdaptiveAllocationsSettings(), + commonFields.internalServiceSettings.getDeploymentId() ); this.dimensions = dimensions; similarityMeasure = commonFields.similarityMeasure; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java index 316dc092e03c7..2b7904e615682 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java @@ -14,21 +14,27 @@ import java.io.IOException; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID; + public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings { public static final String NAME = "elastic_reranker_service_settings"; + public static ElasticRerankerServiceSettings defaultEndpointSettings() { + return new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)); + } + public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) { super(other); } - public ElasticRerankerServiceSettings( + private ElasticRerankerServiceSettings( Integer numAllocations, int numThreads, String modelId, AdaptiveAllocationsSettings adaptiveAllocationsSettings ) { - super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null); } public ElasticRerankerServiceSettings(StreamInput in) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index ac4dbe68a8f83..07084a09f04d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; @@ -562,6 +561,7 @@ private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomE model.getServiceSettings().getNumThreads(), model.getServiceSettings().modelId(), model.getServiceSettings().getAdaptiveAllocationsSettings(), + model.getServiceSettings().getDeploymentId(), embeddingSize, model.getServiceSettings().similarity(), model.getServiceSettings().elementType() @@ -902,12 +902,7 @@ private List defaultConfigs(boolean useLinuxOptimizedModel) { DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, NAME, - new ElserInternalServiceSettings( - null, - 1, - useLinuxOptimizedModel ? ELSER_V2_MODEL_LINUX_X86 : ELSER_V2_MODEL, - new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32) - ), + ElserInternalServiceSettings.defaultEndpointSettings(useLinuxOptimizedModel), ElserMlNodeTaskSettings.DEFAULT, ChunkingSettingsBuilder.DEFAULT_SETTINGS ); @@ -915,19 +910,14 @@ private List defaultConfigs(boolean useLinuxOptimizedModel) { DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, NAME, - new MultilingualE5SmallInternalServiceSettings( - null, - 1, - useLinuxOptimizedModel ? MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 : MULTILINGUAL_E5_SMALL_MODEL_ID, - new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32) - ), + MultilingualE5SmallInternalServiceSettings.defaultEndpointSettings(useLinuxOptimizedModel), ChunkingSettingsBuilder.DEFAULT_SETTINGS ); var defaultRerank = new ElasticRerankerModel( DEFAULT_RERANK_ID, TaskType.RERANK, NAME, - new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)), + ElasticRerankerServiceSettings.defaultEndpointSettings(), RerankTaskSettings.DEFAULT_SETTINGS ); return List.of(defaultElser, defaultE5, defaultRerank); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 244108edc3dd4..98730f33d10f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.ToXContentObject; @@ -108,25 +109,12 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap( .setDeploymentId(deploymentId); } - public ElasticsearchInternalServiceSettings( - Integer numAllocations, - int numThreads, - String modelId, - AdaptiveAllocationsSettings adaptiveAllocationsSettings - ) { - this.numAllocations = numAllocations; - this.numThreads = numThreads; - this.modelId = Objects.requireNonNull(modelId); - this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; - this.deploymentId = null; - } - public ElasticsearchInternalServiceSettings( Integer numAllocations, int numThreads, String modelId, AdaptiveAllocationsSettings adaptiveAllocationsSettings, - String deploymentId + @Nullable String deploymentId ) { this.numAllocations = numAllocations; this.numThreads = numThreads; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java index da9164bf3f288..b94b9feb8a049 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java @@ -9,14 +9,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import java.io.IOException; -import java.util.Arrays; -import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86; public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSettings { @@ -26,37 +26,26 @@ public static MinimalServiceSettings minimalServiceSettings() { return MinimalServiceSettings.sparseEmbedding(); } - public static Builder fromRequestMap(Map map) { - ValidationException validationException = new ValidationException(); - var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); - - String modelId = baseSettings.getModelId(); - if (modelId != null && ElserModels.isValidModel(modelId) == false) { - var ve = new ValidationException(); - ve.addValidationError( - "Unknown ELSER model ID [" + modelId + "]. Valid models are " + Arrays.toString(ElserModels.VALID_ELSER_MODEL_IDS.toArray()) - ); - throw ve; - } - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return baseSettings; + public static ElserInternalServiceSettings defaultEndpointSettings(boolean useLinuxOptimizedModel) { + return new ElserInternalServiceSettings( + null, + 1, + useLinuxOptimizedModel ? ELSER_V2_MODEL_LINUX_X86 : ELSER_V2_MODEL, + new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32) + ); } public ElserInternalServiceSettings(ElasticsearchInternalServiceSettings other) { super(other); } - public ElserInternalServiceSettings( + private ElserInternalServiceSettings( Integer numAllocations, int numThreads, String modelId, AdaptiveAllocationsSettings adaptiveAllocationsSettings ) { - this(new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId, adaptiveAllocationsSettings)); + this(new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null)); } public ElserInternalServiceSettings(StreamInput in) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java index 317cc48172fca..45d52d3c8deaa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java @@ -18,6 +18,9 @@ import java.util.Arrays; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; + public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInternalServiceSettings { public static final String NAME = "multilingual_e5_small_service_settings"; @@ -29,17 +32,26 @@ public static MinimalServiceSettings minimalServiceSettings() { return MinimalServiceSettings.textEmbedding(DIMENSIONS, SIMILARITY, DenseVectorFieldMapper.ElementType.FLOAT); } + public static MultilingualE5SmallInternalServiceSettings defaultEndpointSettings(boolean useLinuxOptimizedModel) { + return new MultilingualE5SmallInternalServiceSettings( + null, + 1, + useLinuxOptimizedModel ? MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 : MULTILINGUAL_E5_SMALL_MODEL_ID, + new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32) + ); + } + public MultilingualE5SmallInternalServiceSettings(ElasticsearchInternalServiceSettings other) { super(other); } - public MultilingualE5SmallInternalServiceSettings( + MultilingualE5SmallInternalServiceSettings( Integer numAllocations, int numThreads, String modelId, AdaptiveAllocationsSettings adaptiveAllocationsSettings ) { - super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null); } public MultilingualE5SmallInternalServiceSettings(StreamInput in) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java index ebb9c964e4c9a..3b8ce3a7cc64c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java @@ -48,6 +48,7 @@ public static CustomElandInternalTextEmbeddingServiceSettings createRandom() { numThreads, modelId, null, + null, dims, similarityMeasure, elementType @@ -86,6 +87,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { modelId, null, null, + null, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -111,6 +113,7 @@ public void testFromMap_Request_DoesNotDefaultSimilarityElementType() { modelId, null, null, + null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -152,6 +155,7 @@ public void testFromMap_Request_IgnoresDimensions() { modelId, null, null, + null, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -192,6 +196,7 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { numThreads, modelId, null, + null, 1, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT @@ -206,6 +211,7 @@ public void testToXContent_WritesAllValues() throws IOException { 1, "model_id", null, + null, 100, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.BYTE diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java index 0db0a7669c8aa..4ec575420613f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java @@ -103,7 +103,7 @@ public void testFromMap() { ) ) ).build(); - assertEquals(new ElasticsearchInternalServiceSettings(1, 4, ".elser_model_1", null), serviceSettings); + assertEquals(new ElasticsearchInternalServiceSettings(1, 4, ".elser_model_1", null, null), serviceSettings); } public void testFromMapMissingOptions() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 3b634f45dc751..1615d46b349ea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -351,7 +351,9 @@ public void testParseRequestConfig_elser() { ) ); - var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + var elserServiceSettings = new ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null, null) + ); service.parseRequestConfig( randomInferenceEntityId, @@ -381,7 +383,9 @@ public void testParseRequestConfig_elser() { ) ); - var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + var elserServiceSettings = new ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null, null) + ); String criticalWarning = "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" @@ -450,7 +454,9 @@ public void testParseRequestConfig_elser() { ); config.put(ModelConfigurations.CHUNKING_SETTINGS, createRandomChunkingSettingsMap()); - var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + var elserServiceSettings = new ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null, null) + ); service.parseRequestConfig( randomInferenceEntityId, @@ -486,7 +492,9 @@ public void testParseRequestConfig_elser() { ) ); - var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + var elserServiceSettings = new ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null, null) + ); service.parseRequestConfig( randomInferenceEntityId, @@ -742,7 +750,16 @@ public void testParsePersistedConfig() { TaskType.TEXT_EMBEDDING, settings ); - var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "invalid", null); + var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( + 1, + 4, + "invalid", + null, + null, + null, + SimilarityMeasure.COSINE, + DenseVectorFieldMapper.ElementType.FLOAT + ); assertEquals( new CustomElandEmbeddingModel( randomInferenceEntityId, @@ -933,7 +950,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int "foo", TaskType.SPARSE_EMBEDDING, "elasticsearch", - new ElasticsearchInternalServiceSettings(1, 1, "model-id", null), + new ElasticsearchInternalServiceSettings(1, 1, "model-id", null, null), chunkingSettings ); var service = createService(client); @@ -1003,7 +1020,7 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte "foo", TaskType.SPARSE_EMBEDDING, "elasticsearch", - new ElserInternalServiceSettings(1, 1, "model-id", null), + new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, "model-id", null, null)), new ElserMlNodeTaskSettings(), chunkingSettings ); @@ -1328,11 +1345,20 @@ private CustomElandModel getCustomElandModel(TaskType taskType) { randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, - new CustomElandInternalServiceSettings(1, 4, "custom-model", null), + new CustomElandInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 4, "custom-model", null, null)), RerankTaskSettings.DEFAULT_SETTINGS ); } else if (taskType == TaskType.TEXT_EMBEDDING) { - var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "custom-model", null); + var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings( + 1, + 4, + "custom-model", + null, + null, + null, + SimilarityMeasure.COSINE, + DenseVectorFieldMapper.ElementType.FLOAT + ); expectedModel = new CustomElandEmbeddingModel( randomInferenceEntityId, @@ -1346,7 +1372,7 @@ private CustomElandModel getCustomElandModel(TaskType taskType) { randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, - new CustomElandInternalServiceSettings(1, 4, "custom-model", null), + new CustomElandInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 4, "custom-model", null, null)), (ChunkingSettings) null ); } @@ -1438,6 +1464,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { 4, "custom-model", null, + null, 1, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT @@ -1463,6 +1490,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { "custom-model", null, null, + null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT ), @@ -1511,7 +1539,7 @@ public void testEmbeddingTypeFromTaskTypeAndSettings() { EmbeddingRequestChunker.EmbeddingType.SPARSE, ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings( TaskType.SPARSE_EMBEDDING, - new ElasticsearchInternalServiceSettings(1, 1, "foo", null) + new ElasticsearchInternalServiceSettings(1, 1, "foo", null, null) ) ); assertEquals( @@ -1526,7 +1554,7 @@ public void testEmbeddingTypeFromTaskTypeAndSettings() { ElasticsearchStatusException.class, () -> ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings( TaskType.COMPLETION, - new ElasticsearchInternalServiceSettings(1, 1, "foo", null) + new ElasticsearchInternalServiceSettings(1, 1, "foo", null, null) ) ); assertThat(e1.getMessage(), containsString("Chunking is not supported for task type [completion]")); @@ -1535,7 +1563,7 @@ public void testEmbeddingTypeFromTaskTypeAndSettings() { ElasticsearchStatusException.class, () -> ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings( TaskType.RERANK, - new ElasticsearchInternalServiceSettings(1, 1, "foo", null) + new ElasticsearchInternalServiceSettings(1, 1, "foo", null, null) ) ); assertThat(e2.getMessage(), containsString("Chunking is not supported for task type [rerank]")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java index 96cd42efa42f5..5b21717ac03e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java @@ -16,7 +16,7 @@ public void testUpdateNumAllocation() { "foo", TaskType.SPARSE_EMBEDDING, ElasticsearchInternalService.NAME, - new ElserInternalServiceSettings(null, 1, "elser", null), + new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(null, 1, "elser", null, null)), new ElserMlNodeTaskSettings(), null ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java index f4e97b2c2e5e0..dd4513db0d50a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java @@ -24,12 +24,12 @@ public static ElserInternalServiceSettings createRandom() { public void testBwcWrite() throws IOException { { - var settings = new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null)); + var settings = new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null, null)); var copy = copyInstance(settings, TransportVersions.V_8_12_0); assertEquals(settings, copy); } { - var settings = new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null)); + var settings = new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null, null)); var copy = copyInstance(settings, TransportVersions.V_8_11_X); assertEquals(settings, copy); } @@ -53,6 +53,7 @@ protected ElserInternalServiceSettings mutateInstance(ElserInternalServiceSettin instance.getNumAllocations() == null ? 1 : instance.getNumAllocations() + 1, instance.getNumThreads(), instance.modelId(), + null, null ) ); @@ -61,6 +62,7 @@ protected ElserInternalServiceSettings mutateInstance(ElserInternalServiceSettin instance.getNumAllocations(), instance.getNumThreads() + 1, instance.modelId(), + null, null ) ); @@ -72,6 +74,7 @@ yield new ElserInternalServiceSettings( instance.getNumAllocations(), instance.getNumThreads(), versions.iterator().next(), + null, null ) );