Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shanbady/chunk size dropdown #1989

Merged
merged 6 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ai_chat/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
collection_name: Optional[str] = None,
):
"""Initialize the AI chat agent service"""
self.assistant_name = name
Expand All @@ -70,6 +71,7 @@ def __init__( # noqa: PLR0913
self.save_history = save_history
self.temperature = temperature or DEFAULT_TEMPERATURE
self.instructions = instructions or self.INSTRUCTIONS
self.collection_name = collection_name
self.user_id = user_id
if settings.AI_PROXY_CLASS:
self.proxy = import_string(f"ai_chat.proxy.{settings.AI_PROXY_CLASS}")()
Expand Down Expand Up @@ -355,6 +357,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
collection_name: Optional[str] = None,
):
"""Initialize the AI search agent service"""
super().__init__(
Expand All @@ -366,6 +369,7 @@ def __init__( # noqa: PLR0913
user_id=user_id,
cache_key=cache_key,
cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT,
collection_name=collection_name,
)
self.search_parameters = []
self.search_results = []
Expand Down Expand Up @@ -520,6 +524,7 @@ def __init__( # noqa: PLR0913
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
collection_name: Optional[str] = None,
):
"""Initialize the AI search agent service"""
super().__init__(
Expand All @@ -531,9 +536,11 @@ def __init__( # noqa: PLR0913
user_id=user_id,
cache_key=cache_key,
cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT,
collection_name=collection_name,
)
self.search_parameters = []
self.search_results = []
self.collection_name = collection_name
self.agent = self.create_agent()
self.create_agent()

Expand All @@ -550,6 +557,8 @@ def search_content_files(self) -> str:
"resource_readable_id": self.readable_id,
"limit": 20,
}
if self.collection_name:
params["collection_name"] = self.collection_name
self.search_parameters.append(params)
try:
response = requests.get(url, params=params, timeout=30)
Expand Down
26 changes: 25 additions & 1 deletion ai_chat/agents_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from llama_index.core.base.llms.types import MessageRole
from llama_index.core.constants import DEFAULT_TEMPERATURE

from ai_chat.agents import SearchAgent
from ai_chat.agents import SearchAgent, SyllabusAgent
from ai_chat.factories import ChatMessageFactory
from learning_resources.factories import LearningResourceFactory
from learning_resources.serializers import (
Expand Down Expand Up @@ -268,3 +268,27 @@ def test_get_completion(mocker):
},
)
assert "".join([str(value) for value in expected_return_value]) in results


@pytest.mark.django_db
def test_collection_name_param(settings, mocker):
"""The collection name should be passed through to the contentfile search"""
settings.AI_MIT_SEARCH_LIMIT = 5
settings.AI_MIT_SYLLABUS_URL = "https://test.com/api/v0/contentfiles/"
mock_post = mocker.patch(
"ai_chat.agents.requests.get",
return_value=mocker.Mock(json=mocker.Mock(return_value={})),
)
search_agent = SyllabusAgent("test agent", collection_name="content_files_512")
search_agent.get_completion("I want to learn physics", readable_id="test")
search_agent.search_content_files()
mock_post.assert_called_once_with(
settings.AI_MIT_SYLLABUS_URL,
params={
"q": "I want to learn physics",
"resource_readable_id": "test",
"limit": 20,
"collection_name": "content_files_512",
},
timeout=30,
)
1 change: 1 addition & 0 deletions ai_chat/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ class SyllabusChatRequestSerializer(ChatRequestSerializer):
"""DRF serializer for syllabus chatbot requests"""

readable_id = serializers.CharField(required=True)
collection_name = serializers.CharField(required=False)
2 changes: 2 additions & 0 deletions ai_chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ def post(self, request: Request) -> StreamingHttpResponse:
user_id = request.user.email if request.user.is_authenticated else "anonymous"
message = serializer.validated_data.pop("message", "")
readable_id = (serializer.validated_data.pop("readable_id"),)
collection_name = (serializer.validated_data.pop("collection_name"),)
clear_history = serializer.validated_data.pop("clear_history", False)
agent = SyllabusAgent(
"Learning Resource Search AI Assistant",
user_id=user_id,
cache_key=f"{cache_id}_search_chat_history",
save_history=True,
collection_name=collection_name,
**serializer.validated_data,
)
if clear_history:
Expand Down
6 changes: 6 additions & 0 deletions frontends/api/src/generated/v0/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5027,6 +5027,12 @@ export interface SyllabusChatRequestRequest {
* @memberof SyllabusChatRequestRequest
*/
readable_id: string
/**
*
* @type {string}
* @memberof SyllabusChatRequestRequest
*/
collection_name?: string
}
/**
* * `0-to-5-hours` - <5 hours/week * `5-to-10-hours` - 5-10 hours/week * `10-to-20-hours` - 10-20 hours/week * `20-to-30-hours` - 20-30 hours/week * `30-plus-hours` - 30+ hours/week
Expand Down
14 changes: 14 additions & 0 deletions frontends/main/src/app-pages/ChatSyllabusPage/ChatSyllabusPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const StyledDebugPre = styled.pre({
const ChatSyllabusPage = () => {
const botEnabled = useFeatureFlagEnabled(FeatureFlags.RecommendationBot)
const [readableId, setReadableId] = useState("18.06SC+fall_2011")
const [collectionName, setCollectionName] = useState("content_files")
const [debugInfo, setDebugInfo] = useState("")

const parseContent = (content: string | unknown) => {
Expand Down Expand Up @@ -105,6 +106,18 @@ const ChatSyllabusPage = () => {
Introduction to Differential Equations(EdX)
</MenuItem>
</Select>
<InputLabel>Contentfile Chunk Size</InputLabel>
<Select
label="Contentfile Chunk Size"
value={collectionName}
onChange={(e) => setCollectionName(e.target.value)}
>
<MenuItem value="content_files">
Default (model dependant - 8191 for OpenAI)
</MenuItem>
<MenuItem value="content_files_512">512</MenuItem>
<MenuItem value="content_files_1024">1024</MenuItem>
</Select>
</div>
</FormContainer>
</form>
Expand All @@ -124,6 +137,7 @@ const ChatSyllabusPage = () => {
return {
message: messages[messages.length - 1].content,
readable_id: readableId,
collection_name: collectionName,
}
},
}}
Expand Down
3 changes: 3 additions & 0 deletions openapi/specs/v0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5057,6 +5057,9 @@ components:
readable_id:
type: string
minLength: 1
collection_name:
type: string
minLength: 1
required:
- message
- readable_id
Expand Down
Loading