Skip to content

Commit

Permalink
pep8 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanKharkovskoy committed Dec 19, 2023
1 parent 3583238 commit 793619c
Showing 1 changed file with 27 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from catboost import CatBoostClassifier, CatBoostRegressor, Pool
from matplotlib import pyplot as plt
from xgboost import XGBClassifier, XGBRegressor, DMatrix
import xgboost

from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
Expand All @@ -20,19 +19,21 @@ class FedotXGBoostImplementation(ModelImplementation):

def __init__(self, params: Optional[OperationParameters] = None):
super().__init__(params)

self.check_and_update_params()

self.model_params = {k: v for k, v in self.params.to_dict().items() if k not in self.__operation_params}

self.model_params = {k: v for k, v in self.params.to_dict(
).items() if k not in self.__operation_params}
self.model = None

def fit(self, input_data: InputData):
input_data = input_data.get_not_encoded_data()

if self.params.get('use_eval_set'):
train_input, eval_input = train_test_data_setup(input_data)

self.model.fit(X=train_input.features, y=train_input.target, eval_set=[(eval_input.features, eval_input.target)])
self.model.fit(X=train_input.features, y=train_input.target, eval_set=[
(eval_input.features, eval_input.target)])

else:

Expand All @@ -41,23 +42,34 @@ def fit(self, input_data: InputData):
return self.model

def predict(self, input_data: InputData):
prediction = self.model.predict(input_data.get_not_encoded_data().features)
prediction = self.model.predict(
input_data.get_not_encoded_data().features)

return prediction

def check_and_update_params(self):
n_jobs = self.params.get('n_jobs')
self.params.update(nthread=n_jobs)
self.params.update(thread_count=n_jobs)

use_best_model = self.params.get('use_best_model')
early_stopping_rounds = self.params.get('early_stopping_rounds')
use_eval_set = self.params.get('use_eval_set')

if use_best_model or early_stopping_rounds and not use_eval_set:
self.params.update(use_best_model=False,
early_stopping_rounds=False)

@staticmethod
def convert_to_dmatrix(data: Optional[InputData]):
return DMatrix(
data=data.features,
label=data.target,
enable_categorical=True,
feature_names=data.features_names.tolist().tolist() if data.features_names is not None else None
feature_names=data.features_names.tolist().tolist(
) if data.features_names is not None else None
)


class FedotXGBoostClassificationImplementation(FedotXGBoostImplementation):
def __init__(self, params: Optional[OperationParameters] = None):
super().__init__(params)
Expand All @@ -77,7 +89,7 @@ def predict_proba(self, input_data: InputData):
class FedotXGBoostRegressionImplementation(FedotXGBoostImplementation):
def __init__(self, params: Optional[OperationParameters] = None):
super().__init__(params)
self.model_params['objective'] = 'reg:squarederror'
self.model = XGBRegressor(**self.model_params)


class FedotCatBoostImplementation(ModelImplementation):
Expand All @@ -88,7 +100,8 @@ def __init__(self, params: Optional[OperationParameters] = None):

self.check_and_update_params()

self.model_params = {k: v for k, v in self.params.to_dict().items() if k not in self.__operation_params}
self.model_params = {k: v for k, v in self.params.to_dict(
).items() if k not in self.__operation_params}
self.model = None

def fit(self, input_data: InputData):
Expand Down Expand Up @@ -125,7 +138,8 @@ def check_and_update_params(self):
use_eval_set = self.params.get('use_eval_set')

if use_best_model or early_stopping_rounds and not use_eval_set:
self.params.update(use_best_model=False, early_stopping_rounds=False)
self.params.update(use_best_model=False,
early_stopping_rounds=False)

@staticmethod
def convert_to_pool(data: Optional[InputData]):
Expand Down

0 comments on commit 793619c

Please sign in to comment.