Skip to content

Commit

Permalink
[ML] Fix infer on and elasticsearch service endpoint created with a d…
Browse files Browse the repository at this point in the history
…eployment id (elastic#121428)

Fixes a bug where the deployment Id was lost creating the text embedding
model configuration
  • Loading branch information
davidkyle committed Jan 31, 2025
1 parent 5fe99a1 commit 3001cab
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 102 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/121428.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 121428
summary: Fix infer on and elasticsearch service endpoint created with a deployment
id
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ public void testAttachToDeployment() throws IOException {
is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", "attach_to_deployment", "deployment_id", "existing_deployment"))
);

var getModel = getModel(inferenceId);
serviceSettings = getModel.get("service_settings");
assertThat(
getModel.toString(),
serviceSettings,
is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", "attach_to_deployment", "deployment_id", "existing_deployment"))
);

var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

Expand Down Expand Up @@ -80,6 +88,25 @@ public void testAttachWithModelId() throws IOException {
)
);

var getModel = getModel(inferenceId);
serviceSettings = getModel.get("service_settings");
assertThat(
getModel.toString(),
serviceSettings,
is(
Map.of(
"num_allocations",
1,
"num_threads",
1,
"model_id",
"attach_with_model_id",
"deployment_id",
"existing_deployment_with_model_id"
)
)
);

var results = infer(inferenceId, List.of("washing machine"));
assertNotNull(results.get("sparse_embedding"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;

import java.io.IOException;

Expand All @@ -20,15 +19,6 @@ public CustomElandInternalServiceSettings(ElasticsearchInternalServiceSettings o
super(other);
}

public CustomElandInternalServiceSettings(
Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
}

public CustomElandInternalServiceSettings(StreamInput in) throws IOException {
super(in);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SimilarityMeasure;
Expand Down Expand Up @@ -105,33 +106,17 @@ private static CommonFields commonFieldsFromMap(Map<String, Object> map, Validat
private final SimilarityMeasure similarityMeasure;
private final DenseVectorFieldMapper.ElementType elementType;

public CustomElandInternalTextEmbeddingServiceSettings(
int numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
this(
numAllocations,
numThreads,
modelId,
adaptiveAllocationsSettings,
null,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.FLOAT
);
}

public CustomElandInternalTextEmbeddingServiceSettings(
CustomElandInternalTextEmbeddingServiceSettings(
int numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
@Nullable String deploymentId,
Integer dimensions,
SimilarityMeasure similarityMeasure,
DenseVectorFieldMapper.ElementType elementType
) {
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, deploymentId);
this.dimensions = dimensions;
this.similarityMeasure = Objects.requireNonNull(similarityMeasure);
this.elementType = Objects.requireNonNull(elementType);
Expand Down Expand Up @@ -159,7 +144,8 @@ private CustomElandInternalTextEmbeddingServiceSettings(CommonFields commonField
commonFields.internalServiceSettings.getNumAllocations(),
commonFields.internalServiceSettings.getNumThreads(),
commonFields.internalServiceSettings.modelId(),
commonFields.internalServiceSettings.getAdaptiveAllocationsSettings()
commonFields.internalServiceSettings.getAdaptiveAllocationsSettings(),
commonFields.internalServiceSettings.getDeploymentId()
);
this.dimensions = dimensions;
similarityMeasure = commonFields.similarityMeasure;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,27 @@
import java.io.IOException;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID;

public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {

public static final String NAME = "elastic_reranker_service_settings";

public static ElasticRerankerServiceSettings defaultEndpointSettings() {
return new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32));
}

public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) {
super(other);
}

public ElasticRerankerServiceSettings(
private ElasticRerankerServiceSettings(
Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null);
}

public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
Expand Down Expand Up @@ -562,6 +561,7 @@ private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomE
model.getServiceSettings().getNumThreads(),
model.getServiceSettings().modelId(),
model.getServiceSettings().getAdaptiveAllocationsSettings(),
model.getServiceSettings().getDeploymentId(),
embeddingSize,
model.getServiceSettings().similarity(),
model.getServiceSettings().elementType()
Expand Down Expand Up @@ -902,32 +902,22 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {
DEFAULT_ELSER_ID,
TaskType.SPARSE_EMBEDDING,
NAME,
new ElserInternalServiceSettings(
null,
1,
useLinuxOptimizedModel ? ELSER_V2_MODEL_LINUX_X86 : ELSER_V2_MODEL,
new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)
),
ElserInternalServiceSettings.defaultEndpointSettings(useLinuxOptimizedModel),
ElserMlNodeTaskSettings.DEFAULT,
ChunkingSettingsBuilder.DEFAULT_SETTINGS
);
var defaultE5 = new MultilingualE5SmallModel(
DEFAULT_E5_ID,
TaskType.TEXT_EMBEDDING,
NAME,
new MultilingualE5SmallInternalServiceSettings(
null,
1,
useLinuxOptimizedModel ? MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 : MULTILINGUAL_E5_SMALL_MODEL_ID,
new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)
),
MultilingualE5SmallInternalServiceSettings.defaultEndpointSettings(useLinuxOptimizedModel),
ChunkingSettingsBuilder.DEFAULT_SETTINGS
);
var defaultRerank = new ElasticRerankerModel(
DEFAULT_RERANK_ID,
TaskType.RERANK,
NAME,
new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)),
ElasticRerankerServiceSettings.defaultEndpointSettings(),
RerankTaskSettings.DEFAULT_SETTINGS
);
return List.of(defaultElser, defaultE5, defaultRerank);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.ToXContentObject;
Expand Down Expand Up @@ -108,25 +109,12 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap(
.setDeploymentId(deploymentId);
}

