Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to tests/rest/client #12108

Merged
merged 13 commits into from
Mar 2, 2022
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,8 @@ exclude = (?x)
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_account.py
|tests/rest/client/test_filter.py
|tests/rest/client/test_report_event.py
|tests/rest/client/test_rooms.py
|tests/rest/client/test_third_party_rules.py
|tests/rest/client/test_transactions.py
|tests/rest/client/test_typing.py
|tests/rest/key/v2/test_remote_key_resource.py
|tests/rest/media/v1/test_base.py
|tests/rest/media/v1/test_media_storage.py
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import itertools
import urllib.parse
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch

from twisted.test.proto_helpers import MemoryReactor
Expand Down Expand Up @@ -45,7 +45,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
]
hijack_auth = False

def default_config(self) -> dict:
def default_config(self) -> Dict[str, Any]:
# We need to enable msc1849 support for aggregations
config = super().default_config()
config["experimental_msc1849_support_enabled"] = True
Expand Down
27 changes: 16 additions & 11 deletions tests/rest/client/test_report_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@

import json

from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
from synapse.rest.client import login, report_event, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock

from tests import unittest

Expand All @@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
report_event.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
Expand All @@ -42,35 +47,35 @@ def prepare(self, reactor, clock, hs):
self.event_id = resp["event_id"]
self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"

def test_reason_str_and_score_int(self):
def test_reason_str_and_score_int(self) -> None:
data = {"reason": "this makes me sad", "score": -100}
self._assert_status(200, data)

def test_no_reason(self):
data = {"score": 0}
def test_no_reason(self) -> None:
data: JsonDict = {"score": 0}
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self._assert_status(200, data)

def test_no_score(self):
def test_no_score(self) -> None:
data = {"reason": "this makes me sad"}
self._assert_status(200, data)

def test_no_reason_and_no_score(self):
data = {}
def test_no_reason_and_no_score(self) -> None:
data: JsonDict = {}
self._assert_status(200, data)

def test_reason_int_and_score_str(self):
def test_reason_int_and_score_str(self) -> None:
data = {"reason": 10, "score": "string"}
self._assert_status(400, data)

def test_reason_zero_and_score_blank(self):
def test_reason_zero_and_score_blank(self) -> None:
data = {"reason": 0, "score": ""}
self._assert_status(400, data)

def test_reason_and_score_null(self):
def test_reason_and_score_null(self) -> None:
data = {"reason": None, "score": None}
self._assert_status(400, data)

def _assert_status(self, response_status, data):
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
"POST",
self.report_path,
Expand Down
100 changes: 65 additions & 35 deletions tests/rest/client/test_third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from unittest.mock import Mock

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, StateMap
from synapse.util import Clock
from synapse.util.frozenutils import unfreeze

from tests import unittest
Expand All @@ -34,40 +40,44 @@


class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
thread_local.rules_module = self
self.module_api = module_api

async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
) -> bool:
return True

async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
async def check_event_allowed(
self, event: EventBase, state: StateMap[EventBase]
) -> Union[bool, dict]:
return True

@staticmethod
def parse_config(config):
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config


class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api)

def on_create_room(
def on_create_room( # type: ignore[override]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self, requester: Requester, config: dict, is_requester_admin: bool
):
) -> bool:
return False


class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api)

async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
async def check_event_allowed(
self, event: EventBase, state: StateMap[EventBase]
) -> JsonDict:
d = event.get_dict()
content = unfreeze(event.content)
content["foo"] = "bar"
Expand All @@ -82,7 +92,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
room.register_servlets,
]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()

load_legacy_third_party_event_rules(hs)
Expand All @@ -92,22 +102,30 @@ def make_homeserver(self, reactor, clock):
# Note that these checks are not relevant to this test case.

# Have this homeserver auto-approve all event signature checking.
async def approve_all_signature_checking(_, pdu):
async def approve_all_signature_checking(
_: RoomVersion, pdu: EventBase
) -> EventBase:
return pdu

hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment]

# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
async def _check_event_auth(origin, event, context, *args, **kwargs):
async def _check_event_auth(
origin: str,
event: EventBase,
context: EventContext,
*args: Any,
**kwargs: Any,
) -> EventContext:
return context

hs.get_federation_event_handler()._check_event_auth = _check_event_auth
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]

return hs

def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.invitee = self.register_user("invitee", "hackme")
Expand All @@ -119,13 +137,15 @@ def prepare(self, reactor, clock, homeserver):
except Exception:
pass

def test_third_party_rules(self):
def test_third_party_rules(self) -> None:
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
return ev.type != "foo.bar.forbidden", None

callback = Mock(spec=[], side_effect=check)
Expand Down Expand Up @@ -159,7 +179,7 @@ async def check(ev, state):
)
self.assertEqual(channel.result["code"], b"403", channel.result)

def test_third_party_rules_workaround_synapse_errors_pass_through(self):
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
"""
Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
is functional: that SynapseErrors are passed through from check_event_allowed
Expand All @@ -170,7 +190,7 @@ def test_third_party_rules_workaround_synapse_errors_pass_through(self):
"""

class NastyHackException(SynapseError):
def error_dict(self):
def error_dict(self) -> JsonDict:
"""
This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response.
Expand All @@ -180,7 +200,9 @@ def error_dict(self):
return result

# add a callback that will raise our hacky exception
async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
raise NastyHackException(429, "message")

self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
Expand All @@ -200,11 +222,13 @@ async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
{"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
)

def test_cannot_modify_event(self):
def test_cannot_modify_event(self) -> None:
"""cannot accidentally modify an event before it is persisted"""

# first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
ev.content = {"x": "y"}
return True, None

Expand All @@ -221,10 +245,12 @@ async def check(ev: EventBase, state):
# 500 Internal Server Error
self.assertEqual(channel.code, 500, channel.result)

def test_modify_event(self):
def test_modify_event(self) -> None:
"""The module can return a modified version of the event"""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict()
d["content"] = {"x": "y"}
return True, d
Expand All @@ -251,10 +277,12 @@ async def check(ev: EventBase, state):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")

def test_message_edit(self):
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict()
d["content"] = {
"msgtype": "m.text",
Expand Down Expand Up @@ -313,7 +341,7 @@ async def check(ev: EventBase, state):
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY")

def test_send_event(self):
def test_send_event(self) -> None:
"""Tests that a module can send an event into a room via the module api"""
content = {
"msgtype": "m.text",
Expand Down Expand Up @@ -342,7 +370,7 @@ def test_send_event(self):
}
}
)
def test_legacy_check_event_allowed(self):
def test_legacy_check_event_allowed(self) -> None:
"""Tests that the wrapper for legacy check_event_allowed callbacks works
correctly.
"""
Expand Down Expand Up @@ -377,13 +405,13 @@ def test_legacy_check_event_allowed(self):
}
}
)
def test_legacy_on_create_room(self):
def test_legacy_on_create_room(self) -> None:
"""Tests that the wrapper for legacy on_create_room callbacks works
correctly.
"""
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)

def test_sent_event_end_up_in_room_state(self):
def test_sent_event_end_up_in_room_state(self) -> None:
"""Tests that a state event sent by a module while processing another state event
doesn't get dropped from the state of the room. This is to guard against a bug
where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
Expand All @@ -398,7 +426,9 @@ def test_sent_event_end_up_in_room_state(self):
api = self.hs.get_module_api()

# Define a callback that sends a custom event on power levels update.
async def test_fn(event: EventBase, state_events):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
if event.is_state and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
Expand Down Expand Up @@ -434,7 +464,7 @@ async def test_fn(event: EventBase, state_events):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["i"], i)

def test_on_new_event(self):
def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None))
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
Expand Down Expand Up @@ -499,7 +529,7 @@ def _send_event_over_federation(self) -> None:

self.assertEqual(channel.code, 200, channel.result)

def _update_power_levels(self, event_default: int = 0):
def _update_power_levels(self, event_default: int = 0) -> None:
"""Updates the room's power levels.

Args:
Expand Down
Loading