diff --git a/changelog/count-method-filters.changed.md b/changelog/count-method-filters.changed.md new file mode 100644 index 0000000..13a15c5 --- /dev/null +++ b/changelog/count-method-filters.changed.md @@ -0,0 +1 @@ +Added possibility to use filters for the SDK client's count method diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 839f1ec..d90b1b3 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -547,8 +547,10 @@ async def count( at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, + **kwargs: Any, ) -> int: """Return the number of nodes of a given kind.""" + filters = kwargs schema = await self.schema.get(kind=kind, branch=branch) branch = branch or self.default_branch @@ -556,7 +558,10 @@ async def count( at = Timestamp(at) response = await self.execute_graphql( - query=Query(query={schema.kind: {"count": None}}).render(), branch_name=branch, at=at, timeout=timeout + query=Query(query={schema.kind: {"count": None, "@filters": filters}}).render(), + branch_name=branch, + at=at, + timeout=timeout, ) return int(response.get(schema.kind, {}).get("count", 0)) @@ -1651,8 +1656,10 @@ def count( at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, + **kwargs: Any, ) -> int: """Return the number of nodes of a given kind.""" + filters = kwargs schema = self.schema.get(kind=kind, branch=branch) branch = branch or self.default_branch @@ -1660,7 +1667,10 @@ def count( at = Timestamp(at) response = self.execute_graphql( - query=Query(query={schema.kind: {"count": None}}).render(), branch_name=branch, at=at, timeout=timeout + query=Query(query={schema.kind: {"count": None, "@filters": filters}}).render(), + branch_name=branch, + at=at, + timeout=timeout, ) return int(response.get(schema.kind, {}).get("count", 0)) diff --git a/tests/integration/test_infrahub_client.py b/tests/integration/test_infrahub_client.py index c18bc5f..3c60751 100644 --- a/tests/integration/test_infrahub_client.py +++ b/tests/integration/test_infrahub_client.py @@ -142,6 +142,14 @@ async def test_create_branch_async(self, client: InfrahubClient, base_dataset): task_id = await client.branch.create(branch_name="new-branch-2", wait_until_completion=False) assert isinstance(task_id, str) + async def test_count(self, client: InfrahubClient, base_dataset): + count = await client.count(kind=TESTING_PERSON) + assert count == 3 + + async def test_count_with_filter(self, client: InfrahubClient, base_dataset): + count = await client.count(kind=TESTING_PERSON, name__values=["Liam Walker", "Ethan Carter"]) + assert count == 2 + # async def test_get_generic_filter_source(self, client: InfrahubClient, base_dataset): # admin = await client.get(kind="CoreAccount", name__value="admin") diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py index 3f08780..f8624c1 100644 --- a/tests/unit/sdk/test_client.py +++ b/tests/unit/sdk/test_client.py @@ -83,6 +83,16 @@ async def test_method_count(clients, mock_query_repository_count, client_type): assert count == 5 +@pytest.mark.parametrize("client_type", client_types) +async def test_method_count_with_filter(clients, mock_query_repository_count, client_type): # pylint: disable=unused-argument + if client_type == "standard": + count = await clients.standard.count(kind="CoreRepository", name__value="test") + else: + count = clients.sync.count(kind="CoreRepository", name__value="test") + + assert count == 5 + + @pytest.mark.parametrize("client_type", client_types) async def test_method_get_version(clients, mock_query_infrahub_version, client_type): # pylint: disable=unused-argument if client_type == "standard":