Skip to content

Commit

Permalink
Change attr_type from list to str for MultipleChoiceFilter (#17638)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremystretch authored Oct 3, 2024
1 parent 648aeaa commit f11dc00
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
2 changes: 1 addition & 1 deletion netbox/dcim/filtersets.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM

class Meta:
model = Location
fields = ('id', 'name', 'slug', 'status', 'facility', 'description')
fields = ('id', 'name', 'slug', 'facility', 'description')

def search(self, queryset, name, value):
if not value.strip():
Expand Down
8 changes: 4 additions & 4 deletions netbox/netbox/graphql/filter_mixins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import partial, partialmethod, wraps
from functools import partialmethod
from typing import List

import django_filters
import strawberry
import strawberry_django
from django.core.exceptions import FieldDoesNotExist, ValidationError
from django.core.exceptions import FieldDoesNotExist
from strawberry import auto

from ipam.fields import ASNField
from netbox.graphql.scalars import BigInt
from utilities.fields import ColorField, CounterCacheField
Expand Down Expand Up @@ -108,8 +109,7 @@ def map_strawberry_type(field):
elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter):
pass
elif issubclass(type(field), django_filters.MultipleChoiceFilter):
should_create_function = True
attr_type = List[str] | None
attr_type = str | None
elif issubclass(type(field), django_filters.TypedChoiceFilter):
pass
elif issubclass(type(field), django_filters.ChoiceFilter):
Expand Down
41 changes: 37 additions & 4 deletions netbox/netbox/tests/test_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from rest_framework import status

from core.models import ObjectType
from dcim.choices import LocationStatusChoices
from dcim.models import Site, Location
from ipam.models import ASN, RIR
from users.models import ObjectPermission
from utilities.testing import disable_warnings, APITestCase, TestCase

Expand Down Expand Up @@ -53,10 +53,27 @@ def test_graphql_filter_objects(self):
sites = (
Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'),
Site(name='Site 3', slug='site-3'),
)
Site.objects.bulk_create(sites)
Location.objects.create(site=sites[0], name='Location 1', slug='location-1'),
Location.objects.create(site=sites[1], name='Location 2', slug='location-2'),
Location.objects.create(
site=sites[0],
name='Location 1',
slug='location-1',
status=LocationStatusChoices.STATUS_PLANNED
),
Location.objects.create(
site=sites[1],
name='Location 2',
slug='location-2',
status=LocationStatusChoices.STATUS_STAGING
),
Location.objects.create(
site=sites[1],
name='Location 3',
slug='location-3',
status=LocationStatusChoices.STATUS_ACTIVE
),

# Add object-level permission
obj_perm = ObjectPermission(
Expand All @@ -68,8 +85,9 @@ def test_graphql_filter_objects(self):
obj_perm.object_types.add(ObjectType.objects.get_for_model(Location))
obj_perm.object_types.add(ObjectType.objects.get_for_model(Site))

# A valid request should return the filtered list
url = reverse('graphql')

# A valid request should return the filtered list
query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}'
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
Expand All @@ -78,6 +96,21 @@ def test_graphql_filter_objects(self):
self.assertEqual(len(data['data']['location_list']), 1)
self.assertIsNotNone(data['data']['location_list'][0]['site'])

# Test OR logic
query = """{
location_list( filters: {
status: \"""" + LocationStatusChoices.STATUS_PLANNED + """\",
OR: {status: \"""" + LocationStatusChoices.STATUS_STAGING + """\"}
}) {
id site {id}
}
}"""
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertNotIn('errors', data)
self.assertEqual(len(data['data']['location_list']), 2)

# An invalid request should return an empty list
query = '{location_list(filters: {site_id: "99999"}) {id site {id}}}' # Invalid site ID
response = self.client.post(url, data={'query': query}, format="json", **self.header)
Expand Down

0 comments on commit f11dc00

Please sign in to comment.