Skip to content

Commit

Permalink
[ENH] Add PAG class (#9)
Browse files Browse the repository at this point in the history
- Add PAG class
- add error checks to CPDAG and PAG when adding edges and constructing graph
- add algorithm to check validity of a CPDAG and / or PAG class
- Update poetry to networkx version on PR branch `mixededge` by Adam Li
  • Loading branch information
adam2392 authored Aug 30, 2022
1 parent a328c0f commit ffd9af2
Show file tree
Hide file tree
Showing 15 changed files with 692 additions and 19 deletions.
22 changes: 22 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,25 @@ build-docs:
make -C docs/ clean
make -C docs/ html-noplot
cd docs/ && make view

clean-pyc:
find . -name "*.pyc" | xargs rm -f

clean-so:
find . -name "*.so" | xargs rm -f
find . -name "*.pyd" | xargs rm -f

clean-build:
rm -rf _build build dist mne_icalabel.egg-info

clean-ctags:
rm -f tags

clean-cache:
find . -name "__pycache__" | xargs rm -rf

clean-test:
rm -rf .pytest_cache .mypy_cache .ipynb_checkpoints
rm junit-results.xml

clean: clean-build clean-pyc clean-so clean-ctags clean-cache clean-test
3 changes: 2 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ graphs encountered in the literature.
.. autosummary::
:toctree: generated/

CPDAG
ADMG
CPDAG
PAG
11 changes: 6 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@
"arguments",
"no",
"attributes", "dictionary",
"DAG", "causal", "CPDAG", "PAG", "ADMG",
# pywhy-graphs
"causal",
'circular', 'endpoint',
# networkx
"node",
"nodes",
Expand Down Expand Up @@ -142,7 +144,6 @@
"ADMG": "pywhy_graphs.ADMG",
"PAG": "pywhy_graphs.PAG",
"CPDAG": "pywhy_graphs.CPDAG",
"DAG": "pywhy_graphs.DAG",
# joblib
"joblib.Parallel": "joblib.Parallel",
# pandas
Expand Down Expand Up @@ -172,13 +173,13 @@
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/devdocs", None),
"scipy": ("https://scipy.github.io/devdocs", None),
"networkx": ("https://networkx.org/documentation/latest/", None),
"nx-guides": ("https://networkx.org/nx-guides/", None),
"matplotlib": ("https://matplotlib.org/stable", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/dev", None),
"pgmpy": ("https://pgmpy.org", None),
"sklearn": ("https://scikit-learn.org/stable", None),
"joblib": ("https://joblib.readthedocs.io/en/latest", None),
"networkx": ("https://networkx.org/documentation/latest/", None),
"nx-guides": ("https://networkx.org/nx-guides/", None),
"matplotlib": ("https://matplotlib.org/stable", None),
}
intersphinx_timeout = 5

Expand Down
1 change: 1 addition & 0 deletions docs/whats_new/v0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Changelog
---------

- |Feature| Implement and test the :class:`pywhy_graphs.CPDAG` for CPDAGs, by `Adam Li`_ (:pr:`6`)
- |Feature| Implement and test the :class:`pywhy_graphs.PAG` for PAGs, by `Adam Li`_ (:pr:`9`)

Code and Documentation Contributors
-----------------------------------
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ profile = 'black'
multi_line_output = 3
line_length = 100
py_version = 38
extend_skip_glob = ['setup.py', 'docs/*', 'examples/*']
extend_skip_glob = ['setup.py', 'docs/*', 'examples/*', 'pywhy_graphs/__init__.py']

[tool.pydocstyle]
convention = 'numpy'
Expand Down
3 changes: 2 additions & 1 deletion pywhy_graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ._version import __version__ # noqa: F401
from .classes import ADMG, CPDAG
from .classes import ADMG, CPDAG, PAG
from .algorithms import is_valid_mec_graph
1 change: 1 addition & 0 deletions pywhy_graphs/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .graph import * # noqa: F403
49 changes: 49 additions & 0 deletions pywhy_graphs/algorithms/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from itertools import combinations
from typing import Union

import networkx as nx

import pywhy_graphs as pgraph


def is_valid_mec_graph(G: Union[pgraph.PAG, pgraph.CPDAG], on_error: str = "raise") -> bool:
"""Check G is a valid PAG.
A valid CPDAG/PAG is one where each pair of nodes have
at most one edge between them.
Parameters
----------
G : pgraph.PAG | pgraph.CPDAG
The PAG or CPDAG.
on_error : str
Whether to raise an error if the graph is non-compliant. Default is 'raise'.
Other options are 'ignore'.
Returns
-------
bool
Whether G is a valid PAG or CPDAG.
"""
for node1, node2 in combinations(G.nodes, 2):
n_edges = 0
names = []
for name, graph in G.get_graphs().items():
if (node1, node2) in graph.edges or (node2, node1) in graph.edges:
n_edges += 1
names.append(name)
if n_edges > 1:
if on_error == "raise":
raise RuntimeError(
f"There is more than one edge between ({node1}, {node2}) in the "
f"edge types: {names}. Please fix the construction of the PAG."
)
return False

# the directed edges should not form cycles
if not nx.is_directed_acyclic_graph(G.sub_directed_graph()):
if on_error == "raise":
raise RuntimeError(f"{G} is not a DAG, which it should be.")
return False

return True
1 change: 1 addition & 0 deletions pywhy_graphs/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .admg import ADMG
from .cpdag import CPDAG
from .pag import PAG
2 changes: 1 addition & 1 deletion pywhy_graphs/classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import networkx as nx

from pywhy_graphs.typing import Node
from ..typing import Node


class GraphMixinProtocol(Protocol):
Expand Down
81 changes: 81 additions & 0 deletions pywhy_graphs/classes/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from enum import Enum, EnumMeta


class MetaEnum(EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True


class EdgeType(Enum, metaclass=MetaEnum):
"""Enumeration of different causal edge endpoints.
Categories
----------
directed : str
Signifies arrowhead ("->") edges.
circle : str
Signifies a circle ("*-o") endpoint. That is an uncertain edge,
which is either circle with directed edge (``o->``),
circle with undirected edge (``o-``), or
circle with circle edge (``o-o``).
undirected : str
Signifies an undirected ("-") edge. That is an undirected edge (``-``),
or circle with circle edge (``-o``).
Notes
-----
The possible edges between two nodes thus are:
->, <-, <->, o->, <-o, o-o
In general, among all possible causal graphs, arrowheads depict
non-descendant relationships. In DAGs, arrowheads depict direct
causal relationships (i.e. parents/children). In ADMGs, arrowheads
can come from directed edges, or bidirected edges
"""

ALL = "all"
DIRECTED = "directed"
BIDIRECTED = "bidirected"
CIRCLE = "circle"
UNDIRECTED = "undirected"


class EndPoint(Enum, metaclass=MetaEnum):
"""Enumeration of different causal edge endpoints.
Categories
----------
arrow : str
Signifies arrowhead (">") endpoint. That is a normal
directed edge (``->``), bidirected arrow (``<->``),
or circle with directed edge (``o->``).
circle : str
Signifies a circle ("o") endpoint. That is an uncertain edge,
which is either circle with directed edge (``o->``),
circle with undirected edge (``o-``), or
circle with circle edge (``o-o``).
tail : str
Signifies a tail ("-") endpoint. That is either
a directed edge (``->``), or an undirected edge (``-``), or
circle with circle edge (``-o``).
Notes
-----
The possible edges between two nodes thus are:
->, <-, <->, o->, <-o, o-o
In general, among all possible causal graphs, arrowheads depict
non-descendant relationships. In DAGs, arrowheads depict direct
causal relationships (i.e. parents/children). In ADMGs, arrowheads
can come from directed edges, or bidirected edges
"""

arrow = "arrow"
circle = "circle"
tail = "tail"
58 changes: 55 additions & 3 deletions pywhy_graphs/classes/cpdag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import networkx as nx

from pywhy_graphs.typing import Node
import pywhy_graphs

from ..typing import Node
from .base import AncestralMixin, ConservativeMixin
from .config import EdgeType


class CPDAG(nx.MixedEdgeGraph, AncestralMixin, ConservativeMixin):
Expand Down Expand Up @@ -62,8 +64,8 @@ def __init__(
self._directed_name = directed_edge_name
self._undirected_name = undirected_edge_name

if not nx.is_directed_acyclic_graph(self.sub_directed_graph()):
raise RuntimeError(f"{self} is not a DAG, which it should be.")
# check that construction of PAG was valid
pywhy_graphs.is_valid_mec_graph(self)

@property
def undirected_edge_name(self) -> str:
Expand Down Expand Up @@ -154,3 +156,53 @@ def possible_parents(self, n: Node) -> Iterator[Node]:
An iterator of the parents of node 'n'.
"""
return self.sub_undirected_graph().neighbors(n)

def add_edge(self, u_of_edge, v_of_edge, edge_type="all", **attr):
_check_adding_cpdag_edge(
self, u_of_edge=u_of_edge, v_of_edge=v_of_edge, edge_type=edge_type
)
return super().add_edge(u_of_edge, v_of_edge, edge_type, **attr)

def add_edges_from(self, ebunch_to_add, edge_type, **attr):
for u_of_edge, v_of_edge in ebunch_to_add:
_check_adding_cpdag_edge(
self, u_of_edge=u_of_edge, v_of_edge=v_of_edge, edge_type=edge_type
)
return super().add_edges_from(ebunch_to_add, edge_type, **attr)


def _check_adding_cpdag_edge(graph: CPDAG, u_of_edge: Node, v_of_edge: Node, edge_type: EdgeType):
"""Check compatibility among internal graphs when adding an edge of a certain type.
Parameters
----------
u_of_edge : node
The start node.
v_of_edge : node
The end node.
edge_type : EdgeType
The edge type that is being added.
"""
raise_error = False
if edge_type == EdgeType.DIRECTED:
# there should not be a circle edge, or a bidirected edge
if graph.has_edge(u_of_edge, v_of_edge, graph.undirected_edge_name):
raise_error = True
if graph.has_edge(v_of_edge, u_of_edge, graph.directed_edge_name):
raise RuntimeError(
f"There is an existing {v_of_edge} -> {u_of_edge}. You are "
f"trying to add a directed edge from {u_of_edge} -> {v_of_edge}. "
f"If your intention is to create a bidirected edge, first remove the "
f"edge and then explicitly add the bidirected edge."
)
elif edge_type == EdgeType.UNDIRECTED:
# there should not be any type of edge between the two
if graph.has_edge(u_of_edge, v_of_edge):
raise_error = True

if raise_error:
raise RuntimeError(
f"There is already an existing edge between {u_of_edge} and {v_of_edge}. "
f"Adding a {edge_type} edge is not possible. Please remove the existing "
f"edge first."
)
Loading

0 comments on commit ffd9af2

Please sign in to comment.