use std::sync::Arc;
use super::demux::DemuxedStreamReceiver;
use super::{BatchSerializer, ObjectWriterBuilder};
use crate::file_compression_type::FileCompressionType;
use datafusion_common::error::Result;
use arrow::array::RecordBatch;
use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError};
use datafusion_common_runtime::{JoinSet, SpawnedTask};
use datafusion_execution::TaskContext;
use bytes::Bytes;
use futures::join;
use object_store::ObjectStore;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver};
type WriterType = Box<dyn AsyncWrite + Send + Unpin>;
type SerializerType = Arc<dyn BatchSerializer>;
pub(crate) enum SerializedRecordBatchResult {
Success {
writer: WriterType,
row_count: usize,
},
Failure {
writer: Option<WriterType>,
err: DataFusionError,
},
}
impl SerializedRecordBatchResult {
pub fn success(writer: WriterType, row_count: usize) -> Self {
Self::Success { writer, row_count }
}
pub fn failure(writer: Option<WriterType>, err: DataFusionError) -> Self {
Self::Failure { writer, err }
}
}
pub(crate) async fn serialize_rb_stream_to_object_store(
mut data_rx: Receiver<RecordBatch>,
serializer: Arc<dyn BatchSerializer>,
mut writer: WriterType,
) -> SerializedRecordBatchResult {
let (tx, mut rx) =
mpsc::channel::<SpawnedTask<Result<(usize, Bytes), DataFusionError>>>(100);
let serialize_task = SpawnedTask::spawn(async move {
let mut initial = true;
while let Some(batch) = data_rx.recv().await {
let serializer_clone = Arc::clone(&serializer);
let task = SpawnedTask::spawn(async move {
let num_rows = batch.num_rows();
let bytes = serializer_clone.serialize(batch, initial)?;
Ok((num_rows, bytes))
});
if initial {
initial = false;
}
tx.send(task).await.map_err(|_| {
internal_datafusion_err!("Unknown error writing to object store")
})?;
}
Ok(())
});
let mut row_count = 0;
while let Some(task) = rx.recv().await {
match task.join().await {
Ok(Ok((cnt, bytes))) => {
match writer.write_all(&bytes).await {
Ok(_) => (),
Err(e) => {
return SerializedRecordBatchResult::failure(
None,
DataFusionError::Execution(format!(
"Error writing to object store: {e}"
)),
)
}
};
row_count += cnt;
}
Ok(Err(e)) => {
return SerializedRecordBatchResult::failure(Some(writer), e);
}
Err(e) => {
return SerializedRecordBatchResult::failure(
Some(writer),
DataFusionError::Execution(format!(
"Serialization task panicked or was cancelled: {e}"
)),
);
}
}
}
match serialize_task.join().await {
Ok(Ok(_)) => (),
Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e),
Err(_) => {
return SerializedRecordBatchResult::failure(
Some(writer),
internal_datafusion_err!("Unknown error writing to object store"),
)
}
}
SerializedRecordBatchResult::success(writer, row_count)
}
type FileWriteBundle = (Receiver<RecordBatch>, SerializerType, WriterType);
pub(crate) async fn stateless_serialize_and_write_files(
mut rx: Receiver<FileWriteBundle>,
tx: tokio::sync::oneshot::Sender<u64>,
) -> Result<()> {
let mut row_count = 0;
let mut any_errors = false;
let mut triggering_error = None;
let mut any_abort_errors = false;
let mut join_set = JoinSet::new();
while let Some((data_rx, serializer, writer)) = rx.recv().await {
join_set.spawn(async move {
serialize_rb_stream_to_object_store(data_rx, serializer, writer).await
});
}
let mut finished_writers = Vec::new();
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => match res {
SerializedRecordBatchResult::Success {
writer,
row_count: cnt,
} => {
finished_writers.push(writer);
row_count += cnt;
}
SerializedRecordBatchResult::Failure { writer, err } => {
finished_writers.extend(writer);
any_errors = true;
triggering_error = Some(err);
}
},
Err(e) => {
any_errors = true;
any_abort_errors = true;
triggering_error = Some(internal_datafusion_err!(
"Unexpected join error while serializing file {e}"
));
}
}
}
for mut writer in finished_writers.into_iter() {
writer.shutdown()
.await
.map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?;
}
if any_errors {
match any_abort_errors{
true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."),
false => match triggering_error {
Some(e) => return Err(e),
None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers successfully aborted.")
}
}
}
tx.send(row_count as u64).map_err(|_| {
internal_datafusion_err!(
"Error encountered while sending row count back to file sink!"
)
})?;
Ok(())
}
pub async fn spawn_writer_tasks_and_join(
context: &Arc<TaskContext>,
serializer: Arc<dyn BatchSerializer>,
compression: FileCompressionType,
object_store: Arc<dyn ObjectStore>,
demux_task: SpawnedTask<Result<()>>,
mut file_stream_rx: DemuxedStreamReceiver,
) -> Result<u64> {
let rb_buffer_size = &context
.session_config()
.options()
.execution
.max_buffered_batches_per_output_file;
let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
let write_coordinator_task = SpawnedTask::spawn(async move {
stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
});
while let Some((location, rb_stream)) = file_stream_rx.recv().await {
let writer =
ObjectWriterBuilder::new(compression, &location, Arc::clone(&object_store))
.with_buffer_size(Some(
context
.session_config()
.options()
.execution
.objectstore_writer_buffer_size,
))
.build()?;
if tx_file_bundle
.send((rb_stream, Arc::clone(&serializer), writer))
.await
.is_err()
{
internal_datafusion_err!(
"Writer receive file bundle channel closed unexpectedly!"
);
}
}
drop(tx_file_bundle);
let (r1, r2) = join!(
write_coordinator_task.join_unwind(),
demux_task.join_unwind()
);
r1.map_err(DataFusionError::ExecutionJoin)??;
r2.map_err(DataFusionError::ExecutionJoin)??;
rx_row_cnt.await.map_err(|_| {
internal_datafusion_err!("Did not receive row count from write coordinator")
})
}