Skip to content

Commit b1c838a

Browse files
committed
Added residual layer example. Fixed some bugs in the process.
1 parent 2b21261 commit b1c838a

11 files changed

+382
-18
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from manim import *
2+
3+
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
4+
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer
5+
from manim_ml.neural_network.neural_network import NeuralNetwork
6+
7+
# Make the specific scene
8+
config.pixel_height = 1200
9+
config.pixel_width = 1900
10+
config.frame_height = 6.0
11+
config.frame_width = 6.0
12+
13+
def make_code_snippet():
14+
code_str = """
15+
nn = NeuralNetwork({
16+
"feed_forward_1": FeedForwardLayer(3),
17+
"feed_forward_2": FeedForwardLayer(3, activation_function="ReLU"),
18+
"feed_forward_3": FeedForwardLayer(3),
19+
"sum_operation": MathOperationLayer("+", activation_function="ReLU"),
20+
})
21+
nn.add_connection("feed_forward_1", "sum_operation")
22+
self.play(nn.make_forward_pass_animation())
23+
"""
24+
25+
code = Code(
26+
code=code_str,
27+
tab_width=4,
28+
background_stroke_width=1,
29+
background_stroke_color=WHITE,
30+
insert_line_no=False,
31+
style="monokai",
32+
# background="window",
33+
language="py",
34+
)
35+
code.scale(0.38)
36+
37+
return code
38+
39+
40+
class CombinedScene(ThreeDScene):
41+
def construct(self):
42+
# Add the network
43+
nn = NeuralNetwork({
44+
"feed_forward_1": FeedForwardLayer(3),
45+
"feed_forward_2": FeedForwardLayer(3, activation_function="ReLU"),
46+
"feed_forward_3": FeedForwardLayer(3),
47+
"sum_operation": MathOperationLayer("+", activation_function="ReLU"),
48+
},
49+
layer_spacing=0.38
50+
)
51+
# Make connections
52+
input_blank_dot = Dot(
53+
nn.input_layers_dict["feed_forward_1"].get_left() - np.array([0.65, 0.0, 0.0])
54+
)
55+
nn.add_connection(input_blank_dot, "feed_forward_1", arc_direction="straight")
56+
nn.add_connection("feed_forward_1", "sum_operation")
57+
output_blank_dot = Dot(
58+
nn.input_layers_dict["sum_operation"].get_right() + np.array([0.65, 0.0, 0.0])
59+
)
60+
nn.add_connection("sum_operation", output_blank_dot, arc_direction="straight")
61+
# Center the nn
62+
nn.move_to(ORIGIN)
63+
self.add(nn)
64+
# Make code snippet
65+
code = make_code_snippet()
66+
code.next_to(nn, DOWN)
67+
self.add(code)
68+
# Group it all
69+
group = Group(nn, code)
70+
group.move_to(ORIGIN)
71+
# Play animation
72+
forward_pass = nn.make_forward_pass_animation()
73+
self.wait(1)
74+
self.play(forward_pass)

manim_ml/neural_network/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@
3939
from manim_ml.neural_network.layers.triplet_to_feed_forward import TripletToFeedForward
4040
from manim_ml.neural_network.layers.triplet import TripletLayer
4141
from manim_ml.neural_network.layers.vector import VectorLayer
42+
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer

manim_ml/neural_network/layers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .paired_query import PairedQueryLayer
3232
from .paired_query_to_feed_forward import PairedQueryToFeedForward
3333
from .max_pooling_2d import MaxPooling2DLayer
34+
from .feed_forward_to_math_operation import FeedForwardToMathOperation
3435

3536
connective_layers_list = (
3637
EmbeddingToFeedForward,
@@ -48,4 +49,5 @@
4849
Convolutional2DToMaxPooling2D,
4950
MaxPooling2DToConvolutional2D,
5051
MaxPooling2DToFeedForward,
52+
FeedForwardToMathOperation
5153
)

manim_ml/neural_network/layers/convolutional_2d.py

+10
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,16 @@ def get_height(self):
272272
"""Overrides get height function"""
273273
return self.feature_maps.length_over_dim(1)
274274

275+
def move_to(self, mobject_or_point):
276+
"""Moves the center of the layer to the given mobject or point"""
277+
layer_center = self.feature_maps.get_center()
278+
if isinstance(mobject_or_point, Mobject):
279+
target_center = mobject_or_point.get_center()
280+
else:
281+
target_center = mobject_or_point
282+
283+
self.shift(target_center - layer_center)
284+
275285
@override_animation(Create)
276286
def _create_override(self, **kwargs):
277287
return FadeIn(self.feature_maps)

manim_ml/neural_network/layers/feed_forward.py

+22
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,25 @@ def _create_override(self, **kwargs):
153153

