From 63dfab6713fc9b09b5b49875313b04dfec07b202 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 8 Sep 2023 16:59:32 -0400 Subject: [PATCH] refactor: port get_scan_files to Ballista --- ballista/scheduler/src/cluster/kv.rs | 3 +-- ballista/scheduler/src/cluster/memory.rs | 3 +-- ballista/scheduler/src/cluster/mod.rs | 30 ++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/ballista/scheduler/src/cluster/kv.rs b/ballista/scheduler/src/cluster/kv.rs index 53372f8df..25bb8e754 100644 --- a/ballista/scheduler/src/cluster/kv.rs +++ b/ballista/scheduler/src/cluster/kv.rs @@ -17,7 +17,7 @@ use crate::cluster::storage::{KeyValueStore, Keyspace, Lock, Operation, WatchEvent}; use crate::cluster::{ - bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, + bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, get_scan_files, is_skip_consistent_hash, BoundTask, ClusterState, ExecutorHeartbeatStream, ExecutorSlot, JobState, JobStateEvent, JobStateEventStream, JobStatus, TaskDistributionPolicy, TopologyNode, @@ -39,7 +39,6 @@ use ballista_core::serde::protobuf::{ use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; use ballista_core::serde::BallistaCodec; use dashmap::DashMap; -use datafusion::datasource::physical_plan::get_scan_files; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use datafusion_proto::logical_plan::AsLogicalPlan; diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs index 03a9358d7..f2fe589a8 100644 --- a/ballista/scheduler/src/cluster/memory.rs +++ b/ballista/scheduler/src/cluster/memory.rs @@ -16,7 +16,7 @@ // under the License. use crate::cluster::{ - bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, + bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, get_scan_files, is_skip_consistent_hash, BoundTask, ClusterState, ExecutorSlot, JobState, JobStateEvent, JobStateEventStream, JobStatus, TaskDistributionPolicy, TopologyNode, }; @@ -42,7 +42,6 @@ use std::collections::{HashMap, HashSet}; use std::ops::DerefMut; use ballista_core::consistent_hash::node::Node; -use datafusion::datasource::physical_plan::get_scan_files; use datafusion::physical_plan::ExecutionPlan; use std::sync::Arc; use tokio::sync::{Mutex, MutexGuard}; diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index 12938aa12..793d3fc1f 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -21,7 +21,11 @@ use std::pin::Pin; use std::sync::Arc; use clap::ArgEnum; +use datafusion::common::tree_node::TreeNode; +use datafusion::common::tree_node::VisitRecursion; use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{AvroExec, CsvExec, NdJsonExec, ParquetExec}; +use datafusion::error::DataFusionError; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use datafusion_proto::logical_plan::AsLogicalPlan; @@ -680,6 +684,32 @@ pub(crate) fn is_skip_consistent_hash(scan_files: &[Vec>]) scan_files.is_empty() || scan_files.len() > 1 } +/// Get all of the [`PartitionedFile`] to be scanned for an [`ExecutionPlan`] +pub(crate) fn get_scan_files( + plan: Arc, +) -> std::result::Result>>, DataFusionError> { + let mut collector: Vec>> = vec![]; + plan.apply(&mut |plan| { + let plan_any = plan.as_any(); + let file_groups = + if let Some(parquet_exec) = plan_any.downcast_ref::() { + parquet_exec.base_config().file_groups.clone() + } else if let Some(avro_exec) = plan_any.downcast_ref::() { + avro_exec.base_config().file_groups.clone() + } else if let Some(json_exec) = plan_any.downcast_ref::() { + json_exec.base_config().file_groups.clone() + } else if let Some(csv_exec) = plan_any.downcast_ref::() { + csv_exec.base_config().file_groups.clone() + } else { + return Ok(VisitRecursion::Continue); + }; + + collector.push(file_groups); + Ok(VisitRecursion::Skip) + })?; + Ok(collector) +} + #[derive(Clone)] pub struct TopologyNode { pub id: String,