Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embed course metadata as contentfile #2050

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
169 changes: 147 additions & 22 deletions vector_search/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
import logging
import uuid

from dateutil import parser as date_parser
from django.conf import settings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
Expand Down Expand Up @@ -74,7 +76,6 @@ def create_qdrant_collections(force_recreate):
client = qdrant_client()
resources_collection_name = RESOURCES_COLLECTION_NAME
content_files_collection_name = CONTENT_FILES_COLLECTION_NAME

encoder = dense_encoder()
if (
not client.collection_exists(collection_name=resources_collection_name)
Expand Down Expand Up @@ -161,25 +162,6 @@ def vector_point_id(readable_id):
return str(uuid.uuid5(uuid.NAMESPACE_DNS, readable_id))


def _process_resource_embeddings(serialized_resources):
docs = []
metadata = []
ids = []
encoder = dense_encoder()
vector_name = encoder.model_short_name()
for doc in serialized_resources:
vector_point_key = doc["readable_id"]
metadata.append(doc)
ids.append(vector_point_id(vector_point_key))
docs.append(
f"{doc.get('title')} {doc.get('description')} {doc.get('full_description')}"
)
if len(docs) > 0:
embeddings = encoder.embed_documents(docs)
return points_generator(ids, metadata, embeddings, vector_name)
return None


def _chunk_documents(encoder, texts, metadatas):
# chunk the documents. use semantic chunking if enabled
chunk_params = {
Expand Down Expand Up @@ -214,7 +196,146 @@ def _chunk_documents(encoder, texts, metadatas):
return recursive_splitter.create_documents(texts=texts, metadatas=metadatas)


def _process_resource_embeddings(serialized_resources):
docs = []
metadata = []
ids = []
encoder = dense_encoder()
vector_name = encoder.model_short_name()
for doc in serialized_resources:
vector_point_key = doc["readable_id"]
metadata.append(doc)
ids.append(vector_point_id(vector_point_key))
docs.append(
f"{doc.get('title')} {doc.get('description')} {doc.get('full_description')}"
)
if len(docs) > 0:
embeddings = encoder.embed_documents(docs)
return points_generator(ids, metadata, embeddings, vector_name)
return None


def generate_metadata_document(serialized_resource):
"""
Generate a plaint-text info document to embed in the contentfile collection
"""
title = serialized_resource.get("title", "")
description = (
f"{serialized_resource.get('description', '')}"
f" {serialized_resource.get('full_description', '')}"
)
offered_by = serialized_resource.get("offered_by", {}).get("name")
price = (
f"${serialized_resource['prices'][0]}"
if serialized_resource.get("prices")
else "Free"
)
certification = serialized_resource.get("certification_type", {}).get("name")
topics = ", ".join(topic["name"] for topic in serialized_resource.get("topics", []))
# process course runs
runs = []
for run in serialized_resource.get("runs", []):
start_date = run.get("start_date")
formatted_date = (
date_parser.parse(start_date)
.replace(tzinfo=datetime.UTC)
.strftime("%B %d, %Y")
if start_date
else ""
)
location = run.get("location") or "Online"
duration = run.get("duration")
delivery_modes = (
", ".join(delivery["name"] for delivery in run.get("delivery", []))
or "Not specified"
)
instructors = ", ".join(
instructor["full_name"]
for instructor in run.get("instructors", [])
if "full_name" in instructor
)
runs.append(
f" - Start Date: {formatted_date}, Location: {location}, "
f"Duration: {duration}, Format: {delivery_modes},"
f" Instructors: {instructors}"
)
runs_text = "\n".join(runs) if runs else ""
# Extract languages
languages = []
for run in serialized_resource.get("runs", []):
if run.get("languages"):
languages.extend(run["languages"])
unique_languages = ", ".join(set(languages))
# Extract levels
levels = []
for run in serialized_resource.get("runs", []):
if run.get("level"):
levels.extend(lvl["name"] for lvl in run["level"])
unique_levels = ", ".join(set(levels))
display_info = {
"Course Title": title,
"Description": description,
"Offered By": offered_by,
"Price": price,
"Certification": certification,
"Topics": topics,
"Level": unique_levels,
"Languages": unique_languages,
"Course Runs": runs_text,
}
rendered_info = "Information about this course:\n"
for section, display_text in display_info.items():
if display_text:
if len(display_text.strip().split("\n")) > 1:
rendered_info += f"{section} -\n{display_text}\n"
else:
rendered_info += f"{section} - {display_text}\n"
return rendered_info


def _embed_course_metadata_as_contentfile(serialized_resources):
"""
Embed general course info as a document in the contentfile collection
"""
client = qdrant_client()
encoder = dense_encoder()
vector_name = encoder.model_short_name()
metadata = []
ids = []
docs = []
for doc in serialized_resources:
readable_id = doc["readable_id"]
resource_vector_point_id = str(vector_point_id(readable_id))
ids.append(resource_vector_point_id)
course_info_document = generate_metadata_document(doc)
metadata.append(
{
"resource_point_id": resource_vector_point_id,
"resource_readable_id": readable_id,
"chunk_number": 0,
"file_extension": ".md",
"file_type": "text/markdown",
"chunk_content": course_info_document,
**{
key: doc[key]
for key in [
"platform",
"offered_by",
]
},
}
)
docs.append(course_info_document)
if len(docs) > 0:
embeddings = encoder.embed_documents(docs)
points = points_generator(ids, metadata, embeddings, vector_name)
client.upload_points(CONTENT_FILES_COLLECTION_NAME, points=points, wait=False)


def _process_content_embeddings(serialized_content):
"""
Chunk and embed content file documents
"""
embeddings = []
metadata = []
ids = []
Expand Down Expand Up @@ -329,7 +450,9 @@ def embed_learning_resources(ids, resource_type, overwrite):
]

collection_name = RESOURCES_COLLECTION_NAME

points = _process_resource_embeddings(serialized_resources)
_embed_course_metadata_as_contentfile(serialized_resources)
else:
serialized_resources = list(serialize_bulk_content_files(ids))
collection_name = CONTENT_FILES_COLLECTION_NAME
Expand Down Expand Up @@ -394,7 +517,7 @@ def vector_search(
"""

client = qdrant_client()

encoder = dense_encoder()
qdrant_conditions = qdrant_query_conditions(
params, collection_name=search_collection
)
Expand All @@ -405,7 +528,6 @@ def vector_search(
]
)
if query_string:
encoder = dense_encoder()
search_result = client.query_points(
collection_name=search_collection,
using=encoder.model_short_name(),
Expand Down Expand Up @@ -506,6 +628,9 @@ def filter_existing_qdrant_points(
lookup_field="readable_id",
collection_name=RESOURCES_COLLECTION_NAME,
):
"""
Return only values that dont exist in qdrant
"""
client = qdrant_client()
results = client.scroll(
collection_name=collection_name,
Expand Down
93 changes: 91 additions & 2 deletions vector_search/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
from decimal import Decimal

import pytest
from django.conf import settings
from qdrant_client import models
from qdrant_client.models import PointStruct

from learning_resources.factories import ContentFileFactory, LearningResourceFactory
from learning_resources.factories import (
ContentFileFactory,
LearningResourceFactory,
LearningResourcePriceFactory,
LearningResourceRunFactory,
)
from learning_resources.models import LearningResource
from learning_resources_search.serializers import serialize_bulk_content_files
from learning_resources_search.serializers import (
serialize_bulk_content_files,
serialize_bulk_learning_resources,
)
from vector_search.constants import (
CONTENT_FILES_COLLECTION_NAME,
RESOURCES_COLLECTION_NAME,
)
from vector_search.encoders.utils import dense_encoder
from vector_search.utils import (
_chunk_documents,
_embed_course_metadata_as_contentfile,
create_qdrant_collections,
embed_learning_resources,
filter_existing_qdrant_points,
Expand Down Expand Up @@ -334,3 +345,81 @@ def test_text_splitter_chunk_size_override(mocker):
settings.CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = None
_chunk_documents(encoder, ["this is a test document"], [{}])
assert "chunk_size" not in mocked_splitter.mock_calls[0].kwargs


def test_course_metadata_indexed_with_learning_resources(mocker):
# test the we embed a metadata document when embedding learning resources
resources = LearningResourceFactory.create_batch(5)

mock_qdrant = mocker.patch("qdrant_client.QdrantClient")
mock_embed_course_metadata_as_contentfile = mocker.patch(
"vector_search.utils._embed_course_metadata_as_contentfile"
)
mocker.patch(
"vector_search.utils.qdrant_client",
return_value=mock_qdrant,
)

mocker.patch(
"vector_search.utils.filter_existing_qdrant_points",
return_value=[r.readable_id for r in resources],
)
embed_learning_resources(
[resource.id for resource in resources], "course", overwrite=True
)
mock_embed_course_metadata_as_contentfile.assert_called()


def test_course_metadata_document_contents(mocker):
# test the contents of the metadata document
resource = LearningResourceFactory.create()

run = LearningResourceRunFactory.create(
learning_resource=resource,
published=True,
prices=[Decimal("0.00"), Decimal("50.00")],
resource_prices=LearningResourcePriceFactory.create_batch(
2, amount=Decimal("1.00")
),
location="Portland, OR",
duration="7 - 9 weeks",
min_weeks=7,
max_weeks=9,
languages=["en", "es"],
time_commitment="8 - 9 hours per week",
min_weekly_hours=8,
max_weekly_hours=19,
)
resource.prices = [Decimal("1.00"), Decimal("3.00")]
resource.resource_prices.set(
LearningResourcePriceFactory.create_batch(2, amount=1.00)
)
resource.save()

mock_qdrant = mocker.patch("qdrant_client.QdrantClient")

mocker.patch(
"vector_search.utils.qdrant_client",
return_value=mock_qdrant,
)

serialized_resource = next(serialize_bulk_learning_resources([resource.id]))

_embed_course_metadata_as_contentfile([serialized_resource])
point = next(mock_qdrant.upload_points.mock_calls[0].kwargs["points"])
course_metadata_content = point.payload["chunk_content"]
prices = (
f"${serialized_resource['prices'][0]}"
if serialized_resource.get("prices")
else "Free"
)
assert course_metadata_content.startswith("Information about this course:")
assert resource.title in course_metadata_content
assert resource.description in course_metadata_content
assert resource.full_description in course_metadata_content
assert prices in course_metadata_content
for topic in resource.topics.all():
assert topic.name in course_metadata_content
for run in serialized_resource["runs"]:
for level in run["level"]:
assert level["name"] in course_metadata_content
Loading