Skip to content

Commit

Permalink
topic embedder with rrf
Browse files Browse the repository at this point in the history
  • Loading branch information
zsristy43 committed Jan 16, 2025
1 parent 0dc68bf commit e59b0b0
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 8 deletions.
72 changes: 68 additions & 4 deletions local_req.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
with open("tests/request_data/no_clicks.json", "r") as req_file:
raw_json = req_file.read()

event_nrms = {
"body": raw_json,
"queryStringParameters": {"pipeline": "nrms"},
"isBase64Encoded": False,
}
event_static = {
"body": raw_json,
"queryStringParameters": {"pipeline": "nrms-topics-static"},
Expand All @@ -27,6 +32,24 @@
"queryStringParameters": {"pipeline": "nrms-topics-clicked"},
"isBase64Encoded": False,
}
event_hybrid = {
"body": raw_json,
"queryStringParameters": {"pipeline": "nrms-topics-hybrid"},
"isBase64Encoded": False,
}
event_rrf_static_candidate = {
"body": raw_json,
"queryStringParameters": {"pipeline": "nrms_rrf_static_candidate"},
"isBase64Encoded": False,
}
event_rrf_static_clicked = {
"body": raw_json,
"queryStringParameters": {"pipeline": "nrms_rrf_static_clicked"},
"isBase64Encoded": False,
}

response_nrms = generate_recs(event_nrms, {})
response_nrms = RecommendationResponse.model_validate_json(response_nrms["body"])

response_static = generate_recs(event_static, {})
response_static = RecommendationResponse.model_validate_json(response_static["body"])
Expand All @@ -37,27 +60,68 @@
response_clicked = generate_recs(event_clicked, {})
response_clicked = RecommendationResponse.model_validate_json(response_clicked["body"])

response_hybrid = generate_recs(event_hybrid, {})
response_hybrid = RecommendationResponse.model_validate_json(response_hybrid["body"])

response_rrf_static_candidate = generate_recs(event_rrf_static_candidate, {})
response_rrf_static_candidate = RecommendationResponse.model_validate_json(response_rrf_static_candidate["body"])

response_rrf_static_clicked = generate_recs(event_rrf_static_clicked, {})
response_rrf_static_clicked = RecommendationResponse.model_validate_json(response_rrf_static_clicked["body"])

for profile_id, recs in response_nrms.recommendations.items():
print("\n")
print(f"Recs for {profile_id}:")
print(f"{event_nrms['queryStringParameters']['pipeline']}")

for idx, article in enumerate(recs):
article_topics = extract_general_topics(article)
print(f"{idx + 1}. {article.headline} {article_topics}")

for profile_id, recs in response_static.recommendations.items():
print("\n")
print(f"{event_static['queryStringParameters']['pipeline']}")

for idx, article in enumerate(recs):
article_topics = extract_general_topics(article)
print(f"{idx+1}. {article.headline} {article_topics}")
print(f"{idx + 1}. {article.headline} {article_topics}")

for profile_id, recs in response_candidate.recommendations.items():
print("\n")
print(f"Recs for {profile_id}:")
print(f"{event_candidate['queryStringParameters']['pipeline']}")

for idx, article in enumerate(recs):
article_topics = extract_general_topics(article)
print(f"{idx+1}. {article.headline} {article_topics}")
print(f"{idx + 1}. {article.headline} {article_topics}")

for profile_id, recs in response_clicked.recommendations.items():
print("\n")
print(f"{event_clicked['queryStringParameters']['pipeline']}")

for idx, article in enumerate(recs):
article_topics = extract_general_topics(article)
print(f"{idx+1}. {article.headline} {article_topics}")
print(f"{idx + 1}. {article.headline} {article_topics}")

for profile_id, recs in response_hybrid.recommendations.items():
print("\n")
print(f"{event_hybrid['queryStringParameters']['pipeline']}")

for idx, article in enumerate(recs):
article_topics = extract_general_topics(article)
print(f"{idx + 1}. {article.headline} {article_topics}")

for profile_id, recs in response_rrf_static_candidate.recommendations.items():
print("\n")
print(f"{event_rrf_static_candidate['queryStringParameters']['pipeline']}")

