Skip to content

Commit

Permalink
depth=2 for summary and self.trainer attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Jul 25, 2024
1 parent 257acfb commit be1879d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
import numpy as np
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseClassifier(BaseEstimator):
Expand Down Expand Up @@ -367,12 +368,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path:
Expand Down
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PoissonDistribution,
StudentTDistribution,
)
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseLSS(BaseEstimator):
Expand Down Expand Up @@ -409,12 +410,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path:
Expand Down
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..base_models.lightning_wrapper import TaskModel
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseRegressor(BaseEstimator):
Expand Down Expand Up @@ -356,12 +357,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path:
Expand Down

0 comments on commit be1879d

Please sign in to comment.