Skip to content

Commit

Permalink
Merge pull request #40 from andrrizzi/neural
Browse files Browse the repository at this point in the history
Neural spline transformer improvements
  • Loading branch information
andrrizzi authored Sep 24, 2024
2 parents 69fbe87 + fe3647f commit 89f28a6
Show file tree
Hide file tree
Showing 10 changed files with 689 additions and 255 deletions.
27 changes: 23 additions & 4 deletions tfep/app/mixedmaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,11 +811,30 @@ def _get_transformer(
maf_periodic_dof_indices = maf_periodic_dof_indices - torch.searchsorted(
maf_conditioning_dof_indices, maf_periodic_dof_indices)

return tfep.nn.transformers.NeuralSplineTransformer(
x0=x0.detach(),
xf=xf.detach(),
# Find all non-periodic DOFs (after filtering the conditioning ones).
mask = torch.full(x0.shape, fill_value=True)
mask[maf_periodic_dof_indices] = False
maf_nonperiodic_dof_indices = torch.tensor(range(len(x0)))[mask]

# Standard splines for non-periodic DOFs.
spline = tfep.nn.transformers.NeuralSplineTransformer(
x0=x0[maf_nonperiodic_dof_indices].detach(),
xf=xf[maf_nonperiodic_dof_indices].detach(),
n_bins=5,
circular=maf_periodic_dof_indices.detach(),
circular=False,
)

# Circular splines for periodic DOFs.
circular_spline = tfep.nn.transformers.NeuralSplineTransformer(
x0=x0[maf_periodic_dof_indices].detach(),
xf=xf[maf_periodic_dof_indices].detach(),
n_bins=5,
circular=True,
)

return tfep.nn.transformers.MixedTransformer(
transformers=[spline, circular_spline],
indices=[maf_nonperiodic_dof_indices, maf_periodic_dof_indices],
)


Expand Down
14 changes: 7 additions & 7 deletions tfep/nn/transformers/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AffineTransformer(MAFTransformer):
"""
# Number of parameters needed by the transformer for each input dimension.
n_parameters_per_input = 2
n_parameters_per_feature = 2

def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the affine transformation to the input.
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_identity_parameters(self, n_features: int) -> torch.Tensor:
Shape ``(2*n_features)``. The parameters for the identity.
"""
return torch.zeros(size=(self.n_parameters_per_input*n_features,))
return torch.zeros(size=(self.n_parameters_per_feature*n_features,))

def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Expand All @@ -131,13 +131,13 @@ def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
transformer as parameters.
"""
return degrees_in.tile((self.n_parameters_per_input,))
return degrees_in.tile((self.n_parameters_per_feature,))

def _split_parameters(self, parameters):
"""Divide shift from log scale."""
# From (batch, 2*n_features) to (batch, 2, n_features).
batch_size = parameters.shape[0]
parameters = parameters.reshape(batch_size, self.n_parameters_per_input, -1)
parameters = parameters.reshape(batch_size, self.n_parameters_per_feature, -1)
return parameters[:, 0], parameters[:, 1]


Expand All @@ -162,7 +162,7 @@ class VolumePreservingShiftTransformer(MAFTransformer):
"""
# Number of parameters needed by the transformer for each input dimension.
n_parameters_per_input = 1
n_parameters_per_feature = 1

def __init__(
self,
Expand Down Expand Up @@ -252,7 +252,7 @@ def get_identity_parameters(self, n_features: int) -> torch.Tensor:
Shape ``(n_features)``. The parameters for the identity.
"""
return torch.zeros(size=(self.n_parameters_per_input*n_features,))
return torch.zeros(size=(self.n_parameters_per_feature*n_features,))

def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Expand All @@ -271,7 +271,7 @@ def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
transformer as parameters.
"""
return degrees_in.tile((self.n_parameters_per_input,))
return degrees_in.tile((self.n_parameters_per_feature,))


# =============================================================================
Expand Down
12 changes: 6 additions & 6 deletions tfep/nn/transformers/moebius.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MoebiusTransformer(MAFTransformer):
"""
# Number of parameters needed by the transformer for each input dimension.
n_parameters_per_input = 1
n_parameters_per_feature = 1

def __init__(self, dimension: int, max_radius: float = 0.99, unit_sphere: bool = False):
"""Constructor.
Expand Down Expand Up @@ -170,7 +170,7 @@ def get_identity_parameters(self, n_features: int) -> torch.Tensor:
vector to perform the identity function with a Moebius transformer.
"""
return torch.zeros(size=(self.n_parameters_per_input*n_features,))
return torch.zeros(size=(self.n_parameters_per_feature*n_features,))

def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Expand All @@ -189,7 +189,7 @@ def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
transformer as parameters.
"""
return degrees_in.tile((self.n_parameters_per_input,))
return degrees_in.tile((self.n_parameters_per_feature,))


class SymmetrizedMoebiusTransformer(MAFTransformer):
Expand Down Expand Up @@ -220,7 +220,7 @@ class SymmetrizedMoebiusTransformer(MAFTransformer):
"""
# Number of parameters needed by the transformer for each input dimension.
n_parameters_per_input = 1
n_parameters_per_feature = 1

def __init__(self, dimension: int, max_radius: float = 0.99):
"""Constructor.
Expand Down Expand Up @@ -330,7 +330,7 @@ def get_identity_parameters(self, n_features: int) -> torch.Tensor:
vector to perform the identity function with a Moebius transformer.
"""
return torch.zeros(size=(self.n_parameters_per_input*n_features,))
return torch.zeros(size=(self.n_parameters_per_feature*n_features,))

def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Expand All @@ -349,7 +349,7 @@ def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
transformer as parameters.
"""
return degrees_in.tile((self.n_parameters_per_input,))
return degrees_in.tile((self.n_parameters_per_feature,))


# =============================================================================
Expand Down
8 changes: 4 additions & 4 deletions tfep/nn/transformers/sos.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def parameters_per_polynomial(self):
return self.degree_polynomials + 1

@property
def n_parameters_per_input(self):
def n_parameters_per_feature(self):
"""Number of parameters needed by the transformer for each input dimension."""
return self.parameters_per_polynomial * self.n_polynomials + 1

Expand Down Expand Up @@ -105,7 +105,7 @@ def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tens
"""
# From (batch, n_parameters*n_features) to (batch, n_parameters, n_features).
batch_size = parameters.shape[0]
parameters = parameters.reshape(batch_size, self.n_parameters_per_input, -1)
parameters = parameters.reshape(batch_size, self.n_parameters_per_feature, -1)
return sos_polynomial_transformer(x, parameters)

def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
Expand All @@ -131,7 +131,7 @@ def get_identity_parameters(self, n_features: int) -> torch.Tensor:
and degree of the polynomials.
"""
id_conditioner = torch.zeros(size=(self.n_parameters_per_input, n_features))
id_conditioner = torch.zeros(size=(self.n_parameters_per_feature, n_features))
# The sum of the squared linear parameters must be 1.
id_conditioner[1::self.parameters_per_polynomial].fill_(np.sqrt(1 / self.n_polynomials))
return id_conditioner.flatten()
Expand All @@ -153,7 +153,7 @@ def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
transformer as parameters.
"""
return degrees_in.tile((self.n_parameters_per_input,))
return degrees_in.tile((self.n_parameters_per_feature,))


# =============================================================================
Expand Down
Loading

0 comments on commit 89f28a6

Please sign in to comment.