diff --git a/docs/changelog/130251.yaml b/docs/changelog/130251.yaml new file mode 100644 index 0000000000000..2dd117bada717 --- /dev/null +++ b/docs/changelog/130251.yaml @@ -0,0 +1,5 @@ +pr: 130251 +summary: Speed up (filtered) KNN queries for flat vector fields +area: Vector Search +type: enhancement +issues: [] diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index 7dd6f2894a20a..c24d8cb7e94e2 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -309,7 +309,7 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException } if (overSamplingFactor > 1f) { // oversample the topK results to get more candidates for the final result - knnQuery = new RescoreKnnVectorQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, knnQuery); + knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery); } QueryProfiler profiler = new QueryProfiler(); TopDocs docs = searcher.search(knnQuery, this.topK); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 329e426be7f47..819d9608f2348 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -30,6 +30,8 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; @@ -77,6 +79,7 @@ import org.elasticsearch.search.lookup.Source; import org.elasticsearch.search.vectors.DenseVectorQuery; import org.elasticsearch.search.vectors.DiversifyingChildrenIVFKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery; import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery; import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; @@ -1391,6 +1394,18 @@ public final boolean equals(Object other) { public final int hashCode() { return Objects.hash(type, doHashCode()); } + + /** + * Indicates whether the underlying vector search is performed using a flat (exhaustive) approach. + *

+ * When {@code true}, it means the search does not use any approximate nearest neighbor (ANN) + * acceleration structures such as HNSW or IVF. Instead, it performs a brute-force comparison + * against all candidate vectors. This information can be used by higher-level components + * to decide whether additional acceleration or optimization is necessary. + * + * @return {@code true} if the vector search is flat (exhaustive), {@code false} if it uses ANN structures + */ + abstract boolean isFlat(); } abstract static class QuantizedIndexOptions extends DenseVectorIndexOptions { @@ -1762,6 +1777,11 @@ int doHashCode() { return Objects.hash(confidenceInterval, rescoreVector); } + @Override + boolean isFlat() { + return true; + } + @Override public boolean updatableTo(DenseVectorIndexOptions update) { return update.type.equals(this.type) @@ -1810,6 +1830,11 @@ public boolean doEquals(DenseVectorIndexOptions o) { public int doHashCode() { return Objects.hash(type); } + + @Override + boolean isFlat() { + return true; + } } public static class Int4HnswIndexOptions extends QuantizedIndexOptions { @@ -1860,6 +1885,11 @@ public int doHashCode() { return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } + @Override + boolean isFlat() { + return false; + } + @Override public String toString() { return "{type=" @@ -1931,6 +1961,11 @@ public int doHashCode() { return Objects.hash(confidenceInterval, rescoreVector); } + @Override + boolean isFlat() { + return true; + } + @Override public String toString() { return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}"; @@ -1999,6 +2034,11 @@ public int doHashCode() { return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); } + @Override + boolean isFlat() { + return false; + } + @Override public String toString() { return "{type=" @@ -2088,6 +2128,11 @@ public int doHashCode() { return Objects.hash(m, efConstruction); } + @Override + boolean isFlat() { + return false; + } + @Override public String toString() { return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; @@ -2126,6 +2171,11 @@ int doHashCode() { return Objects.hash(m, efConstruction, rescoreVector); } + @Override + boolean isFlat() { + return false; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -2179,6 +2229,11 @@ int doHashCode() { return CLASS_NAME_HASH; } + @Override + boolean isFlat() { + return true; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -2237,6 +2292,11 @@ int doHashCode() { return Objects.hash(clusterSize, defaultNProbe, rescoreVector); } + @Override + boolean isFlat() { + return false; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -2485,9 +2545,21 @@ private Query createKnnBitQuery( KnnSearchStrategy searchStrategy ) { elementType.checkDimensions(dims, queryVector.length); - Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + Query knnQuery; + if (indexOptions != null && indexOptions.isFlat()) { + var exactKnnQuery = parentFilter != null + ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnBitQuery(queryVector)) + : createExactKnnBitQuery(queryVector); + knnQuery = filter == null + ? exactKnnQuery + : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) + .add(filter, BooleanClause.Occur.FILTER) + .build(); + } else { + knnQuery = parentFilter != null + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2513,9 +2585,22 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + + Query knnQuery; + if (indexOptions != null && indexOptions.isFlat()) { + var exactKnnQuery = parentFilter != null + ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnByteQuery(queryVector)) + : createExactKnnByteQuery(queryVector); + knnQuery = filter == null + ? exactKnnQuery + : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) + .add(filter, BooleanClause.Occur.FILTER) + .build(); + } else { + knnQuery = parentFilter != null + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); + } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2568,7 +2653,16 @@ && isNotUnitVector(squaredMagnitude)) { numCands = Math.max(adjustedK, numCands); } Query knnQuery; - if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { + if (indexOptions != null && indexOptions.isFlat()) { + var exactKnnQuery = parentFilter != null + ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnFloatQuery(queryVector)) + : createExactKnnFloatQuery(queryVector); + knnQuery = filter == null + ? exactKnnQuery + : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD) + .add(filter, BooleanClause.Occur.FILTER) + .build(); + } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { knnQuery = parentFilter != null ? new DiversifyingChildrenIVFKnnFloatVectorQuery( name(), @@ -2594,11 +2688,12 @@ && isNotUnitVector(squaredMagnitude)) { : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy); } if (rescore) { - knnQuery = new RescoreKnnVectorQuery( + knnQuery = RescoreKnnVectorQuery.fromInnerQuery( name(), queryVector, similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), k, + adjustedK, knnQuery ); } @@ -2624,7 +2719,7 @@ ElementType getElementType() { return elementType; } - public IndexOptions getIndexOptions() { + public DenseVectorIndexOptions getIndexOptions() { return indexOptions; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java index 31e19b6784757..dfc84f5eaef6d 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java @@ -207,4 +207,5 @@ public int docID() { return iterator.docID(); } } + } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java new file mode 100644 index 0000000000000..052914697a873 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java @@ -0,0 +1,195 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; + +import java.io.IOException; +import java.util.Objects; + +/** + * A Lucene query that selects the highest-scoring child document for each parent block. + *

+ * Children are scored using the {@code innerQuery}, and for each parent (as defined by the + * {@code parentFilter}), the single best-scoring child is returned. + */ +public class DiversifyingParentBlockQuery extends Query { + private final BitSetProducer parentFilter; + private final Query innerQuery; + + public DiversifyingParentBlockQuery(BitSetProducer parentFilter, Query innerQuery) { + this.parentFilter = Objects.requireNonNull(parentFilter); + this.innerQuery = Objects.requireNonNull(innerQuery); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query rewritten = innerQuery.rewrite(indexSearcher); + if (rewritten != innerQuery) { + return new DiversifyingParentBlockQuery(parentFilter, rewritten); + } + return this; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + Weight innerWeight = innerQuery.createWeight(searcher, scoreMode, boost); + return new DiversifyingParentBlockWeight(this, innerWeight, parentFilter); + } + + @Override + public String toString(String field) { + return "DiversifyingBlockQuery(inner=" + innerQuery.toString(field) + ")"; + } + + @Override + public void visit(QueryVisitor visitor) { + innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DiversifyingParentBlockQuery that = (DiversifyingParentBlockQuery) o; + return Objects.equals(innerQuery, that.innerQuery) && parentFilter == that.parentFilter; + } + + @Override + public int hashCode() { + return Objects.hash(innerQuery, parentFilter); + } + + private static class DiversifyingParentBlockWeight extends Weight { + private final Weight innerWeight; + private final BitSetProducer parentFilter; + + DiversifyingParentBlockWeight(Query query, Weight innerWeight, BitSetProducer parentFilter) { + super(query); + this.innerWeight = innerWeight; + this.parentFilter = parentFilter; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return innerWeight.explain(context, doc); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + var innerSupplier = innerWeight.scorerSupplier(context); + var parentBits = parentFilter.getBitSet(context); + if (parentBits == null || innerSupplier == null) { + return null; + } + + return new ScorerSupplier() { + @Override + public Scorer get(long leadCost) throws IOException { + var innerScorer = innerSupplier.get(leadCost); + var innerIterator = innerScorer.iterator(); + return new Scorer() { + int currentDoc = -1; + float currentScore = Float.NaN; + + @Override + public int docID() { + return currentDoc; + } + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + boolean exhausted = false; + + @Override + public int docID() { + return currentDoc; + } + + @Override + public int nextDoc() throws IOException { + return advance(currentDoc + 1); + } + + @Override + public int advance(int target) throws IOException { + if (exhausted) { + return NO_MORE_DOCS; + } + if (currentDoc == -1 || innerIterator.docID() < target) { + if (innerIterator.advance(target) == NO_MORE_DOCS) { + exhausted = true; + return currentDoc = NO_MORE_DOCS; + } + } + + int bestChild = innerIterator.docID(); + float bestScore = innerScorer.score(); + int parent = parentBits.nextSetBit(bestChild); + + int innerDoc; + while ((innerDoc = innerIterator.nextDoc()) < parent) { + float score = innerScorer.score(); + if (score > bestScore) { + bestChild = innerIterator.docID(); + bestScore = score; + } + } + if (innerDoc == NO_MORE_DOCS) { + exhausted = true; + } + currentScore = bestScore; + return currentDoc = bestChild; + } + + @Override + public long cost() { + return innerIterator.cost(); + } + }; + } + + @Override + public float score() throws IOException { + return currentScore; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return innerScorer.getMaxScore(upTo); + } + }; + } + + @Override + public long cost() { + return innerSupplier.cost(); + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 99568a507ffb9..c7346bb9edd75 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -12,8 +12,9 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TopDocs; @@ -25,17 +26,28 @@ import java.util.Objects; /** - * Wraps an internal query to rescore the results using a similarity function over the original, non-quantized vectors of a vector field + * A Lucene {@link Query} that applies vector-based rescoring to an inner query's results. + *

+ * Depending on the nature of the {@code innerQuery}, this class dynamically selects between two rescoring strategies: + *

*/ -public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider { - private final String fieldName; - private final float[] floatTarget; - private final VectorSimilarityFunction vectorSimilarityFunction; - private final int k; - private final Query innerQuery; - private long vectorOperations = 0; - - public RescoreKnnVectorQuery( +public abstract class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider { + protected final String fieldName; + protected final float[] floatTarget; + protected final VectorSimilarityFunction vectorSimilarityFunction; + protected final int k; + protected final Query innerQuery; + protected long vectorOperations = 0; + + private RescoreKnnVectorQuery( String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, @@ -49,16 +61,31 @@ public RescoreKnnVectorQuery( this.innerQuery = innerQuery; } - @Override - public Query rewrite(IndexSearcher searcher) throws IOException { - DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); - FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); - Query query = searcher.rewrite(functionScoreQuery); - - // Retrieve top k documents from the rescored query - TopDocs topDocs = searcher.search(query, k); - vectorOperations = topDocs.totalHits.value(); - return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + /** + * Selects and returns the appropriate {@link RescoreKnnVectorQuery} strategy based on the nature of the {@code innerQuery}. + * + * @param fieldName the name of the field containing the vector + * @param floatTarget the target vector to compare against + * @param vectorSimilarityFunction the similarity function to apply + * @param k the number of top documents to return after rescoring + * @param rescoreK the number of top documents to consider for rescoring + * @param innerQuery the original Lucene query to rescore + */ + public static RescoreKnnVectorQuery fromInnerQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + int k, + int rescoreK, + Query innerQuery + ) { + if ((innerQuery instanceof KnnFloatVectorQuery fQuery && fQuery.getK() == rescoreK) + || (innerQuery instanceof KnnByteVectorQuery bQuery && bQuery.getK() == rescoreK) + || (innerQuery instanceof AbstractIVFKnnVectorQuery ivfQuery && ivfQuery.k == rescoreK)) { + // Queries that return only the top `k` results and do not require reduction before re-scoring. + return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); + } + return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery); } public Query innerQuery() { @@ -102,7 +129,8 @@ public int hashCode() { @Override public String toString(String field) { - return "KnnRescoreVectorQuery{" + return getClass().getSimpleName() + + "{" + "fieldName='" + fieldName + '\'' @@ -117,4 +145,82 @@ public String toString(String field) { + innerQuery + '}'; } + + private static class InlineRescoreQuery extends RescoreKnnVectorQuery { + private InlineRescoreQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + int k, + Query innerQuery + ) { + super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + var functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); + // Retrieve top k documents from the function score query + var topDocs = searcher.search(functionScoreQuery, k); + vectorOperations = topDocs.totalHits.value(); + return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return super.equals(o); + } + + @Override + public int hashCode() { + return super.hashCode(); + } + } + + private static class LateRescoreQuery extends RescoreKnnVectorQuery { + final int rescoreK; + + private LateRescoreQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + int k, + int rescoreK, + Query innerQuery + ) { + super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery); + this.rescoreK = rescoreK; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + final TopDocs topDocs; + // Retrieve top `rescoreK` documents from the inner query + topDocs = searcher.search(innerQuery, rescoreK); + vectorOperations = topDocs.totalHits.value(); + + // Retrieve top `k` documents from the top `rescoreK` query + var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader()); + var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource); + var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k); + return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + var that = (RescoreKnnVectorQuery.LateRescoreQuery) o; + return super.equals(o) && that.rescoreK == rescoreK; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), rescoreK); + } + } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java index 5c3c02a58d4d1..3a07e37159233 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java @@ -992,16 +992,10 @@ private void doTestDefaultDenseVectorMappings(DocumentMapper mapper, XContentBui assertThat(((FieldMapper) update.getRoot().getMapper("mapsToFloatTooBig")).fieldType().typeName(), equalTo("float")); assertThat(((FieldMapper) update.getRoot().getMapper("mapsToInt8HnswDenseVector")).fieldType().typeName(), equalTo("dense_vector")); DenseVectorFieldMapper int8DVFieldMapper = ((DenseVectorFieldMapper) update.getRoot().getMapper("mapsToInt8HnswDenseVector")); - assertThat( - ((DenseVectorFieldMapper.DenseVectorIndexOptions) int8DVFieldMapper.fieldType().getIndexOptions()).getType().getName(), - equalTo("int8_hnsw") - ); + assertThat(int8DVFieldMapper.fieldType().getIndexOptions().getType().getName(), equalTo("int8_hnsw")); assertThat(((FieldMapper) update.getRoot().getMapper("mapsToBBQHnswDenseVector")).fieldType().typeName(), equalTo("dense_vector")); DenseVectorFieldMapper bbqDVFieldMapper = ((DenseVectorFieldMapper) update.getRoot().getMapper("mapsToBBQHnswDenseVector")); - assertThat( - ((DenseVectorFieldMapper.DenseVectorIndexOptions) bbqDVFieldMapper.fieldType().getIndexOptions()).getType().getName(), - equalTo("bbq_hnsw") - ); + assertThat(bbqDVFieldMapper.fieldType().getIndexOptions().getType().getName(), equalTo("bbq_hnsw")); } public void testDefaultDenseVectorMappingsObject() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 6c7964bbf773b..4a933efa9516a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; +import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; @@ -238,7 +239,11 @@ public void testCreateNestedKnnQuery() { if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); } - assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); + if (field.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); + } else { + assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); + } } { DenseVectorFieldType field = new DenseVectorFieldType( @@ -269,7 +274,11 @@ public void testCreateNestedKnnQuery() { producer, randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); - assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + if (field.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DenseVectorQuery.class)); + } else { + assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + } vectorData = new VectorData(floatQueryVector, null); query = field.createKnnQuery( @@ -282,7 +291,11 @@ public void testCreateNestedKnnQuery() { producer, randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); - assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + if (field.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DenseVectorQuery.class)); + } else { + assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); + } } } @@ -445,7 +458,11 @@ public void testCreateKnnQueryMaxDims() { if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) { query = rescoreKnnVectorQuery.innerQuery(); } - assertThat(query, instanceOf(KnnFloatVectorQuery.class)); + if (fieldWith4096dims.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); + } else { + assertThat(query, instanceOf(KnnFloatVectorQuery.class)); + } } { // byte type with 4096 dims @@ -475,7 +492,11 @@ public void testCreateKnnQueryMaxDims() { null, randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); - assertThat(query, instanceOf(KnnByteVectorQuery.class)); + if (fieldWith4096dims.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DenseVectorQuery.Bytes.class)); + } else { + assertThat(query, instanceOf(KnnByteVectorQuery.class)); + } } } @@ -574,13 +595,21 @@ public void testRescoreOversampleUsedWithoutQuantization() { ); if (elementType == BYTE) { - ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); + if (nonQuantizedField.getIndexOptions().isFlat()) { + assertThat(knnQuery, instanceOf(DenseVectorQuery.Bytes.class)); + } else { + ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } } else { - ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); + if (nonQuantizedField.getIndexOptions().isFlat()) { + assertThat(knnQuery, instanceOf(DenseVectorQuery.Floats.class)); + } else { + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } } } @@ -628,7 +657,11 @@ public void testRescoreOversampleQueryOverrides() { null, randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) ); - assertTrue(query instanceof ESKnnFloatVectorQuery); + if (fieldType.getIndexOptions().isFlat()) { + assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); + } else { + assertThat(query, instanceOf(ESKnnFloatVectorQuery.class)); + } // verify we can override a `0` to a positive number fieldType = new DenseVectorFieldType( diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 0e295fb02eaaa..a15372bc1e8ef 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -51,6 +51,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DEFAULT_OVERSAMPLE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -203,8 +204,14 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que } } switch (elementType()) { - case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery); - case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery); + case FLOAT -> assertThat( + query, + anyOf(instanceOf(ESKnnFloatVectorQuery.class), instanceOf(DenseVectorQuery.Floats.class), instanceOf(BooleanQuery.class)) + ); + case BYTE -> assertThat( + query, + anyOf(instanceOf(ESKnnByteVectorQuery.class), instanceOf(DenseVectorQuery.Bytes.class), instanceOf(BooleanQuery.class)) + ); } BooleanQuery.Builder builder = new BooleanQuery.Builder(); @@ -244,10 +251,34 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que expectedStrategy ); }; + + Query bruteForceVectorQueryBuilt = switch (elementType()) { + case BIT, BYTE -> { + if (filterQuery != null) { + yield new BooleanQuery.Builder().add( + new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD), + BooleanClause.Occur.SHOULD + ).add(filterQuery, BooleanClause.Occur.FILTER).build(); + } else { + yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD); + } + } + case FLOAT -> { + if (filterQuery != null) { + yield new BooleanQuery.Builder().add( + new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD), + BooleanClause.Occur.SHOULD + ).add(filterQuery, BooleanClause.Occur.FILTER).build(); + } else { + yield new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD); + } + } + }; + if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); } - assertEquals(query, knnVectorQueryBuilt); + assertThat(query, anyOf(equalTo(knnVectorQueryBuilt), equalTo(bruteForceVectorQueryBuilt))); } public void testWrongDimension() { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java new file mode 100644 index 0000000000000..320b3efc4924c --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.search.join.ToParentBlockJoinQuery; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.index.mapper.MapperServiceTestCase; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.SourceToParse; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.TreeMap; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class DiversifyingParentBlockQueryTests extends MapperServiceTestCase { + private static String getMapping(int dim) { + return String.format(Locale.ROOT, """ + { + "_doc": { + "properties": { + "id": { + "type": "keyword", + "store": true + }, + "nested": { + "type": "nested", + "properties": { + "emb": { + "type": "dense_vector", + "dims": %d, + "similarity": "l2_norm", + "index_options": { + "type": "flat" + } + } + } + } + } + } + } + } + """, dim); + } + + public void testRandom() throws IOException { + int dims = randomIntBetween(3, 10); + var mapperService = createMapperService(getMapping(dims)); + var fieldType = (DenseVectorFieldMapper.DenseVectorFieldType) mapperService.fieldType("nested.emb"); + var nestedParent = mapperService.mappingLookup().nestedLookup().getNestedMappers().get("nested"); + + int numQueries = randomIntBetween(1, 3); + float[][] queries = new float[numQueries][]; + List> expectedTopDocs = new ArrayList<>(); + for (int i = 0; i < numQueries; i++) { + queries[i] = randomVector(dims); + expectedTopDocs.add(new TreeMap<>((o1, o2) -> -Float.compare(o1, o2))); + } + + withLuceneIndex(mapperService, iw -> { + int numDocs = randomIntBetween(10, 50); + for (int i = 0; i < numDocs; i++) { + int numVectors = randomIntBetween(0, 5); + float[][] vectors = new float[numVectors][]; + for (int j = 0; j < numVectors; j++) { + vectors[j] = randomVector(dims); + } + + for (int k = 0; k < numQueries; k++) { + float maxScore = Float.MIN_VALUE; + for (int j = 0; j < numVectors; j++) { + float score = EUCLIDEAN.compare(vectors[j], queries[k]); + maxScore = Math.max(score, maxScore); + } + expectedTopDocs.get(k).put(maxScore, Integer.toString(i)); + } + + SourceToParse source = randomSource(Integer.toString(i), vectors); + ParsedDocument doc = mapperService.documentMapper().parse(source); + iw.addDocuments(doc.docs()); + + if (randomBoolean()) { + int numEmpty = randomIntBetween(1, 3); + for (int l = 0; l < numEmpty; l++) { + source = randomSource(randomAlphaOfLengthBetween(5, 10), new float[0][]); + doc = mapperService.documentMapper().parse(source); + iw.addDocuments(doc.docs()); + } + } + } + }, ir -> { + var storedFields = ir.storedFields(); + var searcher = new IndexSearcher(wrapInMockESDirectoryReader(ir)); + var context = createSearchExecutionContext(mapperService); + var bitSetproducer = context.bitsetFilter(nestedParent.parentTypeFilter()); + for (int i = 0; i < numQueries; i++) { + var knnQuery = fieldType.createKnnQuery( + VectorData.fromFloats(queries[i]), + 10, + 10, + null, + null, + null, + bitSetproducer, + DenseVectorFieldMapper.FilterHeuristic.ACORN + ); + assertThat(knnQuery, instanceOf(DiversifyingParentBlockQuery.class)); + var nestedQuery = new ToParentBlockJoinQuery(knnQuery, bitSetproducer, ScoreMode.Total); + var topDocs = searcher.search(nestedQuery, 10); + for (var doc : topDocs.scoreDocs) { + var entry = expectedTopDocs.get(i).pollFirstEntry(); + assertNotNull(entry); + assertThat(doc.score, equalTo(entry.getKey())); + var storedDoc = storedFields.document(doc.doc, Set.of("id")); + assertThat(storedDoc.getField("id").binaryValue().utf8ToString(), equalTo(entry.getValue())); + } + } + }); + } + + private SourceToParse randomSource(String id, float[][] vectors) throws IOException { + try (var builder = XContentBuilder.builder(XContentType.JSON.xContent())) { + builder.startObject(); + builder.field("id", id); + builder.startArray("nested"); + for (int i = 0; i < vectors.length; i++) { + builder.startObject(); + builder.field("emb", vectors[i]); + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + return new SourceToParse(id, BytesReference.bytes(builder), XContentType.JSON); + } + } + + private float[] randomVector(int dim) { + float[] vector = new float[dim]; + for (int i = 0; i < vector.length; i++) { + vector[i] = randomFloat(); + } + return vector; + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 05b7bc9ef4f82..c2ba061ad38bc 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -18,7 +18,10 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchAllDocsQuery; @@ -41,6 +44,8 @@ import java.io.IOException; import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.List; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -54,53 +59,57 @@ public void testRescoreDocs() throws Exception { int numDims = randomIntBetween(5, 100); int k = randomIntBetween(1, numDocs - 1); + var queryVector = randomVector(numDims); + List innerQueries = new ArrayList<>(); + innerQueries.add(new KnnFloatVectorQuery(FIELD_NAME, randomVector(numDims), (int) (k * randomFloatBetween(1.0f, 10.0f, true)))); + innerQueries.add( + new BooleanQuery.Builder().add(new DenseVectorQuery.Floats(queryVector, FIELD_NAME), BooleanClause.Occur.SHOULD) + .add(new FieldExistsQuery(FIELD_NAME), BooleanClause.Occur.FILTER) + .build() + ); + innerQueries.add(new MatchAllDocsQuery()); + try (Directory d = newDirectory()) { addRandomDocuments(numDocs, d, numDims); try (IndexReader reader = DirectoryReader.open(d)) { - // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query - // and thus we're rescoring the top k docs. - float[] queryVector = randomVector(numDims); - Query innerQuery; - if (randomBoolean()) { - innerQuery = new KnnFloatVectorQuery(FIELD_NAME, queryVector, (int) (k * randomFloatBetween(1.0f, 10.0f, true))); - } else { - innerQuery = new MatchAllDocsQuery(); - } - RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( - FIELD_NAME, - queryVector, - VectorSimilarityFunction.COSINE, - k, - innerQuery - ); - - IndexSearcher searcher = newSearcher(reader, true, false); - TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs); - assertThat(rescoredDocs.scoreDocs.length, equalTo(k)); - - // Get real scores - DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource( - FIELD_NAME, - queryVector, - VectorSimilarityFunction.COSINE - ); - FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource); - TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs); - - int i = 0; - ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs; - for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) { - // There are docs that won't be found in the rescored search, but every doc found must be in the same order - // and have the same score - while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) { - i++; - } - if (i >= realScoreDocs.length) { - fail("Rescored doc not found in real score docs"); + for (var innerQuery : innerQueries) { + RescoreKnnVectorQuery rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + k, + k, + innerQuery + ); + + IndexSearcher searcher = newSearcher(reader, true, false); + TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs); + assertThat(rescoredDocs.scoreDocs.length, equalTo(k)); + + // Get real scores + DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE + ); + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource); + TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs); + + int i = 0; + ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs; + for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) { + // There are docs that won't be found in the rescored search, but every doc found must be in the same order + // and have the same score + while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) { + i++; + } + if (i >= realScoreDocs.length) { + fail("Rescored doc not found in real score docs"); + } + assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score)); } - assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score)); } } } @@ -124,11 +133,12 @@ public void testProfiling() throws Exception { } private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { - RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + var rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery( FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, + k, innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false);