Skip to content

[ML] SageMaker Elastic Payload #129413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 23, 2025
5 changes: 5 additions & 0 deletions docs/changelog/129413.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129413
summary: '`SageMaker` Elastic Payload'
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_51);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -298,6 +299,7 @@ static TransportVersion def(int id) {
public static final TransportVersion HEAP_USAGE_IN_CLUSTER_INFO = def(9_096_0_00);
public static final TransportVersion NONE_CHUNKING_STRATEGY = def(9_097_0_00);
public static final TransportVersion PROJECT_DELETION_GLOBAL_BLOCK = def(9_098_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = def(9_099_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ public static Params withMaxCompletionTokensTokens(String modelId, Params params
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
*/
public static Params withMaxCompletionTokensTokens(Params params) {
return new DelegatingMapParams(Map.of(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD), params);
}

public sealed interface Content extends NamedWriteable, ToXContent permits ContentObjects, ContentString {}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public int hashCode() {
}

public record Result(String delta) implements ChunkedToXContent, Writeable {
private static final String RESULT = "delta";
public static final String RESULT = "delta";

private Result(StreamInput in) throws IOException {
this(in.readString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;

public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

Expand All @@ -31,13 +30,8 @@ public static void init() {
}

public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(24));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to check the list size anymore, containsInAnyOrder does that and will print out the missing element


var providers = providers(services);

assertThat(
providers,
allProviders(),
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
Expand Down Expand Up @@ -69,6 +63,10 @@ public void testGetServicesWithoutTaskType() throws IOException {
);
}

private Iterable<String> allProviders() throws IOException {
return providers(getAllServices());
}

@SuppressWarnings("unchecked")
private Iterable<String> providers(List<Object> services) {
return services.stream().map(service -> {
Expand All @@ -78,13 +76,8 @@ private Iterable<String> providers(List<Object> services) {
}

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(17));

var providers = providers(services);

assertThat(
providers,
providersFor(TaskType.TEXT_EMBEDDING),
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
Expand All @@ -109,14 +102,13 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
);
}

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(9));

var providers = providers(services);
private Iterable<String> providersFor(TaskType taskType) throws IOException {
return providers(getServices(taskType));
}

public void testGetServicesWithRerankTaskType() throws IOException {
assertThat(
providers,
providersFor(TaskType.RERANK),
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
Expand All @@ -127,20 +119,16 @@ public void testGetServicesWithRerankTaskType() throws IOException {
"jinaai",
"test_reranking_service",
"voyageai",
"hugging_face"
"hugging_face",
"amazon_sagemaker"
).toArray()
)
);
}

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(14));

var providers = providers(services);

assertThat(
providers,
providersFor(TaskType.COMPLETION),
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
Expand All @@ -164,13 +152,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(8));

var providers = providers(services);

assertThat(
providers,
providersFor(TaskType.CHAT_COMPLETION),
containsInAnyOrder(
List.of(
"deepseek",
Expand All @@ -187,13 +170,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
}

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(7));

var providers = providers(services);

assertThat(
providers,
providersFor(TaskType.SPARSE_EMBEDDING),
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
Expand All @@ -202,7 +180,8 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
"elasticsearch",
"hugging_face",
"streaming_completion_test_service",
"test_service"
"test_service",
"amazon_sagemaker"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,14 @@ public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>
return Stream.empty();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
return parse(parserConfig, event.data());
}

public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
XContentParserConfiguration parserConfig,
String data
) throws IOException {
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public InvokeEndpointRequest request(SageMakerModel model, SageMakerInferenceReq
throw e;
} catch (Exception e) {
throw new ElasticsearchStatusException(
"Failed to create SageMaker request for [%s]",
"Failed to create SageMaker request for [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
e,
model.getInferenceEntityId()
Expand Down Expand Up @@ -98,7 +98,7 @@ public InferenceServiceResults response(SageMakerModel model, InvokeEndpointResp
throw e;
} catch (Exception e) {
throw new ElasticsearchStatusException(
"Failed to translate SageMaker response for [%s]",
"Failed to translate SageMaker response for [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
e,
model.getInferenceEntityId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticCompletionPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticRerankPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticSparseEmbeddingPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticTextEmbeddingPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;

Expand Down Expand Up @@ -41,7 +45,14 @@ public class SageMakerSchemas {
/*
* Add new model API to the register call.
*/
schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload());
schemas = register(
new OpenAiTextEmbeddingPayload(),
new OpenAiCompletionPayload(),
new ElasticTextEmbeddingPayload(),
new ElasticSparseEmbeddingPayload(),
new ElasticCompletionPayload(),
new ElasticRerankPayload()
);

streamSchemas = schemas.entrySet()
.stream()
Expand Down
Loading
Loading