Skip to content

Commit

Permalink
Reference, fix readme and add [all] install option (#435)
Browse files Browse the repository at this point in the history
* grouping in Chlog

* update readme

* rc0

Co-authored-by: Jirka <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 9, 2021
1 parent 05faaf6 commit 292233c
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
31 changes: 25 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Extra dependencies for specialized metrics:
```bash
pip install torchmetrics[image]
pip install torchmetrics[text]
pip install torchmetrics[all] # install all of the above
```

</details>
Expand Down Expand Up @@ -116,11 +117,15 @@ import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()

# move the metric to device you want computations to take place
device = "cuda" if torch.cuda.is_available() else "cpu"
metric.to(device)

n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
preds = torch.randn(10, 5).softmax(dim=-1).to(device)
target = torch.randint(5, (10,)).to(device)

# metric on current batch
acc = metric(preds, target)
Expand Down Expand Up @@ -262,6 +267,7 @@ We currently have implemented metrics within the following domains:
[SI_SDR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-sdr),
[SI_SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-snr),
[SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr)
and [1 more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics)
)
- Classification (
[Accuracy](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#accuracy),
Expand Down Expand Up @@ -291,6 +297,7 @@ We currently have implemented metrics within the following domains:
[BleuScore](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#bleuscore),
[RougeScore](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#rougescore),
[WER](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#wer)
and [1 more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#text)
)

In total torchmetrics contains 50+ metrics!
Expand All @@ -308,11 +315,23 @@ to get help becoming a contributor!

For help or questions, join our huge community on [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!

## Citations
## Citation

We’re excited to continue the strong legacy of open source software and have been inspired
over the years by Caffe, Theano, Keras, PyTorch, torchbearer, ignite, sklearn and fast.ai.

We’re excited to continue the strong legacy of open source software and have been inspired over the years by
Caffe, Theano, Keras, PyTorch, torchbearer, ignite, sklearn and fast.ai. When/if a paper is written about this,
we’ll be happy to cite these frameworks and the corresponding authors.
If you want to cite this framework feel free to use this (but only if you loved it 😊):

```misc
@misc{torchmetrics,
author = {PyTorchLightning Team},
title = {Torchmetrics: Machine learning metrics for distributed, scalable PyTorch applications},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/PyTorchLightning/metrics}},
}
```

## License

Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def _prepare_extras():
"image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="image.txt"),
"text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="text.txt"),
}
# create an 'all' keyword that install all possible denpendencies
extras["all"] = [package for extra in extras.values() for package in extra]

return extras


Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.0dev"
__version__ = "0.5.0rc0"
__author__ = "PyTorchLightning et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]
_TORCH_LOWER_1_6: Optional[bool] = _compare_version("torch", operator.lt, "1.6.0")
_TORCH_GREATER_EQUAL_1_6: Optional[bool] = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7: Optional[bool] = _compare_version("torch", operator.ge, "1.7.0")
_LIGHTNING_AVAILABLE: bool = _module_available("pytorch_lightning")

_LIGHTNING_AVAILABLE: bool = _module_available("pytorch_lightning")
_JIWER_AVAILABLE: bool = _module_available("jiwer")
_NLTK_AVAILABLE: bool = _module_available("nltk")
_ROUGE_SCORE_AVAILABLE: bool = _module_available("rouge_score")
_BERTSCORE_AVAILABLE: bool = _module_available("bert_score")
_NLTK_AVAILABLE = _module_available("nltk")
_ROUGE_SCORE_AVAILABLE = _module_available("rouge_score")
_SCIPY_AVAILABLE: bool = _module_available("scipy")
_TORCH_FIDELITY_AVAILABLE: bool = _module_available("torch_fidelity")

0 comments on commit 292233c

Please sign in to comment.