diff --git a/xgi/drawing/draw.py b/xgi/drawing/draw.py index 3b0d1d28..b29f5d06 100644 --- a/xgi/drawing/draw.py +++ b/xgi/drawing/draw.py @@ -1238,6 +1238,13 @@ def draw_hyperedge_labels( return text_items +def create_circle(x, y, z, radius, num_points=30): + theta = np.linspace(0, 2 * np.pi, num_points) + circle_x = x + radius * np.cos(theta) + circle_y = y + radius * np.sin(theta) + circle_z = np.full_like(circle_x, z) + return np.array([circle_x, circle_y, circle_z]).T + def draw_multilayer( H, @@ -1579,20 +1586,32 @@ def draw_multilayer( ax.add_collection3d(between_lines) # draw nodes (last) - #create a pathpatch collection - node_collection = [] # to be implemented still for d in orders: z = [sep * d] * H.num_nodes + patches = [] for id in range(H.num_nodes): - p = Circle( - xy=(xy[id, 0], xy[id, 1]), - radius=node_size if isinstance(node_size, (int, float)) else node_size[id], - facecolor=node_fc if isinstance(node_fc, str) else node_fc[id], - edgecolor=node_ec if isinstance(node_ec, str) else node_ec[id], - linewidth=node_lw if isinstance(node_lw, (int, float)) else node_lw[id], - ) - ax.add_patch(p) # Add the 2D patch to the axis - art3d.pathpatch_2d_to_3d(p, z=z[id], zdir="z") + # Set radius based on node_size + radius = node_size if isinstance(node_size, (int, float)) else node_size[id] + # Create a circle at specified coordinates + node = create_circle(xs[id], ys[id], z[id], radius=radius) + patches.append(node) + + # Create Poly3DCollection with customization options + node_collection = Poly3DCollection( + patches, + facecolors=node_fc, + edgecolor=node_ec, + cmap=node_fc_cmap, + linewidths=node_lw, + alpha=1, + zorder=max_order + 2 + ) + + if vmin is not None and vmax is not None: + node_collection.set_clim(vmin, vmax) + + ax.add_collection3d(node_collection) + ax.view_init(h_angle, v_angle) ax.set_ylim(np.min(ys) - ydiff * 0.1, np.max(ys) + ydiff * 0.1)