Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: Add more type hints #471

Merged
merged 1 commit into from
Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions waffle/admin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Dict, Tuple

from django.contrib import admin
from django.contrib.admin.models import LogEntry, CHANGE, DELETION
from django.contrib.admin.widgets import ManyToManyRawIdWidget
from django.contrib.contenttypes.models import ContentType
from django.http import HttpRequest
from django.utils.html import escape
from django.utils.translation import gettext_lazy as _

Expand All @@ -12,7 +15,7 @@
class BaseAdmin(admin.ModelAdmin):
search_fields = ('name', 'note')

def get_actions(self, request):
def get_actions(self, request: HttpRequest) -> Dict[str, Any]:
actions = super().get_actions(request)
if 'delete_selected' in actions:
del actions['delete_selected']
Expand Down Expand Up @@ -70,7 +73,7 @@ class InformativeManyToManyRawIdWidget(ManyToManyRawIdWidget):
Will display the names of the users in a parenthesised list after the
input field. This widget works with all models that have a "name" field.
"""
def label_and_url_for_value(self, values):
def label_and_url_for_value(self, values: Any) -> Tuple[str, str]:
names = []
key = self.rel.get_related_field().name
for value in values:
Expand Down
19 changes: 13 additions & 6 deletions waffle/decorators.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from functools import wraps, WRAPPER_ASSIGNMENTS
from typing import Any, Callable, Optional, Union

from django.http import Http404
from django.http import Http404, HttpRequest, HttpResponse, HttpResponsePermanentRedirect, HttpResponseRedirect
from django.shortcuts import redirect
from django.urls import reverse, NoReverseMatch

from waffle import flag_is_active, switch_is_active


def waffle_flag(flag_name, redirect_to=None):
def decorator(view):
def waffle_flag(
flag_name: str, redirect_to: Optional[Union[Callable, str]] = None,
) -> Callable[[Callable[[HttpRequest], HttpResponse]], Callable[[HttpRequest], HttpResponse]]:
def decorator(view: Callable[[HttpRequest], HttpResponse]) -> Callable[[HttpRequest], HttpResponse]:
@wraps(view, assigned=WRAPPER_ASSIGNMENTS)
def _wrapped_view(request, *args, **kwargs):
if flag_name.startswith('!'):
Expand All @@ -28,8 +31,10 @@ def _wrapped_view(request, *args, **kwargs):
return decorator


def waffle_switch(switch_name, redirect_to=None):
def decorator(view):
def waffle_switch(
switch_name: str, redirect_to: Optional[Union[Callable, str]] = None,
) -> Callable[[Callable[[HttpRequest], HttpResponse]], Callable[[HttpRequest], HttpResponse]]:
def decorator(view: Callable[[HttpRequest], HttpResponse]) -> Callable[[HttpRequest], HttpResponse]:
@wraps(view, assigned=WRAPPER_ASSIGNMENTS)
def _wrapped_view(request, *args, **kwargs):
if switch_name.startswith('!'):
Expand All @@ -49,7 +54,9 @@ def _wrapped_view(request, *args, **kwargs):
return decorator


def get_response_to_redirect(view, *args, **kwargs):
def get_response_to_redirect(
view: Optional[Union[Callable, str]], *args: Any, **kwargs: Any,
) -> Optional[Union[HttpResponseRedirect, HttpResponsePermanentRedirect]]:
try:
return redirect(reverse(view, args=args, kwargs=kwargs)) if view else None
except NoReverseMatch:
Expand Down
2 changes: 1 addition & 1 deletion waffle/management/commands/waffle_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from waffle import get_waffle_switch_model


def on_off_bool(string):
def on_off_bool(string: str) -> bool:
if string not in ['on', 'off']:
raise ArgumentTypeError("invalid choice: %r (choose from 'on', "
"'off')" % string)
Expand Down
20 changes: 14 additions & 6 deletions waffle/managers.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from django.db import models

from waffle.utils import get_setting, get_cache


class BaseManager(models.Manager):
if TYPE_CHECKING:
from waffle.models import _BaseModelType, AbstractBaseFlag, AbstractBaseSample, AbstractBaseSwitch
else:
_BaseModelType = TypeVar("_BaseModelType")


class BaseManager(models.Manager, Generic[_BaseModelType]):
KEY_SETTING = ''

def get_by_natural_key(self, name):
def get_by_natural_key(self, name: str) -> _BaseModelType:
return self.get(name=name)

def create(self, *args, **kwargs):
def create(self, *args: Any, **kwargs: Any) -> _BaseModelType:
cache = get_cache()
ret = super().create(*args, **kwargs)
cache_key = get_setting(self.KEY_SETTING)
cache.delete(cache_key)
return ret


class FlagManager(BaseManager):
class FlagManager(BaseManager['AbstractBaseFlag']):
KEY_SETTING = 'ALL_FLAGS_CACHE_KEY'


class SwitchManager(BaseManager):
class SwitchManager(BaseManager['AbstractBaseSwitch']):
KEY_SETTING = 'ALL_SWITCHES_CACHE_KEY'


class SampleManager(BaseManager):
class SampleManager(BaseManager['AbstractBaseSample']):
KEY_SETTING = 'ALL_SAMPLES_CACHE_KEY'
6 changes: 3 additions & 3 deletions waffle/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import random
from decimal import Decimal
from typing import Any, List, Optional, Set, Tuple, Type, TypeVar
from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar

from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser, Group
Expand Down Expand Up @@ -102,7 +102,7 @@ def flush(self) -> None:
]
cache.delete_many(keys)

def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
self.modified = timezone.now()
ret = super().save(*args, **kwargs)
if hasattr(transaction, 'on_commit'):
Expand All @@ -111,7 +111,7 @@ def save(self, *args, **kwargs):
self.flush()
return ret

def delete(self, *args, **kwargs):
def delete(self, *args: Any, **kwargs: Any) -> Tuple[int, Dict[str, int]]:
ret = super().delete(*args, **kwargs)
if hasattr(transaction, 'on_commit'):
transaction.on_commit(self.flush)
Expand Down
38 changes: 21 additions & 17 deletions waffle/testutils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Generic, Optional, TypeVar, Union

from django.test.utils import TestContextDecorator

from waffle import (
Expand All @@ -10,37 +12,39 @@

__all__ = ['override_flag', 'override_sample', 'override_switch']

_T = TypeVar("_T")


class _overrider(TestContextDecorator):
def __init__(self, name, active):
class _overrider(TestContextDecorator, Generic[_T]):
def __init__(self, name: str, active: _T):
super().__init__()
self.name = name
self.active = active

def get(self):
def get(self) -> None:
self.obj, self.created = self.cls.objects.get_or_create(name=self.name)

def update(self, active):
def update(self, active: _T) -> None:
raise NotImplementedError

def get_value(self):
def get_value(self) -> _T:
raise NotImplementedError

def enable(self):
def enable(self) -> None:
self.get()
self.old_value = self.get_value()
if self.old_value != self.active:
self.update(self.active)

def disable(self):
def disable(self) -> None:
if self.created:
self.obj.delete()
self.obj.flush()
else:
self.update(self.old_value)


class override_switch(_overrider):
class override_switch(_overrider[bool]):
"""
override_switch is a contextmanager for easier testing of switches.

