Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hints for apnstruncate, apnspushkin, notifications, and gcmpushkin #264

Merged
merged 19 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/264.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve static type checking.
94 changes: 61 additions & 33 deletions sygnal/apnspushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import logging
import os
from datetime import timezone
from typing import Dict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import uuid4

import aioapns
from aioapns import APNs, NotificationRequest
from aioapns.common import NotificationResult
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate
from opentracing import logs, tags
from opentracing import Span, logs, tags
from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.defer import Deferred

Expand All @@ -38,9 +39,18 @@
TemporaryNotificationDispatchException,
)
from sygnal.helper.proxy.proxy_asyncio import ProxyingEventLoopWrapper
from sygnal.notifications import ConcurrencyLimitedPushkin
from sygnal.notifications import (
ConcurrencyLimitedPushkin,
Device,
Notification,
NotificationContext,
)
from sygnal.utils import NotificationLoggerAdapter, twisted_sleep

if TYPE_CHECKING:
from sygnal.sygnal import Sygnal


logger = logging.getLogger(__name__)

SEND_TIME_HISTOGRAM = Histogram(
Expand Down Expand Up @@ -89,7 +99,7 @@ class ApnsPushkin(ConcurrencyLimitedPushkin):
"topic",
} | ConcurrencyLimitedPushkin.UNDERSTOOD_CONFIG_FIELDS

def __init__(self, name, sygnal, config):
def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None:
super().__init__(name, sygnal, config)

nonunderstood = set(self.cfg.keys()).difference(self.UNDERSTOOD_CONFIG_FIELDS)
Expand All @@ -99,16 +109,16 @@ def __init__(self, name, sygnal, config):
nonunderstood,
)

platform = self.get_config("platform")
platform = self.get_config("platform", str)
if not platform or platform == "production" or platform == "prod":
self.use_sandbox = False
elif platform == "sandbox":
self.use_sandbox = True
else:
raise PushkinSetupException(f"Invalid platform: {platform}")

certfile = self.get_config("certfile")
keyfile = self.get_config("keyfile")
certfile = self.get_config("certfile", str)
keyfile = self.get_config("keyfile", str)
if not certfile and not keyfile:
raise PushkinSetupException(
"You must provide a path to an APNs certificate, or an APNs token."
Expand All @@ -119,17 +129,17 @@ def __init__(self, name, sygnal, config):
raise PushkinSetupException(
f"The APNs certificate '{certfile}' does not exist."
)
else:
elif keyfile:
# keyfile
if not os.path.exists(keyfile):
raise PushkinSetupException(
f"The APNs key file '{keyfile}' does not exist."
)
if not self.get_config("key_id"):
if not self.get_config("key_id", str):
raise PushkinSetupException("You must supply key_id.")
if not self.get_config("team_id"):
if not self.get_config("team_id", str):
raise PushkinSetupException("You must supply team_id.")
if not self.get_config("topic"):
if not self.get_config("topic", str):
raise PushkinSetupException("You must supply topic.")

# use the Sygnal global proxy configuration
Expand Down Expand Up @@ -157,10 +167,10 @@ def __init__(self, name, sygnal, config):
# additional connection attempts, so =0 means try once only
# (we will retry at a higher level so not worth doing more here)
self.apns_client = APNs(
key=self.get_config("keyfile"),
key_id=self.get_config("key_id"),
team_id=self.get_config("team_id"),
topic=self.get_config("topic"),
key=self.get_config("keyfile", str),
key_id=self.get_config("key_id", str),
team_id=self.get_config("team_id", str),
topic=self.get_config("topic", str),
use_sandbox=self.use_sandbox,
max_connection_attempts=0,
loop=loop,
Expand All @@ -169,7 +179,7 @@ def __init__(self, name, sygnal, config):
# without this, aioapns will retry every second forever.
self.apns_client.pool.max_connection_attempts = 3

def _report_certificate_expiration(self, certfile):
def _report_certificate_expiration(self, certfile: str) -> None:
"""Export the epoch time that the certificate expires as a metric."""
with open(certfile, "rb") as f:
cert_bytes = f.read()
Expand All @@ -180,7 +190,14 @@ def _report_certificate_expiration(self, certfile):
cert.not_valid_after.replace(tzinfo=timezone.utc).timestamp()
)

async def _dispatch_request(self, log, span, device, shaved_payload, prio):
async def _dispatch_request(
self,
log: NotificationLoggerAdapter,
span: Span,
device: Device,
shaved_payload: Dict[str, Any],
prio: int,
) -> List[str]:
"""
Actually attempts to dispatch the notification once.
"""
Expand Down Expand Up @@ -237,7 +254,9 @@ async def _dispatch_request(self, log, span, device, shaved_payload, prio):
f"{response.status} {response.description}"
)

async def _dispatch_notification_unlimited(self, n, device, context):
async def _dispatch_notification_unlimited(
self, n: Notification, device: Device, context: NotificationContext
) -> List[str]:
log = NotificationLoggerAdapter(logger, {"request_id": context.request_id})

