Skip to content

Commit

Permalink
Merge pull request #84 from basf/trainer_fix
Browse files Browse the repository at this point in the history
Trainer fix
  • Loading branch information
AnFreTh authored Jul 25, 2024
2 parents 6da2cf8 + be1879d commit b16af74
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
4 changes: 2 additions & 2 deletions mambular/base_models/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
lss=False,
family=None,
loss_fct: callable = None,
**kwargs
**kwargs,
):
super().__init__()
self.num_classes = num_classes
Expand Down Expand Up @@ -300,7 +300,7 @@ def configure_optimizers(self):
A dictionary containing the optimizer and lr_scheduler configurations.
"""
optimizer = torch.optim.Adam(
self.parameters(),
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
Expand Down
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
import numpy as np
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseClassifier(BaseEstimator):
Expand Down Expand Up @@ -367,12 +368,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path:
Expand Down
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PoissonDistribution,
StudentTDistribution,
)
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseLSS(BaseEstimator):
Expand Down Expand Up @@ -409,12 +410,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path:
Expand Down
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..base_models.lightning_wrapper import TaskModel
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseRegressor(BaseEstimator):
Expand Down Expand Up @@ -356,12 +357,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path:
Expand Down

0 comments on commit b16af74

Please sign in to comment.