Skip to content

Commit

Permalink
Port to new module API
Browse files Browse the repository at this point in the history
This ports the auth provider to the new module API of Synapse 1.46+.

Docs: https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html

Based on anishihara@6c29f4d by @anishihara

Fixes ma1uta#9
  • Loading branch information
davidmehren committed Nov 23, 2021
1 parent 092d8f6 commit 354a471
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions rest_auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@
#

import logging
from typing import Tuple, Optional, Callable, Awaitable

from twisted.internet import defer
import requests
import json
import time
import synapse
from synapse import module_api

logger = logging.getLogger(__name__)


class RestAuthProvider(object):

def __init__(self, config, account_handler):
def __init__(self, config: dict, account_handler: module_api):
self.account_handler = account_handler

if not config.endpoint:
Expand All @@ -43,6 +47,36 @@ def __init__(self, config, account_handler):
logger.info('Endpoint: %s', self.endpoint)
logger.info('Enforce lowercase username during registration: %s', self.regLower)

# register an auth callback handler
# see https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html
account_handler.register_password_provider_auth_callbacks(
auth_checkers={
("m.login.password", ("password",)): self.check_m_login_password
}
)

async def check_m_login_password(self, username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict") -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
if login_type != "m.login.password":
return None

# get the complete MXID
mxid = self.account_handler.get_qualified_user_id(username)

# check if the password is valid with the old function
password_valid = await self.check_password(mxid, login_dict.get("password"))

if password_valid:
return mxid
else:
return None

async def check_password(self, user_id, password):
logger.info("Got password check for " + user_id)
data = {'user': {'id': user_id, 'password': password}}
Expand Down Expand Up @@ -84,9 +118,10 @@ async def check_password(self, user_id, password):
try:
store = await self.account_handler._hs.get_profile_handler().store # for synapse >= 1.9.0
except AttributeError:
store = await self.account_handler.hs.get_profile_handler().store # for synapse < 1.9.0
store = await self.account_handler.hs.get_profile_handler().store # for synapse < 1.9.0

if "display_name" in profile and ((registration and self.config.setNameOnRegister) or (self.config.setNameOnLogin)):
if "display_name" in profile and (
(registration and self.config.setNameOnRegister) or (self.config.setNameOnLogin)):
display_name = profile["display_name"]
logger.info("Setting display name to '%s' based on profile data", display_name)
await store.set_profile_displayname(localpart, display_name)
Expand Down

0 comments on commit 354a471

Please sign in to comment.