for idx, article in enumerate(recs):
article_topics = extract_general_topics(article)
print(f"{idx + 1}. {article.headline} {article_topics}")

for profile_id, recs in response_rrf_static_clicked.recommendations.items():
print("\n")
print(f"{event_rrf_static_clicked['queryStringParameters']['pipeline']}")

for idx, article in enumerate(recs):
article_topics = extract_general_topics(article)
print(f"{idx + 1}. {article.headline} {article_topics}")
12 changes: 12 additions & 0 deletions src/poprox_recommender/components/embedders/topic_wise_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ def __call__(
**embeddings_from_candidates,
**embeddings_from_clicked,
}
elif self.embedding_source == "hybrid":
all_topic_uuids = (
set(embeddings_from_definitions) | set(embeddings_from_candidates) | set(embeddings_from_clicked)
)
topic_embeddings_by_uuid = {}
for topic_uuid in all_topic_uuids:
def_emb = embeddings_from_definitions.get(topic_uuid, th.zeros(768, device=self.device))
cand_emb = embeddings_from_candidates.get(topic_uuid, th.zeros(768, device=self.device))
clicked_emb = embeddings_from_clicked.get(topic_uuid, th.zeros(768, device=self.device))

avg_emb = 0.6 * def_emb + 0.3 * cand_emb + 0.1 * clicked_emb
topic_embeddings_by_uuid[topic_uuid] = avg_emb
else:
raise ValueError(f"Unknown embedding source: {self.embedding_source}")

Expand Down
3 changes: 2 additions & 1 deletion src/poprox_recommender/components/joiners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from poprox_recommender.components.joiners.concat import Concatenate
from poprox_recommender.components.joiners.fill import Fill
from poprox_recommender.components.joiners.interleave import Interleave
from poprox_recommender.components.joiners.rrf import ReciprocalRankFusion

__all__ = ["Concatenate", "Fill", "Interleave"]
__all__ = ["Concatenate", "Fill", "Interleave", "ReciprocalRankFusion"]
90 changes: 88 additions & 2 deletions src/poprox_recommender/recommenders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from poprox_recommender.components.embedders import NRMSArticleEmbedder, NRMSUserEmbedder
from poprox_recommender.components.embedders.topic_wise_user import TopicUserEmbedder
from poprox_recommender.components.filters import TopicFilter
from poprox_recommender.components.joiners import Fill
from poprox_recommender.components.joiners import Fill, ReciprocalRankFusion
from poprox_recommender.components.rankers.topk import TopkRanker
from poprox_recommender.components.samplers import SoftmaxSampler, UniformSampler
from poprox_recommender.components.scorers import ArticleScorer
Expand Down Expand Up @@ -105,6 +105,12 @@ def build_pipelines(num_slots: int, device: str) -> dict[str, Pipeline]:
embedding_source="static",
topic_embedding="avg",
)
topic_user_embedder_hybrid = TopicUserEmbedder(
model_file_path("nrms-mind/user_encoder.safetensors"),
device,
embedding_source="hybrid",
topic_embedding="avg",
)

topk_ranker = TopkRanker(num_slots=num_slots)
mmr = MMRDiversifier(num_slots=num_slots)
Expand Down Expand Up @@ -145,6 +151,14 @@ def build_pipelines(num_slots: int, device: str) -> dict[str, Pipeline]:
num_slots=num_slots,
)

nrms_onboarding_pipe_hybrid = build_pipeline(
"plain-NRMS-with-onboarding-topics",
article_embedder=article_embedder,
user_embedder=topic_user_embedder_hybrid,
ranker=topk_ranker,
num_slots=num_slots,
)

mmr_pipe = build_pipeline(
"NRMS+MMR",
article_embedder=article_embedder,
Expand Down Expand Up @@ -185,16 +199,37 @@ def build_pipelines(num_slots: int, device: str) -> dict[str, Pipeline]:
num_slots=num_slots,
)

nrms_rrf_static_candidate = build_RRF_pipeline(
"NRMS+RRF",
article_embedder=article_embedder,
user_embedder=topic_user_embedder_static,
user_embedder2=topic_user_embedder_candidate,
ranker=topk_ranker,
num_slots=num_slots,
)

