Skip to content

Commit

Permalink
- refactoring of test_regression.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andreygetmanov committed Jan 24, 2024
1 parent f6624fe commit 936d8f9
Showing 1 changed file with 27 additions and 35 deletions.
62 changes: 27 additions & 35 deletions test/unit/tasks/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,29 @@ def get_synthetic_regression_data(n_samples=1000, n_features=10, random_state=No
return input_data


def get_regression_data_from_numpy():
def get_regression_data(source: str) -> InputData:
test_file_path = str(os.path.dirname(__file__))
file = '../../data/regression/simple_regression.npy'
numpy_data = np.load(os.path.join(test_file_path, file))
features_array = numpy_data[:, :-1]
target_array = numpy_data[:, -1]
input_data = InputData.from_numpy(features_array=features_array,
target_array=target_array,
task='regression')
return input_data


def get_regression_data_from_df():
test_file_path = str(os.path.dirname(__file__))
file = '../../data/regression/simple_regression.csv'
df_data = pd.read_csv(os.path.join(test_file_path, file))
features_df = df_data.iloc[:, :-1]
target_df = df_data.iloc[:, -1]
input_data = InputData.from_dataframe(features_df=features_df,
target_df=target_df,
if source == 'numpy':
file = '../../data/regression/simple_regression.npy'
numpy_data = np.load(os.path.join(test_file_path, file))
features_array = numpy_data[:, :-1]
target_array = numpy_data[:, -1]
return InputData.from_numpy(features_array=features_array,
target_array=target_array,
task='regression')
return input_data


def get_regression_data_from_csv():
test_file_path = str(os.path.dirname(__file__))
file = '../../data/regression/simple_regression.csv'
input_data = InputData.from_csv(
os.path.join(test_file_path, file),
task='regression')
return input_data
elif source == 'dataframe':
file = '../../data/regression/simple_regression.csv'
df_data = pd.read_csv(os.path.join(test_file_path, file))
features_df = df_data.iloc[:, :-1]
target_df = df_data.iloc[:, -1]
return InputData.from_dataframe(features_df=features_df,
target_df=target_df,
task='regression')
elif source == 'csv':
file = '../../data/regression/simple_regression.csv'
return InputData.from_csv(
os.path.join(test_file_path, file),
task='regression')


def get_rmse_value(pipeline: Pipeline, train_data: InputData, test_data: InputData) -> (float, float):
Expand All @@ -98,19 +90,19 @@ def get_rmse_value(pipeline: Pipeline, train_data: InputData, test_data: InputDa
return rmse_value_train, rmse_value_test


REGRESSION_DATA_SOURCES = [get_regression_data_from_numpy,
get_regression_data_from_df,
get_regression_data_from_csv,
REGRESSION_DATA_SOURCES = ['numpy',
'dataframe',
'csv',
# 'from_image',
# 'from_text_meta_file',
# 'from_text_files',
# 'from_json_files',
]


@pytest.mark.parametrize('get_regression_data', REGRESSION_DATA_SOURCES)
def test_regression_pipeline_fit_predict_correct(get_regression_data: Callable):
data = get_regression_data()
@pytest.mark.parametrize('source', REGRESSION_DATA_SOURCES)
def test_regression_pipeline_fit_predict_correct(source: str):
data = get_regression_data(source)
pipeline = generate_pipeline()
train_data, test_data = train_test_data_setup(data, shuffle=True)

Expand Down

0 comments on commit 936d8f9

Please sign in to comment.