-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: add tests for the several modules
- Loading branch information
Showing
12 changed files
with
685 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pytest | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
@pytest.fixture | ||
def raw_data(): | ||
return pd.DataFrame({ | ||
'A': np.random.randint(0, 100, 100), | ||
'B': np.random.choice([0, 1, 2], 100), | ||
'C': np.random.uniform(0, 1, 100), | ||
'D': np.random.choice([0, 1], 100) | ||
}) | ||
|
||
|
||
@pytest.fixture | ||
def synthetic_data(): | ||
return pd.DataFrame({ | ||
'A': np.random.randint(0, 100, 100), | ||
'B': np.random.choice([0, 1, 2], 100), | ||
'C': np.random.uniform(0, 1, 100), | ||
'D': np.random.choice([0, 1], 100) | ||
}) | ||
|
||
|
||
@pytest.fixture | ||
def col_dtypes(): | ||
return { | ||
'categorical': ['B', 'D'], | ||
'numerical': ['A', 'C'] | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def sample_categorical_data1(): | ||
return pd.Series(['A'] * 40 + ['B'] * 30 + ['C'] * 20 + ['D'] * 10) | ||
|
||
|
||
@pytest.fixture | ||
def sample_categorical_data2(): | ||
return pd.Series( | ||
np.random.choice( | ||
['A', 'B', 'C', 'D'], 100, p=[0.4, 0.3, 0.2, 0.1] | ||
) | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def sample_numerical_data1(): | ||
return pd.Series([1, 2, 3, 4, 5] * 20) | ||
|
||
|
||
@pytest.fixture | ||
def sample_numerical_data2(): | ||
return pd.Series([1, 2, 3, 4, 5, 5, 4, 3, 2, 1] * 10) |
55 changes: 55 additions & 0 deletions
55
tests/test_sdqc_check/test_causality/test_causal_analysis.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pytest | ||
import castle | ||
import numpy as np | ||
import pandas as pd | ||
from sdqc_check import CausalAnalysis | ||
from sdqc_data import read_data | ||
|
||
|
||
@pytest.fixture | ||
def sample_data(raw_data, synthetic_data): | ||
return raw_data, synthetic_data | ||
|
||
|
||
def test_causal_analysis_initialization(sample_data): | ||
raw_data, synthetic_data = sample_data | ||
ca = CausalAnalysis(raw_data, synthetic_data) | ||
assert isinstance(ca, CausalAnalysis) | ||
assert np.array_equal(ca.raw_data, raw_data.to_numpy()) | ||
assert np.array_equal(ca.synthetic_data, synthetic_data.to_numpy()) | ||
assert ca.model_name == 'dlg' | ||
assert ca.random_seed == 17 | ||
assert ca.device_type == 'cpu' | ||
assert ca.device_id == 0 | ||
|
||
|
||
def test_causal_analysis_invalid_model(): | ||
with pytest.raises(ValueError): | ||
CausalAnalysis( | ||
pd.DataFrame(), pd.DataFrame(), model_name='invalid_model' | ||
) | ||
|
||
|
||
@pytest.mark.parametrize('model_name', ['dlg', 'notears', 'golem', 'grandag', 'gae']) | ||
def test_get_model(sample_data, model_name): | ||
raw_data, synthetic_data = sample_data | ||
ca = CausalAnalysis(raw_data, synthetic_data, model_name=model_name) | ||
model = ca._get_model(model_name) | ||
assert isinstance(model, castle.common.BaseLearner) | ||
|
||
|
||
def test_compute_causal_matrices(sample_data): | ||
raw_data, synthetic_data = sample_data | ||
ca = CausalAnalysis(raw_data, synthetic_data) | ||
raw_matrix, synthetic_matrix = ca.compute_causal_matrices() | ||
assert isinstance(raw_matrix, np.ndarray) | ||
assert isinstance(synthetic_matrix, np.ndarray) | ||
assert raw_matrix.shape == (4, 4) | ||
assert synthetic_matrix.shape == (4, 4) | ||
|
||
def test_compare_adjacency_matrices(): | ||
raw_data = read_data('3_raw') | ||
synthetic_data = read_data('3_synth') | ||
ca = CausalAnalysis(raw_data, synthetic_data) | ||
mt = ca.compare_adjacency_matrices() | ||
assert isinstance(mt, castle.MetricsDAG) |
71 changes: 71 additions & 0 deletions
71
tests/test_sdqc_check/test_classification/test_classification_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import pytest | ||
import pandas as pd | ||
from sdqc_check import ClassificationModel | ||
|
||
|
||
@pytest.fixture | ||
def sample_data(raw_data, synthetic_data): | ||
return raw_data, synthetic_data | ||
|
||
|
||
def test_classification_model_initialization(sample_data): | ||
raw_data, synthetic_data = sample_data | ||
cm = ClassificationModel(raw_data, synthetic_data) | ||
assert isinstance(cm, ClassificationModel) | ||
assert cm.model_name == 'rf' | ||
assert cm.test_size == 0.2 | ||
assert cm.random_seed == 17 | ||
|
||
|
||
def test_classification_model_invalid_model(): | ||
with pytest.raises(ValueError): | ||
ClassificationModel( | ||
pd.DataFrame(), pd.DataFrame(), model_name='invalid_model' | ||
) | ||
|
||
|
||
@pytest.mark.parametrize('model_name', ['svm', 'rf', 'xgb', 'lgbm']) | ||
def test_classification_model_single_model(sample_data, model_name): | ||
raw_data, synthetic_data = sample_data | ||
cm = ClassificationModel(raw_data, synthetic_data, model_name=model_name) | ||
metrics, models = cm.train_and_evaluate_models() | ||
assert isinstance(metrics, pd.DataFrame) | ||
assert len(models) == 1 | ||
assert metrics['Model'].iloc[0] == model_name | ||
|
||
|
||
def test_classification_model_multiple_models(sample_data): | ||
raw_data, synthetic_data = sample_data | ||
model_names = ['svm', 'rf', 'xgb', 'lgbm'] | ||
cm = ClassificationModel(raw_data, synthetic_data, model_name=model_names) | ||
metrics, models = cm.train_and_evaluate_models() | ||
assert isinstance(metrics, pd.DataFrame) | ||
assert len(models) == len(model_names) | ||
assert set(metrics['Model']) == set(model_names) | ||
|
||
|
||
def test_classification_model_custom_params(sample_data): | ||
raw_data, synthetic_data = sample_data | ||
custom_params = { | ||
'rf': {'n_estimators': 100, 'max_depth': 5}, | ||
'xgb': {'n_estimators': 100, 'max_depth': 5} | ||
} | ||
cm = ClassificationModel( | ||
raw_data, synthetic_data, | ||
model_name=['rf', 'xgb'], | ||
model_params=custom_params | ||
) | ||
metrics, models = cm.train_and_evaluate_models() | ||
assert isinstance(metrics, pd.DataFrame) | ||
assert len(models) == 2 | ||
assert set(metrics['Model']) == {'rf', 'xgb'} | ||
|
||
|
||
def test_classification_model_metrics(sample_data): | ||
raw_data, synthetic_data = sample_data | ||
cm = ClassificationModel(raw_data, synthetic_data) | ||
metrics, _ = cm.train_and_evaluate_models() | ||
expected_columns = [ | ||
'Model', 'Accuracy', 'Precision', 'Recall', 'F1', 'AUC' | ||
] | ||
assert set(metrics.columns) == set(expected_columns) |
76 changes: 76 additions & 0 deletions
76
tests/test_sdqc_check/test_explainability/test_explainability.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import pytest | ||
import pandas as pd | ||
from sklearn.datasets import make_classification | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.svm import SVC | ||
from sdqc_check import ShapFeatureImportance, PFIFeatureImportance | ||
|
||
|
||
@pytest.fixture | ||
def sample_data(): | ||
X, y = make_classification( | ||
n_samples=100, n_features=10, n_informative=5, random_state=17 | ||
) | ||
X = pd.DataFrame(X, columns=[f'feature_{i + 1}' for i in range(10)]) | ||
y = pd.Series(y) | ||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, test_size=0.2, random_state=17 | ||
) | ||
model = RandomForestClassifier(random_state=17) | ||
model.fit(X_train, y_train) | ||
return model, X_train, X_test, y_train, y_test | ||
|
||
|
||
@pytest.mark.parametrize("FeatureImportance", [ShapFeatureImportance, PFIFeatureImportance]) | ||
def test_feature_importance(sample_data, FeatureImportance): | ||
model, X_train, X_test, y_train, y_test = sample_data | ||
importance = FeatureImportance(model, X_train, X_test, y_test) | ||
result = importance.compute_feature_importance() | ||
|
||
assert isinstance(result, pd.DataFrame) | ||
assert len(result) == 10 | ||
assert set(result.columns) == {'feature', 'importance'} | ||
assert result['feature'].nunique() == 10 | ||
assert all( | ||
f'feature_{i + 1}' in result['feature'].values for i in range(10) | ||
) | ||
assert result['importance'].is_monotonic_decreasing | ||
|
||
|
||
@pytest.mark.parametrize("FeatureImportance", [ShapFeatureImportance, PFIFeatureImportance]) | ||
def test_feature_importance_random_seed(sample_data, FeatureImportance): | ||
model, X_train, X_test, y_train, y_test = sample_data | ||
importance1 = FeatureImportance( | ||
model, X_train, X_test, y_test, random_seed=17 | ||
) | ||
importance2 = FeatureImportance( | ||
model, X_train, X_test, y_test, random_seed=17 | ||
) | ||
|
||
scores1 = importance1.compute_feature_importance() | ||
scores2 = importance2.compute_feature_importance() | ||
|
||
pd.testing.assert_frame_equal(scores1, scores2) | ||
|
||
|
||
@pytest.mark.parametrize("model_class", [RandomForestClassifier, SVC]) | ||
def test_shap_feature_importance_models(sample_data, model_class): | ||
_, X_train, X_test, y_train, y_test = sample_data | ||
if model_class == SVC: | ||
model = model_class(random_state=17, probability=True) | ||
else: | ||
model = model_class(random_state=17) | ||
model.fit(X_train, y_train) | ||
|
||
shap_importance = ShapFeatureImportance(model, X_train, X_test, y_test) | ||
importance = shap_importance.compute_feature_importance() | ||
|
||
assert isinstance(importance, pd.DataFrame) | ||
assert len(importance) == 10 | ||
assert set(importance.columns) == {'feature', 'importance'} | ||
assert importance['feature'].nunique() == 10 | ||
assert all( | ||
f'feature_{i + 1}' in importance['feature'].values for i in range(10) | ||
) | ||
assert importance['importance'].is_monotonic_decreasing |
68 changes: 68 additions & 0 deletions
68
tests/test_sdqc_check/test_statistical_test/test_categorical_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import pytest | ||
import pandas as pd | ||
from sdqc_check import CategoricalTest | ||
|
||
|
||
@pytest.fixture | ||
def categorical_test(): | ||
return CategoricalTest() | ||
|
||
|
||
def test_basis(categorical_test, sample_categorical_data1): | ||
result = categorical_test.basis(sample_categorical_data1) | ||
assert isinstance(result, pd.Series) | ||
assert result['count'] == 100 | ||
assert result['missing'] == 0 | ||
assert result['unique'] == 4 | ||
assert result['top'] == 'A' | ||
assert result['freq'] == 40 | ||
|
||
|
||
def test_distribution( | ||
categorical_test, | ||
sample_categorical_data1, | ||
sample_categorical_data2 | ||
): | ||
jaccard, p_value = categorical_test.distribution( | ||
sample_categorical_data1, sample_categorical_data2 | ||
) | ||
assert isinstance(jaccard, float) | ||
assert isinstance(p_value, float) | ||
assert 0 <= jaccard <= 1 | ||
assert 0 <= p_value <= 1 | ||
|
||
|
||
def test_jaccard_index_identical(categorical_test): | ||
data = pd.Series(['A', 'B', 'C']) | ||
jaccard, _ = categorical_test.distribution(data, data) | ||
assert jaccard == 1 | ||
|
||
|
||
def test_jaccard_index_disjoint(categorical_test): | ||
data1 = pd.Series(['A', 'B', 'C']) | ||
data2 = pd.Series(['D', 'E', 'F']) | ||
jaccard, _ = categorical_test.distribution(data1, data2) | ||
assert jaccard == 0 | ||
|
||
|
||
def test_jaccard_index_different(categorical_test): | ||
data1 = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'C', 'A', 'B', 'C']) | ||
data2 = pd.Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'C', 'D']) | ||
jaccard, p_value = categorical_test.distribution(data1, data2) | ||
assert jaccard < 1 | ||
assert p_value == 0 | ||
|
||
|
||
def test_chi_square_identical(categorical_test): | ||
data1 = pd.Series(['A'] * 4 + ['B'] * 2 + ['C'] * 3 + ['D'] * 1) | ||
data2 = pd.Series(['D'] * 1 + ['C'] * 3 + ['B'] * 2 + ['A'] * 4) | ||
|
||
_, p_value = categorical_test.distribution(data1, data2) | ||
assert p_value == 1 | ||
|
||
|
||
def test_chi_square_different(categorical_test): | ||
data1 = pd.Series(['A', 'B', 'A', 'C', 'B', 'A', 'D', 'A', 'B', 'C']) | ||
data2 = pd.Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'C', 'D']) | ||
_, p_value = categorical_test.distribution(data1, data2) | ||
assert p_value < 1 |
46 changes: 46 additions & 0 deletions
46
tests/test_sdqc_check/test_statistical_test/test_correlation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import pandas as pd | ||
import numpy as np | ||
from sdqc_check import data_corr | ||
|
||
|
||
def test_data_corr(raw_data, col_dtypes): | ||
result = data_corr(raw_data, col_dtypes) | ||
|
||
assert isinstance(result, pd.DataFrame) | ||
assert set(result.columns) == {'column1', | ||
'column2', 'method', 'corr_coefficient'} | ||
|
||
expected_pairs = [ | ||
('A', 'B'), ('A', 'C'), ('A', 'D'), | ||
('B', 'C'), ('B', 'D'), ('C', 'D') | ||
] | ||
assert set(zip(result['column1'], result['column2']) | ||
) == set(expected_pairs) | ||
|
||
assert set(result['method']) == {"Cramer's V", 'Pearson', 'Eta'} | ||
assert (result['corr_coefficient'] >= - | ||
1).all() and (result['corr_coefficient'] <= 1).all() | ||
|
||
|
||
def test_data_corr_perfect_correlation(): | ||
perfect_data = pd.DataFrame({ | ||
'cat1': ['A', 'B'] * 50, | ||
'cat2': ['X', 'Y'] * 50, | ||
'num1': list(range(100)), | ||
'num2': list(range(0, 200, 2)) | ||
}) | ||
|
||
col_dtypes = { | ||
'categorical': ['cat1', 'cat2'], | ||
'numerical': ['num1', 'num2'] | ||
} | ||
|
||
result = data_corr(perfect_data, col_dtypes) | ||
|
||
cramer_v = result[(result['column1'] == 'cat1') & ( | ||
result['column2'] == 'cat2')]['corr_coefficient'].values[0] | ||
assert np.isclose(cramer_v, 1.0, atol=0.02) | ||
|
||
pearson = result[(result['column1'] == 'num1') & ( | ||
result['column2'] == 'num2')]['corr_coefficient'].values[0] | ||
assert np.isclose(pearson, 1.0) |
Oops, something went wrong.