Skip to content

Commit 2e566c9

Browse files
[ML] CustomService adding template validation prior to request flow (#129591)
* Adding template validation prior to request flow * Fixing tests * Narrowing exception
1 parent 0f98676 commit 2e566c9

File tree

5 files changed

+94
-5
lines changed

5 files changed

+94
-5
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutor.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ private static void ensureNoMorePlaceholdersExist(String substitutedString, Stri
5151
Matcher matcher = VARIABLE_PLACEHOLDER_PATTERN.matcher(substitutedString);
5252
if (matcher.find()) {
5353
throw new IllegalStateException(
54-
Strings.format("Found placeholder [%s] in field [%s] after replacement call", matcher.group(), field)
54+
Strings.format(
55+
"Found placeholder [%s] in field [%s] after replacement call, "
56+
+ "please check that all templates have a corresponding field definition.",
57+
matcher.group(),
58+
field
59+
)
5560
);
5661
}
5762
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.xpack.inference.services.SenderService;
4040
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4141
import org.elasticsearch.xpack.inference.services.ServiceUtils;
42+
import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
4243

4344
import java.util.EnumSet;
4445
import java.util.HashMap;
@@ -55,6 +56,7 @@
5556
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
5657

5758
public class CustomService extends SenderService {
59+
5860
public static final String NAME = "custom";
5961
private static final String SERVICE_NAME = "Custom";
6062

@@ -101,12 +103,32 @@ public void parseRequestConfig(
101103
throwIfNotEmptyMap(serviceSettingsMap, NAME);
102104
throwIfNotEmptyMap(taskSettingsMap, NAME);
103105

106+
validateConfiguration(model);
107+
104108
parsedModelListener.onResponse(model);
105109
} catch (Exception e) {
106110
parsedModelListener.onFailure(e);
107111
}
108112
}
109113

114+
/**
115+
* This does some initial validation with mock inputs to determine if any templates are missing a field to fill them.
116+
*/
117+
private static void validateConfiguration(CustomModel model) {
118+
String query = null;
119+
if (model.getTaskType() == TaskType.RERANK) {
120+
query = "test query";
121+
}
122+
123+
try {
124+
new CustomRequest(query, List.of("test input"), model).createHttpRequest();
125+
} catch (IllegalStateException e) {
126+
var validationException = new ValidationException();
127+
validationException.addValidationError(Strings.format("Failed to validate model configuration: %s", e.getMessage()));
128+
throw validationException;
129+
}
130+
}
131+
110132
private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
111133
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
112134
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutorTests.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,38 @@ public void testReplace_ThrowsException_WhenPlaceHolderStillExists() {
3737
var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "key2", "value2"), "${", "}");
3838
var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super:${key}", "setting"));
3939

40-
assertThat(exception.getMessage(), is("Found placeholder [${key}] in field [setting] after replacement call"));
40+
assertThat(
41+
exception.getMessage(),
42+
is(
43+
"Found placeholder [${key}] in field [setting] after replacement call, "
44+
+ "please check that all templates have a corresponding field definition."
45+
)
46+
);
4147
}
4248
// only reports the first placeholder pattern
4349
{
4450
var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "some_key2", "value2"), "${", "}");
4551
var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super, ${key}, ${key2}", "setting"));
4652

47-
assertThat(exception.getMessage(), is("Found placeholder [${key}] in field [setting] after replacement call"));
53+
assertThat(
54+
exception.getMessage(),
55+
is(
56+
"Found placeholder [${key}] in field [setting] after replacement call, "
57+
+ "please check that all templates have a corresponding field definition."
58+
)
59+
);
4860
}
4961
{
5062
var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "key2", "value2"), "${", "}");
5163
var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super:${ \\/\tkey\"}", "setting"));
5264

53-
assertThat(exception.getMessage(), is("Found placeholder [${ \\/\tkey\"}] in field [setting] after replacement call"));
65+
assertThat(
66+
exception.getMessage(),
67+
is(
68+
"Found placeholder [${ \\/\tkey\"}] in field [setting] after replacement call,"
69+
+ " please check that all templates have a corresponding field definition."
70+
)
71+
);
5472
}
5573
}
5674
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.support.PlainActionFuture;
1212
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.common.ValidationException;
1314
import org.elasticsearch.common.settings.SecureString;
1415
import org.elasticsearch.core.Nullable;
1516
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@@ -52,6 +53,7 @@
5253
import java.util.Map;
5354

5455
import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
56+
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
5557
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
5658
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
5759
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@@ -611,6 +613,42 @@ public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOEx
611613
}
612614
}
613615

616+
public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNotFillTemplate() throws Exception {
617+
try (var service = createService(threadPool, clientManager)) {
618+
619+
var settingsMap = new HashMap<>(
620+
Map.of(
621+
CustomServiceSettings.URL,
622+
"http://www.abc.com",
623+
CustomServiceSettings.HEADERS,
624+
Map.of("key", "value"),
625+
QueryParameters.QUERY_PARAMETERS,
626+
List.of(List.of("key", "value")),
627+
CustomServiceSettings.REQUEST,
628+
"request body ${some_template}",
629+
CustomServiceSettings.RESPONSE,
630+
new HashMap<>(Map.of(CustomServiceSettings.JSON_PARSER, createResponseParserMap(TaskType.COMPLETION)))
631+
)
632+
);
633+
634+
var config = getRequestConfigMap(settingsMap, createTaskSettingsMap(), createSecretSettingsMap());
635+
636+
var listener = new PlainActionFuture<Model>();
637+
service.parseRequestConfig("id", TaskType.COMPLETION, config, listener);
638+
639+
var exception = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT));
640+
641+
assertThat(
642+
exception.getMessage(),
643+
is(
644+
"Validation Failed: 1: Failed to validate model configuration: Found placeholder "
645+
+ "[${some_template}] in field [request] after replacement call, please check that all "
646+
+ "templates have a corresponding field definition.;"
647+
)
648+
);
649+
}
650+
}
651+
614652
public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
615653
var model = createInternalEmbeddingModel(
616654
SimilarityMeasure.DOT_PRODUCT,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,13 @@ public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IO
264264

265265
var request = new CustomRequest(null, List.of("abc", "123"), model);
266266
var exception = expectThrows(IllegalStateException.class, request::createHttpRequest);
267-
assertThat(exception.getMessage(), is("Found placeholder [${task.key}] in field [header.Accept] after replacement call"));
267+
assertThat(
268+
exception.getMessage(),
269+
is(
270+
"Found placeholder [${task.key}] in field [header.Accept] after replacement call, "
271+
+ "please check that all templates have a corresponding field definition."
272+
)
273+
);
268274
}
269275

270276
public void testCreateRequest_ThrowsException_ForInvalidUrl() {

0 commit comments

Comments
 (0)