Skip to content

Commit 8b71674

Browse files
committed
chore
1 parent aa79feb commit 8b71674

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

datafusion/physical-plan/src/repartition/on_demand_repartition.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ impl OnDemandRepartitionExec {
420420
async fn process_input(
421421
input: Arc<dyn ExecutionPlan>,
422422
partition: usize,
423-
buffer_tx: Sender<RecordBatch>,
423+
buffer_tx: tokio::sync::mpsc::Sender<RecordBatch>,
424424
context: Arc<TaskContext>,
425425
fetch_time: metrics::Time,
426426
send_buffer_time: metrics::Time,
@@ -476,7 +476,7 @@ impl OnDemandRepartitionExec {
476476
context: Arc<TaskContext>,
477477
) -> Result<()> {
478478
// initialize buffer channel so that we can pre-fetch from input
479-
let (buffer_tx, buffer_rx) = async_channel::bounded::<RecordBatch>(2);
479+
let (buffer_tx, mut buffer_rx) = tokio::sync::mpsc::channel(2);
480480
// execute the child operator in a separate task
481481
// that pushes batches into buffer channel with limited capacity
482482
let processing_task = SpawnedTask::spawn(Self::process_input(
@@ -491,12 +491,6 @@ impl OnDemandRepartitionExec {
491491
let mut batches_until_yield = partitioning.partition_count();
492492
// When the input is done, break the loop
493493
while !output_channels.is_empty() {
494-
// Fetch the batch from the buffer, ideally this should reduce the time gap between the requester and the input stream
495-
let batch = match buffer_rx.recv().await {
496-
Ok(batch) => batch,
497-
_ => break,
498-
};
499-
500494
// Wait until a partition is requested, then get the output partition information
501495
let partition = output_partition_rx.recv().await.map_err(|e| {
502496
internal_datafusion_err!(
@@ -505,6 +499,25 @@ impl OnDemandRepartitionExec {
505499
)
506500
})?;
507501

502+
// Fetch the batch from the buffer, ideally this should reduce the time gap between the requester and the input stream
503+
let batch_opt = loop {
504+
match buffer_rx.try_recv() {
505+
Ok(batch) => break Some(batch),
506+
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {
507+
tokio::task::yield_now().await;
508+
}
509+
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
510+
break None
511+
}
512+
}
513+
};
514+
515+
let batch = if let Some(batch) = batch_opt {
516+
batch
517+
} else {
518+
break;
519+
};
520+
508521
let size = batch.get_array_memory_size();
509522

510523
let timer = metrics.send_time[partition].timer();

0 commit comments

Comments
 (0)