diff --git a/ai_chat/agents.py b/ai_chat/agents.py index 586363ab2d..9f544dcb55 100644 --- a/ai_chat/agents.py +++ b/ai_chat/agents.py @@ -62,6 +62,7 @@ def __init__( # noqa: PLR0913 save_history: Optional[bool] = False, cache_key: Optional[str] = None, cache_timeout: Optional[int] = None, + collection_name: Optional[str] = None, ): """Initialize the AI chat agent service""" self.assistant_name = name @@ -70,6 +71,7 @@ def __init__( # noqa: PLR0913 self.save_history = save_history self.temperature = temperature or DEFAULT_TEMPERATURE self.instructions = instructions or self.INSTRUCTIONS + self.collection_name = collection_name self.user_id = user_id if settings.AI_PROXY_CLASS: self.proxy = import_string(f"ai_chat.proxy.{settings.AI_PROXY_CLASS}")() @@ -355,6 +357,7 @@ def __init__( # noqa: PLR0913 save_history: Optional[bool] = False, cache_key: Optional[str] = None, cache_timeout: Optional[int] = None, + collection_name: Optional[str] = None, ): """Initialize the AI search agent service""" super().__init__( @@ -366,6 +369,7 @@ def __init__( # noqa: PLR0913 user_id=user_id, cache_key=cache_key, cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT, + collection_name=collection_name, ) self.search_parameters = [] self.search_results = [] @@ -520,6 +524,7 @@ def __init__( # noqa: PLR0913 save_history: Optional[bool] = False, cache_key: Optional[str] = None, cache_timeout: Optional[int] = None, + collection_name: Optional[str] = None, ): """Initialize the AI search agent service""" super().__init__( @@ -531,9 +536,11 @@ def __init__( # noqa: PLR0913 user_id=user_id, cache_key=cache_key, cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT, + collection_name=collection_name, ) self.search_parameters = [] self.search_results = [] + self.collection_name = collection_name self.agent = self.create_agent() self.create_agent() @@ -550,6 +557,8 @@ def search_content_files(self) -> str: "resource_readable_id": self.readable_id, "limit": 20, } + if self.collection_name: + params["collection_name"] = self.collection_name self.search_parameters.append(params) try: response = requests.get(url, params=params, timeout=30) diff --git a/ai_chat/agents_test.py b/ai_chat/agents_test.py index 309f03183c..5d186713cb 100644 --- a/ai_chat/agents_test.py +++ b/ai_chat/agents_test.py @@ -8,7 +8,7 @@ from llama_index.core.base.llms.types import MessageRole from llama_index.core.constants import DEFAULT_TEMPERATURE -from ai_chat.agents import SearchAgent +from ai_chat.agents import SearchAgent, SyllabusAgent from ai_chat.factories import ChatMessageFactory from learning_resources.factories import LearningResourceFactory from learning_resources.serializers import ( @@ -268,3 +268,27 @@ def test_get_completion(mocker): }, ) assert "".join([str(value) for value in expected_return_value]) in results + + +@pytest.mark.django_db +def test_collection_name_param(settings, mocker): + """The collection name should be passed through to the contentfile search""" + settings.AI_MIT_SEARCH_LIMIT = 5 + settings.AI_MIT_SYLLABUS_URL = "https://test.com/api/v0/contentfiles/" + mock_post = mocker.patch( + "ai_chat.agents.requests.get", + return_value=mocker.Mock(json=mocker.Mock(return_value={})), + ) + search_agent = SyllabusAgent("test agent", collection_name="content_files_512") + search_agent.get_completion("I want to learn physics", readable_id="test") + search_agent.search_content_files() + mock_post.assert_called_once_with( + settings.AI_MIT_SYLLABUS_URL, + params={ + "q": "I want to learn physics", + "resource_readable_id": "test", + "limit": 20, + "collection_name": "content_files_512", + }, + timeout=30, + ) diff --git a/ai_chat/serializers.py b/ai_chat/serializers.py index 076a2d6294..0f0e569773 100644 --- a/ai_chat/serializers.py +++ b/ai_chat/serializers.py @@ -41,3 +41,4 @@ class SyllabusChatRequestSerializer(ChatRequestSerializer): """DRF serializer for syllabus chatbot requests""" readable_id = serializers.CharField(required=True) + collection_name = serializers.CharField(required=False) diff --git a/ai_chat/views.py b/ai_chat/views.py index 5dc13a589a..9ac6b86d37 100644 --- a/ai_chat/views.py +++ b/ai_chat/views.py @@ -104,12 +104,14 @@ def post(self, request: Request) -> StreamingHttpResponse: user_id = request.user.email if request.user.is_authenticated else "anonymous" message = serializer.validated_data.pop("message", "") readable_id = (serializer.validated_data.pop("readable_id"),) + collection_name = (serializer.validated_data.pop("collection_name"),) clear_history = serializer.validated_data.pop("clear_history", False) agent = SyllabusAgent( "Learning Resource Search AI Assistant", user_id=user_id, cache_key=f"{cache_id}_search_chat_history", save_history=True, + collection_name=collection_name, **serializer.validated_data, ) if clear_history: diff --git a/frontends/api/src/generated/v0/api.ts b/frontends/api/src/generated/v0/api.ts index c61e6d5546..3dc7b32264 100644 --- a/frontends/api/src/generated/v0/api.ts +++ b/frontends/api/src/generated/v0/api.ts @@ -5027,6 +5027,12 @@ export interface SyllabusChatRequestRequest { * @memberof SyllabusChatRequestRequest */ readable_id: string + /** + * + * @type {string} + * @memberof SyllabusChatRequestRequest + */ + collection_name?: string } /** * * `0-to-5-hours` - <5 hours/week * `5-to-10-hours` - 5-10 hours/week * `10-to-20-hours` - 10-20 hours/week * `20-to-30-hours` - 20-30 hours/week * `30-plus-hours` - 30+ hours/week diff --git a/frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx b/frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx index 5219aa6ceb..f8bc3c6402 100644 --- a/frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx +++ b/frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx @@ -51,6 +51,7 @@ const StyledDebugPre = styled.pre({ const ChatSyllabusPage = () => { const botEnabled = useFeatureFlagEnabled(FeatureFlags.RecommendationBot) const [readableId, setReadableId] = useState("18.06SC+fall_2011") + const [collectionName, setCollectionName] = useState("content_files") const [debugInfo, setDebugInfo] = useState("") const parseContent = (content: string | unknown) => { @@ -105,6 +106,18 @@ const ChatSyllabusPage = () => { Introduction to Differential Equations(EdX) + Contentfile Chunk Size + @@ -124,6 +137,7 @@ const ChatSyllabusPage = () => { return { message: messages[messages.length - 1].content, readable_id: readableId, + collection_name: collectionName, } }, }} diff --git a/openapi/specs/v0.yaml b/openapi/specs/v0.yaml index 39e44bc3e7..9f7e004d7f 100644 --- a/openapi/specs/v0.yaml +++ b/openapi/specs/v0.yaml @@ -5057,6 +5057,9 @@ components: readable_id: type: string minLength: 1 + collection_name: + type: string + minLength: 1 required: - message - readable_id