Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: filter only the columns that are provided in the schema #1562

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def load(
self.dataset_path = self.schema.source.path

df = self._load_from_local_source()
df = self._filter_columns(df)
df = self._apply_transformations(df)

# Convert to pandas DataFrame while preserving internal data
Expand Down Expand Up @@ -202,6 +203,23 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
f"Failed to execute query for '{source_type}' with: {formatted_query}"
) from e

def _filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
"""Filter DataFrame columns based on schema columns if specified.

Args:
df (pd.DataFrame): Input DataFrame to filter

Returns:
pd.DataFrame: DataFrame with only columns specified in schema
"""
if not self.schema or not self.schema.columns:
return df

schema_columns = [col.name for col in self.schema.columns]
df_columns = df.columns.tolist()
columns_to_keep = [col for col in df_columns if col in schema_columns]
return df[columns_to_keep]

def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
if not self.schema.transformations:
return df
Expand Down
53 changes: 53 additions & 0 deletions tests/unit_tests/dataframe/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,56 @@ def test_load_with_schema_and_path(self, sample_schema):
match="Provide only one of 'dataset_path' or 'schema', not both.",
):
result = loader.load("test/users", sample_schema)

def test_filter_columns_with_schema_columns(self, sample_schema):
"""Test that columns are filtered correctly when schema columns are specified."""
loader = DatasetLoader()
loader.schema = sample_schema

# Create a DataFrame with extra columns
df = pd.DataFrame(
{
"email": ["[email protected]"],
"first_name": ["John"],
"timestamp": ["2023-01-01"],
"extra_col": ["extra"], # This column should be filtered out
}
)

filtered_df = loader._filter_columns(df)
assert list(filtered_df.columns) == ["email", "first_name", "timestamp"]
assert "extra_col" not in filtered_df.columns

def test_filter_columns_without_schema_columns(self):
"""Test that all columns are kept when no schema columns are specified."""
loader = DatasetLoader()
# Create schema without columns
loader.schema = SemanticLayerSchema(
**{"name": "Users", "source": {"type": "csv", "path": "users.csv"}}
)

df = pd.DataFrame({"col1": [1], "col2": [2], "col3": [3]})

filtered_df = loader._filter_columns(df)
assert list(filtered_df.columns) == ["col1", "col2", "col3"]

def test_filter_columns_with_non_matching_columns(self, sample_schema):
"""Test filtering when schema columns don't match DataFrame columns."""
loader = DatasetLoader()
loader.schema = sample_schema

# Create DataFrame with none of the schema columns
df = pd.DataFrame({"different_col1": [1], "different_col2": [2]})

filtered_df = loader._filter_columns(df)
assert len(filtered_df.columns) == 0 # Should return empty DataFrame

def test_filter_columns_without_schema(self):
"""Test that all columns are kept when no schema is set."""
loader = DatasetLoader()
loader.schema = None

df = pd.DataFrame({"col1": [1], "col2": [2]})

filtered_df = loader._filter_columns(df)
assert list(filtered_df.columns) == ["col1", "col2"]
Loading