Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FAQ rewriting pipeline #191

Merged
merged 24 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/common/PipelineEnum.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ class PipelineEnum(str, Enum):
IRIS_SUMMARY_PIPELINE = "IRIS_SUMMARY_PIPELINE"
IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE"
IRIS_LECTURE_INGESTION = "IRIS_LECTURE_INGESTION"
IRIS_REWRITING_PIPELINE = "IRIS_REWRITING_PIPELINE"
NOT_SET = "NOT_SET"
11 changes: 11 additions & 0 deletions app/domain/rewriting_pipeline_execution_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import List

from pydantic import Field, BaseModel

from . import PipelineExecutionDTO
from .data.competency_dto import CompetencyTaxonomy, Competency


class RewritingPipelineExecutionDTO(BaseModel):
execution: PipelineExecutionDTO
to_be_rewritten : str = Field(alias="toBeRewritten")
6 changes: 6 additions & 0 deletions app/domain/status/rewriting_status_update_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from app.domain.data.competency_dto import Competency
from app.domain.status.status_update_dto import StatusUpdateDTO


class RewritingStatusUpdateDTO(StatusUpdateDTO):
result: str = ""
15 changes: 15 additions & 0 deletions app/pipeline/prompts/faq_rewriting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
system_prompt_faq = """
You are a excellent tutor with expertise in computer science and practical applications teaching an university course. Your task is to proofread and refine the given text of an FAQ. Specifically, you should:

1. Correct all spelling and grammatical errors.
2. Ensure the text is written in simple and clear language, making it easy to understand for students.
3. Preserve the original meaning and intent of the text.
4. Ensure that the response is always written in complete sentences. If you are given a list of bullet points, convert them into complete sentences.
5. Make sure to use the original language of the input text
6. Avoid repeating any information that is already present in the text.
7. Make sure to keep the markdown formatting intact and add formatting for the most important information

{rewritten_text}

Respond with a single string containing only the improved version of the text. Your output will be used as a frequently asked question (FAQ) on the Artemis platform, so make sure it is clear and concise.
"""
68 changes: 68 additions & 0 deletions app/pipeline/rewriting_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
from typing import Optional

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
)

from app.common.PipelineEnum import PipelineEnum
from app.common.pyris_message import PyrisMessage, IrisMessageRole
from app.domain.data.text_message_content_dto import TextMessageContentDTO
from app.domain.data.competency_dto import Competency
from app.domain.rewriting_pipeline_execution_dto import RewritingPipelineExecutionDTO
from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments
from app.pipeline import Pipeline
from app.pipeline.prompts.faq_rewriting import system_prompt_faq
from app.web.status.status_update import RewritingCallback

logger = logging.getLogger(__name__)


class RewritingPipeline(Pipeline):
callback: RewritingCallback
request_handler: CapabilityRequestHandler
output_parser: PydanticOutputParser

def __init__(self, callback: Optional[RewritingCallback] = None):
super().__init__(
implementation_id="rewriting_pipeline_reference_impl"
)
self.callback = callback
self.request_handler = CapabilityRequestHandler(
requirements=RequirementList(
gpt_version_equivalent=4.5,
context_length=16385,
)
)
self.output_parser = PydanticOutputParser(pydantic_object=Competency)
self.tokens = []

def __call__(
self,
dto: RewritingPipelineExecutionDTO,
prompt: Optional[ChatPromptTemplate] = None,
**kwargs,
):
if not dto.to_be_rewritten:
raise ValueError("You need to provide a text to rewrite")

#
prompt = system_prompt_faq.format(
rewritten_text=dto.to_be_rewritten,
)
prompt = PyrisMessage(
sender=IrisMessageRole.SYSTEM,
contents=[TextMessageContentDTO(text_content=prompt)],
)

