From d8bbb22af4b7ee9f3532bb6aef889ef6dfdbf0c1 Mon Sep 17 00:00:00 2001 From: Nick Pope Date: Thu, 7 Oct 2021 20:11:43 +0100 Subject: [PATCH] Fix inheritance bugs with @extend_schema_view(). When creating a copy of a method from a parent class we now: - Ensure that `__qualname__` is defined correctly - i.e. `Child.method` instead of `Parent.method`. - This isn't essential but helps diagnosing issues when debugging. - Move application of the decorator to the last moment. - Deep copy the existing schema extensions before applying decorator. This fixes #218 where two child classes with @extend_schema_view affect each other - schema extensions are applied to the parent such that the second child overwrites the changes applied to the first child. This also fixes my case where a child with @extend_schema_view clobbered the schema extensions of the parent which also used @extend_schema_view. --- drf_spectacular/utils.py | 16 ++++++--- tests/test_extend_schema_view.py | 19 ++++++++-- tests/test_extend_schema_view.yml | 59 +++++++++++++++++++++++++++++++ tests/test_regressions.py | 35 ++++++++++++++++++ 4 files changed, 123 insertions(+), 6 deletions(-) diff --git a/drf_spectacular/utils.py b/drf_spectacular/utils.py index 62b43628..dae473b6 100644 --- a/drf_spectacular/utils.py +++ b/drf_spectacular/utils.py @@ -1,6 +1,7 @@ import functools import inspect import sys +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from rest_framework.fields import Field, empty @@ -472,13 +473,20 @@ def extend_schema_view(**kwargs) -> Callable[[F], F]: :param kwargs: method names as argument names and :func:`@extend_schema <.extend_schema>` calls as values """ - def wrapping_decorator(method_decorator, method): - @method_decorator + def wrapping_decorator(method_decorator, view, method): @functools.wraps(method) def wrapped_method(self, request, *args, **kwargs): return method(self, request, *args, **kwargs) - return wrapped_method + # Construct a new __qualname__ based on the __name__ of the target view. + wrapped_method.__qualname__ = f'{view.__name__}.{method.__name__}' + + # Clone the extended schema if the source method has it. + if hasattr(method, 'kwargs'): + wrapped_method.kwargs = deepcopy(method.kwargs) + + # Finally apply any additional schema extensions applied to the target view. + return method_decorator(wrapped_method) def decorator(view): view_methods = {m.__name__: m for m in get_view_methods(view)} @@ -498,7 +506,7 @@ def decorator(view): if method_name in view.__dict__: method_decorator(method) else: - setattr(view, method_name, wrapping_decorator(method_decorator, method)) + setattr(view, method_name, wrapping_decorator(method_decorator, view, method)) return view return decorator diff --git a/tests/test_extend_schema_view.py b/tests/test_extend_schema_view.py index 16291abe..f00e7df8 100644 --- a/tests/test_extend_schema_view.py +++ b/tests/test_extend_schema_view.py @@ -28,7 +28,7 @@ class Meta: extended_action=extend_schema(description='view extended action description'), raw_action=extend_schema(description='view raw action description'), ) -class XViewset(mixins.ListModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet): +class XViewSet(mixins.ListModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet): queryset = ESVModel.objects.all() serializer_class = ESVSerializer @@ -52,9 +52,24 @@ class YViewSet(viewsets.ModelViewSet): queryset = ESVModel.objects.all() +# view to make sure that schema applied to a subclass does not affect its parent. +@extend_schema_view( + list=extend_schema(exclude=True), + retrieve=extend_schema(description='overridden description for child only'), + extended_action=extend_schema(responses={200: {'type': 'string', 'pattern': r'^[0-9]{4}(?:-[0-9]{2}){2}$'}}), + raw_action=extend_schema(summary="view raw action summary"), +) +class ZViewSet(XViewSet): + @extend_schema(tags=['child-tag']) + @action(detail=False, methods=['GET']) + def raw_action(self, request): + return Response('2019-03-01') + + router = routers.SimpleRouter() -router.register('x', XViewset) +router.register('x', XViewSet) router.register('y', YViewSet) +router.register('z', ZViewSet) urlpatterns = router.urls diff --git a/tests/test_extend_schema_view.yml b/tests/test_extend_schema_view.yml index 8597e892..4228fcb8 100644 --- a/tests/test_extend_schema_view.yml +++ b/tests/test_extend_schema_view.yml @@ -232,6 +232,65 @@ paths: responses: '204': description: No response body + /z/{id}/: + get: + operationId: z_retrieve + description: overridden description for child only + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this esv model. + required: true + tags: + - custom-retrieve-tag + security: + - cookieAuth: [] + - basicAuth: [] + - {} + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/ESV' + description: '' + /z/extended_action/: + get: + operationId: z_extended_action_retrieve + description: view extended action description + tags: + - global-tag + security: + - cookieAuth: [] + - basicAuth: [] + - {} + responses: + '200': + content: + application/json: + schema: + type: string + pattern: ^[0-9]{4}(?:-[0-9]{2}){2}$ + description: '' + /z/raw_action/: + get: + operationId: z_raw_action_retrieve + summary: view raw action summary + tags: + - child-tag + security: + - cookieAuth: [] + - basicAuth: [] + - {} + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/ESV' + description: '' components: schemas: ESV: diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 7e895f0d..fb7e2da1 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -2569,3 +2569,38 @@ def custom_action(self): schema = generate_schema('x', viewset=XViewSet) schema['paths']['/x/{id}/custom_action/']['get']['summary'] == 'A custom action!' + + +def test_extend_schema_view_isolation(no_warnings): + + class Animal(models.Model): + pass + + class AnimalSerializer(serializers.ModelSerializer): + class Meta: + model = Animal + fields = '__all__' + + class AnimalViewSet(viewsets.GenericViewSet): + serializer_class = AnimalSerializer + queryset = Animal.objects.all() + + @action(detail=False) + def notes(self, request): + pass # pragma: no cover + + @extend_schema_view(notes=extend_schema(summary='List mammals.')) + class MammalViewSet(AnimalViewSet): + pass + + @extend_schema_view(notes=extend_schema(summary='List insects.')) + class InsectViewSet(AnimalViewSet): + pass + + router = routers.SimpleRouter() + router.register('api/mammals', MammalViewSet) + router.register('api/insects', InsectViewSet) + + schema = generate_schema(None, patterns=router.urls) + assert schema['paths']['/api/mammals/notes/']['get']['summary'] == 'List mammals.' + assert schema['paths']['/api/insects/notes/']['get']['summary'] == 'List insects.'