diff --git a/inst/scripts/PredictScTourPseudotime.py b/inst/scripts/PredictScTourPseudotime.py index a7803f5d..147089b5 100644 --- a/inst/scripts/PredictScTourPseudotime.py +++ b/inst/scripts/PredictScTourPseudotime.py @@ -6,6 +6,8 @@ import anndata from anndata import AnnData +torch.serialization.add_safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]): + def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file): #read count data and variable genes adataObj = sc.read_10x_h5(GEXfile) @@ -16,8 +18,7 @@ def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file): print('AnnData object is a csr matrix, converting to dense because scipy depreciated the .A shorthand') adataObj.X = adataObj.X.toarray() - with torch.serialization.add_safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]): - checkpoint = torch.load(model_file, map_location=torch.device('cpu')) + checkpoint = torch.load(model_file, map_location=torch.device('cpu'), weights_only = True) model_adata = checkpoint['adata'] genes_in_model = model_adata.var.index.values.tolist() @@ -30,8 +31,8 @@ def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file): #subset to genes found in the pretrained model. adataObj = adataObj[:, genes_in_model] #initalize a trainer and pull a previously saved model from model_file - with torch.serialization.safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]): - tnode = sct.predict.load_model(model_file) + + tnode = sct.predict.load_model(model_file) pred_t = sct.predict.predict_time(adata = adataObj, model = tnode) adataObj.obs['ptime'] = pred_t mix_zs, zs, pred_zs = sct.predict.predict_latentsp(adata = adataObj, model = tnode)