Skip to content

Commit

Permalink
Merge pull request #39 from andrrizzi/mixed
Browse files Browse the repository at this point in the history
MixedTransformer
  • Loading branch information
andrrizzi authored Sep 19, 2024
2 parents 96cdc26 + 2e6b62e commit 69fbe87
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jobs:
if: contains(matrix.os, 'ubuntu')
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.xml
flags: unittests
name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}
1 change: 1 addition & 0 deletions tfep/nn/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AffineTransformer, affine_transformer, affine_transformer_inverse,
VolumePreservingShiftTransformer, volume_preserving_shift_transformer, volume_preserving_shift_transformer_inverse,
)
from tfep.nn.transformers.mixed import MixedTransformer
from tfep.nn.transformers.moebius import (
MoebiusTransformer, moebius_transformer)
from tfep.nn.transformers.sos import (
Expand Down
186 changes: 186 additions & 0 deletions tfep/nn/transformers/mixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
A transformer applying different transformers to different features.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from collections.abc import Sequence

import torch

from tfep.nn.transformers.transformer import MAFTransformer
from tfep.utils.misc import ensure_tensor_sequence


# =============================================================================
# MIXED TRANSFORMER
# =============================================================================

class MixedTransformer(MAFTransformer):
"""A transformer applying different transformers to different features."""

def __init__(
self,
transformers : Sequence[MAFTransformer],
indices : Sequence[Sequence[int]],
):
"""Constructor.
Parameters
----------
transformers : Sequence[MAFTransformer].
The transformers to mix.
indices : Sequence[Sequence[int]]
A list of length ``len(transformers)``. ``indices[i]`` is another
list containing the indices of the input features for the ``i``-th
transformer. The sum of all the lengths must equal the number of
features.
"""
super().__init__()

# Input checking.
if len(transformers) < 2:
raise ValueError('The number of transformers must be greater than 1.')
if len(transformers) != len(indices):
raise ValueError('The number of elements in indices must equal that in transformers.')

self._transformers = transformers

# Save the indices into buffers.
for idx, ind in enumerate(indices):
self.register_buffer(f'_indices{idx}', ensure_tensor_sequence(ind))

# Cache the starting and ending indices to split the parameters.
par_lengths = [len(transformer.get_identity_parameters(len(ind)))
for transformer, ind in zip(transformers, indices)]
split_indices = torch.cumsum(torch.tensor(par_lengths[:-1]), dim=0)
self.register_buffer('_parameters_split_indices', split_indices)

def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the transformation.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_features)``. The input features.
parameters : torch.Tensor
Shape ``(batch_size, n_parameters)``. The parameters for the
transformers expected grouped by transformer (i.e., first all
parameters for the first transformer, then those of the second one
etc.).
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_features)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dy / dx``.
"""
return self._run(x, parameters, inverse=False)

def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Reverse the transformation.
Parameters
----------
y : torch.Tensor
Shape ``(batch_size, n_features)``. The input features.
parameters : torch.Tensor
Shape ``(batch_size, n_parameters)``. The parameters for the
transformers expected grouped by transformer (i.e., first all
parameters for the first transformer, then those of the second one
etc.).
Returns
-------
x : torch.Tensor
Shape ``(batch_size, n_features)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dx / dy``.
"""
return self._run(y, parameters, inverse=True)

def get_identity_parameters(self, n_features: int) -> torch.Tensor:
"""Return the value of the parameters that makes this the identity function.
This can be used to initialize the normalizing flow to perform the identity
transformation.
Parameters
----------
n_features : int
The dimension of the input vector passed to the transformer.
Returns
-------
parameters : torch.Tensor
Shape ``(n_parameters,)``. The parameters for the identity function.
"""
parameters = [transformer.get_identity_parameters(len(indices))
for transformer, indices in zip(self._transformers, self._indices)]
return torch.cat(parameters, dim=-1)