nrms_rrf_static_clicked = build_RRF_pipeline(
"NRMS+RRF",
article_embedder=article_embedder,
user_embedder=topic_user_embedder_static,
user_embedder2=topic_user_embedder_clicked,
ranker=topk_ranker,
num_slots=num_slots,
)

return {
"nrms": nrms_pipe,
"nrms-topics-candidate": nrms_onboarding_pipe_cadidate,
"nrms-topics-clicked": nrms_onboarding_pipe_clicked,
"nrms-topics-static": nrms_onboarding_pipe_static,
"nrms-topics-hybrid": nrms_onboarding_pipe_hybrid,
"mmr": mmr_pipe,
"pfar": pfar_pipe,
"topic-cali": topic_cali_pipe,
"locality-cali": locality_cali_pipe,
"softmax": softmax_pipe,
"nrms_rrf_static_candidate": nrms_rrf_static_candidate,
"nrms_rrf_static_clicked": nrms_rrf_static_clicked,
}


Expand All @@ -204,7 +239,6 @@ def build_pipeline(name, article_embedder, user_embedder, ranker, num_slots):
sampler = UniformSampler(num_slots=num_slots)
fill = Fill(num_slots=num_slots)
topk_ranker = TopkRanker(num_slots=num_slots)

pipeline = Pipeline(name=name)

# Define pipeline inputs
Expand Down Expand Up @@ -237,3 +271,55 @@ def build_pipeline(name, article_embedder, user_embedder, ranker, num_slots):
pipeline.add_component("recommender", fill, candidates1=o_rank, candidates2=o_sampled)

return pipeline


def build_RRF_pipeline(name, article_embedder, user_embedder, user_embedder2, ranker, num_slots):
article_scorer = ArticleScorer()
rrf = ReciprocalRankFusion(num_slots=num_slots)
topk_ranker = TopkRanker(num_slots=num_slots)

pipeline = Pipeline(name=name)

# Define pipeline inputs
candidates = pipeline.create_input("candidate", ArticleSet)
clicked = pipeline.create_input("clicked", ArticleSet)
profile = pipeline.create_input("profile", InterestProfile)

# Compute embeddings
e_cand = pipeline.add_component("candidate-embedder", article_embedder, article_set=candidates)
e_click = pipeline.add_component("history-embedder", article_embedder, article_set=clicked)
e_user_1 = pipeline.add_component(
"user-embedder",
user_embedder,
candidate_articles=candidates,
clicked_articles=e_click,
interest_profile=profile,
)

# Score and rank articles with diversification/calibration reranking
o_scored_1 = pipeline.add_component("scorer", article_scorer, candidate_articles=e_cand, interest_profile=e_user_1)
o_topk_1 = pipeline.add_component("ranker", topk_ranker, candidate_articles=o_scored_1, interest_profile=e_user_1)
if ranker is topk_ranker:
o_rank_1 = o_topk_1
else:
o_rank_1 = pipeline.add_component("reranker", ranker, candidate_articles=o_scored_1, interest_profile=e_user_1)

# Fallback in case not enough articles came from the ranker
e_user_2 = pipeline.add_component(
"user-embedder2",
user_embedder2,
candidate_articles=candidates,
clicked_articles=e_click,
interest_profile=profile,
)

o_scored_2 = pipeline.add_component("scorer2", article_scorer, candidate_articles=e_cand, interest_profile=e_user_2)
o_topk_2 = pipeline.add_component("ranker2", topk_ranker, candidate_articles=o_scored_2, interest_profile=e_user_2)
if ranker is topk_ranker:
o_rank_2 = o_topk_2
else:
o_rank_2 = pipeline.add_component("reranker2", ranker, candidate_articles=o_scored_2, interest_profile=e_user_2)

pipeline.add_component("recommender", rrf, candidates1=o_rank_1, candidates2=o_rank_2)

return pipeline
2 changes: 1 addition & 1 deletion tests/request_data/no_clicks.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
{
"entity_id": "1e813fd6-0998-43fb-9839-75fa96b69b32",
"entity_name": "Science",
"preference": 1
"preference": 5
},
{
"entity_id": "5f6de24a-9a1b-4863-ab01-1ecacf4c54b7",
Expand Down

0 comments on commit e59b0b0

Please sign in to comment.