Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Move spam check to handlers.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Aug 14, 2020
1 parent 7a88534 commit bc84745
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 33 deletions.
11 changes: 10 additions & 1 deletion synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class CasHandler:
"""

def __init__(self, hs):
self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
Expand Down Expand Up @@ -210,8 +211,16 @@ async def handle_ticket(

else:
if not registered_user_id:
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent", default=[b""]
)[0].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)

registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
localpart=localpart,
default_display_name=user_display_name,
user_agent_ips=(user_agent, ip_address),
)

await self._auth_handler.complete_sso_login(
Expand Down
21 changes: 18 additions & 3 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class OidcHandler:
"""

def __init__(self, hs: "HomeServer"):
self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
Expand Down Expand Up @@ -692,9 +693,17 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
self._render_error(request, "invalid_token", str(e))
return

# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)

# Call the mapper to register/login the user
try:
user_id = await self._map_userinfo_to_user(userinfo, token)
user_id = await self._map_userinfo_to_user(
userinfo, token, user_agent, ip_address
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
Expand Down Expand Up @@ -831,7 +840,9 @@ def _verify_expiry(self, caveat: str) -> bool:
now = self._clock.time_msec()
return now < expiry

async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
async def _map_userinfo_to_user(
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
) -> str:
"""Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim
Expand All @@ -846,6 +857,8 @@ async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Raises:
MappingException: if there was an error while mapping some properties
Expand Down Expand Up @@ -902,7 +915,9 @@ async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=attributes["display_name"],
localpart=localpart,
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
)

await self._datastore.record_user_external_id(
Expand Down
26 changes: 24 additions & 2 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ReplicationPostRegisterActionsServlet,
ReplicationRegisterServlet,
)
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester

Expand All @@ -52,6 +53,8 @@ def __init__(self, hs):
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid

self.spam_checker = hs.get_spam_checker()

if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
Expand Down Expand Up @@ -142,7 +145,7 @@ async def register_user(
address=None,
bind_emails=[],
by_admin=False,
shadow_banned=False,
user_agent_ips=None,
):
"""Registers a new client on the server.
Expand All @@ -160,14 +163,33 @@ async def register_user(
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
shadow_banned (bool): Shadow-ban the created user.
user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
during the registration process.
Returns:
str: user_id
Raises:
SynapseError if there was a problem registering.
"""
self.check_registration_ratelimit(address)

result = self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)

if result == RegistrationBehaviour.DENY:
logger.info(
"Blocked registration of %r", localpart,
)
# We return a 429 to make it not obvious that they've been
# denied.
raise SynapseError(429, "Rate limited")

shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
if shadow_banned:
logger.info(
"Shadow banning registration of %r", localpart,
)

# do not check_auth_blocking if the call is coming through the Admin API
if not by_admin:
await self.auth.check_auth_blocking(threepid=threepid)
Expand Down
18 changes: 16 additions & 2 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Saml2SessionData:

class SamlHandler:
def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
Expand Down Expand Up @@ -133,8 +134,14 @@ async def handle_saml_response(self, request: SynapseRequest) -> None:
# the dict.
self.expire_sessions()

# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)

user_id, current_session = await self._map_saml_response_to_user(
resp_bytes, relay_state
resp_bytes, relay_state, user_agent, ip_address
)

# Complete the interactive auth session or the login.
Expand All @@ -147,14 +154,20 @@ async def handle_saml_response(self, request: SynapseRequest) -> None:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)

async def _map_saml_response_to_user(
self, resp_bytes: str, client_redirect_url: str
self,
resp_bytes: str,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]:
"""
Given a sample response, retrieve the cached session and user for it.
Args:
resp_bytes: The SAML response.
client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
Tuple of the user ID and SAML session associated with this response.
Expand Down Expand Up @@ -291,6 +304,7 @@ async def _map_saml_response_to_user(
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
user_agent_ips=(user_agent, ip_address),
)

await self._datastore.record_user_external_id(
Expand Down
23 changes: 1 addition & 22 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
parse_string,
)
from synapse.push.mailer import load_jinja2_templates
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
Expand Down Expand Up @@ -391,8 +390,6 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self._registration_enabled = self.hs.config.enable_registration

self.spam_checker = hs.get_spam_checker()

self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
)
Expand Down Expand Up @@ -617,31 +614,13 @@ async def on_POST(self, request):
session_id
)

result = self.spam_checker.check_registration_for_spam(
threepid, desired_username, entries,
)

if result == RegistrationBehaviour.DENY:
logger.info(
"Blocked registration of %r", desired_username,
)
# We return a 429 to make it not obvious that they've been
# denied.
raise SynapseError(429, "Rate limited")

shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
if shadow_banned:
logger.info(
"Shadow banning registration of %r", desired_username,
)

registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
address=client_addr,
shadow_banned=shadow_banned,
user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
Expand Down
18 changes: 15 additions & 3 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,16 @@ def test_callback(self):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(spec=["args", "getCookie", "addCookie"])
request = Mock(
spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
)

code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
Expand All @@ -392,14 +396,20 @@ def test_callback(self):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]

request.requestHeaders = Mock(spec=["getRawHeaders"])
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
request.getClientIP.return_value = ip_address

yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))

self.handler._auth_handler.complete_sso_login.assert_called_once_with(
user_id, request, client_redirect_url,
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
self.handler._map_userinfo_to_user.assert_called_once_with(
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()

Expand Down Expand Up @@ -431,7 +441,9 @@ def test_callback(self):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
self.handler._map_userinfo_to_user.assert_called_once_with(
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()

Expand Down

0 comments on commit bc84745

Please sign in to comment.