-
Notifications
You must be signed in to change notification settings - Fork 87
/
Copy pathtask_assumptions.py
127 lines (101 loc) · 4.75 KB
/
task_assumptions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from abc import abstractmethod
from typing import List
from fedot.api.api_utils.assumptions.operations_filter import OperationsFilter
from fedot.core.pipelines.pipeline_builder import PipelineBuilder
from fedot.core.repository.operation_types_repository import OperationTypesRepository
from fedot.core.repository.tasks import Task, TaskTypesEnum
from fedot.utilities.custom_errors import AbstractMethodNotImplementError
class TaskAssumptions:
""" Abstracts task-specific pipeline assumptions. """
def __init__(self, repository: OperationTypesRepository):
self.repo = repository
@staticmethod
def for_task(task: Task, repository: OperationTypesRepository) -> 'TaskAssumptions':
assumptions_by_task = {
TaskTypesEnum.classification: ClassificationAssumptions,
TaskTypesEnum.regression: RegressionAssumptions,
TaskTypesEnum.ts_forecasting: TSForecastingAssumptions,
}
assumptions_cls: TaskAssumptions = assumptions_by_task.get(task.task_type)
if not assumptions_cls:
raise NotImplementedError(f"Don't have assumptions for task type: {task.task_type}")
return assumptions_cls(repository)
@abstractmethod
def ensemble_operation(self) -> str:
""" Suitable ensemble operation used for MultiModalData case. """
raise AbstractMethodNotImplementError
@abstractmethod
def processing_builders(self) -> List[PipelineBuilder]:
""" Returns alternatives of PipelineBuilders for core processing (without preprocessing). """
raise AbstractMethodNotImplementError
@abstractmethod
def fallback_builder(self, operations_filter: OperationsFilter) -> PipelineBuilder:
"""
Returns default PipelineBuilder for case when primary alternatives are not valid.
Have access for OperationsFilter for sampling available operations.
"""
raise AbstractMethodNotImplementError
class TSForecastingAssumptions(TaskAssumptions):
""" Simple static dictionary-based assumptions for time series forecasting task. """
@property
def builders(self):
return {
'lagged_ridge':
PipelineBuilder().add_sequence('lagged', 'ridge'),
'topological':
PipelineBuilder()
.add_node('lagged')
.add_node('topological_features')
.add_node('lagged', branch_idx=1)
.join_branches('ridge'),
'polyfit_ridge':
PipelineBuilder()
.add_branch('polyfit', 'lagged')
.grow_branches(None, 'ridge')
.join_branches('ridge'),
'smoothing_ar':
PipelineBuilder()
.add_sequence('smoothing', 'ar'),
}
def ensemble_operation(self) -> str:
return 'ridge'
def processing_builders(self) -> List[PipelineBuilder]:
return list(self.builders.values())
def fallback_builder(self, operations_filter: OperationsFilter) -> PipelineBuilder:
random_choice_node = operations_filter.sample()
operation_info = self.repo.operation_info_by_id(random_choice_node)
if 'non_lagged' in operation_info.tags:
return PipelineBuilder().add_node(random_choice_node)
else:
return PipelineBuilder().add_node('lagged').add_node(random_choice_node)
class RegressionAssumptions(TaskAssumptions):
""" Simple static dictionary-based assumptions for regression task. """
@property
def builders(self):
return {
'rfr': PipelineBuilder().add_node('rfr'),
'ridge': PipelineBuilder().add_node('ridge'),
}
def ensemble_operation(self) -> str:
return 'rfr'
def processing_builders(self) -> List[PipelineBuilder]:
return list(self.builders.values())
def fallback_builder(self, operations_filter: OperationsFilter) -> PipelineBuilder:
random_choice_node = operations_filter.sample()
return PipelineBuilder().add_node(random_choice_node)
class ClassificationAssumptions(TaskAssumptions):
""" Simple static dictionary-based assumptions for classification task. """
@property
def builders(self):
return {
'rf': PipelineBuilder().add_node('rf'),
'logit': PipelineBuilder().add_node('logit'),
'catboost': PipelineBuilder().add_node('catboost'),
}
def ensemble_operation(self) -> str:
return 'rf'
def processing_builders(self) -> List[PipelineBuilder]:
return list(self.builders.values())
def fallback_builder(self, operations_filter: OperationsFilter) -> PipelineBuilder:
random_choice_node = operations_filter.sample()
return PipelineBuilder().add_node(random_choice_node)