From 98f955ceedf483f7db26529b06a1674cb024d6d6 Mon Sep 17 00:00:00 2001 From: bbimber Date: Thu, 30 Jan 2025 19:53:34 -0800 Subject: [PATCH] Attempt to fix pytorch/anndata issue --- inst/scripts/PredictScTourPseudotime.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/inst/scripts/PredictScTourPseudotime.py b/inst/scripts/PredictScTourPseudotime.py index 9399057f..a7803f5d 100644 --- a/inst/scripts/PredictScTourPseudotime.py +++ b/inst/scripts/PredictScTourPseudotime.py @@ -3,11 +3,10 @@ import numpy as np import pandas as pd import torch +import anndata from anndata import AnnData def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file): - torch.serialization.add_safe_globals([AnnData, "anndata._core.file_backing.AnnDataFileManager"]) - #read count data and variable genes adataObj = sc.read_10x_h5(GEXfile) adataObj.X = round(adataObj.X).astype(np.float32) @@ -17,7 +16,7 @@ def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file): print('AnnData object is a csr matrix, converting to dense because scipy depreciated the .A shorthand') adataObj.X = adataObj.X.toarray() - with torch.serialization.add_safe_globals([AnnData]): + with torch.serialization.add_safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]): checkpoint = torch.load(model_file, map_location=torch.device('cpu')) model_adata = checkpoint['adata'] @@ -31,7 +30,7 @@ def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file): #subset to genes found in the pretrained model. adataObj = adataObj[:, genes_in_model] #initalize a trainer and pull a previously saved model from model_file - with torch.serialization.safe_globals([AnnData, "anndata._core.file_backing.AnnDataFileManager"]): + with torch.serialization.safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]): tnode = sct.predict.load_model(model_file) pred_t = sct.predict.predict_time(adata = adataObj, model = tnode) adataObj.obs['ptime'] = pred_t