From ae6011ae351f591d41d35b5601af66df10515b1f Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Fri, 16 Aug 2024 15:29:51 -0400 Subject: [PATCH] wip Signed-off-by: Diwank Tomer --- .../agents_api/activities/embed_docs.py | 34 +++++----- .../activities/task_steps/transition_step.py | 2 +- .../activities/task_steps/yield_step.py | 12 +++- .../agents_api/autogen/openapi_model.py | 2 +- .../agents_api/common/protocol/sessions.py | 12 +++- .../agents_api/models/entry/delete_entries.py | 5 +- .../agents_api/routers/docs/create_doc.py | 13 +++- agents-api/agents_api/workflows/embed_docs.py | 8 +-- .../agents_api/workflows/task_execution.py | 8 +-- agents-api/pytype.toml | 68 +++++++------------ agents-api/tests/fixtures.py | 4 +- 11 files changed, 87 insertions(+), 81 deletions(-) diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index 7198a7e54..12b389dc5 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -1,40 +1,42 @@ from uuid import UUID from beartype import beartype +from pydantic import BaseModel from temporalio import activity from ..clients import embed as embedder from ..clients.cozo import get_cozo_client from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query -snippet_embed_instruction = "Encode this passage for retrieval: " + +class EmbedDocsPayload(BaseModel): + developer_id: UUID + doc_id: UUID + content: list[str] + embed_instruction: str | None + title: str | None = None + include_title: bool = False # Need to be a separate parameter for the activity @activity.defn @beartype -async def embed_docs( - developer_id: UUID, - doc_id: UUID, - title: str, - content: list[str], - include_title: bool = True, - cozo_client=None, -) -> None: - indices, snippets = list(zip(*enumerate(content))) +async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: + indices, snippets = list(zip(*enumerate(payload.content))) + embed_instruction: str = payload.embed_instruction or "" + title: str = payload.title or "" embeddings = await embedder.embed( [ - { - "instruction": snippet_embed_instruction, - "text": (title + "\n\n" + snippet) if include_title else snippet, - } + ( + embed_instruction + (title + "\n\n" + snippet) if title else snippet + ).strip() for snippet in snippets ] ) embed_snippets_query( - developer_id=developer_id, - doc_id=doc_id, + developer_id=payload.developer_id, + doc_id=payload.doc_id, snippet_indices=indices, embeddings=embeddings, client=cozo_client or get_cozo_client(), diff --git a/agents-api/agents_api/activities/task_steps/transition_step.py b/agents-api/agents_api/activities/task_steps/transition_step.py index e83acd59a..b06c19657 100644 --- a/agents-api/agents_api/activities/task_steps/transition_step.py +++ b/agents-api/agents_api/activities/task_steps/transition_step.py @@ -15,7 +15,7 @@ @activity.defn @beartype async def transition_step( - context: StepContext[None], + context: StepContext, transition_info: CreateTransitionRequest, ) -> None: need_to_wait = transition_info.type == "wait" diff --git a/agents-api/agents_api/activities/task_steps/yield_step.py b/agents-api/agents_api/activities/task_steps/yield_step.py index 1dca81168..f94c8715d 100644 --- a/agents-api/agents_api/activities/task_steps/yield_step.py +++ b/agents-api/agents_api/activities/task_steps/yield_step.py @@ -18,10 +18,18 @@ @activity.defn @beartype async def yield_step(context: StepContext[YieldStep]) -> StepOutcome[dict[str, Any]]: - workflow = context.definition.workflow - exprs = context.definition.arguments + all_workflows = context.execution_input.task.workflows + workflow = context.current_step.workflow + + assert workflow in [ + wf.name for wf in all_workflows + ], f"Workflow {workflow} not found in task" + + # Evaluate the expressions in the arguments + exprs = context.current_step.arguments arguments = simple_eval_dict(exprs, values=context.model_dump()) + # Transition to the first step of that workflow transition_target = TransitionTarget( workflow=workflow, step=0, diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index df0d44342..fdaadd30a 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -1,5 +1,5 @@ # ruff: noqa: F401, F403, F405 -from typing import Annotated, Generic, Self, Type, TypeVar +from typing import Annotated, Generic, Literal, Self, Type, TypeVar from uuid import UUID from litellm.utils import _select_tokenizer as select_tokenizer diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index d4db30a7d..1e98f7f12 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -73,12 +73,16 @@ def get_active_agent(self) -> Agent: def merge_settings(self, chat_input: ChatInput) -> ChatSettings: request_settings = chat_input.model_dump(exclude_unset=True) active_agent = self.get_active_agent() - default_settings = active_agent.default_settings + + default_settings: AgentDefaultSettings | None = active_agent.default_settings + default_settings: dict = ( + default_settings and default_settings.model_dump() or {} + ) self.settings = settings = ChatSettings( **{ "model": active_agent.model, - **default_settings.model_dump(), + **default_settings, **request_settings, } ) @@ -102,13 +106,15 @@ def get_chat_environment(self) -> dict[str, dict | list[dict]]: """ current_agent = self.get_active_agent() tools = self.get_active_tools() + settings: ChatSettings | None = self.settings + settings: dict = settings and settings.model_dump() or {} return { "session": self.session.model_dump(), "agents": [agent.model_dump() for agent in self.agents], "current_agent": current_agent.model_dump(), "users": [user.model_dump() for user in self.users], - "settings": self.settings.model_dump(), + "settings": settings, "tools": [tool.model_dump() for tool in tools], } diff --git a/agents-api/agents_api/models/entry/delete_entries.py b/agents-api/agents_api/models/entry/delete_entries.py index 5bf34c721..1d7cf0386 100644 --- a/agents-api/agents_api/models/entry/delete_entries.py +++ b/agents-api/agents_api/models/entry/delete_entries.py @@ -81,8 +81,9 @@ def delete_entries_for_session( verify_developer_owns_resource_query( developer_id, "sessions", session_id=session_id ), - mark_session_as_updated - and mark_session_updated_query(developer_id, session_id), + mark_session_updated_query(developer_id, session_id) + if mark_session_as_updated + else "", delete_query, ] diff --git a/agents-api/agents_api/routers/docs/create_doc.py b/agents-api/agents_api/routers/docs/create_doc.py index 548ee29d5..229163dd5 100644 --- a/agents-api/agents_api/routers/docs/create_doc.py +++ b/agents-api/agents_api/routers/docs/create_doc.py @@ -6,6 +6,7 @@ from starlette.status import HTTP_201_CREATED from temporalio.client import Client as TemporalClient +from ...activities.embed_docs import EmbedDocsPayload from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse from ...clients import temporal from ...dependencies.developer_id import get_developer_id @@ -15,6 +16,8 @@ async def run_embed_docs_task( + *, + developer_id: UUID, doc_id: UUID, title: str, content: list[str], @@ -31,9 +34,17 @@ async def run_embed_docs_task( if testing: return None + embed_payload = EmbedDocsPayload( + developer_id=developer_id, + doc_id=doc_id, + content=content, + title=title, + embed_instruction=None, + ) + handle = await client.start_workflow( EmbedDocsWorkflow.run, - args=[str(doc_id), title, content], + embed_payload, task_queue=temporal_task_queue, id=str(job_id), ) diff --git a/agents-api/agents_api/workflows/embed_docs.py b/agents-api/agents_api/workflows/embed_docs.py index e52921ed8..556297134 100644 --- a/agents-api/agents_api/workflows/embed_docs.py +++ b/agents-api/agents_api/workflows/embed_docs.py @@ -6,15 +6,15 @@ from temporalio import workflow with workflow.unsafe.imports_passed_through(): - from ..activities.embed_docs import embed_docs + from ..activities.embed_docs import EmbedDocsPayload, embed_docs @workflow.defn class EmbedDocsWorkflow: @workflow.run - async def run(self, doc_id: str, title: str, content: list[str]) -> None: - return await workflow.execute_activity( + async def run(self, embed_payload: EmbedDocsPayload) -> None: + await workflow.execute_activity( embed_docs, - args=[doc_id, title, content], + embed_payload, schedule_to_close_timeout=timedelta(seconds=600), ) diff --git a/agents-api/agents_api/workflows/task_execution.py b/agents-api/agents_api/workflows/task_execution.py index 6215a6416..9fb74c4c0 100644 --- a/agents-api/agents_api/workflows/task_execution.py +++ b/agents-api/agents_api/workflows/task_execution.py @@ -5,9 +5,6 @@ from temporalio import workflow -from agents_api.autogen.Executions import TransitionTarget -from agents_api.autogen.openapi_model import CreateTransitionRequest, TransitionType - with workflow.unsafe.imports_passed_through(): from ..activities.task_steps import ( evaluate_step, @@ -18,11 +15,14 @@ yield_step, ) from ..autogen.openapi_model import ( + CreateTransitionRequest, ErrorWorkflowStep, EvaluateStep, IfElseWorkflowStep, PromptStep, ToolCallStep, + TransitionTarget, + TransitionType, # WaitForInputStep, # WorkflowStep, YieldStep, @@ -156,7 +156,7 @@ async def run( # 4. Closing # End if the last step - if context.is_last_step: + if transition_type in ("finish", "cancelled"): return final_output # Otherwise, recurse to the next step diff --git a/agents-api/pytype.toml b/agents-api/pytype.toml index edd07e7d4..2371cea58 100644 --- a/agents-api/pytype.toml +++ b/agents-api/pytype.toml @@ -2,10 +2,13 @@ [tool.pytype] +# Space-separated list of files or directories to exclude. +exclude = [ +] + # Space-separated list of files or directories to process. inputs = [ - 'agents_api', - 'tests', + '.', ] # Keep going past errors to analyze as many files as possible. @@ -27,60 +30,35 @@ pythonpath = '.' # Python version (major.minor) of the target code. python_version = '3.11' -# Bind 'self' in methods with non-transparent decorators. This flag is temporary -# and will be removed once this behavior is enabled by default. -bind_decorated_methods = true - # Don't allow None to match bool. This flag is temporary and will be removed # once this behavior is enabled by default. none_is_not_bool = true -# Enable parameter count checks for overriding methods with renamed arguments. -# This flag is temporary and will be removed once this behavior is enabled by -# default. -overriding_renamed_parameter_count_checks = true - # Variables initialized as None retain their None binding. This flag is # temporary and will be removed once this behavior is enabled by default. strict_none_binding = true -# Support the third-party fiddle library. This flag is temporary and will be -# removed once this behavior is enabled by default. -use_fiddle_overlay = true +# Space-separated list of error names to ignore. +disable = [ + 'pyi-error', +] + +# -------------- +# Optional flags +# -------------- + +# Bind 'self' in methods with non-transparent decorators. This flag is temporary +# and will be removed once this behavior is enabled by default. +bind_decorated_methods = false + +# Enable parameter count checks for overriding methods with renamed arguments. +# This flag is temporary and will be removed once this behavior is enabled by +# default. +overriding_renamed_parameter_count_checks = false # Opt-in: Do not allow Any as a return type. no_return_any = false # Opt-in: Require decoration with @typing.override when overriding a method or # nested class attribute of a parent class. -require_override_decorator = false - -# Experimental: Infer precise return types even for invalid function calls. -precise_return = true - -# Experimental: Solve unknown types to label with structural types. -protocols = true - -# Experimental: Only load submodules that are explicitly imported. -strict_import = true - -# Experimental: Enable exhaustive checking of function parameter types. -strict_parameter_checks = true - -# Experimental: Emit errors for comparisons between incompatible primitive -# types. -strict_primitive_comparisons = true - -# Experimental: Check that variables are defined in all possible code paths. -strict_undefined_checks = true - -# Experimental: FOR TESTING ONLY. Use pytype/rewrite/. -use_rewrite = false - -# Space-separated list of error names to ignore. -disable = [ - 'pyi-error', -] - -# Don't report errors. -report_errors = true +require_override_decorator = false \ No newline at end of file diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index b1e461557..216b33077 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -73,8 +73,8 @@ async def temporal_worker(): kill_signal = worker.shutdown() worker_task.cancel() - await asyncio.wait( - [kill_signal, worker_task], + await asyncio.gather( + *[kill_signal, worker_task], return_when=asyncio.FIRST_COMPLETED, )