Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 16, 2024
1 parent 86e51a7 commit ae6011a
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 81 deletions.
34 changes: 18 additions & 16 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,42 @@
from uuid import UUID

from beartype import beartype
from pydantic import BaseModel
from temporalio import activity

from ..clients import embed as embedder
from ..clients.cozo import get_cozo_client
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query

snippet_embed_instruction = "Encode this passage for retrieval: "

class EmbedDocsPayload(BaseModel):
developer_id: UUID
doc_id: UUID
content: list[str]
embed_instruction: str | None
title: str | None = None
include_title: bool = False # Need to be a separate parameter for the activity


@activity.defn
@beartype
async def embed_docs(
developer_id: UUID,
doc_id: UUID,
title: str,
content: list[str],
include_title: bool = True,
cozo_client=None,
) -> None:
indices, snippets = list(zip(*enumerate(content)))
async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
indices, snippets = list(zip(*enumerate(payload.content)))
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await embedder.embed(
[
{
"instruction": snippet_embed_instruction,
"text": (title + "\n\n" + snippet) if include_title else snippet,
}
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
)

embed_snippets_query(
developer_id=developer_id,
doc_id=doc_id,
developer_id=payload.developer_id,
doc_id=payload.doc_id,
snippet_indices=indices,
embeddings=embeddings,
client=cozo_client or get_cozo_client(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@activity.defn
@beartype
async def transition_step(
context: StepContext[None],
context: StepContext,
transition_info: CreateTransitionRequest,
) -> None:
need_to_wait = transition_info.type == "wait"
Expand Down
12 changes: 10 additions & 2 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@
@activity.defn
@beartype
async def yield_step(context: StepContext[YieldStep]) -> StepOutcome[dict[str, Any]]:
workflow = context.definition.workflow
exprs = context.definition.arguments
all_workflows = context.execution_input.task.workflows
workflow = context.current_step.workflow

assert workflow in [
wf.name for wf in all_workflows
], f"Workflow {workflow} not found in task"

# Evaluate the expressions in the arguments
exprs = context.current_step.arguments
arguments = simple_eval_dict(exprs, values=context.model_dump())

# Transition to the first step of that workflow
transition_target = TransitionTarget(
workflow=workflow,
step=0,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ruff: noqa: F401, F403, F405
from typing import Annotated, Generic, Self, Type, TypeVar
from typing import Annotated, Generic, Literal, Self, Type, TypeVar
from uuid import UUID

from litellm.utils import _select_tokenizer as select_tokenizer
Expand Down
12 changes: 9 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,16 @@ def get_active_agent(self) -> Agent:
def merge_settings(self, chat_input: ChatInput) -> ChatSettings:
request_settings = chat_input.model_dump(exclude_unset=True)
active_agent = self.get_active_agent()
default_settings = active_agent.default_settings

default_settings: AgentDefaultSettings | None = active_agent.default_settings
default_settings: dict = (
default_settings and default_settings.model_dump() or {}
)

self.settings = settings = ChatSettings(
**{
"model": active_agent.model,
**default_settings.model_dump(),
**default_settings,
**request_settings,
}
)
Expand All @@ -102,13 +106,15 @@ def get_chat_environment(self) -> dict[str, dict | list[dict]]:
"""
current_agent = self.get_active_agent()
tools = self.get_active_tools()
settings: ChatSettings | None = self.settings
settings: dict = settings and settings.model_dump() or {}

return {
"session": self.session.model_dump(),
"agents": [agent.model_dump() for agent in self.agents],
"current_agent": current_agent.model_dump(),
"users": [user.model_dump() for user in self.users],
"settings": self.settings.model_dump(),
"settings": settings,
"tools": [tool.model_dump() for tool in tools],
}

Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/models/entry/delete_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def delete_entries_for_session(
verify_developer_owns_resource_query(
developer_id, "sessions", session_id=session_id
),
mark_session_as_updated
and mark_session_updated_query(developer_id, session_id),
mark_session_updated_query(developer_id, session_id)
if mark_session_as_updated
else "",
delete_query,
]

Expand Down
13 changes: 12 additions & 1 deletion agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from starlette.status import HTTP_201_CREATED
from temporalio.client import Client as TemporalClient

from ...activities.embed_docs import EmbedDocsPayload
from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse
from ...clients import temporal
from ...dependencies.developer_id import get_developer_id
Expand All @@ -15,6 +16,8 @@


async def run_embed_docs_task(
*,
developer_id: UUID,
doc_id: UUID,
title: str,
content: list[str],
Expand All @@ -31,9 +34,17 @@ async def run_embed_docs_task(
if testing:
return None

embed_payload = EmbedDocsPayload(
developer_id=developer_id,
doc_id=doc_id,
content=content,
title=title,
embed_instruction=None,
)

handle = await client.start_workflow(
EmbedDocsWorkflow.run,
args=[str(doc_id), title, content],
embed_payload,
task_queue=temporal_task_queue,
id=str(job_id),
)
Expand Down
8 changes: 4 additions & 4 deletions agents-api/agents_api/workflows/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from temporalio import workflow

with workflow.unsafe.imports_passed_through():
from ..activities.embed_docs import embed_docs
from ..activities.embed_docs import EmbedDocsPayload, embed_docs


@workflow.defn
class EmbedDocsWorkflow:
@workflow.run
async def run(self, doc_id: str, title: str, content: list[str]) -> None:
return await workflow.execute_activity(
async def run(self, embed_payload: EmbedDocsPayload) -> None:
await workflow.execute_activity(
embed_docs,
args=[doc_id, title, content],
embed_payload,
schedule_to_close_timeout=timedelta(seconds=600),
)
8 changes: 4 additions & 4 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from temporalio import workflow

from agents_api.autogen.Executions import TransitionTarget
from agents_api.autogen.openapi_model import CreateTransitionRequest, TransitionType

with workflow.unsafe.imports_passed_through():
from ..activities.task_steps import (
evaluate_step,
Expand All @@ -18,11 +15,14 @@
yield_step,
)
from ..autogen.openapi_model import (
CreateTransitionRequest,
ErrorWorkflowStep,
EvaluateStep,
IfElseWorkflowStep,
PromptStep,
ToolCallStep,
TransitionTarget,
TransitionType,
# WaitForInputStep,
# WorkflowStep,
YieldStep,
Expand Down Expand Up @@ -156,7 +156,7 @@ async def run(

# 4. Closing
# End if the last step
if context.is_last_step:
if transition_type in ("finish", "cancelled"):
return final_output

# Otherwise, recurse to the next step
Expand Down
68 changes: 23 additions & 45 deletions agents-api/pytype.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

[tool.pytype]

# Space-separated list of files or directories to exclude.
exclude = [
]

# Space-separated list of files or directories to process.
inputs = [
'agents_api',
'tests',
'.',
]

# Keep going past errors to analyze as many files as possible.
Expand All @@ -27,60 +30,35 @@ pythonpath = '.'
# Python version (major.minor) of the target code.
python_version = '3.11'

# Bind 'self' in methods with non-transparent decorators. This flag is temporary
# and will be removed once this behavior is enabled by default.
bind_decorated_methods = true

# Don't allow None to match bool. This flag is temporary and will be removed
# once this behavior is enabled by default.
none_is_not_bool = true

# Enable parameter count checks for overriding methods with renamed arguments.
# This flag is temporary and will be removed once this behavior is enabled by
# default.
overriding_renamed_parameter_count_checks = true

# Variables initialized as None retain their None binding. This flag is
# temporary and will be removed once this behavior is enabled by default.
strict_none_binding = true

# Support the third-party fiddle library. This flag is temporary and will be
# removed once this behavior is enabled by default.
use_fiddle_overlay = true
# Space-separated list of error names to ignore.
disable = [
'pyi-error',
]

# --------------
# Optional flags
# --------------

# Bind 'self' in methods with non-transparent decorators. This flag is temporary
# and will be removed once this behavior is enabled by default.
bind_decorated_methods = false

# Enable parameter count checks for overriding methods with renamed arguments.
# This flag is temporary and will be removed once this behavior is enabled by
# default.
overriding_renamed_parameter_count_checks = false

# Opt-in: Do not allow Any as a return type.
no_return_any = false

# Opt-in: Require decoration with @typing.override when overriding a method or
# nested class attribute of a parent class.
require_override_decorator = false

# Experimental: Infer precise return types even for invalid function calls.
precise_return = true

# Experimental: Solve unknown types to label with structural types.
protocols = true

# Experimental: Only load submodules that are explicitly imported.
strict_import = true

# Experimental: Enable exhaustive checking of function parameter types.
strict_parameter_checks = true

# Experimental: Emit errors for comparisons between incompatible primitive
# types.
strict_primitive_comparisons = true

# Experimental: Check that variables are defined in all possible code paths.
strict_undefined_checks = true

# Experimental: FOR TESTING ONLY. Use pytype/rewrite/.
use_rewrite = false

# Space-separated list of error names to ignore.
disable = [
'pyi-error',
]

# Don't report errors.
report_errors = true
require_override_decorator = false
4 changes: 2 additions & 2 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ async def temporal_worker():

kill_signal = worker.shutdown()
worker_task.cancel()
await asyncio.wait(
[kill_signal, worker_task],
await asyncio.gather(
*[kill_signal, worker_task],
return_when=asyncio.FIRST_COMPLETED,
)

Expand Down

0 comments on commit ae6011a

Please sign in to comment.