Skip to content

Commit

Permalink
Global hook implementation (aws#275)
Browse files Browse the repository at this point in the history
* Global hook implementation for PyTorch

* Remove todo

* Renaming

* Address comments

* Disable tests instead of deleting

* Abstract get_hook() and set_hook() methods for each framework

* Fix comment

* Address Rahul's comments

* Fix copy-paste import error
  • Loading branch information
jarednielsen authored and rahul003 committed Oct 14, 2019
1 parent ce31b26 commit 4a9f80b
Show file tree
Hide file tree
Showing 17 changed files with 184 additions and 22 deletions.
3 changes: 3 additions & 0 deletions tests/core/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from tornasole.core.utils import is_s3, check_dir_exists
from tornasole.core.json_config import DEFAULT_SAGEMAKER_TORNASOLE_PATH, collect_tornasole_config_params
from tornasole.core.collection_manager import CollectionManager
Expand Down Expand Up @@ -51,6 +53,7 @@ def test_check_dir_exists_no():
except Exception as e:
pass

@pytest.mark.skip(reason="If no config file is found, then SM doesn't want a TornasoleHook")
def test_collect_tornasole_config_params():
tornasole_params = collect_tornasole_config_params(collection_manager=CollectionManager())
assert(tornasole_params["out_dir"] == DEFAULT_SAGEMAKER_TORNASOLE_PATH)
2 changes: 2 additions & 0 deletions tests/mxnet/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
import shutil
from tornasole.core.access_layer.utils import has_training_ended
import pytest
import os
from tornasole.core.json_config import TORNASOLE_CONFIG_FILE_PATH_ENV_STR, DEFAULT_SAGEMAKER_TORNASOLE_PATH

Expand Down Expand Up @@ -39,6 +40,7 @@ def test_hook_from_json_config_full():
run_mnist_gluon_model(hook=hook, num_steps_train=10, num_steps_eval=10, register_to_loss_block=True)
shutil.rmtree(out_dir, True)

@pytest.mark.skip(reason="If no config file is found, then SM doesn't want a TornasoleHook")
def test_default_hook():
reset_collections()
shutil.rmtree('/opt/ml/output/tensors', ignore_errors=True)
Expand Down
1 change: 1 addition & 0 deletions tests/xgboost/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_hook_from_json_config_full(tmpdir, monkeypatch):
run_xgboost_model(hook=hook)


@pytest.mark.skip(reason="If no config file is found, then SM doesn't want a TornasoleHook")
def test_default_hook(monkeypatch):
reset_collections()
shutil.rmtree('/opt/ml/output/tensors', ignore_errors=True)
Expand Down
2 changes: 1 addition & 1 deletion tornasole/core/config_constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
TORNASOLE_CONFIG_DEFAULT_WORKER_NAME = "worker_0"
TORNASOLE_CONFIG_FILE_PATH_ENV_STR = "TORNASOLE_CONFIG_FILE_PATH"
DEFAULT_CONFIG_FILE_PATH = "/opt/ml/input/data/tornasole-config/tornasole-hook-config.json"
DEFAULT_CONFIG_FILE_PATH = "/opt/ml/input/config/debughookconfig.json"
TORNASOLE_CONFIG_REDUCTION_CONFIGS_KEY = "reduction_configs"
TORNASOLE_CONFIG_SAVE_CONFIGS_KEY = "save_configs"
TORNASOLE_CONFIG_OUTDIR_KEY = "LocalPath"
Expand Down
32 changes: 19 additions & 13 deletions tornasole/core/json_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,27 @@
TORNASOLE_CONFIG_OUTDIR_KEY, TORNASOLE_CONFIG_RDN_CFG_KEY, TORNASOLE_CONFIG_INCLUDE_REGEX_KEY, \
TORNASOLE_CONFIG_SAVE_ALL_KEY, DEFAULT_SAGEMAKER_TORNASOLE_PATH

def get_json_config_as_dict(json_config_path) -> Dict:
"""Checks json_config_path, then environment variables, then attempts to load.
def create_hook_from_json_config(hook_cls, collection_manager):
Will throw FileNotFoundError if a config is not available.
"""
if json_config_path is not None:
path = json_config_path
else:
path = os.getenv(TORNASOLE_CONFIG_FILE_PATH_ENV_STR, DEFAULT_CONFIG_FILE_PATH)

with open(path) as json_config_file:
params_dict = json.load(json_config_file)
return params_dict

def create_hook_from_json_config(hook_cls, collection_manager, json_config_path):
"""Returns a TornasoleHook object corresponding to either TF, PT, or MXNet.
If json_config_path is None, an environment variable must be set.
Here we compare HookParameters with CollectionConfiguration and set all the defaults.
"""
tornasole_params = collect_tornasole_config_params(collection_manager)
tornasole_params = collect_tornasole_config_params(collection_manager, json_config_path=json_config_path)
if "collections" in tornasole_params:
include_collections = []
for obj in tornasole_params["collections"].values():
Expand All @@ -84,23 +98,15 @@ def create_hook_from_json_config(hook_cls, collection_manager):
)


