From 63dfab6713fc9b09b5b49875313b04dfec07b202 Mon Sep 17 00:00:00 2001
From: Andrew Lamb <andrew@nerdnetworks.org>
Date: Fri, 8 Sep 2023 16:59:32 -0400
Subject: [PATCH] refactor: port get_scan_files to Ballista

---
 ballista/scheduler/src/cluster/kv.rs     |  3 +--
 ballista/scheduler/src/cluster/memory.rs |  3 +--
 ballista/scheduler/src/cluster/mod.rs    | 30 ++++++++++++++++++++++++
 3 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/ballista/scheduler/src/cluster/kv.rs b/ballista/scheduler/src/cluster/kv.rs
index 53372f8df..25bb8e754 100644
--- a/ballista/scheduler/src/cluster/kv.rs
+++ b/ballista/scheduler/src/cluster/kv.rs
@@ -17,7 +17,7 @@
 
 use crate::cluster::storage::{KeyValueStore, Keyspace, Lock, Operation, WatchEvent};
 use crate::cluster::{
-    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin,
+    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, get_scan_files,
     is_skip_consistent_hash, BoundTask, ClusterState, ExecutorHeartbeatStream,
     ExecutorSlot, JobState, JobStateEvent, JobStateEventStream, JobStatus,
     TaskDistributionPolicy, TopologyNode,
@@ -39,7 +39,6 @@ use ballista_core::serde::protobuf::{
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
 use ballista_core::serde::BallistaCodec;
 use dashmap::DashMap;
-use datafusion::datasource::physical_plan::get_scan_files;
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion::prelude::SessionContext;
 use datafusion_proto::logical_plan::AsLogicalPlan;
diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs
index 03a9358d7..f2fe589a8 100644
--- a/ballista/scheduler/src/cluster/memory.rs
+++ b/ballista/scheduler/src/cluster/memory.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use crate::cluster::{
-    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin,
+    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, get_scan_files,
     is_skip_consistent_hash, BoundTask, ClusterState, ExecutorSlot, JobState,
     JobStateEvent, JobStateEventStream, JobStatus, TaskDistributionPolicy, TopologyNode,
 };
@@ -42,7 +42,6 @@ use std::collections::{HashMap, HashSet};
 use std::ops::DerefMut;
 
 use ballista_core::consistent_hash::node::Node;
-use datafusion::datasource::physical_plan::get_scan_files;
 use datafusion::physical_plan::ExecutionPlan;
 use std::sync::Arc;
 use tokio::sync::{Mutex, MutexGuard};
diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs
index 12938aa12..793d3fc1f 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -21,7 +21,11 @@ use std::pin::Pin;
 use std::sync::Arc;
 
 use clap::ArgEnum;
+use datafusion::common::tree_node::TreeNode;
+use datafusion::common::tree_node::VisitRecursion;
 use datafusion::datasource::listing::PartitionedFile;
+use datafusion::datasource::physical_plan::{AvroExec, CsvExec, NdJsonExec, ParquetExec};
+use datafusion::error::DataFusionError;
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion::prelude::SessionContext;
 use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -680,6 +684,32 @@ pub(crate) fn is_skip_consistent_hash(scan_files: &[Vec<Vec<PartitionedFile>>])
     scan_files.is_empty() || scan_files.len() > 1
 }
 
+/// Get all of the [`PartitionedFile`] to be scanned for an [`ExecutionPlan`]
+pub(crate) fn get_scan_files(
+    plan: Arc<dyn ExecutionPlan>,
+) -> std::result::Result<Vec<Vec<Vec<PartitionedFile>>>, DataFusionError> {
+    let mut collector: Vec<Vec<Vec<PartitionedFile>>> = vec![];
+    plan.apply(&mut |plan| {
+        let plan_any = plan.as_any();
+        let file_groups =
+            if let Some(parquet_exec) = plan_any.downcast_ref::<ParquetExec>() {
+                parquet_exec.base_config().file_groups.clone()
+            } else if let Some(avro_exec) = plan_any.downcast_ref::<AvroExec>() {
+                avro_exec.base_config().file_groups.clone()
+            } else if let Some(json_exec) = plan_any.downcast_ref::<NdJsonExec>() {
+                json_exec.base_config().file_groups.clone()
+            } else if let Some(csv_exec) = plan_any.downcast_ref::<CsvExec>() {
+                csv_exec.base_config().file_groups.clone()
+            } else {
+                return Ok(VisitRecursion::Continue);
+            };
+
+        collector.push(file_groups);
+        Ok(VisitRecursion::Skip)
+    })?;
+    Ok(collector)
+}
+
 #[derive(Clone)]
 pub struct TopologyNode {
     pub id: String,