diff --git a/docs/changelog/130251.yaml b/docs/changelog/130251.yaml
new file mode 100644
index 0000000000000..2dd117bada717
--- /dev/null
+++ b/docs/changelog/130251.yaml
@@ -0,0 +1,5 @@
+pr: 130251
+summary: Speed up (filtered) KNN queries for flat vector fields
+area: Vector Search
+type: enhancement
+issues: []
diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
index 7dd6f2894a20a..c24d8cb7e94e2 100644
--- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
+++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
@@ -309,7 +309,7 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException
}
if (overSamplingFactor > 1f) {
// oversample the topK results to get more candidates for the final result
- knnQuery = new RescoreKnnVectorQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, knnQuery);
+ knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
}
QueryProfiler profiler = new QueryProfiler();
TopDocs docs = searcher.search(knnQuery, this.topK);
diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
index 329e426be7f47..819d9608f2348 100644
--- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
+++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
@@ -30,6 +30,8 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
@@ -77,6 +79,7 @@
import org.elasticsearch.search.lookup.Source;
import org.elasticsearch.search.vectors.DenseVectorQuery;
import org.elasticsearch.search.vectors.DiversifyingChildrenIVFKnnFloatVectorQuery;
+import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery;
import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery;
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
@@ -1391,6 +1394,18 @@ public final boolean equals(Object other) {
public final int hashCode() {
return Objects.hash(type, doHashCode());
}
+
+ /**
+ * Indicates whether the underlying vector search is performed using a flat (exhaustive) approach.
+ *
+ * When {@code true}, it means the search does not use any approximate nearest neighbor (ANN)
+ * acceleration structures such as HNSW or IVF. Instead, it performs a brute-force comparison
+ * against all candidate vectors. This information can be used by higher-level components
+ * to decide whether additional acceleration or optimization is necessary.
+ *
+ * @return {@code true} if the vector search is flat (exhaustive), {@code false} if it uses ANN structures
+ */
+ abstract boolean isFlat();
}
abstract static class QuantizedIndexOptions extends DenseVectorIndexOptions {
@@ -1762,6 +1777,11 @@ int doHashCode() {
return Objects.hash(confidenceInterval, rescoreVector);
}
+ @Override
+ boolean isFlat() {
+ return true;
+ }
+
@Override
public boolean updatableTo(DenseVectorIndexOptions update) {
return update.type.equals(this.type)
@@ -1810,6 +1830,11 @@ public boolean doEquals(DenseVectorIndexOptions o) {
public int doHashCode() {
return Objects.hash(type);
}
+
+ @Override
+ boolean isFlat() {
+ return true;
+ }
}
public static class Int4HnswIndexOptions extends QuantizedIndexOptions {
@@ -1860,6 +1885,11 @@ public int doHashCode() {
return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector);
}
+ @Override
+ boolean isFlat() {
+ return false;
+ }
+
@Override
public String toString() {
return "{type="
@@ -1931,6 +1961,11 @@ public int doHashCode() {
return Objects.hash(confidenceInterval, rescoreVector);
}
+ @Override
+ boolean isFlat() {
+ return true;
+ }
+
@Override
public String toString() {
return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}";
@@ -1999,6 +2034,11 @@ public int doHashCode() {
return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector);
}
+ @Override
+ boolean isFlat() {
+ return false;
+ }
+
@Override
public String toString() {
return "{type="
@@ -2088,6 +2128,11 @@ public int doHashCode() {
return Objects.hash(m, efConstruction);
}
+ @Override
+ boolean isFlat() {
+ return false;
+ }
+
@Override
public String toString() {
return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}";
@@ -2126,6 +2171,11 @@ int doHashCode() {
return Objects.hash(m, efConstruction, rescoreVector);
}
+ @Override
+ boolean isFlat() {
+ return false;
+ }
+
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@@ -2179,6 +2229,11 @@ int doHashCode() {
return CLASS_NAME_HASH;
}
+ @Override
+ boolean isFlat() {
+ return true;
+ }
+
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@@ -2237,6 +2292,11 @@ int doHashCode() {
return Objects.hash(clusterSize, defaultNProbe, rescoreVector);
}
+ @Override
+ boolean isFlat() {
+ return false;
+ }
+
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@@ -2485,9 +2545,21 @@ private Query createKnnBitQuery(
KnnSearchStrategy searchStrategy
) {
elementType.checkDimensions(dims, queryVector.length);
- Query knnQuery = parentFilter != null
- ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
- : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+ Query knnQuery;
+ if (indexOptions != null && indexOptions.isFlat()) {
+ var exactKnnQuery = parentFilter != null
+ ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnBitQuery(queryVector))
+ : createExactKnnBitQuery(queryVector);
+ knnQuery = filter == null
+ ? exactKnnQuery
+ : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
+ .add(filter, BooleanClause.Occur.FILTER)
+ .build();
+ } else {
+ knnQuery = parentFilter != null
+ ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
+ : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+ }
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
@@ -2513,9 +2585,22 @@ private Query createKnnByteQuery(
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
- Query knnQuery = parentFilter != null
- ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
- : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+
+ Query knnQuery;
+ if (indexOptions != null && indexOptions.isFlat()) {
+ var exactKnnQuery = parentFilter != null
+ ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnByteQuery(queryVector))
+ : createExactKnnByteQuery(queryVector);
+ knnQuery = filter == null
+ ? exactKnnQuery
+ : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
+ .add(filter, BooleanClause.Occur.FILTER)
+ .build();
+ } else {
+ knnQuery = parentFilter != null
+ ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
+ : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+ }
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
@@ -2568,7 +2653,16 @@ && isNotUnitVector(squaredMagnitude)) {
numCands = Math.max(adjustedK, numCands);
}
Query knnQuery;
- if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
+ if (indexOptions != null && indexOptions.isFlat()) {
+ var exactKnnQuery = parentFilter != null
+ ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnFloatQuery(queryVector))
+ : createExactKnnFloatQuery(queryVector);
+ knnQuery = filter == null
+ ? exactKnnQuery
+ : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
+ .add(filter, BooleanClause.Occur.FILTER)
+ .build();
+ } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
knnQuery = parentFilter != null
? new DiversifyingChildrenIVFKnnFloatVectorQuery(
name(),
@@ -2594,11 +2688,12 @@ && isNotUnitVector(squaredMagnitude)) {
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
}
if (rescore) {
- knnQuery = new RescoreKnnVectorQuery(
+ knnQuery = RescoreKnnVectorQuery.fromInnerQuery(
name(),
queryVector,
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
k,
+ adjustedK,
knnQuery
);
}
@@ -2624,7 +2719,7 @@ ElementType getElementType() {
return elementType;
}
- public IndexOptions getIndexOptions() {
+ public DenseVectorIndexOptions getIndexOptions() {
return indexOptions;
}
diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java
index 31e19b6784757..dfc84f5eaef6d 100644
--- a/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java
+++ b/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java
@@ -207,4 +207,5 @@ public int docID() {
return iterator.docID();
}
}
+
}
diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java
new file mode 100644
index 0000000000000..052914697a873
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java
@@ -0,0 +1,195 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryVisitor;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.ScorerSupplier;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.search.join.BitSetProducer;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * A Lucene query that selects the highest-scoring child document for each parent block.
+ *
+ * Children are scored using the {@code innerQuery}, and for each parent (as defined by the
+ * {@code parentFilter}), the single best-scoring child is returned.
+ */
+public class DiversifyingParentBlockQuery extends Query {
+ private final BitSetProducer parentFilter;
+ private final Query innerQuery;
+
+ public DiversifyingParentBlockQuery(BitSetProducer parentFilter, Query innerQuery) {
+ this.parentFilter = Objects.requireNonNull(parentFilter);
+ this.innerQuery = Objects.requireNonNull(innerQuery);
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ Query rewritten = innerQuery.rewrite(indexSearcher);
+ if (rewritten != innerQuery) {
+ return new DiversifyingParentBlockQuery(parentFilter, rewritten);
+ }
+ return this;
+ }
+
+ @Override
+ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
+ Weight innerWeight = innerQuery.createWeight(searcher, scoreMode, boost);
+ return new DiversifyingParentBlockWeight(this, innerWeight, parentFilter);
+ }
+
+ @Override
+ public String toString(String field) {
+ return "DiversifyingBlockQuery(inner=" + innerQuery.toString(field) + ")";
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {
+ innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ DiversifyingParentBlockQuery that = (DiversifyingParentBlockQuery) o;
+ return Objects.equals(innerQuery, that.innerQuery) && parentFilter == that.parentFilter;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(innerQuery, parentFilter);
+ }
+
+ private static class DiversifyingParentBlockWeight extends Weight {
+ private final Weight innerWeight;
+ private final BitSetProducer parentFilter;
+
+ DiversifyingParentBlockWeight(Query query, Weight innerWeight, BitSetProducer parentFilter) {
+ super(query);
+ this.innerWeight = innerWeight;
+ this.parentFilter = parentFilter;
+ }
+
+ @Override
+ public Explanation explain(LeafReaderContext context, int doc) throws IOException {
+ return innerWeight.explain(context, doc);
+ }
+
+ @Override
+ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
+ var innerSupplier = innerWeight.scorerSupplier(context);
+ var parentBits = parentFilter.getBitSet(context);
+ if (parentBits == null || innerSupplier == null) {
+ return null;
+ }
+
+ return new ScorerSupplier() {
+ @Override
+ public Scorer get(long leadCost) throws IOException {
+ var innerScorer = innerSupplier.get(leadCost);
+ var innerIterator = innerScorer.iterator();
+ return new Scorer() {
+ int currentDoc = -1;
+ float currentScore = Float.NaN;
+
+ @Override
+ public int docID() {
+ return currentDoc;
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return new DocIdSetIterator() {
+ boolean exhausted = false;
+
+ @Override
+ public int docID() {
+ return currentDoc;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ return advance(currentDoc + 1);
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ if (exhausted) {
+ return NO_MORE_DOCS;
+ }
+ if (currentDoc == -1 || innerIterator.docID() < target) {
+ if (innerIterator.advance(target) == NO_MORE_DOCS) {
+ exhausted = true;
+ return currentDoc = NO_MORE_DOCS;
+ }
+ }
+
+ int bestChild = innerIterator.docID();
+ float bestScore = innerScorer.score();
+ int parent = parentBits.nextSetBit(bestChild);
+
+ int innerDoc;
+ while ((innerDoc = innerIterator.nextDoc()) < parent) {
+ float score = innerScorer.score();
+ if (score > bestScore) {
+ bestChild = innerIterator.docID();
+ bestScore = score;
+ }
+ }
+ if (innerDoc == NO_MORE_DOCS) {
+ exhausted = true;
+ }
+ currentScore = bestScore;
+ return currentDoc = bestChild;
+ }
+
+ @Override
+ public long cost() {
+ return innerIterator.cost();
+ }
+ };
+ }
+
+ @Override
+ public float score() throws IOException {
+ return currentScore;
+ }
+
+ @Override
+ public float getMaxScore(int upTo) throws IOException {
+ return innerScorer.getMaxScore(upTo);
+ }
+ };
+ }
+
+ @Override
+ public long cost() {
+ return innerSupplier.cost();
+ }
+ };
+ }
+
+ @Override
+ public boolean isCacheable(LeafReaderContext ctx) {
+ return false;
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java
index 99568a507ffb9..c7346bb9edd75 100644
--- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java
+++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java
@@ -12,8 +12,9 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionScoreQuery;
import org.apache.lucene.search.BooleanClause;
-import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnByteVectorQuery;
+import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.TopDocs;
@@ -25,17 +26,28 @@
import java.util.Objects;
/**
- * Wraps an internal query to rescore the results using a similarity function over the original, non-quantized vectors of a vector field
+ * A Lucene {@link Query} that applies vector-based rescoring to an inner query's results.
+ *
+ * Depending on the nature of the {@code innerQuery}, this class dynamically selects between two rescoring strategies:
+ *
+ * - Inline rescoring:
+ * Used when the inner query is already a top-N vector query with {@code rescoreK} results.
+ * The vector similarity is applied inline using a {@link FunctionScoreQuery} without an additional
+ * filtering pass.
+ * - Late rescoring: Used when the inner query is not a top-N vector query or does not return
+ * {@code rescoreK} results. The top {@code rescoreK} documents are first collected, and then rescoring is applied
+ * separately to select the final top {@code k}.
+ *
*/
-public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider {
- private final String fieldName;
- private final float[] floatTarget;
- private final VectorSimilarityFunction vectorSimilarityFunction;
- private final int k;
- private final Query innerQuery;
- private long vectorOperations = 0;
-
- public RescoreKnnVectorQuery(
+public abstract class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider {
+ protected final String fieldName;
+ protected final float[] floatTarget;
+ protected final VectorSimilarityFunction vectorSimilarityFunction;
+ protected final int k;
+ protected final Query innerQuery;
+ protected long vectorOperations = 0;
+
+ private RescoreKnnVectorQuery(
String fieldName,
float[] floatTarget,
VectorSimilarityFunction vectorSimilarityFunction,
@@ -49,16 +61,31 @@ public RescoreKnnVectorQuery(
this.innerQuery = innerQuery;
}
- @Override
- public Query rewrite(IndexSearcher searcher) throws IOException {
- DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
- FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
- Query query = searcher.rewrite(functionScoreQuery);
-
- // Retrieve top k documents from the rescored query
- TopDocs topDocs = searcher.search(query, k);
- vectorOperations = topDocs.totalHits.value();
- return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
+ /**
+ * Selects and returns the appropriate {@link RescoreKnnVectorQuery} strategy based on the nature of the {@code innerQuery}.
+ *
+ * @param fieldName the name of the field containing the vector
+ * @param floatTarget the target vector to compare against
+ * @param vectorSimilarityFunction the similarity function to apply
+ * @param k the number of top documents to return after rescoring
+ * @param rescoreK the number of top documents to consider for rescoring
+ * @param innerQuery the original Lucene query to rescore
+ */
+ public static RescoreKnnVectorQuery fromInnerQuery(
+ String fieldName,
+ float[] floatTarget,
+ VectorSimilarityFunction vectorSimilarityFunction,
+ int k,
+ int rescoreK,
+ Query innerQuery
+ ) {
+ if ((innerQuery instanceof KnnFloatVectorQuery fQuery && fQuery.getK() == rescoreK)
+ || (innerQuery instanceof KnnByteVectorQuery bQuery && bQuery.getK() == rescoreK)
+ || (innerQuery instanceof AbstractIVFKnnVectorQuery ivfQuery && ivfQuery.k == rescoreK)) {
+ // Queries that return only the top `k` results and do not require reduction before re-scoring.
+ return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
+ }
+ return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery);
}
public Query innerQuery() {
@@ -102,7 +129,8 @@ public int hashCode() {
@Override
public String toString(String field) {
- return "KnnRescoreVectorQuery{"
+ return getClass().getSimpleName()
+ + "{"
+ "fieldName='"
+ fieldName
+ '\''
@@ -117,4 +145,82 @@ public String toString(String field) {
+ innerQuery
+ '}';
}
+
+ private static class InlineRescoreQuery extends RescoreKnnVectorQuery {
+ private InlineRescoreQuery(
+ String fieldName,
+ float[] floatTarget,
+ VectorSimilarityFunction vectorSimilarityFunction,
+ int k,
+ Query innerQuery
+ ) {
+ super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher searcher) throws IOException {
+ var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
+ var functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
+ // Retrieve top k documents from the function score query
+ var topDocs = searcher.search(functionScoreQuery, k);
+ vectorOperations = topDocs.totalHits.value();
+ return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ return super.equals(o);
+ }
+
+ @Override
+ public int hashCode() {
+ return super.hashCode();
+ }
+ }
+
+ private static class LateRescoreQuery extends RescoreKnnVectorQuery {
+ final int rescoreK;
+
+ private LateRescoreQuery(
+ String fieldName,
+ float[] floatTarget,
+ VectorSimilarityFunction vectorSimilarityFunction,
+ int k,
+ int rescoreK,
+ Query innerQuery
+ ) {
+ super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
+ this.rescoreK = rescoreK;
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher searcher) throws IOException {
+ final TopDocs topDocs;
+ // Retrieve top `rescoreK` documents from the inner query
+ topDocs = searcher.search(innerQuery, rescoreK);
+ vectorOperations = topDocs.totalHits.value();
+
+ // Retrieve top `k` documents from the top `rescoreK` query
+ var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
+ var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
+ var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource);
+ var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k);
+ return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader());
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ var that = (RescoreKnnVectorQuery.LateRescoreQuery) o;
+ return super.equals(o) && that.rescoreK == rescoreK;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), rescoreK);
+ }
+ }
}
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java
index 5c3c02a58d4d1..3a07e37159233 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java
@@ -992,16 +992,10 @@ private void doTestDefaultDenseVectorMappings(DocumentMapper mapper, XContentBui
assertThat(((FieldMapper) update.getRoot().getMapper("mapsToFloatTooBig")).fieldType().typeName(), equalTo("float"));
assertThat(((FieldMapper) update.getRoot().getMapper("mapsToInt8HnswDenseVector")).fieldType().typeName(), equalTo("dense_vector"));
DenseVectorFieldMapper int8DVFieldMapper = ((DenseVectorFieldMapper) update.getRoot().getMapper("mapsToInt8HnswDenseVector"));
- assertThat(
- ((DenseVectorFieldMapper.DenseVectorIndexOptions) int8DVFieldMapper.fieldType().getIndexOptions()).getType().getName(),
- equalTo("int8_hnsw")
- );
+ assertThat(int8DVFieldMapper.fieldType().getIndexOptions().getType().getName(), equalTo("int8_hnsw"));
assertThat(((FieldMapper) update.getRoot().getMapper("mapsToBBQHnswDenseVector")).fieldType().typeName(), equalTo("dense_vector"));
DenseVectorFieldMapper bbqDVFieldMapper = ((DenseVectorFieldMapper) update.getRoot().getMapper("mapsToBBQHnswDenseVector"));
- assertThat(
- ((DenseVectorFieldMapper.DenseVectorIndexOptions) bbqDVFieldMapper.fieldType().getIndexOptions()).getType().getName(),
- equalTo("bbq_hnsw")
- );
+ assertThat(bbqDVFieldMapper.fieldType().getIndexOptions().getType().getName(), equalTo("bbq_hnsw"));
}
public void testDefaultDenseVectorMappingsObject() throws IOException {
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java
index 6c7964bbf773b..4a933efa9516a 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java
@@ -25,6 +25,7 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.vectors.DenseVectorQuery;
+import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery;
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
import org.elasticsearch.search.vectors.RescoreKnnVectorQuery;
@@ -238,7 +239,11 @@ public void testCreateNestedKnnQuery() {
if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
query = rescoreKnnVectorQuery.innerQuery();
}
- assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class));
+ if (field.getIndexOptions().isFlat()) {
+ assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
+ } else {
+ assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class));
+ }
}
{
DenseVectorFieldType field = new DenseVectorFieldType(
@@ -269,7 +274,11 @@ public void testCreateNestedKnnQuery() {
producer,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
- assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+ if (field.getIndexOptions().isFlat()) {
+ assertThat(query, instanceOf(DenseVectorQuery.class));
+ } else {
+ assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+ }
vectorData = new VectorData(floatQueryVector, null);
query = field.createKnnQuery(
@@ -282,7 +291,11 @@ public void testCreateNestedKnnQuery() {
producer,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
- assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+ if (field.getIndexOptions().isFlat()) {
+ assertThat(query, instanceOf(DenseVectorQuery.class));
+ } else {
+ assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+ }
}
}
@@ -445,7 +458,11 @@ public void testCreateKnnQueryMaxDims() {
if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
query = rescoreKnnVectorQuery.innerQuery();
}
- assertThat(query, instanceOf(KnnFloatVectorQuery.class));
+ if (fieldWith4096dims.getIndexOptions().isFlat()) {
+ assertThat(query, instanceOf(DenseVectorQuery.Floats.class));
+ } else {
+ assertThat(query, instanceOf(KnnFloatVectorQuery.class));
+ }
}
{ // byte type with 4096 dims
@@ -475,7 +492,11 @@ public void testCreateKnnQueryMaxDims() {
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
- assertThat(query, instanceOf(KnnByteVectorQuery.class));
+ if (fieldWith4096dims.getIndexOptions().isFlat()) {
+ assertThat(query, instanceOf(DenseVectorQuery.Bytes.class));
+ } else {
+ assertThat(query, instanceOf(KnnByteVectorQuery.class));
+ }
}
}
@@ -574,13 +595,21 @@ public void testRescoreOversampleUsedWithoutQuantization() {
);
if (elementType == BYTE) {
- ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery;
- assertThat(esKnnQuery.getK(), is(100));
- assertThat(esKnnQuery.kParam(), is(10));
+ if (nonQuantizedField.getIndexOptions().isFlat()) {
+ assertThat(knnQuery, instanceOf(DenseVectorQuery.Bytes.class));
+ } else {
+ ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery;
+ assertThat(esKnnQuery.getK(), is(100));
+ assertThat(esKnnQuery.kParam(), is(10));
+ }
} else {
- ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
- assertThat(esKnnQuery.getK(), is(100));
- assertThat(esKnnQuery.kParam(), is(10));
+ if (nonQuantizedField.getIndexOptions().isFlat()) {
+ assertThat(knnQuery, instanceOf(DenseVectorQuery.Floats.class));
+ } else {
+ ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
+ assertThat(esKnnQuery.getK(), is(100));
+ assertThat(esKnnQuery.kParam(), is(10));
+ }
}
}
@@ -628,7 +657,11 @@ public void testRescoreOversampleQueryOverrides() {
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
- assertTrue(query instanceof ESKnnFloatVectorQuery);
+ if (fieldType.getIndexOptions().isFlat()) {
+ assertThat(query, instanceOf(DenseVectorQuery.Floats.class));
+ } else {
+ assertThat(query, instanceOf(ESKnnFloatVectorQuery.class));
+ }
// verify we can override a `0` to a positive number
fieldType = new DenseVectorFieldType(
diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java
index 0e295fb02eaaa..a15372bc1e8ef 100644
--- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java
+++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java
@@ -51,6 +51,7 @@
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DEFAULT_OVERSAMPLE;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
+import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@@ -203,8 +204,14 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
}
}
switch (elementType()) {
- case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery);
- case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery);
+ case FLOAT -> assertThat(
+ query,
+ anyOf(instanceOf(ESKnnFloatVectorQuery.class), instanceOf(DenseVectorQuery.Floats.class), instanceOf(BooleanQuery.class))
+ );
+ case BYTE -> assertThat(
+ query,
+ anyOf(instanceOf(ESKnnByteVectorQuery.class), instanceOf(DenseVectorQuery.Bytes.class), instanceOf(BooleanQuery.class))
+ );
}
BooleanQuery.Builder builder = new BooleanQuery.Builder();
@@ -244,10 +251,34 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
expectedStrategy
);
};
+
+ Query bruteForceVectorQueryBuilt = switch (elementType()) {
+ case BIT, BYTE -> {
+ if (filterQuery != null) {
+ yield new BooleanQuery.Builder().add(
+ new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD),
+ BooleanClause.Occur.SHOULD
+ ).add(filterQuery, BooleanClause.Occur.FILTER).build();
+ } else {
+ yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD);
+ }
+ }
+ case FLOAT -> {
+ if (filterQuery != null) {
+ yield new BooleanQuery.Builder().add(
+ new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD),
+ BooleanClause.Occur.SHOULD
+ ).add(filterQuery, BooleanClause.Occur.FILTER).build();
+ } else {
+ yield new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD);
+ }
+ }
+ };
+
if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) {
query = vectorSimilarityQuery.getInnerKnnQuery();
}
- assertEquals(query, knnVectorQueryBuilt);
+ assertThat(query, anyOf(equalTo(knnVectorQueryBuilt), equalTo(bruteForceVectorQueryBuilt)));
}
public void testWrongDimension() {
diff --git a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java
new file mode 100644
index 0000000000000..320b3efc4924c
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java
@@ -0,0 +1,162 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.join.ScoreMode;
+import org.apache.lucene.search.join.ToParentBlockJoinQuery;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.index.mapper.MapperServiceTestCase;
+import org.elasticsearch.index.mapper.ParsedDocument;
+import org.elasticsearch.index.mapper.SourceToParse;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.Set;
+import java.util.TreeMap;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+
+public class DiversifyingParentBlockQueryTests extends MapperServiceTestCase {
+ private static String getMapping(int dim) {
+ return String.format(Locale.ROOT, """
+ {
+ "_doc": {
+ "properties": {
+ "id": {
+ "type": "keyword",
+ "store": true
+ },
+ "nested": {
+ "type": "nested",
+ "properties": {
+ "emb": {
+ "type": "dense_vector",
+ "dims": %d,
+ "similarity": "l2_norm",
+ "index_options": {
+ "type": "flat"
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ """, dim);
+ }
+
+ public void testRandom() throws IOException {
+ int dims = randomIntBetween(3, 10);
+ var mapperService = createMapperService(getMapping(dims));
+ var fieldType = (DenseVectorFieldMapper.DenseVectorFieldType) mapperService.fieldType("nested.emb");
+ var nestedParent = mapperService.mappingLookup().nestedLookup().getNestedMappers().get("nested");
+
+ int numQueries = randomIntBetween(1, 3);
+ float[][] queries = new float[numQueries][];
+ List> expectedTopDocs = new ArrayList<>();
+ for (int i = 0; i < numQueries; i++) {
+ queries[i] = randomVector(dims);
+ expectedTopDocs.add(new TreeMap<>((o1, o2) -> -Float.compare(o1, o2)));
+ }
+
+ withLuceneIndex(mapperService, iw -> {
+ int numDocs = randomIntBetween(10, 50);
+ for (int i = 0; i < numDocs; i++) {
+ int numVectors = randomIntBetween(0, 5);
+ float[][] vectors = new float[numVectors][];
+ for (int j = 0; j < numVectors; j++) {
+ vectors[j] = randomVector(dims);
+ }
+
+ for (int k = 0; k < numQueries; k++) {
+ float maxScore = Float.MIN_VALUE;
+ for (int j = 0; j < numVectors; j++) {
+ float score = EUCLIDEAN.compare(vectors[j], queries[k]);
+ maxScore = Math.max(score, maxScore);
+ }
+ expectedTopDocs.get(k).put(maxScore, Integer.toString(i));
+ }
+
+ SourceToParse source = randomSource(Integer.toString(i), vectors);
+ ParsedDocument doc = mapperService.documentMapper().parse(source);
+ iw.addDocuments(doc.docs());
+
+ if (randomBoolean()) {
+ int numEmpty = randomIntBetween(1, 3);
+ for (int l = 0; l < numEmpty; l++) {
+ source = randomSource(randomAlphaOfLengthBetween(5, 10), new float[0][]);
+ doc = mapperService.documentMapper().parse(source);
+ iw.addDocuments(doc.docs());
+ }
+ }
+ }
+ }, ir -> {
+ var storedFields = ir.storedFields();
+ var searcher = new IndexSearcher(wrapInMockESDirectoryReader(ir));
+ var context = createSearchExecutionContext(mapperService);
+ var bitSetproducer = context.bitsetFilter(nestedParent.parentTypeFilter());
+ for (int i = 0; i < numQueries; i++) {
+ var knnQuery = fieldType.createKnnQuery(
+ VectorData.fromFloats(queries[i]),
+ 10,
+ 10,
+ null,
+ null,
+ null,
+ bitSetproducer,
+ DenseVectorFieldMapper.FilterHeuristic.ACORN
+ );
+ assertThat(knnQuery, instanceOf(DiversifyingParentBlockQuery.class));
+ var nestedQuery = new ToParentBlockJoinQuery(knnQuery, bitSetproducer, ScoreMode.Total);
+ var topDocs = searcher.search(nestedQuery, 10);
+ for (var doc : topDocs.scoreDocs) {
+ var entry = expectedTopDocs.get(i).pollFirstEntry();
+ assertNotNull(entry);
+ assertThat(doc.score, equalTo(entry.getKey()));
+ var storedDoc = storedFields.document(doc.doc, Set.of("id"));
+ assertThat(storedDoc.getField("id").binaryValue().utf8ToString(), equalTo(entry.getValue()));
+ }
+ }
+ });
+ }
+
+ private SourceToParse randomSource(String id, float[][] vectors) throws IOException {
+ try (var builder = XContentBuilder.builder(XContentType.JSON.xContent())) {
+ builder.startObject();
+ builder.field("id", id);
+ builder.startArray("nested");
+ for (int i = 0; i < vectors.length; i++) {
+ builder.startObject();
+ builder.field("emb", vectors[i]);
+ builder.endObject();
+ }
+ builder.endArray();
+ builder.endObject();
+ return new SourceToParse(id, BytesReference.bytes(builder), XContentType.JSON);
+ }
+ }
+
+ private float[] randomVector(int dim) {
+ float[] vector = new float[dim];
+ for (int i = 0; i < vector.length; i++) {
+ vector[i] = randomFloat();
+ }
+ return vector;
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java
index 05b7bc9ef4f82..c2ba061ad38bc 100644
--- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java
+++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java
@@ -18,7 +18,10 @@
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionScoreQuery;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DoubleValuesSource;
+import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
@@ -41,6 +44,8 @@
import java.io.IOException;
import java.io.UnsupportedEncodingException;
+import java.util.ArrayList;
+import java.util.List;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
@@ -54,53 +59,57 @@ public void testRescoreDocs() throws Exception {
int numDims = randomIntBetween(5, 100);
int k = randomIntBetween(1, numDocs - 1);
+ var queryVector = randomVector(numDims);
+ List innerQueries = new ArrayList<>();
+ innerQueries.add(new KnnFloatVectorQuery(FIELD_NAME, randomVector(numDims), (int) (k * randomFloatBetween(1.0f, 10.0f, true))));
+ innerQueries.add(
+ new BooleanQuery.Builder().add(new DenseVectorQuery.Floats(queryVector, FIELD_NAME), BooleanClause.Occur.SHOULD)
+ .add(new FieldExistsQuery(FIELD_NAME), BooleanClause.Occur.FILTER)
+ .build()
+ );
+ innerQueries.add(new MatchAllDocsQuery());
+
try (Directory d = newDirectory()) {
addRandomDocuments(numDocs, d, numDims);
try (IndexReader reader = DirectoryReader.open(d)) {
- // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query
- // and thus we're rescoring the top k docs.
- float[] queryVector = randomVector(numDims);
- Query innerQuery;
- if (randomBoolean()) {
- innerQuery = new KnnFloatVectorQuery(FIELD_NAME, queryVector, (int) (k * randomFloatBetween(1.0f, 10.0f, true)));
- } else {
- innerQuery = new MatchAllDocsQuery();
- }
- RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
- FIELD_NAME,
- queryVector,
- VectorSimilarityFunction.COSINE,
- k,
- innerQuery
- );
-
- IndexSearcher searcher = newSearcher(reader, true, false);
- TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs);
- assertThat(rescoredDocs.scoreDocs.length, equalTo(k));
-
- // Get real scores
- DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(
- FIELD_NAME,
- queryVector,
- VectorSimilarityFunction.COSINE
- );
- FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource);
- TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs);
-
- int i = 0;
- ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs;
- for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) {
- // There are docs that won't be found in the rescored search, but every doc found must be in the same order
- // and have the same score
- while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) {
- i++;
- }
- if (i >= realScoreDocs.length) {
- fail("Rescored doc not found in real score docs");
+ for (var innerQuery : innerQueries) {
+ RescoreKnnVectorQuery rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery(
+ FIELD_NAME,
+ queryVector,
+ VectorSimilarityFunction.COSINE,
+ k,
+ k,
+ innerQuery
+ );
+
+ IndexSearcher searcher = newSearcher(reader, true, false);
+ TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs);
+ assertThat(rescoredDocs.scoreDocs.length, equalTo(k));
+
+ // Get real scores
+ DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(
+ FIELD_NAME,
+ queryVector,
+ VectorSimilarityFunction.COSINE
+ );
+ FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource);
+ TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs);
+
+ int i = 0;
+ ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs;
+ for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) {
+ // There are docs that won't be found in the rescored search, but every doc found must be in the same order
+ // and have the same score
+ while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) {
+ i++;
+ }
+ if (i >= realScoreDocs.length) {
+ fail("Rescored doc not found in real score docs");
+ }
+ assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score));
}
- assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score));
}
}
}
@@ -124,11 +133,12 @@ public void testProfiling() throws Exception {
}
private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException {
- RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
+ var rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery(
FIELD_NAME,
queryVector,
VectorSimilarityFunction.COSINE,
k,
+ k,
innerQuery
);
IndexSearcher searcher = newSearcher(reader, true, false);