Skip to content

Commit

Permalink
Update trainer.py
Browse files Browse the repository at this point in the history
support pytorch_lightning
  • Loading branch information
thompson0012 committed Oct 28, 2021
1 parent a925e05 commit 4670dc0
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions pyemits/core/ml/regression/trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.linear_model import ElasticNet, Ridge, Lasso, BayesianRidge, HuberRegressor
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
from pyemits.core.ml.base import BaseTrainer, BaseWrapper, NeuralNetworkWrapperBase
from pyemits.common.config_model import BaseConfig, KerasSequentialConfig
from pyemits.core.ml.regression.nn import TorchLightningWrapper
from pyemits.common.config_model import BaseConfig, KerasSequentialConfig, TorchLightningSequentialConfig
from pyemits.common.data_model import RegressionDataModel
from pyemits.common.py_native_dtype import SliceableDeque
from pyemits.common.validation import raise_if_value_not_contains
Expand Down Expand Up @@ -42,6 +44,14 @@ def __init__(self,
raw_data_model: RegressionDataModel,
other_config: Dict[str, Union[List, BaseConfig, Any]] = {}):
"""
universal class for regression model training,
all-in-one training including sklearn, xgboost, lightgbm, keras, pytorch_lightning
you are not required to fill the algo config if you have idea on algo_config
the algo config is designed for people to config their model based on the configuration that provided in config_model
so that people can easily config their model during creation
for Pytorch_lightning user, pls configured your model before use this. at that moment, no algo_config is
Parameters
----------
Expand Down Expand Up @@ -111,6 +121,11 @@ def fill_algo_config_clf(self,
clf_or_wrapper.model_obj.add(i)
clf_or_wrapper.model_obj.compile(**algo_config.compile)
return clf_or_wrapper
elif isinstance(algo_config, TorchLightningSequentialConfig):
clf_or_wrapper: TorchLightningWrapper
for nos, layer in enumerate(algo_config.layer, 1):
clf_or_wrapper.add_layer2blank_model(str(nos), layer)
return clf_or_wrapper
# not support pytorch, mxnet model right now
raise TypeError('now only support KerasSequentialConfig')

Expand All @@ -129,15 +144,28 @@ def fill_fit_config_clf(self,
y,
fit_config: Optional[Union[BaseConfig, Dict]] = None,
):
from pyemits.core.ml.regression.nn import torchlighting_data_helper
# nn wrapper
if isinstance(clf_or_wrapper, NeuralNetworkWrapperBase):
dl_train, dl_val = torchlighting_data_helper(X, y)

if fit_config is None:
# pytorch_lightning path
if isinstance(clf_or_wrapper, TorchLightningWrapper):
return clf_or_wrapper.fit(dl_train, dl_val)
# keras path
return clf_or_wrapper.fit(X, y)

if isinstance(fit_config, BaseConfig):
if isinstance(clf_or_wrapper, TorchLightningWrapper):
return clf_or_wrapper.fit(dl_train, dl_val, **dict(fit_config))
# keras path
return clf_or_wrapper.fit(X, y, **dict(fit_config))

elif isinstance(fit_config, Dict):
if isinstance(clf_or_wrapper, TorchLightningWrapper):
return clf_or_wrapper.fit(dl_train, dl_val, **fit_config)
# keras path
return clf_or_wrapper.fit(X, y, **fit_config)

# sklearn/xgboost/lightgbm clf
Expand Down Expand Up @@ -204,7 +232,6 @@ def _get_fitted_trainer(algo: List,
algo_, algo_config_ in
zip(self._algo, self._algo_config))

# self.clf_models = [obj.clf_models for obj in out]
for obj in out:
self.clf_models.append(obj.clf_models)
return
Expand All @@ -214,6 +241,12 @@ def fit(self):


class MultiOutputRegTrainer(RegTrainer):
"""
machine learning based multioutput regression trainer
bring forecasting power into machine learning model,
forecasting is not only the power of deep learning
"""

def __init__(self,
algo: List[Union[str, Any]],
algo_config: List[Optional[BaseConfig]],
Expand Down

0 comments on commit 4670dc0

Please sign in to comment.