From 70fce103dfddce647847d41db46ee7b5ad2fcfbd Mon Sep 17 00:00:00 2001 From: stefan Date: Tue, 3 Dec 2024 16:15:57 +1030 Subject: [PATCH 01/10] added matplotlib to test requirements --- requirements/_tests.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/_tests.txt b/requirements/_tests.txt index 0a8cb73eaab..26b8e14ff98 100644 --- a/requirements/_tests.txt +++ b/requirements/_tests.txt @@ -20,3 +20,5 @@ cloudpickle >1.3, <=3.1.0 scikit-learn ==1.2.*; python_version < "3.9" scikit-learn ==1.5.*; python_version > "3.8" # we do not use `> =` because of oldest replcement cachier ==3.1.2 + +matplotlib==3.9.3 From fc80fa41eb398bff979375f2a24e67bd2f3adada Mon Sep 17 00:00:00 2001 From: stefan Date: Tue, 3 Dec 2024 16:16:10 +1030 Subject: [PATCH 02/10] added new test for plotting in multilabel classifier --- .../classification/test_confusion_matrix.py | 51 ++++++++++++------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 4d27dfc2069..40716724f99 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -243,28 +243,28 @@ def test_multiclass_confusion_matrix_dtype_gpu(self, inputs, dtype): ("preds", "target", "ignore_index", "error_message"), [ ( - torch.randint(NUM_CLASSES + 1, (100,)), - torch.randint(NUM_CLASSES, (100,)), - None, - f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*", + torch.randint(NUM_CLASSES + 1, (100,)), + torch.randint(NUM_CLASSES, (100,)), + None, + f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*", ), ( - torch.randint(NUM_CLASSES, (100,)), - torch.randint(NUM_CLASSES + 1, (100,)), - None, - f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*", + torch.randint(NUM_CLASSES, (100,)), + torch.randint(NUM_CLASSES + 1, (100,)), + None, + f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*", ), ( - torch.randint(NUM_CLASSES + 2, (100,)), - torch.randint(NUM_CLASSES, (100,)), - 1, - f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*", + torch.randint(NUM_CLASSES + 2, (100,)), + torch.randint(NUM_CLASSES, (100,)), + 1, + f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*", ), ( - torch.randint(NUM_CLASSES, (100,)), - torch.randint(NUM_CLASSES + 2, (100,)), - 1, - f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*", + torch.randint(NUM_CLASSES, (100,)), + torch.randint(NUM_CLASSES + 2, (100,)), + 1, + f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*", ), ], ) @@ -369,7 +369,7 @@ def test_multilabel_confusion_matrix_dtype_cpu(self, inputs, dtype): if (preds < 0).any() and dtype == torch.half: pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") - self.run_precision_test_cpu( + self.run_plot_test_cpu( preds=preds, target=target, metric_module=MultilabelConfusionMatrix, @@ -392,6 +392,19 @@ def test_multilabel_confusion_matrix_dtype_gpu(self, inputs, dtype): dtype=dtype, ) + def test_multilabel_confusion_matrix_plot(self, inputs): + """Test multilabel cm plots.""" + multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=2) + preds = target = torch.ones(1, 2).int() + multi_label_confusion_matrix.update(preds, target) + multi_label_confusion_matrix.plot() + + multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=NUM_CLASSES) + preds = target = torch.ones(1, NUM_CLASSES).int() + print(preds.shape, target.shape) + multi_label_confusion_matrix.update(preds, target) + multi_label_confusion_matrix.plot() + def test_warning_on_nan(): """Test that a warning is given if division by zero happens during normalization of confusion matrix.""" @@ -399,8 +412,8 @@ def test_warning_on_nan(): target = torch.randint(3, size=(20,)) with pytest.warns( - UserWarning, - match=".* NaN values found in confusion matrix have been replaced with zeros.", + UserWarning, + match=".* NaN values found in confusion matrix have been replaced with zeros.", ): multiclass_confusion_matrix(preds, target, num_classes=5, normalize="true") From 945f8f90d88065d2f27a300a1ef9dc8663d585dc Mon Sep 17 00:00:00 2001 From: stefan Date: Tue, 3 Dec 2024 16:16:17 +1030 Subject: [PATCH 03/10] adde bugfix --- src/torchmetrics/utilities/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 4d14349b7f9..2bd9ec819e0 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -270,7 +270,7 @@ def plot_confusion_matrix( fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax) axs = trim_axs(axs, nb) for i in range(nb): - ax = axs[i] if rows != 1 and cols != 1 else axs + ax = axs[i] if rows != 1 or cols != 1 else axs if fig_label is not None: ax.set_title(f"Label {fig_label[i]}", fontsize=15) im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap) From 64e99ee36069c60b3c28214ef41bc575c94abc69 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 05:55:37 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../classification/test_confusion_matrix.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 40716724f99..775c3641a2e 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -243,28 +243,28 @@ def test_multiclass_confusion_matrix_dtype_gpu(self, inputs, dtype): ("preds", "target", "ignore_index", "error_message"), [ ( - torch.randint(NUM_CLASSES + 1, (100,)), - torch.randint(NUM_CLASSES, (100,)), - None, - f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*", + torch.randint(NUM_CLASSES + 1, (100,)), + torch.randint(NUM_CLASSES, (100,)), + None, + f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*", ), ( - torch.randint(NUM_CLASSES, (100,)), - torch.randint(NUM_CLASSES + 1, (100,)), - None, - f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*", + torch.randint(NUM_CLASSES, (100,)), + torch.randint(NUM_CLASSES + 1, (100,)), + None, + f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*", ), ( - torch.randint(NUM_CLASSES + 2, (100,)), - torch.randint(NUM_CLASSES, (100,)), - 1, - f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*", + torch.randint(NUM_CLASSES + 2, (100,)), + torch.randint(NUM_CLASSES, (100,)), + 1, + f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*", ), ( - torch.randint(NUM_CLASSES, (100,)), - torch.randint(NUM_CLASSES + 2, (100,)), - 1, - f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*", + torch.randint(NUM_CLASSES, (100,)), + torch.randint(NUM_CLASSES + 2, (100,)), + 1, + f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*", ), ], ) @@ -412,8 +412,8 @@ def test_warning_on_nan(): target = torch.randint(3, size=(20,)) with pytest.warns( - UserWarning, - match=".* NaN values found in confusion matrix have been replaced with zeros.", + UserWarning, + match=".* NaN values found in confusion matrix have been replaced with zeros.", ): multiclass_confusion_matrix(preds, target, num_classes=5, normalize="true") From 4d1d3c6ecd4373a033c2e9b06ab1105d81bd8831 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 3 Dec 2024 14:39:28 +0100 Subject: [PATCH 05/10] Update src/torchmetrics/utilities/plot.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/utilities/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 2bd9ec819e0..d5f8f373c7b 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -270,7 +270,7 @@ def plot_confusion_matrix( fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax) axs = trim_axs(axs, nb) for i in range(nb): - ax = axs[i] if rows != 1 or cols != 1 else axs + ax = axs[i] if (rows != 1 or cols != 1) else axs if fig_label is not None: ax.set_title(f"Label {fig_label[i]}", fontsize=15) im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap) From b62e379ecd52438a1753554434f558dca84a3cb8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 14:44:54 +0100 Subject: [PATCH 06/10] fix errors --- requirements/_tests.txt | 2 -- .../unittests/classification/test_confusion_matrix.py | 11 +++++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/requirements/_tests.txt b/requirements/_tests.txt index 26b8e14ff98..0a8cb73eaab 100644 --- a/requirements/_tests.txt +++ b/requirements/_tests.txt @@ -20,5 +20,3 @@ cloudpickle >1.3, <=3.1.0 scikit-learn ==1.2.*; python_version < "3.9" scikit-learn ==1.5.*; python_version > "3.8" # we do not use `> =` because of oldest replcement cachier ==3.1.2 - -matplotlib==3.9.3 diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 775c3641a2e..718fdbd37c8 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -369,7 +369,7 @@ def test_multilabel_confusion_matrix_dtype_cpu(self, inputs, dtype): if (preds < 0).any() and dtype == torch.half: pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") - self.run_plot_test_cpu( + self.run_precision_test_cpu( preds=preds, target=target, metric_module=MultilabelConfusionMatrix, @@ -397,13 +397,16 @@ def test_multilabel_confusion_matrix_plot(self, inputs): multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=2) preds = target = torch.ones(1, 2).int() multi_label_confusion_matrix.update(preds, target) - multi_label_confusion_matrix.plot() + fig, ax = multi_label_confusion_matrix.plot() + assert fig is not None + assert ax is not None multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=NUM_CLASSES) preds = target = torch.ones(1, NUM_CLASSES).int() - print(preds.shape, target.shape) multi_label_confusion_matrix.update(preds, target) - multi_label_confusion_matrix.plot() + fig, ax = multi_label_confusion_matrix.plot() + assert fig is not None + assert ax is not None def test_warning_on_nan(): From 3550c919dcc2bbc8cb79ee900a239e62cf239a0a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 14:45:49 +0100 Subject: [PATCH 07/10] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a253f90f6ef..6898fb2e847 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed plotting of multilabel confusion matrix ([#2858](https://github.com/PyTorchLightning/metrics/pull/2858)) --- From 30bab9af5263f6a15fb55e9a4de7ee7662d6db8b Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 17 Dec 2024 20:36:25 +0900 Subject: [PATCH 08/10] Apply suggestions from code review --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e6821af0b5..b3402b4a064 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed plotting of multilabel confusion matrix ([#2858](https://github.com/PyTorchLightning/metrics/pull/2858)) + + - Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848)) From 9dad49b7cc657c43b1a1214acaabf9ee3a5286d0 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 17 Dec 2024 20:38:51 +0900 Subject: [PATCH 09/10] test --- .../classification/test_confusion_matrix.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 718fdbd37c8..d5912955b63 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -392,17 +392,11 @@ def test_multilabel_confusion_matrix_dtype_gpu(self, inputs, dtype): dtype=dtype, ) - def test_multilabel_confusion_matrix_plot(self, inputs): + @pytest.mark.parametrize("num_labels", [2, NUM_CLASSES]) + def test_multilabel_confusion_matrix_plot(self, num_labels): """Test multilabel cm plots.""" - multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=2) - preds = target = torch.ones(1, 2).int() - multi_label_confusion_matrix.update(preds, target) - fig, ax = multi_label_confusion_matrix.plot() - assert fig is not None - assert ax is not None - - multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=NUM_CLASSES) - preds = target = torch.ones(1, NUM_CLASSES).int() + multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=num_labels) + preds = target = torch.ones(1, num_labels).int() multi_label_confusion_matrix.update(preds, target) fig, ax = multi_label_confusion_matrix.plot() assert fig is not None From c7b5c760575192c4ebbe1f6c0dd4907164ba5854 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Sat, 21 Dec 2024 20:02:08 +0900 Subject: [PATCH 10/10] inputs --- tests/unittests/classification/test_confusion_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 42b41df76d4..7d7a5f28cb0 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -394,7 +394,7 @@ def test_multilabel_confusion_matrix_dtype_gpu(self, inputs, dtype): ) @pytest.mark.parametrize("num_labels", [2, NUM_CLASSES]) - def test_multilabel_confusion_matrix_plot(self, num_labels): + def test_multilabel_confusion_matrix_plot(self, num_labels, inputs): """Test multilabel cm plots.""" multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=num_labels) preds = target = torch.ones(1, num_labels).int()