Skip to content

Commit

Permalink
Pass module API to OIDC mapping provider
Browse files Browse the repository at this point in the history
  • Loading branch information
MatMaul committed Mar 1, 2024
1 parent 274f289 commit 4ac1b21
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions changelog.d/16974.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.
4 changes: 3 additions & 1 deletion docs/sso_mapping_providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ comment these options out and use those specified by the module instead.

A custom mapping provider must specify the following methods:

* `def __init__(self, parsed_config)`
* `def __init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
stable API available for extension modules.
* `def parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
Expand Down
17 changes: 14 additions & 3 deletions synapse/handlers/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
Expand Down Expand Up @@ -421,9 +422,19 @@ def __init__(
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)

self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
user_mapping_provider_init_method = (
provider.user_mapping_provider_class.__init__
)
if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
else:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
)

self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users

Expand Down Expand Up @@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
This is the default mapping provider.
"""

def __init__(self, config: JinjaOidcMappingConfig):
def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
self._config = config

@staticmethod
Expand Down

0 comments on commit 4ac1b21

Please sign in to comment.