From 6adbfc748409165a663943a0191db89d7841527d Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 9 Jun 2025 12:25:24 -0400 Subject: [PATCH 1/6] [ML] SageMaker Elastic Payload Send the Elastic API Payload to a SageMaker endpoint, and parse the response as if it were an Elastic API response. - SageMaker now supports all task types in the Elastic API format. - Streaming is supported using the SageMaker client/server rpc, rather than SSE. Payloads must be in a complete and valid JSON structure. - Task Settings can be used for additional passthrough settings, but they will not be saved alongside the model. Elastic cannot make guarantees on the structure or contents of this payload, so Elastic will treat it like the other input payloads and only allow them during inference. --- .../org/elasticsearch/TransportVersions.java | 2 + .../inference/UnifiedCompletionRequest.java | 8 + .../results/ChatCompletionResults.java | 1 - .../StreamingChatCompletionResults.java | 2 +- .../OpenAiUnifiedStreamingProcessor.java | 9 +- .../sagemaker/schema/SageMakerSchema.java | 4 +- .../sagemaker/schema/SageMakerSchemas.java | 13 +- .../elastic/ElasticCompletionPayload.java | 167 ++++++++++ .../schema/elastic/ElasticPayload.java | 110 +++++++ .../schema/elastic/ElasticRerankPayload.java | 106 ++++++ .../ElasticSparseEmbeddingPayload.java | 91 ++++++ .../elastic/ElasticTextEmbeddingPayload.java | 302 ++++++++++++++++++ .../elastic/SageMakerElasticTaskSettings.java | 65 ++++ .../openai/OpenAiCompletionPayload.java | 4 +- .../SageMakerSchemaPayloadTestCase.java | 7 +- .../schema/SageMakerSchemasTests.java | 18 +- .../ElasticCompletionPayloadTests.java | 149 +++++++++ .../elastic/ElasticPayloadTestCase.java | 122 +++++++ .../elastic/ElasticRerankPayloadTests.java | 108 +++++++ .../ElasticSparseEmbeddingPayloadTests.java | 64 ++++ .../ElasticTextEmbeddingPayloadTests.java | 114 +++++++ .../SageMakerElasticTaskSettingsTests.java | 39 +++ ...sticTextEmbeddingServiceSettingsTests.java | 51 +++ .../openai/OpenAiCompletionPayloadTests.java | 7 +- 24 files changed, 1546 insertions(+), 17 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticRerankPayload.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayload.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayloadTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticRerankPayloadTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayloadTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTextEmbeddingServiceSettingsTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 8bf8a94fccfe0..904bb37a7fa2a 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -196,6 +196,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48); public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49); public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50); + public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_51); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -298,6 +299,7 @@ static TransportVersion def(int id) { public static final TransportVersion HEAP_USAGE_IN_CLUSTER_INFO = def(9_096_0_00); public static final TransportVersion NONE_CHUNKING_STRATEGY = def(9_097_0_00); public static final TransportVersion PROJECT_DELETION_GLOBAL_BLOCK = def(9_098_0_00); + public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = def(9_099_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 8e934486be7e1..db31aafc8c190 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -128,6 +128,14 @@ public static Params withMaxCompletionTokensTokens(String modelId, Params params ); } + /** + * Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values: + * - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()} + */ + public static Params withMaxCompletionTokensTokens(Params params) { + return new DelegatingMapParams(Map.of(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD), params); + } + public sealed interface Content extends NamedWriteable, ToXContent permits ContentObjects, ContentString {} @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java index 346d7416f9dc5..f1a01296c78c8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java @@ -43,7 +43,6 @@ public record ChatCompletionResults(List results) implements InferenceSe public static final String NAME = "chat_completion_service_results"; public static final String COMPLETION = TaskType.COMPLETION.name().toLowerCase(Locale.ROOT); - public ChatCompletionResults(StreamInput in) throws IOException { this(in.readCollectionAsList(Result::new)); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java index 7657ad498cadf..e377228f04ba9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java @@ -80,7 +80,7 @@ public int hashCode() { } public record Result(String delta) implements ChunkedToXContent, Writeable { - private static final String RESULT = "delta"; + public static final String RESULT = "delta"; private Result(StreamInput in) throws IOException { this(in.readString()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 86b4a0a65ef2c..3120f1ff92e48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -97,7 +97,14 @@ public static Stream return Stream.empty(); } - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) { + return parse(parserConfig, event.data()); + } + + public static Stream parse( + XContentParserConfiguration parserConfig, + String data + ) throws IOException { + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { moveToFirstToken(jsonParser); XContentParser.Token token = jsonParser.currentToken(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java index 3a39bb804e235..91f3e2c05f0a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java @@ -67,7 +67,7 @@ public InvokeEndpointRequest request(SageMakerModel model, SageMakerInferenceReq throw e; } catch (Exception e) { throw new ElasticsearchStatusException( - "Failed to create SageMaker request for [%s]", + "Failed to create SageMaker request for [{}]", RestStatus.INTERNAL_SERVER_ERROR, e, model.getInferenceEntityId() @@ -98,7 +98,7 @@ public InferenceServiceResults response(SageMakerModel model, InvokeEndpointResp throw e; } catch (Exception e) { throw new ElasticsearchStatusException( - "Failed to translate SageMaker response for [%s]", + "Failed to translate SageMaker response for [{}]", RestStatus.INTERNAL_SERVER_ERROR, e, model.getInferenceEntityId() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java index 3ecd0388796c4..832d5eaa7e1e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java @@ -12,6 +12,10 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticCompletionPayload; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticRerankPayload; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticSparseEmbeddingPayload; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticTextEmbeddingPayload; import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload; import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload; @@ -41,7 +45,14 @@ public class SageMakerSchemas { /* * Add new model API to the register call. */ - schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload()); + schemas = register( + new OpenAiTextEmbeddingPayload(), + new OpenAiCompletionPayload(), + new ElasticTextEmbeddingPayload(), + new ElasticSparseEmbeddingPayload(), + new ElasticCompletionPayload(), + new ElasticRerankPayload() + ); streamSchemas = schemas.entrySet() .stream() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java new file mode 100644 index 0000000000000..a8c4f7c57796b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayload.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.core.inference.DequeUtils; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedStreamingProcessor; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchemaPayload; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.json.JsonXContent.jsonXContent; + +/** + * Streaming payloads are expected to be in the exact format of the Elastic API. This does *not* use the Server-Sent Event transport + * protocol, rather this expects the SageMaker client and the implemented Endpoint to use AWS's transport protocol to deliver entire chunks. + * Each chunk should be in a valid JSON format, as that is the format the Elastic API uses. + */ +public class ElasticCompletionPayload implements SageMakerStreamSchemaPayload, ElasticPayload { + private static final XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + + /** + * { + * "completion": [ + * { + * "result": "some result 1" + * }, + * { + * "result": "some result 2" + * } + * ] + * } + */ + @Override + public ChatCompletionResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { + return Completion.PARSER.apply(p, null); + } + } + + /** + * { + * "completion": [ + * { + * "delta": "some result 1" + * }, + * { + * "delta": "some result 2" + * } + * ] + * } + */ + @Override + public StreamingChatCompletionResults.Results streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception { + try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.asInputStream())) { + return StreamCompletion.PARSER.apply(p, null); + } + } + + @Override + public SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) { + return SdkBytes.fromUtf8String(Strings.toString((builder, params) -> { + request.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(params)); + return builder; + })); + } + + @Override + public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) { + var responseData = response.asUtf8String(); + try { + var results = OpenAiUnifiedStreamingProcessor.parse(parserConfig, responseData) + .collect( + () -> new ArrayDeque(), + ArrayDeque::offer, + ArrayDeque::addAll + ); + return new StreamingUnifiedChatCompletionResults.Results(results); + } catch (Exception e) { + throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), responseData, e); + } + } + + private static class Completion { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ChatCompletionResults.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> new ChatCompletionResults((List) args[0]) + ); + private static final ConstructingObjectParser RESULT_PARSER = new ConstructingObjectParser<>( + ChatCompletionResults.Result.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> new ChatCompletionResults.Result((String) args[0]) + ); + + static { + RESULT_PARSER.declareString(constructorArg(), new ParseField(ChatCompletionResults.Result.RESULT)); + PARSER.declareObjectArray(constructorArg(), RESULT_PARSER::apply, new ParseField(ChatCompletionResults.COMPLETION)); + } + } + + private static class StreamCompletion { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + StreamingChatCompletionResults.Results.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> new StreamingChatCompletionResults.Results((Deque) args[0]) + ); + private static final ConstructingObjectParser RESULT_PARSER = + new ConstructingObjectParser<>( + StreamingChatCompletionResults.Result.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> new StreamingChatCompletionResults.Result((String) args[0]) + ); + + static { + RESULT_PARSER.declareString(constructorArg(), new ParseField(StreamingChatCompletionResults.Result.RESULT)); + PARSER.declareField(constructorArg(), (p, c) -> { + var currentToken = p.currentToken(); + + // ES allows users to send single-value strings instead of an array of one value + if (currentToken.isValue() + || currentToken == XContentParser.Token.VALUE_NULL + || currentToken == XContentParser.Token.START_OBJECT) { + return DequeUtils.of(RESULT_PARSER.apply(p, c)); + } + + var deque = new ArrayDeque(); + XContentParser.Token token; + while ((token = p.nextToken()) != XContentParser.Token.END_ARRAY) { + if (token.isValue() || token == XContentParser.Token.VALUE_NULL || token == XContentParser.Token.START_OBJECT) { + deque.offer(RESULT_PARSER.apply(p, c)); + } else { + throw new IllegalStateException("expected value but got [" + token + "]"); + } + } + return deque; + }, new ParseField(ChatCompletionResults.COMPLETION), ObjectParser.ValueType.OBJECT_ARRAY); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java new file mode 100644 index 0000000000000..46c5a9eb30a9a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import software.amazon.awssdk.core.SdkBytes; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayload; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.util.Map; +import java.util.stream.Stream; + +import static org.elasticsearch.xcontent.json.JsonXContent.jsonXContent; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.INPUT; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.INPUT_TYPE; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.QUERY; + +interface ElasticPayload extends SageMakerSchemaPayload { + String API = "elastic"; + String APPLICATION_JSON = jsonXContent.type().mediaTypeWithoutParameters(); + /** + * If Elastic receives an element in the response that it does not recognize, it will fail. + */ + boolean IGNORE_UNKNOWN_FIELDS = false; + + @Override + default String api() { + return API; + } + + @Override + default String accept(SageMakerModel model) { + return APPLICATION_JSON; + } + + @Override + default String contentType(SageMakerModel model) { + return APPLICATION_JSON; + } + + @Override + default SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception { + if (model.apiTaskSettings() instanceof SageMakerElasticTaskSettings elasticTaskSettings) { + return SdkBytes.fromUtf8String(Strings.toString((builder, params) -> { + if (request.input().size() > 1) { + builder.field(INPUT.getPreferredName(), request.input()); + } else { + builder.field(INPUT.getPreferredName(), request.input().get(0)); + } + if (InputType.isSpecified(request.inputType())) { + builder.field(INPUT_TYPE.getPreferredName(), switch (request.inputType()) { + case INGEST, INTERNAL_INGEST -> InputType.INGEST; + case SEARCH, INTERNAL_SEARCH -> InputType.SEARCH; + default -> request.inputType(); + }); + } + if (request.query() != null) { + builder.field(QUERY.getPreferredName(), request.query()); + } + if (elasticTaskSettings.isEmpty() == false) { + builder.field(InferenceAction.Request.TASK_SETTINGS.getPreferredName()); + if (elasticTaskSettings.isFragment()) { + builder.startObject(); + } + builder.value(elasticTaskSettings); + if (elasticTaskSettings.isFragment()) { + builder.endObject(); + } + } + return builder; + })); + } else { + throw createUnsupportedSchemaException(model); + } + } + + @Override + default SageMakerElasticTaskSettings apiTaskSettings(Map taskSettings, ValidationException validationException) { + if (taskSettings != null && (taskSettings.isEmpty() == false)) { + validationException.addValidationError( + InferenceAction.Request.TASK_SETTINGS.getPreferredName() + + " is only supported during the inference request and cannot be stored in the inference endpoint." + ); + } + return SageMakerElasticTaskSettings.empty(); + } + + @Override + default Stream namedWriteables() { + return Stream.of( + new NamedWriteableRegistry.Entry( + SageMakerStoredTaskSchema.class, + SageMakerElasticTaskSettings.NAME, + SageMakerElasticTaskSettings::new + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticRerankPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticRerankPayload.java new file mode 100644 index 0000000000000..d06541cc6958b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticRerankPayload.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; + +import java.util.EnumSet; + +import static org.elasticsearch.xcontent.json.JsonXContent.jsonXContent; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.INPUT; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.QUERY; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.RETURN_DOCUMENTS; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.TOP_N; + +public class ElasticRerankPayload implements ElasticPayload { + + private static final EnumSet SUPPORTED_TASKS = EnumSet.of(TaskType.RERANK); + private static final ConstructingObjectParser PARSER = RankedDocsResults.createParser(IGNORE_UNKNOWN_FIELDS); + + @Override + public EnumSet supportedTasks() { + return SUPPORTED_TASKS; + } + + /** + * { + * "input": "single string or list", + * "query": "string", + * "return_documents": "boolean", + * "top_n": "integer", + * "task_settings": { + * "additional": "settings" + * } + * } + */ + @Override + public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception { + if (model.apiTaskSettings() instanceof SageMakerElasticTaskSettings elasticTaskSettings) { + return SdkBytes.fromUtf8String(Strings.toString((builder, params) -> { + if (request.input().size() > 1) { + builder.field(INPUT.getPreferredName(), request.input()); + } else { + builder.field(INPUT.getPreferredName(), request.input().get(0)); + } + + assert request.query() != null : "InferenceAction.Request will validate that rerank requests have a query field"; + builder.field(QUERY.getPreferredName(), request.query()); + + if (request.returnDocuments() != null) { + builder.field(RETURN_DOCUMENTS.getPreferredName(), request.returnDocuments()); + } + + if (request.topN() != null) { + builder.field(TOP_N.getPreferredName(), request.topN()); + } + + if (elasticTaskSettings.isEmpty() == false) { + builder.field(InferenceAction.Request.TASK_SETTINGS.getPreferredName()); + if (elasticTaskSettings.isFragment()) { + builder.startObject(); + } + builder.value(elasticTaskSettings); + if (elasticTaskSettings.isFragment()) { + builder.endObject(); + } + } + return builder; + })); + } else { + throw createUnsupportedSchemaException(model); + } + } + + /** + * { + * "rerank": [ + * { + * "index": 0, + * "relevance_score": 1.0 + * "text": "hello, world" + * } + * ] + * } + */ + @Override + public RankedDocsResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { + return PARSER.apply(p, null); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayload.java new file mode 100644 index 0000000000000..56e50a79231a0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayload.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.WeightedToken; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.json.JsonXContent.jsonXContent; + +public class ElasticSparseEmbeddingPayload implements ElasticPayload { + + private static final EnumSet SUPPORTED_TASKS = EnumSet.of(TaskType.SPARSE_EMBEDDING); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + SparseEmbeddingResults.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> new SparseEmbeddingResults((List) args[0]) + ); + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser EMBEDDINGS_PARSER = + new ConstructingObjectParser<>( + SparseEmbeddingResults.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> SparseEmbeddingResults.Embedding.create((List) args[0], (boolean) args[1]) + ); + + static { + EMBEDDINGS_PARSER.declareObject( + constructorArg(), + (p, c) -> p.map(HashMap::new, XContentParser::floatValue) + .entrySet() + .stream() + .map(entry -> new WeightedToken(entry.getKey(), entry.getValue())) + .toList(), + new ParseField("embedding") + ); + EMBEDDINGS_PARSER.declareBoolean(constructorArg(), new ParseField("is_truncated")); + PARSER.declareObjectArray(constructorArg(), EMBEDDINGS_PARSER::apply, new ParseField("sparse_embedding")); + } + + @Override + public EnumSet supportedTasks() { + return SUPPORTED_TASKS; + } + + /** + * Reads sparse embeddings format + * { + * "sparse_embedding" : [ + * { + * "is_truncated" : false, + * "embedding" : { + * "token" : 0.1 + * } + * }, + * { + * "is_truncated" : false, + * "embedding" : { + * "token2" : 0.2, + * "token3" : 0.3 + * } + * } + * ] + * } + */ + @Override + public SparseEmbeddingResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { + return PARSER.apply(p, null); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java new file mode 100644 index 0000000000000..34cab3bef4caa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -0,0 +1,302 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.json.JsonXContent.jsonXContent; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; + +/** + * TextEmbedding needs to differentiate between Bit, Byte, and Float types. Users must specify the + * {@link org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType} in the Service Settings, + * and Elastic will use that to parse the request/response. {@link SimilarityMeasure} and Dimensions are also needed, though Dimensions can + * be guessed and set during the validation call. + * At the very least, Service Settings must look like: + * { + * "element_type": "bit|byte|float", + * "similarity": "cosine|dot_product|l2_norm" + * } + */ +public class ElasticTextEmbeddingPayload implements ElasticPayload { + private static final EnumSet SUPPORTED_TASKS = EnumSet.of(TaskType.TEXT_EMBEDDING); + private static final ParseField EMBEDDING = new ParseField("embedding"); + + @Override + public EnumSet supportedTasks() { + return SUPPORTED_TASKS; + } + + @Override + public SageMakerStoredServiceSchema apiServiceSettings(Map serviceSettings, ValidationException validationException) { + return ApiServiceSettings.fromMap(serviceSettings, validationException); + } + + @Override + public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception { + if (model.apiServiceSettings() instanceof ApiServiceSettings) { + return ElasticPayload.super.requestBytes(model, request); + } else { + throw createUnsupportedSchemaException(model); + } + } + + @Override + public Stream namedWriteables() { + return Stream.concat( + ElasticPayload.super.namedWriteables(), + Stream.of( + new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, ApiServiceSettings.NAME, ApiServiceSettings::new) + ) + ); + } + + @Override + public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { + return switch (model.apiServiceSettings().elementType()) { + case BIT -> TextEmbeddingBinary.PARSER.apply(p, null); + case BYTE -> TextEmbeddingBytes.PARSER.apply(p, null); + case FLOAT -> TextEmbeddingFloat.PARSER.apply(p, null); + }; + } + } + + /** + * Reads binary format (it says bytes, but the lengths are different) + * { + * "text_embedding_bits": [ + * { + * "embedding": [ + * 23 + * ] + * }, + * { + * "embedding": [ + * -23 + * ] + * } + * ] + * } + */ + private static class TextEmbeddingBinary { + private static final ParseField TEXT_EMBEDDING_BITS = new ParseField(TextEmbeddingBitResults.TEXT_EMBEDDING_BITS); + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + TextEmbeddingBitResults.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> new TextEmbeddingBitResults((List) args[0]) + ); + + static { + PARSER.declareObjectArray(constructorArg(), TextEmbeddingBytes.BYTE_PARSER::apply, TEXT_EMBEDDING_BITS); + } + } + + /** + * Reads byte format from + * { + * "text_embedding_bytes": [ + * { + * "embedding": [ + * 23 + * ] + * }, + * { + * "embedding": [ + * -23 + * ] + * } + * ] + * } + */ + private static class TextEmbeddingBytes { + private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField("text_embedding_bytes"); + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + TextEmbeddingByteResults.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> new TextEmbeddingByteResults((List) args[0]) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser BYTE_PARSER = + new ConstructingObjectParser<>( + TextEmbeddingByteResults.Embedding.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> TextEmbeddingByteResults.Embedding.of((List) args[0]) + ); + + static { + BYTE_PARSER.declareObjectArray(constructorArg(), (p, c) -> { + var byteVal = p.shortValue(); + if (byteVal < Byte.MIN_VALUE || byteVal > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value [" + byteVal + "] is out of range for a byte"); + } + return (byte) byteVal; + }, EMBEDDING); + PARSER.declareObjectArray(constructorArg(), BYTE_PARSER::apply, TEXT_EMBEDDING_BYTES); + } + } + + /** + * Reads float format from + * { + * "text_embedding": [ + * { + * "embedding": [ + * 0.1 + * ] + * }, + * { + * "embedding": [ + * 0.2 + * ] + * } + * ] + * } + */ + private static class TextEmbeddingFloat { + private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField("text_embedding"); + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + TextEmbeddingByteResults.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> new TextEmbeddingFloatResults((List) args[0]) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser FLOAT_PARSER = + new ConstructingObjectParser<>( + TextEmbeddingFloatResults.Embedding.class.getSimpleName(), + IGNORE_UNKNOWN_FIELDS, + args -> TextEmbeddingFloatResults.Embedding.of((List) args[0]) + ); + + static { + FLOAT_PARSER.declareFloatArray(constructorArg(), EMBEDDING); + PARSER.declareObjectArray(constructorArg(), FLOAT_PARSER::apply, TEXT_EMBEDDING_FLOAT); + } + } + + /** + * Element Type is required. It is used to disambiguate between binary embeddings and byte embeddings. + */ + record ApiServiceSettings( + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable SimilarityMeasure similarity, + DenseVectorFieldMapper.ElementType elementType + ) implements SageMakerStoredServiceSchema { + + private static final String NAME = "sagemaker_elastic_text_embeddings_service_settings"; + private static final String DIMENSIONS_FIELD = "dimensions"; + private static final String DIMENSIONS_SET_BY_USER_FIELD = "dimensions_set_by_user"; + private static final String SIMILARITY_FIELD = "similarity"; + private static final String ELEMENT_TYPE_FIELD = "element_type"; + + ApiServiceSettings(StreamInput in) throws IOException { + this( + in.readOptionalInt(), + in.readBoolean(), + in.readOptionalEnum(SimilarityMeasure.class), + in.readEnum(DenseVectorFieldMapper.ElementType.class) + ); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER_ELASTIC; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(dimensions); + out.writeBoolean(dimensionsSetByUser); + out.writeOptionalEnum(similarity); + out.writeEnum(elementType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD, dimensions); + } + builder.field(DIMENSIONS_SET_BY_USER_FIELD, dimensionsSetByUser); + if (similarity != null) { + builder.field(SIMILARITY_FIELD, similarity); + } + builder.field(ELEMENT_TYPE_FIELD, elementType); + return builder; + } + + @Override + public ApiServiceSettings updateModelWithEmbeddingDetails(Integer dimensions) { + return new ApiServiceSettings(dimensions, false, similarity, elementType); + } + + static ApiServiceSettings fromMap(Map serviceSettings, ValidationException validationException) { + var dimensions = extractOptionalPositiveInteger( + serviceSettings, + DIMENSIONS_FIELD, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var dimensionsSetByUser = extractOptionalBoolean(serviceSettings, DIMENSIONS_SET_BY_USER_FIELD, validationException); + var similarity = extractSimilarity(serviceSettings, ModelConfigurations.SERVICE_SETTINGS, validationException); + var elementType = extractRequiredEnum( + serviceSettings, + ELEMENT_TYPE_FIELD, + ModelConfigurations.SERVICE_SETTINGS, + DenseVectorFieldMapper.ElementType::fromString, + EnumSet.allOf(DenseVectorFieldMapper.ElementType.class), + validationException + ); + return new ApiServiceSettings(dimensions, dimensionsSetByUser != null && dimensionsSetByUser, similarity, elementType); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java new file mode 100644 index 0000000000000..3cdcbb35ffdc9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.io.IOException; +import java.util.Map; + +/** + * SageMaker + Elastic task settings are different in that they will not be stored within the index because + * they will not be verified. Instead, these task settings will only exist as additional input to SageMaker. + */ +record SageMakerElasticTaskSettings(@Nullable Map passthroughSettings) implements SageMakerStoredTaskSchema { + static final String NAME = "sagemaker_elastic_task_settings"; + + static SageMakerElasticTaskSettings empty() { + return new SageMakerElasticTaskSettings(Map.of()); + } + + SageMakerElasticTaskSettings(StreamInput in) throws IOException { + this(in.readGenericMap()); + } + + @Override + public boolean isEmpty() { + return passthroughSettings == null || passthroughSettings.isEmpty(); + } + + @Override + public SageMakerStoredTaskSchema updatedTaskSettings(Map newSettings) { + return new SageMakerElasticTaskSettings(newSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER_ELASTIC; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeGenericMap(passthroughSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return isEmpty() ? builder : builder.mapContents(passthroughSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java index 64b42f00d2d5b..03e1941df6938 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java @@ -15,11 +15,11 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -154,7 +154,7 @@ public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest req } @Override - public InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + public ChatCompletionResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { return OpenAiChatCompletionResponseEntity.fromResponse(response.body().asByteArray()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java index ebef5b6eefd9e..adffbb366fb02 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.sagemaker.schema; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.VersionedNamedWriteable; @@ -66,7 +67,7 @@ public final void testApiServiceSettings() throws IOException { validationException.throwIfValidationErrorsExist(); } - public final void testApiTaskSettings() throws IOException { + public void testApiTaskSettings() throws IOException { var validationException = new ValidationException(); var expectedApiTaskSettings = randomApiTaskSettings(); var actualApiTaskSettings = payload.apiTaskSettings(toMap(expectedApiTaskSettings), validationException); @@ -158,4 +159,8 @@ protected static void assertSdkBytes(SdkBytes sdkBytes, String expectedValue) { protected static void assertJsonSdkBytes(SdkBytes sdkBytes, String expectedValue) throws IOException { assertThat(sdkBytes.asUtf8String(), equalTo(XContentHelper.stripWhitespace(expectedValue))); } + + protected static InvokeEndpointResponse invokeEndpointResponse(String responseJson) { + return InvokeEndpointResponse.builder().body(SdkBytes.fromUtf8String(responseJson)).build(); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java index d306fc2713077..6461c490cf037 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java @@ -11,6 +11,10 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticCompletionPayload; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticRerankPayload; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticSparseEmbeddingPayload; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticTextEmbeddingPayload; import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload; import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload; @@ -43,7 +47,13 @@ public static SageMakerSchema mockSchema() { public void testSupportedTaskTypes() { assertThat( schemas.supportedTaskTypes(), - containsInAnyOrder(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION) + containsInAnyOrder( + TaskType.TEXT_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION, + TaskType.SPARSE_EMBEDDING, + TaskType.RERANK + ) ); } @@ -111,7 +121,11 @@ public void testMissingStreamSchemaThrowsException() { public void testNamedWriteables() { var namedWriteables = Stream.of( new OpenAiTextEmbeddingPayload().namedWriteables(), - new OpenAiCompletionPayload().namedWriteables() + new OpenAiCompletionPayload().namedWriteables(), + new ElasticCompletionPayload().namedWriteables(), + new ElasticSparseEmbeddingPayload().namedWriteables(), + new ElasticTextEmbeddingPayload().namedWriteables(), + new ElasticRerankPayload().namedWriteables() ); var expectedNamedWriteables = Stream.concat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayloadTests.java new file mode 100644 index 0000000000000..a662a405b35a8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticCompletionPayloadTests.java @@ -0,0 +1,149 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import software.amazon.awssdk.core.SdkBytes; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.is; + +public class ElasticCompletionPayloadTests extends ElasticPayloadTestCase { + @Override + protected ElasticCompletionPayload payload() { + return new ElasticCompletionPayload(); + } + + @Override + protected Set expectedSupportedTaskTypes() { + return Set.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + + public void testNonStreamingResponse() throws Exception { + var responseJson = """ + { + "completion": [ + { + "result": "hello" + } + ] + } + """; + + var chatCompletionResults = payload.responseBody(mockModel(), invokeEndpointResponse(responseJson)); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("hello")); + } + + public void testStreamingResponse() throws Exception { + var responseJson = """ + { + "completion": [ + { + "delta": "hola" + } + ] + } + """; + + var chatCompletionResults = payload.streamResponseBody(mockModel(), SdkBytes.fromUtf8String(responseJson)); + + assertThat(chatCompletionResults.results().size(), is(1)); + assertThat(chatCompletionResults.results().iterator().next().delta(), is("hola")); + } + + public void testChatCompletionRequest() throws Exception { + var message = new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello, world!"), "user", null, null); + var unifiedRequest = new UnifiedCompletionRequest( + List.of(message), + "i am ignored", + 10L, + List.of("right meow"), + 1.0F, + null, + null, + null + ); + var sdkBytes = payload.chatCompletionRequestBytes(mockModel(), unifiedRequest); + assertJsonSdkBytes(sdkBytes, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "stop": [ + "right meow" + ], + "temperature": 1.0, + "max_completion_tokens": 10 + } + """); + } + + public void testChatCompletionResponse() throws Exception { + var responseJson = """ + { + "id": "chunk1", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + ], + "model": "example_model", + "object": "example_object", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15 + } + } + """; + + var chatCompletionResponse = payload.chatCompletionResponseBody(mockModel(), SdkBytes.fromUtf8String(responseJson)); + + XContentBuilder builder = JsonXContent.contentBuilder(); + chatCompletionResponse.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(XContentHelper.stripWhitespace(responseJson), Strings.toString(builder).trim()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java new file mode 100644 index 0000000000000..65dcd62bb149a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public abstract class ElasticPayloadTestCase extends SageMakerSchemaPayloadTestCase { + + @Override + protected String expectedApi() { + return "elastic"; + } + + @Override + protected SageMakerStoredServiceSchema randomApiServiceSettings() { + return SageMakerStoredServiceSchema.NO_OP; + } + + @Override + protected SageMakerStoredTaskSchema randomApiTaskSettings() { + return SageMakerElasticTaskSettingsTests.randomInstance(); + } + + protected SageMakerModel mockModel() { + return mockModel(SageMakerElasticTaskSettings.empty()); + } + + protected SageMakerModel mockModel(SageMakerElasticTaskSettings taskSettings) { + SageMakerModel model = mock(); + when(model.apiTaskSettings()).thenReturn(taskSettings); + return model; + } + + public void testApiTaskSettings() { + { + var validationException = new ValidationException(); + var actualApiTaskSettings = payload.apiTaskSettings(null, validationException); + assertTrue(actualApiTaskSettings.isEmpty()); + assertTrue(validationException.validationErrors().isEmpty()); + } + { + var validationException = new ValidationException(); + var actualApiTaskSettings = payload.apiTaskSettings(Map.of(), validationException); + assertTrue(actualApiTaskSettings.isEmpty()); + assertTrue(validationException.validationErrors().isEmpty()); + } + { + var validationException = new ValidationException(); + var actualApiTaskSettings = payload.apiTaskSettings(Map.of("hello", "world"), validationException); + assertTrue(actualApiTaskSettings.isEmpty()); + assertFalse(validationException.validationErrors().isEmpty()); + assertThat( + validationException.validationErrors().get(0), + is(equalTo("task_settings is only supported during the inference request and cannot be stored in the inference endpoint.")) + ); + } + } + + public void testRequestWithRequiredFields() throws Exception { + var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), false, InputType.UNSPECIFIED); + var sdkByes = payload.requestBytes(mockModel(), request); + assertJsonSdkBytes(sdkByes, """ + { + "input": "hello" + }"""); + } + + public void testRequestWithInternalFields() throws Exception { + var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), false, InputType.INTERNAL_SEARCH); + var sdkByes = payload.requestBytes(mockModel(), request); + assertJsonSdkBytes(sdkByes, """ + { + "input": "hello", + "input_type": "search" + }"""); + } + + public void testRequestWithMultipleInput() throws Exception { + var request = new SageMakerInferenceRequest(null, null, null, List.of("hello", "there"), false, InputType.UNSPECIFIED); + var sdkByes = payload.requestBytes(mockModel(), request); + assertJsonSdkBytes(sdkByes, """ + { + "input": [ + "hello", + "there" + ] + }"""); + } + + public void testRequestWithOptionalFields() throws Exception { + var request = new SageMakerInferenceRequest("test", null, null, List.of("hello"), false, InputType.INGEST); + var sdkByes = payload.requestBytes(mockModel(new SageMakerElasticTaskSettings(Map.of("more", "args"))), request); + assertJsonSdkBytes(sdkByes, """ + { + "input": "hello", + "input_type": "ingest", + "query": "test", + "task_settings": { + "more": "args" + } + }"""); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticRerankPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticRerankPayloadTests.java new file mode 100644 index 0000000000000..4f7fd0daf594f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticRerankPayloadTests.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.is; + +public class ElasticRerankPayloadTests extends ElasticPayloadTestCase { + @Override + protected ElasticRerankPayload payload() { + return new ElasticRerankPayload(); + } + + @Override + protected Set expectedSupportedTaskTypes() { + return Set.of(TaskType.RERANK); + } + + public void testRequestWithRequiredFields() throws Exception { + var request = new SageMakerInferenceRequest("is this a greeting?", null, null, List.of("hello"), false, InputType.UNSPECIFIED); + var sdkByes = payload.requestBytes(mockModel(), request); + assertJsonSdkBytes(sdkByes, """ + { + "input": "hello", + "query": "is this a greeting?" + }"""); + } + + // input_type is ignored for rerank + public void testRequestWithInternalFields() throws Exception { + var request = new SageMakerInferenceRequest("is this a greeting?", null, null, List.of("hello"), false, InputType.INTERNAL_SEARCH); + var sdkByes = payload.requestBytes(mockModel(), request); + assertJsonSdkBytes(sdkByes, """ + { + "input": "hello", + "query": "is this a greeting?" + }"""); + } + + public void testRequestWithMultipleInput() throws Exception { + var request = new SageMakerInferenceRequest( + "is this a greeting?", + null, + null, + List.of("hello", "there"), + false, + InputType.UNSPECIFIED + ); + var sdkByes = payload.requestBytes(mockModel(), request); + assertJsonSdkBytes(sdkByes, """ + { + "input": [ + "hello", + "there" + ], + "query": "is this a greeting?" + }"""); + } + + public void testRequestWithOptionalFields() throws Exception { + var request = new SageMakerInferenceRequest("is this a greeting?", true, 5, List.of("hello"), false, InputType.INGEST); + var sdkByes = payload.requestBytes(mockModel(new SageMakerElasticTaskSettings(Map.of("more", "args"))), request); + assertJsonSdkBytes(sdkByes, """ + { + "input": "hello", + "query": "is this a greeting?", + "return_documents": true, + "top_n": 5, + "task_settings": { + "more": "args" + } + }"""); + } + + public void testResponse() throws Exception { + var responseJson = """ + { + "rerank": [ + { + "index": 0, + "relevance_score": 1.0, + "text": "hello, world" + } + ] + } + """; + + var rankedDocsResults = payload.responseBody(mockModel(), invokeEndpointResponse(responseJson)); + assertThat(rankedDocsResults.getRankedDocs().size(), is(1)); + var rankedDoc = rankedDocsResults.getRankedDocs().get(0); + assertThat(rankedDoc.index(), is(0)); + assertThat(rankedDoc.relevanceScore(), is(1.0F)); + assertThat(rankedDoc.text(), is("hello, world")); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayloadTests.java new file mode 100644 index 0000000000000..7235ad5a94ebd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayloadTests.java @@ -0,0 +1,64 @@ +/* + Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + or more contributor license agreements. Licensed under the Elastic License + 2.0; you may not use this file except in compliance with the Elastic License + 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.WeightedToken; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; + +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; + +public class ElasticSparseEmbeddingPayloadTests extends ElasticPayloadTestCase { + @Override + protected ElasticSparseEmbeddingPayload payload() { + return new ElasticSparseEmbeddingPayload(); + } + + @Override + protected Set expectedSupportedTaskTypes() { + return Set.of(TaskType.SPARSE_EMBEDDING); + } + + public void testParseResponse() throws Exception { + var responseJson = """ + { + "sparse_embedding" : [ + { + "is_truncated" : false, + "embedding" : { + "token" : 0.1 + } + }, + { + "is_truncated" : false, + "embedding" : { + "token2" : 0.2, + "token3" : 0.3 + } + } + ] + } + """; + + var results = payload.responseBody(mockModel(), invokeEndpointResponse(responseJson)); + + assertThat(results.embeddings().size(), is(2)); + + assertThat( + results.embeddings(), + containsInAnyOrder( + new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1F)), false), + new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token2", 0.2F), new WeightedToken("token3", 0.3F)), false) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java new file mode 100644 index 0000000000000..ed0ee43266ab5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayloadTests.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; + +import java.util.Set; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.when; + +public class ElasticTextEmbeddingPayloadTests extends ElasticPayloadTestCase { + @Override + protected ElasticTextEmbeddingPayload payload() { + return new ElasticTextEmbeddingPayload(); + } + + @Override + protected Set expectedSupportedTaskTypes() { + return Set.of(TaskType.TEXT_EMBEDDING); + } + + @Override + protected SageMakerStoredServiceSchema randomApiServiceSettings() { + return SageMakerElasticTextEmbeddingServiceSettingsTests.randomInstance(); + } + + @Override + protected SageMakerModel mockModel(SageMakerElasticTaskSettings taskSettings) { + var model = super.mockModel(taskSettings); + when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings()); + return model; + } + + protected SageMakerModel mockModel(DenseVectorFieldMapper.ElementType elementType) { + var model = super.mockModel(SageMakerElasticTaskSettings.empty()); + when(model.apiServiceSettings()).thenReturn(SageMakerElasticTextEmbeddingServiceSettingsTests.randomInstance(elementType)); + return model; + } + + public void testBitResponse() throws Exception { + var responseJson = """ + { + "text_embedding_bits": [ + { + "embedding": [ + 23 + ] + } + ] + } + """; + + var bitResults = payload.responseBody(mockModel(DenseVectorFieldMapper.ElementType.BIT), invokeEndpointResponse(responseJson)); + + assertThat(bitResults.embeddings().size(), is(1)); + var embedding = bitResults.embeddings().get(0); + assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class)); + assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); + } + + public void testByteResponse() throws Exception { + var responseJson = """ + { + "text_embedding_bytes": [ + { + "embedding": [ + 23 + ] + } + ] + } + """; + + var byteResults = payload.responseBody(mockModel(DenseVectorFieldMapper.ElementType.BYTE), invokeEndpointResponse(responseJson)); + + assertThat(byteResults.embeddings().size(), is(1)); + var embedding = byteResults.embeddings().get(0); + assertThat(embedding, isA(TextEmbeddingByteResults.Embedding.class)); + assertThat(((TextEmbeddingByteResults.Embedding) embedding).values(), is(new byte[] { 23 })); + } + + public void testFloatResponse() throws Exception { + var responseJson = """ + { + "text_embedding": [ + { + "embedding": [ + 0.1 + ] + } + ] + } + """; + + var byteResults = payload.responseBody(mockModel(DenseVectorFieldMapper.ElementType.FLOAT), invokeEndpointResponse(responseJson)); + + assertThat(byteResults.embeddings().size(), is(1)); + var embedding = byteResults.embeddings().get(0); + assertThat(embedding, isA(TextEmbeddingFloatResults.Embedding.class)); + assertThat(((TextEmbeddingFloatResults.Embedding) embedding).values(), is(new float[] { 0.1F })); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettingsTests.java new file mode 100644 index 0000000000000..60f188675955c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettingsTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; + +import java.util.Map; + +public class SageMakerElasticTaskSettingsTests extends InferenceSettingsTestCase { + @Override + protected SageMakerElasticTaskSettings fromMutableMap(Map mutableMap) { + return new SageMakerElasticTaskSettings(mutableMap); + } + + @Override + protected Writeable.Reader instanceReader() { + return SageMakerElasticTaskSettings::new; + } + + @Override + protected SageMakerElasticTaskSettings createTestInstance() { + return randomInstance(); + } + + static SageMakerElasticTaskSettings randomInstance() { + return randomBoolean() + ? SageMakerElasticTaskSettings.empty() + : new SageMakerElasticTaskSettings( + randomMap(1, 3, () -> Tuple.tuple(randomAlphanumericOfLength(4), randomAlphanumericOfLength(4))) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTextEmbeddingServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTextEmbeddingServiceSettingsTests.java new file mode 100644 index 0000000000000..1361ea40b1e80 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTextEmbeddingServiceSettingsTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; + +import java.util.Map; + +public class SageMakerElasticTextEmbeddingServiceSettingsTests extends InferenceSettingsTestCase< + ElasticTextEmbeddingPayload.ApiServiceSettings> { + + @Override + protected ElasticTextEmbeddingPayload.ApiServiceSettings fromMutableMap(Map mutableMap) { + var validationException = new ValidationException(); + var settings = ElasticTextEmbeddingPayload.ApiServiceSettings.fromMap(mutableMap, validationException); + validationException.throwIfValidationErrorsExist(); + return settings; + } + + @Override + protected Writeable.Reader instanceReader() { + return ElasticTextEmbeddingPayload.ApiServiceSettings::new; + } + + @Override + protected ElasticTextEmbeddingPayload.ApiServiceSettings createTestInstance() { + return randomInstance(); + } + + static ElasticTextEmbeddingPayload.ApiServiceSettings randomInstance() { + return randomInstance(randomFrom(DenseVectorFieldMapper.ElementType.values())); + } + + static ElasticTextEmbeddingPayload.ApiServiceSettings randomInstance(DenseVectorFieldMapper.ElementType elementType) { + return new ElasticTextEmbeddingPayload.ApiServiceSettings( + randomIntBetween(1, 100), + randomBoolean(), + randomFrom(SimilarityMeasure.values()), + elementType + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayloadTests.java index 24be845e5ab66..c04e99053b1c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayloadTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayloadTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai; import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; @@ -17,7 +16,6 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase; @@ -163,10 +161,7 @@ public void testResponse() throws Exception { } """; - var chatCompletionResults = (ChatCompletionResults) payload.responseBody( - mockModel(), - InvokeEndpointResponse.builder().body(SdkBytes.fromUtf8String(responseJson)).build() - ); + var chatCompletionResults = payload.responseBody(mockModel(), invokeEndpointResponse(responseJson)); assertThat(chatCompletionResults.getResults().size(), is(1)); assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); From eecce82ec0eb7f22005c037b19ed7793b2cbb23d Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Fri, 13 Jun 2025 09:21:09 -0400 Subject: [PATCH 2/6] Update docs/changelog/129413.yaml --- docs/changelog/129413.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/129413.yaml diff --git a/docs/changelog/129413.yaml b/docs/changelog/129413.yaml new file mode 100644 index 0000000000000..505b627c42b16 --- /dev/null +++ b/docs/changelog/129413.yaml @@ -0,0 +1,5 @@ +pr: 129413 +summary: '`SageMaker` Elastic Payload' +area: Machine Learning +type: enhancement +issues: [] From b4b1f3fc2588aecda6214a239c95251b45a83851 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 13 Jun 2025 13:29:48 +0000 Subject: [PATCH 3/6] [CI] Auto commit changes from spotless --- .../xpack/core/inference/results/ChatCompletionResults.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java index f1a01296c78c8..346d7416f9dc5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChatCompletionResults.java @@ -43,6 +43,7 @@ public record ChatCompletionResults(List results) implements InferenceSe public static final String NAME = "chat_completion_service_results"; public static final String COMPLETION = TaskType.COMPLETION.name().toLowerCase(Locale.ROOT); + public ChatCompletionResults(StreamInput in) throws IOException { this(in.readCollectionAsList(Result::new)); } From fa1909cb7a173b11208b9f40001a4cb08a9b2479 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Fri, 13 Jun 2025 09:34:15 -0400 Subject: [PATCH 4/6] Add sagemaker to IT --- .../inference/InferenceGetServicesIT.java | 57 ++++++------------- 1 file changed, 18 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index ecf89dff104a0..68e7f51e10f2e 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -20,7 +20,6 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.equalTo; public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @@ -31,13 +30,8 @@ public static void init() { } public void testGetServicesWithoutTaskType() throws IOException { - List services = getAllServices(); - assertThat(services.size(), equalTo(24)); - - var providers = providers(services); - assertThat( - providers, + allProviders(), containsInAnyOrder( List.of( "alibabacloud-ai-search", @@ -69,6 +63,10 @@ public void testGetServicesWithoutTaskType() throws IOException { ); } + private Iterable allProviders() throws IOException { + return providers(getAllServices()); + } + @SuppressWarnings("unchecked") private Iterable providers(List services) { return services.stream().map(service -> { @@ -78,13 +76,8 @@ private Iterable providers(List services) { } public void testGetServicesWithTextEmbeddingTaskType() throws IOException { - List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(17)); - - var providers = providers(services); - assertThat( - providers, + providersFor(TaskType.TEXT_EMBEDDING), containsInAnyOrder( List.of( "alibabacloud-ai-search", @@ -109,14 +102,13 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { ); } - public void testGetServicesWithRerankTaskType() throws IOException { - List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(9)); - - var providers = providers(services); + private Iterable providersFor(TaskType taskType) throws IOException { + return providers(getServices(taskType)); + } + public void testGetServicesWithRerankTaskType() throws IOException { assertThat( - providers, + providersFor(TaskType.RERANK), containsInAnyOrder( List.of( "alibabacloud-ai-search", @@ -127,20 +119,16 @@ public void testGetServicesWithRerankTaskType() throws IOException { "jinaai", "test_reranking_service", "voyageai", - "hugging_face" + "hugging_face", + "amazon_sagemaker" ).toArray() ) ); } public void testGetServicesWithCompletionTaskType() throws IOException { - List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(14)); - - var providers = providers(services); - assertThat( - providers, + providersFor(TaskType.COMPLETION), containsInAnyOrder( List.of( "alibabacloud-ai-search", @@ -164,13 +152,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException { } public void testGetServicesWithChatCompletionTaskType() throws IOException { - List services = getServices(TaskType.CHAT_COMPLETION); - assertThat(services.size(), equalTo(8)); - - var providers = providers(services); - assertThat( - providers, + providersFor(TaskType.CHAT_COMPLETION), containsInAnyOrder( List.of( "deepseek", @@ -187,13 +170,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { } public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { - List services = getServices(TaskType.SPARSE_EMBEDDING); - assertThat(services.size(), equalTo(7)); - - var providers = providers(services); - assertThat( - providers, + providersFor(TaskType.SPARSE_EMBEDDING), containsInAnyOrder( List.of( "alibabacloud-ai-search", @@ -202,7 +180,8 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { "elasticsearch", "hugging_face", "streaming_completion_test_service", - "test_service" + "test_service", + "amazon_sagemaker" ).toArray() ) ); From 76525d3dbf79d8eb2b6d861c5a54a26d99b22f74 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Fri, 13 Jun 2025 10:22:49 -0400 Subject: [PATCH 5/6] Fix copyright header --- .../elastic/ElasticSparseEmbeddingPayloadTests.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayloadTests.java index 7235ad5a94ebd..2bad469b6abad 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayloadTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticSparseEmbeddingPayloadTests.java @@ -1,8 +1,8 @@ /* - Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - or more contributor license agreements. Licensed under the Elastic License - 2.0; you may not use this file except in compliance with the Elastic License - 2.0. + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic; From 0d9e76b8a84ea118bbd683698a7c8a5e3a3a250a Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 23 Jun 2025 11:11:31 -0400 Subject: [PATCH 6/6] Use vint --- .../sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java index 34cab3bef4caa..cf9d24a86dcc3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -236,7 +236,7 @@ record ApiServiceSettings( ApiServiceSettings(StreamInput in) throws IOException { this( - in.readOptionalInt(), + in.readOptionalVInt(), in.readBoolean(), in.readOptionalEnum(SimilarityMeasure.class), in.readEnum(DenseVectorFieldMapper.ElementType.class) @@ -255,7 +255,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalInt(dimensions); + out.writeOptionalVInt(dimensions); out.writeBoolean(dimensionsSetByUser); out.writeOptionalEnum(similarity); out.writeEnum(elementType);