diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index 56c6ab6a1..005173cb3 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -1,4 +1,5 @@ import ast +import copy import logging import re import traceback @@ -452,19 +453,24 @@ def _check_is_df_declaration(self, node: ast.AST): and value.func.attr == "DataFrame" ) - def _extract_fix_dataframe_redeclarations(self, node: ast.AST) -> ast.AST: + def _extract_fix_dataframe_redeclarations( + self, node: ast.AST, code_lines: list[str] + ) -> ast.AST: if isinstance(node, ast.Assign): target_names, is_slice, target = self._get_target_names(node.targets) if target_names and self._check_is_df_declaration(node): - value = node.value - # Construct dataframe from node - dataframe_code = compile( - ast.Expression(body=value), filename="", mode="eval" + code = "\n".join(code_lines) + env = self._get_environment() + env["dfs"] = copy.deepcopy(self._dfs) + exec(code, env) + + df_generated = ( + env[target_names[0]][target.slice.value] + if is_slice + else env[target_names[0]] ) - result = eval(dataframe_code) - df_generated = result # check if exists in provided dfs for index, df in enumerate(self._dfs): @@ -506,6 +512,8 @@ def _clean_code(self, code: str, context: CodeExecutionContext) -> str: # Clear recent optional dependencies self._additional_dependencies = [] + clean_code_lines = [] + tree = ast.parse(code) # Check for imports and the node where analyze_data is defined @@ -548,7 +556,12 @@ def _clean_code(self, code: str, context: CodeExecutionContext) -> str: self.find_function_calls(node, context) - new_body.append(self._extract_fix_dataframe_redeclarations(node) or node) + clean_code_lines.append(astor.to_source(node)) + + new_body.append( + self._extract_fix_dataframe_redeclarations(node, clean_code_lines) + or node + ) # Enforcing use of execute_sql_query via Error Prompt Pipeline if self._config.direct_sql and not execute_sql_query_used: diff --git a/pandasai/helpers/output_validator.py b/pandasai/helpers/output_validator.py index 7b2f31f54..8eb68ce23 100644 --- a/pandasai/helpers/output_validator.py +++ b/pandasai/helpers/output_validator.py @@ -1,6 +1,8 @@ import re from typing import Any, Iterable +import numpy as np + import pandasai.pandas as pd from pandasai.exceptions import InvalidOutputValueMismatch @@ -64,7 +66,7 @@ def validate_value(self, expected_type: str) -> bool: @staticmethod def validate_result(result: dict) -> bool: - if not isinstance(result, dict): + if not isinstance(result, dict) or "type" not in result: raise InvalidOutputValueMismatch( "Result must be in the format of dictionary of type and value" ) @@ -73,7 +75,7 @@ def validate_result(result: dict) -> bool: return False elif result["type"] == "number": - return isinstance(result["value"], (int, float)) + return isinstance(result["value"], (int, float, np.int64)) elif result["type"] == "string": return isinstance(result["value"], str) elif result["type"] == "dataframe": diff --git a/tests/unit_tests/test_codemanager.py b/tests/unit_tests/test_codemanager.py index 9a870576b..b445d4b28 100644 --- a/tests/unit_tests/test_codemanager.py +++ b/tests/unit_tests/test_codemanager.py @@ -695,7 +695,11 @@ def test_fix_dataframe_redeclarations( """ tree = ast.parse(python_code) - output = code_manager._extract_fix_dataframe_redeclarations(tree.body[0]) + clean_code = ["df1 = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})"] + + output = code_manager._extract_fix_dataframe_redeclarations( + tree.body[0], clean_code + ) assert isinstance(output, ast.Assign) @@ -721,8 +725,12 @@ def test_fix_dataframe_multiline_redeclarations( print(df1) """ tree = ast.parse(python_code) + clean_codes = [ + "df1 = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})", + ] + outputs = [ - code_manager._extract_fix_dataframe_redeclarations(node) + code_manager._extract_fix_dataframe_redeclarations(node, clean_codes) for node in tree.body ] @@ -750,7 +758,11 @@ def test_fix_dataframe_no_redeclarations( """ tree = ast.parse(python_code) - output = code_manager._extract_fix_dataframe_redeclarations(tree.body[0]) + code_list = ["df1 = dfs[0]"] + + output = code_manager._extract_fix_dataframe_redeclarations( + tree.body[0], code_list + ) assert output is None @@ -773,7 +785,83 @@ def test_fix_dataframe_redeclarations_with_subscript( """ tree = ast.parse(python_code) - output = code_manager._extract_fix_dataframe_redeclarations(tree.body[0]) + code_list = ["dfs[0] = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})"] + + output = code_manager._extract_fix_dataframe_redeclarations( + tree.body[0], code_list + ) + + assert isinstance(output, ast.Assign) + + @patch("pandasai.connectors.pandas.PandasConnector.head") + def test_fix_dataframe_redeclarations_with_subscript_and_data_variable( + self, + mock_head, + exec_context: MagicMock, + config_with_direct_sql: Config, + logger: Logger, + ): + data = { + "country": ["China", "United States", "Japan", "Germany", "United Kingdom"], + "sales": [8000, 6000, 4000, 3500, 3000], + } + df = pd.DataFrame(data) + mock_head.return_value = df + pandas_connector = PandasConnector({"original_df": df}) + + code_manager = CodeManager([pandas_connector], config_with_direct_sql, logger) + + python_code = """ +data = {'country': ['China', 'United States', 'Japan', 'Germany', 'United Kingdom'], + 'sales': [8000, 6000, 4000, 3500, 3000]} +dfs[0] = pd.DataFrame(data) +""" + tree = ast.parse(python_code) + + code_list = [ + "data = {'country': ['China', 'United States', 'Japan', 'Germany', 'United Kingdom'],'sales': [8000, 6000, 4000, 3500, 3000]}", + "dfs[0] = pd.DataFrame(data)", + ] + + output = code_manager._extract_fix_dataframe_redeclarations( + tree.body[1], code_list + ) + + assert isinstance(output, ast.Assign) + + @patch("pandasai.connectors.pandas.PandasConnector.head") + def test_fix_dataframe_redeclarations_and_data_variable( + self, + mock_head, + exec_context: MagicMock, + config_with_direct_sql: Config, + logger: Logger, + ): + data = { + "country": ["China", "United States", "Japan", "Germany", "United Kingdom"], + "sales": [8000, 6000, 4000, 3500, 3000], + } + df = pd.DataFrame(data) + mock_head.return_value = df + pandas_connector = PandasConnector({"original_df": df}) + + code_manager = CodeManager([pandas_connector], config_with_direct_sql, logger) + + python_code = """ +data = {'country': ['China', 'United States', 'Japan', 'Germany', 'United Kingdom'], + 'sales': [8000, 6000, 4000, 3500, 3000]} +df = pd.DataFrame(data) +""" + tree = ast.parse(python_code) + + code_list = [ + "data = {'country': ['China', 'United States', 'Japan', 'Germany', 'United Kingdom'],'sales': [8000, 6000, 4000, 3500, 3000]}", + "df = pd.DataFrame(data)", + ] + + output = code_manager._extract_fix_dataframe_redeclarations( + tree.body[1], code_list + ) assert isinstance(output, ast.Assign)