Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Shorthand notation to instantiate callbacks [3/3] #8815

Merged
merged 88 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d7f00be
add registries
tchaton Aug 9, 2021
a07d305
simplify LightningCLI with defaults
tchaton Aug 9, 2021
ce39c47
cleanup
tchaton Aug 9, 2021
3081475
update
tchaton Aug 9, 2021
51f82d5
updates
tchaton Aug 10, 2021
7197d6e
cleanup
tchaton Aug 10, 2021
9a6e81e
update on comments
tchaton Aug 10, 2021
41f5d78
update
tchaton Aug 10, 2021
06e4999
cleanup
tchaton Aug 10, 2021
e91ea47
update on comments
tchaton Aug 10, 2021
3f35ecd
Merge branch 'master' into lightning_cli_registries
tchaton Aug 10, 2021
705c0bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2021
78a4398
add docs
tchaton Aug 10, 2021
e96dc28
doc updates
tchaton Aug 10, 2021
631aa72
update
tchaton Aug 10, 2021
2fc3c0a
update
tchaton Aug 10, 2021
43dd8b4
resolve comments
tchaton Aug 10, 2021
c6ae669
comment
tchaton Aug 10, 2021
5c21b1c
add comment
tchaton Aug 10, 2021
b370deb
typo
tchaton Aug 10, 2021
f8e7ca7
update on comments
tchaton Aug 11, 2021
e428d2f
resolve bug
tchaton Aug 11, 2021
3e97905
typo
tchaton Aug 11, 2021
3b1bdb6
update
tchaton Aug 11, 2021
0d1db29
resolve comments
tchaton Aug 11, 2021
3d35c82
add unittesting
tchaton Aug 11, 2021
4c0f960
resolve tests
tchaton Aug 11, 2021
d3a62ca
resolve comments
tchaton Aug 12, 2021
39781a1
update on comments
tchaton Aug 13, 2021
68c03de
doc updates
tchaton Aug 13, 2021
b01828b
update
tchaton Aug 13, 2021
d213c73
Merge branch 'master' into lightning_cli_registries
tchaton Aug 13, 2021
5935ec4
update on comments
tchaton Aug 17, 2021
b6616f0
Merge branch 'lightning_cli_registries' of https://github.com/PyTorch…
tchaton Aug 17, 2021
0d89423
Merge branch 'master' into lightning_cli_registries
carmocca Aug 19, 2021
37fd679
Fix mypy
carmocca Aug 19, 2021
f16db3d
Revert unrelated change which had broken mypy
carmocca Aug 19, 2021
572488c
Convert to staticmethod
carmocca Aug 19, 2021
2fc4608
Replace context managers for functional static transformations
carmocca Aug 19, 2021
9f383dc
Split tests
carmocca Aug 19, 2021
2a7dfa8
Refactor optimizer tests
carmocca Aug 19, 2021
423ab7b
Cleaning tests
carmocca Aug 19, 2021
7c2e39e
Delete broken test
carmocca Aug 19, 2021
048e159
Docs improvements
carmocca Aug 19, 2021
86fce55
Docs improvements
carmocca Aug 19, 2021
624b0d8
Restructure docs
carmocca Aug 19, 2021
2cc0dc5
Docs for callbacks
carmocca Aug 19, 2021
f9b49fe
Add reload test when add_optimizer_args is added by the user
carmocca Aug 19, 2021
afcc4ba
Add failing config test - needs to be fixed
carmocca Aug 19, 2021
9f41b88
Merge branch 'master' into lightning_cli_registries
carmocca Aug 28, 2021
0ed4ae8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
4dd0732
Use property
carmocca Aug 19, 2021
e0fae4f
Fixes after merge
carmocca Aug 28, 2021
4f053bb
Merge branch 'master' into lightning_cli_registries
carmocca Sep 15, 2021
a22fdb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2021
160b3f6
Update jsonargparse version
carmocca Sep 15, 2021
f185c2d
Use properties in registry
carmocca Sep 15, 2021
803385c
Keep hacks together
carmocca Sep 15, 2021
8eb8b05
Add FIXMEs
carmocca Sep 15, 2021
9d84127
add_class_choices
carmocca Sep 15, 2021
33ff2f4
Merge branch 'master' into lightning_cli_registries
carmocca Sep 15, 2021
cf82e1a
Remove contains registry. Avoid nested_key clash for optimizers and l…
carmocca Sep 15, 2021
b1cd083
Remove sanitize argv
carmocca Sep 15, 2021
95d31a7
Better support for new callback format
carmocca Sep 16, 2021
231e0ed
Avoid evaluating
carmocca Sep 16, 2021
2af596f
Minor cleaning
carmocca Sep 16, 2021
6add619
Mark argv as private
carmocca Sep 16, 2021
525358a
Fix mypy
carmocca Sep 16, 2021
84b8120
Fix mypy
carmocca Sep 16, 2021
7e48c0e
Fix mypy
carmocca Sep 16, 2021
40ce3c7
Merge branch 'master' into lightning_cli_registries
carmocca Sep 16, 2021
3e77e8e
Support shorthand notation to instantiate optimizers and learning rat…
carmocca Sep 16, 2021
1512a80
Update CHANGELOG
carmocca Sep 16, 2021
c6b86b1
Fix install
carmocca Sep 16, 2021
6f1600c
Fix install
carmocca Sep 16, 2021
a3a791f
Use release
carmocca Sep 16, 2021
f67a90f
Merge branch 'feat/cli-shorthand-optimizers' into lightning_cli_regis…
carmocca Sep 16, 2021
fedae46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2021
ee7a068
Introduce set_choices
carmocca Sep 16, 2021
6e67617
Undo change
carmocca Sep 16, 2021
e7f6d61
Replace add_class_choices with set_choices
carmocca Sep 16, 2021
8e87359
Replace add_class_choices with set_choices
carmocca Sep 16, 2021
c74426b
Merge
carmocca Sep 16, 2021
66cdb52
Docstrings
carmocca Sep 16, 2021
9217304
Merge branch 'master' into lightning_cli_registries
carmocca Sep 17, 2021
7b50401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2021
1406be9
Fix mypy
carmocca Sep 17, 2021
a000446
Undo change
carmocca Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update on comments
  • Loading branch information
