diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6d4ea306af2..17ecc734857 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -20,10 +20,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ### Removed
 
+- Removed `rouge-score` as dependency for text package ([#443](https://github.com/PyTorchLightning/metrics/pull/443))
 
 ### Fixed
 
-- Fixed bug in the ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448))
+- Fixed ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448))
 
 
 ## [0.5.0] - 2021-08-09
diff --git a/requirements/test.txt b/requirements/test.txt
index 960f810cb89..7a1fbe57fdc 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -25,3 +25,6 @@ mir_eval>=0.6
 #pesq @ https://github.com/ludlows/python-pesq/archive/refs/heads/master.zip
 #SRMRpy @ https://github.com/jfsantos/SRMRpy/archive/refs/heads/master.zip
 speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip
+
+# text
+rouge-score>=0.0.4
diff --git a/requirements/text.txt b/requirements/text.txt
index d850ded158a..99408171a0d 100644
--- a/requirements/text.txt
+++ b/requirements/text.txt
@@ -1,4 +1,3 @@
 jiwer>=2.2.0
 nltk>=3.6
-rouge-score>=0.0.4
 bert-score==0.3.10
diff --git a/tests/text/test_rouge.py b/tests/text/test_rouge.py
index 8fe1ce0903b..497ecd8467c 100644
--- a/tests/text/test_rouge.py
+++ b/tests/text/test_rouge.py
@@ -16,7 +16,6 @@
 
 import pytest
 import torch
-from torch import tensor
 
 from torchmetrics.functional.text.rouge import rouge_score
 from torchmetrics.text.rouge import ROUGEScore
@@ -30,16 +29,13 @@
 
 ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum")
 
-PRECISION = 0
-RECALL = 1
-F_MEASURE = 2
-
 SINGLE_SENTENCE_EXAMPLE_PREDS = "The quick brown fox jumps over the lazy dog"
 SINGLE_SENTENCE_EXAMPLE_TARGET = "The quick brown dog jumps on the log."
 
 PREDS = "My name is John".split()
 TARGETS = "Is your name John".split()
 
+
 BATCHES_RS_PREDS = [SINGLE_SENTENCE_EXAMPLE_PREDS]
 BATCHES_RS_PREDS.extend(PREDS)
 BATCHES_RS_TARGETS = [SINGLE_SENTENCE_EXAMPLE_TARGET]
@@ -55,145 +51,139 @@ def _compute_rouge_score(preds: List[str], targets: List[str], use_stemmer: bool
     scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
     aggregator = BootstrapAggregator()
     for pred, target in zip(preds, targets):
-        aggregator.add_scores(scorer.score(pred, target))
+        aggregator.add_scores(scorer.score(target, pred))
     return aggregator.aggregate()
 
 
-@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
+@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
 @pytest.mark.parametrize(
-    ["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
+    ["pl_rouge_metric_key", "use_stemmer"],
     [
-        pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
-        pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
-        pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
-        pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
-        pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
-        pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
-        pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
-        pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
-        pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
-        pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
-        pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
-        pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
+        pytest.param("rouge1_precision", True),
+        pytest.param("rouge1_recall", True),
+        pytest.param("rouge1_fmeasure", False),
+        pytest.param("rouge2_precision", False),
+        pytest.param("rouge2_recall", True),
+        pytest.param("rouge2_fmeasure", True),
+        pytest.param("rougeL_precision", False),
+        pytest.param("rougeL_recall", False),
+        pytest.param("rougeL_fmeasure", True),
+        pytest.param("rougeLsum_precision", True),
+        pytest.param("rougeLsum_recall", False),
+        pytest.param("rougeLsum_fmeasure", False),
     ],
 )
-def test_rouge_metric_functional_single_sentence(
-    pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
-):
-    scorer = RougeScorer(ROUGE_KEYS)
-    rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET)
-    rs_output = round(rs_scores[rouge_score_key][metric], decimal_places)
+def test_rouge_metric_functional_single_sentence(pl_rouge_metric_key, use_stemmer):
+    rouge_level, metric = pl_rouge_metric_key.split("_")
+
+    scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
+    rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_TARGET, SINGLE_SENTENCE_EXAMPLE_PREDS)
+    rs_result = torch.tensor(getattr(rs_scores[rouge_level], metric), dtype=torch.float32)
 
