Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generate fixture for PostGenerationMethodCall declaration #103

Merged
merged 11 commits into from
May 13, 2022
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Changelog

Unreleased
----------
- Add support for ``factory.PostGenerationMethodCall`` `#103 <https://github.com/pytest-dev/pytest-factoryboy/pull/103>`_ `#87 <https://github.com/pytest-dev/pytest-factoryboy/issues/87>`_.


2.2.1
----------
Expand Down
91 changes: 45 additions & 46 deletions pytest_factoryboy/fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import factory.declarations
import factory.enums
import inflection
from factory.declarations import NotProvided

from .codegen import FixtureDef, make_fixture_model_module
from .compat import PostGenerationContext
Expand Down Expand Up @@ -96,57 +97,52 @@ def register_(factory_class: F) -> F:
args = []
attr_name = SEPARATOR.join((model_name, attr))

if isinstance(value, factory.declarations.PostGeneration):
value = kwargs.get(attr, None)
if isinstance(value, LazyFixture):
args = value.args
if isinstance(value, (factory.SubFactory, factory.RelatedFactory)):
value = kwargs.get(attr, value)
subfactory_class = value.get_factory()
subfactory_deps = get_deps(subfactory_class, factory_class)

args = list(subfactory_deps)
if isinstance(value, factory.RelatedFactory):
related_model = get_model_name(subfactory_class)
args.append(related_model)
related.append(related_model)
related.append(attr_name)
related.extend(subfactory_deps)

if isinstance(value, factory.SubFactory):
args.append(inflection.underscore(subfactory_class._meta.model.__name__))

fixture_defs.append(
FixtureDef(
name=attr_name,
function_name="attr_fixture",
function_kwargs={"value": value},
function_name="subfactory_fixture",
function_kwargs={"factory_class": subfactory_class},
deps=args,
)
)
continue

if isinstance(value, factory.PostGeneration):
default_value = None
elif isinstance(value, factory.PostGenerationMethodCall):
default_value = value.method_arg
else:
value = kwargs.get(attr, value)
default_value = value

if isinstance(value, (factory.SubFactory, factory.RelatedFactory)):
subfactory_class = value.get_factory()
subfactory_deps = get_deps(subfactory_class, factory_class)

args = list(subfactory_deps)
if isinstance(value, factory.RelatedFactory):
related_model = get_model_name(subfactory_class)
args.append(related_model)
related.append(related_model)
related.append(attr_name)
related.extend(subfactory_deps)

if isinstance(value, factory.SubFactory):
args.append(inflection.underscore(subfactory_class._meta.model.__name__))

fixture_defs.append(
FixtureDef(
name=attr_name,
function_name="subfactory_fixture",
function_kwargs={"factory_class": subfactory_class},
deps=args,
)
)
else:
if isinstance(value, LazyFixture):
args = value.args

fixture_defs.append(
FixtureDef(
name=attr_name,
function_name="attr_fixture",
function_kwargs={"value": value},
deps=args,
)
)
value = kwargs.get(attr, default_value)

if isinstance(value, LazyFixture):
args = value.args

fixture_defs.append(
FixtureDef(
name=attr_name,
function_name="attr_fixture",
function_kwargs={"value": value},
deps=args,
)
)

