Skip to content

Commit

Permalink
- added pipeline tuning to multi_modal_genre_prediction.py
Browse files Browse the repository at this point in the history
- now there is no useless try of download of stopwords and other nltk packages if they are already downloaded
- keras.Input changed to recommended keras.layers.InputLayer
- test_multi_modal.py is moved to multimodal folder
  • Loading branch information
andreygetmanov committed Apr 6, 2022
1 parent 6cd90a3 commit 3337a68
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
17 changes: 14 additions & 3 deletions cases/multi_modal_genre_prediction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import datetime

from sklearn.metrics import f1_score

from examples.advanced.multi_modal_pipeline import calculate_validation_metric, \
generate_initial_pipeline_and_data, prepare_multi_modal_data
from fedot.core.composer.composer_builder import ComposerBuilder
Expand All @@ -7,6 +10,7 @@
from fedot.core.optimisers.gp_comp.gp_optimiser import GPGraphOptimiserParameters, GeneticSchemeTypesEnum
from fedot.core.repository.operation_types_repository import get_operations_for_task
from fedot.core.repository.quality_metrics_repository import ClassificationMetricsEnum
from fedot.core.pipelines.tuning.unified import PipelineTuner
from fedot.core.repository.tasks import Task, TaskTypesEnum


Expand Down Expand Up @@ -49,13 +53,20 @@ def run_multi_modal_case(files_path, is_visualise=True, timeout=datetime.timedel
# the optimal pipeline generation by composition - the most time-consuming task
pipeline_evo_composed = composer.compose_pipeline(data=fit_data,
is_visualise=True)
pipeline_evo_composed.print_structure()

pipeline_evo_composed.fit(input_data=fit_data)
# tuning of the composed pipeline
pipeline_tuner = PipelineTuner(pipeline=pipeline_evo_composed, task=task, iterations=15)
tuned_pipeline = pipeline_tuner.tune_pipeline(input_data=fit_data,
loss_function=f1_score,
loss_params={'average': 'micro'})
tuned_pipeline.print_structure()
tuned_pipeline.fit(input_data=fit_data)

if is_visualise:
pipeline_evo_composed.show()
tuned_pipeline.show()

prediction = pipeline_evo_composed.predict(predict_data, output_mode='labels')
prediction = tuned_pipeline.predict(predict_data, output_mode='labels')
err = calculate_validation_metric(predict_data, prediction)

print(f'F1 micro for validation sample is {err}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,16 @@ def transform(self, input_data, is_fit_pipeline_stage: Optional[bool]):

@staticmethod
def _download_nltk_resources():
for resource in ['punkt', 'stopwords', 'wordnet', 'omw-1.4']:
for resource in ['punkt']:
try:
nltk.data.find(f'tokenizers/{resource}')
except LookupError:
nltk.download(f'{resource}')
for resource in ['stopwords', 'wordnet', 'omw-1.4']:
try:
nltk.data.find(f'corpora/{resource}')
except LookupError:
nltk.download(f'{resource}')

@staticmethod
def _word_vectorize(text):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_deep_cnn(input_shape: tuple,
num_classes: int):
model = tf.keras.Sequential(
[
tf.keras.Input(shape=input_shape),
tf.keras.layers.InputLayer(input_shape=input_shape),
tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
Expand All @@ -51,7 +51,7 @@ def create_simple_cnn(input_shape: tuple,
num_classes: int):
model = tf.keras.Sequential(
[
tf.keras.Input(shape=input_shape),
tf.keras.layers.InputLayer(input_shape=input_shape),
tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

from examples.advanced.multi_modal_pipeline import (prepare_multi_modal_data)
from fedot.core.data.multi_modal import MultiModalData
from fedot.core.pipelines.node import PrimaryNode, SecondaryNode
from fedot.core.pipelines.pipeline import Pipeline
from fedot.core.repository.tasks import Task, TaskTypesEnum
Expand Down

0 comments on commit 3337a68

Please sign in to comment.