From 71b26e0d196152729501b8119f7e2d0957452b07 Mon Sep 17 00:00:00 2001 From: Kevin Qiu Date: Mon, 8 Jul 2024 18:36:29 -0700 Subject: [PATCH] Minor fix to enable CPU/GPU usage for the example. --- examples/example.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/example.py b/examples/example.py index 6d78317..d1d44c1 100644 --- a/examples/example.py +++ b/examples/example.py @@ -6,6 +6,8 @@ from nlgm.train import train_and_evaluate from nlgm.searchspace import construct_graph_search_space +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Define the data transforms transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] @@ -55,7 +57,7 @@ def objective_function(signature): latent_dim = len(signature) * 2 model = GeometricAutoencoder(signature, latent_dim=latent_dim) train_losses, test_loss = train_and_evaluate( - model, train_loader, test_loader, epochs=epochs, device=torch.device("cuda") + model, train_loader, test_loader, epochs=epochs, device=device ) return train_losses, test_loss