# The pushkey is kind of secret because you can use it to send push
Expand All @@ -250,14 +269,16 @@ async def _dispatch_notification_unlimited(self, n, device, context):
) as span_parent:

if n.event_id and not n.type:
payload = self._get_payload_event_id_only(n, device)
payload: Optional[Dict[str, Any]] = self._get_payload_event_id_only(
n, device
)
else:
payload = self._get_payload_full(n, device, log)

if payload is None:
# Nothing to do
span_parent.log_kv({logs.EVENT: "apns_no_payload"})
return
return []
prio = 10
if n.prio == "low":
prio = 5
Expand All @@ -273,6 +294,7 @@ async def _dispatch_notification_unlimited(self, n, device, context):
with self.sygnal.tracer.start_span(
"apns_dispatch_try", tags=span_tags, child_of=span_parent
) as span:
assert shaved_payload is not None
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
return await self._dispatch_request(
log, span, device, shaved_payload, prio
)
Expand All @@ -290,22 +312,21 @@ async def _dispatch_notification_unlimited(self, n, device, context):
span_parent.log_kv(
{"event": "temporary_fail", "retrying_in": retry_delay}
)

if retry_number == self.MAX_TRIES - 1:
raise NotificationDispatchException(
"Retried too many times."
) from exc
else:
if retry_number < self.MAX_TRIES - 1:
await twisted_sleep(
retry_delay, twisted_reactor=self.sygnal.reactor
)

def _get_payload_event_id_only(self, n, device):
raise NotificationDispatchException("Retried too many times.")

def _get_payload_event_id_only(
self, n: Notification, device: Device
) -> Dict[str, Any]:
"""
Constructs a payload for a notification where we know only the event ID.
Args:
n: The notification to construct a payload for.
device (Device): Device information to which the constructed payload
device: Device information to which the constructed payload
will be sent.

Returns:
Expand All @@ -328,21 +349,26 @@ def _get_payload_event_id_only(self, n, device):

return payload

def _get_payload_full(self, n, device, log):
def _get_payload_full(
self, n: Notification, device: Device, log: NotificationLoggerAdapter
) -> Optional[Dict[str, Any]]:
"""
Constructs a payload for a notification.
Args:
n: The notification to construct a payload for.
device (Device): Device information to which the constructed payload
device: Device information to which the constructed payload
will be sent.
log: A logger.

Returns:
The APNs payload as nested dicts.
"""
from_display = n.sender
if n.sender_display_name is not None:
if not n.sender and not n.sender_display_name:
from_display = " "
elif n.sender_display_name is not None:
from_display = n.sender_display_name
elif n.sender is not None:
from_display = n.sender
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
from_display = from_display[0 : self.MAX_FIELD_LENGTH]

loc_key = None
Expand Down Expand Up @@ -471,7 +497,9 @@ def _get_payload_full(self, n, device, log):

return payload

async def _send_notification(self, request):
async def _send_notification(
self, request: NotificationRequest
) -> NotificationResult:
return await Deferred.fromFuture(
asyncio.ensure_future(self.apns_client.send_notification(request))
)
27 changes: 20 additions & 7 deletions sygnal/apnstruncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
# Copied and adapted from
# https://raw.githubusercontent.com/matrix-org/pushbaby/master/pushbaby/truncate.py
import json
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from typing_extensions import Literal

def json_encode(payload: Any) -> bytes:
Choppable = Union[
Tuple[Literal["alert", "alert.body"]], Tuple[Literal["alert.loc-args"], int]
]


def json_encode(payload) -> bytes:
return json.dumps(payload, ensure_ascii=False).encode()


Expand Down Expand Up @@ -85,8 +91,8 @@ def truncate(payload: Dict[str, Any], max_length: int = 2048) -> Dict[str, Any]:
return payload


def _choppables_for_aps(aps):
ret: List[Union[Tuple[str], Tuple[str, int]]] = []
def _choppables_for_aps(aps: Dict[str, Any]) -> List[Choppable]:
ret: List[Choppable] = []
if "alert" not in aps:
return ret

Expand All @@ -102,7 +108,10 @@ def _choppables_for_aps(aps):
return ret


def _choppable_get(aps, choppable):
def _choppable_get(
aps: Dict[str, Any],
choppable: Choppable,
):
if choppable[0] == "alert":
return aps["alert"]
elif choppable[0] == "alert.body":
Expand All @@ -111,7 +120,11 @@ def _choppable_get(aps, choppable):
return aps["alert"]["loc-args"][choppable[1]]


def _choppable_put(aps, choppable, val):
def _choppable_put(
aps: Dict[str, Any],
choppable: Choppable,
val: str,
) -> None:
if choppable[0] == "alert":
aps["alert"] = val
elif choppable[0] == "alert.body":
Expand All @@ -120,7 +133,7 @@ def _choppable_put(aps, choppable, val):
aps["alert"]["loc-args"][choppable[1]] = val


def _longest_choppable(aps):
def _longest_choppable(aps: Dict[str, Any]) -> Optional[Choppable]:
longest = None
length_of_longest = 0
for c in _choppables_for_aps(aps):
Expand Down
Loading