From 9876dfac11d3ac1402ab2c25a80ae31cb6b87a5d Mon Sep 17 00:00:00 2001 From: serfg Date: Wed, 15 May 2024 16:54:32 +0200 Subject: [PATCH] fix --- src/single_struct_calculator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/single_struct_calculator.py b/src/single_struct_calculator.py index 9e846b4..d28c2ea 100644 --- a/src/single_struct_calculator.py +++ b/src/single_struct_calculator.py @@ -43,6 +43,7 @@ def __init__(self, path_to_calc_folder, checkpoint="best_val_rmse_both_model", d self.model = model self.hypers = hypers self.all_species = all_species + self.device = device def forward(self, structure): @@ -52,6 +53,7 @@ def forward(self, structure): graph = molecule.get_graph(molecule.get_max_num(), self.all_species, molecule.get_num_k()) graph.batch = torch.zeros(graph.num_nodes, dtype = torch.long, device = graph.x.device) + graph = graph.to(self.device) prediction_energy, prediction_forces = self.model(graph, augmentation = False, create_graph = False) compositional_features = get_compositional_features([structure], self.all_species)[0]