Skip to content

[ML] Move to the Cohere V2 API for new inference endpoints #129884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129884.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129884
summary: Move to the Cohere V2 API for new inference endpoints
area: Machine Learning
type: enhancement
issues: []
2 changes: 0 additions & 2 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ tests:
- class: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT
method: test {yaml=analysis-common/40_token_filters/stemmer_override file access}
issue: https://github.com/elastic/elasticsearch/issues/121625
- class: org.elasticsearch.xpack.application.CohereServiceUpgradeIT
issue: https://github.com/elastic/elasticsearch/issues/121537
- class: org.elasticsearch.test.rest.ClientYamlTestSuiteIT
method: test {yaml=snapshot.delete/10_basic/Delete a snapshot asynchronously}
issue: https://github.com/elastic/elasticsearch/issues/122102
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ static TransportVersion def(int id) {
public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55);
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);

public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_57);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -311,6 +311,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00);
public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_105_0_00);
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_106_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import com.carrotsearch.randomizedtesting.annotations.Name;

import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockRequest;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
Expand All @@ -24,6 +26,7 @@

import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
Expand All @@ -36,10 +39,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
// TODO: replace with proper test features
private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0";
private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0";
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";

private static MockWebServer cohereEmbeddingsServer;
private static MockWebServer cohereRerankServer;

private enum ApiVersion {
V1,
V2
}

public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
super(upgradedNodes);
}
Expand All @@ -62,15 +71,18 @@ public static void shutdown() {
@SuppressWarnings("unchecked")
public void testCohereEmbeddings() throws IOException {
var embeddingsSupported = oldClusterHasFeature(COHERE_EMBEDDINGS_ADDED_TEST_FEATURE);
String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
assumeTrue("Cohere embedding service supported", embeddingsSupported);

String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;

final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
final String oldClusterIdFloat = "old-cluster-embeddings-float";

var testTaskType = TaskType.TEXT_EMBEDDING;

if (isOldCluster()) {

// queue a response as PUT will call the service
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
Expand Down Expand Up @@ -128,13 +140,17 @@ public void testCohereEmbeddings() throws IOException {

// Inference on old cluster models
assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion);
assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion);

{
final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte";

// new endpoints use the V2 API
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdByte).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
Expand All @@ -146,34 +162,70 @@ public void testCohereEmbeddings() throws IOException {
{
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8";

// new endpoints use the V2 API
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdInt8).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte

assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
delete(upgradedClusterIdInt8);
}
{
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float";
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdFloat).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "float"));

assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT);
assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2);
delete(upgradedClusterIdFloat);
}
{
// new endpoints use the V2 API which require the model to be set
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
var jsonBody = Strings.format("""
{
"service": "cohere",
"service_settings": {
"url": "%s",
"api_key": "XXXX",
"embedding_type": "int8"
}
}
""", getUrl(cohereEmbeddingsServer));

var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
assertThat(
e.getMessage(),
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
);
}

delete(oldClusterIdFloat);
delete(oldClusterIdInt8);
}
}

private void assertVersionInPath(MockRequest request, String endpoint, ApiVersion apiVersion) {
switch (apiVersion) {
case V2:
assertEquals("/v2/" + endpoint, request.getUri().getPath());
break;
case V1:
assertEquals("/v1/" + endpoint, request.getUri().getPath());
break;
}
}

void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
switch (type) {
case INT8:
Expand All @@ -191,9 +243,11 @@ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) thro
@SuppressWarnings("unchecked")
public void testRerank() throws IOException {
var rerankSupported = oldClusterHasFeature(COHERE_RERANK_ADDED_TEST_FEATURE);
String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
assumeTrue("Cohere rerank service supported", rerankSupported);

String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models";
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;

final String oldClusterId = "old-cluster-rerank";
final String upgradedClusterId = "upgraded-cluster-rerank";

Expand All @@ -216,7 +270,6 @@ public void testRerank() throws IOException {
assertThat(taskSettings, hasEntry("top_n", 3));

assertRerank(oldClusterId);

} else if (isUpgradedCluster()) {
// check old cluster model
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get("endpoints");
Expand All @@ -227,6 +280,7 @@ public void testRerank() throws IOException {
assertThat(taskSettings, hasEntry("top_n", 3));

assertRerank(oldClusterId);
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", oldClusterApiVersion);

// New endpoint
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
Expand All @@ -235,6 +289,27 @@ public void testRerank() throws IOException {
assertThat(configs, hasSize(1));

assertRerank(upgradedClusterId);
assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", ApiVersion.V2);

{
// new endpoints use the V2 API which require the model_id to be set
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
var jsonBody = Strings.format("""
{
"service": "cohere",
"service_settings": {
"url": "%s",
"api_key": "XXXX"
}
}
""", getUrl(cohereEmbeddingsServer));

var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
assertThat(
e.getMessage(),
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
);
}

delete(oldClusterId);
delete(upgradedClusterId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class InferenceFeatures implements FeatureSpecification {
"test_rule_retriever.with_indices_that_dont_return_rank_docs"
);
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");

@Override
public Set<NodeFeature> getTestFeatures() {
Expand Down Expand Up @@ -64,7 +65,8 @@ public Set<NodeFeature> getTestFeatures() {
SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG,
SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
SEMANTIC_TEXT_INDEX_OPTIONS
SEMANTIC_TEXT_INDEX_OPTIONS,
COHERE_V2_API
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
private final Boolean returnDocuments;
private final Integer topN;

public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, null, null, false);
}

public QueryAndDocsInputs(
String query,
List<String> chunks,
Expand All @@ -45,6 +41,10 @@ public QueryAndDocsInputs(
this.topN = topN;
}

public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, null, null, false);
}

public String getQuery() {
return query;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,35 @@

package org.elasticsearch.xpack.inference.services.cohere;

import org.elasticsearch.common.CheckedSupplier;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;

public record CohereAccount(URI uri, SecureString apiKey) {

public static CohereAccount of(CohereModel model, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
var uri = buildUri(model.uri(), "Cohere", uriBuilder);

return new CohereAccount(uri, model.apiKey());
public record CohereAccount(URI baseUri, SecureString apiKey) {

public static CohereAccount of(CohereModel model) {
try {
var uri = model.baseUri() != null ? model.baseUri() : new URIBuilder().setScheme("https").setHost(CohereUtils.HOST).build();
return new CohereAccount(uri, model.apiKey());
} catch (URISyntaxException e) {
// using bad request here so that potentially sensitive URL information does not get logged
throw new ElasticsearchStatusException(
Strings.format("Failed to construct %s URL", CohereService.NAME),
RestStatus.BAD_REQUEST,
e
);
}
}

public CohereAccount {
Objects.requireNonNull(uri);
Objects.requireNonNull(baseUri);
Objects.requireNonNull(apiKey);
}
}

This file was deleted.

Loading
Loading