From 4ab30411f767b008f8d7b9b24ed2d674a910e87f Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 27 Jun 2025 19:47:57 +0100 Subject: [PATCH 01/15] Speed up (filtered) KNN queries for flat vector fields MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For dense vector fields using the `flat` index, we already know a brute-force search will be used—so there’s no need to go through the codec’s approximate KNN logic. This change skips that step and builds the brute-force query directly, making things faster and simpler. I tested this on a setup with **10 million random vectors**, each with **1596 dimensions** and **17,500 partitions**, using the `random_vector` track. The results: ### Performance Comparison | Metric | Before | After | Change | | ----------------- | --------- | ---------- | --------- | | **Throughput** | 221 ops/s | 2762 ops/s | 🟢 +1149% | | **Latency (p50)** | 29.2 ms | 1.6 ms | 🔻 -94.4% | | **Latency (p99)** | 81.6 ms | 3.5 ms | 🔻 -95.7% | Filtered KNN queries on flat vectors are now over 10x faster on my laptop! --- .../elasticsearch/test/knn/KnnSearcher.java | 2 +- .../vectors/DenseVectorFieldMapper.java | 114 ++++++++++++-- .../search/vectors/DenseVectorQuery.java | 1 + .../search/vectors/RescoreKnnVectorQuery.java | 149 +++++++++++++++--- .../index/mapper/DynamicMappingTests.java | 10 +- .../vectors/DenseVectorFieldTypeTests.java | 57 +++++-- ...AbstractKnnVectorQueryBuilderTestCase.java | 37 ++++- .../vectors/RescoreKnnVectorQueryTests.java | 94 ++++++----- 8 files changed, 366 insertions(+), 98 deletions(-) diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index 7dd6f2894a20a..c24d8cb7e94e2 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -309,7 +309,7 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException } if (overSamplingFactor > 1f) { // oversample the topK results to get more candidates for the final result - knnQuery = new RescoreKnnVectorQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, knnQuery); + knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery); } QueryProfiler profiler = new QueryProfiler(); TopDocs docs = searcher.search(knnQuery, this.topK); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 329e426be7f47..0350cc0c18b2a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -30,9 +30,13 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.BytesRef; @@ -1391,6 +1395,18 @@ public final boolean equals(Object other) { public final int hashCode() { return Objects.hash(type, doHashCode()); } + + /** + * Indicates whether the underlying vector search is performed using a flat (exhaustive) approach. + *