Expand All @@ -64,41 +68,41 @@ def test_happy_mode_enabled():
"""
cls = get_waffle_switch_model()

def update(self, active):
def update(self, active: bool) -> None:
obj = self.cls.objects.get(pk=self.obj.pk)
obj.active = active
obj.save()
obj.flush()

def get_value(self):
def get_value(self) -> bool:
return self.obj.active


class override_flag(_overrider):
class override_flag(_overrider[Optional[bool]]):
cls = get_waffle_flag_model()

def update(self, active):
def update(self, active: Optional[bool]) -> None:
obj = self.cls.objects.get(pk=self.obj.pk)
obj.everyone = active
obj.save()
obj.flush()

def get_value(self):
def get_value(self) -> Optional[bool]:
return self.obj.everyone


class override_sample(_overrider):
class override_sample(_overrider[Union[bool, float]]):
cls = get_waffle_sample_model()

def get(self):
def get(self) -> None:
try:
self.obj = self.cls.objects.get(name=self.name)
self.created = False
except self.cls.DoesNotExist:
self.obj = self.cls.objects.create(name=self.name, percent='0.0')
self.created = True

def update(self, active):
def update(self, active: Union[bool, float]) -> None:
if active is True:
p = 100.0
elif active is False:
Expand All @@ -110,7 +114,7 @@ def update(self, active):
obj.save()
obj.flush()

def get_value(self):
def get_value(self) -> Union[bool, float]:
p = self.obj.percent
if p == 100.0:
return True
Expand Down
8 changes: 5 additions & 3 deletions waffle/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from django.http import HttpResponse, JsonResponse
from typing import Any, Dict

from django.http import HttpRequest, HttpResponse, JsonResponse
from django.template import loader
from django.views.decorators.cache import never_cache

Expand All @@ -12,7 +14,7 @@ def wafflejs(request):
content_type='application/x-javascript')


def _generate_waffle_js(request):
def _generate_waffle_js(request: HttpRequest) -> str:
flags = get_waffle_flag_model().get_all()
flag_values = [(f.name, f.is_active(request)) for f in flags]

Expand All @@ -37,7 +39,7 @@ def waffle_json(request):
return JsonResponse(_generate_waffle_json(request))


def _generate_waffle_json(request):
def _generate_waffle_json(request: HttpRequest) -> Dict[str, Dict[str, Any]]:
flags = get_waffle_flag_model().get_all()
flag_values = {
f.name: {
Expand Down