From a99d3ab2125fd4dc841e53de7b29cb499a9e8ca0 Mon Sep 17 00:00:00 2001 From: Subhiiiiii Date: Mon, 31 Mar 2025 14:06:07 +0530 Subject: [PATCH 1/2] Fix issue #45 --- jraph/_src/models.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/jraph/_src/models.py b/jraph/_src/models.py index e6a2562..dd31752 100644 --- a/jraph/_src/models.py +++ b/jraph/_src/models.py @@ -547,19 +547,26 @@ def _ApplyGCN(graph): nodes = update_node_fn(nodes) # Equivalent to jnp.sum(n_node), but jittable total_num_nodes = tree.tree_leaves(nodes)[0].shape[0] + + # Handle None senders and receivers by initializing empty arrays + if senders is None: + senders = jnp.array([], dtype=jnp.int32) + if receivers is None: + receivers = jnp.array([], dtype=jnp.int32) + if add_self_edges: - # We add self edges to the senders and receivers so that each node - # includes itself in aggregation. - # In principle, a `GraphsTuple` should partition by n_edge, but in - # this case it is not required since a GCN is agnostic to whether - # the `GraphsTuple` is a batch of graphs or a single large graph. - conv_receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), + # We add self edges to the senders and receivers so that each node + # includes itself in aggregation. + # In principle, a `GraphsTuple` should partition by n_edge, but in + # this case it is not required since a GCN is agnostic to whether + # the `GraphsTuple` is a batch of graphs or a single large graph. + conv_receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0) - conv_senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), + conv_senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0) else: - conv_senders = senders - conv_receivers = receivers + conv_senders = senders + conv_receivers = receivers # pylint: disable=g-long-lambda if symmetric_normalization: From f6202548e5b56916bdc6d3e5711eab32cfaa8b47 Mon Sep 17 00:00:00 2001 From: Subhiiiiii Date: Mon, 31 Mar 2025 14:38:03 +0530 Subject: [PATCH 2/2] Fix #46 --- jraph/_src/models.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/jraph/_src/models.py b/jraph/_src/models.py index dd31752..7135642 100644 --- a/jraph/_src/models.py +++ b/jraph/_src/models.py @@ -22,6 +22,7 @@ import jax.tree_util as tree from jraph._src import graph as gn_graph from jraph._src import utils +from contextlib import nullcontext # As of 04/2020 pytype doesn't support recursive types. # pytype: disable=not-supported-yet @@ -547,13 +548,6 @@ def _ApplyGCN(graph): nodes = update_node_fn(nodes) # Equivalent to jnp.sum(n_node), but jittable total_num_nodes = tree.tree_leaves(nodes)[0].shape[0] - - # Handle None senders and receivers by initializing empty arrays - if senders is None: - senders = jnp.array([], dtype=jnp.int32) - if receivers is None: - receivers = jnp.array([], dtype=jnp.int32) - if add_self_edges: # We add self edges to the senders and receivers so that each node # includes itself in aggregation. @@ -601,3 +595,17 @@ def _ApplyGCN(graph): return graph._replace(nodes=nodes) return _ApplyGCN + + +def random_graph(device=None): + """Returns a random graph with 10 nodes and 20 edges. + + Args: + device: Optional device to place the arrays on. If None, uses current device. + """ + n_node = 10 + n_edge = 20 + with jax.device(device) if device else nullcontext(): + senders = jnp.random.randint(0, n_node, size=n_edge) + receivers = jnp.random.randint(0, n_node, size=n_edge) + # ...