Skip to content

Commit

Permalink
Attempt to fix pytorch/anndata issue
Browse files Browse the repository at this point in the history
  • Loading branch information
bbimber committed Jan 31, 2025
1 parent 4c4ff8b commit 98f955c
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions inst/scripts/PredictScTourPseudotime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import numpy as np
import pandas as pd
import torch
import anndata
from anndata import AnnData

def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file):
torch.serialization.add_safe_globals([AnnData, "anndata._core.file_backing.AnnDataFileManager"])

#read count data and variable genes
adataObj = sc.read_10x_h5(GEXfile)
adataObj.X = round(adataObj.X).astype(np.float32)
Expand All @@ -17,7 +16,7 @@ def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file):
print('AnnData object is a csr matrix, converting to dense because scipy depreciated the .A shorthand')
adataObj.X = adataObj.X.toarray()

with torch.serialization.add_safe_globals([AnnData]):
with torch.serialization.add_safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]):
checkpoint = torch.load(model_file, map_location=torch.device('cpu'))
model_adata = checkpoint['adata']

Expand All @@ -31,7 +30,7 @@ def PredictPseudotime(GEXfile, model_file, ptime_out_file, embedding_out_file):
#subset to genes found in the pretrained model.
adataObj = adataObj[:, genes_in_model]
#initalize a trainer and pull a previously saved model from model_file
with torch.serialization.safe_globals([AnnData, "anndata._core.file_backing.AnnDataFileManager"]):
with torch.serialization.safe_globals([AnnData, anndata._core.file_backing.AnnDataFileManager]):
tnode = sct.predict.load_model(model_file)
pred_t = sct.predict.predict_time(adata = adataObj, model = tnode)
adataObj.obs['ptime'] = pred_t
Expand Down

0 comments on commit 98f955c

Please sign in to comment.