diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8c17cdc06cc19..bd60b96c8f106 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
 - Added `Pytorch Geometric` integration example with Lightning ([#4568](https://github.com/PyTorchLightning/pytorch-lightning/pull/4568))
 
 
+- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set (
+    [#4707](https://github.com/PyTorchLightning/pytorch-lightning/pull/4707))
+
+
 ### Changed
 
 - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst
index c390db8d7537e..04ac191458c1e 100644
--- a/docs/source/trainer.rst
+++ b/docs/source/trainer.rst
@@ -148,6 +148,19 @@ So you can run it like so:
 
 ------------
 
+Validation
+----------
+You can perform an evaluation epoch over the validation set, outside of the training loop,
+using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be
+useful if you want to collect new metrics from a model right at its initialization
+or that has already been trained.
+
+.. code-block:: python
+
+    trainer.validate(val_dataloaders=val_dataloaders)
+
+------------
+
 Testing
 -------
 Once you're done training, feel free to run the test set!
@@ -155,7 +168,7 @@ Once you're done training, feel free to run the test set!
 
 .. code-block:: python
 
-    trainer.test(test_dataloader=test_dataloader)
+    trainer.test(test_dataloaders=test_dataloaders)
 
 ------------
 
diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py
index 931a39e07af89..75e46dbce83dd 100644
--- a/pytorch_lightning/accelerators/accelerator.py
+++ b/pytorch_lightning/accelerators/accelerator.py
@@ -59,9 +59,9 @@ def barrier(self, name: Optional[str] = None):
     def broadcast(self, obj, src=0):
         return obj
 
-    def train_or_test(self):
-        if self.trainer.testing:
-            results = self.trainer.run_test()
+    def train_or_evaluate(self):
+        if self.trainer.evaluating:
+            results = self.trainer.run_test_or_validate()
         else:
             results = self.trainer.train()
         return results
@@ -160,7 +160,7 @@ def early_stopping_should_stop(self, pl_module):
         return self.trainer.should_stop
 
     def setup_optimizers(self, model):
-        if self.trainer.testing is True:
+        if self.trainer.evaluating:
             return
 
         optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py
index fe0ab59fb554f..279b6327bba5a 100644
--- a/pytorch_lightning/accelerators/cpu_accelerator.py
+++ b/pytorch_lightning/accelerators/cpu_accelerator.py
@@ -57,8 +57,8 @@ def train(self):
         # set up training routine
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
         return results
 
     def training_step(self, args):
diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py
index f43866881cabb..0acc5d6b65339 100644
--- a/pytorch_lightning/accelerators/ddp2_accelerator.py
+++ b/pytorch_lightning/accelerators/ddp2_accelerator.py
@@ -181,8 +181,8 @@ def ddp_train(self, process_idx, mp_queue, model):
         # set up training routine
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
 
         # clean up memory
         torch.cuda.empty_cache()
diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py
index 687b5c21874fb..90347a60a4566 100644
--- a/pytorch_lightning/accelerators/ddp_accelerator.py
+++ b/pytorch_lightning/accelerators/ddp_accelerator.py
@@ -275,8 +275,8 @@ def ddp_train(self, process_idx, model):
         self.barrier('ddp_setup')
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
 
         # clean up memory
         torch.cuda.empty_cache()
diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
index 982da2f53216b..879ad3cdb8b74 100644
--- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
+++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
@@ -145,8 +145,8 @@ def ddp_train(self, process_idx, mp_queue, model):
         # set up training routine
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
 
         # get original model
         model = self.trainer.get_model()
diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py
index 28817c6845f5b..316fac61ca732 100644
--- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py
+++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py
@@ -174,8 +174,8 @@ def ddp_train(self, process_idx, model):
         # set up training routine
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
 
         # clean up memory
         torch.cuda.empty_cache()
diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py
index a06d0b82d6d15..b871f6cbf0c6d 100644
--- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py
+++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py
@@ -157,8 +157,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
         # set up training routine
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
 
         # get original model
         model = self.trainer.get_model()
diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py
index 4b4e1eac8a66c..214b4d88f03aa 100644
--- a/pytorch_lightning/accelerators/dp_accelerator.py
+++ b/pytorch_lightning/accelerators/dp_accelerator.py
@@ -106,8 +106,8 @@ def train(self):
         # set up training routine
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
 
         return results
 
diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py
index b12d275c8ac26..e3f0fb9890809 100644
--- a/pytorch_lightning/accelerators/gpu_accelerator.py
+++ b/pytorch_lightning/accelerators/gpu_accelerator.py
@@ -62,8 +62,9 @@ def train(self):
         # set up training routine
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
+
         return results
 
     def training_step(self, args):
diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py
index b2cec906178f9..d4027c772e061 100644
--- a/pytorch_lightning/accelerators/horovod_accelerator.py
+++ b/pytorch_lightning/accelerators/horovod_accelerator.py
@@ -111,8 +111,8 @@ def train(self):
             # set up training routine
             self.trainer.train_loop.setup_training(self.trainer.model)
 
-            # train or test
-            results = self.train_or_test()
+            # train or evaluate
+            results = self.train_or_evaluate()
 
         # Make sure all workers have finished training before returning to the user
         hvd.join()
diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py
index 30cf6c9dbf169..303066c5e5310 100644
--- a/pytorch_lightning/accelerators/tpu_accelerator.py
+++ b/pytorch_lightning/accelerators/tpu_accelerator.py
@@ -129,8 +129,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
         # set up training routine
         self.trainer.train_loop.setup_training(model)
 
-        # train or test
-        results = self.train_or_test()
+        # train or evaluate
+        results = self.train_or_evaluate()
 
         # save weights at the end of training
         self.__save_end_of_training_weights(model, trainer)
diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py
index 3f6b4ffe9622a..8ca0ef301c260 100644
--- a/pytorch_lightning/callbacks/base.py
+++ b/pytorch_lightning/callbacks/base.py
@@ -28,11 +28,11 @@ class Callback(abc.ABC):
     """
 
     def setup(self, trainer, pl_module, stage: str):
-        """Called when fit or test begins"""
+        """Called when fit, validate, or test begins"""
         pass
 
     def teardown(self, trainer, pl_module, stage: str):
-        """Called when fit or test ends"""
+        """Called when fit, validate, or test ends"""
         pass
 
     def on_init_start(self, trainer):
diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py
index 005a3f8cde4ad..3a2b5c2a57259 100644
--- a/pytorch_lightning/callbacks/early_stopping.py
+++ b/pytorch_lightning/callbacks/early_stopping.py
@@ -134,13 +134,13 @@ def on_load_checkpoint(self, checkpointed_state):
         self.patience = checkpointed_state['patience']
 
     def on_validation_end(self, trainer, pl_module):
-        if trainer.running_sanity_check:
+        if trainer.running_sanity_check or trainer.evaluating:
             return
 
         self._run_early_stopping_check(trainer, pl_module)
 
     def on_validation_epoch_end(self, trainer, pl_module):
-        if trainer.running_sanity_check:
+        if trainer.running_sanity_check or trainer.evaluating:
             return
 
         if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py
index d41928cd55aea..0efaef9c660b7 100644
--- a/pytorch_lightning/callbacks/model_checkpoint.py
+++ b/pytorch_lightning/callbacks/model_checkpoint.py
@@ -220,6 +220,7 @@ def save_checkpoint(self, trainer, pl_module):
             or self.period < 1  # no models are saved
             or (epoch + 1) % self.period  # skip epoch
             or trainer.running_sanity_check  # don't save anything during sanity check
+            or trainer.evaluating  # don't save anything during evaluation: might delete the checkpoint being evaluated
             or self.last_global_step_saved == global_step  # already saved at the last step
         ):
             return
diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py
index 6582f16fd27be..b00dca548671f 100644
--- a/pytorch_lightning/callbacks/progress.py
+++ b/pytorch_lightning/callbacks/progress.py
@@ -282,9 +282,13 @@ def init_train_tqdm(self) -> tqdm:
 
     def init_validation_tqdm(self) -> tqdm:
         """ Override this to customize the tqdm bar for validation. """
+
+        # The main progress bar doesn't exist in trainer.validate(...)
+        has_main_bar = int(self.main_progress_bar is not None)
+
         bar = tqdm(
             desc='Validating',
-            position=(2 * self.process_position + 1),
+            position=(2 * self.process_position + has_main_bar),
             disable=self.is_disabled,
             leave=False,
             dynamic_ncols=True,
@@ -341,7 +345,10 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
     def on_validation_start(self, trainer, pl_module):
         super().on_validation_start(trainer, pl_module)
         if not trainer.running_sanity_check:
-            self._update_bar(self.main_progress_bar)  # fill up remaining
+            # The main progress bar doesn't exist in trainer.validate(...)
+            if self.main_progress_bar is not None:
+                self._update_bar(self.main_progress_bar)  # fill up remaining
+
             self.val_progress_bar = self.init_validation_tqdm()
             self.val_progress_bar.total = convert_inf(self.total_val_batches)
 
@@ -349,11 +356,18 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
         super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
         if self._should_update(self.val_batch_idx, self.total_val_batches):
             self._update_bar(self.val_progress_bar)
-            self._update_bar(self.main_progress_bar)
+
+            # The main progress bar doesn't exist in trainer.validate(...)
+            if self.main_progress_bar is not None:
+                self._update_bar(self.main_progress_bar)
 
     def on_validation_end(self, trainer, pl_module):
         super().on_validation_end(trainer, pl_module)
-        self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
+
+        # The main progress bar doesn't exist in trainer.validate(...)
+        if self.main_progress_bar is not None:
+            self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
+
         self.val_progress_bar.close()
 
     def on_train_end(self, trainer, pl_module):
diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py
index fe81d641c86d6..3ff9f4cf889d4 100644
--- a/pytorch_lightning/core/datamodule.py
+++ b/pytorch_lightning/core/datamodule.py
@@ -76,13 +76,16 @@ def wrapped_fn(*args, **kwargs):
         if fn.__name__ == "setup":
 
             # Get stage either by grabbing from args or checking kwargs.
-            # If not provided, set call status of 'fit' and 'test' to True.
+            # If not provided, set call status of 'fit', 'validation', and 'test' to True.
             # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
             stage = args[1] if len(args) > 1 else kwargs.get("stage", None)
 
             if stage == "fit" or stage is None:
                 obj._has_setup_fit = True
 
+            if stage == "validation" or stage is None:
+                obj._has_setup_validation = True
+
             if stage == "test" or stage is None:
                 obj._has_setup_test = True
 
@@ -155,6 +158,7 @@ def __init__(
         # Private attrs to keep track of whether or not data hooks have been called yet
         self._has_prepared_data = False
         self._has_setup_fit = False
+        self._has_setup_validation = False
         self._has_setup_test = False
 
     @property
@@ -230,6 +234,15 @@ def has_setup_fit(self):
         """
         return self._has_setup_fit
 
+    @property
+    def has_setup_validation(self):
+        """Return bool letting you know if datamodule.setup('validation') has been called or not.
+
+        Returns:
+            bool: True if datamodule.setup('validation') has been called. False by default.
+        """
+        return self._has_setup_validation
+
     @property
     def has_setup_test(self):
         """Return bool letting you know if datamodule.setup('test') has been called or not.
diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py
index 57979b73f2cb6..a4251484991f2 100644
--- a/pytorch_lightning/core/hooks.py
+++ b/pytorch_lightning/core/hooks.py
@@ -26,12 +26,12 @@ class ModelHooks:
     """Hooks to be used in LightningModule."""
     def setup(self, stage: str):
         """
-        Called at the beginning of fit and test.
+        Called at the beginning of fit (training + validation), validation, and test.
         This is a good hook when you need to build models dynamically or adjust something about them.
         This hook is called on every process when using DDP.
 
         Args:
-            stage: either 'fit' or 'test'
+            stage: either 'fit', 'validation', or 'test'
 
         Example::
 
@@ -54,10 +54,10 @@ def setup(stage):
 
     def teardown(self, stage: str):
         """
-        Called at the end of fit and test.
+        Called at the end of fit (training + validation), validation, and test.
 
         Args:
-            stage: either 'fit' or 'test'
+            stage: either 'fit', 'validation', or 'test'
         """
 
     def on_fit_start(self):
diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py
index 01c0119e857ec..23967dc1bc2a9 100644
--- a/pytorch_lightning/trainer/configuration_validator.py
+++ b/pytorch_lightning/trainer/configuration_validator.py
@@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule):
             model: The model to check the configuration.
 
         """
-        if not self.trainer.testing:
+        if not self.trainer.evaluating:
             self.__verify_train_loop_configuration(model)
             self.__verify_eval_loop_configuration(model, 'validation')
         else:
-            # check test loop configuration
-            self.__verify_eval_loop_configuration(model, 'test')
+            # check evaluation loop configurations
+            self.__verify_eval_loop_configuration(model, self.trainer.evaluating)
 
     def __verify_train_loop_configuration(self, model):
         # -----------------------------------
diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
index cab08edd58531..33ff30380eabb 100644
--- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
+++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py
@@ -265,7 +265,7 @@ def prepare_eval_loop_results(self):
         for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
             self.add_to_eval_loop_results(dl_idx, has_been_initialized)
 
-    def get_evaluate_epoch_results(self, test_mode):
+    def get_evaluate_epoch_results(self):
         if not self.trainer.running_sanity_check:
             # log all the metrics as a single dict
             metrics_to_log = self.cached_results.get_epoch_log_metrics()
@@ -274,11 +274,11 @@ def get_evaluate_epoch_results(self, test_mode):
 
         self.prepare_eval_loop_results()
 
-        # log results of test
-        if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
+        # log results of evaluation
+        if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate:
             print('-' * 80)
             for result_idx, results in enumerate(self.eval_loop_results):
-                print(f'DATALOADER:{result_idx} TEST RESULTS')
+                print(f'DATALOADER:{result_idx} {self.trainer.evaluating.upper()} RESULTS')
                 pprint(results)
                 print('-' * 80)
 
diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py
index c5a8c48357b44..c665ee971b885 100644
--- a/pytorch_lightning/trainer/connectors/model_connector.py
+++ b/pytorch_lightning/trainer/connectors/model_connector.py
@@ -36,7 +36,10 @@ def copy_trainer_model_properties(self, model):
             m.use_ddp2 = self.trainer.use_ddp2
             m.use_ddp = self.trainer.use_ddp
             m.use_amp = self.trainer.amp_backend is not None
-            m.testing = self.trainer.testing
+            # TODO: I only find usages of m.testing in DDP, where it's used to
+            #  discriminate test from validation, as opposed to test from fit in
+            #  Trainer. Still need to fully determine if it's correct.
+            m.testing = self.trainer.evaluating == 'test'
             m.use_single_gpu = self.trainer.use_single_gpu
             m.use_tpu = self.trainer.use_tpu
             m.tpu_local_core_rank = self.trainer.tpu_local_core_rank
diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py
index 097727a6bed78..11da428b83453 100644
--- a/pytorch_lightning/trainer/evaluation_loop.py
+++ b/pytorch_lightning/trainer/evaluation_loop.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 import torch
 
+import pytorch_lightning as pl
 from pytorch_lightning.core.step_result import EvalResult, Result
 from pytorch_lightning.trainer.supporters import PredictionCollection
 from pytorch_lightning.utilities.distributed import rank_zero_warn
@@ -22,7 +23,7 @@
 
 
 class EvaluationLoop(object):
-    def __init__(self, trainer):
+    def __init__(self, trainer: 'pl.Trainer'):
         self.trainer = trainer
         self.testing = False
         self.outputs = []
@@ -39,13 +40,15 @@ def on_trainer_init(self):
         self.trainer.test_dataloaders = None
         self.trainer.val_dataloaders = None
         self.trainer.running_sanity_check = False
-        self.trainer.testing = False
 
-        # when .test() is called, it sets this
-        self.trainer.tested_ckpt_path = None
+        # .validate() sets this to 'validation' and .test() sets this to 'test'
+        self.trainer.evaluating = None
 
-        # when true, prints test results
-        self.trainer.verbose_test = True
+        # .validate() and .test() set this when they load a checkpoint
+        self.trainer.evaluated_ckpt_path = None
+
+        # when true, print evaluation results in .validate() and .test()
+        self.trainer.verbose_evaluate = True
 
     def get_evaluation_dataloaders(self, max_batches):
         # select dataloaders
@@ -216,7 +219,7 @@ def evaluation_epoch_end(self):
 
     def log_epoch_metrics_on_evaluation_end(self):
         # get the final loop results
-        eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing)
+        eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results()
         return eval_loop_results
 
     def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index fd715988ef370..1bad441eb0083 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -441,10 +441,6 @@ def fit(
         # hook
         self.data_connector.prepare_data(model)
 
-        # bookkeeping
-        # we reuse fit in .test() but change its behavior using this flag
-        self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
-
         # ----------------------------
         # SET UP TRAINING
         # ----------------------------
@@ -659,11 +655,15 @@ def track_output_for_epoch_end(self, outputs, output):
             outputs.append(output)
         return outputs
 
-    def run_test(self):
+    def run_test_or_validate(self):
         # only load test dataloader for testing
         # self.reset_test_dataloader(ref_model)
-        with self.profiler.profile("run_test_evaluation"):
-            eval_loop_results, _ = self.run_evaluation(test_mode=True)
+        if self.evaluating == 'test':
+            with self.profiler.profile("run_test_evaluation"):
+                eval_loop_results, _ = self.run_evaluation(test_mode=True)
+        else:
+            with self.profiler.profile("run_validate_evaluation"):
+                eval_loop_results, _ = self.run_evaluation(test_mode=False)
 
         if len(eval_loop_results) == 0:
             return 1
@@ -711,42 +711,90 @@ def run_sanity_check(self, ref_model):
             self.on_sanity_check_end()
             self.running_sanity_check = False
 
-    def test(
+    def validate(
         self,
         model: Optional[LightningModule] = None,
-        test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
+        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
         ckpt_path: Optional[str] = 'best',
         verbose: bool = True,
         datamodule: Optional[LightningDataModule] = None,
     ):
         r"""
-
-        Separates from fit to make sure you never run on your test set until you want to.
+        Perform one evaluation epoch over the validation set.
 
         Args:
-            ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
-                If ``None``, use the weights from the last epoch to test. Default to ``best``.
-
+            ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
+                If ``None``, use the current weights of the model. Default to ``best``.
             datamodule: A instance of :class:`LightningDataModule`.
+            model: The model to evaluate.
+            val_dataloaders: Either a single PyTorch DataLoader or a list of them,
+                specifying validation samples.
+            verbose: If True, prints the validation results.
+
+        Returns:
+            The dictionary with final validation results returned by validation_epoch_end.
+            If validation_epoch_end is not defined, the output is a list of the dictionaries
+            returned by validation_step.
+        """
+        # --------------------
+        # SETUP HOOK
+        # --------------------
+        self.verbose_evaluate = verbose
+
+        self.logger_connector.set_stage("validation")
+
+        # If you supply a datamodule you can't supply val_dataloaders
+        if val_dataloaders and datamodule:
+            raise MisconfigurationException(
+                'You cannot pass val_dataloaders to trainer.validate if you supply a datamodule'
+            )
+
+        # Attach datamodule to get setup/prepare_data added to model before the call to it below
+        self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'validation')
+
+        if model is not None:
+            results = self.__evaluate_given_model(model, val_dataloaders, 'validation')
+        else:
+            results = self.__evaluate_using_best_weights(ckpt_path, val_dataloaders, 'validation')
+
+        self.teardown('validation')
 
-            model: The model to test.
+        return results
 
-            test_dataloaders: Either a single
-                Pytorch Dataloader or a list of them, specifying validation samples.
+    def test(
+        self,
+        model: Optional[LightningModule] = None,
+        test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
+        ckpt_path: Optional[str] = 'best',
+        verbose: bool = True,
+        datamodule: Optional[LightningDataModule] = None,
+    ):
+        r"""
+        Perform one evaluation epoch over the test set. It's separated from
+        fit to make sure you never run on your test set until you want to.
 
-            verbose: If True, prints the test results
+        Args:
+            ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
+                If ``None``, use the current weights of the model. Default to ``best``.
+            datamodule: A instance of :class:`LightningDataModule`.
+            model: The model to evaluate.
+            test_dataloaders: Either a single PyTorch DataLoader or a list of them,
+                specifying test samples.
+            verbose: If True, prints the test results.
 
         Returns:
-            The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
+            The dictionary with final test results returned by test_epoch_end.
+            If test_epoch_end is not defined, the output is a list of the dictionaries
+            returned by test_step.
         """
         # --------------------
         # SETUP HOOK
         # --------------------
-        self.verbose_test = verbose
+        self.verbose_evaluate = verbose
 
         self.logger_connector.set_stage("test")
 
-        # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
+        # If you supply a datamodule you can't supply test_dataloaders
         if test_dataloaders and datamodule:
             raise MisconfigurationException(
                 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
@@ -756,15 +804,15 @@ def test(
         self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')
 
         if model is not None:
-            results = self.__test_given_model(model, test_dataloaders)
+            results = self.__evaluate_given_model(model, test_dataloaders, 'test')
         else:
-            results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
+            results = self.__evaluate_using_best_weights(ckpt_path, test_dataloaders, 'test')
 
         self.teardown('test')
 
         return results
 
-    def __test_using_best_weights(self, ckpt_path, test_dataloaders):
+    def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str):
         model = self.get_model()
 
         # if user requests the best checkpoint but we don't have it, error
@@ -792,44 +840,62 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
             model.load_state_dict(ckpt['state_dict'])
 
         # attach dataloaders
-        if test_dataloaders is not None:
-            self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
+        if dataloaders is not None:
+            kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders}
+            self.data_connector.attach_dataloaders(model, **kwargs)
 
         # run tests
-        self.tested_ckpt_path = ckpt_path
-        self.testing = True
-        os.environ['PL_TESTING_MODE'] = '1'
+        self.evaluating = stage
+        self.evaluated_ckpt_path = ckpt_path
         self.model = model
         results = self.fit(model)
-        self.testing = False
-        del os.environ['PL_TESTING_MODE']
+        self.evaluating = None
 
         # teardown
         if self.is_function_implemented('teardown'):
             model_ref = self.get_model()
-            model_ref.teardown('test')
+            model_ref.teardown(stage)
 
         return results
 
-    def __test_given_model(self, model, test_dataloaders):
+    def __evaluate_given_model(self, model, dataloaders, stage: str):
 
         # attach data
-        if test_dataloaders is not None:
-            self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
+        if dataloaders is not None:
+            kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders}
+            self.data_connector.attach_dataloaders(model, **kwargs)
 
         # run test
         # sets up testing so we short circuit to eval
-        self.testing = True
+        self.evaluating = stage
         self.model = model
         results = self.fit(model)
-        self.testing = False
+        self.evaluating = None
 
         # teardown
         if self.is_function_implemented('teardown'):
-            model.teardown('test')
+            model.teardown(stage)
 
         return results
 
+    @property
+    def testing(self):
+        warnings.warn(
+            'Trainer.testing has been deprecated in v1.1 and will be removed '
+            'in v1.3, use Trainer.evaluating instead.',
+            DeprecationWarning, stacklevel=2
+        )
+        return bool(self.evaluating)
+
+    @property
+    def tested_ckpt_path(self):
+        warnings.warn(
+            'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path '
+            'in v1.1 and will be removed in v1.3.',
+            DeprecationWarning, stacklevel=2
+        )
+        return self.evaluated_ckpt_path
+
     def tune(
         self,
         model: LightningModule,
@@ -856,11 +922,18 @@ def tune(
 
     def call_setup_hook(self, model):
         # call setup after the ddp process has connected
-        stage_name = 'test' if self.testing else 'fit'
+        stage_name = self.evaluating or 'fit'
+
         if self.datamodule is not None:
-            called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
+            called = {
+                None: self.datamodule.has_setup_fit,
+                'validation': self.datamodule.has_setup_validation,
+                'test': self.datamodule.has_setup_test,
+            }[self.evaluating]
+
             if not called:
                 self.datamodule.setup(stage_name)
+
         self.setup(model, stage_name)
         model.setup(stage_name)
 
diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py
index 9a4f324033d39..ff19b9b8a9858 100644
--- a/pytorch_lightning/trainer/training_loop.py
+++ b/pytorch_lightning/trainer/training_loop.py
@@ -161,7 +161,7 @@ def setup_training(self, model: LightningModule):
             ref_model.on_pretrain_routine_start()
 
         # print model summary
-        if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
+        if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.evaluating:
             if self.trainer.weights_summary in ModelSummary.MODES:
                 ref_model.summarize(mode=self.trainer.weights_summary)
             else:
diff --git a/tests/backends/test_dp.py b/tests/backends/test_dp.py
index c051b442cb7a7..b697440280f80 100644
--- a/tests/backends/test_dp.py
+++ b/tests/backends/test_dp.py
@@ -67,7 +67,7 @@ def test_multi_gpu_model_dp(tmpdir):
 
 
 @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
-def test_dp_test(tmpdir):
+def test_dp_evaluate(tmpdir):
     tutils.set_random_master_port()
 
     import os
@@ -84,6 +84,22 @@ def test_dp_test(tmpdir):
     )
     trainer.fit(model)
     assert 'ckpt' in trainer.checkpoint_callback.best_model_path
+
+    # validate
+    results = trainer.validate()
+    assert 'val_acc' in results[0]
+
+    old_weights = model.c_d1.weight.clone().detach().cpu()
+
+    results = trainer.validate(model)
+    assert 'val_acc' in results[0]
+
+    # make sure weights didn't change
+    new_weights = model.c_d1.weight.clone().detach().cpu()
+
+    assert torch.all(torch.eq(old_weights, new_weights))
+
+    # test
     results = trainer.test()
     assert 'test_acc' in results[0]
 
diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py
index e4d0b4bff89d7..94e4ba9c1efe9 100644
--- a/tests/base/datamodules.py
+++ b/tests/base/datamodules.py
@@ -33,7 +33,7 @@ def prepare_data(self):
 
     def setup(self, stage: Optional[str] = None):
 
-        if stage == "fit" or stage is None:
+        if stage != 'test':
             mnist_full = TrialMNIST(
                 root=self.data_dir, train=True, num_samples=64, download=True
             )
@@ -88,7 +88,7 @@ def setup(self, stage: Optional[str] = None):
 
         # Assign train/val datasets for use in dataloaders
         # TODO: need to split using random_split once updated to torch >= 1.6
-        if stage == "fit" or stage is None:
+        if stage != 'test':
             self.mnist_train = MNIST(
                 self.data_dir, train=True, normalize=(0.1307, 0.3081)
             )
diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py
index bb740b1dcbb1c..6f427afef7728 100644
--- a/tests/callbacks/test_callbacks.py
+++ b/tests/callbacks/test_callbacks.py
@@ -101,6 +101,28 @@ def test_trainer_callback_system(torch_save):
         call.teardown(trainer, model, 'fit'),
     ]
 
+    callback_mock.reset_mock()
+    trainer = Trainer(**trainer_options)
+    trainer.validate(model)
+
+    assert callback_mock.method_calls == [
+        call.on_init_start(trainer),
+        call.on_init_end(trainer),
+        call.setup(trainer, model, 'validation'),
+        call.on_fit_start(trainer, model),
+        call.on_pretrain_routine_start(trainer, model),
+        call.on_pretrain_routine_end(trainer, model),
+        call.on_validation_start(trainer, model),
+        call.on_validation_epoch_start(trainer, model),
+        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
+        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
+        call.on_validation_epoch_end(trainer, model),
+        call.on_validation_end(trainer, model),
+        call.on_fit_end(trainer, model),
+        call.teardown(trainer, model, 'fit'),
+        call.teardown(trainer, model, 'validation'),
+    ]
+
     callback_mock.reset_mock()
     trainer = Trainer(**trainer_options)
     trainer.test(model)
diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py
index 3c19748765e52..988da6f233dd2 100644
--- a/tests/callbacks/test_progress_bar.py
+++ b/tests/callbacks/test_progress_bar.py
@@ -84,7 +84,7 @@ def test_progress_bar_totals(tmpdir):
         limit_val_batches=1.0,
         max_epochs=1,
     )
-    bar = trainer.progress_bar_callback
+    bar: ProgressBar = trainer.progress_bar_callback
     assert 0 == bar.total_train_batches
     assert 0 == bar.total_val_batches
     assert 0 == bar.total_test_batches
@@ -113,6 +113,17 @@ def test_progress_bar_totals(tmpdir):
     assert 0 == bar.total_test_batches
     assert bar.test_progress_bar is None
 
+    trainer.validate(model)
+
+    # check validation progress bar total
+    k = bar.total_val_batches
+    assert sum(len(loader) for loader in trainer.val_dataloaders) == k
+    assert bar.val_progress_bar.total == k
+
+    # validation progress bar should have reached the end
+    assert bar.val_progress_bar.n == k
+    assert bar.val_batch_idx == k
+
     trainer.test(model)
 
     # check test progress bar total
@@ -135,7 +146,7 @@ def test_progress_bar_fast_dev_run(tmpdir):
 
     trainer.fit(model)
 
-    progress_bar = trainer.progress_bar_callback
+    progress_bar: ProgressBar = trainer.progress_bar_callback
     assert 1 == progress_bar.total_train_batches
     # total val batches are known only after val dataloaders have reloaded
 
@@ -150,6 +161,13 @@ def test_progress_bar_fast_dev_run(tmpdir):
     assert 2 == progress_bar.main_progress_bar.total
     assert 2 == progress_bar.main_progress_bar.n
 
+    trainer.validate(model)
+
+    # the validation progress bar should display 1 batch
+    assert 1 == progress_bar.val_batch_idx
+    assert 1 == progress_bar.val_progress_bar.total
+    assert 1 == progress_bar.val_progress_bar.n
+
     trainer.test(model)
 
     # the test progress bar should display 1 batch
@@ -207,8 +225,16 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
     trainer.fit(model)
     assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches
     assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps
+    assert progress_bar.test_batches_seen == 0
+
+    trainer.validate(model)
+    assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches
+    assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps
+    assert progress_bar.test_batches_seen == 0
 
     trainer.test(model)
+    assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches
+    assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps
     assert progress_bar.test_batches_seen == progress_bar.total_test_batches
 
 
diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py
index 33bc19a894d8f..e3e6dfe4ceddc 100644
--- a/tests/checkpointing/test_model_checkpoint.py
+++ b/tests/checkpointing/test_model_checkpoint.py
@@ -797,6 +797,9 @@ def get_model():
     assert trainer.current_epoch == epochs - 1
     assert_checkpoint_log_dir(0)
 
+    trainer.validate(model)
+    assert trainer.current_epoch == epochs - 1
+
     trainer.test(model)
     assert trainer.current_epoch == epochs - 1
 
@@ -817,6 +820,11 @@ def get_model():
         )
         assert_trainer_init(trainer)
 
+        trainer.validate(model)
+        assert not trainer.checkpoint_connector.has_trained
+        assert trainer.global_step == epochs * limit_train_batches
+        assert trainer.current_epoch == epochs
+
         trainer.test(model)
         assert not trainer.checkpoint_connector.has_trained
         assert trainer.global_step == epochs * limit_train_batches
diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py
index 3e683025e8867..32f4aebe445d4 100644
--- a/tests/core/test_datamodules.py
+++ b/tests/core/test_datamodules.py
@@ -111,6 +111,7 @@ def test_base_datamodule_with_verbose_setup(tmpdir):
     dm = TrialMNISTDataModule()
     dm.prepare_data()
     dm.setup('fit')
+    dm.setup('validation')
     dm.setup('test')
 
 
@@ -118,16 +119,19 @@ def test_data_hooks_called(tmpdir):
     dm = TrialMNISTDataModule()
     assert dm.has_prepared_data is False
     assert dm.has_setup_fit is False
+    assert dm.has_setup_validation is False
     assert dm.has_setup_test is False
 
     dm.prepare_data()
     assert dm.has_prepared_data is True
     assert dm.has_setup_fit is False
+    assert dm.has_setup_validation is False
     assert dm.has_setup_test is False
 
     dm.setup()
     assert dm.has_prepared_data is True
     assert dm.has_setup_fit is True
+    assert dm.has_setup_validation is True
     assert dm.has_setup_test is True
 
 
@@ -135,21 +139,31 @@ def test_data_hooks_called_verbose(tmpdir):
     dm = TrialMNISTDataModule()
     assert dm.has_prepared_data is False
     assert dm.has_setup_fit is False
+    assert dm.has_setup_validation is False
     assert dm.has_setup_test is False
 
     dm.prepare_data()
     assert dm.has_prepared_data is True
     assert dm.has_setup_fit is False
+    assert dm.has_setup_validation is False
     assert dm.has_setup_test is False
 
     dm.setup('fit')
     assert dm.has_prepared_data is True
     assert dm.has_setup_fit is True
+    assert dm.has_setup_validation is False
+    assert dm.has_setup_test is False
+
+    dm.setup('validation')
+    assert dm.has_prepared_data is True
+    assert dm.has_setup_fit is True
+    assert dm.has_setup_validation is True
     assert dm.has_setup_test is False
 
     dm.setup('test')
     assert dm.has_prepared_data is True
     assert dm.has_setup_fit is True
+    assert dm.has_setup_validation is True
     assert dm.has_setup_test is True
 
 
@@ -160,10 +174,17 @@ def test_data_hooks_called_with_stage_kwarg(tmpdir):
 
     dm.setup(stage='fit')
     assert dm.has_setup_fit is True
+    assert dm.has_setup_validation is False
+    assert dm.has_setup_test is False
+
+    dm.setup(stage='validation')
+    assert dm.has_setup_fit is True
+    assert dm.has_setup_validation is True
     assert dm.has_setup_test is False
 
     dm.setup(stage='test')
     assert dm.has_setup_fit is True
+    assert dm.has_setup_validation is True
     assert dm.has_setup_test is True
 
 
@@ -254,6 +275,21 @@ def test_dm_checkpoint_save(tmpdir):
     assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__
 
 
+def test_validate_loop_only(tmpdir):
+    reset_seed()
+
+    dm = TrialMNISTDataModule(tmpdir)
+
+    model = EvalModelTemplate()
+
+    trainer = Trainer(
+        default_root_dir=tmpdir,
+        max_epochs=3,
+        weights_summary=None,
+    )
+    trainer.validate(model, datamodule=dm)
+
+
 def test_test_loop_only(tmpdir):
     reset_seed()
 
@@ -287,6 +323,11 @@ def test_full_loop(tmpdir):
     result = trainer.fit(model, dm)
     assert result == 1
 
+    # validate
+    result = trainer.validate(datamodule=dm)
+    result = result[0]
+    assert result['val_acc'] > 0.8
+
     # test
     result = trainer.test(datamodule=dm)
     result = result[0]
@@ -312,6 +353,11 @@ def test_trainer_attached_to_dm(tmpdir):
     assert result == 1
     assert dm.trainer is not None
 
+    # validate
+    result = trainer.validate(datamodule=dm)
+    result = result[0]
+    assert dm.trainer is not None
+
     # test
     result = trainer.test(datamodule=dm)
     result = result[0]
@@ -338,6 +384,11 @@ def test_full_loop_single_gpu(tmpdir):
     result = trainer.fit(model, dm)
     assert result == 1
 
+    # validate
+    result = trainer.validate(datamodule=dm)
+    result = result[0]
+    assert result['val_acc'] > 0.8
+
     # test
     result = trainer.test(datamodule=dm)
     result = result[0]
@@ -365,6 +416,11 @@ def test_full_loop_dp(tmpdir):
     result = trainer.fit(model, dm)
     assert result == 1
 
+    # validate
+    result = trainer.validate(datamodule=dm)
+    result = result[0]
+    assert result['val_acc'] > 0.8
+
     # test
     result = trainer.test(datamodule=dm)
     result = result[0]
diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py
index 1ab97304f2338..b724fc8587e24 100755
--- a/tests/trainer/test_config_validator.py
+++ b/tests/trainer/test_config_validator.py
@@ -19,9 +19,6 @@
 from tests.base import EvalModelTemplate
 
 
-# TODO: add matching messages
-
-
 def test_wrong_train_setting(tmpdir):
     """
     * Test that an error is thrown when no `train_dataloader()` is defined
@@ -31,12 +28,12 @@ def test_wrong_train_setting(tmpdir):
     hparams = EvalModelTemplate.get_default_hparams()
     trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
 
-    with pytest.raises(MisconfigurationException):
+    with pytest.raises(MisconfigurationException, match=r'No `train_dataloader\(\)` method defined.'):
         model = EvalModelTemplate(**hparams)
         model.train_dataloader = None
         trainer.fit(model)
 
-    with pytest.raises(MisconfigurationException):
+    with pytest.raises(MisconfigurationException, match=r'No `training_step\(\)` method defined.'):
         model = EvalModelTemplate(**hparams)
         model.training_step = None
         trainer.fit(model)
@@ -47,7 +44,7 @@ def test_wrong_configure_optimizers(tmpdir):
     tutils.reset_seed()
     trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
 
-    with pytest.raises(MisconfigurationException):
+    with pytest.raises(MisconfigurationException, match=r'No `configure_optimizers\(\)` method defined.'):
         model = EvalModelTemplate()
         model.configure_optimizers = None
         trainer.fit(model)
@@ -62,13 +59,13 @@ def test_val_loop_config(tmpdir):
     trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
 
     # no val data has val loop
-    with pytest.warns(UserWarning):
+    with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'):
         model = EvalModelTemplate(**hparams)
         model.validation_step = None
         trainer.fit(model)
 
     # has val loop but no val data
-    with pytest.warns(UserWarning):
+    with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'):
         model = EvalModelTemplate(**hparams)
         model.val_dataloader = None
         trainer.fit(model)
@@ -82,13 +79,33 @@ def test_test_loop_config(tmpdir):
     trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
 
     # has test loop but no test data
-    with pytest.warns(UserWarning):
+    with pytest.warns(UserWarning, match=r'you defined a test_step but have no test_dataloader'):
         model = EvalModelTemplate(**hparams)
         model.test_dataloader = None
         trainer.test(model)
 
     # has test data but no test loop
-    with pytest.warns(UserWarning):
+    with pytest.warns(UserWarning, match=r'you passed in a test_dataloader but have no test_step'):
         model = EvalModelTemplate(**hparams)
         model.test_step = None
         trainer.test(model, test_dataloaders=model.dataloader(train=False))
+
+
+def test_validation_loop_config(tmpdir):
+    """"
+    When either validation loop or validation data are missing
+    """
+    hparams = EvalModelTemplate.get_default_hparams()
+    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+
+    # has val loop but no val data
+    with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'):
+        model = EvalModelTemplate(**hparams)
+        model.val_dataloader = None
+        trainer.validate(model)
+
+    # has val data but no val loop
+    with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'):
+        model = EvalModelTemplate(**hparams)
+        model.validation_step = None
+        trainer.validate(model, val_dataloaders=model.dataloader(train=False))
diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py
index f16ef22faa507..d0b838b5fbf45 100644
--- a/tests/trainer/test_dataloaders.py
+++ b/tests/trainer/test_dataloaders.py
@@ -172,6 +172,48 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
     trainer.test(ckpt_path=ckpt_path)
 
 
+@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
+def test_multiple_validate_dataloader(tmpdir, ckpt_path):
+    """Verify multiple val_dataloaders."""
+
+    model_template = EvalModelTemplate()
+
+    class MultipleValDataloaderModel(EvalModelTemplate):
+        def val_dataloader(self):
+            return model_template.val_dataloader__multiple()
+
+        def validation_step(self, batch, batch_idx, *args, **kwargs):
+            return model_template.validation_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs)
+
+        def validation_epoch_end(self, outputs):
+            return model_template.validation_epoch_end__multiple_dataloaders(outputs)
+
+    model = MultipleValDataloaderModel()
+
+    # fit model
+    trainer = Trainer(
+        default_root_dir=tmpdir,
+        max_epochs=1,
+        limit_val_batches=0.1,
+        limit_train_batches=0.2,
+    )
+    trainer.fit(model)
+    if ckpt_path == 'specific':
+        ckpt_path = trainer.checkpoint_callback.best_model_path
+    trainer.validate(ckpt_path=ckpt_path)
+
+    # verify there are 2 test loaders
+    assert len(trainer.val_dataloaders) == 2, \
+        'Multiple val_dataloaders not initiated properly'
+
+    # make sure predictions are good for each test set
+    for dataloader in trainer.val_dataloaders:
+        tpipes.run_prediction(dataloader, trainer.model)
+
+    # run the validate method
+    trainer.validate(ckpt_path=ckpt_path)
+
+
 def test_train_dataloader_passed_to_fit(tmpdir):
     """Verify that train dataloader can be passed to fit """
 
diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py
index 2e76192836740..27f0bcda66926 100644
--- a/tests/trainer/test_optimizers.py
+++ b/tests/trainer/test_optimizers.py
@@ -337,6 +337,24 @@ def test_init_optimizers_during_testing(tmpdir):
     assert len(trainer.optimizer_frequencies) == 0
 
 
+def test_init_optimizers_during_validation(tmpdir):
+    """
+    Test that optimizers is an empty list during validation.
+    """
+    model = EvalModelTemplate()
+    model.configure_optimizers = model.configure_optimizers__multiple_schedulers
+
+    trainer = Trainer(
+        default_root_dir=tmpdir,
+        limit_test_batches=10
+    )
+    trainer.validate(model, ckpt_path=None)
+
+    assert len(trainer.lr_schedulers) == 0
+    assert len(trainer.optimizers) == 0
+    assert len(trainer.optimizer_frequencies) == 0
+
+
 def test_multiple_optimizers_callbacks(tmpdir):
     """
     Tests that multiple optimizers can be used with callbacks
diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py
index 0244f654227a2..f6e29b7187d61 100644
--- a/tests/trainer/test_states.py
+++ b/tests/trainer/test_states.py
@@ -23,7 +23,7 @@ class StateSnapshotCallback(Callback):
 
     def __init__(self, snapshot_method: str):
         super().__init__()
-        assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
+        assert snapshot_method in ['on_batch_start', 'on_validation_batch_start', 'on_test_batch_start']
         self.snapshot_method = snapshot_method
         self.trainer_state = None
 
@@ -31,6 +31,10 @@ def on_batch_start(self, trainer, pl_module):
         if self.snapshot_method == 'on_batch_start':
             self.trainer_state = trainer.state
 
+    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
+        if self.snapshot_method == 'on_validation_batch_start':
+            self.trainer_state = trainer.state
+
     def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
         if self.snapshot_method == 'on_test_batch_start':
             self.trainer_state = trainer.state
@@ -191,6 +195,40 @@ def test_finished_state_after_test(tmpdir):
     assert trainer.state == TrainerState.FINISHED
 
 
+def test_running_state_during_validation(tmpdir):
+    """ Tests that state is set to RUNNING during test """
+
+    hparams = EvalModelTemplate.get_default_hparams()
+    model = EvalModelTemplate(**hparams)
+
+    snapshot_callback = StateSnapshotCallback(snapshot_method='on_validation_batch_start')
+
+    trainer = Trainer(
+        callbacks=[snapshot_callback],
+        default_root_dir=tmpdir,
+        fast_dev_run=True,
+    )
+
+    trainer.validate(model)
+
+    assert snapshot_callback.trainer_state == TrainerState.RUNNING
+
+
+def test_finished_state_after_validation(tmpdir):
+    """ Tests that state is FINISHED after fit """
+    hparams = EvalModelTemplate.get_default_hparams()
+    model = EvalModelTemplate(**hparams)
+
+    trainer = Trainer(
+        default_root_dir=tmpdir,
+        fast_dev_run=True,
+    )
+
+    trainer.validate(model)
+
+    assert trainer.state == TrainerState.FINISHED
+
+
 @pytest.mark.parametrize("extra_params", [
     pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
     pytest.param(dict(max_steps=1), id='Single-Step'),
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 085d361952844..5cf1bb17218a2 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -728,12 +728,12 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
                 trainer.test(ckpt_path=ckpt_path)
         else:
             trainer.test(ckpt_path=ckpt_path)
-            assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path
+            assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path
     elif ckpt_path is None:
         # ckpt_path is None, meaning we don't load any checkpoints and
         # use the weights from the end of training
         trainer.test(ckpt_path=ckpt_path)
-        assert trainer.tested_ckpt_path is None
+        assert trainer.evaluated_ckpt_path is None
     else:
         # specific checkpoint, pick one from saved ones
         if save_top_k == 0:
@@ -746,7 +746,48 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
                 ].absolute()
             )
             trainer.test(ckpt_path=ckpt_path)
-            assert trainer.tested_ckpt_path == ckpt_path
+            assert trainer.evaluated_ckpt_path == ckpt_path
+
+
+@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
+@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
+def test_validate_checkpoint_path(tmpdir, ckpt_path, save_top_k):
+    hparams = EvalModelTemplate.get_default_hparams()
+
+    model = EvalModelTemplate(**hparams)
+    trainer = Trainer(
+        max_epochs=2,
+        progress_bar_refresh_rate=0,
+        default_root_dir=tmpdir,
+        checkpoint_callback=ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k),
+    )
+    trainer.fit(model)
+    if ckpt_path == "best":
+        # ckpt_path is 'best', meaning we load the best weights
+        if save_top_k == 0:
+            with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"):
+                trainer.validate(ckpt_path=ckpt_path)
+        else:
+            trainer.validate(ckpt_path=ckpt_path)
+            assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path
+    elif ckpt_path is None:
+        # ckpt_path is None, meaning we don't load any checkpoints and
+        # use the weights from the end of training
+        trainer.validate(ckpt_path=ckpt_path)
+        assert trainer.evaluated_ckpt_path is None
+    else:
+        # specific checkpoint, pick one from saved ones
+        if save_top_k == 0:
+            with pytest.raises(FileNotFoundError):
+                trainer.validate(ckpt_path="random.ckpt")
+        else:
+            ckpt_path = str(
+                list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir())[
+                    0
+                ].absolute()
+            )
+            trainer.validate(ckpt_path=ckpt_path)
+            assert trainer.evaluated_ckpt_path == ckpt_path
 
 
 def test_disabled_training(tmpdir):
@@ -1450,6 +1491,10 @@ def setup(self, model, stage):
     assert trainer.stage == "test"
     assert trainer.get_model().stage == "test"
 
+    trainer.validate(ckpt_path=None)
+    assert trainer.stage == "validation"
+    assert trainer.get_model().stage == "validation"
+
 
 @pytest.mark.parametrize(
     "train_batches, max_steps, log_interval",
diff --git a/tests/trainer/test_trainer_validate_loop.py b/tests/trainer/test_trainer_validate_loop.py
new file mode 100644
index 0000000000000..a2205a4b50dc2
--- /dev/null
+++ b/tests/trainer/test_trainer_validate_loop.py
@@ -0,0 +1,76 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+import torch
+
+import pytorch_lightning as pl
+import tests.base.develop_utils as tutils
+from tests.base import EvalModelTemplate
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
+def test_single_gpu_validate(tmpdir):
+    tutils.set_random_master_port()
+
+    model = EvalModelTemplate()
+    trainer = pl.Trainer(
+        default_root_dir=tmpdir,
+        max_epochs=2,
+        limit_train_batches=10,
+        limit_val_batches=10,
+        gpus=[0],
+    )
+    trainer.fit(model)
+    assert 'ckpt' in trainer.checkpoint_callback.best_model_path
+    results = trainer.validate()
+    assert 'val_acc' in results[0]
+
+    old_weights = model.c_d1.weight.clone().detach().cpu()
+
+    results = trainer.validate(model)
+    assert 'val_acc' in results[0]
+
+    # make sure weights didn't change
+    new_weights = model.c_d1.weight.clone().detach().cpu()
+
+    assert torch.all(torch.eq(old_weights, new_weights))
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
+def test_ddp_spawn_validate(tmpdir):
+    tutils.set_random_master_port()
+
+    model = EvalModelTemplate()
+    trainer = pl.Trainer(
+        default_root_dir=tmpdir,
+        max_epochs=2,
+        limit_train_batches=10,
+        limit_val_batches=10,
+        gpus=[0, 1],
+        distributed_backend='ddp_spawn',
+    )
+    trainer.fit(model)
+    assert 'ckpt' in trainer.checkpoint_callback.best_model_path
+    results = trainer.validate()
+    assert 'val_acc' in results[0]
+
+    old_weights = model.c_d1.weight.clone().detach().cpu()
+
+    results = trainer.validate(model)
+    assert 'val_acc' in results[0]
+
+    # make sure weights didn't change
+    new_weights = model.c_d1.weight.clone().detach().cpu()
+
+    assert torch.all(torch.eq(old_weights, new_weights))