From 7f4231021774251aebec0705dfb46b07ec34f95c Mon Sep 17 00:00:00 2001 From: T Date: Thu, 19 Sep 2024 11:57:05 +0800 Subject: [PATCH] test: add tests for the several modules --- tests/conftest.py | 55 +++++++++++++ .../test_causality/test_causal_analysis.py | 55 +++++++++++++ .../test_classification_model.py | 71 ++++++++++++++++ .../test_explainability.py | 76 ++++++++++++++++++ .../test_categorical_test.py | 68 ++++++++++++++++ .../test_statistical_test/test_correlation.py | 46 +++++++++++ .../test_numerical_test.py | 51 ++++++++++++ .../test_utils_data_type.py | 30 +++++++ tests/test_sdqc_check/test_utils.py | 35 ++++++++ .../test_sdqc_integration/test_sequential.py | 80 +++++++++++++++++++ .../test_sdv_synthesizer.py | 54 +++++++++++++ .../test_ydata_synthesizer.py | 64 +++++++++++++++ 12 files changed, 685 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_sdqc_check/test_causality/test_causal_analysis.py create mode 100644 tests/test_sdqc_check/test_classification/test_classification_model.py create mode 100644 tests/test_sdqc_check/test_explainability/test_explainability.py create mode 100644 tests/test_sdqc_check/test_statistical_test/test_categorical_test.py create mode 100644 tests/test_sdqc_check/test_statistical_test/test_correlation.py create mode 100644 tests/test_sdqc_check/test_statistical_test/test_numerical_test.py create mode 100644 tests/test_sdqc_check/test_statistical_test/test_utils_data_type.py create mode 100644 tests/test_sdqc_check/test_utils.py create mode 100644 tests/test_sdqc_integration/test_sequential.py create mode 100644 tests/test_sdqc_synthesize/test_sdv_synthesizer.py create mode 100644 tests/test_sdqc_synthesize/test_ydata_synthesizer.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ba3759c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,55 @@ +import pytest +import pandas as pd +import numpy as np + + +@pytest.fixture +def raw_data(): + return pd.DataFrame({ + 'A': np.random.randint(0, 100, 100), + 'B': np.random.choice([0, 1, 2], 100), + 'C': np.random.uniform(0, 1, 100), + 'D': np.random.choice([0, 1], 100) + }) + + +@pytest.fixture +def synthetic_data(): + return pd.DataFrame({ + 'A': np.random.randint(0, 100, 100), + 'B': np.random.choice([0, 1, 2], 100), + 'C': np.random.uniform(0, 1, 100), + 'D': np.random.choice([0, 1], 100) + }) + + +@pytest.fixture +def col_dtypes(): + return { + 'categorical': ['B', 'D'], + 'numerical': ['A', 'C'] + } + + +@pytest.fixture +def sample_categorical_data1(): + return pd.Series(['A'] * 40 + ['B'] * 30 + ['C'] * 20 + ['D'] * 10) + + +@pytest.fixture +def sample_categorical_data2(): + return pd.Series( + np.random.choice( + ['A', 'B', 'C', 'D'], 100, p=[0.4, 0.3, 0.2, 0.1] + ) + ) + + +@pytest.fixture +def sample_numerical_data1(): + return pd.Series([1, 2, 3, 4, 5] * 20) + + +@pytest.fixture +def sample_numerical_data2(): + return pd.Series([1, 2, 3, 4, 5, 5, 4, 3, 2, 1] * 10) diff --git a/tests/test_sdqc_check/test_causality/test_causal_analysis.py b/tests/test_sdqc_check/test_causality/test_causal_analysis.py new file mode 100644 index 0000000..d90206d --- /dev/null +++ b/tests/test_sdqc_check/test_causality/test_causal_analysis.py @@ -0,0 +1,55 @@ +import pytest +import castle +import numpy as np +import pandas as pd +from sdqc_check import CausalAnalysis +from sdqc_data import read_data + + +@pytest.fixture +def sample_data(raw_data, synthetic_data): + return raw_data, synthetic_data + + +def test_causal_analysis_initialization(sample_data): + raw_data, synthetic_data = sample_data + ca = CausalAnalysis(raw_data, synthetic_data) + assert isinstance(ca, CausalAnalysis) + assert np.array_equal(ca.raw_data, raw_data.to_numpy()) + assert np.array_equal(ca.synthetic_data, synthetic_data.to_numpy()) + assert ca.model_name == 'dlg' + assert ca.random_seed == 17 + assert ca.device_type == 'cpu' + assert ca.device_id == 0 + + +def test_causal_analysis_invalid_model(): + with pytest.raises(ValueError): + CausalAnalysis( + pd.DataFrame(), pd.DataFrame(), model_name='invalid_model' + ) + + +@pytest.mark.parametrize('model_name', ['dlg', 'notears', 'golem', 'grandag', 'gae']) +def test_get_model(sample_data, model_name): + raw_data, synthetic_data = sample_data + ca = CausalAnalysis(raw_data, synthetic_data, model_name=model_name) + model = ca._get_model(model_name) + assert isinstance(model, castle.common.BaseLearner) + + +def test_compute_causal_matrices(sample_data): + raw_data, synthetic_data = sample_data + ca = CausalAnalysis(raw_data, synthetic_data) + raw_matrix, synthetic_matrix = ca.compute_causal_matrices() + assert isinstance(raw_matrix, np.ndarray) + assert isinstance(synthetic_matrix, np.ndarray) + assert raw_matrix.shape == (4, 4) + assert synthetic_matrix.shape == (4, 4) + +def test_compare_adjacency_matrices(): + raw_data = read_data('3_raw') + synthetic_data = read_data('3_synth') + ca = CausalAnalysis(raw_data, synthetic_data) + mt = ca.compare_adjacency_matrices() + assert isinstance(mt, castle.MetricsDAG) diff --git a/tests/test_sdqc_check/test_classification/test_classification_model.py b/tests/test_sdqc_check/test_classification/test_classification_model.py new file mode 100644 index 0000000..645c306 --- /dev/null +++ b/tests/test_sdqc_check/test_classification/test_classification_model.py @@ -0,0 +1,71 @@ +import pytest +import pandas as pd +from sdqc_check import ClassificationModel + + +@pytest.fixture +def sample_data(raw_data, synthetic_data): + return raw_data, synthetic_data + + +def test_classification_model_initialization(sample_data): + raw_data, synthetic_data = sample_data + cm = ClassificationModel(raw_data, synthetic_data) + assert isinstance(cm, ClassificationModel) + assert cm.model_name == 'rf' + assert cm.test_size == 0.2 + assert cm.random_seed == 17 + + +def test_classification_model_invalid_model(): + with pytest.raises(ValueError): + ClassificationModel( + pd.DataFrame(), pd.DataFrame(), model_name='invalid_model' + ) + + +@pytest.mark.parametrize('model_name', ['svm', 'rf', 'xgb', 'lgbm']) +def test_classification_model_single_model(sample_data, model_name): + raw_data, synthetic_data = sample_data + cm = ClassificationModel(raw_data, synthetic_data, model_name=model_name) + metrics, models = cm.train_and_evaluate_models() + assert isinstance(metrics, pd.DataFrame) + assert len(models) == 1 + assert metrics['Model'].iloc[0] == model_name + + +def test_classification_model_multiple_models(sample_data): + raw_data, synthetic_data = sample_data + model_names = ['svm', 'rf', 'xgb', 'lgbm'] + cm = ClassificationModel(raw_data, synthetic_data, model_name=model_names) + metrics, models = cm.train_and_evaluate_models() + assert isinstance(metrics, pd.DataFrame) + assert len(models) == len(model_names) + assert set(metrics['Model']) == set(model_names) + + +def test_classification_model_custom_params(sample_data): + raw_data, synthetic_data = sample_data + custom_params = { + 'rf': {'n_estimators': 100, 'max_depth': 5}, + 'xgb': {'n_estimators': 100, 'max_depth': 5} + } + cm = ClassificationModel( + raw_data, synthetic_data, + model_name=['rf', 'xgb'], + model_params=custom_params + ) + metrics, models = cm.train_and_evaluate_models() + assert isinstance(metrics, pd.DataFrame) + assert len(models) == 2 + assert set(metrics['Model']) == {'rf', 'xgb'} + + +def test_classification_model_metrics(sample_data): + raw_data, synthetic_data = sample_data + cm = ClassificationModel(raw_data, synthetic_data) + metrics, _ = cm.train_and_evaluate_models() + expected_columns = [ + 'Model', 'Accuracy', 'Precision', 'Recall', 'F1', 'AUC' + ] + assert set(metrics.columns) == set(expected_columns) diff --git a/tests/test_sdqc_check/test_explainability/test_explainability.py b/tests/test_sdqc_check/test_explainability/test_explainability.py new file mode 100644 index 0000000..fcd9a1e --- /dev/null +++ b/tests/test_sdqc_check/test_explainability/test_explainability.py @@ -0,0 +1,76 @@ +import pytest +import pandas as pd +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +from sklearn.ensemble import RandomForestClassifier +from sklearn.svm import SVC +from sdqc_check import ShapFeatureImportance, PFIFeatureImportance + + +@pytest.fixture +def sample_data(): + X, y = make_classification( + n_samples=100, n_features=10, n_informative=5, random_state=17 + ) + X = pd.DataFrame(X, columns=[f'feature_{i + 1}' for i in range(10)]) + y = pd.Series(y) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=17 + ) + model = RandomForestClassifier(random_state=17) + model.fit(X_train, y_train) + return model, X_train, X_test, y_train, y_test + + +@pytest.mark.parametrize("FeatureImportance", [ShapFeatureImportance, PFIFeatureImportance]) +def test_feature_importance(sample_data, FeatureImportance): + model, X_train, X_test, y_train, y_test = sample_data + importance = FeatureImportance(model, X_train, X_test, y_test) + result = importance.compute_feature_importance() + + assert isinstance(result, pd.DataFrame) + assert len(result) == 10 + assert set(result.columns) == {'feature', 'importance'} + assert result['feature'].nunique() == 10 + assert all( + f'feature_{i + 1}' in result['feature'].values for i in range(10) + ) + assert result['importance'].is_monotonic_decreasing + + +@pytest.mark.parametrize("FeatureImportance", [ShapFeatureImportance, PFIFeatureImportance]) +def test_feature_importance_random_seed(sample_data, FeatureImportance): + model, X_train, X_test, y_train, y_test = sample_data + importance1 = FeatureImportance( + model, X_train, X_test, y_test, random_seed=17 + ) + importance2 = FeatureImportance( + model, X_train, X_test, y_test, random_seed=17 + ) + + scores1 = importance1.compute_feature_importance() + scores2 = importance2.compute_feature_importance() + + pd.testing.assert_frame_equal(scores1, scores2) + + +@pytest.mark.parametrize("model_class", [RandomForestClassifier, SVC]) +def test_shap_feature_importance_models(sample_data, model_class): + _, X_train, X_test, y_train, y_test = sample_data + if model_class == SVC: + model = model_class(random_state=17, probability=True) + else: + model = model_class(random_state=17) + model.fit(X_train, y_train) + + shap_importance = ShapFeatureImportance(model, X_train, X_test, y_test) + importance = shap_importance.compute_feature_importance() + + assert isinstance(importance, pd.DataFrame) + assert len(importance) == 10 + assert set(importance.columns) == {'feature', 'importance'} + assert importance['feature'].nunique() == 10 + assert all( + f'feature_{i + 1}' in importance['feature'].values for i in range(10) + ) + assert importance['importance'].is_monotonic_decreasing diff --git a/tests/test_sdqc_check/test_statistical_test/test_categorical_test.py b/tests/test_sdqc_check/test_statistical_test/test_categorical_test.py new file mode 100644 index 0000000..bd260f0 --- /dev/null +++ b/tests/test_sdqc_check/test_statistical_test/test_categorical_test.py @@ -0,0 +1,68 @@ +import pytest +import pandas as pd +from sdqc_check import CategoricalTest + + +@pytest.fixture +def categorical_test(): + return CategoricalTest() + + +def test_basis(categorical_test, sample_categorical_data1): + result = categorical_test.basis(sample_categorical_data1) + assert isinstance(result, pd.Series) + assert result['count'] == 100 + assert result['missing'] == 0 + assert result['unique'] == 4 + assert result['top'] == 'A' + assert result['freq'] == 40 + + +def test_distribution( + categorical_test, + sample_categorical_data1, + sample_categorical_data2 +): + jaccard, p_value = categorical_test.distribution( + sample_categorical_data1, sample_categorical_data2 + ) + assert isinstance(jaccard, float) + assert isinstance(p_value, float) + assert 0 <= jaccard <= 1 + assert 0 <= p_value <= 1 + + +def test_jaccard_index_identical(categorical_test): + data = pd.Series(['A', 'B', 'C']) + jaccard, _ = categorical_test.distribution(data, data) + assert jaccard == 1 + + +def test_jaccard_index_disjoint(categorical_test): + data1 = pd.Series(['A', 'B', 'C']) + data2 = pd.Series(['D', 'E', 'F']) + jaccard, _ = categorical_test.distribution(data1, data2) + assert jaccard == 0 + + +def test_jaccard_index_different(categorical_test): + data1 = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'C', 'A', 'B', 'C']) + data2 = pd.Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'C', 'D']) + jaccard, p_value = categorical_test.distribution(data1, data2) + assert jaccard < 1 + assert p_value == 0 + + +def test_chi_square_identical(categorical_test): + data1 = pd.Series(['A'] * 4 + ['B'] * 2 + ['C'] * 3 + ['D'] * 1) + data2 = pd.Series(['D'] * 1 + ['C'] * 3 + ['B'] * 2 + ['A'] * 4) + + _, p_value = categorical_test.distribution(data1, data2) + assert p_value == 1 + + +def test_chi_square_different(categorical_test): + data1 = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'D', 'A', 'B', 'C']) + data2 = pd.Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'C', 'D']) + _, p_value = categorical_test.distribution(data1, data2) + assert p_value < 1 diff --git a/tests/test_sdqc_check/test_statistical_test/test_correlation.py b/tests/test_sdqc_check/test_statistical_test/test_correlation.py new file mode 100644 index 0000000..1b9cc3a --- /dev/null +++ b/tests/test_sdqc_check/test_statistical_test/test_correlation.py @@ -0,0 +1,46 @@ +import pandas as pd +import numpy as np +from sdqc_check import data_corr + + +def test_data_corr(raw_data, col_dtypes): + result = data_corr(raw_data, col_dtypes) + + assert isinstance(result, pd.DataFrame) + assert set(result.columns) == {'column1', + 'column2', 'method', 'corr_coefficient'} + + expected_pairs = [ + ('A', 'B'), ('A', 'C'), ('A', 'D'), + ('B', 'C'), ('B', 'D'), ('C', 'D') + ] + assert set(zip(result['column1'], result['column2']) + ) == set(expected_pairs) + + assert set(result['method']) == {"Cramer's V", 'Pearson', 'Eta'} + assert (result['corr_coefficient'] >= - + 1).all() and (result['corr_coefficient'] <= 1).all() + + +def test_data_corr_perfect_correlation(): + perfect_data = pd.DataFrame({ + 'cat1': ['A', 'B'] * 50, + 'cat2': ['X', 'Y'] * 50, + 'num1': list(range(100)), + 'num2': list(range(0, 200, 2)) + }) + + col_dtypes = { + 'categorical': ['cat1', 'cat2'], + 'numerical': ['num1', 'num2'] + } + + result = data_corr(perfect_data, col_dtypes) + + cramer_v = result[(result['column1'] == 'cat1') & ( + result['column2'] == 'cat2')]['corr_coefficient'].values[0] + assert np.isclose(cramer_v, 1.0, atol=0.02) + + pearson = result[(result['column1'] == 'num1') & ( + result['column2'] == 'num2')]['corr_coefficient'].values[0] + assert np.isclose(pearson, 1.0) diff --git a/tests/test_sdqc_check/test_statistical_test/test_numerical_test.py b/tests/test_sdqc_check/test_statistical_test/test_numerical_test.py new file mode 100644 index 0000000..457d66f --- /dev/null +++ b/tests/test_sdqc_check/test_statistical_test/test_numerical_test.py @@ -0,0 +1,51 @@ +import pytest +import pandas as pd +from sdqc_check import NumericalTest + + +@pytest.fixture +def numerical_test(): + return NumericalTest() + + +def test_basis(numerical_test, sample_numerical_data1): + result = numerical_test.basis(sample_numerical_data1) + assert isinstance(result, pd.Series) + assert result['count'] == 100 + assert result['missing'] == 0 + assert result['min'] == 1 + assert result['max'] == 5 + assert result['mean'] == 3 + assert round(result['var']) == 2 + assert round(result['cv'], 1) == 0.5 + assert result['skew'] == 0 + assert round(result['kurt']) == -1 + + +def test_distribution( + numerical_test, + sample_numerical_data1, + sample_numerical_data2 +): + wasserstein, hellinger = numerical_test.distribution( + sample_numerical_data1, sample_numerical_data2 + ) + assert isinstance(wasserstein, float) + assert isinstance(hellinger, float) + assert wasserstein >= 0 + assert 0 <= hellinger <= 1 + + +def test_distances_identical(numerical_test): + data = pd.Series([1, 2, 3, 4, 5]) + wasserstein, hellinger = numerical_test.distribution(data, data) + assert wasserstein == 0 + assert hellinger == 0 + + +def test_distances_different(numerical_test): + data1 = pd.Series([1, 2, 3, 4, 5]) + data2 = pd.Series([6, 7, 8, 9, 10]) + wasserstein, hellinger = numerical_test.distribution(data1, data2) + assert wasserstein > 0 + assert hellinger > 0 diff --git a/tests/test_sdqc_check/test_statistical_test/test_utils_data_type.py b/tests/test_sdqc_check/test_statistical_test/test_utils_data_type.py new file mode 100644 index 0000000..d0c5ad8 --- /dev/null +++ b/tests/test_sdqc_check/test_statistical_test/test_utils_data_type.py @@ -0,0 +1,30 @@ +import pytest +import pandas as pd +from typing import Dict +from sdqc_check.statistical_test.utils import identify_data_types + + +@pytest.fixture +def sample_data(): + return pd.DataFrame({ + 'bool_col': [True, False] * 10, + 'int_col': range(1, 21), + 'float_col': [i * 1.1 for i in range(1, 21)], + 'cat_col': ['A', 'B', 'C', 'D'] * 5, + 'cat_problem_col': [str(i) for i in range(20)], + 'na_problem_col': [1, 2, 3] + [pd.NA] * 17 + }) + + +def test_identify_data_types(sample_data): + result = identify_data_types(sample_data) + + assert isinstance(result, Dict) + assert set(result.keys()) == {'categorical', 'numerical', 'problem'} + + assert 'bool_col' in result['categorical'] + assert 'cat_col' in result['categorical'] + assert 'int_col' in result['numerical'] + assert 'float_col' in result['numerical'] + assert 'cat_problem_col' in result['problem'] + assert 'na_problem_col' in result['problem'] diff --git a/tests/test_sdqc_check/test_utils.py b/tests/test_sdqc_check/test_utils.py new file mode 100644 index 0000000..75d2158 --- /dev/null +++ b/tests/test_sdqc_check/test_utils.py @@ -0,0 +1,35 @@ +import pandas as pd +import numpy as np +from sdqc_check.utils import combine_data_and_labels, set_seed + + +def test_combine_data_and_labels(raw_data, synthetic_data): + X, y = combine_data_and_labels(raw_data, synthetic_data) + + assert isinstance(X, pd.DataFrame) + assert len(X) == len(raw_data) + len(synthetic_data) + assert isinstance(y, np.ndarray) + assert len(y) == len(X) + assert np.sum(y == 1) == len(raw_data) + assert np.sum(y == 0) == len(synthetic_data) + + +def test_set_seed(): + seed = 17 + set_seed(seed) + + random_1 = np.random.rand() + set_seed(seed) + random_2 = np.random.rand() + assert random_1 == random_2 + + import os + assert os.environ['PYTHONHASHSEED'] == str(seed) + + import torch + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + random_tensor_1 = torch.rand(1) + set_seed(seed) + random_tensor_2 = torch.rand(1) + assert torch.all(random_tensor_1.eq(random_tensor_2)) diff --git a/tests/test_sdqc_integration/test_sequential.py b/tests/test_sdqc_integration/test_sequential.py new file mode 100644 index 0000000..6ad2731 --- /dev/null +++ b/tests/test_sdqc_integration/test_sequential.py @@ -0,0 +1,80 @@ +import pytest +import pandas as pd +import numpy as np +from typing import List, Dict +from sdqc_integration import SequentialAnalysis +from sdqc_data import read_data + + +@pytest.fixture +def sample_data(): + raw_data = read_data('3_raw') + synthetic_data = read_data('3_synth') + return raw_data, synthetic_data + + +def test_sequential_analysis_initialization(sample_data): + raw_data, synthetic_data = sample_data + analysis = SequentialAnalysis(raw_data, synthetic_data) + assert isinstance(analysis, SequentialAnalysis) + assert analysis.raw_data.equals(raw_data) + assert analysis.synthetic_data.equals(synthetic_data) + + +def test_statistical_test_step(sample_data): + raw_data, synthetic_data = sample_data + analysis = SequentialAnalysis(raw_data, synthetic_data) + result = analysis.statistical_test_step() + assert isinstance(result, dict) + assert 'column_types' in result + assert 'raw_correlation' in result + assert 'synthetic_correlation' in result + assert 'results' in result + + +def test_classification_step(sample_data): + raw_data, synthetic_data = sample_data + analysis = SequentialAnalysis(raw_data, synthetic_data) + metrics, models = analysis.classification_step() + assert isinstance(metrics, pd.DataFrame) + assert isinstance(models, List) + assert len(models) > 0 + + +def test_causal_analysis_step(sample_data): + raw_data, synthetic_data = sample_data + analysis = SequentialAnalysis(raw_data, synthetic_data) + metrics, raw_matrix, synthetic_matrix = analysis.causal_analysis_step() + assert isinstance(metrics, Dict) + assert isinstance(raw_matrix, np.ndarray) + assert isinstance(synthetic_matrix, np.ndarray) + + +@pytest.mark.parametrize('explainability_algorithm', ['shap', 'pfi']) +def test_explainability_step(sample_data, explainability_algorithm): + raw_data, synthetic_data = sample_data + analysis = SequentialAnalysis( + raw_data, synthetic_data, explainability_algorithm=explainability_algorithm + ) + analysis.run() + assert isinstance(analysis.results['Explainability'], pd.DataFrame) + + +def test_run(sample_data): + raw_data, synthetic_data = sample_data + analysis = SequentialAnalysis(raw_data, synthetic_data) + results = analysis.run() + assert isinstance(results, Dict) + assert 'Statistical Test' in results + assert 'Classification' in results + assert 'Explainability' in results + assert 'Causal Analysis' in results + + +def test_visualize_html(sample_data, tmp_path): + raw_data, synthetic_data = sample_data + analysis = SequentialAnalysis(raw_data, synthetic_data) + analysis.run() + output_path = tmp_path / 'test_report.html' + analysis.visualize_html(str(output_path)) + assert output_path.exists() diff --git a/tests/test_sdqc_synthesize/test_sdv_synthesizer.py b/tests/test_sdqc_synthesize/test_sdv_synthesizer.py new file mode 100644 index 0000000..c1e80ef --- /dev/null +++ b/tests/test_sdqc_synthesize/test_sdv_synthesizer.py @@ -0,0 +1,54 @@ +import pandas as pd +from typing import Dict +from sdqc_synthesize import SDVSynthesizer + + +def test_sdv_synthesizer_initialization(raw_data): + synthesizer = SDVSynthesizer(data=raw_data) + assert isinstance(synthesizer, SDVSynthesizer) + assert synthesizer.data.equals(raw_data) + assert synthesizer.model_name == 'tvae' + assert synthesizer.random_seed == 17 + + +def test_sdv_synthesizer_fit(raw_data): + synthesizer = SDVSynthesizer(data=raw_data) + fitted_model = synthesizer.fit() + assert fitted_model is not None + + +def test_sdv_synthesizer_generate(raw_data): + synthesizer = SDVSynthesizer(data=raw_data) + synthetic_data = synthesizer.generate() + assert isinstance(synthetic_data, pd.DataFrame) + assert synthetic_data.shape[0] == raw_data.shape[0] + assert set(synthetic_data.columns) == set(raw_data.columns) + + +def test_sdv_synthesizer_multiple_models(raw_data): + synthesizer = SDVSynthesizer( + data=raw_data, model_name=['gaussiancopula', 'ctgan'] + ) + results = synthesizer.generate() + assert isinstance(results, Dict) + assert set(results.keys()) == {'gaussiancopula', 'ctgan'} + for _, synthetic_data in results.items(): + assert isinstance(synthetic_data, pd.DataFrame) + assert synthetic_data.shape[0] == raw_data.shape[0] + assert set(synthetic_data.columns) == set(raw_data.columns) + + +def test_sdv_synthesizer_custom_num_rows(raw_data): + num_rows = 10 + synthesizer = SDVSynthesizer(data=raw_data, num_rows=num_rows) + synthetic_data = synthesizer.generate() + assert isinstance(synthetic_data, pd.DataFrame) + assert synthetic_data.shape[0] == num_rows + + +def test_sdv_synthesizer_custom_model_args(raw_data): + model_args = {'epochs': 5} + synthesizer = SDVSynthesizer(data=raw_data, model_args=model_args) + assert synthesizer.model_args == model_args + synthetic_data = synthesizer.generate() + assert isinstance(synthetic_data, pd.DataFrame) diff --git a/tests/test_sdqc_synthesize/test_ydata_synthesizer.py b/tests/test_sdqc_synthesize/test_ydata_synthesizer.py new file mode 100644 index 0000000..14dfc9a --- /dev/null +++ b/tests/test_sdqc_synthesize/test_ydata_synthesizer.py @@ -0,0 +1,64 @@ +import pandas as pd +from typing import Dict +from sdqc_synthesize import YDataSynthesizer + + +def test_ydata_synthesizer_initialization(raw_data): + synthesizer = YDataSynthesizer(data=raw_data) + assert isinstance(synthesizer, YDataSynthesizer) + assert synthesizer.data.equals(raw_data) + assert synthesizer.model_name == 'fast' + assert synthesizer.random_seed == 17 + + +def test_ydata_synthesizer_fit(raw_data): + synthesizer = YDataSynthesizer(data=raw_data) + fitted_model = synthesizer.fit() + assert fitted_model is not None + + +def test_ydata_synthesizer_generate(raw_data): + synthesizer = YDataSynthesizer(data=raw_data) + synthetic_data = synthesizer.generate() + assert isinstance(synthetic_data, pd.DataFrame) + assert synthetic_data.shape[0] == raw_data.shape[0] + assert set(synthetic_data.columns) == set(raw_data.columns) + + +def test_ydata_synthesizer_multiple_models(raw_data): + synthesizer = YDataSynthesizer(data=raw_data, model_name=['gan', 'wgan']) + results = synthesizer.generate() + assert isinstance(results, Dict) + assert set(results.keys()) == {'gan', 'wgan'} + for _, synthetic_data in results.items(): + assert isinstance(synthetic_data, pd.DataFrame) + assert synthetic_data.shape[0] == raw_data.shape[0] + assert set(synthetic_data.columns) == set(raw_data.columns) + + +def test_ydata_synthesizer_custom_num_rows(raw_data): + num_rows = 10 + synthesizer = YDataSynthesizer(data=raw_data, num_rows=num_rows) + synthetic_data = synthesizer.generate() + assert isinstance(synthetic_data, pd.DataFrame) + assert synthetic_data.shape[0] == num_rows + + +def test_ydata_synthesizer_custom_model_args(raw_data): + model_args = {'lr': 1e-3} + synthesizer = YDataSynthesizer( + data=raw_data, model_name='gan', model_args=model_args + ) + assert synthesizer.model_args['lr'] == 1e-3 + synthetic_data = synthesizer.generate() + assert isinstance(synthetic_data, pd.DataFrame) + + +def test_ydata_synthesizer_custom_train_args(raw_data): + train_args = {'epochs': 50} + synthesizer = YDataSynthesizer( + data=raw_data, model_name='gan', train_args=train_args + ) + assert synthesizer.train_args['epochs'] == 50 + synthetic_data = synthesizer.generate() + assert isinstance(synthetic_data, pd.DataFrame)