diff --git a/lale/expressions.py b/lale/expressions.py index 58f31d1bc..746d25271 100644 --- a/lale/expressions.py +++ b/lale/expressions.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, Union import astunparse +from six.moves import cStringIO AstLits = (ast.Num, ast.Str, ast.List, ast.Tuple, ast.Set, ast.Dict) AstLit = Union[ast.Num, ast.Str, ast.List, ast.Tuple, ast.Set, ast.Dict] @@ -48,6 +49,27 @@ ] +# !! WORKAROUND !! +# There is a bug with astunparse and Python 3.8. +# https://github.com/simonpercivall/astunparse/issues/43 +# Until it is fixed (which may be never), here is a workaround, +# based on the workaround found in https://github.com/juanlao7/codeclose +class FixUnparser(astunparse.Unparser): + def _Constant(self, t): + if not hasattr(t, "kind"): + setattr(t, "kind", None) + + super()._Constant(t) + + +# !! WORKAROUND !! +# This method should be called instead of astunparse.unparse +def fixedUnparse(tree): + v = cStringIO() + FixUnparser(tree, file=v) + return v.getvalue() + + class Expr: _expr: AstExpr @@ -104,7 +126,7 @@ def __getitem__(self, key: Union[int, str, slice]) -> "Expr": return Expr(subscript) def __str__(self) -> str: - result = astunparse.unparse(self._expr).strip() + result = fixedUnparse(self._expr).strip() if isinstance(self._expr, (ast.UnaryOp, ast.BinOp, ast.Compare, ast.BoolOp)): if result.startswith("(") and result.endswith(")"): result = result[1:-1] diff --git a/test/test_core_misc.py b/test/test_core_misc.py index 0c4205b52..b34b9cb51 100644 --- a/test/test_core_misc.py +++ b/test/test_core_misc.py @@ -62,6 +62,19 @@ def test_transformers(self): self.assertNotIn("MLPClassifier", ops_names) +class TestUnparseExpr(unittest.TestCase): + def test_unparse_const38(self): + import lale.expressions + from lale.expressions import it + + test_expr = it.hello["hi"] + # This fails on 3.8 with some versions of the library + # which is why we use the fixed version + # import astunparse + # astunparse.unparse(he._expr) + str(lale.expressions.fixedUnparse(test_expr._expr)) + + class TestOperatorWithoutSchema(unittest.TestCase): def test_trainable_pipe_left(self): from sklearn.decomposition import PCA