Skip to content

Commit c8df7fe

Browse files
committed
refactor: Used tokio mpsc for OnDemandRepartiion
1 parent f716c27 commit c8df7fe

File tree

3 files changed

+290
-136
lines changed

3 files changed

+290
-136
lines changed

datafusion/physical-plan/src/repartition/distributor_channels.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,17 @@ pub fn channels<T>(
7979
(senders, receivers)
8080
}
8181

82+
pub fn tokio_channels<T>(
83+
n: usize,
84+
) -> (
85+
Vec<tokio::sync::mpsc::Sender<T>>,
86+
Vec<tokio::sync::mpsc::Receiver<T>>,
87+
) {
88+
// only used for OnDemandRepartitionExec, so no need for unbounded capacity
89+
let (senders, receivers) = (0..n).map(|_| tokio::sync::mpsc::channel(2)).unzip();
90+
(senders, receivers)
91+
}
92+
8293
type PartitionAwareSenders<T> = Vec<Vec<DistributionSender<T>>>;
8394
type PartitionAwareReceivers<T> = Vec<Vec<DistributionReceiver<T>>>;
8495

@@ -92,6 +103,19 @@ pub fn partition_aware_channels<T>(
92103
(0..n_in).map(|_| channels(n_out)).unzip()
93104
}
94105

106+
type OnDemandPartitionAwareSenders<T> = Vec<Vec<tokio::sync::mpsc::Sender<T>>>;
107+
type OnDemandPartitionAwareReceivers<T> = Vec<Vec<tokio::sync::mpsc::Receiver<T>>>;
108+
109+
pub fn on_demand_partition_aware_channels<T>(
110+
n_in: usize,
111+
n_out: usize,
112+
) -> (
113+
OnDemandPartitionAwareSenders<T>,
114+
OnDemandPartitionAwareReceivers<T>,
115+
) {
116+
(0..n_in).map(|_| tokio_channels(n_out)).unzip()
117+
}
118+
95119
/// Erroring during [send](DistributionSender::send).
96120
///
97121
/// This occurs when the [receiver](DistributionReceiver) is gone.

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 34 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Stat
4343
use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
4444
use arrow::compute::take_arrays;
4545
use arrow::datatypes::{SchemaRef, UInt32Type};
46-
use async_channel::Receiver;
4746
use datafusion_common::utils::transpose;
4847
use datafusion_common::HashMap;
4948
use datafusion_common::{not_impl_err, DataFusionError, Result};
@@ -56,7 +55,6 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering;
5655
use futures::stream::Stream;
5756
use futures::{ready, FutureExt, StreamExt, TryStreamExt};
5857
use log::trace;
59-
use on_demand_repartition::{OnDemandRepartitionExec, OnDemandRepartitionMetrics};
6058
use parking_lot::Mutex;
6159

6260
mod distributor_channels;
@@ -66,68 +64,6 @@ type MaybeBatch = Option<Result<RecordBatch>>;
6664
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
6765
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
6866

