Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 8 additions & 9 deletions datafusion/core/src/physical_plan/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,19 @@ impl ExecutionPlan for AnalyzeExec {
)));
}

let (tx, rx) = tokio::sync::mpsc::channel(input_partitions);
let mut builder =
RecordBatchReceiverStream::builder(self.schema(), input_partitions);
let tx = builder.tx();

let captured_input = self.input.clone();
let mut input_stream = captured_input.execute(0, context).await?;
let captured_schema = self.schema.clone();
let verbose = self.verbose;

// Task reads batches the input and when complete produce a
// RecordBatch with a report that is written to `tx` when done
let join_handle = tokio::task::spawn(async move {
// Task reads batches from the input and when complete produces
// a RecordBatch with a report that is written to `tx` when
// done. Panics from this task are propagated via the builder.
builder.spawn(async move {
let start = Instant::now();
let mut total_rows = 0;

Expand Down Expand Up @@ -201,11 +204,7 @@ impl ExecutionPlan for AnalyzeExec {
tx.send(maybe_batch).await.ok();
});

Ok(RecordBatchReceiverStream::create(
&self.schema,
rx,
join_handle,
))
Ok(builder.build())
}

fn fmt_as(
Expand Down
81 changes: 26 additions & 55 deletions datafusion/core/src/physical_plan/coalesce_partitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,20 @@

use std::any::Any;
use std::sync::Arc;
use std::task::Poll;

use futures::channel::mpsc;
use futures::Stream;

use async_trait::async_trait;

use arrow::record_batch::RecordBatch;
use arrow::{datatypes::SchemaRef, error::Result as ArrowResult};
use arrow::datatypes::SchemaRef;

use super::common::AbortOnDropMany;
use super::expressions::PhysicalSortExpr;
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use super::{RecordBatchStream, Statistics};
use super::stream::{ObservedStream, RecordBatchReceiverStream};
use super::Statistics;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning};

use super::SendableRecordBatchStream;
use crate::execution::context::TaskContext;
use crate::physical_plan::common::spawn_execution;
use pin_project_lite::pin_project;

/// Merge execution plan executes partitions in parallel and combines them into a single
/// partition. No guarantees are made about the order of the resulting partition.
Expand Down Expand Up @@ -134,27 +127,17 @@ impl ExecutionPlan for CoalescePartitionsExec {
// use a stream that allows each sender to put in at
// least one result in an attempt to maximize
// parallelism.
let (sender, receiver) =
mpsc::channel::<ArrowResult<RecordBatch>>(input_partitions);
let mut builder =
RecordBatchReceiverStream::builder(self.schema(), input_partitions);

// spawn independent tasks whose resulting streams (of batches)
// are sent to the channel for consumption.
let mut join_handles = Vec::with_capacity(input_partitions);
for part_i in 0..input_partitions {
join_handles.push(spawn_execution(
self.input.clone(),
sender.clone(),
part_i,
context.clone(),
));
builder.run_input(self.input.clone(), part_i, context.clone());
}

Ok(Box::pin(MergeStream {
input: receiver,
schema: self.schema(),
baseline_metrics,
drop_helper: AbortOnDropMany(join_handles),
}))
let stream = builder.build();
Ok(Box::pin(ObservedStream::new(stream, baseline_metrics)))
}
}
}
Expand All @@ -180,35 +163,6 @@ impl ExecutionPlan for CoalescePartitionsExec {
}
}

pin_project! {
struct MergeStream {
schema: SchemaRef,
#[pin]
input: mpsc::Receiver<ArrowResult<RecordBatch>>,
baseline_metrics: BaselineMetrics,
drop_helper: AbortOnDropMany<()>,
}
}

impl Stream for MergeStream {
type Item = ArrowResult<RecordBatch>;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.project();
let poll = this.input.poll_next(cx);
this.baseline_metrics.record_poll(poll)
}
}

impl RecordBatchStream for MergeStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}

#[cfg(test)]
mod tests {

Expand All @@ -220,7 +174,9 @@ mod tests {
use crate::physical_plan::file_format::{CsvExec, FileScanConfig};
use crate::physical_plan::{collect, common};
use crate::prelude::SessionContext;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
use crate::test::exec::{
assert_strong_count_converges_to_zero, BlockingExec, PanicExec,
};
use crate::test::{self, assert_is_pending};
use crate::test_util;

Expand Down Expand Up @@ -288,4 +244,19 @@ mod tests {

Ok(())
}

#[tokio::test]
#[should_panic(expected = "PanickingStream did panic")]
async fn test_panic() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));

let panicking_exec = Arc::new(PanicExec::new(Arc::clone(&schema), 2));
let coalesce_partitions_exec =
Arc::new(CoalescePartitionsExec::new(panicking_exec));

collect(coalesce_partitions_exec, task_ctx).await.unwrap();
}
}
112 changes: 30 additions & 82 deletions datafusion/core/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ use std::task::{Context, Poll};
use std::vec;

use ahash::RandomState;
use futures::{
stream::{Stream, StreamExt},
Future,
};
use futures::stream::{Stream, StreamExt};