-    pl_output = rouge_score(
-        [SINGLE_SENTENCE_EXAMPLE_PREDS],
-        [SINGLE_SENTENCE_EXAMPLE_TARGET],
-        newline_sep=newline_sep,
-        use_stemmer=use_stemmer,
-        decimal_places=decimal_places,
-    )
+    pl_output = rouge_score([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET], use_stemmer=use_stemmer)
 
-    assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
+    assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)
 
 
-@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
+@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
 @pytest.mark.parametrize(
-    ["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
+    ["pl_rouge_metric_key", "use_stemmer"],
     [
-        pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
-        pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
-        pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
-        pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
-        pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
-        pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
-        pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
-        pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
-        pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
-        pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
-        pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
-        pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
+        pytest.param("rouge1_precision", True),
+        pytest.param("rouge1_recall", True),
+        pytest.param("rouge1_fmeasure", False),
+        pytest.param("rouge2_precision", False),
+        pytest.param("rouge2_recall", True),
+        pytest.param("rouge2_fmeasure", True),
+        pytest.param("rougeL_precision", False),
+        pytest.param("rougeL_recall", False),
+        pytest.param("rougeL_fmeasure", True),
+        pytest.param("rougeLsum_precision", True),
+        pytest.param("rougeLsum_recall", False),
+        pytest.param("rougeLsum_fmeasure", False),
     ],
 )
-def test_rouge_metric_functional(
-    pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
-):
+def test_rouge_metric_functional(pl_rouge_metric_key, use_stemmer):
+    rouge_level, metric = pl_rouge_metric_key.split("_")
+
     rs_scores = _compute_rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer)
-    rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places)
+    rs_result = torch.tensor(getattr(rs_scores[rouge_level].mid, metric), dtype=torch.float32)
 
-    pl_output = rouge_score(
-        PREDS, TARGETS, newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places
-    )
+    pl_output = rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer)
 
-    assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
+    assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)
 
 
-@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
+@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
 @pytest.mark.parametrize(
-    ["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
+    ["pl_rouge_metric_key", "use_stemmer"],
     [
-        pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
-        pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
-        pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
-        pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
-        pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
-        pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
-        pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
-        pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
-        pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
-        pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
-        pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
-        pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
+        pytest.param("rouge1_precision", True),
+        pytest.param("rouge1_recall", True),
+        pytest.param("rouge1_fmeasure", False),
+        pytest.param("rouge2_precision", False),
+        pytest.param("rouge2_recall", True),
+        pytest.param("rouge2_fmeasure", True),
+        pytest.param("rougeL_precision", False),
+        pytest.param("rougeL_recall", False),
+        pytest.param("rougeL_fmeasure", True),
+        pytest.param("rougeLsum_precision", True),
+        pytest.param("rougeLsum_recall", False),
+        pytest.param("rougeLsum_fmeasure", False),
     ],
 )
-def test_rouge_metric_class(pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep):
-    scorer = RougeScorer(ROUGE_KEYS)
-    rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET)
-    rs_output = round(rs_scores[rouge_score_key][metric], decimal_places)
+def test_rouge_metric_class(pl_rouge_metric_key, use_stemmer):
+    rouge_level, metric = pl_rouge_metric_key.split("_")
+
+    scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
+    rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_TARGET, SINGLE_SENTENCE_EXAMPLE_PREDS)
+    rs_result = torch.tensor(getattr(rs_scores[rouge_level], metric), dtype=torch.float32)
 
-    rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places)
+    rouge = ROUGEScore(use_stemmer=use_stemmer)
     pl_output = rouge([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET])
 
-    assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
+    assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)
 
 
-@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
+@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
 @pytest.mark.parametrize(
-    ["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
+    ["pl_rouge_metric_key", "use_stemmer"],
     [
-        pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
-        pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
-        pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
-        pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
-        pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
-        pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
-        pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
-        pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
-        pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
-        pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
-        pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
-        pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
+        pytest.param("rouge1_precision", True),
+        pytest.param("rouge1_recall", True),
+        pytest.param("rouge1_fmeasure", False),
+        pytest.param("rouge2_precision", False),
+        pytest.param("rouge2_recall", True),
+        pytest.param("rouge2_fmeasure", True),
+        pytest.param("rougeL_precision", False),
+        pytest.param("rougeL_recall", False),
+        pytest.param("rougeL_fmeasure", True),
+        pytest.param("rougeLsum_precision", True),
+        pytest.param("rougeLsum_recall", False),
+        pytest.param("rougeLsum_fmeasure", False),
     ],
 )
-def test_rouge_metric_class_batches(
-    pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
-):
+def test_rouge_metric_class_batches(pl_rouge_metric_key, use_stemmer):
+    rouge_level, metric = pl_rouge_metric_key.split("_")
+
     rs_scores = _compute_rouge_score(BATCHES_RS_PREDS, BATCHES_RS_TARGETS, use_stemmer=use_stemmer)
-    rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places)
+    rs_result = torch.tensor(getattr(rs_scores[rouge_level].mid, metric), dtype=torch.float32)
 