response = self.request_handler.chat(
[prompt], CompletionArguments(temperature=0.4)
)
self._append_tokens(
response.token_usage, PipelineEnum.IRIS_REWRITING_PIPELINE
)
response = response.contents[0].text_content
final_result = response
logging.info(f"Final rewritten text: {final_result}")
self.callback.done(final_result=final_result, tokens=self.tokens)
53 changes: 52 additions & 1 deletion app/web/routers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
CourseChatPipelineExecutionDTO,
CompetencyExtractionPipelineExecutionDTO,
)
from app.domain.rewriting_pipeline_execution_dto import RewritingPipelineExecutionDTO
from app.pipeline.chat.exercise_chat_agent_pipeline import ExerciseChatAgentPipeline
from app.domain.chat.lecture_chat.lecture_chat_pipeline_execution_dto import (
LectureChatPipelineExecutionDTO,
)
from app.pipeline.chat.lecture_chat_pipeline import LectureChatPipeline
from app.pipeline.rewriting_pipeline import RewritingPipeline
from app.web.status.status_update import (
ExerciseChatStatusCallback,
CourseChatStatusCallback,
CompetencyExtractionCallback,
LectureChatCallback,
LectureChatCallback, RewritingCallback,
)
from app.pipeline.chat.course_chat_pipeline import CourseChatPipeline
from app.dependencies import TokenValidator
Expand Down Expand Up @@ -218,6 +220,29 @@ def run_competency_extraction_pipeline_worker(
callback.error("Fatal error.", exception=e)


def run_rewriting_pipeline_worker(
dto: RewritingPipelineExecutionDTO, _variant: str
):
try:
callback = RewritingCallback(
run_id=dto.execution.settings.authentication_token,
base_url=dto.execution.settings.artemis_base_url,
initial_stages=dto.execution.initial_stages,
)
pipeline = RewritingPipeline(callback=callback)
except Exception as e:
logger.error(f"Error preparing rewriting pipeline: {e}")
logger.error(traceback.format_exc())
capture_exception(e)
return

try:
pipeline(dto=dto)
except Exception as e:
logger.error(f"Error running rewriting extraction pipeline: {e}")
logger.error(traceback.format_exc())
callback.error("Fatal error.", exception=e)

@router.post(
"/competency-extraction/{variant}/run",
status_code=status.HTTP_202_ACCEPTED,
Expand All @@ -232,6 +257,22 @@ def run_competency_extraction_pipeline(
thread.start()



@router.post(
"/rewriting/{variant}/run",
status_code=status.HTTP_202_ACCEPTED,
dependencies=[Depends(TokenValidator())],
)
def run_rewriting_pipeline(
variant: str, dto: RewritingPipelineExecutionDTO
):
logger.info(f"Rewriting pipeline started with variant: {variant} and dto: {dto}")
thread = Thread(
target=run_rewriting_pipeline_worker, args=(dto, variant)
)
thread.start()


@router.get("/{feature}/variants")
def get_pipeline(feature: str):
"""
Expand Down Expand Up @@ -294,5 +335,15 @@ def get_pipeline(feature: str):
description="Default lecture chat variant.",
)
]

case "REWRITING":
return [
FeatureDTO(
id="default",
name="Default Variant",
description="Default rewriting variant.",
)
]

case _:
return Response(status_code=status.HTTP_400_BAD_REQUEST)
21 changes: 21 additions & 0 deletions app/web/status/status_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from app.domain.status.lecture_chat_status_update_dto import (
LectureChatStatusUpdateDTO,
)
from app.domain.status.rewriting_status_update_dto import RewritingStatusUpdateDTO
from app.domain.status.stage_state_dto import StageStateEnum
from app.domain.status.stage_dto import StageDTO
from app.domain.status.text_exercise_chat_status_update_dto import (
Expand Down Expand Up @@ -274,6 +275,26 @@ def __init__(
stage = stages[-1]
super().__init__(url, run_id, status, stage, len(stages) - 1)

class RewritingCallback(StatusCallback):
def __init__(
self,
run_id: str,
base_url: str,
initial_stages: List[StageDTO],
):
url = f"{base_url}/api/public/pyris/pipelines/rewriting/runs/{run_id}/status"
stages = initial_stages or []
stages.append(
StageDTO(
weight=10,
state=StageStateEnum.NOT_STARTED,
name="Generating Rewritting",
)
)
status = RewritingStatusUpdateDTO(stages=stages)
stage = stages[-1]
super().__init__(url, run_id, status, stage, len(stages) - 1)


class LectureChatCallback(StatusCallback):
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions docker/pyris-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ services:
- ../llm_config.local.yml:/config/llm_config.yml:ro
networks:
- pyris
ports:
- 8000:8000

weaviate:
extends:
Expand Down
Loading