diff --git a/botocore/__init__.py b/botocore/__init__.py index bbc83ce369..7fa5fc601c 100644 --- a/botocore/__init__.py +++ b/botocore/__init__.py @@ -28,6 +28,7 @@ def emit(self, record): log = logging.getLogger('botocore') log.addHandler(NullHandler()) +_INITIALIZERS = [] _first_cap_regex = re.compile('(.)([A-Z][a-z]+)') _end_cap_regex = re.compile('([a-z0-9])([A-Z])') @@ -97,3 +98,42 @@ def xform_name(name, sep='_', _xform_cache=_xform_cache): transformed = _end_cap_regex.sub(r'\1' + sep + r'\2', s1).lower() _xform_cache[key] = transformed return _xform_cache[key] + + +def register_initializer(callback): + """Register an initializer function for session creation. + + This initializer function will be invoked whenever a new + `botocore.session.Session` is instantiated. + + :type callback: callable + :param callback: A callable that accepts a single argument + of type `botocore.session.Session`. + + """ + _INITIALIZERS.append(callback) + + +def unregister_initializer(callback): + """Unregister an initializer function. + + :type callback: callable + :param callback: A callable that was previously registered + with `botocore.register_initializer`. + + :raises ValueError: If a callback is provided that is not currently + registered as an initializer. + + """ + _INITIALIZERS.remove(callback) + + +def invoke_initializers(session): + """Invoke all initializers for a session. + + :type session: botocore.session.Session + :param session: The session to initialize. + + """ + for initializer in _INITIALIZERS: + initializer(session) diff --git a/botocore/session.py b/botocore/session.py index 729c2b0e94..856157ab73 100644 --- a/botocore/session.py +++ b/botocore/session.py @@ -30,6 +30,7 @@ UNSIGNED, __version__, handlers, + invoke_initializers, monitoring, paginate, retryhandler, @@ -148,6 +149,7 @@ def __init__( self.session_var_map = SessionVarDict(self, self.SESSION_VARIABLES) if session_vars is not None: self.session_var_map.update(session_vars) + invoke_initializers(self) def _register_components(self): self._register_credential_provider() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1c90ce9f64..153098ccbd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -23,7 +23,12 @@ import botocore.exceptions import botocore.loaders import botocore.session -from botocore import UNSIGNED, client +from botocore import ( + UNSIGNED, + client, + register_initializer, + unregister_initializer, +) from botocore.configprovider import ConfigChainFactory from botocore.hooks import HierarchicalEmitter from botocore.model import ServiceModel @@ -993,3 +998,32 @@ def test_new_session_with_none_region(self): s3_client = self.session.create_client('s3', region_name=None) self.assertIsInstance(s3_client, client.BaseClient) self.assertTrue(s3_client.meta.region_name is not None) + + +class TestInitializationHooks(BaseSessionTest): + def test_can_register_init_hook(self): + call_args = [] + + def init_hook(session): + call_args.append(session) + + register_initializer(init_hook) + self.addCleanup(unregister_initializer, init_hook) + session = create_session() + self.assertEqual(call_args, [session]) + + def test_can_unregister_hook(self): + call_args = [] + + def init_hook(session): + call_args.append(session) + + register_initializer(init_hook) + unregister_initializer(init_hook) + create_session() + self.assertEqual(call_args, []) + + def test_unregister_hook_raises_value_error(self): + not_registered = lambda session: None + with self.assertRaises(ValueError): + self.assertRaises(unregister_initializer(not_registered))