Skip to content

Commit

Permalink
Merge pull request #287 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Making PyPOTS able to save all models during training, checking if d_model=n_heads*d_k for SAITS and Transformer
  • Loading branch information
WenjieDu authored Dec 26, 2023
2 parents bae6dd4 + 23567e1 commit e6e1880
Show file tree
Hide file tree
Showing 34 changed files with 141 additions and 99 deletions.
20 changes: 10 additions & 10 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _setup_path(self, saving_path) -> None:
logger.info(f"Tensorboard file will be saved to {tb_saving_path}")
else:
logger.warning(
"saving_path not given. Model files and tensorboard file will not be saved."
"‼️ saving_path not given. Model files and tensorboard file will not be saved."
)

def _send_model_to_given_device(self) -> None:
Expand Down Expand Up @@ -221,30 +221,30 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None

def _auto_save_model_if_necessary(
self,
training_finished: bool = True,
confirm_saving: bool = True,
saving_name: str = None,
) -> None:
"""Automatically save the current model into a file if in need.
Parameters
----------
training_finished :
Whether the training is already finished when invoke this function.
The saving_strategy "better" only works when training_finished is False.
The saving_strategy "best" only works when training_finished is True.
confirm_saving :
One more condition to confirm saving the model.
saving_name :
The file name of the saved model.
"""
if self.saving_path is not None and self.model_saving_strategy is not None:
# construct the saving path
name = self.__class__.__name__ if saving_name is None else saving_name
saving_path = os.path.join(self.saving_path, name)

