Skip to content

Commit

Permalink
Shanbady/chunk size dropdown (#1989)
Browse files Browse the repository at this point in the history
* adding drop down

* add collection name as parameter

* passing collection name through

* regenerate spec

* adding simple test for passing down of collection name

* fixing tests
  • Loading branch information
shanbady authored Jan 27, 2025
1 parent b013ee3 commit a9bd2e6
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 1 deletion.
9 changes: 9 additions & 0 deletions ai_chat/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
collection_name: Optional[str] = None,
):
"""Initialize the AI chat agent service"""
self.assistant_name = name
Expand All @@ -70,6 +71,7 @@ def __init__( # noqa: PLR0913
self.save_history = save_history
self.temperature = temperature or DEFAULT_TEMPERATURE
self.instructions = instructions or self.INSTRUCTIONS
self.collection_name = collection_name
self.user_id = user_id
if settings.AI_PROXY_CLASS:
self.proxy = import_string(f"ai_chat.proxy.{settings.AI_PROXY_CLASS}")()
Expand Down Expand Up @@ -355,6 +357,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
collection_name: Optional[str] = None,
):
"""Initialize the AI search agent service"""
super().__init__(
Expand All @@ -366,6 +369,7 @@ def __init__( # noqa: PLR0913
user_id=user_id,
cache_key=cache_key,
cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT,
collection_name=collection_name,
)
self.search_parameters = []
self.search_results = []
Expand Down Expand Up @@ -520,6 +524,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
collection_name: Optional[str] = None,
):
"""Initialize the AI search agent service"""
super().__init__(
Expand All @@ -531,9 +536,11 @@ def __init__( # noqa: PLR0913
user_id=user_id,
cache_key=cache_key,
cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT,
collection_name=collection_name,
)
self.search_parameters = []
self.search_results = []
self.collection_name = collection_name
self.agent = self.create_agent()
self.create_agent()

Expand All @@ -550,6 +557,8 @@ def search_content_files(self) -> str:
"resource_readable_id": self.readable_id,
"limit": 20,
}
if self.collection_name:
params["collection_name"] = self.collection_name
self.search_parameters.append(params)
try:
response = requests.get(url, params=params, timeout=30)
Expand Down
26 changes: 25 additions & 1 deletion ai_chat/agents_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from llama_index.core.base.llms.types import MessageRole
from llama_index.core.constants import DEFAULT_TEMPERATURE

from ai_chat.agents import SearchAgent
from ai_chat.agents import SearchAgent, SyllabusAgent
from ai_chat.factories import ChatMessageFactory
from learning_resources.factories import LearningResourceFactory
from learning_resources.serializers import (
Expand Down Expand Up @@ -268,3 +268,27 @@ def test_get_completion(mocker):
},
)
assert "".join([str(value) for value in expected_return_value]) in results


@pytest.mark.django_db
def test_collection_name_param(settings, mocker):
"""The collection name should be passed through to the contentfile search"""
settings.AI_MIT_SEARCH_LIMIT = 5
settings.AI_MIT_SYLLABUS_URL = "https://test.com/api/v0/contentfiles/"
mock_post = mocker.patch(
"ai_chat.agents.requests.get",
return_value=mocker.Mock(json=mocker.Mock(return_value={})),
)
search_agent = SyllabusAgent("test agent", collection_name="content_files_512")
search_agent.get_completion("I want to learn physics", readable_id="test")
search_agent.search_content_files()
mock_post.assert_called_once_with(
settings.AI_MIT_SYLLABUS_URL,
params={
"q": "I want to learn physics",
"resource_readable_id": "test",
"limit": 20,
"collection_name": "content_files_512",
},
timeout=30,
)
1 change: 1 addition & 0 deletions ai_chat/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ class SyllabusChatRequestSerializer(ChatRequestSerializer):
"""DRF serializer for syllabus chatbot requests"""

readable_id = serializers.CharField(required=True)
collection_name = serializers.CharField(required=False)
2 changes: 2 additions & 0 deletions ai_chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ def post(self, request: Request) -> StreamingHttpResponse:
user_id = request.user.email if request.user.is_authenticated else "anonymous"
message = serializer.validated_data.pop("message", "")
readable_id = (serializer.validated_data.pop("readable_id"),)
collection_name = (serializer.validated_data.pop("collection_name"),)
clear_history = serializer.validated_data.pop("clear_history", False)
agent = SyllabusAgent(
"Learning Resource Search AI Assistant",
user_id=user_id,
cache_key=f"{cache_id}_search_chat_history",
save_history=True,
collection_name=collection_name,
**serializer.validated_data,
)
if clear_history:
Expand Down
6 changes: 6 additions & 0 deletions frontends/api/src/generated/v0/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5027,6 +5027,12 @@ export interface SyllabusChatRequestRequest {
* @memberof SyllabusChatRequestRequest
*/
readable_id: string
/**
*
* @type {string}
* @memberof SyllabusChatRequestRequest
*/
collection_name?: string
}
/**
* * `0-to-5-hours` - <5 hours/week * `5-to-10-hours` - 5-10 hours/week * `10-to-20-hours` - 10-20 hours/week * `20-to-30-hours` - 20-30 hours/week * `30-plus-hours` - 30+ hours/week
Expand Down
14 changes: 14 additions & 0 deletions frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const StyledDebugPre = styled.pre({
const ChatSyllabusPage = () => {
const botEnabled = useFeatureFlagEnabled(FeatureFlags.RecommendationBot)
const [readableId, setReadableId] = useState("18.06SC+fall_2011")
const [collectionName, setCollectionName] = useState("content_files")
const [debugInfo, setDebugInfo] = useState("")

const parseContent = (content: string | unknown) => {
Expand Down Expand Up @@ -105,6 +106,18 @@ const ChatSyllabusPage = () => {
Introduction to Differential Equations(EdX)
</MenuItem>
</Select>
<InputLabel>Contentfile Chunk Size</InputLabel>
<Select
label="Contentfile Chunk Size"
value={collectionName}
onChange={(e) => setCollectionName(e.target.value)}
>
<MenuItem value="content_files">
Default (model dependant - 8191 for OpenAI)
</MenuItem>
<MenuItem value="content_files_512">512</MenuItem>
<MenuItem value="content_files_1024">1024</MenuItem>
</Select>
</div>
</FormContainer>
</form>
Expand All @@ -124,6 +137,7 @@ const ChatSyllabusPage = () => {
return {
message: messages[messages.length - 1].content,
readable_id: readableId,
collection_name: collectionName,
}
},
}}
Expand Down
3 changes: 3 additions & 0 deletions openapi/specs/v0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5057,6 +5057,9 @@ components:
readable_id:
type: string
minLength: 1
collection_name:
type: string
minLength: 1
required:
- message
- readable_id
Expand Down

0 comments on commit a9bd2e6

Please sign in to comment.