From 5278744a4e8dc2aefa4f12b9f5b9a5d3d316ff4b Mon Sep 17 00:00:00 2001
From: Andy Grove <andygrove73@gmail.com>
Date: Sat, 15 Oct 2022 11:52:09 -0600
Subject: [PATCH] benchmark look for path with and without extension

---
 benchmarks/src/bin/tpch.rs | 65 ++++++++++++++++++++++++++++++--------
 1 file changed, 52 insertions(+), 13 deletions(-)

diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index ece296fb5..68d350be2 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -368,7 +368,7 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> {
     let path = opt.path.to_str().unwrap();
     let file_format = opt.file_format.as_str();
 
-    register_tables(path, file_format, &ctx).await;
+    register_tables(path, file_format, &ctx, opt.debug).await?;
 
     let mut millis = vec![];
 
@@ -474,7 +474,7 @@ async fn loadtest_ballista(opt: BallistaLoadtestOpt) -> Result<()> {
     let sql_path = opt.sql_path.to_str().unwrap().to_string();
 
     for ctx in &clients {
-        register_tables(path, file_format, ctx).await;
+        register_tables(path, file_format, ctx, opt.debug).await?;
     }
 
     let request_per_thread = request_amount.div(concurrency);
@@ -552,44 +552,83 @@ fn get_query_sql_by_path(query: usize, mut sql_path: String) -> Result<String> {
     }
 }
 
-async fn register_tables(path: &str, file_format: &str, ctx: &BallistaContext) {
+async fn register_tables(
+    path: &str,
+    file_format: &str,
+    ctx: &BallistaContext,
+    debug: bool,
+) -> Result<()> {
     for table in TABLES {
         match file_format {
             // dbgen creates .tbl ('|' delimited) files without header
             "tbl" => {
-                let path = format!("{}/{}.tbl", path, table);
+                let path = find_path(path, table, "tbl")?;
                 let schema = get_schema(table);
                 let options = CsvReadOptions::new()
                     .schema(&schema)
                     .delimiter(b'|')
                     .has_header(false)
                     .file_extension(".tbl");
+                if debug {
+                    println!(
+                        "Registering table '{}' using TBL files at path {}",
+                        table, path
+                    );
+                }
                 ctx.register_csv(table, &path, options)
                     .await
-                    .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))
-                    .unwrap();
+                    .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?;
             }
             "csv" => {
-                let path = format!("{}/{}", path, table);
+                let path = find_path(path, table, "csv")?;
                 let schema = get_schema(table);
                 let options = CsvReadOptions::new().schema(&schema).has_header(true);
+                if debug {
+                    println!(
+                        "Registering table '{}' using CSV files at path {}",
+                        table, path
+                    );
+                }
                 ctx.register_csv(table, &path, options)
                     .await
-                    .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))
-                    .unwrap();
+                    .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?;
             }
             "parquet" => {
-                let path = format!("{}/{}", path, table);
+                let path = find_path(path, table, "parquet")?;
+                if debug {
+                    println!(
+                        "Registering table '{}' using Parquet files at path {}",
+                        table, path
+                    );
+                }
                 ctx.register_parquet(table, &path, ParquetReadOptions::default())
                     .await
-                    .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))
-                    .unwrap();
+                    .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?;
             }
             other => {
-                unimplemented!("Invalid file format '{}'", other);
+                return Err(DataFusionError::Plan(format!(
+                    "Invalid file format '{}'",
+                    other
+                )))
             }
         }
     }
+    Ok(())
+}
+
+fn find_path(path: &str, table: &str, ext: &str) -> Result<String> {
+    let path1 = format!("{}/{}.{}", path, table, ext);
+    let path2 = format!("{}/{}", path, table);
+    if Path::new(&path1).exists() {
+        Ok(path1)
+    } else if Path::new(&path2).exists() {
+        Ok(path2)
+    } else {
+        Err(DataFusionError::Plan(format!(
+            "Could not find {} files at {} or {}",
+            ext, path1, path2
+        )))
+    }
 }
 
 /// Get the SQL statements from the specified query file