public ElasticsearchInternalServiceSettings(
Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
this.numAllocations = numAllocations;
this.numThreads = numThreads;
this.modelId = Objects.requireNonNull(modelId);
this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
this.deploymentId = null;
}

public ElasticsearchInternalServiceSettings(
Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
String deploymentId
@Nullable String deploymentId
) {
this.numAllocations = numAllocations;
this.numThreads = numThreads;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86;

public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSettings {

Expand All @@ -26,37 +26,26 @@ public static MinimalServiceSettings minimalServiceSettings() {
return MinimalServiceSettings.sparseEmbedding();
}

public static Builder fromRequestMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);

String modelId = baseSettings.getModelId();
if (modelId != null && ElserModels.isValidModel(modelId) == false) {
var ve = new ValidationException();
ve.addValidationError(
"Unknown ELSER model ID [" + modelId + "]. Valid models are " + Arrays.toString(ElserModels.VALID_ELSER_MODEL_IDS.toArray())
);
throw ve;
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return baseSettings;
public static ElserInternalServiceSettings defaultEndpointSettings(boolean useLinuxOptimizedModel) {
return new ElserInternalServiceSettings(
null,
1,
useLinuxOptimizedModel ? ELSER_V2_MODEL_LINUX_X86 : ELSER_V2_MODEL,
new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)
);
}

public ElserInternalServiceSettings(ElasticsearchInternalServiceSettings other) {
super(other);
}

public ElserInternalServiceSettings(
private ElserInternalServiceSettings(
Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
this(new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId, adaptiveAllocationsSettings));
this(new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null));
}

public ElserInternalServiceSettings(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import java.util.Arrays;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;

public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInternalServiceSettings {

public static final String NAME = "multilingual_e5_small_service_settings";
Expand All @@ -29,17 +32,26 @@ public static MinimalServiceSettings minimalServiceSettings() {
return MinimalServiceSettings.textEmbedding(DIMENSIONS, SIMILARITY, DenseVectorFieldMapper.ElementType.FLOAT);
}

public static MultilingualE5SmallInternalServiceSettings defaultEndpointSettings(boolean useLinuxOptimizedModel) {
return new MultilingualE5SmallInternalServiceSettings(
null,
1,
useLinuxOptimizedModel ? MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 : MULTILINGUAL_E5_SMALL_MODEL_ID,
new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)
);
}

public MultilingualE5SmallInternalServiceSettings(ElasticsearchInternalServiceSettings other) {
super(other);
}

public MultilingualE5SmallInternalServiceSettings(
MultilingualE5SmallInternalServiceSettings(
Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null);
}

public MultilingualE5SmallInternalServiceSettings(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public static CustomElandInternalTextEmbeddingServiceSettings createRandom() {
numThreads,
modelId,
null,
null,
dims,
similarityMeasure,
elementType
Expand Down Expand Up @@ -86,6 +87,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() {
modelId,
null,
null,
null,
SimilarityMeasure.DOT_PRODUCT,
DenseVectorFieldMapper.ElementType.FLOAT
)
Expand All @@ -111,6 +113,7 @@ public void testFromMap_Request_DoesNotDefaultSimilarityElementType() {
modelId,
null,
null,
null,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.FLOAT
)
Expand Down Expand Up @@ -152,6 +155,7 @@ public void testFromMap_Request_IgnoresDimensions() {
modelId,
null,
null,
null,
SimilarityMeasure.DOT_PRODUCT,
DenseVectorFieldMapper.ElementType.FLOAT
)
Expand Down Expand Up @@ -192,6 +196,7 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() {
numThreads,
modelId,
null,
null,
1,
SimilarityMeasure.DOT_PRODUCT,
DenseVectorFieldMapper.ElementType.FLOAT
Expand All @@ -206,6 +211,7 @@ public void testToXContent_WritesAllValues() throws IOException {
1,
"model_id",
null,
null,
100,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.BYTE
Expand Down
Loading

0 comments on commit 3001cab

Please sign in to comment.