+ * When {@code true}, it means the search does not use any approximate nearest neighbor (ANN) + * acceleration structures such as HNSW or IVF. Instead, it performs a brute-force comparison + * against all candidate vectors. This information can be used by higher-level components + * to decide whether additional acceleration or optimization is necessary. + * + * @return {@code true} if the vector search is flat (exhaustive), {@code false} if it uses ANN structures + */ + abstract boolean isFlat(); } abstract static class QuantizedIndexOptions extends DenseVectorIndexOptions { @@ -1762,6 +1778,11 @@ int doHashCode() { return Objects.hash(confidenceInterval, rescoreVector); } + @Override + boolean isFlat() { + return true; + } + @Override public boolean updatableTo(DenseVectorIndexOptions update) { return update.type.equals(this.type) @@ -1810,6 +1831,11 @@ public boolean doEquals(DenseVectorIndexOptions o) { public int doHashCode() { return Objects.hash(type); } + + @Override + boolean isFlat() { + return true; + } } public static class Int4HnswIndexOptions extends QuantizedIndexOptions { @@ -1860,6 +1886,11 @@ public int doHashCode() { return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } + @Override + boolean isFlat() { + return false; + } + @Override public String toString() { return "{type=" @@ -1931,6 +1962,11 @@ public int doHashCode() { return Objects.hash(confidenceInterval, rescoreVector); } + @Override + boolean isFlat() { + return true; + } + @Override public String toString() { return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}"; @@ -1999,6 +2035,11 @@ public int doHashCode() { return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } + @Override + boolean isFlat() { + return false; + } + @Override public String toString() { return "{type=" @@ -2088,6 +2129,11 @@ public int doHashCode() { return Objects.hash(m, efConstruction); } + @Override + boolean isFlat() { + return false; + } + @Override public String toString() { return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; @@ -2126,6 +2172,11 @@ int doHashCode() { return Objects.hash(m, efConstruction, rescoreVector); } + @Override + boolean isFlat() { + return false; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -2179,6 +2230,11 @@ int doHashCode() { return CLASS_NAME_HASH; } + @Override + boolean isFlat() { + return true; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -2237,6 +2293,11 @@ int doHashCode() { return Objects.hash(clusterSize, defaultNProbe, rescoreVector); } + @Override + boolean isFlat() { + return false; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -2485,9 +2546,21 @@ private Query createKnnBitQuery( KnnSearchStrategy searchStrategy ) { elementType.checkDimensions(dims, queryVector.length); - Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + Query knnQuery; + if (indexOptions.isFlat()) { + knnQuery = filter == null + ? createExactKnnBitQuery(queryVector) + : new BooleanQuery.Builder().add(createExactKnnBitQuery(queryVector), BooleanClause.Occur.SHOULD) + .add(filter, BooleanClause.Occur.FILTER) + .build(); + if (parentFilter != null) { + knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); + } + } else { + knnQuery = parentFilter != null + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2513,9 +2586,22 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + + Query knnQuery; + if (indexOptions.isFlat()) { + knnQuery = filter == null + ? createExactKnnByteQuery(queryVector) + : new BooleanQuery.Builder().add(createExactKnnByteQuery(queryVector), BooleanClause.Occur.SHOULD) + .add(filter, BooleanClause.Occur.FILTER) + .build(); + if (parentFilter != null) { + knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); + } + } else { + knnQuery = parentFilter != null + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2568,7 +2654,16 @@ && isNotUnitVector(squaredMagnitude)) { numCands = Math.max(adjustedK, numCands); } Query knnQuery; - if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { + if (indexOptions.isFlat()) { + knnQuery = filter == null + ? createExactKnnFloatQuery(queryVector) + : new BooleanQuery.Builder().add(createExactKnnFloatQuery(queryVector), BooleanClause.Occur.SHOULD) + .add(filter, BooleanClause.Occur.FILTER) + .build(); + if (parentFilter != null) { + knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); + } + } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { knnQuery = parentFilter != null ? new DiversifyingChildrenIVFKnnFloatVectorQuery( name(), @@ -2594,11 +2689,12 @@ && isNotUnitVector(squaredMagnitude)) { : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy); } if (rescore) { - knnQuery = new RescoreKnnVectorQuery( + knnQuery = RescoreKnnVectorQuery.fromInnerQuery( name(), queryVector, similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), k, + adjustedK, knnQuery ); } @@ -2624,7 +2720,7 @@ ElementType getElementType() { return elementType; } - public IndexOptions getIndexOptions() { + public DenseVectorIndexOptions getIndexOptions() { return indexOptions; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java index 31e19b6784757..dfc84f5eaef6d 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java @@ -207,4 +207,5 @@ public int docID() { return iterator.docID(); } } + } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 99568a507ffb9..fa6240c56eda2 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -12,11 +12,11 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.TopDocs; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -25,17 +25,28 @@ import java.util.Objects; /** - * Wraps an internal query to rescore the results using a similarity function over the original, non-quantized vectors of a vector field + * A Lucene {@link Query} that applies vector-based rescoring to an inner query's results. + *

+ * Depending on the nature of the {@code innerQuery}, this class dynamically selects between two rescoring strategies: + *

*/ -public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider { - private final String fieldName; - private final float[] floatTarget; - private final VectorSimilarityFunction vectorSimilarityFunction; - private final int k; - private final Query innerQuery; - private long vectorOperations = 0; - - public RescoreKnnVectorQuery( +public abstract class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider { + protected final String fieldName; + protected final float[] floatTarget; + protected final VectorSimilarityFunction vectorSimilarityFunction; + protected final int k; + protected final Query innerQuery; + protected long vectorOperations = 0; + + private RescoreKnnVectorQuery( String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, @@ -49,16 +60,30 @@ public RescoreKnnVectorQuery( this.innerQuery = innerQuery; } - @Override - public Query rewrite(IndexSearcher searcher) throws IOException { - DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); - FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); - Query query = searcher.rewrite(functionScoreQuery); - - // Retrieve top k documents from the rescored query - TopDocs topDocs = searcher.search(query, k); - vectorOperations = topDocs.totalHits.value(); - return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + /** + * Selects and returns the appropriate {@link RescoreKnnVectorQuery} strategy based on the nature of the {@code innerQuery}. + * + * @param fieldName the name of the field containing the vector + * @param floatTarget the target vector to compare against + * @param vectorSimilarityFunction the similarity function to apply + * @param k the number of top documents to return after rescoring + * @param rescoreK the number of top documents to consider for rescoring + * @param innerQuery the original Lucene query to rescore + */ + public static RescoreKnnVectorQuery fromInnerQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + int k, + int rescoreK, + Query innerQuery + ) { + if ((innerQuery instanceof KnnFloatVectorQuery fQuery && fQuery.getK() == rescoreK) + || (innerQuery instanceof KnnByteVectorQuery bQuery && bQuery.getK() == rescoreK) + || (innerQuery instanceof AbstractIVFKnnVectorQuery ivfQuery && ivfQuery.k == rescoreK)) { + return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); + } + return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery); } public Query innerQuery() { @@ -102,7 +127,8 @@ public int hashCode() { @Override public String toString(String field) { - return "KnnRescoreVectorQuery{" + return getClass().getSimpleName() + + "{" + "fieldName='" + fieldName + '\'' @@ -117,4 +143,81 @@ public String toString(String field) { + innerQuery + '}'; } + + private static class InlineRescoreQuery extends RescoreKnnVectorQuery { + private InlineRescoreQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + int k, + Query innerQuery + ) { + super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + var functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); + // Retrieve top k documents from the function score query + var topDocs = searcher.search(functionScoreQuery, k); + vectorOperations = topDocs.totalHits.value(); + return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return super.equals(o); + } + + @Override + public int hashCode() { + return super.hashCode(); + } + } + + private static class LateRescoreQuery extends RescoreKnnVectorQuery { + final int rescoreK; + + private LateRescoreQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + int k, + int rescoreK, + Query innerQuery + ) { + super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); + this.rescoreK = rescoreK; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + // Retrieve top rescoreK documents from the inner query + var topDocs = searcher.search(innerQuery, rescoreK); + vectorOperations = topDocs.totalHits.value(); + + // Retrieve top k documents from the top rescoreK query + var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource); + var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k); + return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + var that = (RescoreKnnVectorQuery.LateRescoreQuery) o; + return super.equals(o) && that.rescoreK == rescoreK; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), rescoreK); + } + } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java index 5c3c02a58d4d1..3a07e37159233 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java @@ -992,16 +992,10 @@ private void doTestDefaultDenseVectorMappings(DocumentMapper mapper, XContentBui assertThat(((FieldMapper) update.getRoot().getMapper("mapsToFloatTooBig")).fieldType().typeName(), equalTo("float")); assertThat(((FieldMapper) update.getRoot().getMapper("mapsToInt8HnswDenseVector")).fieldType().typeName(), equalTo("dense_vector")); DenseVectorFieldMapper int8DVFieldMapper = ((DenseVectorFieldMapper) update.getRoot().getMapper("mapsToInt8HnswDenseVector")); - assertThat( - ((DenseVectorFieldMapper.DenseVectorIndexOptions) int8DVFieldMapper.fieldType().getIndexOptions()).getType().getName(), - equalTo("int8_hnsw") - ); + assertThat(int8DVFieldMapper.fieldType().getIndexOptions().getType().getName(), equalTo("int8_hnsw")); assertThat(((FieldMapper) update.getRoot().getMapper("mapsToBBQHnswDenseVector")).fieldType().typeName(), equalTo("dense_vector")); DenseVectorFieldMapper bbqDVFieldMapper = ((DenseVectorFieldMapper) update.getRoot().getMapper("mapsToBBQHnswDenseVector")); - assertThat( - ((DenseVectorFieldMapper.DenseVectorIndexOptions) bbqDVFieldMapper.fieldType().getIndexOptions()).getType().getName(), - equalTo("bbq_hnsw") - ); + assertThat(bbqDVFieldMapper.fieldType().getIndexOptions().getType().getName(), equalTo("bbq_hnsw")); } public void testDefaultDenseVectorMappingsObject() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 6c7964bbf773b..da0f1f4189c42 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -15,6 +15,7 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; @@ -238,7 +239,11 @@ public void testCreateNestedKnnQuery() { if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); } - assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); + if (field.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(ToParentBlockJoinQuery.class)); + } else { + assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); + } } { DenseVectorFieldType field = new DenseVectorFieldType( @@ -269,7 +274,11 @@ public void testCreateNestedKnnQuery() { producer, randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); - assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + if (field.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(ToParentBlockJoinQuery.class)); + } else { + assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + } vectorData = new VectorData(floatQueryVector, null); query = field.createKnnQuery( @@ -282,7 +291,11 @@ public void testCreateNestedKnnQuery() { producer, randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); - assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + if (field.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(ToParentBlockJoinQuery.class)); + } else { + assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + } } } @@ -445,7 +458,11 @@ public void testCreateKnnQueryMaxDims() { if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); } - assertThat(query, instanceOf(KnnFloatVectorQuery.class)); + if (fieldWith4096dims.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); + } else { + assertThat(query, instanceOf(KnnFloatVectorQuery.class)); + } } { // byte type with 4096 dims @@ -475,7 +492,11 @@ public void testCreateKnnQueryMaxDims() { null, randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); - assertThat(query, instanceOf(KnnByteVectorQuery.class)); + if (fieldWith4096dims.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DenseVectorQuery.Bytes.class)); + } else { + assertThat(query, instanceOf(KnnByteVectorQuery.class)); + } } } @@ -574,13 +595,21 @@ public void testRescoreOversampleUsedWithoutQuantization() { ); if (elementType == BYTE) { - ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); + if (nonQuantizedField.getIndexOptions().isFlat()) { + assertThat(knnQuery, instanceOf(DenseVectorQuery.Bytes.class)); + } else { + ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } } else { - ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); + if (nonQuantizedField.getIndexOptions().isFlat()) { + assertThat(knnQuery, instanceOf(DenseVectorQuery.Floats.class)); + } else { + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } } } @@ -628,7 +657,11 @@ public void testRescoreOversampleQueryOverrides() { null, randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); - assertTrue(query instanceof ESKnnFloatVectorQuery); + if (fieldType.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); + } else { + assertThat(query, instanceOf(ESKnnFloatVectorQuery.class)); + } // verify we can override a `0` to a positive number fieldType = new DenseVectorFieldType( diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 0e295fb02eaaa..a15372bc1e8ef 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -51,6 +51,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DEFAULT_OVERSAMPLE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -203,8 +204,14 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que } } switch (elementType()) { - case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery); - case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery); + case FLOAT -> assertThat( + query, + anyOf(instanceOf(ESKnnFloatVectorQuery.class), instanceOf(DenseVectorQuery.Floats.class), instanceOf(BooleanQuery.class)) + ); + case BYTE -> assertThat( + query, + anyOf(instanceOf(ESKnnByteVectorQuery.class), instanceOf(DenseVectorQuery.Bytes.class), instanceOf(BooleanQuery.class)) + ); } BooleanQuery.Builder builder = new BooleanQuery.Builder(); @@ -244,10 +251,34 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que expectedStrategy ); }; + + Query bruteForceVectorQueryBuilt = switch (elementType()) { + case BIT, BYTE -> { + if (filterQuery != null) { + yield new BooleanQuery.Builder().add( + new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD), + BooleanClause.Occur.SHOULD + ).add(filterQuery, BooleanClause.Occur.FILTER).build(); + } else { + yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD); + } + } + case FLOAT -> { + if (filterQuery != null) { + yield new BooleanQuery.Builder().add( + new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD), + BooleanClause.Occur.SHOULD + ).add(filterQuery, BooleanClause.Occur.FILTER).build(); + } else { + yield new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD); + } + } + }; + if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); } - assertEquals(query, knnVectorQueryBuilt); + assertThat(query, anyOf(equalTo(knnVectorQueryBuilt), equalTo(bruteForceVectorQueryBuilt))); } public void testWrongDimension() { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 05b7bc9ef4f82..c2ba061ad38bc 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -18,7 +18,10 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchAllDocsQuery; @@ -41,6 +44,8 @@ import java.io.IOException; import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.List; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -54,53 +59,57 @@ public void testRescoreDocs() throws Exception { int numDims = randomIntBetween(5, 100); int k = randomIntBetween(1, numDocs - 1); + var queryVector = randomVector(numDims); + List innerQueries = new ArrayList<>(); + innerQueries.add(new KnnFloatVectorQuery(FIELD_NAME, randomVector(numDims), (int) (k * randomFloatBetween(1.0f, 10.0f, true)))); + innerQueries.add( + new BooleanQuery.Builder().add(new DenseVectorQuery.Floats(queryVector, FIELD_NAME), BooleanClause.Occur.SHOULD) + .add(new FieldExistsQuery(FIELD_NAME), BooleanClause.Occur.FILTER) + .build() + ); + innerQueries.add(new MatchAllDocsQuery()); + try (Directory d = newDirectory()) { addRandomDocuments(numDocs, d, numDims); try (IndexReader reader = DirectoryReader.open(d)) { - // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query - // and thus we're rescoring the top k docs. - float[] queryVector = randomVector(numDims); - Query innerQuery; - if (randomBoolean()) { - innerQuery = new KnnFloatVectorQuery(FIELD_NAME, queryVector, (int) (k * randomFloatBetween(1.0f, 10.0f, true))); - } else { - innerQuery = new MatchAllDocsQuery(); - } - RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( - FIELD_NAME, - queryVector, - VectorSimilarityFunction.COSINE, - k, - innerQuery - ); - - IndexSearcher searcher = newSearcher(reader, true, false); - TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs); - assertThat(rescoredDocs.scoreDocs.length, equalTo(k)); - - // Get real scores - DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource( - FIELD_NAME, - queryVector, - VectorSimilarityFunction.COSINE - ); - FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource); - TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs); - - int i = 0; - ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs; - for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) { - // There are docs that won't be found in the rescored search, but every doc found must be in the same order - // and have the same score - while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) { - i++; - } - if (i >= realScoreDocs.length) { - fail("Rescored doc not found in real score docs"); + for (var innerQuery : innerQueries) { + RescoreKnnVectorQuery rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + k, + k, + innerQuery + ); + + IndexSearcher searcher = newSearcher(reader, true, false); + TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs); + assertThat(rescoredDocs.scoreDocs.length, equalTo(k)); + + // Get real scores + DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE + ); + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource); + TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs); + + int i = 0; + ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs; + for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) { + // There are docs that won't be found in the rescored search, but every doc found must be in the same order + // and have the same score + while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) { + i++; + } + if (i >= realScoreDocs.length) { + fail("Rescored doc not found in real score docs"); + } + assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score)); } - assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score)); } } } @@ -124,11 +133,12 @@ public void testProfiling() throws Exception { } private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { - RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + var rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery( FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, + k, innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false); From 99fb2bd27de05148738f7a4145ce6d801cd8cb2e Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 27 Jun 2025 19:49:57 +0100 Subject: [PATCH 02/15] Update docs/changelog/130251.yaml --- docs/changelog/130251.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/130251.yaml diff --git a/docs/changelog/130251.yaml b/docs/changelog/130251.yaml new file mode 100644 index 0000000000000..2dd117bada717 --- /dev/null +++ b/docs/changelog/130251.yaml @@ -0,0 +1,5 @@ +pr: 130251 +summary: Speed up (filtered) KNN queries for flat vector fields +area: Vector Search +type: enhancement +issues: [] From 1ae65f994fdd60c251d9eec0e71e7e6ce808c6b8 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 27 Jun 2025 20:13:16 +0100 Subject: [PATCH 03/15] handle null index options --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0350cc0c18b2a..bfc8efe15158a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2547,7 +2547,7 @@ private Query createKnnBitQuery( ) { elementType.checkDimensions(dims, queryVector.length); Query knnQuery; - if (indexOptions.isFlat()) { + if (indexOptions != null && indexOptions.isFlat()) { knnQuery = filter == null ? createExactKnnBitQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnBitQuery(queryVector), BooleanClause.Occur.SHOULD) @@ -2588,7 +2588,7 @@ private Query createKnnByteQuery( } Query knnQuery; - if (indexOptions.isFlat()) { + if (indexOptions != null && indexOptions.isFlat()) { knnQuery = filter == null ? createExactKnnByteQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnByteQuery(queryVector), BooleanClause.Occur.SHOULD) @@ -2654,7 +2654,7 @@ && isNotUnitVector(squaredMagnitude)) { numCands = Math.max(adjustedK, numCands); } Query knnQuery; - if (indexOptions.isFlat()) { + if (indexOptions != null && indexOptions.isFlat()) { knnQuery = filter == null ? createExactKnnFloatQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnFloatQuery(queryVector), BooleanClause.Occur.SHOULD) From 55c5cc7fa74e55d586e256f6c7e4338f870bd052 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 27 Jun 2025 22:13:24 +0100 Subject: [PATCH 04/15] handle nested fields --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index bfc8efe15158a..2bae8b257d419 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -36,6 +36,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitUtil; @@ -2551,7 +2552,7 @@ private Query createKnnBitQuery( knnQuery = filter == null ? createExactKnnBitQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnBitQuery(queryVector), BooleanClause.Occur.SHOULD) - .add(filter, BooleanClause.Occur.FILTER) + .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) .build(); if (parentFilter != null) { knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); @@ -2592,7 +2593,7 @@ private Query createKnnByteQuery( knnQuery = filter == null ? createExactKnnByteQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnByteQuery(queryVector), BooleanClause.Occur.SHOULD) - .add(filter, BooleanClause.Occur.FILTER) + .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) .build(); if (parentFilter != null) { knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); @@ -2658,7 +2659,7 @@ && isNotUnitVector(squaredMagnitude)) { knnQuery = filter == null ? createExactKnnFloatQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnFloatQuery(queryVector), BooleanClause.Occur.SHOULD) - .add(filter, BooleanClause.Occur.FILTER) + .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) .build(); if (parentFilter != null) { knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); From 379f7f949e72daab622573289f6dda7afb764348 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 27 Jun 2025 22:19:20 +0100 Subject: [PATCH 05/15] nested fields should return docs at the nested level --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 9 --------- 1 file changed, 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 2bae8b257d419..04100dfbcc353 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2554,9 +2554,6 @@ private Query createKnnBitQuery( : new BooleanQuery.Builder().add(createExactKnnBitQuery(queryVector), BooleanClause.Occur.SHOULD) .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) .build(); - if (parentFilter != null) { - knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); - } } else { knnQuery = parentFilter != null ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) @@ -2595,9 +2592,6 @@ private Query createKnnByteQuery( : new BooleanQuery.Builder().add(createExactKnnByteQuery(queryVector), BooleanClause.Occur.SHOULD) .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) .build(); - if (parentFilter != null) { - knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); - } } else { knnQuery = parentFilter != null ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) @@ -2661,9 +2655,6 @@ && isNotUnitVector(squaredMagnitude)) { : new BooleanQuery.Builder().add(createExactKnnFloatQuery(queryVector), BooleanClause.Occur.SHOULD) .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) .build(); - if (parentFilter != null) { - knnQuery = new ToParentBlockJoinQuery(knnQuery, parentFilter, ScoreMode.Max); - } } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { knnQuery = parentFilter != null ? new DiversifyingChildrenIVFKnnFloatVectorQuery( From 61df4fa7a3813f467dfd7bbaafceeaa1423b6c98 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 27 Jun 2025 21:26:46 +0000 Subject: [PATCH 06/15] [CI] Auto commit changes from spotless --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 04100dfbcc353..127a875e410a6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -35,9 +35,7 @@ import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.search.join.ToChildBlockJoinQuery; -import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.BytesRef; From 3fdc47f6d54073e559b4a316334b363920239037 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 27 Jun 2025 23:43:44 +0100 Subject: [PATCH 07/15] First pass at diversifying results when the field is nested --- .../elasticsearch/test/knn/KnnSearcher.java | 2 +- .../vectors/DenseVectorFieldMapper.java | 5 +- .../search/vectors/RescoreKnnVectorQuery.java | 116 +++++++++++++++++- .../vectors/RescoreKnnVectorQueryTests.java | 6 +- 4 files changed, 117 insertions(+), 12 deletions(-) diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index c24d8cb7e94e2..62bc2fdaae4ef 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -309,7 +309,7 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException } if (overSamplingFactor > 1f) { // oversample the topK results to get more candidates for the final result - knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery); + knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery, null); } QueryProfiler profiler = new QueryProfiler(); TopDocs docs = searcher.search(knnQuery, this.topK); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 04100dfbcc353..a0a69016ad3d7 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -35,9 +35,7 @@ import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.search.join.ToChildBlockJoinQuery; -import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.BytesRef; @@ -2687,7 +2685,8 @@ && isNotUnitVector(squaredMagnitude)) { similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), k, adjustedK, - knnQuery + knnQuery, + parentFilter ); } if (similarityThreshold != null) { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index fa6240c56eda2..5b6e1809ceb53 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.BooleanClause; @@ -17,11 +18,23 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.grouping.GroupSelector; +import org.apache.lucene.search.grouping.GroupingSearch; +import org.apache.lucene.search.grouping.SearchGroup; +import org.apache.lucene.search.grouping.TopGroups; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.util.BitSet; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.Arrays; +import java.util.Collection; import java.util.Objects; /** @@ -69,6 +82,9 @@ private RescoreKnnVectorQuery( * @param k the number of top documents to return after rescoring * @param rescoreK the number of top documents to consider for rescoring * @param innerQuery the original Lucene query to rescore + * @param parentFilterProducer A filter used to produce the BitSet identifying parent documents, + * required when the field is nested to determine the parent of each child. + * If {@code null}, the query is assumed to match parent documents directly. */ public static RescoreKnnVectorQuery fromInnerQuery( String fieldName, @@ -76,14 +92,16 @@ public static RescoreKnnVectorQuery fromInnerQuery( VectorSimilarityFunction vectorSimilarityFunction, int k, int rescoreK, - Query innerQuery + Query innerQuery, + @Nullable BitSetProducer parentFilterProducer ) { if ((innerQuery instanceof KnnFloatVectorQuery fQuery && fQuery.getK() == rescoreK) || (innerQuery instanceof KnnByteVectorQuery bQuery && bQuery.getK() == rescoreK) || (innerQuery instanceof AbstractIVFKnnVectorQuery ivfQuery && ivfQuery.k == rescoreK)) { + // We ignore the nested context here since the knn query already handles the parent diversification. return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); } - return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery); + return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery, parentFilterProducer); } public Query innerQuery() { @@ -180,6 +198,7 @@ public int hashCode() { private static class LateRescoreQuery extends RescoreKnnVectorQuery { final int rescoreK; + final BitSetProducer parentFilter; private LateRescoreQuery( String fieldName, @@ -187,19 +206,36 @@ private LateRescoreQuery( VectorSimilarityFunction vectorSimilarityFunction, int k, int rescoreK, - Query innerQuery + Query innerQuery, + @Nullable BitSetProducer parentFilter ) { super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); this.rescoreK = rescoreK; + this.parentFilter = parentFilter; } + record DocAndParent(int parent, int doc) {} + @Override public Query rewrite(IndexSearcher searcher) throws IOException { - // Retrieve top rescoreK documents from the inner query - var topDocs = searcher.search(innerQuery, rescoreK); + final TopDocs topDocs; + if (parentFilter != null) { + // We're dealing with a nested field, so we need to search at the child level. + // We retrieve the top `rescoreK` child documents, but collapse them so that only + // the best child per parent is kept. + var groupSearch = new GroupingSearch(new ParentSelector(parentFilter)); + TopGroups topGroups = groupSearch.search(searcher, innerQuery, 0, rescoreK); + var scoreDocs = Arrays.stream(topGroups.groups) + .map(g -> new ScoreDoc(g.groupValue().doc, g.score())) + .toArray(ScoreDoc[]::new); + topDocs = new TopDocs(new TotalHits(topGroups.totalHitCount, TotalHits.Relation.EQUAL_TO), scoreDocs); + } else { + // Retrieve top `rescoreK` documents from the inner query + topDocs = searcher.search(innerQuery, rescoreK); + } vectorOperations = topDocs.totalHits.value(); - // Retrieve top k documents from the top rescoreK query + // Retrieve top `k` documents from the top `rescoreK` query var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource); @@ -220,4 +256,72 @@ public int hashCode() { return Objects.hash(super.hashCode(), rescoreK); } } + + private static class ParentAndDoc { + int parent; + int doc; + + ParentAndDoc(int parent, int doc) { + this.parent = parent; + this.doc = doc; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ParentAndDoc that = (ParentAndDoc) o; + return parent == that.parent; + } + + @Override + public int hashCode() { + return Objects.hash(parent); + } + } + + private static class ParentSelector extends GroupSelector { + final BitSetProducer parentFilter; + BitSet parentBitSet; + int currentParent = -1; + ParentAndDoc current; + + ParentSelector(BitSetProducer parentFilter) { + this.parentFilter = parentFilter; + } + + @Override + public void setNextReader(LeafReaderContext readerContext) throws IOException { + parentBitSet = parentFilter.getBitSet(readerContext); + } + + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public State advanceTo(int doc) throws IOException { + if (parentBitSet == null) { + return State.SKIP; + } + if (doc > currentParent) { + currentParent = parentBitSet.nextSetBit(doc); + current.parent = currentParent; + } + current.doc = doc; + return State.ACCEPT; + } + + @Override + public ParentAndDoc currentValue() throws IOException { + return current; + } + + @Override + public ParentAndDoc copyValue() throws IOException { + return new ParentAndDoc(current.parent, current.doc); + } + + @Override + public void setGroups(Collection> searchGroups) {} + } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index c2ba061ad38bc..24a2583e2263c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -81,7 +81,8 @@ public void testRescoreDocs() throws Exception { VectorSimilarityFunction.COSINE, k, k, - innerQuery + innerQuery, + null ); IndexSearcher searcher = newSearcher(reader, true, false); @@ -139,7 +140,8 @@ private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader VectorSimilarityFunction.COSINE, k, k, - innerQuery + innerQuery, + null ); IndexSearcher searcher = newSearcher(reader, true, false); searcher.search(rescoreKnnVectorQuery, numDocs); From be25897649a2b92a9d56041cbf896b6675a5cd16 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Sat, 28 Jun 2025 00:05:34 +0100 Subject: [PATCH 08/15] collapse using a single pass --- .../search/vectors/RescoreKnnVectorQuery.java | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 5b6e1809ceb53..4215deb6163e8 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -20,12 +20,12 @@ import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.grouping.FirstPassGroupingCollector; import org.apache.lucene.search.grouping.GroupSelector; -import org.apache.lucene.search.grouping.GroupingSearch; import org.apache.lucene.search.grouping.SearchGroup; -import org.apache.lucene.search.grouping.TopGroups; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.BitSet; import org.elasticsearch.core.Nullable; @@ -217,18 +217,22 @@ private LateRescoreQuery( record DocAndParent(int parent, int doc) {} @Override + @SuppressWarnings("unchecked") public Query rewrite(IndexSearcher searcher) throws IOException { final TopDocs topDocs; if (parentFilter != null) { // We're dealing with a nested field, so we need to search at the child level. // We retrieve the top `rescoreK` child documents, but collapse them so that only // the best child per parent is kept. - var groupSearch = new GroupingSearch(new ParentSelector(parentFilter)); - TopGroups topGroups = groupSearch.search(searcher, innerQuery, 0, rescoreK); - var scoreDocs = Arrays.stream(topGroups.groups) - .map(g -> new ScoreDoc(g.groupValue().doc, g.score())) - .toArray(ScoreDoc[]::new); - topDocs = new TopDocs(new TotalHits(topGroups.totalHitCount, TotalHits.Relation.EQUAL_TO), scoreDocs); + FirstPassGroupingCollector groupingCollector = new FirstPassGroupingCollector<>( + new ParentSelector(parentFilter), + Sort.RELEVANCE, + rescoreK + ); + searcher.search(innerQuery, groupingCollector); + var groups = groupingCollector.getTopGroups(0); + var scoreDocs = groups.stream().map(g -> new ScoreDoc(g.groupValue.doc, (float) g.sortValues[0])).toArray(ScoreDoc[]::new); + topDocs = new TopDocs(new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO), scoreDocs); } else { // Retrieve top `rescoreK` documents from the inner query topDocs = searcher.search(innerQuery, rescoreK); From ed9daad09c1fc1bb9df19810878eb6a8f7447d83 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Sat, 28 Jun 2025 00:06:28 +0100 Subject: [PATCH 09/15] unused --- .../org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 4215deb6163e8..f0046683edcb5 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -214,8 +214,6 @@ private LateRescoreQuery( this.parentFilter = parentFilter; } - record DocAndParent(int parent, int doc) {} - @Override @SuppressWarnings("unchecked") public Query rewrite(IndexSearcher searcher) throws IOException { From 9a1b722dae845f6df828ccb8fc7ffaac2bc2ea95 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Sat, 28 Jun 2025 01:55:15 +0100 Subject: [PATCH 10/15] iter --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index a0a69016ad3d7..f2bad988e9f81 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -35,7 +35,6 @@ import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.BytesRef; @@ -2550,7 +2549,7 @@ private Query createKnnBitQuery( knnQuery = filter == null ? createExactKnnBitQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnBitQuery(queryVector), BooleanClause.Occur.SHOULD) - .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) + .add(filter, BooleanClause.Occur.FILTER) .build(); } else { knnQuery = parentFilter != null @@ -2588,7 +2587,7 @@ private Query createKnnByteQuery( knnQuery = filter == null ? createExactKnnByteQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnByteQuery(queryVector), BooleanClause.Occur.SHOULD) - .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) + .add(filter, BooleanClause.Occur.FILTER) .build(); } else { knnQuery = parentFilter != null @@ -2651,7 +2650,7 @@ && isNotUnitVector(squaredMagnitude)) { knnQuery = filter == null ? createExactKnnFloatQuery(queryVector) : new BooleanQuery.Builder().add(createExactKnnFloatQuery(queryVector), BooleanClause.Occur.SHOULD) - .add(parentFilter != null ? new ToChildBlockJoinQuery(filter, parentFilter) : filter, BooleanClause.Occur.FILTER) + .add(filter, BooleanClause.Occur.FILTER) .build(); } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { knnQuery = parentFilter != null From 2371ec333f968779bbc51828fe7b79965bcb8baa Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Sat, 28 Jun 2025 11:43:30 +0100 Subject: [PATCH 11/15] always apply diversification when nested and flat --- .../elasticsearch/test/knn/KnnSearcher.java | 2 +- .../vectors/DenseVectorFieldMapper.java | 19 +- .../vectors/DiversifyingParentBlockQuery.java | 195 ++++++++++++++++++ .../search/vectors/RescoreKnnVectorQuery.java | 115 +---------- .../vectors/RescoreKnnVectorQueryTests.java | 4 +- 5 files changed, 217 insertions(+), 118 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index 62bc2fdaae4ef..c24d8cb7e94e2 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -309,7 +309,7 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException } if (overSamplingFactor > 1f) { // oversample the topK results to get more candidates for the final result - knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery, null); + knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery); } QueryProfiler profiler = new QueryProfiler(); TopDocs docs = searcher.search(knnQuery, this.topK); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index f2bad988e9f81..d389b48c7bf77 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -79,6 +79,7 @@ import org.elasticsearch.search.lookup.Source; import org.elasticsearch.search.vectors.DenseVectorQuery; import org.elasticsearch.search.vectors.DiversifyingChildrenIVFKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery; import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery; import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; @@ -2546,9 +2547,12 @@ private Query createKnnBitQuery( elementType.checkDimensions(dims, queryVector.length); Query knnQuery; if (indexOptions != null && indexOptions.isFlat()) { + var exactKnnQuery = parentFilter != null + ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnBitQuery(queryVector)) + : createExactKnnBitQuery(queryVector); knnQuery = filter == null ? createExactKnnBitQuery(queryVector) - : new BooleanQuery.Builder().add(createExactKnnBitQuery(queryVector), BooleanClause.Occur.SHOULD) + : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) .add(filter, BooleanClause.Occur.FILTER) .build(); } else { @@ -2584,9 +2588,12 @@ private Query createKnnByteQuery( Query knnQuery; if (indexOptions != null && indexOptions.isFlat()) { + var exactKnnQuery = parentFilter != null + ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnByteQuery(queryVector)) + : createExactKnnByteQuery(queryVector); knnQuery = filter == null ? createExactKnnByteQuery(queryVector) - : new BooleanQuery.Builder().add(createExactKnnByteQuery(queryVector), BooleanClause.Occur.SHOULD) + : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) .add(filter, BooleanClause.Occur.FILTER) .build(); } else { @@ -2647,9 +2654,12 @@ && isNotUnitVector(squaredMagnitude)) { } Query knnQuery; if (indexOptions != null && indexOptions.isFlat()) { + var exactKnnQuery = parentFilter != null + ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnFloatQuery(queryVector)) + : createExactKnnFloatQuery(queryVector); knnQuery = filter == null ? createExactKnnFloatQuery(queryVector) - : new BooleanQuery.Builder().add(createExactKnnFloatQuery(queryVector), BooleanClause.Occur.SHOULD) + : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) .add(filter, BooleanClause.Occur.FILTER) .build(); } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { @@ -2684,8 +2694,7 @@ && isNotUnitVector(squaredMagnitude)) { similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), k, adjustedK, - knnQuery, - parentFilter + knnQuery ); } if (similarityThreshold != null) { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java new file mode 100644 index 0000000000000..052914697a873 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java @@ -0,0 +1,195 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; + +import java.io.IOException; +import java.util.Objects; + +/** + * A Lucene query that selects the highest-scoring child document for each parent block. + *

+ * Children are scored using the {@code innerQuery}, and for each parent (as defined by the + * {@code parentFilter}), the single best-scoring child is returned. + */ +public class DiversifyingParentBlockQuery extends Query { + private final BitSetProducer parentFilter; + private final Query innerQuery; + + public DiversifyingParentBlockQuery(BitSetProducer parentFilter, Query innerQuery) { + this.parentFilter = Objects.requireNonNull(parentFilter); + this.innerQuery = Objects.requireNonNull(innerQuery); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query rewritten = innerQuery.rewrite(indexSearcher); + if (rewritten != innerQuery) { + return new DiversifyingParentBlockQuery(parentFilter, rewritten); + } + return this; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + Weight innerWeight = innerQuery.createWeight(searcher, scoreMode, boost); + return new DiversifyingParentBlockWeight(this, innerWeight, parentFilter); + } + + @Override + public String toString(String field) { + return "DiversifyingBlockQuery(inner=" + innerQuery.toString(field) + ")"; + } + + @Override + public void visit(QueryVisitor visitor) { + innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DiversifyingParentBlockQuery that = (DiversifyingParentBlockQuery) o; + return Objects.equals(innerQuery, that.innerQuery) && parentFilter == that.parentFilter; + } + + @Override + public int hashCode() { + return Objects.hash(innerQuery, parentFilter); + } + + private static class DiversifyingParentBlockWeight extends Weight { + private final Weight innerWeight; + private final BitSetProducer parentFilter; + + DiversifyingParentBlockWeight(Query query, Weight innerWeight, BitSetProducer parentFilter) { + super(query); + this.innerWeight = innerWeight; + this.parentFilter = parentFilter; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return innerWeight.explain(context, doc); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + var innerSupplier = innerWeight.scorerSupplier(context); + var parentBits = parentFilter.getBitSet(context); + if (parentBits == null || innerSupplier == null) { + return null; + } + + return new ScorerSupplier() { + @Override + public Scorer get(long leadCost) throws IOException { + var innerScorer = innerSupplier.get(leadCost); + var innerIterator = innerScorer.iterator(); + return new Scorer() { + int currentDoc = -1; + float currentScore = Float.NaN; + + @Override + public int docID() { + return currentDoc; + } + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + boolean exhausted = false; + + @Override + public int docID() { + return currentDoc; + } + + @Override + public int nextDoc() throws IOException { + return advance(currentDoc + 1); + } + + @Override + public int advance(int target) throws IOException { + if (exhausted) { + return NO_MORE_DOCS; + } + if (currentDoc == -1 || innerIterator.docID() < target) { + if (innerIterator.advance(target) == NO_MORE_DOCS) { + exhausted = true; + return currentDoc = NO_MORE_DOCS; + } + } + + int bestChild = innerIterator.docID(); + float bestScore = innerScorer.score(); + int parent = parentBits.nextSetBit(bestChild); + + int innerDoc; + while ((innerDoc = innerIterator.nextDoc()) < parent) { + float score = innerScorer.score(); + if (score > bestScore) { + bestChild = innerIterator.docID(); + bestScore = score; + } + } + if (innerDoc == NO_MORE_DOCS) { + exhausted = true; + } + currentScore = bestScore; + return currentDoc = bestChild; + } + + @Override + public long cost() { + return innerIterator.cost(); + } + }; + } + + @Override + public float score() throws IOException { + return currentScore; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return innerScorer.getMaxScore(upTo); + } + }; + } + + @Override + public long cost() { + return innerSupplier.cost(); + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index f0046683edcb5..c7346bb9edd75 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.vectors; -import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.BooleanClause; @@ -18,23 +17,12 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.Scorable; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.grouping.FirstPassGroupingCollector; -import org.apache.lucene.search.grouping.GroupSelector; -import org.apache.lucene.search.grouping.SearchGroup; -import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.util.BitSet; -import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.Arrays; -import java.util.Collection; import java.util.Objects; /** @@ -82,9 +70,6 @@ private RescoreKnnVectorQuery( * @param k the number of top documents to return after rescoring * @param rescoreK the number of top documents to consider for rescoring * @param innerQuery the original Lucene query to rescore - * @param parentFilterProducer A filter used to produce the BitSet identifying parent documents, - * required when the field is nested to determine the parent of each child. - * If {@code null}, the query is assumed to match parent documents directly. */ public static RescoreKnnVectorQuery fromInnerQuery( String fieldName, @@ -92,16 +77,15 @@ public static RescoreKnnVectorQuery fromInnerQuery( VectorSimilarityFunction vectorSimilarityFunction, int k, int rescoreK, - Query innerQuery, - @Nullable BitSetProducer parentFilterProducer + Query innerQuery ) { if ((innerQuery instanceof KnnFloatVectorQuery fQuery && fQuery.getK() == rescoreK) || (innerQuery instanceof KnnByteVectorQuery bQuery && bQuery.getK() == rescoreK) || (innerQuery instanceof AbstractIVFKnnVectorQuery ivfQuery && ivfQuery.k == rescoreK)) { - // We ignore the nested context here since the knn query already handles the parent diversification. + // Queries that return only the top `k` results and do not require reduction before re-scoring. return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); } - return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery, parentFilterProducer); + return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery); } public Query innerQuery() { @@ -198,7 +182,6 @@ public int hashCode() { private static class LateRescoreQuery extends RescoreKnnVectorQuery { final int rescoreK; - final BitSetProducer parentFilter; private LateRescoreQuery( String fieldName, @@ -206,35 +189,17 @@ private LateRescoreQuery( VectorSimilarityFunction vectorSimilarityFunction, int k, int rescoreK, - Query innerQuery, - @Nullable BitSetProducer parentFilter + Query innerQuery ) { super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); this.rescoreK = rescoreK; - this.parentFilter = parentFilter; } @Override - @SuppressWarnings("unchecked") public Query rewrite(IndexSearcher searcher) throws IOException { final TopDocs topDocs; - if (parentFilter != null) { - // We're dealing with a nested field, so we need to search at the child level. - // We retrieve the top `rescoreK` child documents, but collapse them so that only - // the best child per parent is kept. - FirstPassGroupingCollector groupingCollector = new FirstPassGroupingCollector<>( - new ParentSelector(parentFilter), - Sort.RELEVANCE, - rescoreK - ); - searcher.search(innerQuery, groupingCollector); - var groups = groupingCollector.getTopGroups(0); - var scoreDocs = groups.stream().map(g -> new ScoreDoc(g.groupValue.doc, (float) g.sortValues[0])).toArray(ScoreDoc[]::new); - topDocs = new TopDocs(new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO), scoreDocs); - } else { - // Retrieve top `rescoreK` documents from the inner query - topDocs = searcher.search(innerQuery, rescoreK); - } + // Retrieve top `rescoreK` documents from the inner query + topDocs = searcher.search(innerQuery, rescoreK); vectorOperations = topDocs.totalHits.value(); // Retrieve top `k` documents from the top `rescoreK` query @@ -258,72 +223,4 @@ public int hashCode() { return Objects.hash(super.hashCode(), rescoreK); } } - - private static class ParentAndDoc { - int parent; - int doc; - - ParentAndDoc(int parent, int doc) { - this.parent = parent; - this.doc = doc; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ParentAndDoc that = (ParentAndDoc) o; - return parent == that.parent; - } - - @Override - public int hashCode() { - return Objects.hash(parent); - } - } - - private static class ParentSelector extends GroupSelector { - final BitSetProducer parentFilter; - BitSet parentBitSet; - int currentParent = -1; - ParentAndDoc current; - - ParentSelector(BitSetProducer parentFilter) { - this.parentFilter = parentFilter; - } - - @Override - public void setNextReader(LeafReaderContext readerContext) throws IOException { - parentBitSet = parentFilter.getBitSet(readerContext); - } - - @Override - public void setScorer(Scorable scorer) throws IOException {} - - @Override - public State advanceTo(int doc) throws IOException { - if (parentBitSet == null) { - return State.SKIP; - } - if (doc > currentParent) { - currentParent = parentBitSet.nextSetBit(doc); - current.parent = currentParent; - } - current.doc = doc; - return State.ACCEPT; - } - - @Override - public ParentAndDoc currentValue() throws IOException { - return current; - } - - @Override - public ParentAndDoc copyValue() throws IOException { - return new ParentAndDoc(current.parent, current.doc); - } - - @Override - public void setGroups(Collection> searchGroups) {} - } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 24a2583e2263c..8da81171160c1 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -81,8 +81,7 @@ public void testRescoreDocs() throws Exception { VectorSimilarityFunction.COSINE, k, k, - innerQuery, - null + innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false); @@ -140,7 +139,6 @@ private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader VectorSimilarityFunction.COSINE, k, k, - innerQuery, null ); IndexSearcher searcher = newSearcher(reader, true, false); From d7dc4238729fe629ef3a3678a24a9bba25e43d19 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Sat, 28 Jun 2025 12:20:51 +0100 Subject: [PATCH 12/15] fix uts --- .../index/mapper/vectors/DenseVectorFieldTypeTests.java | 3 ++- .../search/vectors/RescoreKnnVectorQueryTests.java | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index da0f1f4189c42..6a053de71e8a2 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; +import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; @@ -240,7 +241,7 @@ public void testCreateNestedKnnQuery() { query = rescoreKnnVectorQuery.innerQuery(); } if (field.getIndexOptions().isFlat()) { - assertThat(query, instanceOf(ToParentBlockJoinQuery.class)); + assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); } else { assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 8da81171160c1..c2ba061ad38bc 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -139,7 +139,7 @@ private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader VectorSimilarityFunction.COSINE, k, k, - null + innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false); searcher.search(rescoreKnnVectorQuery, numDocs); From e1abb0c548318070114d4cbe51cf19aa91f87da3 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Sat, 28 Jun 2025 17:44:37 +0100 Subject: [PATCH 13/15] fix ut --- .../index/mapper/vectors/DenseVectorFieldTypeTests.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 6a053de71e8a2..4a933efa9516a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -15,7 +15,6 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; -import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; @@ -276,7 +275,7 @@ public void testCreateNestedKnnQuery() { randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); if (field.getIndexOptions().isFlat()) { - assertThat(query, instanceOf(ToParentBlockJoinQuery.class)); + assertThat(query, instanceOf(DenseVectorQuery.class)); } else { assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } @@ -293,7 +292,7 @@ public void testCreateNestedKnnQuery() { randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); if (field.getIndexOptions().isFlat()) { - assertThat(query, instanceOf(ToParentBlockJoinQuery.class)); + assertThat(query, instanceOf(DenseVectorQuery.class)); } else { assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } From bda1c7329dc0b5b307f0c0b89c70c0e648bcd403 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Sun, 29 Jun 2025 22:45:28 +0100 Subject: [PATCH 14/15] add more tests --- .../vectors/DenseVectorFieldMapper.java | 6 +- .../vectors/DenseVectorFieldTypeTests.java | 1 + .../DiversifyingParentBlockQueryTests.java | 162 ++++++++++++++++++ 3 files changed, 166 insertions(+), 3 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index d389b48c7bf77..819d9608f2348 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2551,7 +2551,7 @@ private Query createKnnBitQuery( ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnBitQuery(queryVector)) : createExactKnnBitQuery(queryVector); knnQuery = filter == null - ? createExactKnnBitQuery(queryVector) + ? exactKnnQuery : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) .add(filter, BooleanClause.Occur.FILTER) .build(); @@ -2592,7 +2592,7 @@ private Query createKnnByteQuery( ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnByteQuery(queryVector)) : createExactKnnByteQuery(queryVector); knnQuery = filter == null - ? createExactKnnByteQuery(queryVector) + ? exactKnnQuery : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) .add(filter, BooleanClause.Occur.FILTER) .build(); @@ -2658,7 +2658,7 @@ && isNotUnitVector(squaredMagnitude)) { ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnFloatQuery(queryVector)) : createExactKnnFloatQuery(queryVector); knnQuery = filter == null - ? createExactKnnFloatQuery(queryVector) + ? exactKnnQuery : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) .add(filter, BooleanClause.Occur.FILTER) .build(); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 4a933efa9516a..6bf1eb2ef0aa5 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery; +import org.elasticsearch.search.vectors.DiversifyingParentBlockQueryTests; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; diff --git a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java new file mode 100644 index 0000000000000..320b3efc4924c --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.search.join.ToParentBlockJoinQuery; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.index.mapper.MapperServiceTestCase; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.SourceToParse; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.TreeMap; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class DiversifyingParentBlockQueryTests extends MapperServiceTestCase { + private static String getMapping(int dim) { + return String.format(Locale.ROOT, """ + { + "_doc": { + "properties": { + "id": { + "type": "keyword", + "store": true + }, + "nested": { + "type": "nested", + "properties": { + "emb": { + "type": "dense_vector", + "dims": %d, + "similarity": "l2_norm", + "index_options": { + "type": "flat" + } + } + } + } + } + } + } + } + """, dim); + } + + public void testRandom() throws IOException { + int dims = randomIntBetween(3, 10); + var mapperService = createMapperService(getMapping(dims)); + var fieldType = (DenseVectorFieldMapper.DenseVectorFieldType) mapperService.fieldType("nested.emb"); + var nestedParent = mapperService.mappingLookup().nestedLookup().getNestedMappers().get("nested"); + + int numQueries = randomIntBetween(1, 3); + float[][] queries = new float[numQueries][]; + List> expectedTopDocs = new ArrayList<>(); + for (int i = 0; i < numQueries; i++) { + queries[i] = randomVector(dims); + expectedTopDocs.add(new TreeMap<>((o1, o2) -> -Float.compare(o1, o2))); + } + + withLuceneIndex(mapperService, iw -> { + int numDocs = randomIntBetween(10, 50); + for (int i = 0; i < numDocs; i++) { + int numVectors = randomIntBetween(0, 5); + float[][] vectors = new float[numVectors][]; + for (int j = 0; j < numVectors; j++) { + vectors[j] = randomVector(dims); + } + + for (int k = 0; k < numQueries; k++) { + float maxScore = Float.MIN_VALUE; + for (int j = 0; j < numVectors; j++) { + float score = EUCLIDEAN.compare(vectors[j], queries[k]); + maxScore = Math.max(score, maxScore); + } + expectedTopDocs.get(k).put(maxScore, Integer.toString(i)); + } + + SourceToParse source = randomSource(Integer.toString(i), vectors); + ParsedDocument doc = mapperService.documentMapper().parse(source); + iw.addDocuments(doc.docs()); + + if (randomBoolean()) { + int numEmpty = randomIntBetween(1, 3); + for (int l = 0; l < numEmpty; l++) { + source = randomSource(randomAlphaOfLengthBetween(5, 10), new float[0][]); + doc = mapperService.documentMapper().parse(source); + iw.addDocuments(doc.docs()); + } + } + } + }, ir -> { + var storedFields = ir.storedFields(); + var searcher = new IndexSearcher(wrapInMockESDirectoryReader(ir)); + var context = createSearchExecutionContext(mapperService); + var bitSetproducer = context.bitsetFilter(nestedParent.parentTypeFilter()); + for (int i = 0; i < numQueries; i++) { + var knnQuery = fieldType.createKnnQuery( + VectorData.fromFloats(queries[i]), + 10, + 10, + null, + null, + null, + bitSetproducer, + DenseVectorFieldMapper.FilterHeuristic.ACORN + ); + assertThat(knnQuery, instanceOf(DiversifyingParentBlockQuery.class)); + var nestedQuery = new ToParentBlockJoinQuery(knnQuery, bitSetproducer, ScoreMode.Total); + var topDocs = searcher.search(nestedQuery, 10); + for (var doc : topDocs.scoreDocs) { + var entry = expectedTopDocs.get(i).pollFirstEntry(); + assertNotNull(entry); + assertThat(doc.score, equalTo(entry.getKey())); + var storedDoc = storedFields.document(doc.doc, Set.of("id")); + assertThat(storedDoc.getField("id").binaryValue().utf8ToString(), equalTo(entry.getValue())); + } + } + }); + } + + private SourceToParse randomSource(String id, float[][] vectors) throws IOException { + try (var builder = XContentBuilder.builder(XContentType.JSON.xContent())) { + builder.startObject(); + builder.field("id", id); + builder.startArray("nested"); + for (int i = 0; i < vectors.length; i++) { + builder.startObject(); + builder.field("emb", vectors[i]); + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + return new SourceToParse(id, BytesReference.bytes(builder), XContentType.JSON); + } + } + + private float[] randomVector(int dim) { + float[] vector = new float[dim]; + for (int i = 0; i < vector.length; i++) { + vector[i] = randomFloat(); + } + return vector; + } +} From 49d29ba1331230b2bf7f0b00ef6e29acc31f05f9 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Sun, 29 Jun 2025 21:52:44 +0000 Subject: [PATCH 15/15] [CI] Auto commit changes from spotless --- .../index/mapper/vectors/DenseVectorFieldTypeTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 6bf1eb2ef0aa5..4a933efa9516a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -26,7 +26,6 @@ import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery; -import org.elasticsearch.search.vectors.DiversifyingParentBlockQueryTests; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; import org.elasticsearch.search.vectors.RescoreKnnVectorQuery;