diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index f8d7a2ffe3a23..3961586f4946a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,8 +20,8 @@ import torch from torch import Tensor +from torchmetrics import Metric -from pytorch_lightning.metrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 343e979dd3e0c..367c9b029d841 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update -from pytorch_lightning.metrics.metric import Metric class Accuracy(Metric): diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py index 6c5a29173d20a..76c1959a8603a 100644 --- a/pytorch_lightning/metrics/classification/auc.py +++ b/pytorch_lightning/metrics/classification/auc.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py index 6b9b5ae9f021f..7d8ba7368e45d 100644 --- a/pytorch_lightning/metrics/classification/auroc.py +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -15,9 +15,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index f9c7bde158383..adcdd86ed1ca8 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -14,9 +14,9 @@ from typing import Any, List, Optional, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index c3defc82bc92d..112fb4940e6e2 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -14,9 +14,9 @@ from typing import Any, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update -from pytorch_lightning.metrics.metric import Metric class ConfusionMatrix(Metric): diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index ae01b80966868..a46b01a1aa8b7 100644 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -14,9 +14,9 @@ from typing import Any, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index adf1086f3c85f..dceb90c0a4ca9 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_compute, _hamming_distance_update -from pytorch_lightning.metrics.metric import Metric class HammingDistance(Metric): diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 5a02a99ed17fd..ccf821d829d78 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -14,12 +14,12 @@ from typing import Any, List, Optional, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.precision_recall_curve import ( _precision_recall_curve_compute, _precision_recall_curve_update, ) -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 598646cde3861..30ca0b4fe6925 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -14,9 +14,9 @@ from typing import Any, List, Optional, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.roc import _roc_compute, _roc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index 3807d7079b508..672b0f41c6fc5 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional, Tuple import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update -from pytorch_lightning.metrics.metric import Metric class StatScores(Metric): diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index 5961714209d40..975b8280f77d5 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -21,10 +21,9 @@ class CompositionalMetric(__CompositionalMetric): - r""" - This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`. - - .. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0. """ def __init__( @@ -34,7 +33,7 @@ def __init__( metric_b: Union[Metric, int, float, torch.Tensor, None], ): rank_zero_warn( - "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." - " It will be removed in v1.5.0", DeprecationWarning + "This `CompositionalMetric` was deprecated since v1.3.0 in favor of" + " `torchmetrics.metric.CompositionalMetric`. It will be removed in v1.5.0", DeprecationWarning ) super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 145a13a251250..918c92049846e 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -13,17 +13,17 @@ # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from torchmetrics import Metric as __Metric -from torchmetrics import MetricCollection as __MetricCollection +from torchmetrics import Metric as _Metric +from torchmetrics.collections import MetricCollection as _MetricCollection +from pytorch_lightning.utilities.deprecation import deprecated from pytorch_lightning.utilities.distributed import rank_zero_warn -class Metric(__Metric): +class Metric(_Metric): r""" - This implementation refers to :class:`~torchmetrics.Metric`. - - .. warning:: This metric is deprecated, use ``torchmetrics.Metric``. Will be removed in v1.5.0. + .. deprecated:: + Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. """ def __init__( @@ -45,16 +45,12 @@ def __init__( ) -class MetricCollection(__MetricCollection): - r""" - This implementation refers to :class:`~torchmetrics.MetricCollection`. - - .. warning:: This metric is deprecated, use ``torchmetrics.MetricCollection``. Will be removed in v1.5.0. +class MetricCollection(_MetricCollection): + """ + .. deprecated:: + Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0. """ + @deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): - rank_zero_warn( - "This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`." - " It will be removed in v1.5.0", DeprecationWarning - ) - super().__init__(metrics=metrics) + pass diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index fc033fcd16759..8b0259694ef4c 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.explained_variance import ( _explained_variance_compute, _explained_variance_update, ) -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index ca184daf736b8..484ccbe83284e 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_absolute_error import ( _mean_absolute_error_compute, _mean_absolute_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanAbsoluteError(Metric): diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 09f275ded8638..c26371514e7cd 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_squared_error import ( _mean_squared_error_compute, _mean_squared_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanSquaredError(Metric): diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 18105e687b0b1..caaf09a3663ff 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_squared_log_error import ( _mean_squared_log_error_compute, _mean_squared_log_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanSquaredLogError(Metric): diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index 8a38bf515ebca..746ff1e52d574 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -14,10 +14,10 @@ from typing import Any, Optional, Sequence, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning import utilities from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update -from pytorch_lightning.metrics.metric import Metric class PSNR(Metric): diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 40d9d24711375..8156b8bc72d48 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update -from pytorch_lightning.metrics.metric import Metric class R2Score(Metric): diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index 09b55fb2bb456..a3bbab938ffad 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -14,9 +14,9 @@ from typing import Any, Optional, Sequence import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/retrieval/retrieval_metric.py b/pytorch_lightning/metrics/retrieval/retrieval_metric.py index 29f02555dad69..6f9088d00083c 100644 --- a/pytorch_lightning/metrics/retrieval/retrieval_metric.py +++ b/pytorch_lightning/metrics/retrieval/retrieval_metric.py @@ -2,8 +2,8 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric -from pytorch_lightning.metrics import Metric from pytorch_lightning.metrics.utils import get_group_indexes #: get_group_indexes is used to group predictions belonging to the same query diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 63c6892cb2987..b758e317c6c8d 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -14,38 +14,32 @@ from typing import List, Optional import torch -from torchmetrics.utilities.data import dim_zero_cat as __dim_zero_cat -from torchmetrics.utilities.data import dim_zero_mean as __dim_zero_mean -from torchmetrics.utilities.data import dim_zero_sum as __dim_zero_sum -from torchmetrics.utilities.data import get_num_classes as __get_num_classes -from torchmetrics.utilities.data import select_topk as __select_topk -from torchmetrics.utilities.data import to_categorical as __to_categorical -from torchmetrics.utilities.data import to_onehot as __to_onehot -from torchmetrics.utilities.distributed import class_reduce as __class_reduce -from torchmetrics.utilities.distributed import reduce as __reduce +from torchmetrics.utilities.data import dim_zero_cat as _dim_zero_cat +from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean +from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum +from torchmetrics.utilities.data import get_num_classes as _get_num_classes +from torchmetrics.utilities.data import select_topk as _select_topk +from torchmetrics.utilities.data import to_categorical as _to_categorical +from torchmetrics.utilities.data import to_onehot as _to_onehot +from torchmetrics.utilities.distributed import class_reduce as _class_reduce +from torchmetrics.utilities.distributed import reduce as _reduce -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_cat(x): - rank_zero_warn( - "This `dim_zero_cat` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning - ) - return __dim_zero_cat(x) + pass +@deprecated(target=_dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_sum(x): - rank_zero_warn( - "This `dim_zero_sum` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning - ) - return __dim_zero_sum(x) + pass +@deprecated(target=_dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") def dim_zero_mean(x): - rank_zero_warn( - "This `dim_zero_mean` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning - ) - return __dim_zero_mean(x) + pass def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: @@ -77,79 +71,51 @@ def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()] +@deprecated(target=_to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: - r""" - .. warning:: This function is deprecated, use ``torchmetrics.utilities.data.to_onehot``. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `to_onehot` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.to_onehot`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __to_onehot(label_tensor=label_tensor, num_classes=num_classes) + .. deprecated:: + Use :func:`torchmetrics.utilities.data.to_onehot`. Will be removed in v1.5.0. + """ +@deprecated(target=_select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.data.select_topk``. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `select_topk` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.select_topk`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __select_topk(prob_tensor=prob_tensor, topk=topk, dim=dim) + .. deprecated:: + Use :func:`torchmetrics.utilities.data.select_topk`. Will be removed in v1.5.0. + """ +@deprecated(target=_to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.data.to_categorical``. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `to_categorical` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.to_categorical`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __to_categorical(tensor=tensor, argmax_dim=argmax_dim) + .. deprecated:: + Use :func:`torchmetrics.utilities.data.to_categorical`. Will be removed in v1.5.0. + """ +@deprecated(target=_get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.data.get_num_classes``. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `get_num_classes` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.get_num_classes`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __get_num_classes(pred=pred, target=target, num_classes=num_classes) + .. deprecated:: + Use :func:`torchmetrics.utilities.data.get_num_classes`. Will be removed in v1.5.0. + """ +@deprecated(target=_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.reduce``. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `reduce` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.reduce`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __reduce(to_reduce=to_reduce, reduction=reduction) + .. deprecated:: + Use :func:`torchmetrics.utilities.reduce`. Will be removed in v1.5.0. + """ +@deprecated(target=_class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: - r""" - .. warning:: - - This function is deprecated, use ``torchmetrics.utilities.class_reduce``. Will be removed in v1.5.0. """ - rank_zero_warn( - "This `class_reduce` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.class_reduce`." - " It will be removed in v1.5.0", DeprecationWarning - ) - return __class_reduce(num=num, denom=denom, weights=weights, class_reduction=class_reduction) + .. deprecated:: + Use :func:`torchmetrics.utilities.class_reduce`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 82f328a927485..554f1d3faf9ed 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -15,8 +15,7 @@ from typing import Any import torch - -from pytorch_lightning.metrics.metric import Metric +from torchmetrics import Metric class MetricsHolder: diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py new file mode 100644 index 0000000000000..3e2034c6a0453 --- /dev/null +++ b/pytorch_lightning/utilities/deprecation.py @@ -0,0 +1,73 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from functools import wraps +from typing import Any, Callable, List, Tuple + +from pytorch_lightning.utilities import rank_zero_warn + + +def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]]: + """Parse function arguments, types and default values + + Example: + >>> get_func_arguments_and_types(get_func_arguments_and_types) + [('func', typing.Callable, )] + """ + func_default_params = inspect.signature(func).parameters + name_type_default = [] + for arg in func_default_params: + arg_type = func_default_params[arg].annotation + arg_default = func_default_params[arg].default + name_type_default.append((arg, arg_type, arg_default)) + return name_type_default + + +def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable: + """ + Decorate a function or class ``__init__`` with warning message + and pass all arguments directly to the target class/method. + """ + + def inner_function(func): + + @wraps(func) + def wrapped_fn(*args, **kwargs): + is_class = inspect.isclass(target) + target_func = target.__init__ if is_class else target + # warn user only once in lifetime + if not getattr(inner_function, 'warned', False): + target_str = f'{target.__module__}.{target.__name__}' + func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__ + rank_zero_warn( + f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." + f" It will be removed in v{ver_remove}.", DeprecationWarning + ) + inner_function.warned = True + + if args: # in case any args passed move them to kwargs + # parse only the argument names + cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)] + # convert args to kwargs + kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + + target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)] + assert all(arg in target_args for arg in kwargs), \ + "Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs] + # all args were already moved to kwargs + return target_func(**kwargs) + + return wrapped_fn + + return inner_function diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 9d31688d9bcc0..0b797dff0e42f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -15,10 +15,10 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +from torchmetrics import Metric import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result -from pytorch_lightning.metrics import Metric from tests.helpers.runif import RunIf diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index b2fa4f69f74b9..7c8c9ad296416 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -16,6 +16,7 @@ import pytest import torch +from pytorch_lightning.metrics import Accuracy, MetricCollection from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -34,3 +35,14 @@ def test_v1_5_0_metrics_utils(): x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) with pytest.deprecated_call(match="It will be removed in v1.5.0"): assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int)) + + +def test_v1_5_0_metrics_collection(): + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + with pytest.deprecated_call( + match="The `MetricCollection` was deprecated since v1.3.0 in favor" + " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0" + ): + metrics = MetricCollection([Accuracy()]) + assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]} diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index f13c1ebe26d3e..c9e5467414832 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,9 +5,10 @@ import pytest import torch from sklearn.metrics import precision_score, recall_score +from torchmetrics import Metric from torchmetrics.classification.checks import _input_format_classification -from pytorch_lightning.metrics import Metric, Precision, Recall +from pytorch_lightning.metrics import Precision, Recall from pytorch_lightning.metrics.functional import precision, precision_recall, recall from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/retrieval/test_map.py b/tests/metrics/retrieval/test_map.py index aa6eeb6424a33..fe43f19b20eb6 100644 --- a/tests/metrics/retrieval/test_map.py +++ b/tests/metrics/retrieval/test_map.py @@ -6,9 +6,9 @@ import pytest import torch from sklearn.metrics import average_precision_score as sk_average_precision +from torchmetrics import Metric from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 895305fa9da7e..2e040a881d49f 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,7 +1,7 @@ import torch +from torchmetrics import Metric, MetricCollection from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric, MetricCollection from tests.helpers.boring_model import BoringModel diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4bd6608ce3fcf..f1f17d0624936 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -8,8 +8,7 @@ import pytest import torch from torch.multiprocessing import Pool, set_start_method - -from pytorch_lightning.metrics import Metric +from torchmetrics import Metric try: set_start_method("spawn") diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py new file mode 100644 index 0000000000000..7c653c07ad168 --- /dev/null +++ b/tests/utilities/test_deprecation.py @@ -0,0 +1,37 @@ +import pytest + +from pytorch_lightning.utilities.deprecation import deprecated +from tests.helpers.utils import no_warning_call + + +def my_sum(a, b=3): + return a + b + + +@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") +def dep_sum(a, b): + pass + + +@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") +def dep2_sum(a, b): + pass + + +def test_deprecated_func(): + with pytest.deprecated_call( + match='The `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + ' It will be removed in v0.5.' + ): + assert dep_sum(2, b=5) == 7 + + # check that the warning is raised only once per function + with no_warning_call(DeprecationWarning): + assert dep_sum(2, b=5) == 7 + + # and does not affect other functions + with pytest.deprecated_call( + match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + ' It will be removed in v0.5.' + ): + assert dep2_sum(2) == 5