diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ef5e9d4..f81785df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Changed - Fixed _model parameter annotations [PR #115](https://github.com/model-bakers/model_bakery/pull/115) +- Fixes bug when field has callable `default` [PR #117](https://github.com/model-bakers/model_bakery/pull/117) - [dev] Drop Python 3.5 support as it is retired (https://www.python.org/downloads/release/python-3510/) ### Removed diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 7d6a00e4..ffbf5827 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -515,7 +515,11 @@ def generate_value(self, field: Field, commit: bool = True) -> Any: self._remote_field(field).model, contenttypes.models.ContentType ) - if field.name in self.attr_mapping: + if field.has_default(): + if callable(field.default): + return field.default() + return field.default + elif field.name in self.attr_mapping: generator = self.attr_mapping[field.name] elif getattr(field, "choices"): generator = random_gen.gen_from_choices(field.choices) @@ -525,8 +529,6 @@ def generate_value(self, field: Field, commit: bool = True) -> Any: generator = generators.get(field.__class__) elif field.__class__ in self.type_mapping: generator = self.type_mapping[field.__class__] - elif field.has_default(): - return field.default else: raise TypeError("%s is not supported by baker." % field.__class__) diff --git a/tests/generic/models.py b/tests/generic/models.py index 3b3d81d8..bfc170ff 100755 --- a/tests/generic/models.py +++ b/tests/generic/models.py @@ -10,6 +10,7 @@ from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.contrib.contenttypes.models import ContentType from django.core.files.storage import FileSystemStorage +from django.utils.timezone import now from model_bakery.gis import BAKER_GIS from model_bakery.timezone import smart_datetime as datetime @@ -265,6 +266,10 @@ class DummyBlankFieldsModel(models.Model): blank_text_field = models.TextField(max_length=300, blank=True) +class ExtendedDefaultField(models.IntegerField): + pass + + class DummyDefaultFieldsModel(models.Model): default_id = models.AutoField(primary_key=True) default_char_field = models.CharField(max_length=50, default="default") @@ -279,6 +284,9 @@ class DummyDefaultFieldsModel(models.Model): ) default_email_field = models.EmailField(default="foo@bar.org") default_slug_field = models.SlugField(default="a-slug") + default_unknown_class_field = ExtendedDefaultField(default=42) + default_callable_int_field = models.IntegerField(default=lambda: 12) + default_callable_datetime_field = models.DateTimeField(default=now) class DummyFileFieldModel(models.Model): diff --git a/tests/test_baker.py b/tests/test_baker.py index f38ed43d..89c3211c 100644 --- a/tests/test_baker.py +++ b/tests/test_baker.py @@ -542,6 +542,15 @@ def test_fill_optional_with_integer(self): with pytest.raises(TypeError): baker.make(models.DummyBlankFieldsModel, _fill_optional=1) + def test_fill_optional_with_default(self): + dummy = baker.make(models.DummyDefaultFieldsModel, _fill_optional=True) + assert dummy.default_callable_int_field == 12 + assert isinstance(dummy.default_callable_datetime_field, datetime.datetime) + + def test_fill_optional_with_default_unknown_class(self): + dummy = baker.make(models.DummyDefaultFieldsModel, _fill_optional=True) + assert dummy.default_unknown_class_field == 42 + @pytest.mark.django_db class TestFillAutoFieldsTestCase: @@ -572,6 +581,9 @@ def test_skip_fields_with_default(self): assert dummy.default_decimal_field == Decimal("0") assert dummy.default_email_field == "foo@bar.org" assert dummy.default_slug_field == "a-slug" + assert dummy.default_unknown_class_field == 42 + assert dummy.default_callable_int_field == 12 + assert isinstance(dummy.default_callable_datetime_field, datetime.datetime) @pytest.mark.django_db diff --git a/tests/test_extending_bakery.py b/tests/test_extending_bakery.py index a266f201..f4c36181 100644 --- a/tests/test_extending_bakery.py +++ b/tests/test_extending_bakery.py @@ -28,6 +28,7 @@ class SadPeopleBaker(baker.Baker): attr_mapping = { "enjoy_jards_macale": gen_opposite, "like_metal_music": gen_opposite, + "name": gen_opposite, # Use a field without `default` } @@ -53,8 +54,13 @@ def test_string_to_generator_required(self): like_metal_music_field = Person._meta.get_field("like_metal_music") sad_people_factory = SadPeopleBaker(Person) person = sad_people_factory.make() - assert person.enjoy_jards_macale is not enjoy_jards_macale_field.default - assert person.like_metal_music is not like_metal_music_field.default + assert person.enjoy_jards_macale is enjoy_jards_macale_field.default + assert person.like_metal_music is like_metal_music_field.default + + def test_kwarg_used_over_attr_mapping_generator(self): + sad_people_factory = SadPeopleBaker(Person) + person = sad_people_factory.make(name="test") + assert person.name == "test" @pytest.mark.parametrize("value", [18, 18.5, [], {}, True]) def test_fail_pass_non_string_to_generator_required(self, value):