use crate::error::{DataFusionError, Result};
use crate::physical_plan::stream::RecordBatchStreamAdapter;
use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
use crate::scheduler::{
is_worker, plan::PipelinePlan, spawn_local, spawn_local_fifo, RoutablePipeline,
Spawner,
};
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use futures::channel::mpsc;
use futures::task::ArcWake;
use futures::{ready, Stream, StreamExt};
use log::{debug, trace};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
pub fn spawn_plan(plan: PipelinePlan, spawner: Spawner) -> ExecutionResults {
debug!("Spawning pipeline plan: {:#?}", plan);
let (senders, receivers) = (0..plan.output_partitions)
.map(|_| mpsc::unbounded())
.unzip::<_, _, Vec<_>, Vec<_>>();
let context = Arc::new(ExecutionContext {
spawner,
pipelines: plan.pipelines,
schema: plan.schema,
output: senders,
});
for (pipeline_idx, query_pipeline) in context.pipelines.iter().enumerate() {
for partition in 0..query_pipeline.pipeline.output_partitions() {
context.spawner.spawn(Task {
context: context.clone(),
waker: Arc::new(TaskWaker {
context: Arc::downgrade(&context),
wake_count: AtomicUsize::new(1),
pipeline: pipeline_idx,
partition,
}),
});
}
}
let partitions = receivers
.into_iter()
.map(|receiver| ExecutionResultStream {
receiver: receiver,
context: context.clone(),
})
.collect();
ExecutionResults {
streams: partitions,
context,
}
}
pub struct Task {
context: Arc<ExecutionContext>,
waker: Arc<TaskWaker>,
}
impl std::fmt::Debug for Task {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let output = &self.context.pipelines[self.waker.pipeline].output;
f.debug_struct("Task")
.field("pipeline", &self.waker.pipeline)
.field("partition", &self.waker.partition)
.field("output", &output)
.finish()
}
}
impl Task {
fn handle_error(
&self,
partition: usize,
routable: &RoutablePipeline,
error: DataFusionError,
) {
self.context.send_query_output(partition, Err(error));
if let Some(link) = routable.output {
trace!(
"Closing pipeline: {:?}, partition: {}, due to error",
link,
self.waker.partition,
);
self.context.pipelines[link.pipeline]
.pipeline
.close(link.child, self.waker.partition);
}
}
pub fn do_work(self) {
assert!(is_worker(), "Task::do_work called outside of worker pool");
if self.context.is_cancelled() {
return;
}
let wake_count = self.waker.wake_count.load(Ordering::SeqCst);
let node = self.waker.pipeline;
let partition = self.waker.partition;
let waker = futures::task::waker_ref(&self.waker);
let mut cx = Context::from_waker(&*waker);
let pipelines = &self.context.pipelines;
let routable = &pipelines[node];
match routable.pipeline.poll_partition(&mut cx, partition) {
Poll::Ready(Some(Ok(batch))) => {
trace!("Poll {:?}: Ok: {}", self, batch.num_rows());
match routable.output {
Some(link) => {
trace!(
"Publishing batch to pipeline {:?} partition {}",
link,
partition
);
let r = pipelines[link.pipeline]
.pipeline
.push(batch, link.child, partition);
if let Err(e) = r {
self.handle_error(partition, routable, e);
return;
}
}
None => {
trace!("Publishing batch to output");
self.context.send_query_output(partition, Ok(batch))
}
}
spawn_local_fifo(self);
}
Poll::Ready(Some(Err(e))) => {
trace!("Poll {:?}: Error: {:?}", self, e);
self.handle_error(partition, routable, e)
}
Poll::Ready(None) => {
trace!("Poll {:?}: None", self);
match routable.output {
Some(link) => {
trace!("Closing pipeline: {:?}, partition: {}", link, partition);
pipelines[link.pipeline]
.pipeline
.close(link.child, partition)
}
None => self.context.finish(partition),
}
}
Poll::Pending => {
trace!("Poll {:?}: Pending", self);
let reset = self.waker.wake_count.compare_exchange(
wake_count,
0,
Ordering::SeqCst,
Ordering::SeqCst,
);
if reset.is_err() {
trace!("Wakeup triggered whilst polling: {:?}", self);
spawn_local(self);
}
}
}
}
}
pub struct ExecutionResults {
streams: Vec<ExecutionResultStream>,
context: Arc<ExecutionContext>,
}
impl ExecutionResults {
pub fn stream(self) -> SendableRecordBatchStream {
let schema = self.context.schema.clone();
Box::pin(RecordBatchStreamAdapter::new(
schema,
futures::stream::select_all(self.streams),
))
}
pub fn stream_partitioned(self) -> Vec<SendableRecordBatchStream> {
self.streams.into_iter().map(|s| Box::pin(s) as _).collect()
}
}
struct ExecutionResultStream {
receiver: mpsc::UnboundedReceiver<Option<Result<RecordBatch>>>,
context: Arc<ExecutionContext>,
}
impl Stream for ExecutionResultStream {
type Item = arrow::error::Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let opt = ready!(self.receiver.poll_next_unpin(cx)).flatten();
Poll::Ready(opt.map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e)))))
}
}
impl RecordBatchStream for ExecutionResultStream {
fn schema(&self) -> SchemaRef {
self.context.schema.clone()
}
}
#[derive(Debug)]
struct ExecutionContext {
spawner: Spawner,
pipelines: Vec<RoutablePipeline>,
pub schema: SchemaRef,
output: Vec<mpsc::UnboundedSender<Option<Result<RecordBatch>>>>,
}
impl Drop for ExecutionContext {
fn drop(&mut self) {
debug!("ExecutionContext dropped");
}
}
impl ExecutionContext {
fn is_cancelled(&self) -> bool {
self.output.iter().all(|x| x.is_closed())
}
fn send_query_output(&self, partition: usize, output: Result<RecordBatch>) {
let _ = self.output[partition].unbounded_send(Some(output));
}
fn finish(&self, partition: usize) {
let _ = self.output[partition].unbounded_send(None);
}
}
struct TaskWaker {
context: Weak<ExecutionContext>,
wake_count: AtomicUsize,
pipeline: usize,
partition: usize,
}
impl ArcWake for TaskWaker {
fn wake(self: Arc<Self>) {
if self.wake_count.fetch_add(1, Ordering::SeqCst) != 0 {
trace!("Ignoring duplicate wakeup");
return;
}
if let Some(context) = self.context.upgrade() {
let task = Task {
context,
waker: self.clone(),
};
trace!("Wakeup {:?}", task);
match is_worker() {
true => spawn_local(task),
false => task.context.spawner.clone().spawn(task),
}
} else {
trace!("Dropped wakeup");
}
}
fn wake_by_ref(s: &Arc<Self>) {
ArcWake::wake(s.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::scheduler::{pipeline::Pipeline, plan::RoutablePipeline, Scheduler};
use arrow::array::{ArrayRef, Int32Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use futures::{channel::oneshot, ready, FutureExt, StreamExt};
use parking_lot::Mutex;
use std::fmt::Debug;
use std::time::Duration;
#[derive(Debug)]
struct TokioPipeline {
handle: tokio::runtime::Handle,
state: Mutex<State>,
}
#[derive(Debug)]
enum State {
Init,
Wait(oneshot::Receiver<Result<RecordBatch>>),
Finished,
}
impl Default for State {
fn default() -> Self {
Self::Init
}
}
impl Pipeline for TokioPipeline {
fn push(
&self,
_input: RecordBatch,
_child: usize,
_partition: usize,
) -> Result<()> {
unreachable!()
}
fn close(&self, _child: usize, _partition: usize) {}
fn output_partitions(&self) -> usize {
1
}
fn poll_partition(
&self,
cx: &mut Context<'_>,
_partition: usize,
) -> Poll<Option<Result<RecordBatch>>> {
let mut state = self.state.lock();
loop {
match &mut *state {
State::Init => {
let (sender, receiver) = oneshot::channel();
self.handle.spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
let array = Int32Array::from_iter_values([1, 2, 3]);
sender.send(
RecordBatch::try_from_iter([(
"int",
Arc::new(array) as ArrayRef,
)])
.map_err(DataFusionError::ArrowError),
)
});
*state = State::Wait(receiver)
}
State::Wait(r) => {
let v = ready!(r.poll_unpin(cx)).ok();
*state = State::Finished;
return Poll::Ready(v);
}
State::Finished => return Poll::Ready(None),
}
}
}
}
#[test]
fn test_tokio_waker() {
let scheduler = Scheduler::new(2);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_time()
.build()
.unwrap();
let pipeline = TokioPipeline {
handle: runtime.handle().clone(),
state: Default::default(),
};
let plan = PipelinePlan {
schema: Arc::new(Schema::new(vec![Field::new(
"int",
DataType::Int32,
false,
)])),
output_partitions: 1,
pipelines: vec![RoutablePipeline {
pipeline: Box::new(pipeline),
output: None,
}],
};
let mut receiver = scheduler.schedule_plan(plan).stream();
runtime.block_on(async move {
let batch = receiver.next().await.unwrap().unwrap();
assert_eq!(batch.num_rows(), 3);
assert!(receiver.next().await.is_none());
})
}
}