Skip to content

Commit

Permalink
- refactoring structure of test_text_data_only
Browse files Browse the repository at this point in the history
  • Loading branch information
andreygetmanov committed Aug 30, 2022
1 parent faca601 commit 87e78dd
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions test/unit/data/test_multimodal_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import pytest
from pathlib import Path

from fedot.api.main import Fedot
Expand Down Expand Up @@ -91,32 +92,25 @@ def test_multi_modal_data():
assert np.array_equal(multi_modal.target, new_target)


def test_text_data_only():
# Case when there is no table data in csv, but MultiModalData.from_csv() is used
file_path = 'test/data/simple_multimodal_classification_text.csv'
path = Path(fedot_project_root(), file_path)
@pytest.mark.parametrize('data_type', [DataTypesEnum.text, DataTypesEnum.table])
def test_text_data_only(data_type):
if data_type is DataTypesEnum.text:
# Case when there is no table data in csv, but MultiModalData.from_csv() is used
file_path = 'test/data/simple_multimodal_classification_text.csv'
data_source_name = 'data_source_text/description'
elif data_type is DataTypesEnum.table:
# Case when there is no text data in csv, but MultiModalData.from_csv() is used
file_path = 'test/data/simple_classification.csv'
data_source_name = 'data_source_table'

file_data = InputData.from_csv(path, data_type=DataTypesEnum.text)
file_mm_data = MultiModalData.from_csv(path)

assert len(file_mm_data) == 1
assert file_mm_data['data_source_text/description'].data_type is DataTypesEnum.text
assert file_mm_data['data_source_text/description'].features.all() == file_data.features.all()
assert file_mm_data['data_source_text/description'].target.all() == file_data.target.all()


def test_table_data_only():
# Case when there is no text data in csv, but MultiModalData.from_csv() is used
file_path = 'test/data/simple_classification.csv'
path = Path(fedot_project_root(), file_path)

file_data = InputData.from_csv(path)
file_data = InputData.from_csv(path, data_type=DataTypesEnum.text)
file_mm_data = MultiModalData.from_csv(path)

assert len(file_mm_data) == 1
assert file_mm_data['data_source_table'].data_type is DataTypesEnum.table
assert file_mm_data['data_source_table'].features.all() == file_data.features.all()
assert file_mm_data['data_source_table'].target.all() == file_data.target.all()
assert file_mm_data[data_source_name].data_type is data_type
assert file_mm_data[data_source_name].features.all() == file_data.features.all()
assert file_mm_data[data_source_name].target.all() == file_data.target.all()


def test_multimodal_data_with_complicated_types():
Expand Down

0 comments on commit 87e78dd

Please sign in to comment.