Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH, VIZ] updating draw() for using a graph layout that allows us to fix node positions #26

Merged
merged 27 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
729d4fd
updating draw() function for taking into account a graph layout and f…
siebert-julien Dec 16, 2022
aa1a3a0
updating docstring and cleaning code #24
siebert-julien Dec 16, 2022
2c5437e
Revert poetry action bump
adam2392 Dec 17, 2022
6ff94ba
Update contributing doc to show how to make changes
adam2392 Dec 17, 2022
38ec251
Update examples/draw_and_compare_graphs_with_same_layout.py
siebert-julien Dec 19, 2022
d30ba7b
Update examples/draw_and_compare_graphs_with_same_layout.py
siebert-julien Dec 19, 2022
e81e829
Update examples/draw_and_compare_graphs_with_same_layout.py
siebert-julien Dec 19, 2022
b60ad8a
Update examples/draw_and_compare_graphs_with_same_layout.py
siebert-julien Dec 19, 2022
0ad2cf0
Update pywhy_graphs/viz/draw.py
siebert-julien Dec 19, 2022
a9de57b
Update tests/test_draw.py
siebert-julien Dec 19, 2022
a8567c1
Update tests/test_draw.py
siebert-julien Dec 19, 2022
6360de9
Update tests/test_draw.py
siebert-julien Dec 19, 2022
5f3e64d
Update tests/test_draw.py
siebert-julien Dec 19, 2022
dd1215c
Formatting files #24
siebert-julien Dec 19, 2022
d58cdfd
Bump abatilo/actions-poetry from 2.1.6 to 2.2.0 (#27)
dependabot[bot] Dec 21, 2022
0430d10
Update examples/draw_and_compare_graphs_with_same_layout.py
adam2392 Dec 21, 2022
e72c524
fixing style issues, updating doc #24
siebert-julien Dec 21, 2022
d0d27d3
fixing documentation #24
siebert-julien Dec 21, 2022
450ffa4
fixing documentation url syntax #24
siebert-julien Dec 21, 2022
8b0a933
fixing documentation url syntax #24
siebert-julien Dec 21, 2022
2ba45a1
fixing style issues #24
siebert-julien Dec 21, 2022
af67624
Clean up and add contributing instructions
adam2392 Dec 21, 2022
c31061f
Merge branch 'main' into draw_with_position_layout
adam2392 Dec 21, 2022
8749de9
Update docs/whats_new/v0.1.rst
adam2392 Dec 21, 2022
1828a90
Add contributor for Julien
adam2392 Dec 21, 2022
9cb4eef
Get some docs fixed
adam2392 Dec 21, 2022
cb89339
Upgrade poetry version to 1.3
adam2392 Dec 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ We use [Sphinx](https://www.sphinx-doc.org/en/master/index.html) to build our AP
of public classes and methods. All docstrings should adhere to the [Numpy styling convention](https://www.sphinx-doc.org/en/master/usage/extensions/example_numpy.html).

### Testing Changes Locally With Poetry
With poetry installed, we have included a few convenience functions to check your code. These checks must pass and will be checked by the PR's continuous integration services. You can install the various different developer dependencies with poetry:

With poetry installed, we have included a few convenience functions to check your code. These checks must pass and will be checked by the PR's continuous integration services.
poetry install --with style, docs, test

Check code formatting with black:

Expand Down
12 changes: 12 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,15 @@ and corresponding causal graphs in pywhy-graphs.
graph_to_arr
clearn_arr_to_graph

Visualization of causal graphs
==============================
Visualization of causal graphs is different compared to networkx because causal graphs
can consist of mixed-edges. We implement an API that wraps ``graphviz`` and ``pygraphviz``
to perform modular visualization of nodes and edges.

.. currentmodule:: pywhy_graphs.viz

.. autosummary::
:toctree: generated/

draw
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@
"n_estimated_nodes",
"n_samples",
"n_variables",
# graphviz
"graphviz",
"Digraph",
}
numpydoc_xref_aliases = {
# Networkx
Expand Down
3 changes: 2 additions & 1 deletion docs/whats_new/v0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Changelog
- |Feature| Implement and test various PAG algorithms :func:`pywhy_graphs.algorithms.possible_ancestors`, :func:`pywhy_graphs.algorithms.possible_descendants`, :func:`pywhy_graphs.algorithms.discriminating_path`, :func:`pywhy_graphs.algorithms.pds`, :func:`pywhy_graphs.algorithms.pds_path`, and :func:`pywhy_graphs.algorithms.uncovered_pd_path`, by `Adam Li`_ (:pr:`10`)
- |Feature| Implement an array API wrapper to convert between causal graphs in pywhy-graphs and causal graphs in ``causal-learn``, by `Adam Li`_ (:pr:`16`)
- |Feature| Implement an acyclification algorithm for converting cyclic graphs to acyclic with :func:`pywhy_graphs.algorithms.acyclification`, by `Adam Li`_ (:pr:`17`)

- |Feature| Adding a layout for the nodes positions in the :func:`pywhy_graphs.viz.draw` function, by Julien Siebert (:pr:`26`)
adam2392 marked this conversation as resolved.
Show resolved Hide resolved

Code and Documentation Contributors
-----------------------------------

Expand Down
60 changes: 60 additions & 0 deletions examples/draw_and_compare_graphs_with_same_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
.. _ex-draw-graphs:

=============================================================
Drawing graphs and setting their layout for visual comparison
=============================================================

One can draw a graph without setting the ``pos`` argument,
in that case graphviz will choose how to place the nodes.

In this example, we demonstrate how to visualize various different graphs with a fixed layout
for the nodes, so that they are easily comparable.


This examples shows how to create a position layout for all the nodes (using networkx)
and pass this to other graphs so that the nodes positions
are the same for the nodes with the same labels.

Alternatively, one can create their own positions manually, or using software,
such as `Dagitty <http://dagitty.net>`_.
"""

import networkx as nx

import pywhy_graphs
from pywhy_graphs import CPDAG, PAG
from pywhy_graphs.viz import draw

# create some dummy graphs: G, admg, cpdag, and pag
# this code is borrowed from the other example: intro_causal_graphs.py ;)
G = nx.DiGraph([("x", "y"), ("z", "y"), ("z", "w"), ("xy", "x"), ("xy", "y")])
admg = pywhy_graphs.set_nodes_as_latent_confounders(G, ["xy"])
cpdag = CPDAG()
cpdag.add_edges_from(G.edges, cpdag.undirected_edge_name)
cpdag.orient_uncertain_edge("x", "y")
cpdag.orient_uncertain_edge("xy", "y")
cpdag.orient_uncertain_edge("z", "y")
pag = PAG()
pag.add_edges_from(G.edges, cpdag.undirected_edge_name)

# get the layout position for the graph G using networkx
pos_G = nx.spring_layout(G, k=10)

siebert-julien marked this conversation as resolved.
Show resolved Hide resolved
# let us inspect the positions.
# Notice that networkx and graphviz related software store positions as
# a dictionary keyed by node with (x, y) coordinates as values.
print(pos_G)

# draw the graphs (i.e., generate a graphviz object that can be rendered)
# each time we call draw() we pass the layout position of G
dot_G = draw(G, pos=pos_G)
dot_admg = draw(admg, pos=pos_G)
dot_cpdag = draw(cpdag, pos=pos_G)
dot_pag = draw(pag, pos=pos_G)

# render the graphs using graphviz render() function
dot_G.render(outfile="G.png", view=True, engine="neato")
dot_admg.render(outfile="admg.png", view=True, engine="neato")
dot_cpdag.render(outfile="cpdag.png", view=True, engine="neato")
dot_pag.render(outfile="pag.png", view=True, engine="neato")
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ numpy = "^1.23.0"
scipy = "^1.9.0"
networkx = { git = "https://github.com/adam2392/networkx.git", branch = 'mixededge' }
importlib-resources = { version = "*", python = "<3.9" }
dynetx = { version = "*", optional = true }
pygraphviz = { version = "*", optional = true }

[tool.poetry.group.test]
Expand Down Expand Up @@ -90,7 +89,6 @@ tqdm = { version = "^4.64.0" } # needed in dowhy's package
typing-extensions = { version = "*" } # needed in dowhy's package

[tool.poetry.extras]
ts = ['dynetx']
viz = ['pygraphviz']

[build-system]
Expand Down
16 changes: 11 additions & 5 deletions pywhy_graphs/viz/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import networkx as nx


def draw(G: nx.MixedEdgeGraph, direction: Optional[str] = None):
def draw(G: nx.MixedEdgeGraph, direction: Optional[str] = None, pos: Optional[dict] = None):
"""Visualize the graph.

Parameters
Expand All @@ -12,11 +12,15 @@ def draw(G: nx.MixedEdgeGraph, direction: Optional[str] = None):
The mixed edge graph.
direction : str, optional
The direction, by default None.
pos : dict, optional
The positions of the nodes keyed by node with (x, y) coordinates as values.
By default None, which will
use the default layout from graphviz.

Returns
-------
dot : Digraph
dot language representation of the graph.
dot : graphviz Digraph
DOT language representation of the graph.
"""
from graphviz import Digraph

Expand Down Expand Up @@ -48,8 +52,10 @@ def draw(G: nx.MixedEdgeGraph, direction: Optional[str] = None):

for v in G.nodes:
child = str(v)

dot.node(child, shape=shape, height=".5", width=".5")
if pos and pos.get(v) is not None:
dot.node(child, shape=shape, height=".5", width=".5", pos=f"{pos[v][0]},{pos[v][1]}!")
else:
dot.node(child, shape=shape, height=".5", width=".5")

for parent in G.predecessors(v):
# memoize if we have seen the bidirected circular edge before
Expand Down
77 changes: 77 additions & 0 deletions tests/test_draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import re

import networkx as nx

from pywhy_graphs.viz import draw


def test_draw_pos_is_fully_given():
"""
Ensure the Graphviz pos="x,y!" attribute is generated by the draw function
when pos is given for all nodes.
"""
# create a dummy graph x --> y <-- z and z --> x
graph = nx.DiGraph([("x", "y"), ("z", "y"), ("z", "x")])
# create a graph layout manually
pos = {"x": [0, 0], "y": [1, 0], "z": [0.5, 0.7]}
# draw the graphs
dot = draw(graph, pos=pos)
# get the graph description in textual form
dot_body_text = "".join(dot.body)
# assert that the produced graph contains the right pos argument for all nodes
assert re.search(r"\tx \[.* pos=\"0,0!\"", dot_body_text) is not None
assert re.search(r"\ty \[.* pos=\"1,0!\"", dot_body_text) is not None
assert re.search(r"\tz \[.* pos=\"0.5,0.7!\"", dot_body_text) is not None


def test_draw_pos_is_partially_given():
"""
Ensure the Graphviz pos="x,y!" attribute is generated by the draw function
when pos is given for some nodes but not all.
"""
# create a dummy graph x --> y <-- z and z --> x
graph = nx.DiGraph([("x", "y"), ("z", "y"), ("z", "x")])
# create a graph layout manually
pos = {"x": [0, 0], "y": [1, 0]}
# draw the graphs
dot = draw(graph, pos=pos)
# get the graph description in textual form
dot_body_text = "".join(dot.body)
# assert that the produced graph contains the right pos argument for nodes x and y but not for z
assert re.search(r"\tx \[.* pos=\"0,0!\"", dot_body_text) is not None
assert re.search(r"\ty \[.* pos=\"1,0!\"", dot_body_text) is not None
assert "pos=" not in re.search(r"\tz \[(.*)\]", dot_body_text).groups()[0]


def test_draw_pos_is_not_given():
"""
Ensure the Graphviz pos="x,y!" attribute is not generated by the draw function
when pos is not given.
"""
# create a dummy graph x --> y <-- z and z --> x
graph = nx.DiGraph([("x", "y"), ("z", "y"), ("z", "x")])
# draw the graphs
dot = draw(graph)
# get the graph description in textual form
dot_body_text = "".join(dot.body)
# assert that the produced graph does not contain any pos argument for the nodes
assert "pos=" not in dot_body_text


def test_draw_pos_contains_more_nodes():
"""
Ensure the Graphviz pos="x,y!" attribute is generated by the draw function
when pos is given for some nodes but not all.
"""
# create a dummy graph x --> y <-- z and z --> x
graph = nx.DiGraph([("x", "y"), ("z", "y"), ("z", "x")])
# create a graph layout manually
pos = {"x": [0, 0], "y": [1, 0], "t": [1, 2], "w": [3, 4]}
# draw the graphs
dot = draw(graph, pos=pos)
# get the graph description in textual form
dot_body_text = "".join(dot.body)
# assert that the produced graph contains the right pos argument for nodes x and y but not for z
assert re.search(r"\tx \[.* pos=\"0,0!\"", dot_body_text) is not None
assert re.search(r"\ty \[.* pos=\"1,0!\"", dot_body_text) is not None
assert "pos=" not in re.search(r"\tz \[(.*)\]", dot_body_text).groups()[0]