tchaton committed Aug 10, 2021
commit 9a6e81e56003c14866ab78ccc16c9edf644f6ba5
16 changes: 10 additions & 6 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,13 @@ def link_optimizers_and_lr_schedulers(self) -> None:
if any(
True for v in sys.argv for optim_name in OPTIMIZER_REGISTRIES.keys() if f"--optimizer={optim_name}" in v
):
self.parser.add_optimizer_args(self.optimizer_registered)
if "optimizer" not in self.parser.groups:
self.parser.add_optimizer_args(self.optimizer_registered)

if any(True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES.keys() if f"-lr_scheduler={sch_name}" in v):
if any(True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES.keys() if f"--lr_scheduler={sch_name}" in v):
lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values())
self.parser.add_lr_scheduler_args(lr_schdulers)
if "lr_scheduler" not in self.parser.groups:
self.parser.add_lr_scheduler_args(lr_schdulers)

for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items():
if link_to == "AUTOMATIC":
Expand Down Expand Up @@ -471,10 +473,12 @@ def prepare_class_list_from_registry(self, pattern: str, registry: Registry):

def parse_arguments(self, parser: LightningArgumentParser) -> None:
"""Parses command line arguments and stores it in ``self.config``."""
with self.prepare_from_registry(OPTIMIZER_REGISTRIES), self.prepare_from_registry(
SCHEDULER_REGISTRIES
), self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRIES):
# fmt: off
with self.prepare_from_registry(OPTIMIZER_REGISTRIES), \
self.prepare_from_registry(SCHEDULER_REGISTRIES), \
self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRIES):
self.config = parser.parse_args()
# fmt: on

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""
Expand Down
14 changes: 8 additions & 6 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,9 @@ def add_arguments_to_parser(self, parser):
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1")
parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2")
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler")
parser.add_optimizer_args(self.optimizer_registered, nested_key="optim1", link_to="model.optim1")
parser.add_optimizer_args(torch.optim.SGD, nested_key="optim2", link_to="model.optim2")
parser.add_lr_scheduler_args(self.lr_scheduler_registered, link_to="model.scheduler")

class TestModel(BoringModel):
def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
Expand All @@ -677,9 +677,11 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
cli_args = [
f"--trainer.default_root_dir={tmpdir}",
"--trainer.max_epochs=1",
"--optim2.class_path=torch.optim.SGD",
"--optim2.init_args.lr=0.01",
"--lr_scheduler.gamma=0.2",
"--optim1=Adam",
"--optim1.weight_decay=0.001",
"--optim2.lr=0.005",
"--lr_scheduler=ExponentialLR",
"--lr_scheduler.gamma=0.1",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
Expand Down