diff --git a/doc/requirements.txt b/doc/requirements.txt index 7bc001d..245e831 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,2 +1,3 @@ sphinx -sphinx_book_theme \ No newline at end of file +sphinx_book_theme +numpydoc \ No newline at end of file diff --git a/doc/source/conf.py b/doc/source/conf.py index 66c040e..d82c505 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -21,7 +21,13 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ["sphinx.ext.todo", "sphinx.ext.viewcode", "sphinx.ext.autodoc"] +extensions = [ + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "numpydoc", +] templates_path = ["_templates"] exclude_patterns = [] diff --git a/nlgm/autoencoder.py b/nlgm/autoencoder.py index cfd3ee0..9f47eaa 100644 --- a/nlgm/autoencoder.py +++ b/nlgm/autoencoder.py @@ -8,14 +8,26 @@ class Encoder(nn.Module): - def __init__(self, hidden_dim=20, latent_dim=2): - """ - Encoder class for the geometric autoencoder. + """ + Encoder class for the geometric autoencoder. - Args: - hidden_dim (int): Number of hidden dimensions. - latent_dim (int): Number of latent dimensions. - """ + Parameters + ---------- + hidden_dim : int + Number of hidden dimensions. + latent_dim : int + Number of latent dimensions. + + Methods + ------- + forward + + Attributes + ---------- + encoder + """ + + def __init__(self, hidden_dim=20, latent_dim=2): super(Encoder, self).__init__() self.encoder = nn.Sequential( @@ -45,25 +57,41 @@ def forward(self, x): """ Forward pass of the encoder. - Args: - x (torch.Tensor): Input tensor. + Parameters + ---------- + x : torch.Tensor + Input tensor. - Returns: - torch.Tensor: Encoded output tensor. + Returns + ------- + tensor : torch.Tensor + Encoded output tensor. """ z = self.encoder(x) return z class Decoder(nn.Module): - def __init__(self, hidden_dim=20, latent_dim=2): - """ - Decoder class for the geometric autoencoder. + """ + Decoder class for the geometric autoencoder. - Args: - hidden_dim (int): Number of hidden dimensions. - latent_dim (int): Number of latent dimensions. - """ + Parameters + ---------- + hidden_dim : int + Number of hidden dimensions. + latent_dim : int + Number of latent dimensions. + + Methods + ------- + forward + + Attributes + ---------- + decoder + """ + + def __init__(self, hidden_dim=20, latent_dim=2): super(Decoder, self).__init__() self.decoder = nn.Sequential( @@ -88,26 +116,45 @@ def forward(self, z): """ Forward pass of the decoder. - Args: - z (torch.Tensor): Encoded input tensor. + Parameters + ---------- + z : torch.Tensor + Encoded input tensor. - Returns: - torch.Tensor: Decoded output tensor. + Returns + ------- + tensor : torch.Tensor + Decoded output tensor. """ x_recon = self.decoder(z) return x_recon class GeometricAutoencoder(nn.Module): - def __init__(self, signature, hidden_dim=20, latent_dim=2): - """ - Geometric Autoencoder class. + """ + Geometric Autoencoder class. + + Parameters + ---------- + signature : list + List of signature dimensions. + hidden_dim : int + Number of hidden dimensions. + latent_dim : int + Number of latent dimensions. + + Methods + ------- + forward + + Attributes + ---------- + geometry + encoder + decoder + """ - Args: - signature (list): List of signature dimensions. - hidden_dim (int): Number of hidden dimensions. - latent_dim (int): Number of latent dimensions. - """ + def __init__(self, signature, hidden_dim=20, latent_dim=2): super(GeometricAutoencoder, self).__init__() self.geometry = ProductManifold(signature) self.encoder = Encoder(hidden_dim, latent_dim) @@ -117,11 +164,15 @@ def forward(self, x): """ Forward pass of the geometric autoencoder. - Args: - x (torch.Tensor): Input tensor. + Parameters + ---------- + x : torch.Tensor + Input tensor. - Returns: - torch.Tensor: Decoded output tensor. + Returns + ------- + tensor : torch.Tensor + Decoded output tensor. """ z = self.encoder(x) z = self.geometry.exponential_map(z)