Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove legacy Result parameters #6016

Merged
merged 20 commits into from
Mar 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))


- Removed legacy references for magic keys in the `Result` object ([#6016](https://github.com/PyTorchLightning/pytorch-lightning/pull/6016))


- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))


Expand Down
1 change: 0 additions & 1 deletion docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,6 @@ For cases like production, you might want to iterate different models inside a L
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)

# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
metrics = {'val_acc': acc, 'val_loss': loss}
self.log_dict(metrics)
return metrics
Expand Down
8 changes: 1 addition & 7 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1478,15 +1478,9 @@ with the hidden
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)

# remember to detach() hiddens.
# If you don't, you will get a RuntimeError: Trying to backward through
# the graph a second time...
# Using hiddens.detach() allows each split to be disconnected.

return {
"loss": ...,
"hiddens": hiddens # remember to detach() this
"hiddens": hiddens
}

To modify how the batch is split,
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode
self.warned_result_obj = False

if self.mode not in self.mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def __init__(
self.best_model_path = ""
self.last_model_path = ""
self.save_function = None
self.warned_result_obj = False

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
Expand Down
79 changes: 15 additions & 64 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""
"""Result class for easier logging and epoch-wise reduction."""

import numbers
import os
from copy import copy
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union

Expand All @@ -27,33 +26,14 @@

class Result(Dict):

def __init__(
self,
minimize: Optional[Tensor] = None,
early_stop_on: Optional[Tensor] = None,
checkpoint_on: Optional[Union[Tensor, bool]] = None,
hiddens: Optional[Tensor] = None,
):

def __init__(self, minimize: Optional[Tensor] = None):
super().__init__()

# temporary until dict results are deprecated
os.environ['PL_USING_RESULT_OBJ'] = '1'

if early_stop_on is not None:
self.early_stop_on = early_stop_on
if checkpoint_on is not None and checkpoint_on:
self.checkpoint_on = checkpoint_on
if hiddens is not None:
self.hiddens = hiddens.detach()
if minimize is not None:
err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
self._assert_grad_tensor_metric('minimize', minimize, err)
self.minimize = minimize

if minimize is not None and checkpoint_on is None:
self.checkpoint_on = minimize.detach()

self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}}

def __getitem__(self, key: Union[str, Any]) -> Any:
Expand All @@ -64,9 +44,7 @@ def __getitem__(self, key: Union[str, Any]) -> Any:

def __getattr__(self, key: str) -> Any:
try:
if key == 'callback_metrics':
return self.get_callback_metrics()
elif key == 'batch_log_metrics':
if key == 'batch_log_metrics':
return self.get_batch_log_metrics()
elif key == 'batch_pbar_metrics':
return self.get_batch_pbar_metrics()
Expand All @@ -80,16 +58,9 @@ def __getattr__(self, key: str) -> Any:
return None

def __setattr__(self, key: str, val: Union[Tensor, Any]):
# ensure reserve keys are tensors and detached
if key in {'checkpoint_on', 'early_stop_on'}:
self._assert_tensor_metric(key, val)
if val is not None and isinstance(val, torch.Tensor):
val = val.detach()

# ensure anything else that is a tensor is detached
elif isinstance(val, torch.Tensor) and key != 'minimize':
# ensure tensors are detached
if isinstance(val, torch.Tensor) and key != 'minimize':
val = val.detach()

self[key] = val

def __getstate__(self):
Expand All @@ -98,11 +69,6 @@ def __getstate__(self):
def __setstate__(self, d):
self.update(d)

def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
if potential_metric is not None and not isinstance(potential_metric, bool):
if not isinstance(potential_metric, Tensor):
raise TypeError(f'{name} must be a torch.Tensor')

def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''):
if x is not None:
if not isinstance(x, Tensor):
Expand Down Expand Up @@ -272,11 +238,6 @@ def get_batch_sizes(self):
meta = self['meta']
return torch.tensor(meta['_internal']['batch_sizes'])

def get_callback_metrics(self) -> dict:
result = {'early_stop_on': self.early_stop_on, 'checkpoint_on': self.checkpoint_on}

return result

def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str:
if dataloader_idx is not None and add_dataloader_idx:
return f"{k}/dataloader_idx_{dataloader_idx}"
Expand Down Expand Up @@ -495,25 +456,22 @@ def padded_gather(cls, outputs):
# find the padding used for other values
default_padding_idx = 0
for name, value in result.items():
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}:
default_padding_idx = meta[name]['tbptt_pad_token']
break
if (
name != 'minimize' and isinstance(value, list) and len(value) > 0
and isinstance(value[0], torch.Tensor)
):
default_padding_idx = meta[name]['tbptt_pad_token']
break

# pad across each key individually
for name, value in result.items():
is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'}
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):

