diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index ebfa3a193ad66..acec6c2672cd8 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -43,6 +43,8 @@ AwaitableBackgroundReader, AwaitableBackgroundWriter, RayDAGArgs, + CompositeChannel, + IntraProcessChannel, ) from ray.util.annotations import DeveloperAPI @@ -287,6 +289,8 @@ def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"): self.output_channels: List[ChannelInterface] = [] self.output_idxs: List[Optional[Union[int, str]]] = [] self.arg_type_hints: List["ChannelOutputType"] = [] + # idxs of possible ClassMethodOutputNodes if they exist, used for visualization + self.output_node_idxs: List[int] = [] @property def args(self) -> Tuple[Any]: @@ -843,6 +847,8 @@ def __init__( self._max_finished_execution_index: int = -1 # execution_index -> {channel_index -> result} self._result_buffer: Dict[int, Dict[int, Any]] = defaultdict(dict) + # channel to possible inner channel + self._channel_dict: Dict[ChannelInterface, ChannelInterface] = {} def _create_proxy_actor() -> "ray.actor.ActorHandle": # Creates the driver actor on the same node as the driver. @@ -1384,6 +1390,7 @@ def _get_or_compile( output_idx = downstream_node.output_idx task.output_channels.append(output_channel) task.output_idxs.append(output_idx) + task.output_node_idxs.append(self.dag_node_to_idx[downstream_node]) actor_handle = task.dag_node._get_actor_handle() assert actor_handle is not None self.actor_refs.add(actor_handle) @@ -1530,7 +1537,6 @@ def _get_or_compile( # Dict from original channel to the channel to be used in execution. # The value of this dict is either the original channel or a newly # created CachedChannel (if the original channel is read more than once). - channel_dict: Dict[ChannelInterface, ChannelInterface] = {} for arg, consumers in arg_to_consumers.items(): arg_idx = self.dag_node_to_idx[arg] upstream_task = self.idx_to_task[arg_idx] @@ -1538,12 +1544,12 @@ def _get_or_compile( arg_channel = upstream_task.output_channels[0] assert arg_channel is not None if len(consumers) > 1: - channel_dict[arg_channel] = CachedChannel( + self._channel_dict[arg_channel] = CachedChannel( len(consumers), arg_channel, ) else: - channel_dict[arg_channel] = arg_channel + self._channel_dict[arg_channel] = arg_channel # Step 3: create executable tasks for the actor executable_tasks = [] @@ -1556,7 +1562,7 @@ def _get_or_compile( assert len(upstream_task.output_channels) == 1 arg_channel = upstream_task.output_channels[0] assert arg_channel is not None - arg_channel = channel_dict[arg_channel] + arg_channel = self._channel_dict[arg_channel] resolved_args.append(arg_channel) else: # Constant arg @@ -2234,8 +2240,45 @@ async def execute_async( self._execution_index += 1 return fut + def get_channel_details( + self, channel: ChannelInterface, downstream_actor_id: str + ) -> str: + """ + Get details about outer and inner channel types and channel ids + based on the channel and the downstream actor ID. + Used for graph visualization. + Args: + channel: The channel to get details for. + downstream_actor_id: The downstream actor ID. + Returns: + A string with details about the channel based on its connection + to the actor provided. + """ + channel_details = type(channel).__name__ + # get outer channel + if channel in self._channel_dict and self._channel_dict[channel] != channel: + channel = self._channel_dict[channel] + channel_details += f"\n{type(channel).__name__}" + if type(channel) == CachedChannel: + channel_details += f", {channel._channel_id[:6]}..." + # get inner channel + if ( + type(channel) == CompositeChannel + and downstream_actor_id in channel._channel_dict + ): + inner_channel = channel._channel_dict[downstream_actor_id] + channel_details += f"\n{type(inner_channel).__name__}" + if type(inner_channel) == IntraProcessChannel: + channel_details += f", {inner_channel._channel_id[:6]}..." + return channel_details + def visualize( - self, filename="compiled_graph", format="png", view=False, return_dot=False + self, + filename="compiled_graph", + format="png", + view=False, + return_dot=False, + channel_details=False, ): """ Visualize the compiled graph using Graphviz. @@ -2249,13 +2292,20 @@ def visualize( format: The format of the output file (e.g., 'png', 'pdf'). view: Whether to open the file with the default viewer. return_dot: If True, returns the DOT source as a string instead of figure. + show_channel_details: If True, adds channel details to edges. Raises: ValueError: If the graph is empty or not properly compiled. ImportError: If the `graphviz` package is not installed. """ - import graphviz + try: + import graphviz + except ImportError: + raise ImportError( + "Please install graphviz to visualize the compiled graph. " + "You can install it by running `pip install graphviz`." + ) from ray.dag import ( InputAttributeNode, InputNode, @@ -2281,11 +2331,14 @@ def visualize( # Dot file for debuging dot = graphviz.Digraph(name="compiled_graph", format=format) - + # Give every actor a unique color, colors between 24k -> 40k tested as readable + # other colors may be too dark, especially when wrapping back around to 0 + actor_id_to_color = defaultdict( + lambda: f"#{((len(actor_id_to_color) * 2000 + 24000) % 0xFFFFFF):06X}" + ) # Add nodes with task information for idx, task in self.idx_to_task.items(): dag_node = task.dag_node - # Initialize the label and attributes label = f"Task {idx}\n" shape = "oval" # Default shape @@ -2313,10 +2366,11 @@ def visualize( if actor_handle: actor_id = actor_handle._actor_id.hex() label += f"Actor: {actor_id[:6]}...\nMethod: {method_name}" + fillcolor = actor_id_to_color[actor_id] else: label += f"Method: {method_name}" + fillcolor = "lightgreen" shape = "oval" - fillcolor = "lightgreen" elif dag_node.is_class_method_output: # Class Method Output Node label += f"ClassMethodOutputNode[{dag_node.output_idx}]" @@ -2335,28 +2389,45 @@ def visualize( # Add the node to the graph with attributes dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor) - - # Add edges with type hints based on argument mappings - for idx, task in self.idx_to_task.items(): - current_task_idx = idx - - for arg_index, arg in enumerate(task.dag_node.get_args()): - if isinstance(arg, DAGNode): - # Get the upstream task index - upstream_task_idx = self.dag_node_to_idx[arg] - - # Get the type hint for this argument - if arg_index < len(task.arg_type_hints): - type_hint = type(task.arg_type_hints[arg_index]).__name__ - else: - type_hint = "UnknownType" - - # Draw an edge from the upstream task to the - # current task with the type hint - dot.edge( - str(upstream_task_idx), str(current_task_idx), label=type_hint - ) - + channel_type_str = ( + type(dag_node.type_hint).__name__ + if dag_node.type_hint + else "UnknownType" + ) + "\n" + + # This logic is built on the assumption that there will only be multiple + # output channels if the task has multiple returns + # case: task with one output + if len(task.output_channels) == 1: + for downstream_node in task.dag_node._downstream_nodes: + downstream_idx = self.dag_node_to_idx[downstream_node] + edge_label = channel_type_str + if channel_details: + edge_label += self.get_channel_details( + task.output_channels[0], + ( + downstream_node._get_actor_handle()._actor_id.hex() + if type(downstream_node) == ClassMethodNode + else self._proxy_actor._actor_id.hex() + ), + ) + dot.edge(str(idx), str(downstream_idx), label=edge_label) + # case: multi return, output channels connect to class method output nodes + elif len(task.output_channels) > 1: + assert len(task.output_idxs) == len(task.output_channels) + for output_channel, downstream_idx in zip( + task.output_channels, task.output_node_idxs + ): + edge_label = channel_type_str + if channel_details: + edge_label += self.get_channel_details( + output_channel, + task.dag_node._get_actor_handle()._actor_id.hex(), + ) + dot.edge(str(idx), str(downstream_idx), label=edge_label) + if type(task.dag_node) == InputAttributeNode: + # Add an edge from the InputAttributeNode to the InputNode + dot.edge(str(self.input_task_idx), str(idx)) if return_dot: return dot.source else: diff --git a/python/ray/dag/tests/experimental/test_dag_visualization.py b/python/ray/dag/tests/experimental/test_dag_visualization.py index dee5096c92b3d..d5fdb4ce09c74 100644 --- a/python/ray/dag/tests/experimental/test_dag_visualization.py +++ b/python/ray/dag/tests/experimental/test_dag_visualization.py @@ -165,6 +165,70 @@ def echo(self, x): compiled_dag.teardown() +def test_visualize_multi_input_nodes(ray_start_regular): + """ + Expect output or dot_source: + MultiOutputNode" fillcolor=yellow shape=rectangle style=filled] + 0 -> 1 + 0 -> 2 + 0 -> 3 + 1 -> 4 + 2 -> 5 + 3 -> 6 + 4 -> 7 + 5 -> 7 + 6 -> 7 + """ + + @ray.remote + class Actor: + def echo(self, x): + return x + + actor = Actor.remote() + + with InputNode() as inp: + o1 = actor.echo.bind(inp.x) + o2 = actor.echo.bind(inp.y) + o3 = actor.echo.bind(inp.z) + dag = MultiOutputNode([o1, o2, o3]) + + compiled_dag = dag.experimental_compile() + + # Get the DOT source + dot_source = compiled_dag.visualize(return_dot=True) + + graphs = pydot.graph_from_dot_data(dot_source) + graph = graphs[0] + + node_names = {node.get_name() for node in graph.get_nodes()} + edge_pairs = { + (edge.get_source(), edge.get_destination()) for edge in graph.get_edges() + } + + expected_nodes = {"0", "1", "2", "3", "4", "5", "6", "7"} + assert expected_nodes.issubset( + node_names + ), f"Expected nodes {expected_nodes} not found." + + expected_edges = { + ("0", "1"), + ("0", "2"), + ("0", "3"), + ("1", "4"), + ("2", "5"), + ("3", "6"), + ("4", "7"), + ("5", "7"), + ("6", "7"), + } + assert expected_edges.issubset( + edge_pairs + ), f"Expected edges {expected_edges} not found." + + compiled_dag.teardown() + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))