def collect_tornasole_config_params(collection_manager) -> Dict:
def collect_tornasole_config_params(collection_manager, json_config_path) -> Dict:
"""Read the config file from an environment variable and return a dictionary.
Return a dictionary, example keys:
dict_keys(['reduction_configs', 'save_configs', 'collections', 'out_dir', 'reduction_config', 'save_config',
'include_regex', 'config_name', 's3_path'])
"""
# Build params dictionary if given a json file, otherwise leave it empty
params_dict = {}
json_config_file_path = os.getenv(TORNASOLE_CONFIG_FILE_PATH_ENV_STR, DEFAULT_CONFIG_FILE_PATH)
if os.path.exists(json_config_file_path):
with open(json_config_file_path) as json_config_file:
params_dict = json.load(json_config_file)
else:
get_logger().info(
f"json config file path {json_config_file_path} doesn't exist. Creating a default hook."
)
# Build params dictionary from the json file
params_dict = get_json_config_as_dict(json_config_path=json_config_path)

# Declare defaults
tornasole_params_dict = {
Expand Down
51 changes: 51 additions & 0 deletions tornasole/core/singleton_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Easy-to-use methods for getting the singleton TornasoleHook.
Sample usage:
import tornasole.(pytorch | tensorflow | mxnet) as ts
hook = ts.hook()
"""

import logging
import os

_ts_hook = None

def get_hook(json_config_path, tornasole_hook_class) -> 'TornasoleHook':
"""Return a singleton TornasoleHook or None.
If the singleton hook exists, we return it. No questions asked, `json_config_path` is a no-op.
Otherwise return hook_from_config().
"""
global _ts_hook

# If global hook exists, return it
if _ts_hook:
if json_config_path is not None:
logging.error(
f"`json_config_path` was passed, but TornasoleHook already exists. "
f"Using the existing hook."
)
return _ts_hook
# Otherwise return hook_from_config
else:
# Either returns a hook or None
try:
set_hook(custom_hook=tornasole_hook_class.hook_from_config(json_config_path=json_config_path))
except FileNotFoundError:
pass

return _ts_hook

def set_hook(custom_hook: 'TornasoleHook') -> None:
"""Overwrite the current hook with the passed hook."""
from tornasole.core.hook import BaseHook # prevent circular imports

if not isinstance(custom_hook, BaseHook):
raise TypeError(f"custom_hook={custom_hook} must be type TornasoleHook")

global _ts_hook
_ts_hook = custom_hook


1 change: 1 addition & 0 deletions tornasole/mxnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .hook import TornasoleHook
from .mxnet_collection import Collection, CollectionManager
from .mxnet_collection import get_collections, get_collection, get_collection_manager, load_collections, add_to_collection, add_to_default_collection, reset_collections
from .singleton_utils import get_hook, set_hook
from tornasole import SaveConfig, SaveConfigMode, ReductionConfig
from tornasole import modes
9 changes: 7 additions & 2 deletions tornasole/mxnet/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tornasole.core.json_config import TORNASOLE_CONFIG_DEFAULT_WORKER_NAME, \
create_hook_from_json_config
from tornasole.mxnet.mxnet_collection import get_collection_manager
from tornasole.mxnet.singleton_utils import set_hook
from tornasole.mxnet.utils import get_reduction_of_data, make_numpy_array
# from tornasole.mxnet.graph import _net2pb

Expand Down Expand Up @@ -42,18 +43,22 @@ def __init__(self,
if CollectionKeys.LOSSES not in self.include_collections:
self.include_collections.append(CollectionKeys.LOSSES)
self.last_block = None

self.model = None
self.exported_model = False

set_hook(self)


def get_worker_name(self):
return TORNASOLE_CONFIG_DEFAULT_WORKER_NAME

def get_num_workers(self):
return 1

@classmethod
def hook_from_config(cls):
return create_hook_from_json_config(cls, get_collection_manager())
def hook_from_config(cls, json_config_path=None):
return create_hook_from_json_config(cls, get_collection_manager(), json_config_path=json_config_path)

def _cleanup(self):
# Write the gradients of the past step if the writer is still available.
Expand Down
16 changes: 16 additions & 0 deletions tornasole/mxnet/singleton_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Easy-to-use methods for getting the singleton TornasoleHook.
This is abstracted into its own module to prevent circular import problems.
Sample usage (in AWS-MXNet repo):
import tornasole.mxnet as ts
hook = ts.hook()
"""

import tornasole.core.singleton_utils as sutils
from tornasole.core.singleton_utils import set_hook

def get_hook(json_config_path=None) -> 'TornasoleHook':
from tornasole.mxnet.hook import TornasoleHook
return sutils.get_hook(json_config_path=json_config_path, tornasole_hook_class=TornasoleHook)
1 change: 1 addition & 0 deletions tornasole/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .collection import get_collections, get_collection, \
load_collections, \
add_to_collection, add_to_default_collection, reset_collections
from .singleton_utils import get_hook, set_hook
from tornasole import SaveConfig, SaveConfigMode, ReductionConfig
from tornasole import modes
23 changes: 21 additions & 2 deletions tornasole/pytorch/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tornasole.core.hook import CallbackHook
from tornasole.core.collection import CollectionKeys
from tornasole.pytorch.collection import get_collection_manager
from tornasole.pytorch.singleton_utils import set_hook
from tornasole.pytorch.utils import get_reduction_of_data, make_numpy_array
# from tornasole.pytorch._pytorch_graph import graph as create_graph

Expand Down Expand Up @@ -43,9 +44,13 @@ def __init__(self,
# mapping of module objects to their names,
# useful in forward hook for logging input/output of modules
self.module_maps = dict()

self.model = None
self.exported_model = False

set_hook(self)


def get_num_workers(self):
"""Check horovod and torch.distributed."""
# Try torch.distributed
Expand Down Expand Up @@ -81,8 +86,22 @@ def get_worker_name(self):
return TORNASOLE_CONFIG_DEFAULT_WORKER_NAME

@classmethod
def hook_from_config(cls):
return create_hook_from_json_config(cls, get_collection_manager())
def hook_from_config(cls, json_config_path=None):
"""Relies on the existence of a JSON file.
First, check json_config_path. If it's not None,
If the file exists, use that.
If the file does not exist, throw an error.
Otherwise, check the filepath set by a SageMaker environment variable.
If the file exists, use that.
Otherwise,
return None.
"""
return create_hook_from_json_config(
cls,
get_collection_manager(),
json_config_path=json_config_path
)

def log_params(self, module):
module_name = module._get_name()
Expand Down
16 changes: 16 additions & 0 deletions tornasole/pytorch/singleton_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Easy-to-use methods for getting the singleton TornasoleHook.
This is abstracted into its own module to prevent circular import problems.
Sample usage (in AWS-PyTorch repo):
import tornasole.pytorch as ts
hook = ts.hook()
"""

import tornasole.core.singleton_utils as sutils
from tornasole.core.singleton_utils import set_hook

def get_hook(json_config_path=None) -> 'TornasoleHook':
from tornasole.pytorch.hook import TornasoleHook
return sutils.get_hook(json_config_path=json_config_path, tornasole_hook_class=TornasoleHook)
1 change: 1 addition & 0 deletions tornasole/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
add_to_collection, add_to_default_collection, reset_collections

from .optimizer import TornasoleOptimizer
from .singleton_utils import get_hook, set_hook
from tornasole import modes
from tornasole.core.collection import CollectionKeys
from tornasole import SaveConfig, SaveConfigMode, ReductionConfig
9 changes: 7 additions & 2 deletions tornasole/tensorflow/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from tornasole.core.hook import BaseHook
from tornasole.core.reductions import get_reduction_tensor_name
from tornasole.core.json_config import TORNASOLE_CONFIG_DEFAULT_WORKER_NAME, create_hook_from_json_config
from tornasole.tensorflow.singleton_utils import set_hook


DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.WEIGHTS,
CollectionKeys.GRADIENTS,
Expand Down Expand Up @@ -82,6 +84,9 @@ def __init__(self, out_dir=None,
self.graph = None
self.tensors_to_save_this_step = None

set_hook(self)


def get_worker_name(self):
try:
import horovod.tensorflow as hvd
Expand All @@ -99,8 +104,8 @@ def get_num_workers(self):
return 1

@classmethod
def hook_from_config(cls):
return create_hook_from_json_config(cls, get_collection_manager())
def hook_from_config(cls, json_config_path=None):
return create_hook_from_json_config(cls, get_collection_manager(), json_config_path=json_config_path)


def _prepare_tensors(self):
Expand Down
16 changes: 16 additions & 0 deletions tornasole/tensorflow/singleton_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Easy-to-use methods for getting the singleton TornasoleHook.
This is abstracted into its own module to prevent circular import problems.
Sample usage (in AWS-TensorFlow repo):
import tornasole.tensorflow as ts
hook = ts.hook()
"""

import tornasole.core.singleton_utils as sutils
from tornasole.core.singleton_utils import set_hook

def get_hook(json_config_path=None) -> 'TornasoleHook':
from tornasole.tensorflow.hook import TornasoleHook
return sutils.get_hook(json_config_path=json_config_path, tornasole_hook_class=TornasoleHook)
7 changes: 5 additions & 2 deletions tornasole/xgboost/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from tornasole.core.tfevent.util import make_numpy_array
from tornasole.core.access_layer.utils import training_has_ended
from tornasole.core.json_config import create_hook_from_json_config
from tornasole.xgboost.singleton_utils import set_hook


from .collection import get_collection_manager
from .utils import validate_data_file_path, get_content_type, get_dmatrix
Expand Down Expand Up @@ -90,6 +92,7 @@ def __init__(
self.validation_data = self._validate_data(validation_data)
# as we do cleanup ourselves at end of job
atexit.unregister(self._cleanup)
set_hook(self)

def __call__(self, env: CallbackEnv) -> None:
self._callback(env)
Expand All @@ -103,8 +106,8 @@ def get_worker_name(self):
pass

@classmethod
def hook_from_config(cls):
return create_hook_from_json_config(cls, get_collection_manager())
def hook_from_config(cls, json_config_path=None):
return create_hook_from_json_config(cls, get_collection_manager(), json_config_path=json_config_path)

def _cleanup(self):
# todo: this second export should go
Expand Down
16 changes: 16 additions & 0 deletions tornasole/xgboost/singleton_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Easy-to-use methods for getting the singleton TornasoleHook.
This is abstracted into its own module to prevent circular import problems.
Sample usage:
import tornasole.xgboost as ts
hook = ts.hook()
"""

import tornasole.core.singleton_utils as sutils
from tornasole.core.singleton_utils import set_hook

def get_hook(json_config_path=None) -> 'TornasoleHook':
from tornasole.xgboost.hook import TornasoleHook
return sutils.get_hook(json_config_path=json_config_path, tornasole_hook_class=TornasoleHook)

0 comments on commit 4a9f80b

Please sign in to comment.