diff --git a/pyemits/core/ml/regression/trainer.py b/pyemits/core/ml/regression/trainer.py index 74f61f1..3a52fcf 100644 --- a/pyemits/core/ml/regression/trainer.py +++ b/pyemits/core/ml/regression/trainer.py @@ -1,10 +1,12 @@ +import torch from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor from sklearn.neural_network import MLPRegressor from sklearn.linear_model import ElasticNet, Ridge, Lasso, BayesianRidge, HuberRegressor from xgboost import XGBRegressor from lightgbm import LGBMRegressor from pyemits.core.ml.base import BaseTrainer, BaseWrapper, NeuralNetworkWrapperBase -from pyemits.common.config_model import BaseConfig, KerasSequentialConfig +from pyemits.core.ml.regression.nn import TorchLightningWrapper +from pyemits.common.config_model import BaseConfig, KerasSequentialConfig, TorchLightningSequentialConfig from pyemits.common.data_model import RegressionDataModel from pyemits.common.py_native_dtype import SliceableDeque from pyemits.common.validation import raise_if_value_not_contains @@ -42,6 +44,14 @@ def __init__(self, raw_data_model: RegressionDataModel, other_config: Dict[str, Union[List, BaseConfig, Any]] = {}): """ + universal class for regression model training, + all-in-one training including sklearn, xgboost, lightgbm, keras, pytorch_lightning + + you are not required to fill the algo config if you have idea on algo_config + the algo config is designed for people to config their model based on the configuration that provided in config_model + so that people can easily config their model during creation + + for Pytorch_lightning user, pls configured your model before use this. at that moment, no algo_config is Parameters ---------- @@ -111,6 +121,11 @@ def fill_algo_config_clf(self, clf_or_wrapper.model_obj.add(i) clf_or_wrapper.model_obj.compile(**algo_config.compile) return clf_or_wrapper + elif isinstance(algo_config, TorchLightningSequentialConfig): + clf_or_wrapper: TorchLightningWrapper + for nos, layer in enumerate(algo_config.layer, 1): + clf_or_wrapper.add_layer2blank_model(str(nos), layer) + return clf_or_wrapper # not support pytorch, mxnet model right now raise TypeError('now only support KerasSequentialConfig') @@ -129,15 +144,28 @@ def fill_fit_config_clf(self, y, fit_config: Optional[Union[BaseConfig, Dict]] = None, ): + from pyemits.core.ml.regression.nn import torchlighting_data_helper # nn wrapper if isinstance(clf_or_wrapper, NeuralNetworkWrapperBase): + dl_train, dl_val = torchlighting_data_helper(X, y) + if fit_config is None: + # pytorch_lightning path + if isinstance(clf_or_wrapper, TorchLightningWrapper): + return clf_or_wrapper.fit(dl_train, dl_val) + # keras path return clf_or_wrapper.fit(X, y) if isinstance(fit_config, BaseConfig): + if isinstance(clf_or_wrapper, TorchLightningWrapper): + return clf_or_wrapper.fit(dl_train, dl_val, **dict(fit_config)) + # keras path return clf_or_wrapper.fit(X, y, **dict(fit_config)) elif isinstance(fit_config, Dict): + if isinstance(clf_or_wrapper, TorchLightningWrapper): + return clf_or_wrapper.fit(dl_train, dl_val, **fit_config) + # keras path return clf_or_wrapper.fit(X, y, **fit_config) # sklearn/xgboost/lightgbm clf @@ -204,7 +232,6 @@ def _get_fitted_trainer(algo: List, algo_, algo_config_ in zip(self._algo, self._algo_config)) - # self.clf_models = [obj.clf_models for obj in out] for obj in out: self.clf_models.append(obj.clf_models) return @@ -214,6 +241,12 @@ def fit(self): class MultiOutputRegTrainer(RegTrainer): + """ + machine learning based multioutput regression trainer + bring forecasting power into machine learning model, + forecasting is not only the power of deep learning + """ + def __init__(self, algo: List[Union[str, Any]], algo_config: List[Optional[BaseConfig]],