From 53881170f9595f7917d60aad013912362eb56876 Mon Sep 17 00:00:00 2001 From: Mihaela Duta Date: Wed, 6 Nov 2024 15:35:05 +0000 Subject: [PATCH] Add type to VGAE.embed. --- l2gv2/embedding/embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/l2gv2/embedding/embeddings.py b/l2gv2/embedding/embeddings.py index c155844..b400d77 100644 --- a/l2gv2/embedding/embeddings.py +++ b/l2gv2/embedding/embeddings.py @@ -1,7 +1,7 @@ """ Module for embedding patches using the VGAE model """ import torch import torch_geometric as tg -from l2gv2.models import speye, VGAEconv +from l2gv2.models import speye, VGAEconv, Patch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -56,7 +56,7 @@ def embed( dim: int=100, hidden_dim: int=32, decoder=None - ): + ) -> tuple[list[Patch], list[torch.nn.Module]]: """ TODO: docstring for `embed` Args: