Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cyantangerine committed Nov 27, 2024
1 parent 3d9af8f commit 1ee0b35
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 11 deletions.
4 changes: 3 additions & 1 deletion sdgx/data_models/inspectors/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):

# Iterate all columns and determain the final data type
for col in raw_data.columns:
if raw_data[col].dtype in ["int64", "float64"]:
if (pd.api.types.is_integer_dtype(raw_data[col].dtype)
or pd.api.types.is_float_dtype(raw_data[col].dtype)):
# series type may be 32/64bit.
# float or int
if self._is_int_column(raw_data[col]):
self.int_columns.add(col)
Expand Down
9 changes: 9 additions & 0 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ def check_categorical_threshold(self, num_categories):
else False
)

def get_column_encoder(self, column_name):
encoder_type = None
if (
self.categorical_encoder
and column_name in self.categorical_encoder
):
encoder_type = self.categorical_encoder[column_name]
return encoder_type

@property
def tag_fields(self) -> Iterable[str]:
"""
Expand Down
9 changes: 2 additions & 7 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):
num_categories = len(encoder.dummies)
activate_fn = "softmax"

checked = self.metadata.check_categorical_threshold(num_categories)
checked = self.metadata.check_categorical_threshold(num_categories) if self.metadata else False
if encoder_type == "onehot" or not checked:
pass
elif encoder_type == "label":
Expand Down Expand Up @@ -127,12 +127,7 @@ def fit(self, data_loader: DataLoader, discrete_columns=()):
if column_name in discrete_columns:
# or column_name in self.metadata.label_columns
logger.debug(f"Fitting discrete column {column_name}...")
encoder_type = None
if (
self.metadata.categorical_encoder
and column_name in self.metadata.categorical_encoder
):
encoder_type = self.metadata.categorical_encoder[column_name]
encoder_type = self.metadata.get_column_encoder(column_name) if self.metadata else None
column_transform_info = self._fit_discrete(data_loader[[column_name]], encoder_type)
else:
logger.debug(f"Fitting continuous column {column_name}...")
Expand Down
1 change: 0 additions & 1 deletion tests/data_models/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from pathlib import Path

import pytest

from sdgx.data_connectors.csv_connector import CsvConnector
Expand Down
3 changes: 1 addition & 2 deletions tests/test_ctgan_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def demo_single_table_data_pos_neg():

@pytest.fixture
def demo_single_table_data_pos_neg_metadata(demo_single_table_data_pos_neg):
yield Metadata.from_dataframe(demo_single_table_data_pos_neg)
yield Metadata.from_dataframe(demo_single_table_data_pos_neg.copy(), check=True)


@pytest.fixture
Expand Down Expand Up @@ -82,7 +82,6 @@ def test_ctgan_synthesizer_with_pos_neg(
demo_single_table_data_pos_neg,
):
original_data = demo_single_table_data_pos_neg
metadata = demo_single_table_data_pos_neg_metadata

# Train the CTGAN model
ctgan_synthesizer.fit(demo_single_table_data_pos_neg_metadata)
Expand Down

0 comments on commit 1ee0b35

Please sign in to comment.