use crate::error::Result;
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::stream::RecordBatchReceiverStream;
use crate::physical_plan::{
Accumulator, AggregateExpr, DisplayFormatType, Distribution, ExecutionPlan,
Partitioning, PhysicalExpr,
Expand All @@ -39,19 +37,17 @@ use crate::scalar::ScalarValue;
use arrow::{array::ArrayRef, compute, compute::cast};
use arrow::{
array::{Array, UInt32Builder},
error::{ArrowError, Result as ArrowResult},
error::Result as ArrowResult,
};
use arrow::{
datatypes::{Field, Schema, SchemaRef},
record_batch::RecordBatch,
};
use hashbrown::raw::RawTable;
use pin_project_lite::pin_project;

use crate::execution::context::TaskContext;
use async_trait::async_trait;

use super::common::AbortOnDropSingle;
use super::expressions::PhysicalSortExpr;
use super::metrics::{
self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
Expand Down Expand Up @@ -356,14 +352,9 @@ Example: average
* Once all N record batches arrive, `merge` is performed, which builds a RecordBatch with N rows and 2 columns.
* Finally, `get_value` returns an array with one entry computed from the state
*/
pin_project! {
struct GroupedHashAggregateStream {
schema: SchemaRef,
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
drop_helper: AbortOnDropSingle<()>,
}
struct GroupedHashAggregateStream {
schema: SchemaRef,
stream: SendableRecordBatchStream,
}

fn group_aggregate_batch(
Expand Down Expand Up @@ -570,12 +561,16 @@ impl GroupedHashAggregateStream {
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
) -> Self {
let (tx, rx) = futures::channel::oneshot::channel();
// Use the panic-propagating builder so that panics in the
// compute task are re-raised on the consumer side instead of
// being reported as a closed channel.
let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1);
let tx = builder.tx();

let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();

let join_handle = tokio::spawn(async move {
builder.spawn(async move {
let result = compute_grouped_hash_aggregate(
mode,
schema_clone,
Expand All @@ -588,14 +583,12 @@ impl GroupedHashAggregateStream {
.record_output(&baseline_metrics);

// failing here is OK, the receiver is gone and does not care about the result
tx.send(result).ok();
tx.send(result).await.ok();
});

Self {
schema,
output: rx,
finished: false,
drop_helper: AbortOnDropSingle::new(join_handle),
stream: builder.build(),
}
}
}
Expand Down Expand Up @@ -647,31 +640,10 @@ impl Stream for GroupedHashAggregateStream {
type Item = ArrowResult<RecordBatch>;

fn poll_next(
self: std::pin::Pin<&mut Self>,
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.finished {
return Poll::Ready(None);
}

// is the output ready?
let this = self.project();
let output_poll = this.output.poll(cx);

match output_poll {
Poll::Ready(result) => {
*this.finished = true;

// check for error in receiving channel and unwrap actual result
let result = match result {
Err(e) => Err(ArrowError::ExternalError(Box::new(e))), // error receiving
Ok(result) => result,
};

Poll::Ready(Some(result))
}
Poll::Pending => Poll::Pending,
}
self.stream.poll_next_unpin(cx)
}
}

Expand Down Expand Up @@ -748,15 +720,10 @@ fn aggregate_expressions(
}
}

pin_project! {
/// stream struct for hash aggregation
pub struct HashAggregateStream {
schema: SchemaRef,
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
drop_helper: AbortOnDropSingle<()>,
}
/// stream struct for hash aggregation
pub struct HashAggregateStream {
schema: SchemaRef,
stream: SendableRecordBatchStream,
}

/// Special case aggregate with no groups
Expand Down Expand Up @@ -799,11 +766,15 @@ impl HashAggregateStream {
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
) -> Self {
let (tx, rx) = futures::channel::oneshot::channel();
// Use the panic-propagating builder so that panics in the
// compute task are re-raised on the consumer side instead of
// being reported as a closed channel.
let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1);
let tx = builder.tx();

let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let join_handle = tokio::spawn(async move {
builder.spawn(async move {
let result = compute_hash_aggregate(
mode,
schema_clone,
Expand All @@ -815,14 +786,12 @@ impl HashAggregateStream {
.record_output(&baseline_metrics);

// failing here is OK, the receiver is gone and does not care about the result
tx.send(result).ok();
tx.send(result).await.ok();
});

Self {
schema,
output: rx,
finished: false,
drop_helper: AbortOnDropSingle::new(join_handle),
stream: builder.build(),
}
}
}
Expand Down Expand Up @@ -863,31 +832,10 @@ impl Stream for HashAggregateStream {
type Item = ArrowResult<RecordBatch>;

fn poll_next(
self: std::pin::Pin<&mut Self>,
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.finished {
return Poll::Ready(None);
}

// is the output ready?
let this = self.project();
let output_poll = this.output.poll(cx);

match output_poll {
Poll::Ready(result) => {
*this.finished = true;

// check for error in receiving channel and unwrap actual result
let result = match result {
Err(e) => Err(ArrowError::ExternalError(Box::new(e))), // error receiving
Ok(result) => result,
};

Poll::Ready(Some(result))
}
Poll::Pending => Poll::Pending,
}
self.stream.poll_next_unpin(cx)
}
}

Expand Down
Loading
Loading