18
18
//! TopK: Combination of Sort / LIMIT
19
19
20
20
use arrow:: {
21
- array:: Array ,
22
- compute:: interleave_record_batch,
21
+ array:: { Array , AsArray } ,
22
+ compute:: { interleave_record_batch, prep_null_mask_filter , FilterBuilder } ,
23
23
row:: { RowConverter , Rows , SortField } ,
24
24
} ;
25
25
use datafusion_expr:: { ColumnarValue , Operator } ;
@@ -203,7 +203,7 @@ impl TopK {
203
203
let baseline = self . metrics . baseline . clone ( ) ;
204
204
let _timer = baseline. elapsed_compute ( ) . timer ( ) ;
205
205
206
- let sort_keys: Vec < ArrayRef > = self
206
+ let mut sort_keys: Vec < ArrayRef > = self
207
207
. expr
208
208
. iter ( )
209
209
. map ( |expr| {
@@ -212,15 +212,56 @@ impl TopK {
212
212
} )
213
213
. collect :: < Result < Vec < _ > > > ( ) ?;
214
214
215
+ let mut selected_rows = None ;
216
+
217
+ if let Some ( filter) = self . filter . as_ref ( ) {
218
+ // If a filter is provided, update it with the new rows
219
+ let filter = filter. current ( ) ?;
220
+ let filtered = filter. evaluate ( & batch) ?;
221
+ let num_rows = batch. num_rows ( ) ;
222
+ let array = filtered. into_array ( num_rows) ?;
223
+ let mut filter = array. as_boolean ( ) . clone ( ) ;
224
+ let true_count = filter. true_count ( ) ;
225
+ if true_count == 0 {
226
+ // nothing to filter, so no need to update
227
+ return Ok ( ( ) ) ;
228
+ }
229
+ // only update the keys / rows if the filter does not match all rows
230
+ if true_count < num_rows {
231
+ // Indices in `set_indices` should be correct if filter contains nulls
232
+ // So we prepare the filter here. Note this is also done in the `FilterBuilder`
233
+ // so there is no overhead to do this here.
234
+ if filter. nulls ( ) . is_some ( ) {
235
+ filter = prep_null_mask_filter ( & filter) ;
236
+ }
237
+
238
+ let filter_predicate = FilterBuilder :: new ( & filter) ;
239
+ let filter_predicate = if sort_keys. len ( ) > 1 {
240
+ // Optimize filter when it has multiple sort keys
241
+ filter_predicate. optimize ( ) . build ( )
242
+ } else {
243
+ filter_predicate. build ( )
244
+ } ;
245
+ selected_rows = Some ( filter) ;
246
+ sort_keys = sort_keys
247
+ . iter ( )
248
+ . map ( |key| filter_predicate. filter ( key) . map_err ( |x| x. into ( ) ) )
249
+ . collect :: < Result < Vec < _ > > > ( ) ?;
250
+ }
251
+ } ;
215
252
// reuse existing `Rows` to avoid reallocations
216
253
let rows = & mut self . scratch_rows ;
217
254
rows. clear ( ) ;
218
255
self . row_converter . append ( rows, & sort_keys) ?;
219
256
220
257
let mut batch_entry = self . heap . register_batch ( batch. clone ( ) ) ;
221
258
222
- let replacements =
223
- self . find_new_topk_items ( 0 ..sort_keys[ 0 ] . len ( ) , & mut batch_entry) ;
259
+ let replacements = match selected_rows {
260
+ Some ( filter) => {
261
+ self . find_new_topk_items ( filter. values ( ) . set_indices ( ) , & mut batch_entry)
262
+ }
263
+ None => self . find_new_topk_items ( 0 ..sort_keys[ 0 ] . len ( ) , & mut batch_entry) ,
264
+ } ;
224
265
225
266
if replacements > 0 {
226
267
self . metrics . row_replacements . add ( replacements) ;
0 commit comments