Skip to content

Commit

Permalink
Change behavior of include to only include the requested fields
Browse files Browse the repository at this point in the history
  • Loading branch information
dgarros committed Jan 5, 2025
1 parent c3315c7 commit 74bad3b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 61 deletions.
2 changes: 2 additions & 0 deletions changelog/192.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
When using include as part of a query, the default behavior has changed.
Only the fields defined in `include` will be returned now, where previously the Python SDK would add the fields provided to the default one.
12 changes: 5 additions & 7 deletions infrahub_sdk/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,7 @@ def generate_query_data_init(
data["@filters"]["limit"] = limit

if include and exclude:
in_both, _, _ = compare_lists(include, exclude)
if in_both:
raise ValueError(f"{in_both} are part of both include and exclude")
raise ValueError("include and exclude are exclusive, they shouldn't be used together")

if partial_match:
data["@filters"]["partial_match"] = True
Expand Down Expand Up @@ -1244,7 +1242,7 @@ async def generate_query_data_node(
data: dict[str, Any] = {}

for attr_name in self._attributes:
if exclude and attr_name in exclude:
if (exclude and attr_name in exclude) or (include and attr_name not in include):
continue

attr: Attribute = getattr(self, attr_name)
Expand All @@ -1262,7 +1260,7 @@ async def generate_query_data_node(
data[attr_name] = {"@alias": f"__alias__{self._schema.kind}__{attr_name}"}

for rel_name in self._relationships:
if exclude and rel_name in exclude:
if (exclude and rel_name in exclude) or (include and rel_name not in include):
continue

rel_schema = self._schema.get_relationship(name=rel_name)
Expand Down Expand Up @@ -1749,7 +1747,7 @@ def generate_query_data_node(
data: dict[str, Any] = {}

for attr_name in self._attributes:
if exclude and attr_name in exclude:
if (exclude and attr_name in exclude) or (include and attr_name not in include):
continue

attr: Attribute = getattr(self, attr_name)
Expand All @@ -1767,7 +1765,7 @@ def generate_query_data_node(
data[attr_name] = {"@alias": f"__alias__{self._schema.kind}__{attr_name}"}

for rel_name in self._relationships:
if exclude and rel_name in exclude:
if (exclude and rel_name in exclude) or (include and rel_name not in include):
continue

rel_schema = self._schema.get_relationship(name=rel_name)
Expand Down
73 changes: 19 additions & 54 deletions tests/unit/sdk/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,10 +991,10 @@ async def test_query_data_generic_fragment(clients, mock_schema_query_02, client
async def test_query_data_include_property(client, location_schema: NodeSchemaAPI, client_type):
if client_type == "standard":
node = InfrahubNode(client=client, schema=location_schema)
data = await node.generate_query_data(include=["tags"], property=True)
data = await node.generate_query_data(include=["name", "type", "tags"], property=True)
else:
node = InfrahubNodeSync(client=client, schema=location_schema)
data = node.generate_query_data(include=["tags"], property=True)
data = node.generate_query_data(include=["name", "type", "tags"], property=True)

assert data == {
"BuiltinLocation": {
Expand Down Expand Up @@ -1023,23 +1023,6 @@ async def test_query_data_include_property(client, location_schema: NodeSchemaAP
},
"value": None,
},
"description": {
"is_default": None,
"is_from_profile": None,
"is_protected": None,
"is_visible": None,
"owner": {
"__typename": None,
"display_label": None,
"id": None,
},
"source": {
"__typename": None,
"display_label": None,
"id": None,
},
"value": None,
},
"type": {
"is_default": None,
"is_from_profile": None,
Expand All @@ -1057,28 +1040,6 @@ async def test_query_data_include_property(client, location_schema: NodeSchemaAP
},
"value": None,
},
"primary_tag": {
"properties": {
"is_protected": None,
"is_visible": None,
"owner": {
"__typename": None,
"display_label": None,
"id": None,
},
"source": {
"__typename": None,
"display_label": None,
"id": None,
},
},
"node": {
"id": None,
"hfid": None,
"display_label": None,
"__typename": None,
},
},
"tags": {
"count": None,
"edges": {
Expand Down Expand Up @@ -1113,10 +1074,10 @@ async def test_query_data_include_property(client, location_schema: NodeSchemaAP
async def test_query_data_include(client, location_schema: NodeSchemaAPI, client_type):
if client_type == "standard":
node = InfrahubNode(client=client, schema=location_schema)
data = await node.generate_query_data(include=["tags"])
data = await node.generate_query_data(include=["name", "type", "tags"])
else:
node = InfrahubNodeSync(client=client, schema=location_schema)
data = node.generate_query_data(include=["tags"])
data = node.generate_query_data(include=["name", "type", "tags"])

assert data == {
"BuiltinLocation": {
Expand All @@ -1131,20 +1092,9 @@ async def test_query_data_include(client, location_schema: NodeSchemaAPI, client
"name": {
"value": None,
},
"description": {
"value": None,
},
"type": {
"value": None,
},
"primary_tag": {
"node": {
"id": None,
"hfid": None,
"display_label": None,
"__typename": None,
},
},
"tags": {
"count": None,
"edges": {
Expand Down Expand Up @@ -1251,6 +1201,21 @@ async def test_query_data_exclude(client, location_schema: NodeSchemaAPI, client
}


@pytest.mark.parametrize("client_type", client_types)
async def test_query_data_include_exclude(client, location_schema: NodeSchemaAPI, client_type):
if client_type == "standard":
node = InfrahubNode(client=client, schema=location_schema)

with pytest.raises(ValueError) as exc:
await node.generate_query_data(include=["name", "type"], exclude=["description"], property=True)
assert "include and exclude are exclusive" in str(exc.value)
else:
node = InfrahubNodeSync(client=client, schema=location_schema)
with pytest.raises(ValueError) as exc:
node.generate_query_data(include=["name", "type", "tags"], exclude=["description"], property=True)
assert "include and exclude are exclusive" in str(exc.value)


@pytest.mark.parametrize("client_type", client_types)
async def test_create_input_data(client, location_schema: NodeSchemaAPI, client_type):
data = {"name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, "type": {"value": "SITE"}}
Expand Down

0 comments on commit 74bad3b

Please sign in to comment.