Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(View): add view support and relationship validation in SemanticLayerSchema #1534

Merged
merged 5 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions docs/v3/semantic-layer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,22 @@ columns:
```

**Type**: `list[dict]`
- Each dictionary represents a column
- `name` (str): Name of the column
- `type` (str): Data type of the column
- "string": IDs, names, categories
- "integer": counts, whole numbers
- "float": prices, percentages
- "datetime": timestamps, dates
- "boolean": flags, true/false values
- `description` (str): Clear explanation of what the column represents
- Each dictionary represents a column.
- **Fields**:
- `name` (str): Name of the column.
- For tables: Use simple column names (e.g., `transaction_id`).
- `type` (str): Data type of the column.
- Supported types:
- `"string"`: IDs, names, categories.
- `"integer"`: Counts, whole numbers.
- `"float"`: Prices, percentages.
- `"datetime"`: Timestamps, dates.
- `"boolean"`: Flags, true/false values.
- `description` (str): Clear explanation of what the column represents.

**Constraints**:
1. Column names must be unique.
2. For views, all column names must be in the format `[table].[column]`.

#### transformations
Apply transformations to your data to clean, convert, or anonymize it.
Expand Down Expand Up @@ -350,4 +357,53 @@ Specify the maximum number of records to load.
**Type**: `int`

```yaml
limit: 1000
limit: 1000
```

### View Configuration

The following sections detail all available configurations for view options in your `schema.yaml` file. Similar to views in SQL, you can define multiple tables and the relationships between them.

#### Example Configuration

```yaml
name: table_heart
source:
type: postgres
connection:
host: localhost
port: 5432
database: test
user: test
password: test
view: true
columns:
- name: parents.id
- name: parents.name
- name: parents.age
- name: children.name
- name: children.age
relations:
- name: parent_to_children
description: Relation linking the parent to its children
from: parents.id
to: children.id
```

---

#### Constraints

1. **Mutual Exclusivity**:
- A schema cannot define both `table` and `view` simultaneously.
- If `source.view` is `true`, then the schema represents a view.

2. **Column Format**:
- For views:
- All columns must follow the format `[table].[column]`.
- `from` and `to` fields in `relations` must follow the `[table].[column]` format.
- Example: `parents.id`, `children.name`.

3. **Relationships for Views**:
- Each table referenced in `columns` must have at least one relationship defined in `relations`.
- Relationships must specify `from` and `to` attributes in the `[table].[column]` format.
85 changes: 80 additions & 5 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import re
from functools import partial
from typing import Any, Dict, List, Optional, Union

import yaml
Expand Down Expand Up @@ -32,6 +34,17 @@ def is_column_type_supported(cls, type: str) -> str:
return type


class Relation(BaseModel):
name: Optional[str] = Field(None, description="Name of the relationship.")
description: Optional[str] = Field(
None, description="Description of the relationship."
)
from_: str = Field(
..., alias="from", description="Source column for the relationship."
)
to: str = Field(..., description="Target column for the relationship.")


class Transformation(BaseModel):
type: str = Field(..., description="Type of transformation to be applied.")
params: Optional[Dict[str, str]] = Field(
Expand All @@ -48,34 +61,38 @@ def is_transformation_type_supported(cls, type: str) -> str:

class Source(BaseModel):
type: str = Field(..., description="Type of the data source.")
path: Optional[str] = Field(None, description="Path of the local data source.")
connection: Optional[Dict[str, Union[str, int]]] = Field(
None, description="Connection object of the data source."
)
path: Optional[str] = Field(None, description="Path of the local data source.")
table: Optional[str] = Field(None, description="Table of the data source.")
view: Optional[bool] = Field(False, description="Whether table is a view")

@model_validator(mode="before")
@classmethod
def validate_type_and_fields(cls, values):
_type = values.get("type")
path = values.get("path")
table = values.get("table")
view = values.get("view")
connection = values.get("connection")

if _type in LOCAL_SOURCE_TYPES:
if not path:
raise ValueError(
f"For local source type '{_type}', 'path' must be defined."
)
if view:
raise ValueError("A view cannot be used with a local source type.")
elif _type in REMOTE_SOURCE_TYPES:
if not connection:
raise ValueError(
f"For remote source type '{_type}', 'connection' must be defined."
)
if not table:
raise ValueError(
f"For remote source type '{_type}', 'table' must be defined."
)
if table and view:
raise ValueError("Only one of 'table' or 'view' can be defined.")
if not table and not view:
raise ValueError("Either 'table' or 'view' must be defined.")
else:
raise ValueError(f"Unsupported source type: {_type}")

Expand Down Expand Up @@ -104,6 +121,9 @@ class SemanticLayerSchema(BaseModel):
columns: Optional[List[Column]] = Field(
None, description="Structure and metadata of your dataset’s columns"
)
relations: Optional[List[Relation]] = Field(
None, description="Relationships between columns and tables."
)
order_by: Optional[List[str]] = Field(
None, description="Ordering criteria for the dataset."
)
Expand All @@ -120,6 +140,61 @@ class SemanticLayerSchema(BaseModel):
None, description="Frequency of dataset updates."
)

@model_validator(mode="after")
def check_columns_relations(self):
column_re_check = r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$"
is_view_column_name = partial(re.match, column_re_check)

# unpack columns info
_columns = self.columns
_column_names = [col.name for col in _columns or ()]
_tables_names_in_columns = {
column_name.split(".")[0] for column_name in _column_names or ()
}

if len(_column_names) != len(set(_column_names)):
raise ValueError("Column names must be unique. Duplicate names found.")

if self.source.view:
# unpack relations info
_relations = self.relations
_column_names_in_relations = {
table
for relation in _relations or ()
for table in (relation.from_, relation.to)
}
_tables_names_in_relations = {
column_name.split(".")[0]
for column_name in _column_names_in_relations or ()
}

if not all(
is_view_column_name(column_name) for column_name in _column_names
):
raise ValueError(
"All columns in a view must be in the format '[table].[column]'."
)

if not all(
is_view_column_name(column_name)
for column_name in _column_names_in_relations
):
raise ValueError(
"All params 'from' and 'to' in the relations must be in the format '[table].[column]'."
)

if (
uncovered_tables := _tables_names_in_columns
- _tables_names_in_relations
):
raise ValueError(
f"No relations provided for the following tables {uncovered_tables}."
)

elif any(is_view_column_name(column_name) for column_name in _column_names):
raise ValueError("All columns in a table must be in the format '[column]'.")
return self

def to_dict(self) -> dict[str, Any]:
return self.model_dump(exclude_none=True)

Expand Down
79 changes: 79 additions & 0 deletions tests/unit_tests/dataframe/test_semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,29 @@ def mysql_schema(self):
},
}

@pytest.fixture
def mysql_view_schema(self):
return {
"name": "Users",
"columns": [
{"name": "parents.id"},
{"name": "parents.name"},
{"name": "children.name"},
],
"relations": [{"from": "parents.id", "to": "children.id"}],
"source": {
"type": "mysql",
"connection": {
"host": "localhost",
"port": 3306,
"database": "test_db",
"user": "test_user",
"password": "test_password",
},
"view": "true",
},
}

def test_valid_schema(self, sample_schema):
schema = SemanticLayerSchema(**sample_schema)

Expand All @@ -113,6 +136,14 @@ def test_valid_mysql_schema(self, mysql_schema):
assert len(schema.transformations) == 2
assert schema.source.type == "mysql"

def test_valid_mysql_view_schema(self, mysql_view_schema):
schema = SemanticLayerSchema(**mysql_view_schema)

assert schema.name == "Users"
assert len(schema.columns) == 3
assert schema.source.view == True
assert schema.source.type == "mysql"

def test_missing_source_path(self, sample_schema):
sample_schema["source"].pop("path")

Expand Down Expand Up @@ -203,3 +234,51 @@ def test_is_schema_source_same_false(self, mysql_schema, sample_schema):
schema2 = SemanticLayerSchema(**sample_schema)

assert is_schema_source_same(schema1, schema2) is False

def test_invalid_source_view_for_local_type(self, sample_schema):
sample_schema["source"]["view"] = True

with pytest.raises(ValidationError):
SemanticLayerSchema(**sample_schema)

def test_invalid_source_view_and_table(self, mysql_schema):
mysql_schema["source"]["view"] = True

with pytest.raises(ValidationError):
SemanticLayerSchema(**mysql_schema)

def test_invalid_source_missing_view_or_table(self, mysql_schema):
mysql_schema["source"].pop("table")

with pytest.raises(ValidationError):
SemanticLayerSchema(**mysql_schema)

def test_invalid_duplicated_columns(self, sample_schema):
sample_schema["columns"].append(sample_schema["columns"][0])

with pytest.raises(ValidationError):
SemanticLayerSchema(**sample_schema)

def test_invalid_wrong_column_format_in_view(self, mysql_view_schema):
mysql_view_schema["columns"][0]["name"] = "parentsid"

with pytest.raises(ValidationError):
SemanticLayerSchema(**mysql_view_schema)

def test_invalid_wrong_column_format(self, sample_schema):
sample_schema["columns"][0]["name"] = "parents.id"

with pytest.raises(ValidationError):
SemanticLayerSchema(**sample_schema)

def test_invalid_wrong_relation_format_in_view(self, mysql_view_schema):
mysql_view_schema["relations"][0]["to"] = "parentsid"

with pytest.raises(ValidationError):
SemanticLayerSchema(**mysql_view_schema)

def test_invalid_uncovered_columns_in_view(self, mysql_view_schema):
mysql_view_schema.pop("relations")

with pytest.raises(ValidationError):
SemanticLayerSchema(**mysql_view_schema)
Loading