Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Mar 16, 2021
1 parent 32ca511 commit 7deb8ae
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchmetrics import Metric as _Metric
from torchmetrics.collections import MetricCollection as _MetricCollection

from pytorch_lightning.utilities.deprecation import _deprecated
from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.utilities.distributed import rank_zero_warn


Expand Down Expand Up @@ -51,6 +51,6 @@ class MetricCollection(_MetricCollection):
Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0.
"""

@_deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
pass
20 changes: 10 additions & 10 deletions pytorch_lightning/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@
from torchmetrics.utilities.distributed import class_reduce as __class_reduce
from torchmetrics.utilities.distributed import reduce as __reduce

from pytorch_lightning.utilities.deprecation import _deprecated
from pytorch_lightning.utilities.deprecation import deprecated


@_deprecated(target=__dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0")
def dim_zero_cat(x):
pass


@_deprecated(target=__dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0")
def dim_zero_sum(x):
pass


@_deprecated(target=__dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0")
def dim_zero_mean(x):
pass

Expand Down Expand Up @@ -71,47 +71,47 @@ def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]:
return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()]


@_deprecated(target=__to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0")
def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.utilities.data.to_onehot`. Will be removed in v1.5.0.
"""


@_deprecated(target=__select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0")
def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.utilities.data.select_topk`. Will be removed in v1.5.0.
"""


@_deprecated(target=__to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0")
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.utilities.data.to_categorical`. Will be removed in v1.5.0.
"""


@_deprecated(target=__get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0")
def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int:
"""
.. deprecated::
Use :func:`torchmetrics.utilities.data.get_num_classes`. Will be removed in v1.5.0.
"""


@_deprecated(target=__reduce, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__reduce, ver_deprecate="1.3.0", ver_remove="1.5.0")
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
"""
.. deprecated::
Use :func:`torchmetrics.utilities.reduce`. Will be removed in v1.5.0.
"""


@_deprecated(target=__class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated(target=__class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0")
def class_reduce(
num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none"
) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/utilities/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]]
return name_type_default


def _deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable:
def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable:
"""
Decorate a function or class ``__init__`` with warning message
and pass all arguments directly to the target class/method.
Expand All @@ -51,7 +51,7 @@ def wrapped_fn(*args, **kwargs):
target_str = f'{target.__module__}.{target.__name__}'
func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__
rank_zero_warn(
f"This `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
f" It will be removed in v{ver_remove}.", DeprecationWarning
)
inner_function.warned = True
Expand Down
2 changes: 1 addition & 1 deletion tests/deprecated_api/test_remove_1-5_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_v1_5_0_metrics_collection():
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
with pytest.deprecated_call(
match="This `MetricCollection` was deprecated since v1.3.0 in favor"
match="The `MetricCollection` was deprecated since v1.3.0 in favor"
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0"
):
metrics = MetricCollection([Accuracy()])
Expand Down
6 changes: 3 additions & 3 deletions tests/utilities/test_deprecation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import pytest

from pytorch_lightning.utilities.deprecation import _deprecated
from pytorch_lightning.utilities.deprecation import deprecated


def my_sum(a, b=3):
return a + b


@_deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
def dep_sum(a, b):
pass


@_deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
def dep2_sum(a, b):
pass

Expand Down

0 comments on commit 7deb8ae

Please sign in to comment.