Skip to content

Commit

Permalink
feat(rnd): Introduce Sub-Graph on Agent Server (#7693)
Browse files Browse the repository at this point in the history
### Background

This change brings the capability to decompose a graph into sub-graphs. The objective of this feature is to allow a user to build a visually modular, and easier-to-understand graph. Also, allowing you to import a graph into your existing graph, without decluttering your existing graph.

This feature will require more implementation on the UI side, to allow the grouping of subgraph to be represented as a node in the builder.

### Changes 🏗️

Introduced a subgraph functionality with the following property:

* Sub-graph is simply a set of nodes that are grouped together, making it representable as a node.
* Sub-graph input & output pins/schema are the `InputBlock` / `OutputBlock` nodes present in the subgraph.
* The previous point implies that connecting two nodes from different sub-graphs, other than input/output nodes, is not allowed.
* Graph can be nested, but defined flatly, e.g.: graph is now only represented by three components: nodes, links, and subgraphs (a set of list of nodes). A nested subgraph is simply connecting a node inside a subgraph into another `InputBlock` node of another subgraph.
  • Loading branch information
majdyz authored Aug 5, 2024
1 parent c7fdfa0 commit 4cf1dd3
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 64 deletions.
7 changes: 7 additions & 0 deletions rnd/autogpt_server/autogpt_server/data/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from contextlib import asynccontextmanager
from uuid import uuid4

from dotenv import load_dotenv
Expand All @@ -23,6 +24,12 @@ async def disconnect():
await prisma.disconnect()


@asynccontextmanager
async def transaction():
async with prisma.tx() as tx:
yield tx


class BaseDbModel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))

Expand Down
194 changes: 149 additions & 45 deletions rnd/autogpt_server/autogpt_server/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from autogpt_server.blocks.basic import InputBlock, OutputBlock
from autogpt_server.data.block import BlockInput, get_block
from autogpt_server.data.db import BaseDbModel
from autogpt_server.data.db import BaseDbModel, transaction
from autogpt_server.util import json


Expand Down Expand Up @@ -89,6 +89,7 @@ def from_db(graph: AgentGraph):
class Graph(GraphMeta):
nodes: list[Node]
links: list[Link]
subgraphs: dict[str, list[str]] = {} # subgraph_id -> [node_id]

@property
def starting_nodes(self) -> list[Node]:
Expand All @@ -106,17 +107,63 @@ def starting_nodes(self) -> list[Node]:
def ending_nodes(self) -> list[Node]:
return [v for v in self.nodes if isinstance(get_block(v.block_id), OutputBlock)]

def validate_graph(self):
@property
def subgraph_map(self) -> dict[str, str]:
"""
Returns a mapping of node_id to subgraph_id.
A node in the main graph will be mapped to the graph's id.
"""
subgraph_map = {
node_id: subgraph_id
for subgraph_id, node_ids in self.subgraphs.items()
for node_id in node_ids
}
subgraph_map.update(
{node.id: self.id for node in self.nodes if node.id not in subgraph_map}
)
return subgraph_map

def reassign_ids(self):
"""
Reassigns all IDs in the graph to new UUIDs.
This method can be used before storing a new graph to the database.
"""
self.validate_graph()

id_map = {
self.id: str(uuid.uuid4()),
**{node.id: str(uuid.uuid4()) for node in self.nodes},
**{subgraph_id: str(uuid.uuid4()) for subgraph_id in self.subgraphs},
}

self.id = id_map[self.id]

for node in self.nodes:
node.id = id_map[node.id]

for link in self.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]

self.subgraphs = {
id_map[subgraph_id]: [id_map[node_id] for node_id in node_ids]
for subgraph_id, node_ids in self.subgraphs.items()
}

def validate_graph(self, for_run: bool = False):

def sanitize(name):
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]

# Check if all required fields are filled or connected, except for InputBlock.
# Nodes: required fields are filled or connected, except for InputBlock.
for node in self.nodes:
block = get_block(node.block_id)
if block is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")

if not for_run:
continue # Skip input completion validation, unless when executing.

