diff --git a/ai_chat/agents.py b/ai_chat/agents.py
index 586363ab2d..9f544dcb55 100644
--- a/ai_chat/agents.py
+++ b/ai_chat/agents.py
@@ -62,6 +62,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
+ collection_name: Optional[str] = None,
):
"""Initialize the AI chat agent service"""
self.assistant_name = name
@@ -70,6 +71,7 @@ def __init__( # noqa: PLR0913
self.save_history = save_history
self.temperature = temperature or DEFAULT_TEMPERATURE
self.instructions = instructions or self.INSTRUCTIONS
+ self.collection_name = collection_name
self.user_id = user_id
if settings.AI_PROXY_CLASS:
self.proxy = import_string(f"ai_chat.proxy.{settings.AI_PROXY_CLASS}")()
@@ -355,6 +357,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
+ collection_name: Optional[str] = None,
):
"""Initialize the AI search agent service"""
super().__init__(
@@ -366,6 +369,7 @@ def __init__( # noqa: PLR0913
user_id=user_id,
cache_key=cache_key,
cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT,
+ collection_name=collection_name,
)
self.search_parameters = []
self.search_results = []
@@ -520,6 +524,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
+ collection_name: Optional[str] = None,
):
"""Initialize the AI search agent service"""
super().__init__(
@@ -531,9 +536,11 @@ def __init__( # noqa: PLR0913
user_id=user_id,
cache_key=cache_key,
cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT,
+ collection_name=collection_name,
)
self.search_parameters = []
self.search_results = []
+ self.collection_name = collection_name
self.agent = self.create_agent()
self.create_agent()
@@ -550,6 +557,8 @@ def search_content_files(self) -> str:
"resource_readable_id": self.readable_id,
"limit": 20,
}
+ if self.collection_name:
+ params["collection_name"] = self.collection_name
self.search_parameters.append(params)
try:
response = requests.get(url, params=params, timeout=30)
diff --git a/ai_chat/agents_test.py b/ai_chat/agents_test.py
index 309f03183c..5d186713cb 100644
--- a/ai_chat/agents_test.py
+++ b/ai_chat/agents_test.py
@@ -8,7 +8,7 @@
from llama_index.core.base.llms.types import MessageRole
from llama_index.core.constants import DEFAULT_TEMPERATURE
-from ai_chat.agents import SearchAgent
+from ai_chat.agents import SearchAgent, SyllabusAgent
from ai_chat.factories import ChatMessageFactory
from learning_resources.factories import LearningResourceFactory
from learning_resources.serializers import (
@@ -268,3 +268,27 @@ def test_get_completion(mocker):
},
)
assert "".join([str(value) for value in expected_return_value]) in results
+
+
+@pytest.mark.django_db
+def test_collection_name_param(settings, mocker):
+ """The collection name should be passed through to the contentfile search"""
+ settings.AI_MIT_SEARCH_LIMIT = 5
+ settings.AI_MIT_SYLLABUS_URL = "https://test.com/api/v0/contentfiles/"
+ mock_post = mocker.patch(
+ "ai_chat.agents.requests.get",
+ return_value=mocker.Mock(json=mocker.Mock(return_value={})),
+ )
+ search_agent = SyllabusAgent("test agent", collection_name="content_files_512")
+ search_agent.get_completion("I want to learn physics", readable_id="test")
+ search_agent.search_content_files()
+ mock_post.assert_called_once_with(
+ settings.AI_MIT_SYLLABUS_URL,
+ params={
+ "q": "I want to learn physics",
+ "resource_readable_id": "test",
+ "limit": 20,
+ "collection_name": "content_files_512",
+ },
+ timeout=30,
+ )
diff --git a/ai_chat/serializers.py b/ai_chat/serializers.py
index 076a2d6294..0f0e569773 100644
--- a/ai_chat/serializers.py
+++ b/ai_chat/serializers.py
@@ -41,3 +41,4 @@ class SyllabusChatRequestSerializer(ChatRequestSerializer):
"""DRF serializer for syllabus chatbot requests"""
readable_id = serializers.CharField(required=True)
+ collection_name = serializers.CharField(required=False)
diff --git a/ai_chat/views.py b/ai_chat/views.py
index 5dc13a589a..9ac6b86d37 100644
--- a/ai_chat/views.py
+++ b/ai_chat/views.py
@@ -104,12 +104,14 @@ def post(self, request: Request) -> StreamingHttpResponse:
user_id = request.user.email if request.user.is_authenticated else "anonymous"
message = serializer.validated_data.pop("message", "")
readable_id = (serializer.validated_data.pop("readable_id"),)
+ collection_name = (serializer.validated_data.pop("collection_name"),)
clear_history = serializer.validated_data.pop("clear_history", False)
agent = SyllabusAgent(
"Learning Resource Search AI Assistant",
user_id=user_id,
cache_key=f"{cache_id}_search_chat_history",
save_history=True,
+ collection_name=collection_name,
**serializer.validated_data,
)
if clear_history:
diff --git a/frontends/api/src/generated/v0/api.ts b/frontends/api/src/generated/v0/api.ts
index c61e6d5546..3dc7b32264 100644
--- a/frontends/api/src/generated/v0/api.ts
+++ b/frontends/api/src/generated/v0/api.ts
@@ -5027,6 +5027,12 @@ export interface SyllabusChatRequestRequest {
* @memberof SyllabusChatRequestRequest
*/
readable_id: string
+ /**
+ *
+ * @type {string}
+ * @memberof SyllabusChatRequestRequest
+ */
+ collection_name?: string
}
/**
* * `0-to-5-hours` - <5 hours/week * `5-to-10-hours` - 5-10 hours/week * `10-to-20-hours` - 10-20 hours/week * `20-to-30-hours` - 20-30 hours/week * `30-plus-hours` - 30+ hours/week
diff --git a/frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx b/frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx
index 5219aa6ceb..f8bc3c6402 100644
--- a/frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx
+++ b/frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx
@@ -51,6 +51,7 @@ const StyledDebugPre = styled.pre({
const ChatSyllabusPage = () => {
const botEnabled = useFeatureFlagEnabled(FeatureFlags.RecommendationBot)
const [readableId, setReadableId] = useState("18.06SC+fall_2011")
+ const [collectionName, setCollectionName] = useState("content_files")
const [debugInfo, setDebugInfo] = useState("")
const parseContent = (content: string | unknown) => {
@@ -105,6 +106,18 @@ const ChatSyllabusPage = () => {
Introduction to Differential Equations(EdX)
+ Contentfile Chunk Size
+
@@ -124,6 +137,7 @@ const ChatSyllabusPage = () => {
return {
message: messages[messages.length - 1].content,
readable_id: readableId,
+ collection_name: collectionName,
}
},
}}
diff --git a/openapi/specs/v0.yaml b/openapi/specs/v0.yaml
index 39e44bc3e7..9f7e004d7f 100644
--- a/openapi/specs/v0.yaml
+++ b/openapi/specs/v0.yaml
@@ -5057,6 +5057,9 @@ components:
readable_id:
type: string
minLength: 1
+ collection_name:
+ type: string
+ minLength: 1
required:
- message
- readable_id