Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement batch key argument in sc.pp.highly_variable_genes #28

Merged
merged 6 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions panpipes/panpipes/pipeline_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def rna_preprocess(adata_obj, log_file):
cmd += " --exclude %(hvg_exclude)s"
if PARAMS['hvg_flavor'] is not None:
cmd += " --flavor %(hvg_flavor)s"
if PARAMS['hvg_batch_key'] is not None:
cmd += " --hvg_batch_key %(hvg_batch_key)s"
if PARAMS['hvg_n_top_genes'] is not None:
cmd += " --n_top_genes %(hvg_n_top_genes)s"
if PARAMS['hvg_min_mean'] is not None:
Expand Down
5 changes: 5 additions & 0 deletions panpipes/panpipes/pipeline_preprocess/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ hvg:
exclude_file: resources/qc_genelist_1.0.csv
exclude: exclude # this is the variable that defines the genes to be excluded in the above file
flavor: seurat_v3 # "seurat", "cell_ranger", "seurat_v3"
# If batch key is specified, highly-variable genes are selected within each batch separately and merged.
# details: https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html#:~:text=or%20return%20them.-,batch_key,-%3A%20Optional%5B
# If you want to use more than one obs column as a covariates, include it as covariate1,covariate2 (comma separated list)
# Leave blank for no batch (default)
batch_key:
n_top_genes: 2000
min_mean:
max_mean:
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_bbknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

nnb = int(args.neighbors_within_batch)
# bbknn can't integrate on 2+ variables, so create a fake column with combined information
columns = [x.replace(" ", "") for x in args.integration_col.split(",")]
columns = [x.strip() for x in args.integration_col.split(",")]

if len(columns) > 1:
L.info("using 2 columns to integrate on more variables")
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_combat.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
adata = mdata.mod[args.modality]

# combat can't integrate on 2+ variables, so create a fake column with combined information
columns = [x.replace(" ", "") for x in args.integration_col.split(",")]
columns = [x.strip() for x in args.integration_col.split(",")]
if len(columns) > 1:
L.info("using 2 columns to integrate on more variables")
#comb_columns = "_".join(columns)
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

# Harmony can integrate on 2+ variables,
# but for consistency with other approaches create a fake column with combined information
columns = [x.replace(" " ,"") for x in args.integration_col.split(",")]
columns = [x.strip() for x in args.integration_col.split(",")]

if len(columns)>1:
L.info("using 2 columns to integrate on more variables")
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_mofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

if params['multimodal']['mofa']['modalities'] is not None:
modalities= params['multimodal']['mofa']['modalities']
modalities = [x.replace(" ", "") for x in modalities.split(",")]
modalities = [x.strip() for x in modalities.split(",")]
L.info(f"using modalities :{modalities}")
removed_mods = None
if all(x in modalities for x in mdata.mod.keys()):
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
# MultiVI integrates by modality, to use batch correction you need a batch covariate to specify in
# categorical_covariate_keys
if args.integration_col_categorical is not None :
cols = [x.replace(" ", "") for x in args.integration_col_categorical.split(",")]
cols = [x.strip() for x in args.integration_col_categorical.split(",")]
columns = []
for cc in cols:
if cc in rna_cols:
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@



columns = [x.replace (" " ,"") for x in args.integration_col.split(",")]
columns = [x.strip() for x in args.integration_col.split(",")]
if len(columns)>1:
comb_columns = "|".join(columns)
adata.obs[comb_columns] = adata.obs[columns].apply(lambda x: '|'.join(x), axis=1)
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_scanorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
bcs = adata.obs_names.tolist()

# scanorama can't integrate on 2+ variables, so create a fake column with combined information
columns = [x.replace(" ", "") for x in args.integration_col.split(",")]
columns = [x.strip() for x in args.integration_col.split(",")]
if len(columns) > 1:
L.info("using 2 columns to integrate on more variables")
# comb_columns = "_".join(columns)
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@


# in case of more than 1 variable, create a fake column with combined information
columns = [x.replace(" ", "") for x in args.integration_col.split(",")]
columns = [x.strip() for x in args.integration_col.split(",")]
if len(columns) > 1:
L.info("using 2 columns to integrate on more variables")
# bc_batch = "_".join(columns)
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
kwargs={}
# in case of more than 1 variable, create a fake column with combined information
if args.integration_col_categorical is not None :
columns = [x.replace(" ", "") for x in args.integration_col_categorical.split(",")]
columns = [x.strip() for x in args.integration_col_categorical.split(",")]
if len(columns) > 1:
L.info("using 2 columns to integrate on more variables")
# bc_batch = "_".join(columns)
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/batch_correct_wnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

