diff --git a/datafusion/physical-plan/src/repartition/on_demand_repartition.rs b/datafusion/physical-plan/src/repartition/on_demand_repartition.rs index 70a93a92d251..f6c72228ee5f 100644 --- a/datafusion/physical-plan/src/repartition/on_demand_repartition.rs +++ b/datafusion/physical-plan/src/repartition/on_demand_repartition.rs @@ -420,7 +420,7 @@ impl OnDemandRepartitionExec { async fn process_input( input: Arc, partition: usize, - buffer_tx: Sender, + buffer_tx: tokio::sync::mpsc::Sender, context: Arc, fetch_time: metrics::Time, send_buffer_time: metrics::Time, @@ -476,7 +476,7 @@ impl OnDemandRepartitionExec { context: Arc, ) -> Result<()> { // initialize buffer channel so that we can pre-fetch from input - let (buffer_tx, buffer_rx) = async_channel::bounded::(2); + let (buffer_tx, mut buffer_rx) = tokio::sync::mpsc::channel(2); // execute the child operator in a separate task // that pushes batches into buffer channel with limited capacity let processing_task = SpawnedTask::spawn(Self::process_input( @@ -491,12 +491,6 @@ impl OnDemandRepartitionExec { let mut batches_until_yield = partitioning.partition_count(); // When the input is done, break the loop while !output_channels.is_empty() { - // Fetch the batch from the buffer, ideally this should reduce the time gap between the requester and the input stream - let batch = match buffer_rx.recv().await { - Ok(batch) => batch, - _ => break, - }; - // Wait until a partition is requested, then get the output partition information let partition = output_partition_rx.recv().await.map_err(|e| { internal_datafusion_err!( @@ -505,6 +499,25 @@ impl OnDemandRepartitionExec { ) })?; + // Fetch the batch from the buffer, ideally this should reduce the time gap between the requester and the input stream + let batch_opt = loop { + match buffer_rx.try_recv() { + Ok(batch) => break Some(batch), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { + tokio::task::yield_now().await; + } + Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { + break None + } + } + }; + + let batch = if let Some(batch) = batch_opt { + batch + } else { + break; + }; + let size = batch.get_array_memory_size(); let timer = metrics.send_time[partition].timer();