if factory_name not in _caller_locals:
fixture_defs.append(
Expand Down Expand Up @@ -227,7 +223,7 @@ def is_dep(value: Any) -> bool:
return False
if isinstance(value, factory.SubFactory) and get_model_name(value.get_factory()) == parent_model_name:
return False
if isinstance(value, factory.declarations.PostGeneration):
if isinstance(value, factory.declarations.PostGenerationDeclaration):
# Dependency on extracted value
return True

Expand Down Expand Up @@ -305,10 +301,13 @@ class Factory(factory_class):
extra[k] = evaluate(request, request.getfixturevalue(post_attr))
else:
extra[k] = v

# Handle special case for ``PostGenerationMethodCall`` where
# `attr_fixture` value is equal to ``NotProvided``, which mean
# that `value_provided` should be falsy
postgen_value = evaluate(request, request.getfixturevalue(argname))
postgen_context = PostGenerationContext(
value_provided=True,
value=evaluate(request, request.getfixturevalue(argname)),
value_provided=(postgen_value is not NotProvided),
value=postgen_value,
extra=extra,
)
deferred.append(
Expand Down
75 changes: 68 additions & 7 deletions tests/test_postgen_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import factory
import pytest
from factory.declarations import NotProvided

from pytest_factoryboy import register

Expand All @@ -19,6 +20,14 @@
class Foo:
value: int
expected: int
secret: str = ""
number: int = 4

def set_secret(self, new_secret: str) -> None:
self.secret = new_secret

def set_number(self, new_number: int = 987) -> None:
self.number = new_number

bar: Bar | None = None

Expand Down Expand Up @@ -47,22 +56,29 @@ class Meta:

@register
class FooFactory(factory.Factory):

"""Foo factory."""

class Meta:
model = Foo

value = 0
#: Value that is expected at the constructor
expected = 0
"""Value that is expected at the constructor."""

secret = factory.PostGenerationMethodCall("set_secret", "super secret")
number = factory.PostGenerationMethodCall("set_number")

@factory.post_generation
def set1(foo: Foo, create: bool, value: Any, **kwargs: Any) -> str:
foo.value = 1
return "set to 1"

baz = factory.RelatedFactory(BazFactory, "foo")
baz = factory.RelatedFactory(BazFactory, factory_related_name="foo")

@factory.post_generation
def set2(foo, create, value, **kwargs):
if create and value:
foo.value = value

@classmethod
def _after_postgeneration(cls, obj: Foo, create: bool, results: dict[str, Any] | None = None) -> None:
Expand All @@ -71,7 +87,6 @@ def _after_postgeneration(cls, obj: Foo, create: bool, results: dict[str, Any] |


class BarFactory(factory.Factory):

"""Bar factory."""

foo = factory.SubFactory(FooFactory)
Expand Down Expand Up @@ -107,15 +122,61 @@ def test_getfixturevalue(request, factoryboy_request: Request):
foo = request.getfixturevalue("foo")
assert not factoryboy_request.deferred
assert foo.value == 1
assert foo.secret == "super secret"
assert foo.number == 987


def test_postgenerationmethodcall_getfixturevalue(request, factoryboy_request):
"""Test default fixture value generated for ``PostGenerationMethodCall``."""
secret = request.getfixturevalue("foo__secret")
number = request.getfixturevalue("foo__number")
assert not factoryboy_request.deferred
assert secret == "super secret"
assert number is NotProvided


def test_postgeneration_getfixturevalue(request, factoryboy_request):
"""Ensure default fixture value generated for ``PostGeneration`` is `None`."""
set1 = request.getfixturevalue("foo__set1")
set2 = request.getfixturevalue("foo__set2")
assert not factoryboy_request.deferred
assert set1 is None
assert set2 is None


def test_after_postgeneration(foo: Foo):
"""Test _after_postgeneration is called."""
assert foo._create is True

foo._postgeneration_results["set1"] == "set to 1"
foo._postgeneration_results["baz"].foo is foo
assert len(foo._postgeneration_results) == 2
assert foo._postgeneration_results["set1"] == "set to 1"
assert foo._postgeneration_results["set2"] is None
assert foo._postgeneration_results["secret"] is None
assert foo._postgeneration_results["number"] is None


@pytest.mark.xfail(reason="This test has been broken for a long time, we only discovered it recently")
def test_postgen_related(foo: Foo):
"""Test that the initiating object `foo` is passed to the RelatedFactory `BazFactory`."""
baz = foo._postgeneration_results["baz"]
assert baz.foo is foo


@pytest.mark.parametrize("foo__set2", [123])
def test_postgeneration_fixture(foo: Foo):
"""Test fixture for ``PostGeneration`` declaration."""
assert foo.value == 123


@pytest.mark.parametrize(
("foo__secret", "foo__number"),
[
("test secret", 456),
],
)
Comment on lines +170 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make 2 different calls to parametrize. It should look more neat

def test_postgenerationmethodcall_fixture(foo: Foo):
"""Test fixture for ``PostGenerationMethodCall`` declaration."""
assert foo.secret == "test secret"
assert foo.number == 456


@dataclass
Expand Down