-    rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places)
+    rouge = ROUGEScore(use_stemmer=use_stemmer)
     for batch in BATCHES:
         rouge.update(batch["preds"], batch["targets"])
     pl_output = rouge.compute()
 
-    assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
+    assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)
 
 
 def test_rouge_metric_raises_errors_and_warnings():
     """Test that expected warnings and errors are raised."""
-    if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE):
+    if not _NLTK_AVAILABLE:
         with pytest.raises(
             ValueError,
-            match="ROUGE metric requires that both nltk and rouge-score is installed."
-            "Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`",
+            match="ROUGE metric requires that nltk is installed."
+            "Either as `pip install torchmetrics[text]` or `pip install nltk`",
         ):
             ROUGEScore()
 
diff --git a/torchmetrics/functional/text/rouge.py b/torchmetrics/functional/text/rouge.py
index 91e61d93e74..688175d0c81 100644
--- a/torchmetrics/functional/text/rouge.py
+++ b/torchmetrics/functional/text/rouge.py
@@ -12,64 +12,150 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
-from typing import Dict, List, Tuple, Union
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import torch
-from torch import Tensor, tensor
-
-from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE
-
-if _ROUGE_SCORE_AVAILABLE:
-    from rouge_score.rouge_scorer import RougeScorer
-    from rouge_score.scoring import AggregateScore, BootstrapAggregator
-else:
-    RougeScorer, AggregateScore, BootstrapAggregator = object, object, object
-
-ALLOWED_ROUGE_KEYS = (
-    "rouge1",
-    "rouge2",
-    "rouge3",
-    "rouge4",
-    "rouge5",
-    "rouge6",
-    "rouge7",
-    "rouge8",
-    "rouge9",
-    "rougeL",
-    "rougeLsum",
-)
-
-
-def add_newline_to_end_of_each_sentence(x: str) -> str:
+from torch import Tensor
+
+from torchmetrics.utilities.imports import _NLTK_AVAILABLE
+
+ALLOWED_ROUGE_KEYS: Dict[str, Union[int, str]] = {
+    "rouge1": 1,
+    "rouge2": 2,
+    "rouge3": 3,
+    "rouge4": 4,
+    "rouge5": 5,
+    "rouge6": 6,
+    "rouge7": 7,
+    "rouge8": 8,
+    "rouge9": 9,
+    "rougeL": "L",
+    "rougeLsum": "Lsum",
+}
+
+
+def _add_newline_to_end_of_each_sentence(x: str) -> str:
     """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
-    if _NLTK_AVAILABLE:
-        import nltk
+    if not _NLTK_AVAILABLE:
+        raise ValueError("ROUGE-Lsum calculation requires that nltk is installed. Use `pip install nltk`.")
+    import nltk
 
-        nltk.download("punkt", quiet=True, force=False)
+    nltk.download("punkt", quiet=True, force=False)
 
     re.sub("<n>", "", x)  # remove pegasus newline char
-    assert nltk, "nltk must be installed to separate newlines between sentences. (pip install nltk)"
     return "\n".join(nltk.sent_tokenize(x))
 
 
-def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, Tensor]:
-    """Formats the computed (aggregated) rouge score to a dictionary of tensors format."""
-    flattened_result = {}
-    for rouge_key, rouge_aggregate_score in result.items():
-        for stat in ["precision", "recall", "fmeasure"]:
-            mid = rouge_aggregate_score.mid
-            score = round(getattr(mid, stat), decimal_places)
-            flattened_result[f"{rouge_key}_{stat}"] = tensor(score, dtype=torch.float)
-    return flattened_result
+def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[str, float]:
+    """This computes precision, recall and F1 score based on hits/lcs, and the length of lists of tokenizer
+    predicted and target sentences.
+
+    Args:
+        hits_or_lcs:
+            A number of matches or a length of the longest common subsequence.
+        pred_len:
+            A length of a tokenized predicted sentence.
+        target_len:
+            A length of a tokenized target sentence.
+    """
+    precision = hits_or_lcs / pred_len
+    recall = hits_or_lcs / target_len
+    if precision == recall == 0.0:
+        return dict(precision=0.0, recall=0.0, fmeasure=0.0)
+
+    fmeasure = 2 * precision * recall / (precision + recall)
+    return dict(precision=precision, recall=recall, fmeasure=fmeasure)
+
+
+def _lcs(pred_tokens: List[str], target_tokens: List[str]) -> int:
+    """Common DP algorithm to compute the length of the longest common subsequence.
+
+    Args:
+        pred_tokens:
+            A tokenized predicted sentence.
+        target_toknes:
+            A tokenized target sentence.
+    """
+    LCS = [[0] * (len(pred_tokens) + 1) for _ in range(len(target_tokens) + 1)]
+    for i in range(1, len(target_tokens) + 1):
+        for j in range(1, len(pred_tokens) + 1):
+            if target_tokens[i - 1] == pred_tokens[j - 1]:
+                LCS[i][j] = LCS[i - 1][j - 1] + 1
+            else:
+                LCS[i][j] = max(LCS[i - 1][j], LCS[i][j - 1])
+    return LCS[-1][-1]
+
+
+def _normalize_text(text: str, stemmer: Optional[Any] = None) -> str:
+    """Rouge score should be calculated only over lowercased words and digits. Optionally, Porter stemmer can be
+    used to strip word suffixes to improve matching. The text normalization follows the implemantion from
+    https://github.com/google-research/google-research/blob/master/rouge/tokenize.py.
+
+    Args:
+        text:
+            An input sentence.
+        stemmer:
+            Porter stemmer instance to strip word suffixes to improve matching.
+    """
+    text = re.sub(r"[^a-z0-9]+", " ", text.lower())
+    if stemmer:
+        text = " ".join(stemmer.stem(x) if len(x) > 3 else x for x in text.split())
+    return text.strip()  # to ensure there are no whitespaces as the end of sentence
+
+
+def _rouge_n_score(pred: str, target: str, n_gram: int) -> Dict[str, float]:
+    """This computes precision, recall and F1 score for the Rouge-N metric.
+
+    Args:
+        pred:
+            A predicted sentence.
+        target:
+            A target sentence.
+        n_gram:
+            N-gram overlap.
+    """
+    pred_tokenized, target_tokenized = _tokenize(pred, n_gram), _tokenize(target, n_gram)
+    pred_len, target_len = len(pred_tokenized), len(target_tokenized)
+    if 0 in (pred_len, target_len):
+        return dict(precision=0.0, recall=0.0, fmeasure=0.0)
+
+    pred_counter: Dict[str, int] = defaultdict(int)
+    target_counter: Dict[str, int] = defaultdict(int)
+    for w in pred_tokenized:
+        pred_counter[w] += 1
+    for w in target_tokenized:
+        target_counter[w] += 1
+    # It is sufficient to take a set(pred_tokenized) for hits count as we consider intersenction of pred & target
+    hits = sum(min(pred_counter[w], target_counter[w]) for w in set(pred_tokenized))
+    return _compute_metrics(hits, pred_len, target_len)
+
+
+def _rouge_l_score(pred: str, target: str) -> Dict[str, float]:
+    """This computes precision, recall and F1 score for the Rouge-L or Rouge-LSum metric.
+
+    Args:
+        pred:
+            A predicted sentence.
+        target:
+            A target sentence.
+    """
+    pred_tokenized, target_tokenized = _tokenize(pred, 1), _tokenize(target, 1)
+    pred_len, target_len = len(pred_tokenized), len(target_tokenized)
+    if 0 in (pred_len, target_len):
+        return dict(precision=0.0, recall=0.0, fmeasure=0.0)
+
+    lcs = _lcs(pred_tokenized, target_tokenized)
+    return _compute_metrics(lcs, pred_len, target_len)
 
 
 def _rouge_score_update(
     preds: List[str],
     targets: List[str],
-    scorer: RougeScorer,
-    aggregator: BootstrapAggregator,
-    newline_sep: bool = False,
-) -> None:
+    rouge_keys_values: List[Union[int, str]],
+    results: Optional[Dict[Union[int, str], List[Dict[str, float]]]] = None,
+    stemmer: Optional[Any] = None,
+) -> Dict[Union[int, str], List[Dict[str, float]]]:
     """Update the rouge score with the current set of predicted and target sentences.
 
     Args:
@@ -77,49 +163,69 @@ def _rouge_score_update(
             An iterable of predicted sentences.
         targets:
             An iterable of target sentences.
-        scorer:
-            An instance of the ``RougeScorer`` class from the ``rouge_score`` package.
-        aggregator:
-            An instance of the ``BootstrapAggregator`` from the from the ``rouge_score`` package.
-        newline_sep:
-            New line separate the inputs.
+        rouge_keys_values:
+            List of N-grams/'L'/'Lsum' arguments.
+        stemmer:
+            Porter stemmer instance to strip word suffixes to improve matching.
 
     Example:
         >>> targets = "Is your name John".split()
         >>> preds = "My name is John".split()
-        >>> aggregator = BootstrapAggregator()
-        >>> scorer = RougeScorer(rouge_types=("rouge1", "rouge2", "rougeL", "rougeLsum"), use_stemmer=False)
-        >>> _rouge_score_update(preds, targets, scorer=scorer, aggregator=aggregator, newline_sep=False)
+        >>> from pprint import pprint
+        >>> score = _rouge_score_update(preds, targets, rouge_keys_values=[1, 2, 3, 'L'])
+        >>> pprint(score)  # doctest: +NORMALIZE_WHITESPACE +SKIP
+        {'1': {'precision': 0.25, 'recall': 0.25, 'fmeasure': 0.25},
+         '2': {'precision': 0.0, 'recall': 0.0, 'fmeasure': 0.0},
+         '3': {'precision': 0.0, 'recall': 0.0, 'fmeasure': 0.0},
+         'L': {'precision': 0.25, 'recall': 0.25, 'fmeasure': 0.25}}
     """
-    for pred, target in zip(preds, targets):
+    results = results if results is not None else {rouge_key: [] for rouge_key in rouge_keys_values}
+    for pred_raw, target_raw in zip(preds, targets):
+        pred, target = _normalize_text(pred_raw, stemmer), _normalize_text(target_raw, stemmer)
         # rougeLsum expects "\n" separated sentences within a summary
-        if newline_sep:
-            pred = add_newline_to_end_of_each_sentence(pred)
-            target = add_newline_to_end_of_each_sentence(target)
-        results = scorer.score(pred, target)
-        aggregator.add_scores(results)
+        if "Lsum" in rouge_keys_values:
+            pred_sum = _normalize_text(_add_newline_to_end_of_each_sentence(pred_raw), stemmer)
+            target_sum = _normalize_text(_add_newline_to_end_of_each_sentence(target_raw), stemmer)
+
+        for rouge_key in rouge_keys_values:
+            if isinstance(rouge_key, int):
+                score = _rouge_n_score(pred, target, rouge_key)
+            else:
+                score = _rouge_l_score(
+                    pred if rouge_key != "Lsum" else pred_sum,
+                    target if rouge_key != "Lsum" else target_sum,
+                )
+            results[rouge_key].append(score)
+    return results
 
 
-def _rouge_score_compute(aggregator: BootstrapAggregator, decimal_places: int = 4) -> Dict[str, Tensor]:
+def _rouge_score_compute(
+    sentence_results: Optional[Dict[Union[int, str], List[Dict[str, float]]]]
+) -> Dict[str, Tensor]:
     """Compute the combined ROUGE metric for all the input set of predicted and target sentences.
 
     Args:
-        aggregator:
-            An instance of the ``BootstrapAggregator`` from the from the ``rouge_score`` package.
-        decimal_places:
-            The number of digits to round the computed the values to.
+        sentence_results:
+            Rouge-N/Rouge-L/Rouge-LSum metrics calculated for single sentence.
     """
-    result = aggregator.aggregate()
-    return format_rouge_results(result, decimal_places)
+    results: Dict[str, Tensor] = {}
+    # Obtain mean scores for individual rouge metrics
+    if sentence_results is None:
+        return results
+    for rouge_key, scores in sentence_results.items():
+        res = torch.tensor([(score["precision"], score["recall"], score["fmeasure"]) for score in scores]).mean(0)
+        results[f"rouge{rouge_key}_precision"] = res[0]
+        results[f"rouge{rouge_key}_recall"] = res[1]
+        results[f"rouge{rouge_key}_fmeasure"] = res[2]
+
+    return results
 
 
 def rouge_score(
     preds: Union[str, List[str]],
     targets: Union[str, List[str]],
-    newline_sep: bool = False,
     use_stemmer: bool = False,
     rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"),  # type: ignore
-    decimal_places: int = 4,
 ) -> Dict[str, Tensor]:
     """Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_, used for automatic summarization.
 
@@ -128,15 +234,11 @@ def rouge_score(
             An iterable of predicted sentences.
         targets:
             An iterable of target sentences.
-        newline_sep:
-            New line separate the inputs.
         use_stemmer:
             Use Porter stemmer to strip word suffixes to improve matching.
         rouge_keys:
             A list of rouge types to calculate.
             Keys that are allowed are ``rougeL``, ``rougeLsum``, and ``rouge1`` through ``rouge9``.
-        decimal_places:
-            The number of digits to round the computed the values to.
 
     Return:
         Python dictionary of rouge scores for each input rouge key.
@@ -161,7 +263,7 @@ def rouge_score(
 
     Raises:
         ValueError:
-            If the python packages ``nltk`` or ``rouge-score`` are not installed.
+            If the python package ``nltk`` is not installed.
         ValueError:
             If any of the ``rouge_keys`` does not belong to the allowed set of keys.
 
@@ -169,17 +271,19 @@ def rouge_score(
         [1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin https://aclanthology.org/W04-1013/
     """
 
-    if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE):
-        raise ValueError(
-            "ROUGE metric requires that both nltk and rouge-score is installed."
-            " Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`"
-        )
+    if use_stemmer:
+        if not _NLTK_AVAILABLE:
+            raise ValueError("Stemmer requires that nltk is installed. Use `pip install nltk`.")
+        import nltk
+
+    stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None
 
     if not isinstance(rouge_keys, tuple):
         rouge_keys = tuple([rouge_keys])
     for key in rouge_keys:
-        if key not in ALLOWED_ROUGE_KEYS:
-            raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {ALLOWED_ROUGE_KEYS}")
+        if key not in ALLOWED_ROUGE_KEYS.keys():
+            raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {list(ALLOWED_ROUGE_KEYS.keys())}")
+    rouge_keys_values = [ALLOWED_ROUGE_KEYS[key] for key in rouge_keys]
 
     if isinstance(preds, str):
         preds = [preds]
@@ -187,8 +291,19 @@ def rouge_score(
     if isinstance(targets, str):
         targets = [targets]
 
-    aggregator = BootstrapAggregator()
-    scorer = RougeScorer(rouge_keys, use_stemmer=use_stemmer)
+    sentence_results = _rouge_score_update(preds, targets, rouge_keys_values, stemmer=stemmer)
+    return _rouge_score_compute(sentence_results)
+
+
+def _tokenize(text: str, n_gram: int) -> List[str]:
+    """Retun the list of a tokenized input text, where tokens are represented by N-grams.
 
-    _rouge_score_update(preds, targets, scorer=scorer, aggregator=aggregator, newline_sep=newline_sep)
-    return _rouge_score_compute(aggregator=aggregator, decimal_places=decimal_places)
+    Args:
+        text:
+            An input sentence.
+        n_gram
+            N-gram size to return.
+    """
+    tokens = re.split(r"\s+", text)
+    n_grams_list = [" ".join(tokens[i : i + n_gram]) for i in range(len(tokens) - n_gram + 1)]
+    return n_grams_list
diff --git a/torchmetrics/text/rouge.py b/torchmetrics/text/rouge.py
index 495f1232bda..b58793cfdff 100644
--- a/torchmetrics/text/rouge.py
+++ b/torchmetrics/text/rouge.py
@@ -11,27 +11,25 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import warnings
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 from torch import Tensor
 
 from torchmetrics import Metric
 from torchmetrics.functional.text.rouge import ALLOWED_ROUGE_KEYS, _rouge_score_compute, _rouge_score_update
-from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE
-
-if _ROUGE_SCORE_AVAILABLE:
-    from rouge_score.rouge_scorer import RougeScorer
-    from rouge_score.scoring import BootstrapAggregator
-else:
-    RougeScorer, BootstrapAggregator = object, object
+from torchmetrics.utilities.imports import _NLTK_AVAILABLE
 
 
 class ROUGEScore(Metric):
     """Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_, used for automatic summarization.
+    This implementation should imitate the behaviour of the `rouge-score` package https://pypi.org/project/rouge-
+    score/.
 
     Args:
         newline_sep:
             New line separate the inputs.
+            This argument has not been in use any more. It is deprecated in v0.6 and will be removed in v0.7.
         use_stemmer:
             Use Porter stemmer to strip word suffixes to improve matching.
         rouge_keys:
@@ -39,6 +37,7 @@ class ROUGEScore(Metric):
             Keys that are allowed are ``rougeL``, ``rougeLsum``, and ``rouge1`` through ``rouge9``.
         decimal_places:
             The number of digits to round the computed the values to.
+            This argument has not been in usd any more. It is deprecated in v0.6 and will be removed in v0.7.
         compute_on_step:
             Forward only calls ``update()`` and returns None if this is set to False. default: True
         dist_sync_on_step:
@@ -72,7 +71,7 @@ class ROUGEScore(Metric):
 
     Raises:
         ValueError:
-            If the python packages ``nltk`` or ``rouge-score`` are not installed.
+            If the python packages ``nltk`` is not installed.
         ValueError:
             If any of the ``rouge_keys`` does not belong to the allowed set of keys.
 
@@ -82,10 +81,10 @@ class ROUGEScore(Metric):
 
     def __init__(
         self,
-        newline_sep: bool = False,
+        newline_sep: Optional[bool] = None,  # remove in v0.7
         use_stemmer: bool = False,
         rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"),  # type: ignore
-        decimal_places: int = 4,
+        decimal_places: Optional[bool] = None,  # remove in v0.7
         compute_on_step: bool = True,
         dist_sync_on_step: bool = False,
         process_group: Optional[Any] = None,
@@ -97,12 +96,15 @@ def __init__(
             process_group=process_group,
             dist_sync_fn=dist_sync_fn,
         )
+        if newline_sep is not None:
+            warnings.warn("Argument `newline_sep` is deprecated in v0.6 and will be removed in v0.7")
+        if decimal_places is not None:
+            warnings.warn("Argument `decimal_places` is deprecated in v0.6 and will be removed in v0.7")
 
-        if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE):
-            raise ValueError(
-                "ROUGE metric requires that both nltk and rouge-score is installed."
-                " Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`"
-            )
+        if use_stemmer or "rougeLsum" in rouge_keys:
+            if not _NLTK_AVAILABLE:
+                raise ValueError("Stemmer and/or `rougeLsum` requires that nltk is installed. Use `pip install nltk`.")
+            import nltk
 
         if not isinstance(rouge_keys, tuple):
             rouge_keys = tuple([rouge_keys])
@@ -111,11 +113,9 @@ def __init__(
                 raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {ALLOWED_ROUGE_KEYS}")
 
         self.rouge_keys = rouge_keys
-        self.newline_sep = newline_sep
-        self.use_stemmer = use_stemmer
-        self.aggregator = BootstrapAggregator()
-        self.scorer = RougeScorer(rouge_keys, use_stemmer=self.use_stemmer)
-        self.decimal_places = decimal_places
+        self.rouge_keys_values = [ALLOWED_ROUGE_KEYS[key] for key in rouge_keys]
+        self.stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None
+        self.sentence_results: Optional[Dict[Union[int, str], List[Dict[str, float]]]] = None
 
     def update(self, preds: Union[str, List[str]], targets: Union[str, List[str]]) -> None:  # type: ignore
         """Compute rouge scores.
@@ -131,8 +131,8 @@ def update(self, preds: Union[str, List[str]], targets: Union[str, List[str]]) -
         if isinstance(targets, str):
             targets = [targets]
 
-        _rouge_score_update(
-            preds, targets, scorer=self.scorer, aggregator=self.aggregator, newline_sep=self.newline_sep
+        self.sentence_results = _rouge_score_update(
+            preds, targets, self.rouge_keys_values, self.sentence_results, self.stemmer
         )
 
     def compute(self) -> Dict[str, Tensor]:
@@ -141,7 +141,7 @@ def compute(self) -> Dict[str, Tensor]:
         Return:
             Python dictionary of rouge scores for each input rouge key.
         """
-        return _rouge_score_compute(aggregator=self.aggregator, decimal_places=self.decimal_places)
+        return _rouge_score_compute(self.sentence_results)
 
     def __hash__(self) -> int:
         # override to hash list objects.