diff --git a/einops/layers/_einmix.py b/einops/layers/_einmix.py index ffef1de9..555441cb 100644 --- a/einops/layers/_einmix.py +++ b/einops/layers/_einmix.py @@ -1,7 +1,7 @@ from typing import Any, List, Optional, Dict from einops import EinopsError -from einops.parsing import ParsedExpression +from einops.parsing import ParsedExpression, _ellipsis import warnings import string from ..einops import _product @@ -17,21 +17,21 @@ def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str] = """ EinMix - Einstein summation with automated tensor management and axis packing/unpacking. - EinMix is an advanced tool, helpful tutorial: + EinMix is a combination of einops and MLP, see tutorial: https://github.com/arogozhnikov/einops/blob/main/docs/3-einmix-layer.ipynb Imagine taking einsum with two arguments, one of each input, and one - tensor with weights >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight) - This layer manages weights for you, syntax highlights separate role of weight matrix + This layer manages weights for you, syntax highlights a special role of weight matrix >>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out') - But otherwise it is the same einsum under the hood. + But otherwise it is the same einsum under the hood. Plus einops-rearrange. - Simple linear layer with bias term (you have one like that in your framework) + Simple linear layer with a bias term (you have one like that in your framework) >>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20) There is no restriction to mix the last axis. Let's mix along height >>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32) - Channel-wise multiplication (like one used in normalizations) + Example of channel-wise multiplication (like one used in normalizations) >>> EinMix('t b c -> t b c', weight_shape='c', c=128) Multi-head linear layer (each head is own linear layer): >>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...) @@ -42,14 +42,16 @@ def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str] = - when channel dimension is not last, use EinMix, not transposition - patch/segment embeddings - when need only within-group connections to reduce number of weights and computations - - perfect as a part of sequential models - - next-gen MLPs (follow tutorial to learn more!) + - next-gen MLPs (follow tutorial link above to learn more!) + - in general, any time you want to combine linear layer and einops.rearrange - Uniform He initialization is applied to weight tensor. This accounts for number of elements mixed. + Uniform He initialization is applied to weight tensor. + This accounts for the number of elements mixed and produced. Parameters :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output :param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer + If bias_shape is not specified, bias is not created. :param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added. :param axes_lengths: dimensions of weight tensor """ @@ -71,9 +73,13 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}), "Unrecognized identifiers on the right side of EinMix {}", ) - - if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis: - raise EinopsError("Ellipsis is not supported in EinMix (right now)") + if weight.has_ellipsis: + raise EinopsError("Ellipsis is not supported in weight, as its shape should be fully specified") + if left.has_ellipsis or right.has_ellipsis: + if not (left.has_ellipsis and right.has_ellipsis): + raise EinopsError(f"Ellipsis in EinMix should be on both sides, {pattern}") + if left.has_ellipsis_parenthesized: + raise EinopsError(f"Ellipsis on left side can't be in parenthesis, got {pattern}") if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]): raise EinopsError("Anonymous axes (numbers) are not allowed in EinMix") if "(" in weight_shape or ")" in weight_shape: @@ -86,16 +92,18 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona names: List[str] = [] for group in left.composition: names += group + names = [name if name != _ellipsis else "..." for name in names] composition = " ".join(names) - pre_reshape_pattern = f"{left_pattern}->{composition}" + pre_reshape_pattern = f"{left_pattern}-> {composition}" pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names} - if any(len(group) != 1 for group in right.composition): + if any(len(group) != 1 for group in right.composition) or right.has_ellipsis_parenthesized: names = [] for group in right.composition: names += group + names = [name if name != _ellipsis else "..." for name in names] composition = " ".join(names) - post_reshape_pattern = f"{composition}->{right_pattern}" + post_reshape_pattern = f"{composition} ->{right_pattern}" self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {}) @@ -116,22 +124,36 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona # single output element is a combination of fan_in input elements _fan_in = _product([axes_lengths[axis] for (axis,) in weight.composition if axis not in right.identifiers]) if bias_shape is not None: + # maybe I should put ellipsis in the beginning for simplicity? if not isinstance(bias_shape, str): raise EinopsError("bias shape should be string specifying which axes bias depends on") bias = ParsedExpression(bias_shape) - _report_axes(set.difference(bias.identifiers, right.identifiers), "Bias axes {} not present in output") + _report_axes( + set.difference(bias.identifiers, right.identifiers), + "Bias axes {} not present in output", + ) _report_axes( set.difference(bias.identifiers, set(axes_lengths)), "Sizes not provided for bias axes {}", ) _bias_shape = [] + used_non_trivial_size = False for axes in right.composition: - for axis in axes: - if axis in bias.identifiers: - _bias_shape.append(axes_lengths[axis]) - else: - _bias_shape.append(1) + if axes == _ellipsis: + if used_non_trivial_size: + raise EinopsError("all bias dimensions should go after ellipsis in the output") + else: + # handles ellipsis correctly + for axis in axes: + if axis == _ellipsis: + if used_non_trivial_size: + raise EinopsError("all bias dimensions should go after ellipsis in the output") + elif axis in bias.identifiers: + _bias_shape.append(axes_lengths[axis]) + used_non_trivial_size = True + else: + _bias_shape.append(1) else: _bias_shape = None @@ -142,15 +164,26 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona # rewrite einsum expression with single-letter latin identifiers so that # expression will be understood by any framework mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers} + if _ellipsis in mapped_identifiers: + mapped_identifiers.remove(_ellipsis) + mapped_identifiers = list(sorted(mapped_identifiers)) mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)} - - def write_flat(axes: list): - return "".join(mapping2letters[axis] for axis in axes) + mapping2letters[_ellipsis] = "..." # preserve ellipsis + + def write_flat_remapped(axes: ParsedExpression): + result = [] + for composed_axis in axes.composition: + if isinstance(composed_axis, list): + result.extend([mapping2letters[axis] for axis in composed_axis]) + else: + assert composed_axis == _ellipsis + result.append("...") + return "".join(result) self.einsum_pattern: str = "{},{}->{}".format( - write_flat(left.flat_axes_order()), - write_flat(weight.flat_axes_order()), - write_flat(right.flat_axes_order()), + write_flat_remapped(left), + write_flat_remapped(weight), + write_flat_remapped(right), ) def _create_rearrange_layers( @@ -174,3 +207,23 @@ def __repr__(self): for axis, length in self.axes_lengths.items(): params += ", {}={}".format(axis, length) return "{}({})".format(self.__class__.__name__, params) + + +class _EinmixDebugger(_EinmixMixin): + """Used only to test mixin""" + + def _create_rearrange_layers( + self, + pre_reshape_pattern: Optional[str], + pre_reshape_lengths: Optional[Dict], + post_reshape_pattern: Optional[str], + post_reshape_lengths: Optional[Dict], + ): + self.pre_reshape_pattern = pre_reshape_pattern + self.pre_reshape_lengths = pre_reshape_lengths + self.post_reshape_pattern = post_reshape_pattern + self.post_reshape_lengths = post_reshape_lengths + + def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): + self.saved_weight_shape = weight_shape + self.saved_bias_shape = bias_shape diff --git a/einops/tests/test_layers.py b/einops/tests/test_layers.py index 83f18163..a70a5d67 100644 --- a/einops/tests/test_layers.py +++ b/einops/tests/test_layers.py @@ -4,7 +4,7 @@ import numpy import pytest -from einops import rearrange, reduce +from einops import rearrange, reduce, EinopsError from einops.tests import collect_test_backends, is_backend_tested, FLOAT_REDUCTIONS as REDUCTIONS __author__ = "Alex Rogozhnikov" @@ -343,3 +343,127 @@ def eval_at_point(params): # check serialization fbytes = flax.serialization.to_bytes(params) _loaded = flax.serialization.from_bytes(params, fbytes) + + +def test_einmix_decomposition(): + """ + Testing that einmix correctly decomposes into smaller transformations. + """ + from einops.layers._einmix import _EinmixDebugger + + mixin1 = _EinmixDebugger( + "a b c d e -> e d c b a", + weight_shape="d a b", + d=2, a=3, b=5, + ) # fmt: off + assert mixin1.pre_reshape_pattern is None + assert mixin1.post_reshape_pattern is None + assert mixin1.einsum_pattern == "abcde,dab->edcba" + assert mixin1.saved_weight_shape == [2, 3, 5] + assert mixin1.saved_bias_shape is None + + mixin2 = _EinmixDebugger( + "a b c d e -> e d c b a", + weight_shape="d a b", + bias_shape="a b c d e", + a=1, b=2, c=3, d=4, e=5, + ) # fmt: off + assert mixin2.pre_reshape_pattern is None + assert mixin2.post_reshape_pattern is None + assert mixin2.einsum_pattern == "abcde,dab->edcba" + assert mixin2.saved_weight_shape == [4, 1, 2] + assert mixin2.saved_bias_shape == [5, 4, 3, 2, 1] + + mixin3 = _EinmixDebugger( + "... -> ...", + weight_shape="", + bias_shape="", + ) # fmt: off + assert mixin3.pre_reshape_pattern is None + assert mixin3.post_reshape_pattern is None + assert mixin3.einsum_pattern == "...,->..." + assert mixin3.saved_weight_shape == [] + assert mixin3.saved_bias_shape == [] + + mixin4 = _EinmixDebugger( + "b a ... -> b c ...", + weight_shape="b a c", + a=1, b=2, c=3, + ) # fmt: off + assert mixin4.pre_reshape_pattern is None + assert mixin4.post_reshape_pattern is None + assert mixin4.einsum_pattern == "ba...,bac->bc..." + assert mixin4.saved_weight_shape == [2, 1, 3] + assert mixin4.saved_bias_shape is None + + mixin5 = _EinmixDebugger( + "(b a) ... -> b c (...)", + weight_shape="b a c", + a=1, b=2, c=3, + ) # fmt: off + assert mixin5.pre_reshape_pattern == "(b a) ... -> b a ..." + assert mixin5.pre_reshape_lengths == dict(a=1, b=2) + assert mixin5.post_reshape_pattern == "b c ... -> b c (...)" + assert mixin5.einsum_pattern == "ba...,bac->bc..." + assert mixin5.saved_weight_shape == [2, 1, 3] + assert mixin5.saved_bias_shape is None + + mixin6 = _EinmixDebugger( + "b ... (a c) -> b ... (a d)", + weight_shape="c d", + bias_shape="a d", + a=1, c=3, d=4, + ) # fmt: off + assert mixin6.pre_reshape_pattern == "b ... (a c) -> b ... a c" + assert mixin6.pre_reshape_lengths == dict(a=1, c=3) + assert mixin6.post_reshape_pattern == "b ... a d -> b ... (a d)" + assert mixin6.einsum_pattern == "b...ac,cd->b...ad" + assert mixin6.saved_weight_shape == [3, 4] + assert mixin6.saved_bias_shape == [1, 1, 4] # (b) a d, ellipsis does not participate + + mixin7 = _EinmixDebugger( + "a ... (b c) -> a (... d b)", + weight_shape="c d b", + bias_shape="d b", + b=2, c=3, d=4, + ) # fmt: off + assert mixin7.pre_reshape_pattern == "a ... (b c) -> a ... b c" + assert mixin7.pre_reshape_lengths == dict(b=2, c=3) + assert mixin7.post_reshape_pattern == "a ... d b -> a (... d b)" + assert mixin7.einsum_pattern == "a...bc,cdb->a...db" + assert mixin7.saved_weight_shape == [3, 4, 2] + assert mixin7.saved_bias_shape == [1, 4, 2] # (a) d b, ellipsis does not participate + + +def test_einmix_restrictions(): + """ + Testing different cases + """ + from einops.layers._einmix import _EinmixDebugger + + with pytest.raises(EinopsError): + _EinmixDebugger( + "a b c d e -> e d c b a", + weight_shape="d a b", + d=2, a=3, # missing b + ) # fmt: off + + with pytest.raises(EinopsError): + _EinmixDebugger( + "a b c d e -> e d c b a", + weight_shape="w a b", + d=2, a=3, b=1 # missing d + ) # fmt: off + + with pytest.raises(EinopsError): + _EinmixDebugger( + "(...) a -> ... a", + weight_shape="a", a=1, # ellipsis on the left + ) # fmt: off + + with pytest.raises(EinopsError): + _EinmixDebugger( + "(...) a -> a ...", + weight_shape="a", a=1, # ellipsis on the right side after bias axis + bias_shape='a', + ) # fmt: off