Skip to content

Commit

Permalink
fix: Make the Parquet Sink properly phase aware (#21499)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Feb 27, 2025
1 parent 823bfc9 commit c1e6be9
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 179 deletions.
15 changes: 4 additions & 11 deletions crates/polars-stream/src/nodes/io_sinks/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ impl SinkNode for CsvSinkNode {
let mut allocation_size = DEFAULT_ALLOCATION_SIZE;
let options = options.clone();

while let Ok((phase_consume_token, outcome, mut receiver, mut sender)) =
rx_receiver.recv().await
{
while let Ok((outcome, mut receiver, mut sender)) = rx_receiver.recv().await {
while let Ok(morsel) = receiver.recv().await {
let (df, seq, _, consume_token) = morsel.into_inner();

Expand Down Expand Up @@ -97,8 +95,7 @@ impl SinkNode for CsvSinkNode {
// backpressure.
}

outcome.stop();
drop(phase_consume_token);
outcome.stopped();
}

PolarsResult::Ok(())
Expand Down Expand Up @@ -136,15 +133,11 @@ impl SinkNode for CsvSinkNode {
file = tokio::fs::File::from_std(std_file);
}

while let Ok((phase_consume_token, outcome, mut linearizer)) =
recv_linearizer.recv().await
{
while let Ok((outcome, mut linearizer)) = recv_linearizer.recv().await {
while let Some(Priority(_, buffer)) = linearizer.get().await {
file.write_all(&buffer).await?;
}

outcome.stop();
drop(phase_consume_token);
outcome.stopped();
}

PolarsResult::Ok(())
Expand Down
7 changes: 3 additions & 4 deletions crates/polars-stream/src/nodes/io_sinks/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ impl SinkNode for IpcSinkNode {
.filter_map(|(i, f)| f.contains_dictionary().then_some(i))
.collect::<Vec<_>>();

while let Ok(input) = recv_ports_recv.recv().await {
let mut receiver = input.port.serial();

while let Ok((outcome, port)) = recv_ports_recv.recv().await {
let mut receiver = port.serial();
while let Ok(morsel) = receiver.recv().await {
let df = morsel.into_df();
// @NOTE: This also performs schema validation.
Expand Down Expand Up @@ -152,7 +151,7 @@ impl SinkNode for IpcSinkNode {
}
}

input.outcome.stop();
outcome.stopped();
}

// Flush the remaining rows.
Expand Down
12 changes: 4 additions & 8 deletions crates/polars-stream/src/nodes/io_sinks/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ impl SinkNode for NDJsonSinkNode {
// allocations, we adjust to that over time.
let mut allocation_size = DEFAULT_ALLOCATION_SIZE;

while let Ok((phase_consume_token, outcome, mut rx, mut sender)) =
rx_receiver.recv().await
{
while let Ok((outcome, mut rx, mut sender)) = rx_receiver.recv().await {
while let Ok(morsel) = rx.recv().await {
let (df, seq, _, consume_token) = morsel.into_inner();

Expand All @@ -69,8 +67,7 @@ impl SinkNode for NDJsonSinkNode {
// backpressure.
}

outcome.stop();
drop(phase_consume_token);
outcome.stopped();
}

PolarsResult::Ok(())
Expand All @@ -94,12 +91,11 @@ impl SinkNode for NDJsonSinkNode {
.await
.map_err(|err| polars_utils::_limit_path_len_io_err(path.as_path(), err))?;

while let Ok((consume_token, outcome, mut linearizer)) = rx_linearizer.recv().await {
while let Ok((outcome, mut linearizer)) = rx_linearizer.recv().await {
while let Some(Priority(_, buffer)) = linearizer.get().await {
file.write_all(&buffer).await?;
}
outcome.stop();
drop(consume_token);
outcome.stopped();
}

PolarsResult::Ok(())
Expand Down
167 changes: 117 additions & 50 deletions crates/polars-stream/src/nodes/io_sinks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ use polars_core::config;
use polars_error::PolarsResult;
use polars_expr::state::ExecutionState;

use super::io_sources::PhaseOutcomeToken;
use super::{ComputeNode, JoinHandle, Morsel, PortState, RecvPort, SendPort, TaskScope};
use super::{
ComputeNode, JoinHandle, Morsel, PhaseOutcome, PortState, RecvPort, SendPort, TaskScope,
};
use crate::async_executor::{spawn, AbortOnDropHandle};
use crate::async_primitives::connector::{connector, Receiver, Sender};
use crate::async_primitives::distributor_channel;
use crate::async_primitives::linearizer::{Inserter, Linearizer};
use crate::async_primitives::wait_group::{WaitGroup, WaitToken};
use crate::async_primitives::wait_group::WaitGroup;
use crate::nodes::TaskPriority;
use crate::DEFAULT_LINEARIZER_BUFFER_SIZE;

#[cfg(feature = "csv")]
pub mod csv;
Expand All @@ -23,24 +26,16 @@ pub mod parquet;

// This needs to be low to increase the backpressure.
const DEFAULT_SINK_LINEARIZER_BUFFER_SIZE: usize = 1;
const DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE: usize = 1;

pub enum SinkInputPort {
Serial(Receiver<Morsel>),
Parallel(Vec<Receiver<Morsel>>),
}

pub struct SinkInput {
pub outcome: PhaseOutcomeToken,
pub port: SinkInputPort,

#[allow(unused)]
/// Dropping this indicates that the phase is done.
wait_token: WaitToken,
}

pub struct SinkRecvPort {
num_pipelines: usize,
recv: Receiver<SinkInput>,
recv: Receiver<(PhaseOutcome, SinkInputPort)>,
}

impl SinkInputPort {
Expand Down Expand Up @@ -69,79 +64,150 @@ impl SinkRecvPort {
mut self,
) -> (
JoinHandle<PolarsResult<()>>,
Vec<Receiver<(WaitToken, PhaseOutcomeToken, Receiver<Morsel>, Inserter<T>)>>,
Receiver<(WaitToken, PhaseOutcomeToken, Linearizer<T>)>,
Vec<Receiver<(PhaseOutcome, Receiver<Morsel>, Inserter<T>)>>,
Receiver<(PhaseOutcome, Linearizer<T>)>,
) {
let (mut rx_senders, rx_receivers) = (0..self.num_pipelines)
.map(|_| connector())
.collect::<(Vec<_>, Vec<_>)>();
let (mut tx_linearizer, rx_linearizer) = connector();
let handle = spawn(TaskPriority::High, async move {
let mut outcomes = Vec::with_capacity(self.num_pipelines + 1);
let wg = WaitGroup::default();

while let Ok(input) = self.recv.recv().await {
let inputs = input.port.parallel();
while let Ok((phase_outcome, port)) = self.recv.recv().await {
let inputs = port.parallel();

let (linearizer, senders) =
Linearizer::<T>::new(self.num_pipelines, DEFAULT_SINK_LINEARIZER_BUFFER_SIZE);

let mut outcomes = Vec::with_capacity(inputs.len());
for ((input, rx_sender), sender) in
inputs.into_iter().zip(rx_senders.iter_mut()).zip(senders)
{
let outcome = PhaseOutcomeToken::new();
if rx_sender
.send((wg.token(), outcome.clone(), input, sender))
.await
.is_err()
{
let (token, outcome) = PhaseOutcome::new_shared_wait(wg.token());
if rx_sender.send((outcome, input, sender)).await.is_err() {
return Ok(());
}
outcomes.push(outcome);
outcomes.push(token);
}
let outcome = PhaseOutcomeToken::new();
let (token, outcome) = PhaseOutcome::new_shared_wait(wg.token());
if tx_linearizer.send((outcome, linearizer)).await.is_err() {
return Ok(());
}
outcomes.push(token);

wg.wait().await;
for outcome in &outcomes {
if outcome.did_finish() {
return Ok(());
}
}

phase_outcome.stopped();
outcomes.clear();
}

Ok(())
});

(handle, rx_receivers, rx_linearizer)
}

/// Receive the [`RecvPort`] serially that distributes amongst workers then [`Linearize`] again
/// to the end.
///
/// This is useful for sinks that process incoming [`Morsel`]s column-wise as the processing
/// of the columns can be done in parallel.
#[allow(clippy::type_complexity)]
pub fn serial_into_distribute<D, L>(
mut self,
) -> (
JoinHandle<PolarsResult<()>>,
Receiver<(
PhaseOutcome,
Option<Receiver<Morsel>>,
distributor_channel::Sender<D>,
)>,
Vec<Receiver<(PhaseOutcome, distributor_channel::Receiver<D>, Inserter<L>)>>,
Receiver<(PhaseOutcome, Linearizer<L>)>,
)
where
D: Send + Sync + 'static,
L: Send + Sync + Ord + 'static,
{
let (mut tx_linearizer, rx_linearizer) = connector();
let (mut rx_senders, rx_receivers) = (0..self.num_pipelines)
.map(|_| connector())
.collect::<(Vec<_>, Vec<_>)>();
let (mut tx_end, rx_end) = connector();
let handle = spawn(TaskPriority::High, async move {
let mut outcomes = Vec::with_capacity(self.num_pipelines + 2);
let wg = WaitGroup::default();

let mut stop = false;
while !stop {
let input = self.recv.recv().await;
stop |= input.is_err(); // We want to send one last message without receiver when
// the channel is dropped. This allows us to flush buffers.
let (phase_outcome, receiver) = match input {
Ok((outcome, port)) => (Some(outcome), Some(port.serial())),
Err(()) => (None, None),
};

let (dist_tx, dist_rxs) = distributor_channel::distributor_channel::<D>(
self.num_pipelines,
DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE,
);
let (linearizer, senders) =
Linearizer::<L>::new(self.num_pipelines, DEFAULT_LINEARIZER_BUFFER_SIZE);

let (token, outcome) = PhaseOutcome::new_shared_wait(wg.token());
if tx_linearizer
.send((wg.token(), outcome.clone(), linearizer))
.send((outcome, receiver, dist_tx))
.await
.is_err()
{
return Ok(());
}
outcomes.push(outcome);
outcomes.push(token);
for ((dist_rx, rx_sender), sender) in
dist_rxs.into_iter().zip(rx_senders.iter_mut()).zip(senders)
{
let (token, outcome) = PhaseOutcome::new_shared_wait(wg.token());
if rx_sender.send((outcome, dist_rx, sender)).await.is_err() {
return Ok(());
}
outcomes.push(token);
}
let (token, outcome) = PhaseOutcome::new_shared_wait(wg.token());
if tx_end.send((outcome, linearizer)).await.is_err() {
return Ok(());
}
outcomes.push(token);

wg.wait().await;
for outcome in outcomes {
for outcome in &outcomes {
if outcome.did_finish() {
return Ok(());
}
}
input.outcome.stop();

if let Some(outcome) = phase_outcome {
outcome.stopped()
}
outcomes.clear();
}

Ok(())
});

(handle, rx_receivers, rx_linearizer)
(handle, rx_linearizer, rx_receivers, rx_end)
}
fn serial(self) -> Receiver<SinkInput> {
pub fn serial(self) -> Receiver<(PhaseOutcome, SinkInputPort)> {
self.recv
}
}

impl SinkInput {
pub fn from_port(port: SinkInputPort) -> (PhaseOutcomeToken, WaitGroup, Self) {
let outcome = PhaseOutcomeToken::new();
let wait_group = WaitGroup::default();

let input = Self {
outcome: outcome.clone(),
wait_token: wait_group.token(),
port,
};
(outcome, wait_group, input)
}
}

pub trait SinkNode {
fn name(&self) -> &str;
fn is_sink_input_parallel(&self) -> bool;
Expand All @@ -156,7 +222,7 @@ pub trait SinkNode {

/// The state needed to manage a spawned [`SinkNode`].
struct StartedSinkComputeNode {
input_send: Sender<SinkInput>,
input_send: Sender<(PhaseOutcome, SinkInputPort)>,
join_handles: FuturesUnordered<AbortOnDropHandle<PolarsResult<()>>>,
}

Expand Down Expand Up @@ -251,18 +317,19 @@ impl ComputeNode for SinkComputeNode {
}
});

let wait_group = WaitGroup::default();
let recv = recv_ports[0].take().unwrap();
let sink_input = if self.sink.is_sink_input_parallel() {
SinkInputPort::Parallel(recv.parallel())
} else {
SinkInputPort::Serial(recv.serial())
};
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
let (outcome, wait_group, sink_input) = SinkInput::from_port(sink_input);
if started.input_send.send(sink_input).await.is_ok() {
let (token, outcome) = PhaseOutcome::new_shared_wait(wait_group.token());
if started.input_send.send((outcome, sink_input)).await.is_ok() {
// Wait for the phase to finish.
wait_group.wait().await;
if !outcome.did_finish() {
if !token.did_finish() {
return Ok(());
}

Expand Down
Loading

0 comments on commit c1e6be9

Please sign in to comment.