Skip to content

Commit

Permalink
Make SessionContext members private (#4698)
Browse files Browse the repository at this point in the history
* Make SessionContext members private

* Fix benchmarks
  • Loading branch information
tustvold authored Dec 22, 2022
1 parent d7f4dd3 commit 4917235
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 21 deletions.
15 changes: 7 additions & 8 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ use datafusion_benchmarks::tpch::*;
use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION;
use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION;
use datafusion::datasource::listing::ListingTableUrl;
use datafusion::execution::context::SessionState;
use datafusion::scheduler::Scheduler;
use futures::TryStreamExt;
use serde::Serialize;
Expand Down Expand Up @@ -270,16 +269,14 @@ async fn benchmark_query(
Ok((benchmark_run, result))
}

#[allow(clippy::await_holding_lock)]
async fn register_tables(
opt: &DataFusionBenchmarkOpt,
ctx: &SessionContext,
) -> Result<()> {
for table in TPCH_TABLES {
let table_provider = {
let mut session_state = ctx.state.write();
get_table(
&mut session_state,
ctx,
opt.path.to_str().unwrap(),
table,
opt.file_format.as_str(),
Expand Down Expand Up @@ -368,12 +365,14 @@ async fn execute_query(
}

async fn get_table(
ctx: &mut SessionState,
ctx: &SessionContext,
path: &str,
table: &str,
table_format: &str,
target_partitions: usize,
) -> Result<Arc<dyn TableProvider>> {
// Obtain a snapshot of the SessionState
let state = ctx.state();
let (format, path, extension): (Arc<dyn FileFormat>, String, &'static str) =
match table_format {
// dbgen creates .tbl ('|' delimited) files without header
Expand All @@ -396,7 +395,7 @@ async fn get_table(
}
"parquet" => {
let path = format!("{}/{}", path, table);
let format = ParquetFormat::new(ctx.config_options())
let format = ParquetFormat::new(state.config_options())
.with_enable_pruning(Some(true));

(Arc::new(format), path, DEFAULT_PARQUET_EXTENSION)
Expand All @@ -410,13 +409,13 @@ async fn get_table(
let options = ListingOptions::new(format)
.with_file_extension(extension)
.with_target_partitions(target_partitions)
.with_collect_stat(ctx.config.collect_statistics());
.with_collect_stat(state.config.collect_statistics());

let table_path = ListingTableUrl::parse(path)?;
let config = ListingTableConfig::new(table_path).with_listing_options(options);

let config = if table_format == "parquet" {
config.infer_schema(ctx).await?
config.infer_schema(&state).await?
} else {
config.with_schema(schema)
};
Expand Down
17 changes: 9 additions & 8 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ pub struct SessionContext {
/// Uuid for the session
session_id: String,
/// Session start time
pub session_start_time: DateTime<Utc>,
session_start_time: DateTime<Utc>,
/// Shared session state for the session
pub state: Arc<RwLock<SessionState>>,
state: Arc<RwLock<SessionState>>,
}

impl Default for SessionContext {
Expand Down Expand Up @@ -203,22 +203,23 @@ impl SessionContext {
/// Creates a new session context using the provided configuration and RuntimeEnv.
pub fn with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let state = SessionState::with_config_rt(config, runtime);
Self {
session_id: state.session_id.clone(),
session_start_time: chrono::Utc::now(),
state: Arc::new(RwLock::new(state)),
}
Self::with_state(state)
}

/// Creates a new session context using the provided session state.
pub fn with_state(state: SessionState) -> Self {
Self {
session_id: state.session_id.clone(),
session_start_time: chrono::Utc::now(),
session_start_time: Utc::now(),
state: Arc::new(RwLock::new(state)),
}
}

/// Returns the time this session was created
pub fn session_start_time(&self) -> DateTime<Utc> {
self.session_start_time
}

/// Registers the [`RecordBatch`] as the specified table name
pub fn register_batch(
&self,
Expand Down
5 changes: 1 addition & 4 deletions datafusion/core/tests/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ fn register_aggregate(ctx: &mut SessionContext) {
);

// register the selector as "first"
ctx.state
.write()
.aggregate_functions
.insert(name.to_string(), Arc::new(first));
ctx.register_udaf(first)
}

/// This structureg models a specialized timeseries aggregate function
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ impl AsLogicalPlan for LogicalPlanNode {
};

let file_type = create_extern_table.file_type.as_str();
let env = &ctx.state.as_ref().read().runtime_env;
let env = ctx.runtime_env();
if !env.table_factories.contains_key(file_type) {
Err(DataFusionError::Internal(format!(
"No TableProvider for file type: {}",
Expand Down

0 comments on commit 4917235

Please sign in to comment.