69-
struct RepartitionExecStateBuilder {
70-
/// Whether to enable pull based execution.
71-
enable_pull_based: bool,
72-
partition_receivers: Option<Vec<Receiver<usize>>>,
73-
}
74-
75-
impl RepartitionExecStateBuilder {
76-
fn new() -> Self {
77-
Self {
78-
enable_pull_based: false,
79-
partition_receivers: None,
80-
}
81-
}
82-
fn enable_pull_based(mut self, enable_pull_based: bool) -> Self {
83-
self.enable_pull_based = enable_pull_based;
84-
self
85-
}
86-
fn partition_receivers(mut self, partition_receivers: Vec<Receiver<usize>>) -> Self {
87-
self.partition_receivers = Some(partition_receivers);
88-
self
89-
}
90-
91-
fn build(
92-
&self,
93-
input: Arc<dyn ExecutionPlan>,
94-
partitioning: Partitioning,
95-
metrics: ExecutionPlanMetricsSet,
96-
preserve_order: bool,
97-
name: String,
98-
context: Arc<TaskContext>,
99-
) -> RepartitionExecState {
100-
RepartitionExecState::new(
101-
input,
102-
partitioning,
103-
metrics,
104-
preserve_order,
105-
name,
106-
context,
107-
self.enable_pull_based,
108-
self.partition_receivers.clone(),
109-
)
110-
}
111-
}
112-
113-
/// Inner state of [`RepartitionExec`].
114-
#[derive(Debug)]
115-
struct RepartitionExecState {
116-
/// Channels for sending batches from input partitions to output partitions.
117-
/// Key is the partition number.
118-
channels: HashMap<
119-
usize,
120-
(
121-
InputPartitionsToCurrentPartitionSender,
122-
InputPartitionsToCurrentPartitionReceiver,
123-
SharedMemoryReservation,
124-
),
125-
>,
126-
127-
/// Helper that ensures that that background job is killed once it is no longer needed.
128-
abort_helper: Arc<Vec<SpawnedTask<()>>>,
129-
}
130-
13167
/// create channels for sending batches from input partitions to output partitions.
13268
fn create_repartition_channels(
13369
preserve_order: bool,
@@ -185,17 +121,33 @@ fn create_partition_channels_hashmap(
185121

186122
channels
187123
}
124+
125+
/// Inner state of [`RepartitionExec`].
126+
#[derive(Debug)]
127+
struct RepartitionExecState {
128+
/// Channels for sending batches from input partitions to output partitions.
129+
/// Key is the partition number.
130+
channels: HashMap<
131+
usize,
132+
(
133+
InputPartitionsToCurrentPartitionSender,
134+
InputPartitionsToCurrentPartitionReceiver,
135+
SharedMemoryReservation,
136+
),
137+
>,
138+
139+
/// Helper that ensures that that background job is killed once it is no longer needed.
140+
abort_helper: Arc<Vec<SpawnedTask<()>>>,
141+
}
142+
188143
impl RepartitionExecState {
189-
#[allow(clippy::too_many_arguments)]
190144
fn new(
191145
input: Arc<dyn ExecutionPlan>,
192146
partitioning: Partitioning,
193147
metrics: ExecutionPlanMetricsSet,
194148
preserve_order: bool,
195149
name: String,
196150
context: Arc<TaskContext>,
197-
enable_pull_based: bool,
198-
partition_receivers: Option<Vec<Receiver<usize>>>,
199151
) -> Self {
200152
let num_input_partitions = input.output_partitioning().partition_count();
201153
let num_output_partitions = partitioning.partition_count();
@@ -219,42 +171,16 @@ impl RepartitionExecState {
219171
})
220172
.collect();
221173

222-
let input_task = if enable_pull_based {
223-
let partition_rx = if preserve_order {
224-
partition_receivers.clone().expect(
225-
"partition_receivers must be provided when preserve_order is enabled",
226-
)[i]
227-
.clone()
228-
} else {
229-
partition_receivers.clone().expect(
230-
"partition_receivers must be provided when preserve_order is disabled",
231-
)[0].clone()
232-
};
233-
let r_metrics =
234-
OnDemandRepartitionMetrics::new(i, num_output_partitions, &metrics);
235-
236-
SpawnedTask::spawn(OnDemandRepartitionExec::pull_from_input(
237-
Arc::clone(&input),
238-
i,
239-
txs.clone(),
240-
partitioning.clone(),
241-
partition_rx,
242-
r_metrics,
243-
Arc::clone(&context),
244-
))
245-
} else {
246-
let r_metrics =
247-
RepartitionMetrics::new(i, num_output_partitions, &metrics);
248-
249-
SpawnedTask::spawn(RepartitionExec::pull_from_input(
250-
Arc::clone(&input),
251-
i,
252-
txs.clone(),
253-
partitioning.clone(),
254-
r_metrics,
255-
Arc::clone(&context),
256-
))
257-
};
174+
let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics);
175+
176+
let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
177+
Arc::clone(&input),
178+
i,
179+
txs.clone(),
180+
partitioning.clone(),
181+
r_metrics,
182+
Arc::clone(&context),
183+
));
258184

259185
// In a separate task, wait for each input to be done
260186
// (and pass along any errors, including panic!s)
@@ -268,7 +194,6 @@ impl RepartitionExecState {
268194

269195
spawned_tasks.push(wait_for_task);
270196
}
271-
272197
Self {
273198
channels,
274199
abort_helper: Arc::new(spawned_tasks),
@@ -467,8 +392,6 @@ pub struct RepartitionExecBase {
467392
preserve_order: bool,
468393
/// Cache holding plan properties like equivalences, output partitioning etc.
469394
cache: PlanProperties,
470-
/// Inner state that is initialized when the first output stream is created.
471-
state: LazyState,
472395
}
473396

474397
impl RepartitionExecBase {
@@ -611,6 +534,8 @@ impl RepartitionExecBase {
611534
pub struct RepartitionExec {
612535
/// Common fields for all repartitioning executors
613536
base: RepartitionExecBase,
537+
/// Inner state that is initialized when the first output stream is created.
538+
state: LazyState,
614539
}
615540

616541
#[derive(Debug, Clone)]
@@ -776,7 +701,7 @@ impl ExecutionPlan for RepartitionExec {
776701
partition
777702
);
778703

779-
let lazy_state = Arc::clone(&self.base.state);
704+
let lazy_state = Arc::clone(&self.state);
780705
let input = Arc::clone(&self.base.input);
781706
let partitioning = self.partitioning().clone();
782707
let metrics = self.base.metrics.clone();
@@ -797,7 +722,7 @@ impl ExecutionPlan for RepartitionExec {
797722
let context_captured = Arc::clone(&context);
798723
let state = lazy_state
799724
.get_or_init(|| async move {
800-
Mutex::new(RepartitionExecStateBuilder::new().build(
725+
Mutex::new(RepartitionExecState::new(
801726
input_captured,
802727
partitioning.clone(),
803728
metrics_captured,
@@ -945,11 +870,11 @@ impl RepartitionExec {
945870
Ok(RepartitionExec {
946871
base: RepartitionExecBase {
947872
input,
948-
state: Default::default(),
949873
metrics: ExecutionPlanMetricsSet::new(),
950874
preserve_order,
951875
cache,
952876
},
877+
state: Default::default(),
953878
})
954879
}
955880

0 commit comments

Comments
 (0)