if self.model_saving_strategy == "all":
self.save(saving_path)
elif not training_finished and self.model_saving_strategy == "better":
elif self.model_saving_strategy == "better" and confirm_saving:
self.save(saving_path)
elif training_finished and self.model_saving_strategy == "best":
elif self.model_saving_strategy == "best" and confirm_saving:
self.save(saving_path)
else:
pass
Expand Down Expand Up @@ -280,10 +280,10 @@ def save(
if os.path.exists(saving_path):
if overwrite:
logger.warning(
f"File {saving_path} exists. Argument `overwrite` is True. Overwriting now..."
f"‼️ File {saving_path} exists. Argument `overwrite` is True. Overwriting now..."
)
else:
logger.error(f"File {saving_path} exists. Saving operation aborted.")
logger.error(f"File {saving_path} exists. Saving operation aborted.")

try:
create_dir_if_not_exist(saving_dir)
Expand Down
13 changes: 7 additions & 6 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,14 +340,15 @@ def _train_model(
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
# save the model if necessary
self._auto_save_model_if_necessary(
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=mean_loss < self.best_loss,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)

if os.getenv("enable_tuning", False):
nni.report_intermediate_result(mean_loss)
if epoch == self.epochs - 1 or self.patience == 0:
Expand All @@ -360,7 +361,7 @@ def _train_model(
break

except Exception as e:
logger.error(f"Exception: {e}")
logger.error(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not trained. Please investigate the error printed above."
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)
self._auto_save_model_if_necessary(confirm_saving=True)

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)
self._auto_save_model_if_necessary(confirm_saving=True)

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)
self._auto_save_model_if_necessary(confirm_saving=True)

def predict(
self,
Expand Down
4 changes: 2 additions & 2 deletions pypots/classification/raindrop/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from torch_geometric.nn.inits import glorot
except ImportError as e:
logger.error(
f"{e}\n"
f"{e}\n"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
)
except NameError as e:
logger.error(
f"{e}\n"
f"{e}\n"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/raindrop/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch_sparse import SparseTensor
except ImportError as e:
logger.error(
f"{e}\n"
f"{e}\n"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
Expand Down
2 changes: 1 addition & 1 deletion pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import nni
except ImportError:
logger.error(
"Hyperparameter tuning mode needs NNI (https://github.com/microsoft/nni) installed, "
"Hyperparameter tuning mode needs NNI (https://github.com/microsoft/nni) installed, "
"but is missing in the current environment."
)

Expand Down
2 changes: 1 addition & 1 deletion pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def _train_model(
break

except Exception as e:
logger.error(f"Exception: {e}")
logger.error(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not trained. Please investigate the error printed above."
Expand Down
15 changes: 8 additions & 7 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,6 @@ def _train_model(
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
# save the model if necessary
self._auto_save_model_if_necessary(
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1

Expand All @@ -319,14 +314,20 @@ def _train_model(
if epoch == self.epochs - 1 or self.patience == 0:
nni.report_final_result(self.best_loss)

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=mean_loss < self.best_loss,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

if self.patience == 0:
logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break

except Exception as e:
logger.error(f"Exception: {e}")
logger.error(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not trained. Please investigate the error printed above."
Expand Down Expand Up @@ -376,7 +377,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)
self._auto_save_model_if_necessary(confirm_saving=True)

def predict(
self,
Expand Down
23 changes: 12 additions & 11 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,16 @@ def _train_model(
gmm.fit(samples)
flag = 1
except ValueError as e:
logger.error(e)
logger.error(f"❌ Exception: {e}")
logger.warning(
"Met with ValueError, double `reg_covar` to re-train the GMM model."
"‼️ Met with ValueError, double `reg_covar` to re-train the GMM model."
)

flag -= 1
if flag == -5:
logger.error(
f"Doubled `reg_covar` for 4 times, whose current value is {reg_covar}, but still failed.\n"
"Now quit to let you check your model training.\n"
f"Doubled `reg_covar` for 4 times, its current value is {reg_covar}, but still failed.\n"
f"Now quit to let you check your model training.\n"
"Please raise an issue https://github.com/WenjieDu/PyPOTS/issues if you have questions."
)
raise RuntimeError
Expand Down Expand Up @@ -321,14 +321,15 @@ def _train_model(
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
# save the model if necessary
self._auto_save_model_if_necessary(
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=mean_loss < self.best_loss,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)

if os.getenv("enable_tuning", False):
nni.report_intermediate_result(mean_loss)
if epoch == self.epochs - 1 or self.patience == 0:
Expand All @@ -341,7 +342,7 @@ def _train_model(
break

except Exception as e:
logger.error(f"Exception: {e}")
logger.error(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not trained. Please investigate the error printed above."
Expand Down Expand Up @@ -391,7 +392,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)
self._auto_save_model_if_necessary(confirm_saving=True)

def predict(
self,
Expand Down
31 changes: 23 additions & 8 deletions pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ..base import BaseModel, BaseNNModel
from ..utils.logging import logger
from ..utils.metrics.error import calc_mse

try:
import nni
Expand Down Expand Up @@ -262,7 +263,6 @@ def _train_model(
# each training starts from the very beginning, so reset the loss and model dict here
self.best_loss = float("inf")
self.best_model_dict = None

try:
training_step = 0
for epoch in range(1, self.epochs + 1):
Expand All @@ -273,6 +273,7 @@ def _train_model(
inputs = self._assemble_input_for_training(data)
self.optimizer.zero_grad()
results = self.model.forward(inputs)
# use sum() before backward() in case of multi-gpu training
results["loss"].sum().backward()
self.optimizer.step()
epoch_train_loss_collector.append(results["loss"].sum().item())
Expand All @@ -286,21 +287,29 @@ def _train_model(

if val_loader is not None:
self.model.eval()
epoch_val_loss_collector = []
forecasting_loss_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs)
epoch_val_loss_collector.append(
results["loss"].sum().item()
results = self.model.forward(inputs, training=False)
forecasting_mse = (
calc_mse(
results["forecasting_data"],
inputs["X_ori"],
inputs["indicating_mask"],
)
.sum()
.detach()
.item()
)
forecasting_loss_collector.append(forecasting_mse)

mean_val_loss = np.mean(epoch_val_loss_collector)
mean_val_loss = np.mean(forecasting_loss_collector)

# save validating loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"imputation_loss": mean_val_loss,
"forecasting_loss": mean_val_loss,
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)

Expand Down Expand Up @@ -328,6 +337,12 @@ def _train_model(
else:
self.patience -= 1

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=mean_loss < self.best_loss,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)

if os.getenv("enable_tuning", False):
nni.report_intermediate_result(mean_loss)
if epoch == self.epochs - 1 or self.patience == 0:
Expand All @@ -340,7 +355,7 @@ def _train_model(
break

except Exception as e:
logger.error(f"Exception: {e}")
logger.error(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not trained. Please investigate the error printed above."
Expand Down
11 changes: 6 additions & 5 deletions pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,15 @@ def _train_model(
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
# save the model if necessary
self._auto_save_model_if_necessary(
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=mean_loss < self.best_loss,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)

if os.getenv("enable_tuning", False):
nni.report_intermediate_result(mean_loss)
if epoch == self.epochs - 1 or self.patience == 0:
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)
self._auto_save_model_if_necessary(confirm_saving=True)

def predict(
self,
Expand Down
Loading

0 comments on commit e6e1880

Please sign in to comment.