From 65c7533acb0a2f3608ab172528f38e588181d66c Mon Sep 17 00:00:00 2001 From: Stefan Date: Sun, 16 Feb 2025 13:39:11 +0100 Subject: [PATCH 1/2] feat: py bindings refactoring --- bindings/python/pyproject.toml | 1 + bindings/python/src/content.rs | 43 ++++++++++++++++++++ bindings/python/src/decision.rs | 32 +++++---------- bindings/python/src/engine.rs | 66 +++++++++++++------------------ bindings/python/src/expression.rs | 62 ++++++++++------------------- bindings/python/src/lib.rs | 3 ++ bindings/python/src/loader.rs | 64 +++++++++++++++++++++--------- bindings/python/src/types.rs | 8 ++-- bindings/python/src/value.rs | 29 ++++++++++++-- bindings/python/src/variable.rs | 35 ++++++++++++++-- bindings/python/test_async.py | 2 +- bindings/python/zen.pyi | 38 ++++++++++++------ 12 files changed, 240 insertions(+), 143 deletions(-) create mode 100644 bindings/python/src/content.rs diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index ef9b8542..1ad34c42 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -5,6 +5,7 @@ build-backend = "maturin" [project] name = "zen-engine" requires-python = ">=3.7" +version = "0.40.0" classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", diff --git a/bindings/python/src/content.rs b/bindings/python/src/content.rs new file mode 100644 index 00000000..fc51ef41 --- /dev/null +++ b/bindings/python/src/content.rs @@ -0,0 +1,43 @@ +use anyhow::Context; +use pyo3::prelude::{PyAnyMethods, PyStringMethods}; +use pyo3::types::PyString; +use pyo3::{pyclass, pymethods, Bound, FromPyObject, PyAny, PyResult}; +use pythonize::depythonize; +use std::sync::Arc; +use zen_engine::model::DecisionContent; + +#[pyclass] +#[pyo3(name = "ZenDecisionContent")] +pub struct PyZenDecisionContent(pub Arc); + +#[pymethods] +impl PyZenDecisionContent { + #[new] + pub fn new(data: &str) -> PyResult { + let content = serde_json::from_str(data).context("Failed to parse JSON")?; + Ok(Self(Arc::new(content))) + } +} + +pub struct PyZenDecisionContentJson(pub PyZenDecisionContent); + +impl<'py> FromPyObject<'py> for PyZenDecisionContentJson { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(s) = ob.downcast::() { + let borrow_ref = s.borrow(); + let content = borrow_ref.0.clone(); + + return Ok(Self(PyZenDecisionContent(content))); + } + + if let Ok(b) = ob.downcast::() { + let str = b.to_str()?; + let content = serde_json::from_str(str).context("Invalid JSON")?; + + return Ok(Self(PyZenDecisionContent(Arc::new(content)))); + } + + let content = depythonize(ob)?; + Ok(Self(PyZenDecisionContent(Arc::new(content)))) + } +} diff --git a/bindings/python/src/decision.rs b/bindings/python/src/decision.rs index e88268ec..92cecfee 100644 --- a/bindings/python/src/decision.rs +++ b/bindings/python/src/decision.rs @@ -5,16 +5,14 @@ use crate::engine::PyZenEvaluateOptions; use crate::loader::PyDecisionLoader; use crate::mt::worker_pool; use crate::value::PyValue; +use crate::variable::PyVariable; use anyhow::{anyhow, Context}; -use pyo3::types::PyDict; -use pyo3::{pyclass, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python}; +use pyo3::{pyclass, pymethods, IntoPyObjectExt, Py, PyAny, PyResult, Python}; use pyo3_async_runtimes::tokio; use pyo3_async_runtimes::tokio::get_current_locals; use pyo3_async_runtimes::tokio::re_exports::runtime::Runtime; -use pythonize::depythonize; use serde_json::Value; use zen_engine::{Decision, EvaluationOptions}; -use zen_expression::Variable; #[pyclass] #[pyo3(name = "ZenDecision")] @@ -32,22 +30,16 @@ impl PyZenDecision { pub fn evaluate( &self, py: Python, - ctx: &Bound<'_, PyDict>, - opts: Option<&Bound<'_, PyDict>>, + ctx: PyVariable, + opts: Option, ) -> PyResult> { - let context: Variable = depythonize(ctx).context("Failed to convert dict")?; - let options: PyZenEvaluateOptions = if let Some(op) = opts { - depythonize(op).context("Failed to convert dict")? - } else { - Default::default() - }; - + let options = opts.unwrap_or_default(); let decision = self.0.clone(); let rt = Runtime::new()?; let result = rt .block_on(decision.evaluate_with_opts( - context, + ctx.into_inner(), EvaluationOptions { max_depth: options.max_depth, trace: options.trace, @@ -65,15 +57,11 @@ impl PyZenDecision { pub fn async_evaluate<'py>( &'py self, py: Python<'py>, - ctx: &Bound<'_, PyDict>, - opts: Option<&Bound<'_, PyDict>>, + ctx: PyValue, + opts: Option, ) -> PyResult> { - let context: Value = depythonize(ctx).context("Failed to convert dict")?; - let options: PyZenEvaluateOptions = if let Some(op) = opts { - depythonize(op).context("Failed to convert dict")? - } else { - Default::default() - }; + let context: Value = ctx.0; + let options = opts.unwrap_or_default(); let decision = self.0.clone(); let result = tokio::future_into_py_with_locals(py, get_current_locals(py)?, async move { diff --git a/bindings/python/src/engine.rs b/bindings/python/src/engine.rs index ab0f7d68..1ec9e7b6 100644 --- a/bindings/python/src/engine.rs +++ b/bindings/python/src/engine.rs @@ -1,22 +1,21 @@ use std::sync::Arc; +use crate::content::PyZenDecisionContentJson; use crate::custom_node::PyCustomNode; use crate::decision::PyZenDecision; use crate::loader::PyDecisionLoader; use crate::mt::{block_on, worker_pool}; use crate::value::PyValue; +use crate::variable::PyVariable; use anyhow::{anyhow, Context}; use pyo3::prelude::PyDictMethods; use pyo3::types::PyDict; -use pyo3::{pyclass, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python}; +use pyo3::{pyclass, pymethods, Bound, FromPyObject, IntoPyObjectExt, Py, PyAny, PyResult, Python}; use pyo3_async_runtimes::tokio::get_current_locals; use pyo3_async_runtimes::{tokio, TaskLocals}; -use pythonize::depythonize; use serde::{Deserialize, Serialize}; use serde_json::Value; -use zen_engine::model::DecisionContent; use zen_engine::{DecisionEngine, EvaluationOptions}; -use zen_expression::Variable; #[pyclass] #[pyo3(name = "ZenEngine")] @@ -24,7 +23,7 @@ pub struct PyZenEngine { engine: Arc>, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, FromPyObject)] pub struct PyZenEvaluateOptions { pub trace: Option, pub max_depth: Option, @@ -60,24 +59,26 @@ impl PyZenEngine { }; let loader = match options.get_item("loader")? { - Some(loader) => Some(Python::with_gil(|py| loader.into_py_any(py))?), + Some(loader) => Some(loader.into_py_any(py)?), None => None, }; let custom_node = match options.get_item("customHandler")? { - Some(custom_node) => Some(Python::with_gil(|py| custom_node.into_py_any(py))?), + Some(custom_node) => Some(custom_node.into_py_any(py)?), None => None, }; - let task_locals = TaskLocals::with_running_loop(py) - .ok() - .map(|s| s.copy_context(py).ok()) - .flatten(); + let make_locals = || { + TaskLocals::with_running_loop(py) + .ok() + .map(|s| s.copy_context(py).ok()) + .flatten() + }; Ok(Self { engine: Arc::new(DecisionEngine::new( - Arc::new(PyDecisionLoader::from(loader)), - Arc::new(PyCustomNode::new(custom_node, task_locals)), + Arc::new(PyDecisionLoader::new(loader, make_locals())), + Arc::new(PyCustomNode::new(custom_node, make_locals())), )), }) } @@ -86,20 +87,14 @@ impl PyZenEngine { pub fn evaluate( &self, py: Python, - key: String, - ctx: &Bound<'_, PyDict>, - opts: Option<&Bound<'_, PyDict>>, + key: &str, + ctx: PyVariable, + opts: Option, ) -> PyResult> { - let context: Variable = depythonize(ctx).context("Failed to convert dict")?; - let options: PyZenEvaluateOptions = if let Some(op) = opts { - depythonize(op).context("Failed to convert dict")? - } else { - Default::default() - }; - + let options = opts.unwrap_or_default(); let result = block_on(self.engine.evaluate_with_opts( key, - context, + ctx.into_inner(), EvaluationOptions { max_depth: options.max_depth, trace: options.trace, @@ -118,15 +113,11 @@ impl PyZenEngine { &'py self, py: Python<'py>, key: String, - ctx: &Bound<'_, PyDict>, - opts: Option<&Bound<'_, PyDict>>, + ctx: PyValue, + opts: Option, ) -> PyResult> { - let context: Value = depythonize(ctx).context("Failed to convert dict")?; - let options: PyZenEvaluateOptions = if let Some(op) = opts { - depythonize(op).context("Failed to convert dict")? - } else { - Default::default() - }; + let context: Value = ctx.0; + let options: PyZenEvaluateOptions = opts.unwrap_or_default(); let engine = self.engine.clone(); let result = tokio::future_into_py_with_locals(py, get_current_locals(py)?, async move { @@ -157,16 +148,13 @@ impl PyZenEngine { Ok(result.unbind()) } - pub fn create_decision(&self, content: String) -> PyResult { - let decision_content: DecisionContent = - serde_json::from_str(&content).context("Failed to serialize decision content")?; - - let decision = self.engine.create_decision(decision_content.into()); + pub fn create_decision(&self, content: PyZenDecisionContentJson) -> PyResult { + let decision = self.engine.create_decision(content.0 .0); Ok(PyZenDecision::from(decision)) } - pub fn get_decision<'py>(&'py self, _py: Python<'py>, key: String) -> PyResult { - let decision = block_on(self.engine.get_decision(&key)) + pub fn get_decision<'py>(&'py self, _py: Python<'py>, key: &str) -> PyResult { + let decision = block_on(self.engine.get_decision(key)) .context("Failed to find decision with given key")?; Ok(PyZenDecision::from(decision)) diff --git a/bindings/python/src/expression.rs b/bindings/python/src/expression.rs index 2cd3b31a..c77a3fc7 100644 --- a/bindings/python/src/expression.rs +++ b/bindings/python/src/expression.rs @@ -1,15 +1,14 @@ use crate::variable::PyVariable; -use anyhow::{anyhow, Context}; +use anyhow::anyhow; use either::Either; -use pyo3::types::PyDict; -use pyo3::{pyclass, pyfunction, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python}; -use pythonize::{depythonize, pythonize}; +use pyo3::{pyclass, pyfunction, pymethods, IntoPyObjectExt, Py, PyAny, PyResult, Python}; +use pythonize::pythonize; use zen_expression::expression::{Standard, Unary}; use zen_expression::{Expression, Variable}; #[pyfunction] -pub fn compile_expression(expression: String) -> PyResult { - let expr = zen_expression::compile_expression(expression.as_str()) +pub fn compile_expression(expression: &str) -> PyResult { + let expr = zen_expression::compile_expression(expression) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; Ok(PyExpression { @@ -18,8 +17,8 @@ pub fn compile_expression(expression: String) -> PyResult { } #[pyfunction] -pub fn compile_unary_expression(expression: String) -> PyResult { - let expr = zen_expression::compile_unary_expression(expression.as_str()) +pub fn compile_unary_expression(expression: &str) -> PyResult { + let expr = zen_expression::compile_unary_expression(expression) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; Ok(PyExpression { @@ -31,42 +30,28 @@ pub fn compile_unary_expression(expression: String) -> PyResult { #[pyo3(signature = (expression, ctx=None))] pub fn evaluate_expression( py: Python, - expression: String, - ctx: Option<&Bound<'_, PyDict>>, + expression: &str, + ctx: Option, ) -> PyResult> { - let context = ctx - .map(|ctx| depythonize(ctx)) - .transpose() - .context("Failed to convert context")? - .unwrap_or(Variable::Null); + let context = ctx.map(|c| c.into_inner()).unwrap_or(Variable::Null); - let result = zen_expression::evaluate_expression(expression.as_str(), context) + let result = zen_expression::evaluate_expression(expression, context) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; PyVariable(result).into_py_any(py) } #[pyfunction] -pub fn evaluate_unary_expression(expression: String, ctx: &Bound<'_, PyDict>) -> PyResult { - let context: Variable = depythonize(ctx).context("Failed to convert context")?; - - let result = zen_expression::evaluate_unary_expression(expression.as_str(), context) +pub fn evaluate_unary_expression(expression: &str, ctx: PyVariable) -> PyResult { + let result = zen_expression::evaluate_unary_expression(expression, ctx.into_inner()) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; Ok(result) } #[pyfunction] -pub fn render_template( - py: Python, - template: String, - ctx: &Bound<'_, PyDict>, -) -> PyResult> { - let context: Variable = depythonize(ctx) - .context("Failed to convert context") - .unwrap_or(Variable::Null); - - let result = zen_tmpl::render(template.as_str(), context) +pub fn render_template(py: Python, template: &str, ctx: PyVariable) -> PyResult> { + let result = zen_tmpl::render(template, ctx.into_inner()) .map_err(|e| anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())))?; PyVariable(result).into_py_any(py) @@ -79,13 +64,8 @@ pub struct PyExpression { #[pymethods] impl PyExpression { #[pyo3(signature = (ctx=None))] - pub fn evaluate(&self, py: Python, ctx: Option<&Bound<'_, PyDict>>) -> PyResult> { - let context = ctx - .map(|ctx| depythonize(ctx)) - .transpose() - .context("Failed to convert context")? - .unwrap_or(Variable::Null); - + pub fn evaluate(&self, py: Python, ctx: Option) -> PyResult> { + let context = ctx.map(|c| c.into_inner()).unwrap_or(Variable::Null); let maybe_result = match &self.expression { Either::Left(standard) => standard.evaluate(context), Either::Right(unary) => unary.evaluate(context).map(Variable::Bool), @@ -99,8 +79,8 @@ impl PyExpression { } #[pyfunction] -pub fn validate_expression(py: Python, expression: String) -> PyResult>> { - let Err(error) = zen_expression::validate::validate_expression(expression.as_str()) else { +pub fn validate_expression(py: Python, expression: &str) -> PyResult>> { + let Err(error) = zen_expression::validate::validate_expression(expression) else { return Ok(None); }; @@ -108,8 +88,8 @@ pub fn validate_expression(py: Python, expression: String) -> PyResult PyResult>> { - let Err(error) = zen_expression::validate::validate_expression(expression.as_str()) else { +pub fn validate_unary_expression(py: Python, expression: &str) -> PyResult>> { + let Err(error) = zen_expression::validate::validate_expression(expression) else { return Ok(None); }; diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 2895ec21..3d262149 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,3 +1,4 @@ +use crate::content::PyZenDecisionContent; use crate::decision::PyZenDecision; use crate::engine::PyZenEngine; use crate::expression::{ @@ -8,6 +9,7 @@ use pyo3::prelude::PyModuleMethods; use pyo3::types::PyModule; use pyo3::{pymodule, wrap_pyfunction, Bound, PyResult, Python}; +mod content; mod custom_node; mod decision; mod engine; @@ -23,6 +25,7 @@ fn zen(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(evaluate_expression, m)?)?; m.add_function(wrap_pyfunction!(evaluate_unary_expression, m)?)?; m.add_function(wrap_pyfunction!(render_template, m)?)?; diff --git a/bindings/python/src/loader.rs b/bindings/python/src/loader.rs index 2d949e05..6ee6b553 100644 --- a/bindings/python/src/loader.rs +++ b/bindings/python/src/loader.rs @@ -1,46 +1,72 @@ use std::future::Future; use std::sync::Arc; +use crate::content::PyZenDecisionContentJson; use anyhow::anyhow; -use pyo3::{Py, PyAny, PyObject, Python}; - +use either::Either; +use pyo3::{IntoPyObjectExt, Py, PyAny, PyResult, Python}; +use pyo3_async_runtimes::TaskLocals; use zen_engine::loader::{DecisionLoader, LoaderError, LoaderResponse}; use zen_engine::model::DecisionContent; #[derive(Default)] -pub(crate) struct PyDecisionLoader(Option>); - -impl From for PyDecisionLoader { - fn from(value: PyObject) -> Self { - Self(Some(value)) - } +pub(crate) struct PyDecisionLoader { + callback: Option>, + task_locals: Option, } -impl From>> for PyDecisionLoader { - fn from(value: Option>) -> Self { - Self(value) +impl PyDecisionLoader { + pub fn new(callback: Option>, task_locals: Option) -> Self { + Self { + callback, + task_locals, + } } } impl PyDecisionLoader { - fn load_element(&self, key: &str) -> Result, anyhow::Error> { - let Some(object) = &self.0 else { + async fn load_element(&self, key: &str) -> Result, anyhow::Error> { + let Some(callable) = &self.callback else { return Err(anyhow!("Loader is not defined")); }; - let content = Python::with_gil(|py| { - let result = object.call1(py, (key,))?; - result.extract::(py) - })?; + let maybe_result: PyResult<_> = Python::with_gil(|py| { + let result = callable.call1(py, (key,))?; + let is_coroutine = result.getattr(py, "__await__").is_ok(); + if !is_coroutine { + return Ok(Either::Left( + result.extract::(py)?, + )); + } - Ok(serde_json::from_str::(&content)?.into()) + let Some(task_locals) = &self.task_locals else { + Err(anyhow!("Task locals are required in async context"))? + }; + + let result_future = pyo3_async_runtimes::into_future_with_locals( + task_locals, + result.into_bound_py_any(py)?, + )?; + + Ok(Either::Right(result_future)) + }); + + match maybe_result? { + Either::Left(result) => Ok(result.0 .0), + Either::Right(future) => { + let result = future.await?; + let content = + Python::with_gil(|py| result.extract::(py))?; + Ok(content.0 .0) + } + } } } impl DecisionLoader for PyDecisionLoader { fn load<'a>(&'a self, key: &'a str) -> impl Future + 'a { async move { - self.load_element(key).map_err(|e| { + self.load_element(key).await.map_err(|e| { LoaderError::Internal { source: e, key: key.to_string(), diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 6d296abf..9537ae14 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -60,11 +60,11 @@ impl PyNodeRequest { #[pymethods] impl PyNodeRequest { - fn get_field(&self, py: Python, path: String) -> PyResult> { + fn get_field(&self, py: Python, path: &str) -> PyResult> { let node_config = &self.inner_node.config; let selected_value: Value = node_config - .dot_get(path.as_str()) + .dot_get(path) .ok() .flatten() .context("Failed to find JSON path")?; @@ -78,11 +78,11 @@ impl PyNodeRequest { PyVariable(template_value).into_py_any(py) } - fn get_field_raw(&self, py: Python, path: String) -> PyResult> { + fn get_field_raw(&self, py: Python, path: &str) -> PyResult> { let node_config = &self.inner_node.config; let selected_value: Value = node_config - .dot_get(path.as_str()) + .dot_get(path) .ok() .flatten() .context("Failed to find JSON path")?; diff --git a/bindings/python/src/value.rs b/bindings/python/src/value.rs index 13bd030d..e3bf5728 100644 --- a/bindings/python/src/value.rs +++ b/bindings/python/src/value.rs @@ -1,6 +1,8 @@ -use pyo3::prelude::{PyDictMethods, PyListMethods}; -use pyo3::types::{PyDict, PyList}; -use pyo3::{Bound, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyResult, Python}; +use anyhow::Context; +use pyo3::prelude::{PyAnyMethods, PyBytesMethods, PyDictMethods, PyListMethods, PyStringMethods}; +use pyo3::types::{PyBytes, PyDict, PyList, PyString}; +use pyo3::{Bound, FromPyObject, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyResult, Python}; +use pythonize::depythonize; use serde_json::Value; #[repr(transparent)] @@ -46,3 +48,24 @@ impl<'py> IntoPyObject<'py> for PyValue { value_to_object(py, &self.0) } } + +impl<'py> FromPyObject<'py> for PyValue { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(s) = ob.downcast::() { + let str_slice = s.to_str()?; + + let var = serde_json::from_str(str_slice).context("Invalid JSON")?; + return Ok(PyValue(var)); + } + + if let Ok(b) = ob.downcast::() { + let bytes = b.as_bytes(); + + let var = serde_json::from_slice(bytes).context("Invalid JSON")?; + return Ok(PyValue(var)); + } + + let var = depythonize(ob)?; + Ok(PyValue(var)) + } +} diff --git a/bindings/python/src/variable.rs b/bindings/python/src/variable.rs index 6f86c23a..ffae4536 100644 --- a/bindings/python/src/variable.rs +++ b/bindings/python/src/variable.rs @@ -1,6 +1,8 @@ -use pyo3::prelude::{PyDictMethods, PyListMethods}; -use pyo3::types::{PyDict, PyList}; -use pyo3::{Bound, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyResult, Python}; +use anyhow::Context; +use pyo3::prelude::{PyAnyMethods, PyBytesMethods, PyDictMethods, PyListMethods, PyStringMethods}; +use pyo3::types::{PyBytes, PyDict, PyList, PyString}; +use pyo3::{Bound, FromPyObject, IntoPyObject, IntoPyObjectExt, PyAny, PyErr, PyResult, Python}; +use pythonize::depythonize; use rust_decimal::prelude::ToPrimitive; use zen_expression::Variable; @@ -8,6 +10,12 @@ use zen_expression::Variable; #[derive(Clone, Debug)] pub struct PyVariable(pub Variable); +impl PyVariable { + pub fn into_inner(self) -> Variable { + self.0 + } +} + pub fn variable_to_object<'py>(py: Python<'py>, val: &Variable) -> PyResult> { match val { Variable::Null => py.None().into_bound_py_any(py), @@ -49,3 +57,24 @@ impl<'py> IntoPyObject<'py> for PyVariable { variable_to_object(py, &self.0) } } + +impl<'py> FromPyObject<'py> for PyVariable { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(s) = ob.downcast::() { + let str_slice = s.to_str()?; + + let var = serde_json::from_str(str_slice).context("Invalid JSON")?; + return Ok(PyVariable(var)); + } + + if let Ok(b) = ob.downcast::() { + let bytes = b.as_bytes(); + + let var = serde_json::from_slice(bytes).context("Invalid JSON")?; + return Ok(PyVariable(var)); + } + + let var = depythonize(ob)?; + Ok(PyVariable(var)) + } +} diff --git a/bindings/python/test_async.py b/bindings/python/test_async.py index 65496dfb..eeddd29a 100644 --- a/bindings/python/test_async.py +++ b/bindings/python/test_async.py @@ -8,7 +8,7 @@ import zen -def loader(key): +async def loader(key): with open("../../test-data/" + key, "r") as f: return f.read() diff --git a/bindings/python/zen.pyi b/bindings/python/zen.pyi index 0d5e3a04..ecf448a2 100644 --- a/bindings/python/zen.pyi +++ b/bindings/python/zen.pyi @@ -1,5 +1,10 @@ from collections.abc import Awaitable -from typing import Any, Optional, TypedDict, Literal +from typing import Any, Optional, TypedDict, Literal, TypeAlias, Union + + +class DecisionEvaluateOptions(TypedDict, total=False): + max_depth: int + trace: bool class EvaluateResponse(TypedDict): @@ -8,33 +13,44 @@ class EvaluateResponse(TypedDict): trace: dict +ZenContext: TypeAlias = Union[str, bytes, dict] +ZenDecisionContentInput: TypeAlias = Union[str, ZenDecisionContent] + + class ZenEngine: - def __init__(self, maybe_options: Optional[dict] = None) -> None: ... + def __init__(self, options: Optional[dict] = None) -> None: ... - def evaluate(self, key: str, ctx: dict, opts: Optional[dict] = None) -> EvaluateResponse: ... + def evaluate(self, key: str, context: ZenContext, + options: Optional[DecisionEvaluateOptions] = None) -> EvaluateResponse: ... - def async_evaluate(self, key: str, ctx: dict, opts: Optional[dict] = None) -> Awaitable[EvaluateResponse]: ... + def async_evaluate(self, key: str, context: ZenContext, options: Optional[DecisionEvaluateOptions] = None) -> \ + Awaitable[EvaluateResponse]: ... - def create_decision(self, content: str) -> ZenDecision: ... + def create_decision(self, content: ZenDecisionContentInput) -> ZenDecision: ... def get_decision(self, key: str) -> ZenDecision: ... +class ZenDecisionContent: + def __init__(self, decision_content: str) -> None: ... + + class ZenDecision: - def evaluate(self, ctx: dict, opts: Optional[dict] = None) -> EvaluateResponse: ... + def evaluate(self, context: ZenContext, options: Optional[DecisionEvaluateOptions] = None) -> EvaluateResponse: ... - def async_evaluate(self, ctx: dict, opts: Optional[dict] = None) -> Awaitable[EvaluateResponse]: ... + def async_evaluate(self, context: ZenContext, options: Optional[DecisionEvaluateOptions] = None) -> Awaitable[ + EvaluateResponse]: ... def validate(self) -> None: ... -def evaluate_expression(expression: str, ctx: Optional[dict] = None) -> Any: ... +def evaluate_expression(expression: str, context: Optional[ZenContext] = None) -> Any: ... -def evaluate_unary_expression(expression: str, ctx: dict) -> bool: ... +def evaluate_unary_expression(expression: str, context: ZenContext) -> bool: ... -def render_template(template: str, ctx: dict) -> Any: ... +def render_template(template: str, context: ZenContext) -> Any: ... def compile_expression(expression: str) -> Expression: ... @@ -44,7 +60,7 @@ def compile_unary_expression(expression: str) -> Expression: ... class Expression: - def evaluate(self, ctx: Optional[dict] = None) -> Any: ... + def evaluate(self, context: Optional[ZenContext] = None) -> Any: ... def validate_expression(expression: str) -> Optional[ValidationResponse]: ... From 4dd811c7bbcb445d12cc1108136c5e1cd0b54ab8 Mon Sep 17 00:00:00 2001 From: Stefan Date: Sun, 16 Feb 2025 18:14:23 +0100 Subject: [PATCH 2/2] unset version in pyproject.toml --- bindings/python/pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 1ad34c42..ef9b8542 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -5,7 +5,6 @@ build-backend = "maturin" [project] name = "zen-engine" requires-python = ">=3.7" -version = "0.40.0" classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython",