Skip to content

Commit

Permalink
- MultiModalAssumptionsBuilder refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
andreygetmanov committed Jul 14, 2022
1 parent 7281b0e commit dd00c9c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
45 changes: 27 additions & 18 deletions fedot/api/api_utils/assumptions/assumptions_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union, Optional
from typing import List, Union, Optional, Set

from fedot.api.api_utils.assumptions.operations_filter import OperationsFilter, WhitelistOperationsFilter
from fedot.api.api_utils.assumptions.preprocessing_builder import PreprocessingBuilder
Expand Down Expand Up @@ -38,7 +38,7 @@ def get(data: Union[InputData, MultiModalData], repository_name: Optional[str] =
raise NotImplementedError(f"Can't build assumptions for data type: {type(data).__name__}")
return cls(data, repository_name=repository_name)

def from_operations(self, available_ops: List[str]):
def from_operations(self, available_operations: List[str]):
raise NotImplementedError('abstract')

def to_builders(self, initial_node: Optional[Node] = None) -> List[PipelineBuilder]:
Expand All @@ -64,9 +64,9 @@ def __init__(self, data: Union[InputData, MultiModalData],

def from_operations(self, available_operations: Optional[List[str]]):
if available_operations:
_check_available_operations(self.data.task.task_type, available_operations)
operations_for_the_task, _ = self.repo.suitable_operation(self.data.task.task_type, self.data_type)
operations_to_choose_from = set(operations_for_the_task).intersection(available_operations)
operations_for_task_and_data, _ = self.repo.suitable_operation(self.data.task.task_type, self.data_type)
operations_to_choose_from = set(operations_for_task_and_data).intersection(available_operations)
_check_operations_to_choose_from(self.data, self.data_type, operations_to_choose_from)
if operations_to_choose_from:
self.ops_filter = WhitelistOperationsFilter(available_operations, operations_to_choose_from)
else:
Expand All @@ -84,7 +84,7 @@ def to_builders(self, initial_node: Optional[Node] = None) -> List[PipelineBuild
valid_builders = []
for processing in self.assumptions_generator.processing_builders():
candidate_builder = preprocessing.merge_with(processing)
if self.ops_filter.satisfies(candidate_builder.to_pipeline(), self.data_type):
if self.ops_filter.satisfies(candidate_builder.to_pipeline()):
valid_builders.append(candidate_builder)
return valid_builders or [self.assumptions_generator.fallback_builder(self.ops_filter)]

Expand All @@ -93,15 +93,17 @@ class MultiModalAssumptionsBuilder(AssumptionsBuilder):
def __init__(self, data: MultiModalData, repository_name: str = "model"):
super().__init__(data, repository_name)
_subbuilders = []
for data_type, (data_source_name, values) in zip(data.data_type, data.items()):
# TODO: can have specific Builder for each particular data column, eg construct InputData
_subbuilders.append((data_source_name, UniModalAssumptionsBuilder(data, data_type)))
for data_type, (data_source_name, values) in zip(self.data.data_type, self.data.items()):
# Performs specific filter on image data operations
if data_type is DataTypesEnum.image:
available_operations = ['data_source_img', 'cnn']
_subbuilders.append((data_source_name, UniModalAssumptionsBuilder(self.data, data_type)
.from_operations(available_operations)))
else:
_subbuilders.append((data_source_name, UniModalAssumptionsBuilder(self.data, data_type)))
self._subbuilders = tuple(_subbuilders)

# TODO: in principle, each data column in MultiModalData can have its own available_ops
def from_operations(self, available_ops: List[str]):
self.logger.info("Available operations are not taken into account when "
"forming the initial assumption for multi-modal data")
def from_operations(self, available_operations: Optional[List[str]]):
return self

def to_builders(self, initial_node: Optional[Node] = None) -> List[PipelineBuilder]:
Expand All @@ -122,8 +124,15 @@ def to_builders(self, initial_node: Optional[Node] = None) -> List[PipelineBuild
return ensemble_builders


def _check_available_operations(task_type: TaskTypesEnum, available_operations: List[str]):
"""Since it is impossible to form a valid pipeline for the time series
without 'lagged' operation, it is added to the list of available operations."""
if task_type is TaskTypesEnum.ts_forecasting and 'lagged' not in available_operations:
available_operations.append('lagged')
def _check_operations_to_choose_from(data, data_type: DataTypesEnum, operations_to_choose_from: Set[str]):
"""Since it is sometimes impossible to form a valid pipeline without some operations,
they are added to the set of operations for current task and data."""
if isinstance(data, MultiModalData):
if data_type is DataTypesEnum.image and 'data_source_img' not in operations_to_choose_from:
operations_to_choose_from.add('data_source_img')
if data_type is DataTypesEnum.text and 'data_source_text' not in operations_to_choose_from:
operations_to_choose_from.add('data_source_text')
if data_type is DataTypesEnum.table and 'data_source_table' not in operations_to_choose_from:
operations_to_choose_from.add('data_source_table')
if data_type is DataTypesEnum.image and 'cnn' not in operations_to_choose_from:
operations_to_choose_from.add('cnn')
7 changes: 2 additions & 5 deletions fedot/api/api_utils/assumptions/operations_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
from typing import Optional, Iterable

from fedot.core.pipelines.pipeline import Pipeline
from fedot.core.repository.dataset_types import DataTypesEnum


class OperationsFilter:
def satisfies(self, pipeline: Optional[Pipeline], data_type: DataTypesEnum) -> bool:
def satisfies(self, pipeline: Optional[Pipeline]) -> bool:
""" Checks if all operations in a Pipeline satisify this filter. """
if data_type is DataTypesEnum.image and 'cnn' not in [node.operation.operation_type for node in pipeline.nodes]:
return False
return True

def sample(self) -> str:
Expand All @@ -26,7 +23,7 @@ def __init__(self, available_operations: Iterable[str], available_task_operation
self._whitelist = tuple(available_operations)
self._choice_operations = tuple(available_task_operations) if available_task_operations else self._whitelist

def satisfies(self, pipeline: Optional[Pipeline], data_type: DataTypesEnum) -> bool:
def satisfies(self, pipeline: Optional[Pipeline]) -> bool:
def node_ok(node):
return node.operation.operation_type in self._whitelist

Expand Down

0 comments on commit dd00c9c

Please sign in to comment.