Skip to content

Commit

Permalink
Fix SCIM search API sort and pagination
Browse files Browse the repository at this point in the history
  • Loading branch information
rhysyngsun committed Feb 24, 2025
1 parent aeb59d3 commit 0c9ddb3
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 24 deletions.
7 changes: 4 additions & 3 deletions scim/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

ol_scim_urls = (
[
re_path("^Bulk$", views.BulkView.as_view(), name="bulk"),
re_path(r"^Bulk$", views.BulkView.as_view(), name="bulk"),
re_path(r"^\.search$", views.SearchView.as_view(), name="users-search"),
],
"ol-scim",
)

urlpatterns = [
re_path("^scim/v2/", include(ol_scim_urls)),
re_path("^scim/v2/", include("django_scim.urls", namespace="scim")),
re_path(r"^scim/v2/", include(ol_scim_urls)),
re_path(r"^scim/v2/", include("django_scim.urls", namespace="scim")),
]
56 changes: 54 additions & 2 deletions scim/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import json
import logging
from http import HTTPStatus
from urllib.parse import urlparse
from urllib.parse import urljoin, urlparse

from django.http import HttpRequest, HttpResponse
from django.urls import Resolver404, resolve
from django.urls import Resolver404, resolve, reverse
from django_scim import constants as djs_constants
from django_scim import exceptions
from django_scim import views as djs_views
from django_scim.utils import get_base_scim_location_getter

from scim import constants

Expand Down Expand Up @@ -158,3 +159,54 @@ def _operation_error(self, method, bulk_id, status_code, detail):
"detail": detail,
},
}


class SearchView(djs_views.UserSearchView):
"""
View for /.search endpoint
"""

def post(self, request, *args, **kwargs): # noqa: ARG002
body = self.load_body(request.body)
if body.get("schemas") != [djs_constants.SchemaURI.SERACH_REQUEST]:
msg = "Invalid schema uri. Must be SearchRequest."
raise exceptions.BadRequestError(msg)

start = body.get("startIndex", 1)
count = body.get("count", 50)
sort_by = body.get("sortBy", None)
sort_order = body.get("sortOrder", "ascending")
query = body.get("filter", None)

if sort_by is not None and sort_by not in ("email", "username"):
msg = "Sorting only supports email or username"
raise exceptions.BadRequestError(msg)

if sort_order is not None and sort_order not in ("ascending", "descending"):
msg = "Sorting only supports ascending or descending"
raise exceptions.BadRequestError(msg)

if not query:
msg = "No filter query specified"
raise exceptions.BadRequestError(msg)

try:
qs = self.__class__.parser_getter().search(query, request)
except ValueError as e:
msg = "Invalid filter/search query: " + str(e)
raise exceptions.BadRequestError(msg) from e

if sort_by is not None:
qs = qs.order_by(sort_by)

if sort_order == "descending":
qs = qs.reverse()

response = self._build_response(request, qs, start, count)

path = reverse(self.scim_adapter.url_name)
url = urljoin(get_base_scim_location_getter()(request=request), path).rstrip(
"/"
)
response["Location"] = url + "/.search"
return response
96 changes: 77 additions & 19 deletions scim/views_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,28 +415,86 @@ def test_bulk_post(scim_client, bulk_test_data):
assert actual_value == expected_value


def test_user_search(scim_client):
@pytest.mark.parametrize(
("sort_by", "sort_order"),
[
(None, None),
("email", None),
("email", "ascending"),
("email", "descending"),
("username", None),
("username", "ascending"),
("username", "descending"),
],
)
@pytest.mark.parametrize("count", [None, 100, 500])
def test_user_search(scim_client, sort_by, sort_order, count):
"""Test the user search endpoint"""
users = UserFactory.create_batch(1500)
emails = [user.email for user in users[:1000]]
large_user_set = UserFactory.create_batch(1100)
search_users = large_user_set[:1000]
emails = [user.email for user in search_users]

resp = scim_client.post(
f"{reverse('scim:users-search')}?count={len(emails)}",
content_type="application/scim+json",
data=json.dumps(
{
"schemas": [djs_constants.SchemaURI.SERACH_REQUEST],
"filter": " OR ".join([f'email EQ "{email}"' for email in emails]),
}
),
)
expected = search_users

assert resp.status_code == 200
effective_count = count or 50
effective_sort_order = sort_order or "ascending"

data = resp.json()
if sort_by is not None:
expected = sorted(
expected,
# postgres sort is case-insensitive
key=lambda user: getattr(user, sort_by).lower(),
reverse=effective_sort_order == "descending",
)

assert data["totalResults"] == len(emails)
assert len(data["Resources"]) == len(emails)
for page in range(int(len(emails) / effective_count)):
start_index = page * effective_count # zero based index
resp = scim_client.post(
reverse("ol-scim:users-search"),
content_type="application/scim+json",
data=json.dumps(
{
"schemas": [djs_constants.SchemaURI.SERACH_REQUEST],
"filter": " OR ".join([f'email EQ "{email}"' for email in emails]),
"startIndex": start_index + 1, # SCIM API is 1-based index
**({"sortBy": sort_by} if sort_by is not None else {}),
**({"sortOrder": sort_order} if sort_order is not None else {}),
**({"count": count} if count is not None else {}),
}
),
)

for resource in data["Resources"]:
assert resource["emails"][0]["value"] in emails
expected_in_resp = expected[start_index : start_index + effective_count]

assert resp.status_code == 200, f"Got error: {resp.content}"
assert resp.json() == {
"totalResults": len(emails),
"itemsPerPage": effective_count,
"startIndex": start_index + 1,
"schemas": [djs_constants.SchemaURI.LIST_RESPONSE],
"Resources": [
{
"id": user.profile.scim_id,
"active": user.is_active,
"userName": user.username,
"displayName": user.profile.name,
"emails": [{"value": user.email, "primary": True}],
"externalId": str(user.profile.scim_external_id),
"name": {
"givenName": user.first_name,
"familyName": user.last_name,
},
"meta": {
"resourceType": "User",
"location": f"https://localhost/scim/v2/Users/{user.profile.scim_id}",
"lastModified": user.profile.updated_at.isoformat(
timespec="milliseconds"
),
"created": user.date_joined.isoformat(timespec="milliseconds"),
},
"groups": [],
"schemas": [djs_constants.SchemaURI.USER],
}
for user in expected_in_resp
],
}

0 comments on commit 0c9ddb3

Please sign in to comment.