20
20
use ballista:: context:: BallistaContext ;
21
21
use ballista:: prelude:: {
22
22
BallistaConfig , BALLISTA_DEFAULT_BATCH_SIZE , BALLISTA_DEFAULT_SHUFFLE_PARTITIONS ,
23
+ BALLISTA_JOB_NAME ,
23
24
} ;
24
25
use datafusion:: datasource:: file_format:: csv:: DEFAULT_CSV_EXTENSION ;
25
26
use datafusion:: datasource:: file_format:: parquet:: DEFAULT_PARQUET_EXTENSION ;
26
27
use datafusion:: datasource:: listing:: ListingTableUrl ;
27
28
use datafusion:: datasource:: { MemTable , TableProvider } ;
28
29
use datafusion:: error:: { DataFusionError , Result } ;
30
+ use datafusion:: execution:: context:: SessionState ;
29
31
use datafusion:: logical_plan:: LogicalPlan ;
30
32
use datafusion:: parquet:: basic:: Compression ;
31
33
use datafusion:: parquet:: file:: properties:: WriterProperties ;
@@ -272,6 +274,7 @@ async fn main() -> Result<()> {
272
274
}
273
275
}
274
276
277
+ #[ allow( clippy:: await_holding_lock) ]
275
278
async fn benchmark_datafusion ( opt : DataFusionBenchmarkOpt ) -> Result < Vec < RecordBatch > > {
276
279
println ! ( "Running benchmarks with the following options: {:?}" , opt) ;
277
280
let mut benchmark_run = BenchmarkRun :: new ( opt. query ) ;
@@ -282,12 +285,17 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result<Vec<RecordB
282
285
283
286
// register tables
284
287
for table in TABLES {
285
- let table_provider = get_table (
286
- opt. path . to_str ( ) . unwrap ( ) ,
287
- table,
288
- opt. file_format . as_str ( ) ,
289
- opt. partitions ,
290
- ) ?;
288
+ let table_provider = {
289
+ let mut session_state = ctx. state . write ( ) ;
290
+ get_table (
291
+ & mut session_state,
292
+ opt. path . to_str ( ) . unwrap ( ) ,
293
+ table,
294
+ opt. file_format . as_str ( ) ,
295
+ opt. partitions ,
296
+ )
297
+ . await ?
298
+ } ;
291
299
if opt. mem_table {
292
300
println ! ( "Loading table '{}' into memory" , table) ;
293
301
let start = Instant :: now ( ) ;
@@ -343,6 +351,10 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> {
343
351
BALLISTA_DEFAULT_SHUFFLE_PARTITIONS ,
344
352
& format ! ( "{}" , opt. partitions) ,
345
353
)
354
+ . set (
355
+ BALLISTA_JOB_NAME ,
356
+ & format ! ( "Query derived from TPC-H q{}" , opt. query) ,
357
+ )
346
358
. set ( BALLISTA_DEFAULT_BATCH_SIZE , & format ! ( "{}" , opt. batch_size) )
347
359
. build ( )
348
360
. map_err ( |e| DataFusionError :: Execution ( format ! ( "{:?}" , e) ) ) ?;
@@ -375,6 +387,10 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> {
375
387
. await
376
388
. map_err ( |e| DataFusionError :: Plan ( format ! ( "{:?}" , e) ) )
377
389
. unwrap ( ) ;
390
+ let plan = df. to_logical_plan ( ) ?;
391
+ if opt. debug {
392
+ println ! ( "=== Optimized logical plan ===\n {:?}\n " , plan) ;
393
+ }
378
394
batches = df
379
395
. collect ( )
380
396
. await
@@ -718,7 +734,8 @@ async fn convert_tbl(opt: ConvertOpt) -> Result<()> {
718
734
Ok ( ( ) )
719
735
}
720
736
721
- fn get_table (
737
+ async fn get_table (
738
+ ctx : & mut SessionState ,
722
739
path : & str ,
723
740
table : & str ,
724
741
table_format : & str ,
@@ -765,9 +782,13 @@ fn get_table(
765
782
} ;
766
783
767
784
let url = ListingTableUrl :: parse ( path) ?;
768
- let config = ListingTableConfig :: new ( url)
769
- . with_listing_options ( options)
770
- . with_schema ( schema) ;
785
+ let config = ListingTableConfig :: new ( url) . with_listing_options ( options) ;
786
+
787
+ let config = if table_format == "parquet" {
788
+ config. infer_schema ( ctx) . await ?
789
+ } else {
790
+ config. with_schema ( schema)
791
+ } ;
771
792
772
793
Ok ( Arc :: new ( ListingTable :: try_new ( config) ?) )
773
794
}
0 commit comments