Skip to content

Commit

Permalink
Allow some TF kernels fusion: tf.nn.bias_add as special case of tf.add (
Browse files Browse the repository at this point in the history
keras-team#20386)

* tf.nn.bias_add as special case of tf.add

* More comments
  • Loading branch information
shkarupa-alex authored Oct 24, 2024
1 parent 8dc19d2 commit eb5c5ae
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 5 deletions.
23 changes: 23 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,29 @@ def add(x1, x2):
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)

# Special case of `tf.add`: `tf.nn.bias_add`
# `BiasAdd` can be fused with `MatMul` and `Conv*` kernels
# Expecting `x1` to be `inputs` and `x2` to be `bias` (no swapping)
x2_squeeze_shape = [d for d in x2.shape if d is None or d > 1]
if (
# `x2` looks like bias (can be squeezed to vector)
1 == len(x2_squeeze_shape)
# `x1` looks like input tensor (rank >= 2)
and len(x1.shape) > 1
# `x2` non-squeezable dimension defined
and x2_squeeze_shape[0] is not None
# `x2` non-squeezable dimension match `x1` channel dimension
and x2_squeeze_shape[0] in {x1.shape[1], x1.shape[-1]}
):
if x1.shape[-1] == x2_squeeze_shape[0]:
data_format = "NHWC"
else:
data_format = "NCHW"
if len(x2.shape) > 1:
x2 = tf.squeeze(x2)
return tf.nn.bias_add(x1, x2, data_format=data_format)

return tf.add(x1, x2)


Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/base_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def call(self, inputs):
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
outputs = ops.add(outputs, bias)

if self.activation is not None:
return self.activation(outputs)
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/base_conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def call(self, inputs):
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
outputs = ops.add(outputs, bias)

if self.activation is not None:
return self.activation(outputs)
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/base_depthwise_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def call(self, inputs):
1,
) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
outputs = ops.add(outputs, bias)

if self.activation is not None:
return self.activation(outputs)
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/base_separable_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def call(self, inputs):
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
outputs = ops.add(outputs, bias)

if self.activation is not None:
return self.activation(outputs)
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/convolutional/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def call(self, inputs):
else:
bias_shape = (1, self.filters) + (1,) * self.rank
bias = ops.reshape(self.bias, bias_shape)
outputs += bias
outputs = ops.add(outputs, bias)

if self.activation is not None:
return self.activation(outputs)
Expand Down

0 comments on commit eb5c5ae

Please sign in to comment.