154154
animation_group = AnimationGroup(*animations, lag_ratio=0.0)
155155
return animation_group
156+
157+
def get_height(self):
158+
return self.surrounding_rectangle.get_height()
159+
160+
def get_center(self):
161+
return self.surrounding_rectangle.get_center()
162+
163+
def get_left(self):
164+
return self.surrounding_rectangle.get_left()
165+
166+
def get_right(self):
167+
return self.surrounding_rectangle.get_right()
168+
169+
def move_to(self, mobject_or_point):
170+
"""Moves the center of the layer to the given mobject or point"""
171+
layer_center = self.surrounding_rectangle.get_center()
172+
if isinstance(mobject_or_point, Mobject):
173+
target_center = mobject_or_point.get_center()
174+
else:
175+
target_center = mobject_or_point
176+
177+
self.shift(target_center - layer_center)

manim_ml/neural_network/layers/feed_forward_to_feed_forward.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def make_forward_pass_animation(
9090
if self.passing_flash:
9191
copy_edge = edge.copy()
9292
anim = ShowPassingFlash(
93-
copy_edge.set_color(self.animation_dot_color), time_width=0.2
93+
copy_edge.set_color(self.animation_dot_color), time_width=0.3
9494
)
9595
else:
9696
anim = MoveAlongPath(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from manim import *
2+
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
3+
from manim_ml.neural_network.layers.parent_layers import ConnectiveLayer
4+
from manim_ml.neural_network.layers.math_operation_layer import MathOperationLayer
5+
from manim_ml.utils.mobjects.connections import NetworkConnection
6+
7+
class FeedForwardToMathOperation(ConnectiveLayer):
8+
"""Image Layer to FeedForward layer"""
9+
10+
input_class = FeedForwardLayer
11+
output_class = MathOperationLayer
12+
13+
def __init__(
14+
self,
15+
input_layer,
16+
output_layer,
17+
active_color=ORANGE,
18+
**kwargs
19+
):
20+
self.active_color = active_color
21+
super().__init__(input_layer, output_layer, **kwargs)
22+
23+
def construct_layer(
24+
self,
25+
input_layer: "NeuralNetworkLayer",
26+
output_layer: "NeuralNetworkLayer",
27+
**kwargs
28+
):
29+
# Draw an arrow from the output of the feed forward layer to the
30+
# input of the math operation layer
31+
self.connection = NetworkConnection(
32+
self.input_layer,
33+
self.output_layer,
34+
arc_direction="straight",
35+
buffer=0.05
36+
)
37+
self.add(self.connection)
38+
39+
return super().construct_layer(input_layer, output_layer, **kwargs)
40+
41+
def make_forward_pass_animation(self, layer_args={}, **kwargs):
42+
"""Makes dots diverge from the given location and move to the feed forward nodes decoder"""
43+
# Make flashing pass animation on arrow
44+
passing_flash = ShowPassingFlash(
45+
self.connection.copy().set_color(self.active_color)
46+
)
47+
48+
return passing_flash
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from manim import *
2+
3+
from manim_ml.neural_network.activation_functions import get_activation_function_by_name
4+
from manim_ml.neural_network.activation_functions.activation_function import (
5+
ActivationFunction,
6+
)
7+
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
8+
9+
class MathOperationLayer(VGroupNeuralNetworkLayer):
10+
"""Handles rendering a layer for a neural network"""
11+
valid_operations = ["+", "-", "*", "/"]
12+
13+
def __init__(
14+
self,
15+
operation_type: str,
16+
node_radius=0.5,
17+
node_color=BLUE,
18+
node_stroke_width=2.0,
19+
active_color=ORANGE,
20+
activation_function=None,
21+
font_size=20,
22+
**kwargs
23+
):
24+
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
25+
# Ensure operation type is valid
26+
assert operation_type in MathOperationLayer.valid_operations
27+
self.operation_type = operation_type
28+
self.node_radius = node_radius
29+
self.node_color = node_color
30+
self.node_stroke_width = node_stroke_width
31+
self.active_color = active_color
32+
self.font_size = font_size
33+
self.activation_function = activation_function
34+
35+
def construct_layer(
36+
self,
37+
input_layer: "NeuralNetworkLayer",
38+
output_layer: "NeuralNetworkLayer",
39+
**kwargs
40+
):
41+
"""Creates the neural network layer"""
42+
# Draw the operation
43+
self.operation_text = Text(
44+
self.operation_type,
45+
font_size=self.font_size
46+
)
47+
self.add(self.operation_text)
48+
# Make the surrounding circle
49+
self.surrounding_circle = Circle(
50+
color=self.node_color,
51+
stroke_width=self.node_stroke_width
52+
).surround(self.operation_text)
53+
self.add(self.surrounding_circle)
54+
# Make the activation function
55+
self.construct_activation_function()
56+
super().construct_layer(input_layer, output_layer, **kwargs)
57+
58+
def construct_activation_function(self):
59+
"""Construct the activation function"""
60+
# Add the activation function
61+
if not self.activation_function is None:
62+
# Check if it is a string
63+
if isinstance(self.activation_function, str):
64+
activation_function = get_activation_function_by_name(
65+
self.activation_function
66+
)()
67+
else:
68+
assert isinstance(self.activation_function, ActivationFunction)
69+
activation_function = self.activation_function
70+
# Plot the function above the rest of the layer
71+
self.activation_function = activation_function
72+
self.add(self.activation_function)
73+
74+
def make_forward_pass_animation(self, layer_args={}, **kwargs):
75+
"""Makes the forward pass animation
76+
77+
Parameters
78+
----------
79+
layer_args : dict, optional
80+
layer specific arguments, by default {}
81+
82+
Returns
83+
-------
84+
AnimationGroup
85+
Forward pass animation
86+
"""
87+
# Make highlight animation
88+
succession = Succession(
89+
ApplyMethod(
90+
self.surrounding_circle.set_color,
91+
self.active_color,
92+
run_time=0.25
93+
),
94+
Wait(1.0),
95+
ApplyMethod(
96+
self.surrounding_circle.set_color,
97+
self.node_color,
98+
run_time=0.25
99+
),
100+
)
101+
# Animate the activation function
102+
if not self.activation_function is None:
103+
animation_group = AnimationGroup(
104+
succession,
105+
self.activation_function.make_evaluate_animation(),
106+
lag_ratio=0.0,
107+
)
108+
return animation_group
109+
else:
110+
return succession
111+
112+
def get_center(self):
113+
return self.surrounding_circle.get_center()
114+
115+
def get_left(self):
116+
return self.surrounding_circle.get_left()
117+
118+
def get_right(self):
119+
return self.surrounding_circle.get_right()
120+
121+
def move_to(self, mobject_or_point):
122+
"""Moves the center of the layer to the given mobject or point"""
123+
layer_center = self.surrounding_circle.get_center()
124+
if isinstance(mobject_or_point, Mobject):
125+
target_center = mobject_or_point.get_center()
126+
else:
127+
target_center = mobject_or_point
128+
129+
self.shift(target_center - layer_center)

manim_ml/neural_network/neural_network.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,31 @@ def make_input_layers_dict(self, input_layers):
9393

9494
def add_connection(
9595
self,
96-
start_layer_name,
97-
end_layer_name,
96+
start_mobject_or_name,
97+
end_mobject_or_name,
9898
connection_style="default",
9999
connection_position="bottom",
100+
arc_direction="down"
100101
):
101102
"""Add connection from start layer to end layer"""
102103
assert connection_style in ["default"]
103104
if connection_style == "default":
104105
# Make arrow connection from start layer to end layer
105106
# Add the connection
107+
if isinstance(start_mobject_or_name, Mobject):
108+
input_mobject = start_mobject_or_name
109+
else:
110+
input_mobject = self.input_layers_dict[start_mobject_or_name]
111+
if isinstance(end_mobject_or_name, Mobject):
112+
output_mobject = end_mobject_or_name
113+
else:
114+
output_mobject = self.input_layers_dict[end_mobject_or_name]
115+
106116
connection = NetworkConnection(
107-
self.input_layers_dict[start_layer_name],
108-
self.input_layers_dict[end_layer_name],
109-
arc_direction="down", # TODO generalize this more
117+
input_mobject,
118+
output_mobject,
119+
arc_direction=arc_direction,
120+
buffer=0.05
110121
)
111122
self.connections.append(connection)
112123
self.add(connection)
@@ -243,7 +254,7 @@ def make_forward_pass_animation(
243254
):
244255
"""Generates an animation for feed forward propagation"""
245256
all_animations = []
246-
per_layer_animations = {}
257+
per_layer_animation_map = {}
247258
per_layer_runtime = (
248259
run_time / len(self.all_layers) if not run_time is None else None
249260
)
@@ -297,11 +308,11 @@ def make_forward_pass_animation(
297308
)
298309
all_animations.append(layer_forward_pass)
299310
# Add the animation to per layer animation
300-
per_layer_animations[layer] = layer_forward_pass
311+
per_layer_animation_map[layer] = layer_forward_pass
301312
# Make the animation group
302313
animation_group = Succession(*all_animations, lag_ratio=1.0)
303314
if per_layer_animations:
304-
return per_layer_animations
315+
return per_layer_animation_map
305316
else:
306317
return animation_group
307318

0 commit comments

Comments
 (0)