Skip to content

Commit f798244

Browse files
committed
Making log more concise
Making the log parameters method more concise as suggested.
1 parent ce635bc commit f798244

File tree

1 file changed

+15
-28
lines changed

1 file changed

+15
-28
lines changed

docs/userguide/torch_forecasting_models.md

+15-28
Original file line numberDiff line numberDiff line change
@@ -512,33 +512,20 @@ with mlflow.start_run(nested=True) as run:
512512
# dataset is used for model training
513513
mlflow.log_input(dataset, context="training")
514514

515-
mlflow.log_param("model_type", "Darts_Pytorch_model")
516-
mlflow.log_param("input_chunk_length", 24)
517-
mlflow.log_param("output_chunk_length", 12)
518-
mlflow.log_param("n_epochs", 500)
519-
mlflow.log_param("model_name", 'NBEATS_MLflow')
520-
mlflow.log_param("log_tensorboard", True)
521-
mlflow.log_param("torch_metrics", "torchmetrics.regression.MeanAbsolutePercentageError()")
522-
mlflow.log_param("nr_epochs_val_period", 1)
523-
mlflow.log_param("pl_trainer_kwargs", "{callbacks: [loss_logger]}")
524-
525-
526-
from pytorch_lightning.callbacks import Callback
527-
528-
class LossLogger(Callback):
529-
def __init__(self):
530-
self.train_loss = []
531-
self.val_loss = []
532-
533-
# will automatically be called at the end of each epoch
534-
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
535-
self.train_loss.append(float(trainer.callback_metrics["train_loss"]))
536-
537-
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
538-
self.val_loss.append(float(trainer.callback_metrics["val_loss"]))
539-
540-
541-
loss_logger = LossLogger()
515+
# Define model hyperparameters to log
516+
params = {
517+
"model_type": "Darts_Pytorch_model",
518+
"input_chunk_length": 24,
519+
"output_chunk_length": 12,
520+
"n_epochs": 500,
521+
"model_name": "NBEATS_MLflow",
522+
"log_tensorboard": True,
523+
"torch_metrics": "torchmetrics.regression.MeanAbsolutePercentageError()",
524+
"nr_epochs_val_period": 1,
525+
}
526+
527+
# Log hyperparameters
528+
mlflow.log_params(params)
542529

543530
# create the model
544531
model = NBEATSModel(
@@ -549,7 +536,7 @@ with mlflow.start_run(nested=True) as run:
549536
log_tensorboard=True,
550537
torch_metrics=torch_metrics,
551538
nr_epochs_val_period=1,
552-
pl_trainer_kwargs={"callbacks": [loss_logger]})
539+
)
553540

554541
# use validation dataset
555542
model.fit(

0 commit comments

Comments
 (0)