Skip to content

Commit

Permalink
add renderer & parser whitelist setting #598
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Nov 10, 2021
1 parent af1cccd commit e793dca
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 7 deletions.
22 changes: 15 additions & 7 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
get_doc, get_type_hints, get_view_model, is_basic_serializer, is_basic_type, is_field,
is_list_serializer, is_patched_serializer, is_serializer, is_trivial_string_variation,
resolve_django_path_parameter, resolve_regex_path_parameter, resolve_type_hint, safe_ref,
sanitize_specification_extensions, warn,
sanitize_specification_extensions, warn, whitelisted,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
Expand Down Expand Up @@ -264,10 +264,7 @@ def get_auth(self):
auths = []

for authenticator in self.view.get_authenticators():
if (
spectacular_settings.AUTHENTICATION_WHITELIST
and authenticator.__class__ not in spectacular_settings.AUTHENTICATION_WHITELIST
):
if not whitelisted(authenticator, spectacular_settings.AUTHENTICATION_WHITELIST, True):
continue

scheme = OpenApiAuthenticationExtension.get_match(authenticator)
Expand Down Expand Up @@ -970,14 +967,25 @@ def get_paginated_name(self, serializer_name):
return f'Paginated{serializer_name}List'

def map_parsers(self):
return list(dict.fromkeys([p.media_type for p in self.view.get_parsers()]))
return list(dict.fromkeys([
p.media_type for p in self.view.get_parsers()
if whitelisted(p, spectacular_settings.PARSER_WHITELIST)
]))

def map_renderers(self, attribute):
assert attribute in ['media_type', 'format']

# Either use whitelist or default back to old behavior by excluding BrowsableAPIRenderer
def use_renderer(r):
if spectacular_settings.RENDERER_WHITELIST:
return whitelisted(r, spectacular_settings.RENDERER_WHITELIST)
else:
return not isinstance(r, renderers.BrowsableAPIRenderer)

return list(dict.fromkeys([
getattr(r, attribute).split(';')[0]
for r in self.view.get_renderers()
if not isinstance(r, renderers.BrowsableAPIRenderer) and getattr(r, attribute, None)
if use_renderer(r) and hasattr(r, attribute)
]))

def _get_serializer(self):
Expand Down
9 changes: 9 additions & 0 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,3 +1161,12 @@ def resolve_type_hint(hint):
return build_array_type(resolve_type_hint(args[0]))
else:
raise UnableToProceedError()


def whitelisted(obj: object, classes: List[Type[object]], exact=False):
if not classes:
return True
if exact:
return obj.__class__ in classes
else:
return isinstance(obj, tuple(classes))
7 changes: 7 additions & 0 deletions drf_spectacular/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@
# authentication classes that are not contained in the whitelist. Use full import paths
# like ['rest_framework.authentication.TokenAuthentication', ...]
'AUTHENTICATION_WHITELIST': [],
# Controls which parsers are exposed in the schema. Works analog to AUTHENTICATION_WHITELIST.
'PARSER_WHITELIST': [],
# Controls which renderers are exposed in the schema. Works analog to AUTHENTICATION_WHITELIST.
# rest_framework.renderers.BrowsableAPIRenderer is ignored by default if whitelist is empty
'RENDERER_WHITELIST': [],

# Option for turning off error and warn messages
'DISABLE_ERRORS_AND_WARNINGS': False,
Expand Down Expand Up @@ -197,6 +202,8 @@
'SORT_OPERATIONS',
'SORT_OPERATION_PARAMETERS',
'AUTHENTICATION_WHITELIST',
'RENDERER_WHITELIST',
'PARSER_WHITELIST',
]

spectacular_settings = APISettings(
Expand Down
19 changes: 19 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2813,3 +2813,22 @@ def view_func(request, format=None):

schema = generate_schema('/x/', view_function=view_func)
assert list(schema['paths']['/x/']['get']['responses'].keys()) == ['2XX', '401', '4XX']


@mock.patch('drf_spectacular.settings.spectacular_settings.RENDERER_WHITELIST', [renderers.MultiPartRenderer])
@mock.patch('drf_spectacular.settings.spectacular_settings.PARSER_WHITELIST', [parsers.MultiPartParser])
def test_renderer_parser_whitelist(no_warnings):
class XSerializer(serializers.Serializer):
field = serializers.CharField()

class XViewset(viewsets.ModelViewSet):
serializer_class = XSerializer
queryset = SimpleModel.objects.none()
renderer_classes = [renderers.MultiPartRenderer, renderers.JSONRenderer]
parser_classes = [parsers.MultiPartParser, parsers.JSONParser]

schema = generate_schema('/x', XViewset)
request_types = list(schema['paths']['/x/']['post']['requestBody']['content'].keys())
response_types = list(schema['paths']['/x/']['post']['responses']['201']['content'].keys())

assert response_types == request_types == ['multipart/form-data']

0 comments on commit e793dca

Please sign in to comment.