if is_reserved:
padding_key = default_padding_idx
else:
padding_key = meta[name]['tbptt_pad_token']
if (isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor)):
padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token']
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
result[name] = padded

# also update the result
if meta and not is_reserved:
if meta and name != "minimize":
meta[name]['value'] = padded
if meta:
result['meta'] = meta
Expand Down Expand Up @@ -581,10 +539,7 @@ def reduce_across_time(cls, time_outputs):
continue

# pick the reduce fx
if k in ['checkpoint_on', 'early_stop_on', 'minimize']:
tbptt_reduce_fx = torch.mean
else:
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
tbptt_reduce_fx = torch.mean if k == "minimize" else meta[k]['tbptt_reduce_fx']

if isinstance(value, list):
value = torch.tensor(value)
Expand Down Expand Up @@ -612,10 +567,6 @@ def dp_reduce(self):
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']

def drop_hiddens(self):
if 'hiddens' in self:
del self['hiddens']

def rename_keys(self, map_dict: dict):
"""
Maps key values to the target values. Useful when renaming variables in mass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,6 @@ def _track_callback_metrics(self, eval_results):
elif isinstance(eval_result, dict):
flat = flatten_dict(eval_result)

# removing val_loss magic word to map to checkpoint + ES callback
if 'val_loss' in flat:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
Expand All @@ -331,11 +327,6 @@ def _track_callback_metrics(self, eval_results):
else:
flat = flatten_dict(eval_results)

# removing val_loss magic word to map to checkpoint + ES callback
if 'val_loss' in flat:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']

self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
Expand Down Expand Up @@ -370,26 +361,13 @@ def on_train_epoch_end(self):
# inform cached logger connector epoch finished
self.cached_results.has_batch_loop_finished = True

def log_train_epoch_end_metrics(
self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers
):
def log_train_epoch_end_metrics(self, epoch_output, num_optimizers):
# epoch output is a list. Each item in that list has all the outputs per optimizer
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
# remember that not using truncated backprop is equivalent with truncated back prop of len(1)

model = self.trainer.lightning_module

epoch_callback_metrics = {}

# -----------------------
# Calculate epoch callback values if given
# -----------------------
if checkpoint_accumulator.num_values > 0:
epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean()

if early_stopping_accumulator.num_values > 0:
epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()

# ------------------------
# determine if using a result obj
# ------------------------
Expand Down Expand Up @@ -437,9 +415,6 @@ def log_train_epoch_end_metrics(
self.log_metrics(epoch_log_metrics, {})
self._callback_metrics.update(epoch_log_metrics)

# add metrics to callbacks
self._callback_metrics.update(epoch_callback_metrics)

# add metrics to progress_bar and callbacks
if len(epoch_progress_bar_metrics) > 0:
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
Expand Down
9 changes: 0 additions & 9 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,6 @@ def __gather_epoch_end_eval_results(self, outputs):
eval_results = []
for epoch_output in outputs:
result = epoch_output[0].__class__.gather(epoch_output)
if 'checkpoint_on' in result:
result.checkpoint_on = result.checkpoint_on.mean()
if 'early_stop_on' in result:
result.early_stop_on = result.early_stop_on.mean()

eval_results.append(result)

# with 1 dataloader don't pass in a list
Expand All @@ -269,10 +264,6 @@ def __auto_reduce_result_objs(self, outputs):
for dl_output in outputs:
result = dl_output[0]
result = result.__class__.reduce_on_epoch_end(dl_output)
if 'checkpoint_on' in result:
result.checkpoint_on = result.checkpoint_on.mean()
if 'early_stop_on' in result:
result.early_stop_on = result.early_stop_on.mean()
eval_results.append(result)

return eval_results
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import inspect
from abc import ABC
from typing import Mapping
from collections import Mapping

import torch

Expand Down Expand Up @@ -76,10 +76,7 @@ def process_dict_result(self, output, train=False):
# --------------------------
# single scalar returned from a xx_step
if isinstance(output, torch.Tensor):
progress_bar_metrics = {}
log_metrics = {}
hiddens = None
return output, progress_bar_metrics, log_metrics, hiddens
return output, {}, {}, None

# ---------------
# EXTRACT PROGRESS BAR KEYS
Expand Down Expand Up @@ -140,6 +137,8 @@ def process_dict_result(self, output, train=False):
# EXTRACT HIDDEN
# ---------------
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None
if hiddens is not None:
hiddens = hiddens.detach()

# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
Expand Down
15 changes: 0 additions & 15 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,6 @@ def _agg_memory(self, how: str):
return getattr(self.memory[:self.current_idx], how)()


class Accumulator(object):
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
self.num_values = 0
self.total = 0

def accumulate(self, x):
with torch.no_grad():
self.total += x
self.num_values += 1

def mean(self):
return self.total / self.num_values


class PredictionCollection(object):

def __init__(self, global_rank: int, world_size: int):
Expand Down
Loading