diff --git a/example/plugins/microservices/disco_to_target_issuer.yaml.example b/example/plugins/microservices/disco_to_target_issuer.yaml.example new file mode 100644 index 000000000..5d5d0100c --- /dev/null +++ b/example/plugins/microservices/disco_to_target_issuer.yaml.example @@ -0,0 +1,6 @@ +module: satosa.micro_services.disco.DiscoToTargetIssuer +name: DiscoToTargetIssuer +config: + # the regex that will intercept http requests to be handled with this microservice + disco_endpoints: + - ".*/disco" diff --git a/example/plugins/microservices/target_based_routing.yaml.example b/example/plugins/microservices/target_based_routing.yaml.example new file mode 100644 index 000000000..55e699c53 --- /dev/null +++ b/example/plugins/microservices/target_based_routing.yaml.example @@ -0,0 +1,8 @@ +module: satosa.micro_services.custom_routing.DecideBackendByTargetIssuer +name: TargetRouter +config: + default_backend: Saml2 + + target_mapping: + "http://idpspid.testunical.it:8088": "spidSaml2" # map SAML entity with entity id 'target_id' to backend name + "http://eidas.testunical.it:8081/saml2/metadata": "eidasSaml2" diff --git a/src/satosa/micro_services/custom_routing.py b/src/satosa/micro_services/custom_routing.py index d903502be..541b824f1 100644 --- a/src/satosa/micro_services/custom_routing.py +++ b/src/satosa/micro_services/custom_routing.py @@ -2,6 +2,8 @@ from base64 import urlsafe_b64encode from satosa.context import Context +from satosa.internal import InternalData + from .base import RequestMicroService from ..exception import SATOSAConfigurationError from ..exception import SATOSAError @@ -10,6 +12,52 @@ logger = logging.getLogger(__name__) +class CustomRoutingError(SATOSAError): + """SATOSA exception raised by CustomRouting rules""" + pass + + +class DecideBackendByTargetIssuer(RequestMicroService): + """ + Select target backend based on the target issuer. + """ + + def __init__(self, config:dict, *args, **kwargs): + """ + Constructor. + + :param config: microservice configuration loaded from yaml file + :type config: Dict[str, Dict[str, str]] + """ + super().__init__(*args, **kwargs) + + self.target_mapping = config['target_mapping'] + self.default_backend = config['default_backend'] + + def process(self, context:Context, data:InternalData): + """Set context.target_backend based on the target issuer""" + + target_issuer = context.get_decoration(Context.KEY_TARGET_ENTITYID) + if not target_issuer: + logger.info('skipping backend decision because no target_issuer was found') + return super().process(context, data) + + target_backend = ( + self.target_mapping.get(target_issuer) + or self.default_backend + ) + + report = { + 'msg': 'decided target backend by target issuer', + 'target_issuer': target_issuer, + 'target_backend': target_backend, + } + logger.info(report) + + context.target_backend = target_backend + return super().process(context, data) + + class DecideBackendByRequester(RequestMicroService): """ Select which backend should be used based on who the requester is. diff --git a/src/satosa/micro_services/disco.py b/src/satosa/micro_services/disco.py new file mode 100644 index 000000000..274f18780 --- /dev/null +++ b/src/satosa/micro_services/disco.py @@ -0,0 +1,58 @@ +from satosa.context import Context +from satosa.internal import InternalData + +from .base import RequestMicroService +from ..exception import SATOSAError + + +class DiscoToTargetIssuerError(SATOSAError): + """SATOSA exception raised by CustomRouting rules""" + + +class DiscoToTargetIssuer(RequestMicroService): + def __init__(self, config:dict, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.disco_endpoints = config['disco_endpoints'] + if not isinstance(self.disco_endpoints, list) or not self.disco_endpoints: + raise DiscoToTargetIssuerError('disco_endpoints must be a list of str') + + def process(self, context:Context, data:InternalData): + context.state[self.name] = { + 'target_frontend': context.target_frontend, + 'internal_data': data.to_dict(), + } + return super().process(context, data) + + def register_endpoints(self): + """ + URL mapping of additional endpoints this micro service needs to register for callbacks. + + Example of a mapping from the url path '/callback' to the callback() method of a micro service: + reg_endp = [ + ('^/callback1$', self.callback), + ] + + :rtype List[Tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]]] + + :return: A list with functions and args bound to a specific endpoint url, + [(regexp, Callable[[satosa.context.Context], satosa.response.Response]), ...] + """ + + return [ + (path , self._handle_disco_response) + for path in self.disco_endpoints + ] + + def _handle_disco_response(self, context:Context): + target_issuer = context.request.get('entityID') + if not target_issuer: + raise DiscoToTargetIssuerError('no valid entity_id in the disco response') + + target_frontend = context.state.get(self.name, {}).get('target_frontend') + data_serialized = context.state.get(self.name, {}).get('internal_data', {}) + data = InternalData.from_dict(data_serialized) + + context.target_frontend = target_frontend + context.decorate(Context.KEY_TARGET_ENTITYID, target_issuer) + return super().process(context, data) diff --git a/tests/satosa/micro_services/test_custom_routing.py b/tests/satosa/micro_services/test_custom_routing.py index 7a5227250..d2022bc3e 100644 --- a/tests/satosa/micro_services/test_custom_routing.py +++ b/tests/satosa/micro_services/test_custom_routing.py @@ -1,11 +1,16 @@ from base64 import urlsafe_b64encode +from unittest import TestCase import pytest from satosa.context import Context -from satosa.exception import SATOSAError, SATOSAConfigurationError +from satosa.state import State +from satosa.exception import SATOSAError, SATOSAConfigurationError, SATOSAStateError from satosa.internal import InternalData from satosa.micro_services.custom_routing import DecideIfRequesterIsAllowed +from satosa.micro_services.custom_routing import DecideBackendByTargetIssuer +from satosa.micro_services.custom_routing import CustomRoutingError + TARGET_ENTITY = "entity1" @@ -156,3 +161,45 @@ def test_missing_target_entity_id_from_context(self, context): req = InternalData(requester="test_requester") with pytest.raises(SATOSAError): decide_service.process(context, req) + + +class TestDecideBackendByTargetIssuer(TestCase): + def setUp(self): + context = Context() + context.state = State() + + config = { + 'default_backend': 'default_backend', + 'target_mapping': { + 'mapped_idp.example.org': 'mapped_backend', + }, + } + + plugin = DecideBackendByTargetIssuer( + config=config, + name='test_decide_service', + base_url='https://satosa.example.org', + ) + plugin.next = lambda ctx, data: (ctx, data) + + self.config = config + self.context = context + self.plugin = plugin + + def test_when_target_is_not_set_do_skip(self): + data = InternalData(requester='test_requester') + newctx, newdata = self.plugin.process(self.context, data) + assert not newctx.target_backend + + def test_when_target_is_not_mapped_choose_default_backend(self): + self.context.decorate(Context.KEY_TARGET_ENTITYID, 'idp.example.org') + data = InternalData(requester='test_requester') + newctx, newdata = self.plugin.process(self.context, data) + assert newctx.target_backend == 'default_backend' + + def test_when_target_is_mapped_choose_mapping_backend(self): + self.context.decorate(Context.KEY_TARGET_ENTITYID, 'mapped_idp.example.org') + data = InternalData(requester='test_requester') + data.requester = 'somebody else' + newctx, newdata = self.plugin.process(self.context, data) + assert newctx.target_backend == 'mapped_backend' diff --git a/tests/satosa/micro_services/test_disco.py b/tests/satosa/micro_services/test_disco.py new file mode 100644 index 000000000..ac2c3c5c2 --- /dev/null +++ b/tests/satosa/micro_services/test_disco.py @@ -0,0 +1,44 @@ +from unittest import TestCase + +import pytest + +from satosa.context import Context +from satosa.state import State +from satosa.micro_services.disco import DiscoToTargetIssuer +from satosa.micro_services.disco import DiscoToTargetIssuerError + + +class TestDiscoToTargetIssuer(TestCase): + def setUp(self): + context = Context() + context.state = State() + + config = { + 'disco_endpoints': [ + '.*/disco', + ], + } + + plugin = DiscoToTargetIssuer( + config=config, + name='test_disco_to_target_issuer', + base_url='https://satosa.example.org', + ) + plugin.next = lambda ctx, data: (ctx, data) + + self.config = config + self.context = context + self.plugin = plugin + + def test_when_entity_id_is_not_set_raise_error(self): + self.context.request = {} + with pytest.raises(DiscoToTargetIssuerError): + self.plugin._handle_disco_response(self.context) + + def test_when_entity_id_is_set_target_issuer_is_set(self): + entity_id = 'idp.example.org' + self.context.request = { + 'entityID': entity_id, + } + newctx, newdata = self.plugin._handle_disco_response(self.context) + assert newctx.get_decoration(Context.KEY_TARGET_ENTITYID) == entity_id