if params['multimodal']['WNN']['modalities'] is not None:
modalities= params['multimodal']['WNN']['modalities']
modalities = [x.replace(" ", "") for x in modalities.split(",")]
modalities = [x.strip() for x in modalities.split(",")]
L.info(f"using modalities :{modalities}")

L.info("running with batch corrections:")
Expand Down
38 changes: 25 additions & 13 deletions panpipes/python_scripts/run_preprocess_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
parser.add_argument('--max_mean', default=3)
parser.add_argument('--min_disp', default=0.5)
parser.add_argument("--filter_by_hvg", default=False, type=check_for_bool)
parser.add_argument('--hvg_batch_key', default=None)
# regress out options
parser.add_argument('--regress_out', default=None)
# scale options
Expand Down Expand Up @@ -92,9 +93,21 @@
mdata = mu.MuData({'rna': mdata})

adata = mdata['rna']
# resolve multi-column batch for hvg batch key
if args.hvg_batch_key is not None:
columns = [x.strip() for x in args.integration_col.split(",")]
if len(columns) > 1:
L.info("combining batch comlumns into one column 'hvg_batch_key'")
adata.obs["hvg_batch_key"] = adata.obs[columns].apply(lambda x: '|'.join(x), axis=1)
# make sure that batch is a categorical
adata.obs["hvg_batch_key"] = adata.obs["hvg_batch_key"].astype("category")
hvg_batch_key="hvg_batch_key"
else:
hvg_batch_key=columns[0]
else:
hvg_batch_key=None


# save raw counts
# save raw counts as a layer
if X_is_raw(adata):
adata.layers['raw_counts'] = adata.X.copy()
elif "raw_counts" in adata.layers :
Expand All @@ -105,27 +118,28 @@
sys.exit("X is not raw data and raw_counts layer not found")



# Normalise to depth 10k, store raw data, assess and drop highly variable genes, regress mitochondria and count

# sc.pp.highly variabel genes Expects logarithmized data, except when flavor='seurat_v3' in which count data is expected.
# sc.pp.highly variabel genes Expects logarithmized data,
# except when flavor='seurat_v3' in which count data is expected.
# change the order accordingly
L.info("normalise, log and calucalte highly variable genes")
if args.flavor == "seurat_v3":
if args.n_top_genes is None:
raise ValueError("if seurat_v3 is used you must give a n_top_genes value")
# sc.pp.highly_variable_genes(adata, flavor="seurat_v3",)
else:
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=int(args.n_top_genes))
sc.pp.highly_variable_genes(adata, flavor="seurat_v3",
n_top_genes=int(args.n_top_genes),
batch_key=hvg_batch_key)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
else:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
L.debug(adata.uns['log1p'])
sc.pp.highly_variable_genes(adata, flavor=args.flavor,
min_mean=float(args.min_mean), max_mean=float(args.max_mean),
min_disp=float(args.min_disp))
min_mean=float(args.min_mean),
max_mean=float(args.max_mean),
min_disp=float(args.min_disp),
batch_key=hvg_batch_key)
L.debug(adata.uns['log1p'])

sc.pl.highly_variable_genes(adata,show=False, save ="_genes_highlyvar.png")
Expand All @@ -145,8 +159,6 @@
cat_dic[cc] = customgenes.loc[customgenes["group"] == cc,"feature"].tolist()
exclude_action = str(args.exclude)
excl = cat_dic[exclude_action]


L.info(len(excl))
L.info("number of hvgs prior to filtering")
L.info(adata.var.highly_variable.sum())
Expand All @@ -161,7 +173,7 @@
L.info(adata.var.highly_variable.sum())
sc.pl.highly_variable_genes(adata,show=False, save ="_exclude_genes_highlyvar.png")
else:
sys.exit("exclusion file %s not found, check the path andn try again" % args.exclude_file)
sys.exit("exclusion file %s not found, check the path and try again" % args.exclude_file)

if isinstance(mdata, mu.MuData):
mdata.update()
Expand Down