1
+ from manim import *
2
+
3
+ from manim_ml .neural_network .activation_functions import get_activation_function_by_name
4
+ from manim_ml .neural_network .activation_functions .activation_function import (
5
+ ActivationFunction ,
6
+ )
7
+ from manim_ml .neural_network .layers .parent_layers import VGroupNeuralNetworkLayer
8
+
9
+ class MathOperationLayer (VGroupNeuralNetworkLayer ):
10
+ """Handles rendering a layer for a neural network"""
11
+ valid_operations = ["+" , "-" , "*" , "/" ]
12
+
13
+ def __init__ (
14
+ self ,
15
+ operation_type : str ,
16
+ node_radius = 0.5 ,
17
+ node_color = BLUE ,
18
+ node_stroke_width = 2.0 ,
19
+ active_color = ORANGE ,
20
+ activation_function = None ,
21
+ font_size = 20 ,
22
+ ** kwargs
23
+ ):
24
+ super (VGroupNeuralNetworkLayer , self ).__init__ (** kwargs )
25
+ # Ensure operation type is valid
26
+ assert operation_type in MathOperationLayer .valid_operations
27
+ self .operation_type = operation_type
28
+ self .node_radius = node_radius
29
+ self .node_color = node_color
30
+ self .node_stroke_width = node_stroke_width
31
+ self .active_color = active_color
32
+ self .font_size = font_size
33
+ self .activation_function = activation_function
34
+
35
+ def construct_layer (
36
+ self ,
37
+ input_layer : "NeuralNetworkLayer" ,
38
+ output_layer : "NeuralNetworkLayer" ,
39
+ ** kwargs
40
+ ):
41
+ """Creates the neural network layer"""
42
+ # Draw the operation
43
+ self .operation_text = Text (
44
+ self .operation_type ,
45
+ font_size = self .font_size
46
+ )
47
+ self .add (self .operation_text )
48
+ # Make the surrounding circle
49
+ self .surrounding_circle = Circle (
50
+ color = self .node_color ,
51
+ stroke_width = self .node_stroke_width
52
+ ).surround (self .operation_text )
53
+ self .add (self .surrounding_circle )
54
+ # Make the activation function
55
+ self .construct_activation_function ()
56
+ super ().construct_layer (input_layer , output_layer , ** kwargs )
57
+
58
+ def construct_activation_function (self ):
59
+ """Construct the activation function"""
60
+ # Add the activation function
61
+ if not self .activation_function is None :
62
+ # Check if it is a string
63
+ if isinstance (self .activation_function , str ):
64
+ activation_function = get_activation_function_by_name (
65
+ self .activation_function
66
+ )()
67
+ else :
68
+ assert isinstance (self .activation_function , ActivationFunction )
69
+ activation_function = self .activation_function
70
+ # Plot the function above the rest of the layer
71
+ self .activation_function = activation_function
72
+ self .add (self .activation_function )
73
+
74
+ def make_forward_pass_animation (self , layer_args = {}, ** kwargs ):
75
+ """Makes the forward pass animation
76
+
77
+ Parameters
78
+ ----------
79
+ layer_args : dict, optional
80
+ layer specific arguments, by default {}
81
+
82
+ Returns
83
+ -------
84
+ AnimationGroup
85
+ Forward pass animation
86
+ """
87
+ # Make highlight animation
88
+ succession = Succession (
89
+ ApplyMethod (
90
+ self .surrounding_circle .set_color ,
91
+ self .active_color ,
92
+ run_time = 0.25
93
+ ),
94
+ Wait (1.0 ),
95
+ ApplyMethod (
96
+ self .surrounding_circle .set_color ,
97
+ self .node_color ,
98
+ run_time = 0.25
99
+ ),
100
+ )
101
+ # Animate the activation function
102
+ if not self .activation_function is None :
103
+ animation_group = AnimationGroup (
104
+ succession ,
105
+ self .activation_function .make_evaluate_animation (),
106
+ lag_ratio = 0.0 ,
107
+ )
108
+ return animation_group
109
+ else :
110
+ return succession
111
+
112
+ def get_center (self ):
113
+ return self .surrounding_circle .get_center ()
114
+
115
+ def get_left (self ):
116
+ return self .surrounding_circle .get_left ()
117
+
118
+ def get_right (self ):
119
+ return self .surrounding_circle .get_right ()
120
+
121
+ def move_to (self , mobject_or_point ):
122
+ """Moves the center of the layer to the given mobject or point"""
123
+ layer_center = self .surrounding_circle .get_center ()
124
+ if isinstance (mobject_or_point , Mobject ):
125
+ target_center = mobject_or_point .get_center ()
126
+ else :
127
+ target_center = mobject_or_point
128
+
129
+ self .shift (target_center - layer_center )
0 commit comments