provided_inputs = set(
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in node.input_links]
Expand All @@ -126,65 +173,100 @@ def sanitize(name):
raise ValueError(
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
node_map = {v.id: v for v in self.nodes}

def is_input_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
b = get_block(bid)
return isinstance(b, InputBlock) or isinstance(b, OutputBlock)

# subgraphs: all nodes in subgraph must be present in the graph.
for subgraph_id, node_ids in self.subgraphs.items():
for node_id in node_ids:
if node_id not in node_map:
raise ValueError(f"Subgraph {subgraph_id}'s node {node_id} invalid")
subgraph_map = self.subgraph_map

# Check if all links are connected compatible pin data type.
# Links: links are connected and the connected pin data type are compatible.
for link in self.links:
source_id = link.source_id
sink_id = link.sink_id
suffix = f"Link {source_id}<->{sink_id}"

source_node = next((v for v in self.nodes if v.id == source_id), None)
if not source_node:
raise ValueError(f"{suffix}, {source_id} is invalid node.")
sink_node = next((v for v in self.nodes if v.id == sink_id), None)
if not sink_node:
raise ValueError(f"{suffix}, {sink_id} is invalid node.")

source_block = get_block(source_node.block_id)
if not source_block:
raise ValueError(f"{suffix}, {source_node.block_id} is invalid block.")
sink_block = get_block(sink_node.block_id)
if not sink_block:
raise ValueError(f"{suffix}, {sink_node.block_id} is invalid block.")

source_name = sanitize(link.source_name)
if source_name not in source_block.output_schema.get_fields():
raise ValueError(f"{suffix}, `{source_name}` is invalid output pin.")
sink_name = sanitize(link.sink_name)
if sink_name not in sink_block.input_schema.get_fields():
raise ValueError(f"{suffix}, `{sink_name}` is invalid input pin.")
source = (link.source_id, link.source_name)
sink = (link.sink_id, link.sink_name)
suffix = f"Link {source} <-> {sink}"

for i, (node_id, name) in enumerate([source, sink]):
node = node_map.get(node_id)
if not node:
raise ValueError(f"{suffix}, {node_id} is invalid node.")

block = get_block(node.block_id)
if not block:
raise ValueError(f"{suffix}, {node.block_id} is invalid block.")

sanitized_name = sanitize(name)
if i == 0:
fields = block.output_schema.get_fields()
else:
fields = block.input_schema.get_fields()
if sanitized_name not in fields:
raise ValueError(f"{suffix}, `{name}` invalid, fields: {fields}")

if (
subgraph_map.get(link.source_id) != subgraph_map.get(link.sink_id)
and not is_input_output_block(link.source_id)
and not is_input_output_block(link.sink_id)
):
raise ValueError(f"{suffix}, Connecting nodes from different subgraph.")

# TODO: Add type compatibility check here.

@staticmethod
def from_db(graph: AgentGraph):
nodes = [
*(graph.AgentNodes or []),
*(
node
for subgraph in graph.AgentSubGraphs or []
for node in subgraph.AgentNodes or []
),
]
return Graph(
**GraphMeta.from_db(graph).model_dump(),
nodes=[Node.from_db(node) for node in graph.AgentNodes or []],
nodes=[Node.from_db(node) for node in nodes],
links=list(
{
Link.from_db(link)
for node in graph.AgentNodes or []
for node in nodes
for link in (node.Input or []) + (node.Output or [])
}
),
subgraphs={
subgraph.id: [node.id for node in subgraph.AgentNodes or []]
for subgraph in graph.AgentSubGraphs or []
},
)


EXECUTION_NODE_INCLUDE = {
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
"Output": True,
"AgentBlock": True,
}

__SUBGRAPH_INCLUDE = {"AgentNodes": {"include": AGENT_NODE_INCLUDE}}

AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
**__SUBGRAPH_INCLUDE,
"AgentSubGraphs": {"include": __SUBGRAPH_INCLUDE}, # type: ignore
}


# --------------------- Model functions --------------------- #


async def get_node(node_id: str) -> Node | None:
node = await AgentNode.prisma().find_unique_or_raise(
where={"id": node_id},
include=EXECUTION_NODE_INCLUDE, # type: ignore
include=AGENT_NODE_INCLUDE,
)
return Node.from_db(node) if node else None

