diff --git a/.gitignore b/.gitignore index fc56ef9..a2b0785 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ generated .vscode rendering_times.csv media/ -.coverage +venv +.coverage .python-version diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7b28a6..604e6eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,14 +32,14 @@ repos: flake8-rst-docstrings==0.3.0, flake8-simplify==0.19.3, ] - - repo: local - hooks: - - id: pytest - name: pytest - entry: poetry run pytest -cov=src tests/ - language: system - pass_filenames: false - # alternatively you could `types: [python]` so it only runs when python files change - # though tests might be invalidated if you were to say change a data file - always_run: true - stages: [push] +# - repo: local +# hooks: +# - id: pytest +# name: pytest +# entry: poetry run pytest -cov=src tests/ +# language: system +# pass_filenames: false +# # alternatively you could `types: [python]` so it only runs when python files change +# # though tests might be invalidated if you were to say change a data file +# always_run: true +# stages: [push] diff --git a/src/manim_data_structures/m_tree.py b/src/manim_data_structures/m_tree.py new file mode 100644 index 0000000..625ac5a --- /dev/null +++ b/src/manim_data_structures/m_tree.py @@ -0,0 +1,144 @@ +import operator as op +import random +from collections import defaultdict +from copy import copy +from functools import partialmethod, reduce +from typing import Any, Callable, Dict, Hashable, List, Tuple + +import numpy as np +from manim import * +from manim import WHITE, Graph, Mobject, VMobject + + +class Tree(VMobject): + """Computer Science Tree Data Structure""" + + _graph: Graph + __layout_config: dict + __layout_scale: float + __layout: str | dict + __vertex_type: Callable[..., Mobject] + + # __parents: list + # __children: dict[Hashable, list] = defaultdict(list) + + def __init__( + self, + nodes: dict[int, Any], + edges: list[tuple[int, int]], + vertex_type: Callable[..., Mobject], + edge_buff=0.4, + layout="tree", + layout_config={"vertex_spacing": (-1, 1)}, + root_vertex=0, + **kwargs + ): + super().__init__(**kwargs) + vertex_mobjects = {k: vertex_type(v) for k, v in nodes.items()} + self.__layout_config = layout_config + self.__layout_scale = len(nodes) * 0.5 + self.__layout = layout + self.__vertex_type = vertex_type + self._graph = Graph( + list(nodes), + edges, + vertex_mobjects=vertex_mobjects, + layout=layout, + root_vertex=0, + layout_config=self.__layout_config, + layout_scale=len(nodes) * 0.5, + edge_config={"stroke_width": 1, "stroke_color": WHITE}, + ) + + def update_edges(graph: Graph): + """Updates edges of graph""" + for (u, v), edge in graph.edges.items(): + buff_vec = ( + edge_buff + * (graph[u].get_center() - graph[v].get_center()) + / np.linalg.norm(graph[u].get_center() - graph[v].get_center()) + ) + edge.put_start_and_end_on( + graph[u].get_center() - buff_vec, graph[v].get_center() + buff_vec + ) + + self._graph.updaters.clear() + self._graph.updaters.append(update_edges) + self.add(self._graph) + + def insert_node(self, node: Any, edge: tuple[Hashable, Hashable]): + """Inserts a node into the graph as (parent, node)""" + self._graph.add_vertices( + edge[1], vertex_mobjects={edge[1]: self.__vertex_type(node)} + ) + self._graph.add_edges(edge) + return self + + def insert_node2(self, node: Any, edge: tuple[Hashable, Hashable]): + """Inserts a node into the graph as (parent, node)""" + self._graph.change_layout( + self.__layout, + layout_scale=self.__layout_scale, + layout_config=self.__layout_config, + root_vertex=0, + ) + for mob in self.family_members_with_points(): + if (mob.get_center() == self._graph[edge[1]].get_center()).all(): + mob.points = mob.points.astype("float") + return self + + def insert_node3(self, node: Any, edge: tuple[Hashable, Hashable]): + """Inserts a node into the graph as (parent, node)""" + self.suspend_updating() + self.insert_node(node, edge) + # self.resume_updating() + self.insert_node2(node, edge) + + return self + + def remove_node(self, node: Hashable): + """Removes a node from the graph""" + self._graph.remove_vertices(node) + + # def insert_node2(self): + # """Shift by the given vectors. + # + # Parameters + # ---------- + # vectors + # Vectors to shift by. If multiple vectors are given, they are added + # together. + # + # Returns + # ------- + # :class:`Mobject` + # ``self`` + # + # See also + # -------- + # :meth:`move_to` + # """ + # + # total_vector = reduce(op.add, vectors) + # for mob in self.family_members_with_points(): + # mob.points = mob.points.astype("float") + # mob.points += total_vector + # + # return self + + +if __name__ == "__main__": + + class TestScene(Scene): + def construct(self): + # make a parent list for a tree + tree = Tree({0: 0, 1: 1, 2: 2, 3: 3}, [(0, 1), (0, 2), (1, 3)], Integer) + self.play(Create(tree)) + self.wait() + self.play(tree.animate.insert_node3(4, (2, 4)), run_time=0) + self.wait() + + config.preview = True + config.renderer = "cairo" + config.quality = "low_quality" + TestScene().render(preview=True) diff --git a/src/manim_data_structures/nary_tree.py b/src/manim_data_structures/nary_tree.py new file mode 100644 index 0000000..0c23737 --- /dev/null +++ b/src/manim_data_structures/nary_tree.py @@ -0,0 +1,118 @@ +from typing import Any, Callable, Hashable + +import networkx as nx +from m_tree import Tree +from manim import Mobject + + +def _nary_layout( + T: nx.classes.graph.Graph, + vertex_spacing: tuple | None = None, + n: int | None = None, +): + if not n: + raise ValueError("the n-ary tree layout requires the n parameter") + if not nx.is_tree(T): + raise ValueError("The tree layout must be used with trees") + + max_height = NaryTree.calc_loc(max(T), n)[1] + + def calc_pos(x, y): + """ + Scales the coordinates to the desired spacing + """ + return (x - (n**y - 1) / 2) * vertex_spacing[0] * n ** ( + max_height - y + ), y * vertex_spacing[1] + + return { + i: np.array([x, y, 0]) + for i, (x, y) in ((i, calc_pos(*NaryTree.calc_loc(i, n))) for i in T) + } + + +class NaryTree(Tree): + def __init__( + self, + nodes: dict[int, Any], + num_child: int, + vertex_type: Callable[..., Mobject], + edge_buff=0.4, + layout_config=None, + **kwargs + ): + if layout_config is None: + layout_config = {"vertex_spacing": (-1, 1)} + self.__layout_config = layout_config + self.num_child = num_child + + edges = [(self.get_parent(e), e) for e in nodes if e != 0] + super().__init__(nodes, edges, vertex_type, edge_buff, **kwargs) + dict_layout = _nary_layout(self._graph._graph, n=num_child, **layout_config) + self._graph.change_layout(dict_layout) + + @staticmethod + def calc_loc(i, n): + """ + Calculates the coordinates in terms of the shifted level order x position and level height + """ + if n == 1: + return 1, i + 1 + height = int(np.emath.logn(n, i * (n - 1) + 1)) + node_shift = (1 - n**height) // (1 - n) + return i - node_shift, height + + @staticmethod + def calc_idx(loc, n): + """ + Calculates the index from the coordinates + """ + x, y = loc + if n == 1: + return y - 1 + + return int(x + (1 - n**y) // (1 - n)) + + def get_parent(self, idx): + """ + Returns the index of the parent of the node at the given index + """ + x, y = NaryTree.calc_loc(idx, self.num_child) + new_loc = x // self.num_child, y - 1 + return NaryTree.calc_idx(new_loc, self.num_child) + + def insert_node(self, node: Any, index: Hashable): + """Inserts a node into the graph""" + res = super().insert_node(node, (self.get_parent(index), index)) + dict_layout = _nary_layout( + self._graph._graph, n=self.num_child, **self.__layout_config + ) + self._graph.change_layout(dict_layout) + self.update() + return res + + +if __name__ == "__main__": + from manim import * + + class TestScene(Scene): + def construct(self): + tree = NaryTree( + {0: 0, 1: 1, 4: 4}, + num_child=2, + vertex_type=Integer, + layout_config={"vertex_spacing": (1, -1)}, + ) + # tree._graph.change_layout(root_vertex=0, layout_config=tree._Tree__layout_config, + # layout_scale=tree._Tree__layout_scale) + self.play(Create(tree)) + self.wait() + tree.insert_node(1, 3) + self.wait() + tree.remove_node(4) + self.wait() + + config.preview = True + config.renderer = "cairo" + config.quality = "low_quality" + TestScene().render(preview=True) diff --git a/tests/test_mtree.py b/tests/test_mtree.py new file mode 100644 index 0000000..9843417 --- /dev/null +++ b/tests/test_mtree.py @@ -0,0 +1,11 @@ +# TODO: Fill with appropriate tests +def test_getitem(): + pass + + +def test_setitem(): + pass + + +def test_iteration(): + pass