Skip to content

Commit

Permalink
Update tests and CHANGELOG
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jun 30, 2021
1 parent 5e0631d commit 3d748cb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
10 changes: 4 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Fault-tolerant training
* Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197))


- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
- Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Add `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
- Added `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))
Expand Down Expand Up @@ -123,9 +124,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014))


- Added `state_dict` and `load_state_dict` function to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197))


- Added `should_raise_exception` parameter to `parse_gpu_ids`, `parse_tpu_cores` and `_sanitize_gpu_ids` utility functions ([#8194](https://github.com/PyTorchLightning/pytorch-lightning/pull/8194))


Expand Down
22 changes: 9 additions & 13 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import ANY

import pytest

Expand All @@ -21,31 +20,28 @@


def test_loops_state_dict_structure():

fit_loop = FitLoop()
state_dict = fit_loop.state_dict()
with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"):
fit_loop.connect(object())
fit_loop.connect(object()) # noqa

fit_loop.connect(Trainer())
state_dict = fit_loop.state_dict()
expected = {'epoch_loop': {'batch_loop': ANY, 'val_loop': ANY}}
assert state_dict == expected

fit_loop.load_state_dict(state_dict)
new_fit_loop = FitLoop()
new_fit_loop.load_state_dict(state_dict)
assert fit_loop.state_dict() == new_fit_loop.state_dict()


def test_loops_state_dict_structure_with_trainer():

trainer = Trainer()
state_dict = trainer.get_loops_state_dict()
expected = {
"fit_loop": {
'epoch_loop': {
'batch_loop': ANY,
'val_loop': ANY
'batch_loop': {},
'val_loop': {},
}
},
"validate_loop": ANY,
"test_loop": ANY
"validate_loop": {},
"test_loop": {},
}
assert state_dict == expected

0 comments on commit 3d748cb

Please sign in to comment.