Skip to content

Commit

Permalink
- rewrote path in test_multimodal_data.py by Path
Browse files Browse the repository at this point in the history
  • Loading branch information
andreygetmanov committed Aug 12, 2022
1 parent d7374e3 commit ff59b62
Showing 1 changed file with 18 additions and 21 deletions.
39 changes: 18 additions & 21 deletions test/unit/data/test_multimodal_data.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import os

import numpy as np
import pandas as pd
from pathlib import Path

from fedot.core.data.data import InputData
from fedot.core.data.multi_modal import MultiModalData
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.repository.tasks import Task, TaskTypesEnum
from fedot.core.utils import fedot_project_root


def test_multimodal_data_from_csv():
test_file_path = str(os.path.dirname(__file__))
file = '../../data/simple_multimodal_classification.csv'
task = Task(TaskTypesEnum.classification)
df = pd.read_csv(os.path.join(test_file_path, file))
file_path = 'test/data/simple_multimodal_classification.csv'
path = Path(fedot_project_root(), file_path)
df = pd.read_csv(path)
text_data = np.array(df['description'])
table_data = np.array(df.drop(columns=['id', 'description', 'variety']))
target = np.array(df['variety'])
Expand All @@ -26,33 +26,32 @@ def test_multimodal_data_from_csv():
idx=idx,
task=task,
data_type=DataTypesEnum.table).features
actual_data = MultiModalData.from_csv(os.path.join(test_file_path, file))
actual_data = MultiModalData.from_csv(path)
actual_text_features = actual_data['data_source_text/description'].features
actual_table_features = actual_data['data_source_table'].features
assert np.array_equal(expected_text_features, actual_text_features)
assert np.array_equal(expected_table_features, actual_table_features)


def test_multimodal_data_with_custom_target():
test_file_path = str(os.path.dirname(__file__))
file = '../../data/simple_multimodal_classification.csv'
file_custom = '../../data/simple_multimodal_classification_with_custom_target.csv'

file_data = MultiModalData.from_csv(os.path.join(test_file_path, file))
file_path = 'test/data/simple_multimodal_classification.csv'
path = Path(fedot_project_root(), file_path)
file_data = MultiModalData.from_csv(path)

expected_table_features = file_data['data_source_table'].features
expected_target = file_data.target

custom_file_data = MultiModalData.from_csv(os.path.join(test_file_path, file_custom))
file_custom_path = 'test/data/simple_multimodal_classification_with_custom_target.csv'
path_custom = Path(fedot_project_root(), file_custom_path)
custom_file_data = MultiModalData.from_csv(path_custom)
actual_table_features = custom_file_data['data_source_table'].features
actual_target = custom_file_data.target

assert not np.array_equal(expected_table_features, actual_table_features)
assert not np.array_equal(expected_target, actual_target)

custom_file_data = MultiModalData.from_csv(
os.path.join(test_file_path, file_custom),
columns_to_drop=['redundant'], target_columns='variety')
custom_file_data = MultiModalData.from_csv(path_custom,
columns_to_drop=['redundant'], target_columns='variety')

actual_table_features = custom_file_data['data_source_table'].features
actual_target = custom_file_data.target
Expand Down Expand Up @@ -93,9 +92,8 @@ def test_multi_modal_data():

def test_text_data_only():
# Case when there is no table data in csv, but MultiModalData.from_csv() is used
test_file_path = str(os.path.dirname(__file__))
file = '../../data/simple_multimodal_classification_text.csv'
path = os.path.join(test_file_path, file)
file_path = 'test/data/simple_multimodal_classification_text.csv'
path = Path(fedot_project_root(), file_path)

file_data = InputData.from_csv(path, data_type=DataTypesEnum.text)
file_mm_data = MultiModalData.from_csv(path)
Expand All @@ -108,9 +106,8 @@ def test_text_data_only():

def test_table_data_only():
# Case when there is no text data in csv, but MultiModalData.from_csv() is used
test_file_path = str(os.path.dirname(__file__))
file = '../../data/simple_classification.csv'
path = os.path.join(test_file_path, file)
file_path = 'test/data/simple_classification.csv'
path = Path(fedot_project_root(), file_path)

file_data = InputData.from_csv(path)
file_mm_data = MultiModalData.from_csv(path)
Expand Down

0 comments on commit ff59b62

Please sign in to comment.