Skip to content

Commit

Permalink
v2: Remove Python sql connection logic (#524)
Browse files Browse the repository at this point in the history
* Remove Python sql connection / sql dataset

* Remove rust python connection logic

* Fix python tests

* fmt
  • Loading branch information
jonmmease committed Nov 16, 2024
1 parent 21494a2 commit 7b3f5b6
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 1,770 deletions.
440 changes: 0 additions & 440 deletions vegafusion-python/src/connection.rs

This file was deleted.

147 changes: 35 additions & 112 deletions vegafusion-python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
pub mod connection;

use lazy_static::lazy_static;
use pyo3;
use pyo3::exceptions::PyValueError;
Expand All @@ -23,7 +21,6 @@ use vegafusion_core::runtime::GrpcVegaFusionRuntime;

use vegafusion_runtime::task_graph::runtime::VegaFusionRuntime;

use crate::connection::{PySqlConnection, PySqlDataset};
use env_logger::{Builder, Target};
use pythonize::{depythonize, pythonize};
use serde_json::json;
Expand Down Expand Up @@ -151,41 +148,22 @@ impl PyChartState {
#[pyclass]
struct PyVegaFusionRuntime {
runtime: Arc<dyn VegaFusionRuntimeTrait>,
tokio_runtime_connection: Arc<Runtime>,
tokio_runtime_current_thread: Arc<Runtime>,
tokio_runtime: Arc<Runtime>,
}

impl PyVegaFusionRuntime {
fn process_inline_datasets(
&self,
inline_datasets: Option<&Bound<PyDict>>,
) -> PyResult<(HashMap<String, VegaFusionDataset>, bool)> {
let mut any_main_thread = false;
) -> PyResult<HashMap<String, VegaFusionDataset>> {
if let Some(inline_datasets) = inline_datasets {
Python::with_gil(|py| -> PyResult<_> {
let vegafusion_dataset_module = PyModule::import_bound(py, "vegafusion.dataset")?;
let sql_dataset_type = vegafusion_dataset_module.getattr("SqlDataset")?;
let imported_datasets = inline_datasets
.iter()
.map(|(name, inline_dataset)| {
let inline_dataset = inline_dataset.to_object(py);
let inline_dataset = inline_dataset.bind(py);
let dataset = if inline_dataset.is_instance(&sql_dataset_type)? {
let main_thread = inline_dataset
.call_method0("main_thread")?
.extract::<bool>()?;
any_main_thread = any_main_thread || main_thread;
let sql_dataset = PySqlDataset::new(inline_dataset.to_object(py))?;
let rt = if main_thread {
&self.tokio_runtime_current_thread
} else {
&self.tokio_runtime_connection
};
let df = py.allow_threads(|| {
rt.block_on(sql_dataset.scan_table(&sql_dataset.table_name))
})?;
VegaFusionDataset::DataFrame(df)
} else if inline_dataset.hasattr("__arrow_c_stream__")? {
let dataset = if inline_dataset.hasattr("__arrow_c_stream__")? {
// Import via Arrow PyCapsule Interface
let (table, hash) =
VegaFusionTable::from_pyarrow_with_hash(py, inline_dataset)?;
Expand All @@ -202,58 +180,42 @@ impl PyVegaFusionRuntime {
Ok((name.to_string(), dataset))
})
.collect::<PyResult<HashMap<_, _>>>()?;
Ok((imported_datasets, any_main_thread))
Ok(imported_datasets)
})
} else {
Ok((Default::default(), false))
Ok(Default::default())
}
}
}

#[pymethods]
impl PyVegaFusionRuntime {
#[staticmethod]
#[pyo3(signature = (max_capacity=None, memory_limit=None, worker_threads=None, connection=None))]
#[pyo3(signature = (max_capacity=None, memory_limit=None, worker_threads=None))]
pub fn new_embedded(
max_capacity: Option<usize>,
memory_limit: Option<usize>,
worker_threads: Option<i32>,
connection: Option<PyObject>,
) -> PyResult<Self> {
initialize_logging();

let (conn, mut tokio_runtime_builder) = if let Some(pyconnection) = connection {
// Use Python connection and single-threaded tokio runtime (this avoids deadlocking the Python interpreter)
let conn = Arc::new(PySqlConnection::new(pyconnection)?) as Arc<dyn Connection>;
(conn, tokio::runtime::Builder::new_current_thread())
} else {
// Use DataFusion connection and multi-threaded tokio runtime
let conn = Arc::new(DataFusionConnection::default()) as Arc<dyn Connection>;
let mut builder = tokio::runtime::Builder::new_multi_thread();
if let Some(worker_threads) = worker_threads {
builder.worker_threads(worker_threads.max(1) as usize);
}
(conn, builder)
};
// Use DataFusion connection and multi-threaded tokio runtime
let conn = Arc::new(DataFusionConnection::default()) as Arc<dyn Connection>;
let mut builder = tokio::runtime::Builder::new_multi_thread();
if let Some(worker_threads) = worker_threads {
builder.worker_threads(worker_threads.max(1) as usize);
}

// Build the tokio runtime
let tokio_runtime_connection = tokio_runtime_builder
.enable_all()
.thread_stack_size(TOKIO_THREAD_STACK_SIZE)
.build()
.external("Failed to create Tokio thread pool")?;

// Create current thread runtime
let tokio_runtime_current_thread = tokio::runtime::Builder::new_current_thread()
let tokio_runtime_connection = builder
.enable_all()
.thread_stack_size(TOKIO_THREAD_STACK_SIZE)
.build()
.external("Failed to create Tokio thread pool")?;

Ok(Self {
runtime: Arc::new(VegaFusionRuntime::new(conn, max_capacity, memory_limit)),
tokio_runtime_connection: Arc::new(tokio_runtime_connection),
tokio_runtime_current_thread: Arc::new(tokio_runtime_current_thread),
tokio_runtime: Arc::new(tokio_runtime_connection),
})
}

Expand All @@ -278,8 +240,7 @@ impl PyVegaFusionRuntime {

Ok(Self {
runtime: Arc::new(runtime),
tokio_runtime_connection: tokio_runtime.clone(),
tokio_runtime_current_thread: tokio_runtime.clone(),
tokio_runtime: tokio_runtime.clone(),
})
}

Expand All @@ -299,21 +260,12 @@ impl PyVegaFusionRuntime {
default_input_tz: default_input_tz.clone(),
};

let (inline_datasets, any_main_thread_sources) =
self.process_inline_datasets(inline_datasets)?;

// Get runtime based on whether there were any Python data sources that require running
// on the main thread. In this case we need to use the current thread tokio runtime
let tokio_runtime = if any_main_thread_sources {
&self.tokio_runtime_current_thread
} else {
&self.tokio_runtime_connection
};
let inline_datasets = self.process_inline_datasets(inline_datasets)?;

py.allow_threads(|| {
PyChartState::try_new(
self.runtime.clone(),
tokio_runtime.clone(),
self.tokio_runtime.clone(),
spec,
inline_datasets,
tz_config,
Expand All @@ -336,8 +288,7 @@ impl PyVegaFusionRuntime {
keep_signals: Option<Vec<(String, Vec<u32>)>>,
keep_datasets: Option<Vec<(String, Vec<u32>)>>,
) -> PyResult<(PyObject, PyObject)> {
let (inline_datasets, any_main_thread_sources) =
self.process_inline_datasets(inline_datasets)?;
let inline_datasets = self.process_inline_datasets(inline_datasets)?;

let spec = parse_json_spec(spec)?;
let preserve_interactivity = preserve_interactivity.unwrap_or(false);
Expand All @@ -351,16 +302,8 @@ impl PyVegaFusionRuntime {
keep_variables.push((Variable::new_data(&name), scope))
}

// Get runtime based on whether there were any Python data sources that require running
// on the main thread. In this case we need to use the current thread tokio runtime
let rt = if any_main_thread_sources {
&self.tokio_runtime_current_thread
} else {
&self.tokio_runtime_connection
};

let (spec, warnings) = py.allow_threads(|| {
rt.block_on(
self.tokio_runtime.block_on(
self.runtime.pre_transform_spec(
&spec,
&inline_datasets,
Expand Down Expand Up @@ -405,8 +348,7 @@ impl PyVegaFusionRuntime {
row_limit: Option<u32>,
inline_datasets: Option<&Bound<PyDict>>,
) -> PyResult<(PyObject, PyObject)> {
let (inline_datasets, any_main_thread_sources) =
self.process_inline_datasets(inline_datasets)?;
let inline_datasets = self.process_inline_datasets(inline_datasets)?;
let spec = parse_json_spec(spec)?;

// Build variables
Expand All @@ -419,16 +361,8 @@ impl PyVegaFusionRuntime {
})
.collect();

// Get runtime based on whether there were any Python data sources that require running
// on the main thread. In this case we need to use the current thread tokio runtime
let rt = if any_main_thread_sources {
&self.tokio_runtime_current_thread
} else {
&self.tokio_runtime_connection
};

let (values, warnings) = py.allow_threads(|| {
rt.block_on(
self.tokio_runtime.block_on(
self.runtime.pre_transform_values(
&spec,
&inline_datasets,
Expand Down Expand Up @@ -495,21 +429,12 @@ impl PyVegaFusionRuntime {
keep_signals: Option<Vec<(String, Vec<u32>)>>,
keep_datasets: Option<Vec<(String, Vec<u32>)>>,
) -> PyResult<(PyObject, Vec<PyObject>, PyObject)> {
let (inline_datasets, any_main_thread_sources) =
self.process_inline_datasets(inline_datasets)?;
let inline_datasets = self.process_inline_datasets(inline_datasets)?;
let spec = parse_json_spec(spec)?;
let preserve_interactivity = preserve_interactivity.unwrap_or(true);
let extract_threshold = extract_threshold.unwrap_or(20);
let extracted_format = extracted_format.unwrap_or_else(|| "pyarrow".to_string());

// Get runtime based on whether there were any Python data sources that require running
// on the main thread. In this case we need to use the current thread tokio runtime
let rt = if any_main_thread_sources {
&self.tokio_runtime_current_thread
} else {
&self.tokio_runtime_connection
};

// Build keep_variables
let mut keep_variables: Vec<PreTransformVariable> = Vec::new();
for (name, scope) in keep_signals.unwrap_or_default() {
Expand All @@ -526,17 +451,18 @@ impl PyVegaFusionRuntime {
}

let (tx_spec, datasets, warnings) = py.allow_threads(|| {
rt.block_on(self.runtime.pre_transform_extract(
&spec,
&inline_datasets,
&PreTransformExtractOpts {
local_tz,
default_input_tz,
preserve_interactivity,
extract_threshold: extract_threshold as i32,
keep_variables,
},
))
self.tokio_runtime
.block_on(self.runtime.pre_transform_extract(
&spec,
&inline_datasets,
&PreTransformExtractOpts {
local_tz,
default_input_tz,
preserve_interactivity,
extract_threshold: extract_threshold as i32,
keep_variables,
},
))
})?;

let warnings: Vec<_> = warnings
Expand Down Expand Up @@ -604,9 +530,7 @@ impl PyVegaFusionRuntime {

pub fn clear_cache(&self) -> PyResult<()> {
if let Some(runtime) = self.runtime.as_any().downcast_ref::<VegaFusionRuntime>() {
Ok(self
.tokio_runtime_current_thread
.block_on(runtime.clear_cache()))
Ok(self.tokio_runtime.block_on(runtime.clear_cache()))
} else {
Err(PyValueError::new_err(
"Current Runtime does not support clear_cache",
Expand Down Expand Up @@ -722,7 +646,6 @@ pub fn build_pre_transform_spec_plan(
fn _vegafusion(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<PyVegaFusionRuntime>()?;
m.add_class::<PyChartState>()?;
m.add_class::<PySqlConnection>()?;
m.add_function(wrap_pyfunction!(get_column_usage, m)?)?;
m.add_function(wrap_pyfunction!(build_pre_transform_spec_plan, m)?)?;
m.add_function(wrap_pyfunction!(get_virtual_memory, m)?)?;
Expand Down
Loading

0 comments on commit 7b3f5b6

Please sign in to comment.