diff --git a/.gitignore b/.gitignore index 51294f38..b21e3ad7 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,8 @@ docs/_build # ignore specific kinds of files like all PDFs *.pdf *.ipynb +*TODO + +# ignore executable and config files +*.sh +*.yml \ No newline at end of file diff --git a/README.md b/README.md index 0df4a316..5c3e3b2f 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,8 @@ algorithms, PyPOTS is going to have unified APIs together with detailed document algorithms as tutorials. 🤗 **Please** star this repo to help others notice PyPOTS if you think it is a useful toolkit. -**Please** kindly [cite PyPOTS](https://github.com/WenjieDu/PyPOTS#-citing-pypots) in your publications if it helps with your research. +**Please** kindly [cite PyPOTS](https://github.com/WenjieDu/PyPOTS#-citing-pypots) in your publications if it helps with +your research. This really means a lot to our open-source research. Thank you! The rest of this readme file is organized as follows: @@ -98,9 +99,9 @@ corresponding task (note that models will be continuously updated in the future currently supported. Stay tuned❗️). 🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support. -This functionality is implemented with the [Microsoft NNI](https://github.com/microsoft/nni) framework. You may want to refer to our time-series -imputation survey repo [Awesome_Imputation](https://github.com/WenjieDu/Awesome_Imputation) to see how to config and -tune the hyperparameters. +This functionality is implemented with the [Microsoft NNI](https://github.com/microsoft/nni) framework. You may want to +refer to our time-series imputation survey repo [Awesome_Imputation](https://github.com/WenjieDu/Awesome_Imputation) +to see how to config and tune the hyperparameters. 🔥 Note that all models whose name with `🧑🔧` in the table (e.g. Transformer, iTransformer, Informer etc.) are not originally proposed as algorithms for POTS data in their papers, and they cannot directly accept time series with @@ -225,7 +226,9 @@ for a guideline with more details. PyPOTS is available on both [PyPI](https://pypi.python.org/pypi/pypots) and [Anaconda](https://anaconda.org/conda-forge/pypots). -You can install PyPOTS like below as well as TSDB and PyGrinder: +You can install PyPOTS like below as well as +[TSDB](https://github.com/WenjieDu/TSDB),[PyGrinder](https://github.com/WenjieDu/PyGrinder), +[BenchPOTS](https://github.com/WenjieDu/BenchPOTS), and [AI4TS](https://github.com/WenjieDu/AI4TS): ``` bash # via pip @@ -235,8 +238,8 @@ pip install pypots --upgrade # update pypots to the latest version pip install https://github.com/WenjieDu/PyPOTS/archive/main.zip # via conda -conda install -c conda-forge pypots # the first time installation -conda update -c conda-forge pypots # update pypots to the latest version +conda install conda-forge::pypots # the first time installation +conda update conda-forge::pypots # update pypots to the latest version ``` ## ❖ Usage diff --git a/README_zh.md b/README_zh.md index ed376405..35657587 100644 --- a/README_zh.md +++ b/README_zh.md @@ -63,19 +63,19 @@
-⦿ `开发背景`: 由于传感器故障、通信异常以及不可预见的未知原因,在现实环境中收集的时间序列数据普遍存在缺失值, -这使得部分观测时间序列(partially-observed time series,简称为POTS)成为现实世界数据的建模中普遍存在的问题。 -数据缺失会严重阻碍数据的高级分析、建模、与后续应用,所以如何直接面向POTS建模成为一个亟需解决的问题。 -尽管关于在POTS上进行不同任务的机器学习算法已经有了不少的研究,但当前没有专门针对POTS建模开发的工具箱。 -因此,旨在填补该领域空白的“PyPOTS”工具箱应运而生。 +⦿ `开发背景`: 由于传感器故障、通信异常以及不可预见的未知原因, 在现实环境中收集的时间序列数据普遍存在缺失值, +这使得部分观测时间序列(partially-observed time series, 简称为POTS)成为现实世界数据的建模中普遍存在的问题. +数据缺失会严重阻碍数据的高级分析、建模、与后续应用, 所以如何直接面向POTS建模成为一个亟需解决的问题. +尽管关于在POTS上进行不同任务的机器学习算法已经有了不少的研究, 但当前没有专门针对POTS建模开发的工具箱. +因此, 旨在填补该领域空白的"PyPOTS"应运而生. -⦿ `应用意义`: PyPOTS(发音为"Pie Pots")是一个易上手的工具箱,工程师和研究人员可以通过PyPOTS轻松地处理POTS数据建模问题, -进而将注意力更多地聚焦在要解决的核心问题上。PyPOTS会持续不断的更新关于部分观测多变量时间序列的经典算法和先进算法。 -除此之外,PyPOTS还提供了统一的应用程序接口,详细的算法学习指南和应用示例。 +⦿ `应用意义`: PyPOTS(发音为"Pie Pots")是一个易上手的工具箱, 工程师和研究人员可以通过PyPOTS轻松地处理POTS数据建模问题, +进而将注意力更多地聚焦在要解决的核心问题上. PyPOTS会持续不断的更新关于部分观测多变量时间序列的经典算法和先进算法. +除此之外, PyPOTS还提供了统一的应用程序接口,详细的算法学习指南和应用示例. -🤗 如果你认为PyPOTS有用,请星标🌟该项目来帮助更多人注意到PyPOTS的存在。 -如果PyPOTS对你的研究有帮助,请在你的研究中[引用PyPOTS](#-引用pypots)。 -这是对我们开源研究工作的最大支持,谢谢! +🤗 如果你认为PyPOTS有用, 请星标🌟该项目来帮助更多人注意到PyPOTS的存在. +如果PyPOTS对你的研究有帮助, 请在你的研究中[引用PyPOTS](#-引用pypots). +这是对我们开源研究工作的最大支持, 谢谢! 该说明文档的后续内容如下: [**❖ 支持的算法**](#-支持的算法), @@ -84,23 +84,23 @@ [**❖ 使用案例**](#-使用案例), [**❖ 引用PyPOTS**](#-引用pypots), [**❖ 贡献声明**](#-贡献声明), -[**❖ 社区组织**](#-社区组织)。 +[**❖ 社区组织**](#-社区组织). ## ❖ 支持的算法 -PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及异常检测五类任务。下表描述了当前PyPOTS中所集成的算法以及对应不同任务的可用性。 -符号`✅`表示该算法当前可用于相应的任务(注意,目前模型尚不支持的任务在未来版本中可能会逐步添加,敬请关注!)。 -算法的参考文献以及论文链接在该文档底部可以找到。 +PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异常检测五类任务. 下表描述了当前PyPOTS中所集成的算法以及对应不同任务的可用性. +符号`✅`表示该算法当前可用于相应的任务(注意, 目前模型尚不支持的任务在未来版本中可能会逐步添加, 敬请关注!). +算法的参考文献以及论文链接在该文档底部可以找到. -🌟 自**v0.2**版本开始, PyPOTS中所有神经网络模型都支持超参数调优。该功能基于[微软的NNI](https://github.com/microsoft/nni) -框架实现。 +🌟 自**v0.2**版本开始, PyPOTS中所有神经网络模型都支持超参数调优. 该功能基于[微软的NNI](https://github.com/microsoft/nni) +框架实现. 你可以通过参考我们的时间序列插补综述项目的代码[Awesome_Imputation](https://github.com/WenjieDu/Awesome_Imputation) -来了解如何使用PyPOTS调优模型的超参。 +来了解如何使用PyPOTS调优模型的超参. -🔥 请注意: 表格中名称带有`🧑🔧`的模型(例如Transformer, iTransformer, Informer等)在它们的原始论文中并非作为可以处理POTS数据的算法提出, -所以这些模型的输入中不能带有缺失值,无法接受POTS数据作为输入,更加不是插补算法。 -**为了使上述模型能够适用于POTS数据,我们采用了与[SAITS论文](https://arxiv.org/pdf/2202.08516)[^1] -中相同的embedding策略和训练方法(ORT+MIT)对它们进行改进**。 +🔥 请注意: 表格中名称带有`🧑🔧`的模型(例如Transformer, iTransformer, Informer等)在它们的原始论文中并非作为可以处理POTS数据的算法提出, +所以这些模型的输入中不能带有缺失值, 无法接受POTS数据作为输入, 更加不是插补算法. +**为了使上述模型能够适用于POTS数据, 我们采用了与[SAITS论文](https://arxiv.org/pdf/2202.08516)[^1] +中相同的embedding策略和训练方法(ORT+MIT)对它们进行改进**. | **类型** | **算法** | **插补** | **预测** | **分类** | **聚类** | **异常检测** | **年份 - 刊物** | |:--------------|:---------------------------------------------------------------------------------------------------------------------------------|:------:|:------:|:------:|:------:|:--------:|:-------------------| @@ -149,50 +149,53 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及 💯 现在贡献你的模型来增加你的研究影响力!PyPOTS的下载量正在迅速增长 (**[目前PyPI上总共超过30万次且每日超1000的下载](https://www.pepy.tech/projects/pypots)**), -你的工作将被社区广泛使用和引用。请参阅[贡献指南](#-%E8%B4%A1%E7%8C%AE%E5%A3%B0%E6%98%8E) -,了解如何将模型包含在PyPOTS中。 +你的工作将被社区广泛使用和引用. 请参阅[贡献指南](#-%E8%B4%A1%E7%8C%AE%E5%A3%B0%E6%98%8E) +, 了解如何将模型包含在PyPOTS中. ## ❖ PyPOTS生态系统 -在PyPOTS生态系统中,一切都与我们熟悉的咖啡息息相关,甚至可以将其视为一杯咖啡的诞生过程! -如你所见,PyPOTS的标志中有一个咖啡壶。除此之外还需要什么呢?请接着看下去、 +在PyPOTS生态系统中, 一切都与我们熟悉的咖啡息息相关, 甚至可以将其视为一杯咖啡的诞生过程! +如你所见, PyPOTS的标志中有一个咖啡壶. 除此之外还需要什么呢?请接着看下去、
@@ -204,20 +207,22 @@ PyGrinder支持以上所有模式并提供与缺失相关的其他功能函数
## ❖ 安装教程
-你可以参考PyPOTS文档中的 [安装说明](https://docs.pypots.com/en/latest/install.html) 以获取更详细的指南。
-PyPOTS可以在 [PyPI](https://pypi.python.org/pypi/pypots) 和 [Anaconda](https://anaconda.org/conda-forge/pypots) 上安装。
-你可以按照以下方式安装PyPOTS(同样适用于TSDB以及PyGrinder):
+你可以参考PyPOTS文档中的 [安装说明](https://docs.pypots.com/en/latest/install.html) 以获取更详细的指南.
+PyPOTS可以在 [PyPI](https://pypi.python.org/pypi/pypots) 和 [Anaconda](https://anaconda.org/conda-forge/pypots) 上安装.
+你可以按照以下方式安装PyPOTS(同样适用于
+[TSDB](https://github.com/WenjieDu/TSDB), [PyGrinder](https://github.com/WenjieDu/PyGrinder),
+[BenchPOTS](https://github.com/WenjieDu/BenchPOTS), 和[AI4TS](https://github.com/WenjieDu/AI4TS):):
```bash
# 通过pip安装
pip install pypots # 首次安装
pip install pypots --upgrade # 更新为最新版本
-# 利用最新源代码安装最新版本,可能带有尚未正式发布的最新功能
+# 利用最新源代码安装最新版本, 可能带有尚未正式发布的最新功能
pip install https://github.com/WenjieDu/PyPOTS/archive/main.zip
# 通过conda安装
-conda install -c conda-forge pypots # 首次安装
-conda update -c conda-forge pypots # 更新为最新版本
+conda install conda-forge::pypots # 首次安装
+conda update conda-forge::pypots # 更新为最新版本
```
## ❖ 使用案例
@@ -225,16 +230,16 @@ conda update -c conda-forge pypots # 更新为最新版本
除了[BrewPOTS](https://github.com/WenjieDu/BrewPOTS)之外, 你还可以在Google Colab
-上找到一个简单且快速的入门教程。如果你有其他问题,请参考[PyPOTS文档](https://docs.pypots.com)。
-你也可以在我们的[社区](#-community)中提问,或直接[发起issue](https://github.com/WenjieDu/PyPOTS/issues)。
+上找到一个简单且快速的入门教程. 如果你有其他问题, 请参考[PyPOTS文档](https://docs.pypots.com).
+你也可以在我们的[社区](#-community)中提问, 或直接[发起issue](https://github.com/WenjieDu/PyPOTS/issues).
-下面,我们为你演示使用PyPOTS进行POTS数据插补的示例:
+下面, 我们为你演示使用PyPOTS进行POTS数据插补的示例:
点击此处查看 SAITS 模型应用于 PhysioNet2012 数据集插补任务的简单案例:
``` python
-# 数据预处理,使用PyPOTS生态帮助完成繁琐的数据预处理
+# 数据预处理, 使用PyPOTS生态帮助完成繁琐的数据预处理
import numpy as np
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar
@@ -246,15 +251,15 @@ X = X.drop(['RecordID', 'Time'], axis = 1)
X = StandardScaler().fit_transform(X.to_numpy())
X = X.reshape(num_samples, 48, -1)
X_ori = X # keep X_ori for validation
-X = mcar(X, 0.1) # 随机掩盖观测值的10%,作为基准数据
+X = mcar(X, 0.1) # 随机掩盖观测值的10%, 作为基准数据
dataset = {"X": X} # X用于模型输入
-print(X.shape) # X的形状为(11988, 48, 37), 即11988个样本,每个样本有48个步长(time steps)和37个特征(features)
+print(X.shape) # X的形状为(11988, 48, 37), 即11988个样本, 每个样本有48个步长(time steps)和37个特征(features)
-# 模型训练。PyPOTS的好戏上演了!
+# 模型训练. PyPOTS的好戏上演了!
from pypots.imputation import SAITS
from pypots.utils.metrics import calc_mae
saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, n_heads=4, d_k=64, d_v=64, d_ffn=128, dropout=0.1, epochs=10)
-# 因为基准数据对模型不可知,将整个数据集作为训练集, 也可以把数据集分为训练/验证/测试集
+# 因为基准数据对模型不可知, 将整个数据集作为训练集, 也可以把数据集分为训练/验证/测试集
saits.fit(dataset) # 基于数据集训练模型
imputation = saits.impute(dataset) # 插补数据集中原始缺失部分和我们上面人为遮蔽缺失的基准数据部分
indicating_mask = np.isnan(X) ^ np.isnan(X_ori) # 用于计算插补误差的掩码矩阵
@@ -270,22 +275,22 @@ saits.load("save_it_here/saits_physionet2012.pypots") # 你随时可以重新
> [!TIP]
> **[2024年6月更新]** 😎
> 第一个全面的时间序列插补基准论文[TSI-Bench: Benchmarking Time Series Imputation](https://arxiv.org/abs/2406.12747)
-> 现在来了。
+> 现在来了.
> 所有代码开源在[Awesome_Imputation](https://github.com/WenjieDu/Awesome_Imputation)
-> 仓库中。通过近35,000个实验,我们对28种imputation方法,3种缺失模式(点,序列,块),各种缺失率,和8个真实数据集进行了全面的基准研究。
+> 仓库中. 通过近35,000个实验, 我们对28种imputation方法, 3种缺失模式(点, 序列, 块), 各种缺失率, 和8个真实数据集进行了全面的基准研究.
>
> **[2024年2月更新]** 🎉
> 我们的综述论文[Deep Learning for Multivariate Time Series Imputation: A Survey](https://arxiv.org/abs/2402.04059)
-> 已在 arXiv 上发布。我们全面调研总结了最新基于深度学习的时间序列插补方法文献并对现有的方法进行分类,此外,还讨论了该领域当前的挑战和未来发展方向。
+> 已在 arXiv 上发布. 我们全面调研总结了最新基于深度学习的时间序列插补方法文献并对现有的方法进行分类, 此外,
+> 还讨论了该领域当前的挑战和未来发展方向.
-PyPOTS的论文可以[在arXiv上获取](https://arxiv.org/abs/2305.18811),其5页的短版论文已被第9届SIGKDD international workshop
-on
-Mining and Learning from Time Series ([MiLeTS'23](https://kdd-milets.github.io/milets2023/))收录,与此同时,
-PyPOTS也已被纳入[PyTorch Ecosystem](https://pytorch.org/ecosystem/)。我们正在努力将其发表在更具影响力的学术刊物上,
-如JMLR (track for [Machine Learning Open Source Software](https://www.jmlr.org/mloss/))。
-如果你在工作中使用了PyPOTS,请按照以下格式引用我们的论文并为将项目设为星标🌟,以便让更多人关注到它,对此我们深表感谢🤗。
+PyPOTS的论文可以[在arXiv上获取](https://arxiv.org/abs/2305.18811), 其5页的短版论文已被第9届SIGKDD international workshop
+on Mining and Learning from Time Series ([MiLeTS'23](https://kdd-milets.github.io/milets2023/))收录, 与此同时,
+PyPOTS也已被纳入[PyTorch Ecosystem](https://pytorch.org/ecosystem/). 我们正在努力将其发表在更具影响力的学术刊物上,
+如JMLR (track for [Machine Learning Open Source Software](https://www.jmlr.org/mloss/)).
+如果你在工作中使用了PyPOTS, 请按照以下格式引用我们的论文并为将项目设为星标🌟, 以便让更多人关注到它, 对此我们深表感谢🤗.
-据不完全统计,该[列表](https://scholar.google.com/scholar?as_ylo=2022&q=%E2%80%9CPyPOTS%E2%80%9D&hl=en>)
+据不完全统计, 该[列表](https://scholar.google.com/scholar?as_ylo=2022&q=%E2%80%9CPyPOTS%E2%80%9D&hl=en>)
为当前使用PyPOTS并在其论文中引用PyPOTS的科学研究项目
```bibtex
@@ -306,17 +311,17 @@ PyPOTS也已被纳入[PyTorch Ecosystem](https://pytorch.org/ecosystem/)。我
非常欢迎你为这个激动人心的项目做出贡献!
-通过提交你的代码,你将:
+通过提交你的代码, 你将:
-1. 把你开发完善的模型直接提供给PyPOTS的所有用户使用,让你的工作更加广为人知。
- 请查看我们的[收录标准](https://docs.pypots.com/en/latest/faq.html#inclusion-criteria)。
+1. 把你开发完善的模型直接提供给PyPOTS的所有用户使用, 让你的工作更加广为人知.
+ 请查看我们的[收录标准](https://docs.pypots.com/en/latest/faq.html#inclusion-criteria).
你也可以利用项目文件中的模板`template`(如:
[pypots/imputation/template](https://github.com/WenjieDu/PyPOTS/tree/main/pypots/imputation/template))快速启动你的开发;
-2. 成为[PyPOTS贡献者](https://github.com/WenjieDu/PyPOTS/graphs/contributors)之一,
+2. 成为[PyPOTS贡献者](https://github.com/WenjieDu/PyPOTS/graphs/contributors)之一,
并在[PyPOTS网站](https://pypots.com/about/#volunteer-developers)上被列为志愿开发者;
3. 在PyPOTS发布新版本的[更新日志](https://github.com/WenjieDu/PyPOTS/releases)中被提及;
-你也可以通过为该项目设置星标🌟,帮助更多人关注它。你的星标🌟既是对PyPOTS的认可,也是对PyPOTS发展所做出的重要贡献!
+你也可以通过为该项目设置星标🌟, 帮助更多人关注它. 你的星标🌟既是对PyPOTS的认可, 也是对PyPOTS发展所做出的重要贡献!
@@ -338,15 +343,15 @@ PyPOTS也已被纳入[PyTorch Ecosystem](https://pytorch.org/ecosystem/)。我
## ❖ 社区组织
-我们非常关心用户的反馈,因此我们正在建立PyPOTS社区:
+我们非常关心用户的反馈, 因此我们正在建立PyPOTS社区:
- [Slack](https://join.slack.com/t/pypots-org/shared_invite/zt-1gq6ufwsi-p0OZdW~e9UW_IA4_f1OfxA):
你可以在这里进行日常讨论、问答以及与我们的开发团队交流;
- [领英](https://www.linkedin.com/company/pypots):你可以在这里获取官方公告和新闻;
- [微信公众号](https://mp.weixin.qq.com/s/X3ukIgL1QpNH8ZEXq1YifA):你可以关注官方公众号并加入微信群聊参与讨论以及获取最新动态;
-如果你有任何建议、想法、或打算分享与时间序列相关的论文,欢迎加入我们!
-PyPOTS社区是一个开放、透明、友好的社区,让我们共同努力建设并改进PyPOTS!
+如果你有任何建议、想法、或打算分享与时间序列相关的论文, 欢迎加入我们!
+PyPOTS社区是一个开放、透明、友好的社区, 让我们共同努力建设并改进PyPOTS!
[//]: # (Use APA reference style below)
diff --git a/docs/install.rst b/docs/install.rst
index 8e682807..6032efd6 100644
--- a/docs/install.rst
+++ b/docs/install.rst
@@ -14,8 +14,8 @@ It is recommended to use **pip** or **conda** for PyPOTS installation as shown b
pip install https://github.com/WenjieDu/PyPOTS/archive/main.zip
# via conda
- conda install -c conda-forge pypots # the first time installation
- conda update -c conda-forge pypots # update pypots to the latest version
+ conda install conda-forge::pypots # the first time installation
+ conda update conda-forge::pypots # update pypots to the latest version
Required Dependencies
diff --git a/docs/references.bib b/docs/references.bib
index 08222b54..171d63c5 100644
--- a/docs/references.bib
+++ b/docs/references.bib
@@ -765,8 +765,17 @@ @article{bai2018tcn
}
@article{zhan2024tefn,
- title={Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting},
- author={Zhan, Tianxiang and He, Yuanpeng and Li, Zhen and Deng, Yong},
- journal={arXiv preprint arXiv:2405.06419},
- year={2024}
-}
\ No newline at end of file
+title={Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting},
+author={Zhan, Tianxiang and He, Yuanpeng and Li, Zhen and Deng, Yong},
+journal={arXiv preprint arXiv:2405.06419},
+year={2024}
+}
+
+@inproceedings{jin2024timellm,
+title={Time-{LLM}: Time Series Forecasting by Reprogramming Large Language Models},
+author={Ming Jin and Shiyu Wang and Lintao Ma and Zhixuan Chu and James Y. Zhang and Xiaoming Shi and Pin-Yu Chen and Yuxuan Liang and Yuan-Fang Li and Shirui Pan and Qingsong Wen},
+booktitle={The Twelfth International Conference on Learning Representations},
+year={2024},
+url={https://openreview.net/forum?id=Unb5CVPtae}
+}
+
diff --git a/pypots/classification/base.py b/pypots/classification/base.py
index e1848602..75a3a3bb 100644
--- a/pypots/classification/base.py
+++ b/pypots/classification/base.py
@@ -347,8 +347,8 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
- saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
+ saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
if os.getenv("enable_tuning", False):
diff --git a/pypots/classification/brits/model.py b/pypots/classification/brits/model.py
index 85ddd798..e4719f05 100644
--- a/pypots/classification/brits/model.py
+++ b/pypots/classification/brits/model.py
@@ -228,7 +228,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/classification/grud/model.py b/pypots/classification/grud/model.py
index 5fb84671..a8b1ed50 100644
--- a/pypots/classification/grud/model.py
+++ b/pypots/classification/grud/model.py
@@ -204,7 +204,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py
index f599b204..aafac455 100644
--- a/pypots/classification/raindrop/model.py
+++ b/pypots/classification/raindrop/model.py
@@ -248,7 +248,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py
index f1838af3..39a18bdc 100644
--- a/pypots/clustering/crli/model.py
+++ b/pypots/clustering/crli/model.py
@@ -295,7 +295,7 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
@@ -354,7 +354,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py
index 0a6e6418..8e14b93f 100644
--- a/pypots/clustering/vader/model.py
+++ b/pypots/clustering/vader/model.py
@@ -303,8 +303,8 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
- saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
+ saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
if os.getenv("enable_tuning", False):
@@ -367,7 +367,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py
index 5113876d..0931c791 100644
--- a/pypots/forecasting/base.py
+++ b/pypots/forecasting/base.py
@@ -346,8 +346,8 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
- saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
+ saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
if os.getenv("enable_tuning", False):
diff --git a/pypots/forecasting/csdi/model.py b/pypots/forecasting/csdi/model.py
index 8492f87b..ea7d5856 100644
--- a/pypots/forecasting/csdi/model.py
+++ b/pypots/forecasting/csdi/model.py
@@ -307,8 +307,8 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
- saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
+ saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
if os.getenv("enable_tuning", False):
@@ -379,7 +379,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py
index 38a044e5..01213d63 100644
--- a/pypots/imputation/autoformer/model.py
+++ b/pypots/imputation/autoformer/model.py
@@ -234,7 +234,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py
index 1a20dc72..0f43bc25 100644
--- a/pypots/imputation/base.py
+++ b/pypots/imputation/base.py
@@ -346,8 +346,8 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
- saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
+ saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
if os.getenv("enable_tuning", False):
diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py
index 06ec6f4e..6311e321 100644
--- a/pypots/imputation/brits/model.py
+++ b/pypots/imputation/brits/model.py
@@ -216,7 +216,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/crossformer/model.py b/pypots/imputation/crossformer/model.py
index 5e8c3016..3b37c849 100644
--- a/pypots/imputation/crossformer/model.py
+++ b/pypots/imputation/crossformer/model.py
@@ -240,7 +240,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py
index 19c3ecfd..b6d22c7f 100644
--- a/pypots/imputation/csdi/model.py
+++ b/pypots/imputation/csdi/model.py
@@ -287,8 +287,8 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
- saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
+ saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
if os.getenv("enable_tuning", False):
@@ -363,7 +363,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/dlinear/model.py b/pypots/imputation/dlinear/model.py
index ea65df87..28809888 100644
--- a/pypots/imputation/dlinear/model.py
+++ b/pypots/imputation/dlinear/model.py
@@ -211,7 +211,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py
index 7ecb0c03..3851c741 100644
--- a/pypots/imputation/etsformer/model.py
+++ b/pypots/imputation/etsformer/model.py
@@ -234,7 +234,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/fedformer/model.py b/pypots/imputation/fedformer/model.py
index 05d8e7cd..4a07104e 100644
--- a/pypots/imputation/fedformer/model.py
+++ b/pypots/imputation/fedformer/model.py
@@ -248,7 +248,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/film/model.py b/pypots/imputation/film/model.py
index 1f505e64..389f527c 100644
--- a/pypots/imputation/film/model.py
+++ b/pypots/imputation/film/model.py
@@ -228,7 +228,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/frets/model.py b/pypots/imputation/frets/model.py
index 0fc730b7..5ff772b8 100644
--- a/pypots/imputation/frets/model.py
+++ b/pypots/imputation/frets/model.py
@@ -210,7 +210,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py
index f8ff2193..28e73e61 100644
--- a/pypots/imputation/gpvae/model.py
+++ b/pypots/imputation/gpvae/model.py
@@ -312,8 +312,8 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
- saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
+ saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
if os.getenv("enable_tuning", False):
@@ -377,7 +377,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/grud/model.py b/pypots/imputation/grud/model.py
index 269888d0..391ad7b3 100644
--- a/pypots/imputation/grud/model.py
+++ b/pypots/imputation/grud/model.py
@@ -201,7 +201,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/imputeformer/model.py b/pypots/imputation/imputeformer/model.py
index 92daf873..1b313686 100644
--- a/pypots/imputation/imputeformer/model.py
+++ b/pypots/imputation/imputeformer/model.py
@@ -253,7 +253,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/informer/model.py b/pypots/imputation/informer/model.py
index 07788534..26a23c50 100644
--- a/pypots/imputation/informer/model.py
+++ b/pypots/imputation/informer/model.py
@@ -228,7 +228,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/itransformer/model.py b/pypots/imputation/itransformer/model.py
index 46774670..a3c179ec 100644
--- a/pypots/imputation/itransformer/model.py
+++ b/pypots/imputation/itransformer/model.py
@@ -255,7 +255,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/koopa/model.py b/pypots/imputation/koopa/model.py
index 60cbc482..6e5285a8 100644
--- a/pypots/imputation/koopa/model.py
+++ b/pypots/imputation/koopa/model.py
@@ -241,7 +241,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/micn/model.py b/pypots/imputation/micn/model.py
index edfa8d3d..cfa925f4 100644
--- a/pypots/imputation/micn/model.py
+++ b/pypots/imputation/micn/model.py
@@ -222,7 +222,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/moderntcn/model.py b/pypots/imputation/moderntcn/model.py
index e408f5eb..5ba2790b 100644
--- a/pypots/imputation/moderntcn/model.py
+++ b/pypots/imputation/moderntcn/model.py
@@ -252,7 +252,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py
index 40f8dcac..bcd63861 100644
--- a/pypots/imputation/mrnn/model.py
+++ b/pypots/imputation/mrnn/model.py
@@ -218,7 +218,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/nonstationary_transformer/model.py b/pypots/imputation/nonstationary_transformer/model.py
index 814cff3d..1866a96b 100644
--- a/pypots/imputation/nonstationary_transformer/model.py
+++ b/pypots/imputation/nonstationary_transformer/model.py
@@ -244,7 +244,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/patchtst/model.py b/pypots/imputation/patchtst/model.py
index 81d09fc7..f61d5328 100644
--- a/pypots/imputation/patchtst/model.py
+++ b/pypots/imputation/patchtst/model.py
@@ -264,7 +264,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/pyraformer/model.py b/pypots/imputation/pyraformer/model.py
index 576e7c87..48dbb26f 100644
--- a/pypots/imputation/pyraformer/model.py
+++ b/pypots/imputation/pyraformer/model.py
@@ -240,7 +240,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/reformer/model.py b/pypots/imputation/reformer/model.py
index 76b23cb4..5b795d26 100644
--- a/pypots/imputation/reformer/model.py
+++ b/pypots/imputation/reformer/model.py
@@ -241,7 +241,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/revinscinet/model.py b/pypots/imputation/revinscinet/model.py
index 20a78807..4612bb73 100644
--- a/pypots/imputation/revinscinet/model.py
+++ b/pypots/imputation/revinscinet/model.py
@@ -246,7 +246,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py
index cecb3cbe..930a4f84 100644
--- a/pypots/imputation/saits/model.py
+++ b/pypots/imputation/saits/model.py
@@ -268,7 +268,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/scinet/model.py b/pypots/imputation/scinet/model.py
index 86caceb8..515dfc2e 100644
--- a/pypots/imputation/scinet/model.py
+++ b/pypots/imputation/scinet/model.py
@@ -248,7 +248,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/stemgnn/model.py b/pypots/imputation/stemgnn/model.py
index 743ed3d5..1bc2ef5f 100644
--- a/pypots/imputation/stemgnn/model.py
+++ b/pypots/imputation/stemgnn/model.py
@@ -222,7 +222,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/tcn/model.py b/pypots/imputation/tcn/model.py
index 8c01981f..9c83a37d 100644
--- a/pypots/imputation/tcn/model.py
+++ b/pypots/imputation/tcn/model.py
@@ -216,7 +216,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/tefn/model.py b/pypots/imputation/tefn/model.py
index ff30eca5..b4c082e4 100644
--- a/pypots/imputation/tefn/model.py
+++ b/pypots/imputation/tefn/model.py
@@ -188,7 +188,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/tide/model.py b/pypots/imputation/tide/model.py
index 949b15fe..e5d644c3 100644
--- a/pypots/imputation/tide/model.py
+++ b/pypots/imputation/tide/model.py
@@ -228,7 +228,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/timemixer/model.py b/pypots/imputation/timemixer/model.py
index 5e274d7f..09a3f74e 100644
--- a/pypots/imputation/timemixer/model.py
+++ b/pypots/imputation/timemixer/model.py
@@ -253,7 +253,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py
index e3029e93..5ac1aec6 100644
--- a/pypots/imputation/timesnet/model.py
+++ b/pypots/imputation/timesnet/model.py
@@ -224,7 +224,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py
index 33eefee1..ccc7cf27 100644
--- a/pypots/imputation/transformer/model.py
+++ b/pypots/imputation/transformer/model.py
@@ -256,7 +256,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,
diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py
index e329fdf0..aadb3703 100644
--- a/pypots/imputation/usgan/model.py
+++ b/pypots/imputation/usgan/model.py
@@ -324,8 +324,8 @@ def _train_model(
# save the model if necessary
self._auto_save_model_if_necessary(
- confirm_saving=self.best_epoch == epoch,
- saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
+ confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
+ saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)
if os.getenv("enable_tuning", False):
@@ -389,7 +389,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.
# Step 3: save the model if necessary
- self._auto_save_model_if_necessary(confirm_saving=True)
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
def predict(
self,