diff --git a/channels/migrations/0013_alter_channel_channel_type.py b/channels/migrations/0013_alter_channel_channel_type.py new file mode 100644 index 0000000000..2ae9c01b06 --- /dev/null +++ b/channels/migrations/0013_alter_channel_channel_type.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.14 on 2024-07-17 12:46 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("channels", "0012_alter_channelunitdetail_unit"), + ] + + operations = [ + migrations.AlterField( + model_name="channel", + name="channel_type", + field=models.CharField( + choices=[ + ("topic", "Topic"), + ("department", "Department"), + ("unit", "Unit"), + ("pathway", "Pathway"), + ], + db_index=True, + max_length=100, + ), + ), + ] diff --git a/channels/models.py b/channels/models.py index 6a7b6e8547..c15a3032f8 100644 --- a/channels/models.py +++ b/channels/models.py @@ -89,7 +89,9 @@ class Channel(TimestampedModel): null=True, blank=True, max_length=2083, upload_to=banner_uri ) about = JSONField(blank=True, null=True) - channel_type = models.CharField(max_length=100, choices=ChannelType.as_tuple()) + channel_type = models.CharField( + max_length=100, choices=ChannelType.as_tuple(), db_index=True + ) configuration = models.JSONField(null=True, default=dict, blank=True) search_filter = models.CharField(max_length=2048, blank=True, default="") public_description = models.TextField(blank=True, default="") diff --git a/channels/serializers.py b/channels/serializers.py index 2a55d44238..2fd8c92037 100644 --- a/channels/serializers.py +++ b/channels/serializers.py @@ -9,7 +9,7 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError -from channels.api import add_user_role, is_moderator +from channels.api import add_user_role from channels.constants import CHANNEL_ROLE_MODERATORS, ChannelType from channels.models import ( Channel, @@ -78,9 +78,10 @@ class ChannelAppearanceMixin(serializers.Serializer): def get_is_moderator(self, instance) -> bool: """Return true if user is a moderator for the channel""" request = self.context.get("request") - if request and is_moderator(request.user, instance.id): - return True - return False + if not request or not request.user or not instance: + return False + moderated_channel_ids = self.context.get("moderated_channel_ids", []) + return request.user.is_staff or instance.id in moderated_channel_ids def get_avatar(self, channel) -> str | None: """Get the avatar image URL""" diff --git a/channels/views.py b/channels/views.py index 41ee1f8e2b..92d6a32a0c 100644 --- a/channels/views.py +++ b/channels/views.py @@ -14,7 +14,7 @@ from channels.api import get_group_role_name, remove_user_role from channels.constants import CHANNEL_ROLE_MODERATORS -from channels.models import Channel, ChannelList +from channels.models import Channel, ChannelGroupRole, ChannelList from channels.permissions import ChannelModeratorPermissions, HasChannelPermission from channels.serializers import ( ChannelCreateSerializer, @@ -22,7 +22,7 @@ ChannelSerializer, ChannelWriteSerializer, ) -from learning_resources.views import LargePagination +from learning_resources.views import DefaultPagination from main.constants import VALID_HTTP_METHODS from main.permissions import AnonymousAccessReadonlyPermission @@ -64,7 +64,7 @@ class ChannelViewSet( or organizations at MIT and are a high-level categorization of content. """ - pagination_class = LargePagination + pagination_class = DefaultPagination permission_classes = (HasChannelPermission,) http_method_names = VALID_HTTP_METHODS lookup_field = "id" @@ -72,6 +72,19 @@ class ChannelViewSet( filter_backends = [DjangoFilterBackend] filterset_fields = ["channel_type"] + def get_serializer_context(self): + context = super().get_serializer_context() + """Return the context data""" + moderated_channel_ids = [] + if self.request.user and self.request.user.is_authenticated: + moderated_channel_ids = ( + ChannelGroupRole.objects.select_related("group") + .filter(role=CHANNEL_ROLE_MODERATORS, group__user=self.request.user) + .values_list("channel_id", flat=True) + ) + context["moderated_channel_ids"] = moderated_channel_ids + return context + def get_queryset(self): """Return a queryset""" return ( @@ -90,7 +103,11 @@ def get_queryset(self): ) .annotate_channel_url() .select_related( - "featured_list", "topic_detail", "department_detail", "unit_detail" + "featured_list", + "topic_detail", + "department_detail", + "unit_detail", + "pathway_detail", ) .all() ) diff --git a/channels/views_test.py b/channels/views_test.py index 024c917279..d1ad6d1b72 100644 --- a/channels/views_test.py +++ b/channels/views_test.py @@ -1,6 +1,7 @@ """Tests for channels.views""" import os +from math import ceil import pytest from django.contrib.auth.models import Group, User @@ -30,8 +31,8 @@ def test_list_channels(user_client): channels = sorted(ChannelFactory.create_batch(15), key=lambda f: f.id) url = reverse("channels:v0:channels_api-list") channel_list = sorted(user_client.get(url).json()["results"], key=lambda f: f["id"]) - assert len(channel_list) == len(channels) - for idx, channel in enumerate(channels): + assert len(channel_list) == 10 + for idx, channel in enumerate(channels[:10]): assert channel_list[idx] == ChannelSerializer(instance=channel).data @@ -372,13 +373,15 @@ def test_delete_moderator_forbidden(channel, user_client): @pytest.mark.parametrize("related_count", [1, 5, 10]) -def test_no_excess_queries(user_client, django_assert_num_queries, related_count): +def test_no_excess_detail_queries( + user_client, django_assert_num_queries, related_count +): """ There should be a constant number of queries made, independent of number of sub_channels / lists. """ # This isn't too important; we care it does not scale with number of related items - expected_query_count = 10 + expected_query_count = 9 topic_channel = ChannelFactory.create(is_topic=True) ChannelListFactory.create_batch(related_count, channel=topic_channel) @@ -392,6 +395,28 @@ def test_no_excess_queries(user_client, django_assert_num_queries, related_count user_client.get(url) +@pytest.mark.parametrize("channel_count", [2, 20, 200]) +def test_no_excess_list_queries(client, user, django_assert_num_queries, channel_count): + """ + There should be a constant number of queries made (based on number of + related models), regardless of number of channel results returned. + """ + ChannelFactory.create_batch(channel_count, is_pathway=True) + + assert Channel.objects.count() == channel_count + + client.force_login(user) + for page in range(ceil(channel_count / 10)): + with django_assert_num_queries(6): + results = client.get( + reverse("channels:v0:channels_api-list"), + data={"limit": 10, "offset": page * 10}, + ) + assert len(results.data["results"]) == min(channel_count, 10) + for result in results.data["results"]: + assert result["channel_url"] is not None + + def test_channel_configuration_is_not_editable(client, channel): """Test that the 'configuration' object is read-only""" url = reverse(