forked from aws/amazon-sagemaker-examples
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Global hook implementation (aws#275)
* Global hook implementation for PyTorch * Remove todo * Renaming * Address comments * Disable tests instead of deleting * Abstract get_hook() and set_hook() methods for each framework * Fix comment * Address Rahul's comments * Fix copy-paste import error
- Loading branch information
1 parent
ce31b26
commit 4a9f80b
Showing
17 changed files
with
184 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
Easy-to-use methods for getting the singleton TornasoleHook. | ||
Sample usage: | ||
import tornasole.(pytorch | tensorflow | mxnet) as ts | ||
hook = ts.hook() | ||
""" | ||
|
||
import logging | ||
import os | ||
|
||
_ts_hook = None | ||
|
||
def get_hook(json_config_path, tornasole_hook_class) -> 'TornasoleHook': | ||
"""Return a singleton TornasoleHook or None. | ||
If the singleton hook exists, we return it. No questions asked, `json_config_path` is a no-op. | ||
Otherwise return hook_from_config(). | ||
""" | ||
global _ts_hook | ||
|
||
# If global hook exists, return it | ||
if _ts_hook: | ||
if json_config_path is not None: | ||
logging.error( | ||
f"`json_config_path` was passed, but TornasoleHook already exists. " | ||
f"Using the existing hook." | ||
) | ||
return _ts_hook | ||
# Otherwise return hook_from_config | ||
else: | ||
# Either returns a hook or None | ||
try: | ||
set_hook(custom_hook=tornasole_hook_class.hook_from_config(json_config_path=json_config_path)) | ||
except FileNotFoundError: | ||
pass | ||
|
||
return _ts_hook | ||
|
||
def set_hook(custom_hook: 'TornasoleHook') -> None: | ||
"""Overwrite the current hook with the passed hook.""" | ||
from tornasole.core.hook import BaseHook # prevent circular imports | ||
|
||
if not isinstance(custom_hook, BaseHook): | ||
raise TypeError(f"custom_hook={custom_hook} must be type TornasoleHook") | ||
|
||
global _ts_hook | ||
_ts_hook = custom_hook | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .hook import TornasoleHook | ||
from .mxnet_collection import Collection, CollectionManager | ||
from .mxnet_collection import get_collections, get_collection, get_collection_manager, load_collections, add_to_collection, add_to_default_collection, reset_collections | ||
from .singleton_utils import get_hook, set_hook | ||
from tornasole import SaveConfig, SaveConfigMode, ReductionConfig | ||
from tornasole import modes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Easy-to-use methods for getting the singleton TornasoleHook. | ||
This is abstracted into its own module to prevent circular import problems. | ||
Sample usage (in AWS-MXNet repo): | ||
import tornasole.mxnet as ts | ||
hook = ts.hook() | ||
""" | ||
|
||
import tornasole.core.singleton_utils as sutils | ||
from tornasole.core.singleton_utils import set_hook | ||
|
||
def get_hook(json_config_path=None) -> 'TornasoleHook': | ||
from tornasole.mxnet.hook import TornasoleHook | ||
return sutils.get_hook(json_config_path=json_config_path, tornasole_hook_class=TornasoleHook) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Easy-to-use methods for getting the singleton TornasoleHook. | ||
This is abstracted into its own module to prevent circular import problems. | ||
Sample usage (in AWS-PyTorch repo): | ||
import tornasole.pytorch as ts | ||
hook = ts.hook() | ||
""" | ||
|
||
import tornasole.core.singleton_utils as sutils | ||
from tornasole.core.singleton_utils import set_hook | ||
|
||
def get_hook(json_config_path=None) -> 'TornasoleHook': | ||
from tornasole.pytorch.hook import TornasoleHook | ||
return sutils.get_hook(json_config_path=json_config_path, tornasole_hook_class=TornasoleHook) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Easy-to-use methods for getting the singleton TornasoleHook. | ||
This is abstracted into its own module to prevent circular import problems. | ||
Sample usage (in AWS-TensorFlow repo): | ||
import tornasole.tensorflow as ts | ||
hook = ts.hook() | ||
""" | ||
|
||
import tornasole.core.singleton_utils as sutils | ||
from tornasole.core.singleton_utils import set_hook | ||
|
||
def get_hook(json_config_path=None) -> 'TornasoleHook': | ||
from tornasole.tensorflow.hook import TornasoleHook | ||
return sutils.get_hook(json_config_path=json_config_path, tornasole_hook_class=TornasoleHook) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Easy-to-use methods for getting the singleton TornasoleHook. | ||
This is abstracted into its own module to prevent circular import problems. | ||
Sample usage: | ||
import tornasole.xgboost as ts | ||
hook = ts.hook() | ||
""" | ||
|
||
import tornasole.core.singleton_utils as sutils | ||
from tornasole.core.singleton_utils import set_hook | ||
|
||
def get_hook(json_config_path=None) -> 'TornasoleHook': | ||
from tornasole.xgboost.hook import TornasoleHook | ||
return sutils.get_hook(json_config_path=json_config_path, tornasole_hook_class=TornasoleHook) |