@@ -512,33 +512,20 @@ with mlflow.start_run(nested=True) as run:
512
512
# dataset is used for model training
513
513
mlflow.log_input(dataset, context = " training" )
514
514
515
- mlflow.log_param(" model_type" , " Darts_Pytorch_model" )
516
- mlflow.log_param(" input_chunk_length" , 24 )
517
- mlflow.log_param(" output_chunk_length" , 12 )
518
- mlflow.log_param(" n_epochs" , 500 )
519
- mlflow.log_param(" model_name" , ' NBEATS_MLflow' )
520
- mlflow.log_param(" log_tensorboard" , True )
521
- mlflow.log_param(" torch_metrics" , " torchmetrics.regression.MeanAbsolutePercentageError()" )
522
- mlflow.log_param(" nr_epochs_val_period" , 1 )
523
- mlflow.log_param(" pl_trainer_kwargs" , " {callbacks: [loss_logger]} " )
524
-
525
-
526
- from pytorch_lightning.callbacks import Callback
527
-
528
- class LossLogger (Callback ):
529
- def __init__ (self ):
530
- self .train_loss = []
531
- self .val_loss = []
532
-
533
- # will automatically be called at the end of each epoch
534
- def on_train_epoch_end (self , trainer : " pl.Trainer" , pl_module : " pl.LightningModule" ) -> None :
535
- self .train_loss.append(float (trainer.callback_metrics[" train_loss" ]))
536
-
537
- def on_validation_epoch_end (self , trainer : " pl.Trainer" , pl_module : " pl.LightningModule" ) -> None :
538
- self .val_loss.append(float (trainer.callback_metrics[" val_loss" ]))
539
-
540
-
541
- loss_logger = LossLogger()
515
+ # Define model hyperparameters to log
516
+ params = {
517
+ " model_type" : " Darts_Pytorch_model" ,
518
+ " input_chunk_length" : 24 ,
519
+ " output_chunk_length" : 12 ,
520
+ " n_epochs" : 500 ,
521
+ " model_name" : " NBEATS_MLflow" ,
522
+ " log_tensorboard" : True ,
523
+ " torch_metrics" : " torchmetrics.regression.MeanAbsolutePercentageError()" ,
524
+ " nr_epochs_val_period" : 1 ,
525
+ }
526
+
527
+ # Log hyperparameters
528
+ mlflow.log_params(params)
542
529
543
530
# create the model
544
531
model = NBEATSModel(
@@ -549,7 +536,7 @@ with mlflow.start_run(nested=True) as run:
549
536
log_tensorboard = True ,
550
537
torch_metrics = torch_metrics,
551
538
nr_epochs_val_period = 1 ,
552
- pl_trainer_kwargs = { " callbacks " : [loss_logger]} )
539
+ )
553
540
554
541
# use validation dataset
555
542
model.fit(
0 commit comments