Skip to content

Commit

Permalink
Corrected OpenAPI schema type for DecimalField (#7254)
Browse files Browse the repository at this point in the history
  • Loading branch information
clintonb authored Apr 9, 2020
1 parent 41f27c3 commit 603aac7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
20 changes: 14 additions & 6 deletions rest_framework/schemas/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rest_framework import exceptions, renderers, serializers
from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings

from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
Expand Down Expand Up @@ -446,11 +447,17 @@ def _map_field(self, field):
content['format'] = field.protocol
return content

# DecimalField has multipleOf based on decimal_places
if isinstance(field, serializers.DecimalField):
content = {
'type': 'number'
}
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
content = {
'type': 'string',
'format': 'decimal',
}
else:
content = {
'type': 'number'
}

if field.decimal_places:
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
if field.max_whole_digits:
Expand All @@ -461,7 +468,7 @@ def _map_field(self, field):

if isinstance(field, serializers.FloatField):
content = {
'type': 'number'
'type': 'number',
}
self._map_min_max(field, content)
return content
Expand Down Expand Up @@ -560,7 +567,8 @@ def _map_field_validators(self, field, schema):
schema['maximum'] = v.limit_value
elif isinstance(v, MinValueValidator):
schema['minimum'] = v.limit_value
elif isinstance(v, DecimalValidator):
elif isinstance(v, DecimalValidator) and \
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
if v.decimal_places:
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
if v.max_digits:
Expand Down
10 changes: 10 additions & 0 deletions tests/schemas/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,16 @@ def test_serializer_validators(self):
assert properties['decimal2']['type'] == 'number'
assert properties['decimal2']['multipleOf'] == .0001

assert properties['decimal3'] == {
'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01
}
assert properties['decimal4'] == {
'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01
}
assert properties['decimal5'] == {
'type': 'string', 'format': 'decimal', 'maximum': 10000, 'minimum': -10000, 'multipleOf': 0.01
}

assert properties['email']['type'] == 'string'
assert properties['email']['format'] == 'email'
assert properties['email']['default'] == '[email protected]'
Expand Down
8 changes: 6 additions & 2 deletions tests/schemas/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,13 @@ class ExampleValidatedSerializer(serializers.Serializer):
MinLengthValidator(limit_value=2),
)
)
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2, coerce_to_string=False)
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, coerce_to_string=False,
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
decimal3 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True)
decimal4 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True,
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
decimal5 = serializers.DecimalField(max_digits=6, decimal_places=2)
email = serializers.EmailField(default='[email protected]')
url = serializers.URLField(default='http://www.example.com', allow_null=True)
uuid = serializers.UUIDField()
Expand Down

0 comments on commit 603aac7

Please sign in to comment.