Skip to content

Commit

Permalink
Torch issue
Browse files Browse the repository at this point in the history
  • Loading branch information
bbimber committed Jan 31, 2025
1 parent 98f955c commit 5697b3c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions inst/scripts/PredictScTourPseudotime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import anndata
from anndata import AnnData

torch.serialization.add_safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]):

def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file):
#read count data and variable genes
adataObj = sc.read_10x_h5(GEXfile)
Expand All @@ -16,8 +18,7 @@ def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file):
print('AnnData object is a csr matrix, converting to dense because scipy depreciated the .A shorthand')
adataObj.X = adataObj.X.toarray()

with torch.serialization.add_safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]):
checkpoint = torch.load(model_file, map_location=torch.device('cpu'))
checkpoint = torch.load(model_file, map_location=torch.device('cpu'), weights_only = True)
model_adata = checkpoint['adata']

genes_in_model = model_adata.var.index.values.tolist()
Expand All @@ -30,8 +31,8 @@ def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file):
#subset to genes found in the pretrained model.
adataObj = adataObj[:, genes_in_model]
#initalize a trainer and pull a previously saved model from model_file
with torch.serialization.safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]):
tnode = sct.predict.load_model(model_file)

tnode = sct.predict.load_model(model_file)
pred_t = sct.predict.predict_time(adata = adataObj, model = tnode)
adataObj.obs['ptime'] = pred_t
mix_zs, zs, pred_zs = sct.predict.predict_latentsp(adata = adataObj, model = tnode)
Expand Down

0 comments on commit 5697b3c

Please sign in to comment.