Skip to content

Commit

Permalink
initial push (draft)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasrobiglio committed Oct 28, 2024
1 parent 5dc6f4f commit 48425d1
Showing 1 changed file with 34 additions and 25 deletions.
59 changes: 34 additions & 25 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
PatchCollection,
Poly3DCollection,
)
from matplotlib.collections import CircleCollection
from scipy.spatial import ConvexHull
import copy

from matplotlib.patches import Circle
from matplotlib.transforms import Affine2D
import mpl_toolkits.mplot3d.art3d as art3d

chaini = chain.from_iterable

Expand Down Expand Up @@ -1240,8 +1246,7 @@ def draw_multilayer(
node_fc="white",
node_ec="black",
node_lw=1,
node_size=5,
node_shape="o",
node_size=0.1,
node_fc_cmap="Reds",
vmin=None,
vmax=None,
Expand Down Expand Up @@ -1295,7 +1300,7 @@ def draw_multilayer(
Radius of the nodes in pixels. If int or float, use the same radius for all
nodes. If iterable or NodeStat, assume the radiuses are specified in the same
order as the nodes are found in H.nodes. Values are clipped below
and above by min_node_size and max_node_size, respectively. By default, 5.
and above by min_node_size and max_node_size, respectively. By default, 0.1.
node_shape : string, optional
The shape of the node. Specification is as matplotlib.scatter
marker. Default is "o".
Expand Down Expand Up @@ -1362,8 +1367,8 @@ def draw_multilayer(
**kwargs : optional args
Alternate default values. Values that can be overwritten are the following:
* "min_node_size" (default: 10)
* "max_node_size" (default: 30)
* "min_node_size" (default: 0.01)
* "max_node_size" (default: 5)
* "min_node_lw" (default: 2)
* "max_node_lw" (default: 10)
* "min_dyad_lw" (default: 1)
Expand All @@ -1381,8 +1386,8 @@ def draw_multilayer(
Collection containing the edges of size > 2
"""
settings = {
"min_node_size": 10,
"max_node_size": 30,
"min_node_size": 0.01,
"max_node_size": 0.5,
"min_dyad_lw": 2,
"max_dyad_lw": 10,
"min_node_lw": 1,
Expand Down Expand Up @@ -1574,26 +1579,30 @@ def draw_multilayer(
ax.add_collection3d(between_lines)

# draw nodes (last)
#create a pathpatch collection
for d in orders:

z = [sep * d] * H.num_nodes

node_collection = ax.scatter(
xs=xy[:, 0],
ys=xy[:, 1],
zs=z,
s=node_size,
marker=node_shape,
c=node_fc,
cmap=node_fc_cmap,
vmin=vmin,
vmax=vmax,
edgecolors=node_ec,
linewidths=node_lw,
zorder=max_order + 1,
plotnonfinite=True, # plot points with nonfinite color
alpha=1,
)
"""
# this works fine except for the fact that we do not return a collection
# for the nodes in the correct way :(
node_collection = []
for id in range(H.num_nodes):
# Create the Circle patch for each node
p = Circle(
xy=(xy[id, 0], xy[id, 1]),
radius=node_size if isinstance(node_size, (int, float)) else node_size[id],
facecolor=node_fc if isinstance(node_fc, str) else node_fc[id],
edgecolor=node_ec if isinstance(node_ec, str) else node_ec[id],
linewidth=node_lw if isinstance(node_lw, (int, float)) else node_lw[id],
)
ax.add_patch(p) # Add the 2D patch to the axis
art3d.pathpatch_2d_to_3d(p, z=z[id], zdir="z")
"""
circles = [Circle(xy=(xy[id, 0], xy[id, 1]), radius=node_size) for id in range(H.num_nodes)]

node_collection = CircleCollection(circles, offsets=xy, transOffset=ax.transData)
ax.add_collection(node_collection)
art3d.pathpatch_2d_to_3d(node_collection, z=z, zdir="z")

ax.view_init(h_angle, v_angle)
ax.set_ylim(np.min(ys) - ydiff * 0.1, np.max(ys) + ydiff * 0.1)
Expand Down

0 comments on commit 48425d1

Please sign in to comment.