Skip to content

Speed up (filtered) KNN queries for flat vector fields #130251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/130251.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 130251
summary: Speed up (filtered) KNN queries for flat vector fields
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException
}
if (overSamplingFactor > 1f) {
// oversample the topK results to get more candidates for the final result
knnQuery = new RescoreKnnVectorQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, knnQuery);
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
}
QueryProfiler profiler = new QueryProfiler();
TopDocs docs = searcher.search(knnQuery, this.topK);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
Expand Down Expand Up @@ -77,6 +79,7 @@
import org.elasticsearch.search.lookup.Source;
import org.elasticsearch.search.vectors.DenseVectorQuery;
import org.elasticsearch.search.vectors.DiversifyingChildrenIVFKnnFloatVectorQuery;
import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery;
import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery;
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
Expand Down Expand Up @@ -1391,6 +1394,18 @@ public final boolean equals(Object other) {
public final int hashCode() {
return Objects.hash(type, doHashCode());
}

/**
* Indicates whether the underlying vector search is performed using a flat (exhaustive) approach.
* <p>
* When {@code true}, it means the search does not use any approximate nearest neighbor (ANN)
* acceleration structures such as HNSW or IVF. Instead, it performs a brute-force comparison
* against all candidate vectors. This information can be used by higher-level components
* to decide whether additional acceleration or optimization is necessary.
*
* @return {@code true} if the vector search is flat (exhaustive), {@code false} if it uses ANN structures
*/
abstract boolean isFlat();
}

abstract static class QuantizedIndexOptions extends DenseVectorIndexOptions {
Expand Down Expand Up @@ -1762,6 +1777,11 @@ int doHashCode() {
return Objects.hash(confidenceInterval, rescoreVector);
}

@Override
boolean isFlat() {
return true;
}

@Override
public boolean updatableTo(DenseVectorIndexOptions update) {
return update.type.equals(this.type)
Expand Down Expand Up @@ -1810,6 +1830,11 @@ public boolean doEquals(DenseVectorIndexOptions o) {
public int doHashCode() {
return Objects.hash(type);
}

@Override
boolean isFlat() {
return true;
}
}

public static class Int4HnswIndexOptions extends QuantizedIndexOptions {
Expand Down Expand Up @@ -1860,6 +1885,11 @@ public int doHashCode() {
return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector);
}

@Override
boolean isFlat() {
return false;
}

@Override
public String toString() {
return "{type="
Expand Down Expand Up @@ -1931,6 +1961,11 @@ public int doHashCode() {
return Objects.hash(confidenceInterval, rescoreVector);
}

@Override
boolean isFlat() {
return true;
}

@Override
public String toString() {
return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}";
Expand Down Expand Up @@ -1999,6 +2034,11 @@ public int doHashCode() {
return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector);
}

@Override
boolean isFlat() {
return false;
}

@Override
public String toString() {
return "{type="
Expand Down Expand Up @@ -2088,6 +2128,11 @@ public int doHashCode() {
return Objects.hash(m, efConstruction);
}

@Override
boolean isFlat() {
return false;
}

@Override
public String toString() {
return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}";
Expand Down Expand Up @@ -2126,6 +2171,11 @@ int doHashCode() {
return Objects.hash(m, efConstruction, rescoreVector);
}

@Override
boolean isFlat() {
return false;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down Expand Up @@ -2179,6 +2229,11 @@ int doHashCode() {
return CLASS_NAME_HASH;
}

@Override
boolean isFlat() {
return true;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down Expand Up @@ -2237,6 +2292,11 @@ int doHashCode() {
return Objects.hash(clusterSize, defaultNProbe, rescoreVector);
}

@Override
boolean isFlat() {
return false;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down Expand Up @@ -2485,9 +2545,21 @@ private Query createKnnBitQuery(
KnnSearchStrategy searchStrategy
) {
elementType.checkDimensions(dims, queryVector.length);
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
Query knnQuery;
if (indexOptions != null && indexOptions.isFlat()) {
var exactKnnQuery = parentFilter != null
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnBitQuery(queryVector))
: createExactKnnBitQuery(queryVector);
knnQuery = filter == null
? createExactKnnBitQuery(queryVector)
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
.add(filter, BooleanClause.Occur.FILTER)
.build();
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
}
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
Expand All @@ -2513,9 +2585,22 @@ private Query createKnnByteQuery(
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);

Query knnQuery;
if (indexOptions != null && indexOptions.isFlat()) {
var exactKnnQuery = parentFilter != null
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnByteQuery(queryVector))
: createExactKnnByteQuery(queryVector);
knnQuery = filter == null
? createExactKnnByteQuery(queryVector)
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
.add(filter, BooleanClause.Occur.FILTER)
.build();
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
}
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
Expand Down Expand Up @@ -2568,7 +2653,16 @@ && isNotUnitVector(squaredMagnitude)) {
numCands = Math.max(adjustedK, numCands);
}
Query knnQuery;
if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
if (indexOptions != null && indexOptions.isFlat()) {
var exactKnnQuery = parentFilter != null
? new DiversifyingParentBlockQuery(parentFilter, createExactKnnFloatQuery(queryVector))
: createExactKnnFloatQuery(queryVector);
knnQuery = filter == null
? createExactKnnFloatQuery(queryVector)
: new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
.add(filter, BooleanClause.Occur.FILTER)
.build();
} else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
knnQuery = parentFilter != null
? new DiversifyingChildrenIVFKnnFloatVectorQuery(
name(),
Expand All @@ -2594,11 +2688,12 @@ && isNotUnitVector(squaredMagnitude)) {
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
}
if (rescore) {
knnQuery = new RescoreKnnVectorQuery(
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(
name(),
queryVector,
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
k,
adjustedK,
knnQuery
);
}
Expand All @@ -2624,7 +2719,7 @@ ElementType getElementType() {
return elementType;
}

public IndexOptions getIndexOptions() {
public DenseVectorIndexOptions getIndexOptions() {
return indexOptions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,5 @@ public int docID() {
return iterator.docID();
}
}

}
Loading
Loading