diff --git a/app/domain/rewriting_pipeline_execution_dto.py b/app/domain/rewriting_pipeline_execution_dto.py index 9af08b5a..b8560899 100644 --- a/app/domain/rewriting_pipeline_execution_dto.py +++ b/app/domain/rewriting_pipeline_execution_dto.py @@ -8,4 +8,4 @@ class RewritingPipelineExecutionDTO(BaseModel): execution: PipelineExecutionDTO - to_be_rewritten : str = Field(alias="toBeRewritten") + to_be_rewritten: str = Field(alias="toBeRewritten") diff --git a/app/pipeline/chat/course_chat_pipeline.py b/app/pipeline/chat/course_chat_pipeline.py index 9e3306f9..f0904116 100644 --- a/app/pipeline/chat/course_chat_pipeline.py +++ b/app/pipeline/chat/course_chat_pipeline.py @@ -100,14 +100,16 @@ def __init__( requirements=RequirementList( gpt_version_equivalent=4.5, ) - ), completion_args=completion_args + ), + completion_args=completion_args, ) self.llm_small = IrisLangchainChatModel( request_handler=CapabilityRequestHandler( requirements=RequirementList( gpt_version_equivalent=4.25, ) - ), completion_args=completion_args + ), + completion_args=completion_args, ) self.callback = callback diff --git a/app/pipeline/chat/exercise_chat_agent_pipeline.py b/app/pipeline/chat/exercise_chat_agent_pipeline.py index ff9e86da..676f96c6 100644 --- a/app/pipeline/chat/exercise_chat_agent_pipeline.py +++ b/app/pipeline/chat/exercise_chat_agent_pipeline.py @@ -533,7 +533,9 @@ def lecture_content_retrieval() -> str: ] ) - guide_response = (self.prompt | self.llm_small | StrOutputParser()).invoke( + guide_response = ( + self.prompt | self.llm_small | StrOutputParser() + ).invoke( { "response": out, } diff --git a/app/pipeline/prompts/faq_rewriting.py b/app/pipeline/prompts/faq_rewriting.py index af74085c..553036b7 100644 --- a/app/pipeline/prompts/faq_rewriting.py +++ b/app/pipeline/prompts/faq_rewriting.py @@ -12,4 +12,4 @@ {rewritten_text} Respond with a single string containing only the improved version of the text. Your output will be used as a frequently asked question (FAQ) on the Artemis platform, so make sure it is clear and concise. -""" \ No newline at end of file +""" diff --git a/app/pipeline/rewriting_pipeline.py b/app/pipeline/rewriting_pipeline.py index 95c1b503..10af525a 100644 --- a/app/pipeline/rewriting_pipeline.py +++ b/app/pipeline/rewriting_pipeline.py @@ -14,7 +14,7 @@ from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments from app.pipeline import Pipeline from app.pipeline.prompts.faq_rewriting import system_prompt_faq -from app.web.status.status_update import RewritingCallback +from app.web.status.status_update import RewritingCallback logger = logging.getLogger(__name__) @@ -25,9 +25,7 @@ class RewritingPipeline(Pipeline): output_parser: PydanticOutputParser def __init__(self, callback: Optional[RewritingCallback] = None): - super().__init__( - implementation_id="rewriting_pipeline_reference_impl" - ) + super().__init__(implementation_id="rewriting_pipeline_reference_impl") self.callback = callback self.request_handler = CapabilityRequestHandler( requirements=RequirementList( @@ -59,9 +57,7 @@ def __call__( response = self.request_handler.chat( [prompt], CompletionArguments(temperature=0.4) ) - self._append_tokens( - response.token_usage, PipelineEnum.IRIS_REWRITING_PIPELINE - ) + self._append_tokens(response.token_usage, PipelineEnum.IRIS_REWRITING_PIPELINE) response = response.contents[0].text_content final_result = response logging.info(f"Final rewritten text: {final_result}") diff --git a/app/pipeline/shared/citation_pipeline.py b/app/pipeline/shared/citation_pipeline.py index 22e13360..fc71016b 100644 --- a/app/pipeline/shared/citation_pipeline.py +++ b/app/pipeline/shared/citation_pipeline.py @@ -57,7 +57,8 @@ def create_formatted_string(self, paragraphs): paragraph.get(LectureSchema.LECTURE_NAME.value), paragraph.get(LectureSchema.LECTURE_UNIT_NAME.value), paragraph.get(LectureSchema.PAGE_NUMBER.value), - paragraph.get(LectureSchema.LECTURE_UNIT_LINK.value) or "No link available", + paragraph.get(LectureSchema.LECTURE_UNIT_LINK.value) + or "No link available", paragraph.get(LectureSchema.PAGE_TEXT_CONTENT.value), ) formatted_string += lct diff --git a/app/web/routers/pipelines.py b/app/web/routers/pipelines.py index 9fd2b465..8c9dd6c0 100644 --- a/app/web/routers/pipelines.py +++ b/app/web/routers/pipelines.py @@ -22,7 +22,8 @@ ExerciseChatStatusCallback, CourseChatStatusCallback, CompetencyExtractionCallback, - LectureChatCallback, RewritingCallback, + LectureChatCallback, + RewritingCallback, ) from app.pipeline.chat.course_chat_pipeline import CourseChatPipeline from app.dependencies import TokenValidator @@ -220,10 +221,8 @@ def run_competency_extraction_pipeline_worker( callback.error("Fatal error.", exception=e) -def run_rewriting_pipeline_worker( - dto: RewritingPipelineExecutionDTO, _variant: str -): - try: +def run_rewriting_pipeline_worker(dto: RewritingPipelineExecutionDTO, _variant: str): + try: callback = RewritingCallback( run_id=dto.execution.settings.authentication_token, base_url=dto.execution.settings.artemis_base_url, @@ -243,6 +242,7 @@ def run_rewriting_pipeline_worker( logger.error(traceback.format_exc()) callback.error("Fatal error.", exception=e) + @router.post( "/competency-extraction/{variant}/run", status_code=status.HTTP_202_ACCEPTED, @@ -257,19 +257,14 @@ def run_competency_extraction_pipeline( thread.start() - @router.post( "/rewriting/{variant}/run", status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(TokenValidator())], ) -def run_rewriting_pipeline( - variant: str, dto: RewritingPipelineExecutionDTO -): +def run_rewriting_pipeline(variant: str, dto: RewritingPipelineExecutionDTO): logger.info(f"Rewriting pipeline started with variant: {variant} and dto: {dto}") - thread = Thread( - target=run_rewriting_pipeline_worker, args=(dto, variant) - ) + thread = Thread(target=run_rewriting_pipeline_worker, args=(dto, variant)) thread.start() diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index 28b8338f..56b744e4 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -275,6 +275,7 @@ def __init__( stage = stages[-1] super().__init__(url, run_id, status, stage, len(stages) - 1) + class RewritingCallback(StatusCallback): def __init__( self,