def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Parameters
----------
degrees_in : torch.Tensor
Shape ``(n_transformed_features,)``. The autoregressive degrees
associated to the features provided as input to the transformer.
Returns
-------
degrees_out : torch.Tensor
Shape ``(n_parameters,)``. The autoregressive degrees associated
to each output of the conditioner that will be fed to the
transformer as parameters.
"""
degrees_out = [transformer.get_degrees_out(degrees_in[indices])
for transformer, indices in zip(self._transformers, self._indices)]
return torch.cat(degrees_out, dim=-1)

@property
def _indices(self):
"""Construct a list of buffers."""
indices = []
for idx, transformer in enumerate(self._transformers):
indices.append(getattr(self, f'_indices{idx}'))
return indices

def _run(self, x, parameters, inverse):
"""Execute the transformation."""
# Avoid in place modification for the result.
y = torch.empty_like(x)
cumulative_log_det_J = 0.0

# Split the parameters by transformer.
parameters = torch.tensor_split(parameters, self._parameters_split_indices, dim=1)

# Run transformers.
for idx, (transformer, par) in enumerate(zip(self._transformers, parameters)):
indices = getattr(self, f'_indices{idx}')
if inverse:
y[:, indices], log_det_J = transformer.inverse(x[:, indices], par)
else:
y[:, indices], log_det_J = transformer(x[:, indices], par)
cumulative_log_det_J = cumulative_log_det_J + log_det_J

return y, cumulative_log_det_J
49 changes: 45 additions & 4 deletions tfep/tests/nn/flows/test_maf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# GLOBAL IMPORTS
# =============================================================================

import numpy as np
import pytest
import torch

from tfep.nn.transformers import (
AffineTransformer, SOSPolynomialTransformer,
NeuralSplineTransformer, MoebiusTransformer
NeuralSplineTransformer, MoebiusTransformer,
MixedTransformer,
)
from tfep.nn.conditioners.made import generate_degrees
from tfep.nn.embeddings.mafembed import PeriodicEmbedding
Expand Down Expand Up @@ -103,7 +105,11 @@ def create_input(batch_size, dimension_in, limits=None, periodic_indices=None, s
SOSPolynomialTransformer(2),
SOSPolynomialTransformer(3),
NeuralSplineTransformer(x0=torch.tensor(-2., dtype=torch.double), xf=torch.tensor(2., dtype=torch.double), n_bins=3),
MoebiusTransformer(dimension=3)
MoebiusTransformer(dimension=3),
MixedTransformer(
transformers=[AffineTransformer(), SOSPolynomialTransformer(3)],
indices=[[0, 2], [1, 3, 4]],
)
])
def test_identity_initialization_MAF(hidden_layers, conditioning_indices, periodic_indices,
degrees_in_order, weight_norm, transformer):
Expand All @@ -118,7 +124,7 @@ def test_identity_initialization_MAF(hidden_layers, conditioning_indices, period
# Must be equal to the NeuralSplineTransformer limits.
limits = [-2., 2.]

# Periodic indices with MoebiusTransformer doens't make sense.
# Periodic indices with MoebiusTransformer doesn't make sense.
if periodic_indices is None:
embedding = None
elif isinstance(transformer, MoebiusTransformer):
Expand All @@ -138,6 +144,18 @@ def test_identity_initialization_MAF(hidden_layers, conditioning_indices, period
else:
repeats = 1

# Remove the conditioning indices from the MixedTransformer.
if isinstance(transformer, MixedTransformer) and len(conditioning_indices) > 0:
indices = [transformer._indices0.tolist(), transformer._indices1.tolist()]
indices = [[i for i in ind if i not in conditioning_indices] for ind in indices]
# Shift the indices to account for the removed conditioning indices.
new_indices = np.argsort(indices[0] + indices[1])
new_indices = [
torch.from_numpy(new_indices[:len(indices[0])]),
torch.from_numpy(new_indices[len(indices[0]):]),
]
transformer = MixedTransformer(transformer._transformers, new_indices)

# Create MAF.
maf = MAF(
degrees_in=generate_degrees(
Expand Down Expand Up @@ -184,7 +202,18 @@ def test_identity_initialization_MAF(hidden_layers, conditioning_indices, period
@pytest.mark.parametrize('degrees_in_order', ['ascending', 'descending', 'random'])
@pytest.mark.parametrize('transformer', [
AffineTransformer(),
MoebiusTransformer(dimension=3)
MoebiusTransformer(dimension=3),
MixedTransformer(
transformers=[
AffineTransformer(),
NeuralSplineTransformer(
x0=torch.tensor(-10., dtype=torch.double),
xf=torch.tensor(10., dtype=torch.double),
n_bins=3,
),
],
indices=[[2, 0, 3], [4, 1]],
)
])
@pytest.mark.parametrize('weight_norm', [False, True])
def test_maf_autoregressive_round_trip(conditioning_indices, periodic_indices, degrees_in_order, weight_norm, transformer):
Expand Down Expand Up @@ -213,6 +242,18 @@ def test_maf_autoregressive_round_trip(conditioning_indices, periodic_indices, d
else:
repeats = 1

# Remove the conditioning indices from the MixedTransformer.
if isinstance(transformer, MixedTransformer) and len(conditioning_indices) > 0:
indices = [transformer._indices0.tolist(), transformer._indices1.tolist()]
indices = [[i for i in ind if i not in conditioning_indices] for ind in indices]
# Shift the indices to account for the removed conditioning indices.
new_indices = np.argsort(indices[0] + indices[1])
new_indices = [
torch.from_numpy(new_indices[:len(indices[0])]),
torch.from_numpy(new_indices[len(indices[0]):]),
]
transformer = MixedTransformer(transformer._transformers, new_indices)

# Input degrees.
degrees_in = generate_degrees(
n_features=n_features,
Expand Down
Loading

0 comments on commit 69fbe87

Please sign in to comment.