Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 27, 2024
1 parent 1ee0b35 commit c4805f1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
5 changes: 3 additions & 2 deletions sdgx/data_models/inspectors/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ def fit(self, raw_data: pd.DataFrame, *args, **kwargs):

# Iterate all columns and determain the final data type
for col in raw_data.columns:
if (pd.api.types.is_integer_dtype(raw_data[col].dtype)
or pd.api.types.is_float_dtype(raw_data[col].dtype)):
if pd.api.types.is_integer_dtype(raw_data[col].dtype) or pd.api.types.is_float_dtype(
raw_data[col].dtype
):
# series type may be 32/64bit.
# float or int
if self._is_int_column(raw_data[col]):
Expand Down
5 changes: 1 addition & 4 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ def check_categorical_threshold(self, num_categories):

def get_column_encoder(self, column_name):
encoder_type = None
if (
self.categorical_encoder
and column_name in self.categorical_encoder
):
if self.categorical_encoder and column_name in self.categorical_encoder:
encoder_type = self.categorical_encoder[column_name]
return encoder_type

Expand Down
8 changes: 6 additions & 2 deletions sdgx/models/components/optimize/sdv_ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):
num_categories = len(encoder.dummies)
activate_fn = "softmax"

checked = self.metadata.check_categorical_threshold(num_categories) if self.metadata else False
checked = (
self.metadata.check_categorical_threshold(num_categories) if self.metadata else False
)
if encoder_type == "onehot" or not checked:
pass
elif encoder_type == "label":
Expand Down Expand Up @@ -127,7 +129,9 @@ def fit(self, data_loader: DataLoader, discrete_columns=()):
if column_name in discrete_columns:
# or column_name in self.metadata.label_columns
logger.debug(f"Fitting discrete column {column_name}...")
encoder_type = self.metadata.get_column_encoder(column_name) if self.metadata else None
encoder_type = (
self.metadata.get_column_encoder(column_name) if self.metadata else None
)
column_transform_info = self._fit_discrete(data_loader[[column_name]], encoder_type)
else:
logger.debug(f"Fitting continuous column {column_name}...")
Expand Down
1 change: 1 addition & 0 deletions tests/data_models/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from pathlib import Path

import pytest

from sdgx.data_connectors.csv_connector import CsvConnector
Expand Down

0 comments on commit c4805f1

Please sign in to comment.