Expand Down Expand Up @@ -242,7 +324,7 @@ async def get_graph(

graph = await AgentGraph.prisma().find_first(
where=where_clause,
include={"AgentNodes": {"include": EXECUTION_NODE_INCLUDE}}, # type: ignore
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
return Graph.from_db(graph) if graph else None
Expand All @@ -267,7 +349,7 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:
graph_versions = await AgentGraph.prisma().find_many(
where={"id": graph_id},
order={"version": "desc"},
include={"AgentNodes": {"include": EXECUTION_NODE_INCLUDE}}, # type: ignore
include=AGENT_GRAPH_INCLUDE,
)

if not graph_versions:
Expand All @@ -277,7 +359,17 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:


async def create_graph(graph: Graph) -> Graph:
await AgentGraph.prisma().create(
async with transaction() as tx:
await __create_graph(tx, graph)

if created_graph := await get_graph(graph.id, graph.version, graph.is_template):
return created_graph

raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")


async def __create_graph(tx, graph: Graph):
await AgentGraph.prisma(tx).create(
data={
"id": graph.id,
"version": graph.version,
Expand All @@ -290,11 +382,30 @@ async def create_graph(graph: Graph) -> Graph:

await asyncio.gather(
*[
AgentNode.prisma().create(
AgentGraph.prisma(tx).create(
data={
"id": subgraph_id,
"agentGraphParentId": graph.id,
"version": graph.version,
"name": f"SubGraph of {graph.name}",
"description": f"Sub-Graph of {graph.id}",
"isTemplate": graph.is_template,
"isActive": graph.is_active,
}
)
for subgraph_id in graph.subgraphs
]
)

subgraph_map = graph.subgraph_map

await asyncio.gather(
*[
AgentNode.prisma(tx).create(
{
"id": node.id,
"agentBlockId": node.block_id,
"agentGraphId": graph.id,
"agentGraphId": subgraph_map.get(node.id, graph.id),
"agentGraphVersion": graph.version,
"constantInput": json.dumps(node.input_default),
"metadata": json.dumps(node.metadata),
Expand All @@ -306,7 +417,7 @@ async def create_graph(graph: Graph) -> Graph:

await asyncio.gather(
*[
AgentNodeLink.prisma().create(
AgentNodeLink.prisma(tx).create(
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
Expand All @@ -320,13 +431,6 @@ async def create_graph(graph: Graph) -> Graph:
]
)

if created_graph := await get_graph(
graph.id, graph.version, template=graph.is_template
):
return created_graph

raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")


# --------------------- Helper functions --------------------- #

Expand Down
2 changes: 1 addition & 1 deletion rnd/autogpt_server/autogpt_server/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def add_execution(self, graph_id: str, data: BlockInput) -> dict[Any, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
graph.validate_graph()
graph.validate_graph(for_run=True)

nodes_input = []
for node in graph.starting_nodes:
Expand Down
20 changes: 2 additions & 18 deletions rnd/autogpt_server/autogpt_server/server/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import uuid
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import Annotated, Any, Dict
Expand Down Expand Up @@ -468,15 +467,7 @@ async def create_graph(

graph.is_template = is_template
graph.is_active = not is_template

id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}

for node in graph.nodes:
node.id = id_map[node.id]

for link in graph.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
graph.reassign_ids()

return await graph_db.create_graph(graph)

Expand All @@ -501,14 +492,7 @@ async def update_graph(cls, graph_id: str, graph: graph_db.Graph) -> graph_db.Gr
400, detail="Changing is_template on an existing graph is forbidden"
)
graph.is_active = not graph.is_template

# Assign new UUIDs to all nodes and links
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
for link in graph.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
graph.reassign_ids()

new_graph_version = await graph_db.create_graph(graph)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- RedefineTables
PRAGMA foreign_keys=OFF;
CREATE TABLE "new_AgentGraph" (
"id" TEXT NOT NULL,
"version" INTEGER NOT NULL DEFAULT 1,
"name" TEXT,
"description" TEXT,
"isActive" BOOLEAN NOT NULL DEFAULT true,
"isTemplate" BOOLEAN NOT NULL DEFAULT false,
"agentGraphParentId" TEXT,

PRIMARY KEY ("id", "version"),
CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraph" ("description", "id", "isActive", "isTemplate", "name", "version") SELECT "description", "id", "isActive", "isTemplate", "name", "version" FROM "AgentGraph";
DROP TABLE "AgentGraph";
ALTER TABLE "new_AgentGraph" RENAME TO "AgentGraph";
PRAGMA foreign_key_check;
PRAGMA foreign_keys=ON;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- AlterTable
ALTER TABLE "AgentGraph" ADD COLUMN "agentGraphParentId" TEXT;

-- AddForeignKey
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
Loading

0 comments on commit 4cf1dd3

Please sign in to comment.