diff --git a/docs/changelog/129693.yaml b/docs/changelog/129693.yaml new file mode 100644 index 0000000000000..8edab59b7d03f --- /dev/null +++ b/docs/changelog/129693.yaml @@ -0,0 +1,5 @@ +pr: 129693 +summary: Add top level normalizer for linear retriever +area: Search +type: enhancement +issues: [] diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/linear-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/linear-retriever.md index 5008831b72acd..b4f6e15433b37 100644 --- a/docs/reference/elasticsearch/rest-apis/retrievers/linear-retriever.md +++ b/docs/reference/elasticsearch/rest-apis/retrievers/linear-retriever.md @@ -91,3 +91,47 @@ The `linear` retriever supports the following normalizers: score = (score - min) / (max - min) ``` * `l2_norm`: Normalizes scores using the L2 norm of the score values {applies_to}`stack: ga 9.1` + +## Examples [linear-retriever-examples] + +### Top-level normalizer example + +This example shows how to use a top-level normalizer that applies to all sub-retrievers: + +```console +GET my_index/_search +{ + "retriever": { + "linear": { + "retrievers": [ + { + "retriever": { + "standard": { + "query": { + "match": { + "title": "elasticsearch" + } + } + } + }, + "weight": 1.0 + }, + { + "retriever": { + "knn": { + "field": "title_vector", + "query_vector": [0.1, 0.2, 0.3], + "k": 10, + "num_candidates": 100 + } + }, + "weight": 2.0 + } + ], + "normalizer": "minmax" + } + } +} +``` + +In this example, the `minmax` normalizer is applied to both the standard retriever and the kNN retriever. The top-level normalizer serves as a default that can be overridden by individual sub-retrievers. When using the multi-field query format, the top-level normalizer is applied to all generated inner retrievers. diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index c1a3f7d174487..f7dd499e66ef4 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -43,6 +43,7 @@ import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_NORMALIZER; import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT; /** @@ -118,7 +119,7 @@ private static float[] getDefaultWeight(List innerRetrievers) { private static ScoreNormalizer[] getDefaultNormalizers(List innerRetrievers) { int size = innerRetrievers != null ? innerRetrievers.size() : 0; ScoreNormalizer[] normalizers = new ScoreNormalizer[size]; - Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE); + Arrays.fill(normalizers, DEFAULT_NORMALIZER); return normalizers; } @@ -167,7 +168,14 @@ public LinearRetrieverBuilder( this.query = query; this.normalizer = normalizer; this.weights = weights; - this.normalizers = normalizers; + this.normalizers = new ScoreNormalizer[normalizers.length]; + for (int i = 0; i < normalizers.length; i++) { + if (normalizers[i] == null || normalizers[i].equals(DEFAULT_NORMALIZER)) { + this.normalizers[i] = normalizer != null ? normalizer : DEFAULT_NORMALIZER; + } else { + this.normalizers[i] = normalizers[i]; + } + } } public LinearRetrieverBuilder( @@ -221,19 +229,7 @@ public ActionRequestValidationException validate( ), validationException ); - } else if (innerRetrievers.isEmpty() == false && normalizer != null) { - validationException = addValidationError( - String.format( - Locale.ROOT, - "[%s] [%s] cannot be provided when [%s] is specified", - getName(), - NORMALIZER_FIELD.getPreferredName(), - RETRIEVERS_FIELD.getPreferredName() - ), - validationException - ); } - return validationException; } @@ -410,7 +406,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept if (query != null) { builder.field(QUERY_FIELD.getPreferredName(), query); } - if (normalizer != null) { + if (normalizer != null && normalizer.equals(DEFAULT_NORMALIZER) == false) { builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName()); } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java index 34b7277498218..d760ffa878aaa 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java @@ -9,6 +9,8 @@ import org.apache.lucene.search.ScoreDoc; +import java.util.Objects; + /** * A no-op {@link ScoreNormalizer} that does not modify the scores. */ @@ -31,4 +33,21 @@ public static ScoreNormalizer valueOf(String normalizer) { public abstract String getName(); public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ScoreNormalizer that = (ScoreNormalizer) obj; + return Objects.equals(getName(), that.getName()); + } + + @Override + public int hashCode() { + return Objects.hash(getName()); + } } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index 74e18bf12fffc..525234fc8e7c8 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -67,7 +67,11 @@ protected LinearRetrieverBuilder createTestInstance() { new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null) ); weights[i] = randomFloat(); - normalizers[i] = randomScoreNormalizer(); + if (normalizer != null && randomBoolean()) { + normalizers[i] = normalizer; + } else { + normalizers[i] = randomScoreNormalizer(); + } } return new LinearRetrieverBuilder(innerRetrievers, fields, query, normalizer, rankWindowSize, weights, normalizers); diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java index c211440d10bae..f87816fbe4427 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.test.ESTestCase; @@ -326,4 +327,63 @@ public int hashCode() { return Objects.hash(retriever, weight, normalizer); } } + + public void testTopLevelNormalizerWithRetrieversArray() { + StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(new MatchQueryBuilder("title", "elasticsearch")); + KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( + "title_vector", + new float[] { 0.1f, 0.2f, 0.3f }, + null, + 10, + 100, + null, + null + ); + + LinearRetrieverBuilder retriever = new LinearRetrieverBuilder( + List.of( + CompoundRetrieverBuilder.RetrieverSource.from(standardRetriever), + CompoundRetrieverBuilder.RetrieverSource.from(knnRetriever) + ), + null, // fields + null, // query + MinMaxScoreNormalizer.INSTANCE, // top-level normalizer + DEFAULT_RANK_WINDOW_SIZE, + new float[] { 1.0f, 2.0f }, + new ScoreNormalizer[] { null, null } + ); + + assertEquals(MinMaxScoreNormalizer.INSTANCE, retriever.getNormalizers()[0]); + assertEquals(MinMaxScoreNormalizer.INSTANCE, retriever.getNormalizers()[1]); + } + + public void testTopLevelNormalizerWithPerRetrieverOverrides() { + StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(new MatchQueryBuilder("title", "elasticsearch")); + KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( + "title_vector", + new float[] { 0.1f, 0.2f, 0.3f }, + null, + 10, + 100, + null, + null + ); + + LinearRetrieverBuilder retriever = new LinearRetrieverBuilder( + List.of( + CompoundRetrieverBuilder.RetrieverSource.from(standardRetriever), + CompoundRetrieverBuilder.RetrieverSource.from(knnRetriever) + ), + null, // fields + null, // query + MinMaxScoreNormalizer.INSTANCE, // top-level normalizer + DEFAULT_RANK_WINDOW_SIZE, + new float[] { 1.0f, 2.0f }, + new ScoreNormalizer[] { L2ScoreNormalizer.INSTANCE, null } + ); + + assertEquals(L2ScoreNormalizer.INSTANCE, retriever.getNormalizers()[0]); + assertEquals(MinMaxScoreNormalizer.INSTANCE, retriever.getNormalizers()[1]); + } + } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml index f62c7e4987046..b9e55aa85e05b 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -1333,3 +1333,217 @@ setup: - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "1" } - close_to: { hits.hits.0._score: { value: 1.0, error: 0.001} } + +--- +"linear retriever with top-level normalizer - minmax": + - do: + search: + index: test + body: + retriever: + linear: + normalizer: minmax + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 10.0 } }, + { constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 5.0 } } + ] + } + } + } + }, + weight: 1.0 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [4], + k: 2, + num_candidates: 10 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 2.0, error: 0.01} } + - match: { hits.hits.1._id: "1" } + - close_to: { hits.hits.1._score: { value: 1.0, error: 0.01} } + - match: { hits.hits.2._id: "2" } + - close_to: { hits.hits.2._score: { value: 0.0, error: 0.01} } + +--- +"linear retriever with top-level normalizer - l2_norm": + - requires: + cluster_features: [ "linear_retriever.l2_norm" ] + reason: "Requires l2_norm normalization support in linear retriever" + - do: + search: + index: test + body: + retriever: + linear: + normalizer: l2_norm + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 3.0 } }, + { constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 4.0 } } + ] + } + } + } + }, + weight: 1.0 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [4], + k: 2, + num_candidates: 10 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 1.79, error: 0.01} } + - match: { hits.hits.1._id: "3" } + - close_to: { hits.hits.1._score: { value: 0.89, error: 0.01} } + - match: { hits.hits.2._id: "2" } + - close_to: { hits.hits.2._score: { value: 0.8, error: 0.01} } + +--- +"linear retriever with top-level normalizer and per-retriever override": + - do: + search: + index: test + body: + retriever: + linear: + normalizer: minmax + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { constant_score: { filter: { term: { keyword: { value: "one" } } }, boost: 10.0 } }, + { constant_score: { filter: { term: { keyword: { value: "two" } } }, boost: 5.0 } } + ] + } + } + } + }, + weight: 1.0, + normalizer: l2_norm + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [4], + k: 2, + num_candidates: 10 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 2.0, error: 0.01} } + - match: { hits.hits.1._id: "1" } + - close_to: { hits.hits.1._score: { value: 0.89, error: 0.01} } + - match: { hits.hits.2._id: "2" } + - close_to: { hits.hits.2._score: { value: 0.45, error: 0.01} } + +--- +"linear retriever with top-level normalizer - multi-field format": + - do: + search: + index: test + body: + retriever: + linear: + normalizer: minmax + query: one + fields: [keyword, text] + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.001} } + +--- +"linear retriever with top-level normalizer - validation test": + - do: + catch: /Unknown normalizer \[invalid\]/ + search: + index: test + body: + retriever: + linear: + normalizer: invalid + retrievers: [ + { + retriever: { + standard: { + query: { + term: { + keyword: { + value: "one" + } + } + } + } + }, + weight: 1.0 + } + ] + +--- +"linear retriever with top-level normalizer - empty results": + - do: + search: + index: test + body: + retriever: + linear: + normalizer: minmax + retrievers: [ + { + retriever: { + standard: { + query: { + term: { + keyword: { + value: "nonexistent" + } + } + } + } + }, + weight: 1.0 + } + ] + + - match: { hits.total.value: 0 } + - length: { hits.hits: 0 } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever_normalizers.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever_normalizers.yml new file mode 100644 index 0000000000000..d99cda3185027 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever_normalizers.yml @@ -0,0 +1,247 @@ +setup: + - requires: + cluster_features: [ "linear_retriever_supported", "linear_retriever.l2_norm" ] + reason: "Support for linear retriever and L2 normalization" + test_runner_features: close_to + + - do: + indices.create: + index: test + body: + mappings: + properties: + vector: + type: dense_vector + dims: 1 + index: true + similarity: l2_norm + index_options: + type: flat + keyword: + type: keyword + other_keyword: + type: keyword + timestamp: + type: date + + - do: + bulk: + refresh: true + index: test + body: + - '{"index": {"_id": 1 }}' + - '{"vector": [1], "keyword": "one", "other_keyword": "other", "timestamp": "2021-01-01T00:00:00"}' + - '{"index": {"_id": 2 }}' + - '{"vector": [2], "keyword": "two", "timestamp": "2022-01-01T00:00:00"}' + - '{"index": {"_id": 3 }}' + - '{"vector": [3], "keyword": "three", "timestamp": "2023-01-01T00:00:00"}' + - '{"index": {"_id": 4 }}' + - '{"vector": [4], "keyword": "four", "other_keyword": "other", "timestamp": "2024-01-01T00:00:00"}' + +--- +"Linear retriever with top-level L2 normalization": + - do: + search: + index: test + body: + retriever: + linear: + normalizer: l2_norm + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 5.0 + } + } + } + }, + weight: 1.0 + }, + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 12.0 + } + } + } + }, + weight: 1.0 + } + ] + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "4" } # Doc 4 should rank higher with normalized scores + - match: { hits.hits.1._id: "1" } + # With L2 normalization: [5.0, 12.0] becomes [5.0/13.0, 12.0/13.0] + - close_to: { hits.hits.0._score: { value: 0.923, error: 0.01} } # 12.0/13.0 + - close_to: { hits.hits.1._score: { value: 0.385, error: 0.01} } # 5.0/13.0 + +--- +"Linear retriever with per-retriever L2 normalization": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 5.0 + } + } + } + }, + weight: 1.0, + normalizer: l2_norm + }, + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 12.0 + } + } + } + }, + weight: 1.0, + normalizer: l2_norm + } + ] + + - match: { hits.total.value: 2 } + # With per-retriever L2 normalization, both scores would be normalized to 1.0 + # So final score = 1.0 * weight1 + 1.0 * weight2 = 2.0 for each doc + # Then sorting is done by _doc (or some other tiebreaker) + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.01} } + - close_to: { hits.hits.1._score: { value: 1.0, error: 0.01} } + +--- +"Linear retriever with mixed normalization (top-level and per-retriever with same normalizer)": + - do: + search: + index: test + body: + retriever: + linear: + normalizer: l2_norm + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 5.0 + } + } + } + }, + weight: 1.0 + }, + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 12.0 + } + } + } + }, + weight: 1.0, + normalizer: l2_norm + } + ] + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.1._id: "1" } + # With L2 normalization: [5.0, 12.0] becomes [5.0/13.0, 12.0/13.0] + - close_to: { hits.hits.0._score: { value: 0.923, error: 0.01} } + - close_to: { hits.hits.1._score: { value: 0.385, error: 0.01} } + +--- +"Linear retriever with mismatched normalizers (should fail)": + - do: + catch: bad_request + search: + index: test + body: + retriever: + linear: + normalizer: l2_norm + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: {} + } + } + } + }, + { + retriever: { + standard: { + query: { + match_all: {} + } + } + }, + normalizer: minmax + } + ] + + - match: { error.root_cause.0.type: "illegal_argument_exception" } + - match: { error.root_cause.0.reason: /.*All per-retriever normalizers must match the top-level normalizer.*/ }