From 297745bad159de0559a779d3c3da83770082ca60 Mon Sep 17 00:00:00 2001 From: Rust Saiargaliev Date: Fri, 12 Mar 2021 14:36:22 +0100 Subject: [PATCH] Fix creation of model instances with related model fields (#164) * Allow id to be used as FK default (Fix #136) https://github.com/model-bakers/model_bakery/pull/117 introduced an issue for cases where `id` values is provided as a default for foreign keys. This PR attempts to fix that issue by checking if generated value for FK field is an instance of this model - if not it uses `_id` as a field name (https://docs.djangoproject.com/en/3.1/ref/models/fields/#database-representation). * Another fix idea: use default unless given custom argument * Remove old way, do some cleanup * Update CHANGELOG.md --- CHANGELOG.md | 3 ++- model_bakery/baker.py | 13 +++++++------ tests/generic/models.py | 28 ++++++++++++++++++++++++++++ tests/test_filling_fields.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d39afbd..597982ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Changed - Type hinting fixed for Recipe "_model" parameter [PR #124](https://github.com/model-bakers/model_bakery/pull/124) -- Modify `setup.py` to not import the whole module for package data, but get it from `__about__.py` [PR #142](https://github.com/model-bakers/model_bakery/pull/142) +- Fixed a bug (introduced in 1.2.1) that was breaking creation of model instances with related model fields [PR #164](https://github.com/model-bakers/model_bakery/pull/164) +- [dev] Modify `setup.py` to not import the whole module for package data, but get it from `__about__.py` [PR #142](https://github.com/model-bakers/model_bakery/pull/142) - [dev] Add Dependabot config file [PR #146](https://github.com/model-bakers/model_bakery/pull/146) - [dev] Update Dependabot config file to support GH Actions and auto-rebase [PR #160](https://github.com/model-bakers/model_bakery/pull/160) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index bf3af9fb..7e2257f2 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -531,10 +531,11 @@ def generate_value(self, field: Field, commit: bool = True) -> Any: """Call the associated generator with a field passing all required args. Generator Resolution Precedence Order: - -- attr_mapping - mapping per attribute name - -- choices -- mapping from avaiable field choices - -- type_mapping - mapping from user defined type associated generators - -- default_mapping - mapping from pre-defined type associated + -- `field.default` - model field default value, unless explicitly overwritten during baking + -- `attr_mapping` - mapping per attribute name + -- `choices` -- mapping from available field choices + -- `type_mapping` - mapping from user defined type associated generators + -- `default_mapping` - mapping from pre-defined type associated generators `attr_mapping` and `type_mapping` can be defined easily overwriting the @@ -543,8 +544,8 @@ def generate_value(self, field: Field, commit: bool = True) -> Any: is_content_type_fk = isinstance(field, ForeignKey) and issubclass( self._remote_field(field).model, contenttypes.models.ContentType ) - - if field.has_default(): + # we only use default unless the field is overwritten in `self.rel_fields` + if field.has_default() and field.name not in self.rel_fields: if callable(field.default): return field.default() return field.default diff --git a/tests/generic/models.py b/tests/generic/models.py index 29e4419b..06240114 100755 --- a/tests/generic/models.py +++ b/tests/generic/models.py @@ -181,6 +181,10 @@ class LonelyPerson(models.Model): only_friend = models.OneToOneField(Person, on_delete=models.CASCADE) +class Cake(models.Model): + name = models.CharField(max_length=64) + + class RelatedNamesModel(models.Model): name = models.CharField(max_length=256) one_to_one = models.OneToOneField( @@ -191,6 +195,30 @@ class RelatedNamesModel(models.Model): ) +def get_default_cake_id(): + instance, _ = Cake.objects.get_or_create(name="Muffin") + return instance.id + + +class RelatedNamesWithDefaultsModel(models.Model): + name = models.CharField(max_length=256, default="Bravo") + cake = models.ForeignKey( + Cake, + on_delete=models.CASCADE, + default=get_default_cake_id, + ) + + +class RelatedNamesWithEmptyDefaultsModel(models.Model): + name = models.CharField(max_length=256, default="Bravo") + cake = models.ForeignKey( + Cake, + on_delete=models.CASCADE, + null=True, + default=None, + ) + + class ModelWithOverridedSave(Dog): def save(self, *args, **kwargs): self.owner = kwargs.pop("owner") diff --git a/tests/test_filling_fields.py b/tests/test_filling_fields.py index f471a4d1..6822aefe 100644 --- a/tests/test_filling_fields.py +++ b/tests/test_filling_fields.py @@ -276,6 +276,37 @@ def test_filling_content_type_field(self): assert dummy.content_type.model_class() is not None +@pytest.mark.django_db +class TestFillingForeignKeyFieldWithDefaultFunctionReturningId: + def test_filling_foreignkey_with_default_id(self): + dummy = baker.make(models.RelatedNamesWithDefaultsModel) + assert dummy.cake.__class__.objects.count() == 1 + assert dummy.cake.id == models.get_default_cake_id() + assert dummy.cake.name == "Muffin" + + def test_filling_foreignkey_with_default_id_with_custom_arguments(self): + dummy = baker.make( + models.RelatedNamesWithDefaultsModel, cake__name="Baumkuchen" + ) + assert dummy.cake.__class__.objects.count() == 1 + assert dummy.cake.id == dummy.cake.__class__.objects.get().id + assert dummy.cake.name == "Baumkuchen" + + +@pytest.mark.django_db +class TestFillingOptionalForeignKeyField: + def test_not_filling_optional_foreignkey(self): + dummy = baker.make(models.RelatedNamesWithEmptyDefaultsModel) + assert dummy.cake is None + + def test_filling_optional_foreignkey_implicitly(self): + dummy = baker.make( + models.RelatedNamesWithEmptyDefaultsModel, cake__name="Carrot cake" + ) + assert dummy.cake.__class__.objects.count() == 1 + assert dummy.cake.name == "Carrot cake" + + @pytest.mark.django_db class TestsFillingFileField: def test_filling_file_field(self):