From a857464aadaecc635e4b769c2a45315645660fd6 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sun, 13 Oct 2024 20:17:44 +0200 Subject: [PATCH 1/2] run-precommit From 0903fe88056da45f31f7d38ff84999287e782778 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 13 Oct 2024 18:18:42 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../\342\236\225-performance-improvement.md" | 5 +- .../\360\237\220\233-bug-report.md" | 6 +- .../\360\237\232\200-feature-request.md" | 7 +- .github/workflows/python-publish.yml | 33 +- .github/workflows/static.yml | 2 +- .pre-commit-config.yaml | 32 +- .scripts/create_dataset.py | 44 +- .scripts/predict.py | 24 +- .scripts/train_model.py | 36 +- docs/notebooks/benchmark_bc.py | 114 ++- docs/source/conf.py | 15 +- docs/user_guide/data_creation.md | 41 +- docs/user_guide/training.md | 12 +- scripts/create_data_sample.py | 23 +- scripts/predict_model_sample.py | 33 +- scripts/train_model.py | 62 +- scripts/train_model_sample.py | 26 +- src/segger/__init__.py | 2 +- src/segger/cli/cli.py | 4 +- src/segger/cli/configs/train/default.yaml | 6 +- src/segger/cli/create_dataset.py | 75 +- src/segger/cli/create_dataset_fast.py | 52 +- src/segger/cli/predict.py | 86 ++- src/segger/cli/train_model.py | 56 +- src/segger/cli/utils.py | 34 +- src/segger/data/README.md | 89 +-- src/segger/data/__init__.py | 33 +- src/segger/data/constants.py | 7 +- src/segger/data/io.py | 549 +++++++------- src/segger/data/parquet/_experimental.py | 36 +- src/segger/data/parquet/_ndtree.py | 37 +- src/segger/data/parquet/_settings/xenium.yaml | 10 +- src/segger/data/parquet/_utils.py | 87 ++- src/segger/data/parquet/pyg_dataset.py | 24 +- src/segger/data/parquet/sample.py | 250 +++---- .../data/parquet/transcript_embedding.py | 24 +- src/segger/data/utils.py | 259 ++++--- src/segger/models/README.md | 20 +- src/segger/models/__init__.py | 4 +- src/segger/models/segger_model.py | 29 +- src/segger/prediction/__init__.py | 5 +- src/segger/prediction/predict.py | 228 +++--- src/segger/training/README.md | 6 + src/segger/training/segger_data_module.py | 6 +- src/segger/training/train.py | 40 +- src/segger/validation/__init__.py | 2 +- src/segger/validation/utils.py | 678 +++++++++--------- src/segger/validation/xenium_explorer.py | 282 +++++--- tests/test_data.py | 63 +- tests/test_model.py | 25 +- tests/test_prediction.py | 14 +- tests/test_training.py | 25 +- 52 files changed, 1855 insertions(+), 1807 deletions(-) diff --git "a/.github/ISSUE_TEMPLATE/\342\236\225-performance-improvement.md" "b/.github/ISSUE_TEMPLATE/\342\236\225-performance-improvement.md" index b281b2f..9b5cb1f 100644 --- "a/.github/ISSUE_TEMPLATE/\342\236\225-performance-improvement.md" +++ "b/.github/ISSUE_TEMPLATE/\342\236\225-performance-improvement.md" @@ -1,10 +1,9 @@ --- name: "➕ Performance Improvement" about: Suggest an improvement in the performance -title: '' -labels: '' +title: "" +labels: "" assignees: andrewmoorman, EliHei2 - --- **Describe the issue with the current implementation** diff --git "a/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" "b/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" index b899e5a..5809219 100644 --- "a/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" +++ "b/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" @@ -2,12 +2,12 @@ name: "\U0001F41B Bug Report" about: Create a report to help us improve title: "[BUG]" -labels: '' +labels: "" assignees: andrewmoorman, EliHei2 - --- --- + name: Bug Report about: Report a bug or unexpected behavior title: "[BUG] " @@ -21,6 +21,7 @@ A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: + 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' @@ -33,6 +34,7 @@ A clear and concise description of what you expected to happen. If applicable, add screenshots or logs to help explain your problem. **Environment (please complete the following information):** + - OS: [e.g. macOS, Windows, Linux] - Python version: [e.g. 3.9] - Package version: [e.g. 1.2.3] diff --git "a/.github/ISSUE_TEMPLATE/\360\237\232\200-feature-request.md" "b/.github/ISSUE_TEMPLATE/\360\237\232\200-feature-request.md" index 08679f6..67644f2 100644 --- "a/.github/ISSUE_TEMPLATE/\360\237\232\200-feature-request.md" +++ "b/.github/ISSUE_TEMPLATE/\360\237\232\200-feature-request.md" @@ -1,10 +1,9 @@ --- name: "\U0001F680 Feature Request" about: Suggest an idea for this project -title: '' -labels: '' -assignees: '' - +title: "" +labels: "" +assignees: "" --- **Is your feature request related to a problem? Please describe.** diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index b7a704b..c16ebea 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,23 +17,22 @@ permissions: jobs: deploy: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install build - - name: Build package - run: python -m build - - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index b6a7e3a..146ad51 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -25,7 +25,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: "3.10" - name: Install package and documentation dependencies run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 781c996..a1e1760 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,22 @@ fail_fast: false default_language_version: - python: python3 + python: python3 default_stages: - - commit - - push + - commit + - push minimum_pre_commit_version: 2.16.0 ci: - skip: [] + skip: [] repos: - - repo: https://github.com/psf/black - rev: 24.8.0 - hooks: - - id: black - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 - hooks: - - id: prettier - - repo: https://github.com/asottile/blacken-docs - rev: 1.18.0 - hooks: - - id: blacken-docs + - repo: https://github.com/psf/black + rev: 24.8.0 + hooks: + - id: black + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 + hooks: + - id: prettier + - repo: https://github.com/asottile/blacken-docs + rev: 1.18.0 + hooks: + - id: blacken-docs diff --git a/.scripts/create_dataset.py b/.scripts/create_dataset.py index 27de8af..91ca6db 100644 --- a/.scripts/create_dataset.py +++ b/.scripts/create_dataset.py @@ -30,9 +30,7 @@ def main(args): download_file(transcripts_url, transcripts_path) download_file(nuclei_url, nuclei_path) - xs = XeniumSample().load_transcripts( - path=transcripts_path, min_qv=args.min_qv - ) + xs = XeniumSample().load_transcripts(path=transcripts_path, min_qv=args.min_qv) xs.load_nuclei(path=nuclei_path) if args.parallel: @@ -83,9 +81,7 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Create dataset from Xenium Human Pancreatic data." - ) + parser = argparse.ArgumentParser(description="Create dataset from Xenium Human Pancreatic data.") parser.add_argument( "--raw_data_dir", type=str, @@ -104,9 +100,7 @@ def main(args): required=True, help="URL for transcripts data.", ) - parser.add_argument( - "--nuclei_url", type=str, required=True, help="URL for nuclei data." - ) + parser.add_argument("--nuclei_url", type=str, required=True, help="URL for nuclei data.") parser.add_argument( "--min_qv", type=int, @@ -125,21 +119,11 @@ def main(args): default=180, help="Step size in y direction for tiles.", ) - parser.add_argument( - "--x_size", type=int, default=200, help="Width of each tile." - ) - parser.add_argument( - "--y_size", type=int, default=200, help="Height of each tile." - ) - parser.add_argument( - "--margin_x", type=int, default=None, help="Margin in x direction." - ) - parser.add_argument( - "--margin_y", type=int, default=None, help="Margin in y direction." - ) - parser.add_argument( - "--r_tx", type=int, default=3, help="Radius for building the graph." - ) + parser.add_argument("--x_size", type=int, default=200, help="Width of each tile.") + parser.add_argument("--y_size", type=int, default=200, help="Height of each tile.") + parser.add_argument("--margin_x", type=int, default=None, help="Margin in x direction.") + parser.add_argument("--margin_y", type=int, default=None, help="Margin in y direction.") + parser.add_argument("--r_tx", type=int, default=3, help="Radius for building the graph.") parser.add_argument( "--val_prob", type=float, @@ -158,9 +142,7 @@ def main(args): default=3, help="Number of nearest neighbors for nuclei.", ) - parser.add_argument( - "--dist_nc", type=int, default=10, help="Distance threshold for nuclei." - ) + parser.add_argument("--dist_nc", type=int, default=10, help="Distance threshold for nuclei.") parser.add_argument( "--k_tx", type=int, @@ -179,12 +161,8 @@ def main(args): default=True, help="Whether to compute edge labels.", ) - parser.add_argument( - "--sampling_rate", type=float, default=1, help="Rate of sampling tiles." - ) - parser.add_argument( - "--parallel", action="store_true", help="Use parallel processing." - ) + parser.add_argument("--sampling_rate", type=float, default=1, help="Rate of sampling tiles.") + parser.add_argument("--parallel", action="store_true", help="Use parallel processing.") parser.add_argument( "--num_workers", type=int, diff --git a/.scripts/predict.py b/.scripts/predict.py index f812822..9a095f4 100644 --- a/.scripts/predict.py +++ b/.scripts/predict.py @@ -30,9 +30,7 @@ def main(args: argparse.Namespace) -> None: if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Predict using the Segger model" - ) + parser = argparse.ArgumentParser(description="Predict using the Segger model") parser.add_argument( "--dataset_path", type=str, @@ -51,24 +49,16 @@ def main(args: argparse.Namespace) -> None: required=True, help="Path to the model checkpoint", ) - parser.add_argument( - "--init_emb", type=int, default=8, help="Initial embedding size" - ) + parser.add_argument("--init_emb", type=int, default=8, help="Initial embedding size") parser.add_argument( "--hidden_channels", type=int, default=64, help="Number of hidden channels", ) - parser.add_argument( - "--out_channels", type=int, default=16, help="Number of output channels" - ) - parser.add_argument( - "--heads", type=int, default=4, help="Number of attention heads" - ) - parser.add_argument( - "--aggr", type=str, default="sum", help="Aggregation method" - ) + parser.add_argument("--out_channels", type=int, default=16, help="Number of output channels") + parser.add_argument("--heads", type=int, default=4, help="Number of attention heads") + parser.add_argument("--aggr", type=str, default="sum", help="Aggregation method") parser.add_argument( "--score_cut", type=float, @@ -81,9 +71,7 @@ def main(args: argparse.Namespace) -> None: default=4, help="Number of nearest neighbors for nuclei", ) - parser.add_argument( - "--dist_nc", type=int, default=20, help="Distance threshold for nuclei" - ) + parser.add_argument("--dist_nc", type=int, default=20, help="Distance threshold for nuclei") parser.add_argument( "--k_tx", type=int, diff --git a/.scripts/train_model.py b/.scripts/train_model.py index 8a6ee85..2515a71 100644 --- a/.scripts/train_model.py +++ b/.scripts/train_model.py @@ -95,39 +95,21 @@ def main(args): default=4, help="Batch size for validation", ) - parser.add_argument( - "--init_emb", type=int, default=8, help="Initial embedding size" - ) + parser.add_argument("--init_emb", type=int, default=8, help="Initial embedding size") parser.add_argument( "--hidden_channels", type=int, default=64, help="Number of hidden channels", ) - parser.add_argument( - "--out_channels", type=int, default=16, help="Number of output channels" - ) - parser.add_argument( - "--heads", type=int, default=4, help="Number of attention heads" - ) - parser.add_argument( - "--aggr", type=str, default="sum", help="Aggregation method" - ) - parser.add_argument( - "--accelerator", type=str, default="cuda", help="Type of accelerator" - ) - parser.add_argument( - "--strategy", type=str, default="auto", help="Training strategy" - ) - parser.add_argument( - "--precision", type=str, default="16-mixed", help="Precision mode" - ) - parser.add_argument( - "--devices", type=int, default=4, help="Number of devices" - ) - parser.add_argument( - "--epochs", type=int, default=100, help="Number of epochs" - ) + parser.add_argument("--out_channels", type=int, default=16, help="Number of output channels") + parser.add_argument("--heads", type=int, default=4, help="Number of attention heads") + parser.add_argument("--aggr", type=str, default="sum", help="Aggregation method") + parser.add_argument("--accelerator", type=str, default="cuda", help="Type of accelerator") + parser.add_argument("--strategy", type=str, default="auto", help="Training strategy") + parser.add_argument("--precision", type=str, default="16-mixed", help="Precision mode") + parser.add_argument("--devices", type=int, default=4, help="Number of devices") + parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") parser.add_argument( "--default_root_dir", type=str, diff --git a/docs/notebooks/benchmark_bc.py b/docs/notebooks/benchmark_bc.py index 31ac0bd..8b9a3fc 100644 --- a/docs/notebooks/benchmark_bc.py +++ b/docs/notebooks/benchmark_bc.py @@ -8,54 +8,54 @@ from segger.validation.utils import * # Define paths and output directories -benchmarks_path = Path('/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc') -output_path = benchmarks_path / 'results+' -figures_path = output_path / 'figures' +benchmarks_path = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc") +output_path = benchmarks_path / "results+" +figures_path = output_path / "figures" figures_path.mkdir(parents=True, exist_ok=True) # Ensure the figures directory exists # Define colors for segmentation methods method_colors = { - 'segger': '#D55E00', - 'segger_n0': '#E69F00', - 'segger_n1': '#F0E442', - 'Baysor': '#0072B2', - '10X': '#009E73', - '10X-nucleus': '#CC79A7', - 'BIDCell': '#8B008B' + "segger": "#D55E00", + "segger_n0": "#E69F00", + "segger_n1": "#F0E442", + "Baysor": "#0072B2", + "10X": "#009E73", + "10X-nucleus": "#CC79A7", + "BIDCell": "#8B008B", } # Define colors for cell types major_colors = { - 'B-cells': '#d8f55e', - 'CAFs': '#532C8A', - 'Cancer Epithelial': '#C72228', - 'Endothelial': '#9e6762', - 'Myeloid': '#ffe012', - 'T-cells': '#3cb44b', - 'Normal Epithelial': '#0F4A9C', - 'PVL': '#c09d9a', - 'Plasmablasts': '#000075' + "B-cells": "#d8f55e", + "CAFs": "#532C8A", + "Cancer Epithelial": "#C72228", + "Endothelial": "#9e6762", + "Myeloid": "#ffe012", + "T-cells": "#3cb44b", + "Normal Epithelial": "#0F4A9C", + "PVL": "#c09d9a", + "Plasmablasts": "#000075", } # Define segmentation file paths segmentation_paths = { - 'segger': benchmarks_path / 'adata_segger.h5ad', - 'Baysor': benchmarks_path / 'adata_baysor.h5ad', - '10X': benchmarks_path / 'adata_10X.h5ad', - '10X-nucleus': benchmarks_path / 'adata_10X_nuc.h5ad', - 'BIDCell': benchmarks_path / 'adata_BIDCell.h5ad' + "segger": benchmarks_path / "adata_segger.h5ad", + "Baysor": benchmarks_path / "adata_baysor.h5ad", + "10X": benchmarks_path / "adata_10X.h5ad", + "10X-nucleus": benchmarks_path / "adata_10X_nuc.h5ad", + "BIDCell": benchmarks_path / "adata_BIDCell.h5ad", } # Load the segmentations and the scRNAseq data segmentations_dict = load_segmentations(segmentation_paths) segmentations_dict = {k: segmentations_dict[k] for k in method_colors.keys() if k in segmentations_dict} -scRNAseq_adata = sc.read(benchmarks_path / 'scRNAseq.h5ad') +scRNAseq_adata = sc.read(benchmarks_path / "scRNAseq.h5ad") # Generate general statistics plots plot_general_statistics_plots(segmentations_dict, figures_path, method_colors) # Find markers for scRNAseq data -markers = find_markers(scRNAseq_adata, cell_type_column='celltype_major', pos_percentile=30, neg_percentile=5) +markers = find_markers(scRNAseq_adata, cell_type_column="celltype_major", pos_percentile=30, neg_percentile=5) # Annotate spatial segmentations with scRNAseq reference data for method in segmentation_paths.keys(): @@ -68,9 +68,7 @@ # Find mutually exclusive genes based on scRNAseq data exclusive_gene_pairs = find_mutually_exclusive_genes( - adata=scRNAseq_adata, - markers=markers, - cell_type_column='celltype_major' + adata=scRNAseq_adata, markers=markers, cell_type_column="celltype_major" ) # Compute MECR for each segmentation method @@ -83,14 +81,12 @@ quantized_mecr_counts = {} for method in segmentations_dict.keys(): - if 'cell_area' in segmentations_dict[method].obs.columns: + if "cell_area" in segmentations_dict[method].obs.columns: quantized_mecr_area[method] = compute_quantized_mecr_area( - adata=segmentations_dict[method], - gene_pairs=exclusive_gene_pairs + adata=segmentations_dict[method], gene_pairs=exclusive_gene_pairs ) quantized_mecr_counts[method] = compute_quantized_mecr_counts( - adata=segmentations_dict[method], - gene_pairs=exclusive_gene_pairs + adata=segmentations_dict[method], gene_pairs=exclusive_gene_pairs ) # Plot MECR results @@ -99,26 +95,30 @@ plot_quantized_mecr_counts(quantized_mecr_counts, output_path=figures_path, palette=method_colors) # Filter segmentation methods for contamination analysis -new_segmentations_dict = {k: v for k, v in segmentations_dict.items() if k in ['segger', 'Baysor', '10X', '10X-nucleus', 'BIDCell']} +new_segmentations_dict = { + k: v for k, v in segmentations_dict.items() if k in ["segger", "Baysor", "10X", "10X-nucleus", "BIDCell"] +} # Compute contamination results contamination_results = {} for method, adata in new_segmentations_dict.items(): - if 'cell_centroid_x' in adata.obs.columns and 'cell_centroid_y' in adata.obs.columns: + if "cell_centroid_x" in adata.obs.columns and "cell_centroid_y" in adata.obs.columns: contamination_results[method] = calculate_contamination( adata=adata, markers=markers, # Assuming you have a dictionary of markers for cell types radius=15, n_neighs=20, - celltype_column='celltype_major', - num_cells=10000 + celltype_column="celltype_major", + num_cells=10000, ) # Prepare contamination data for boxplots boxplot_data = [] for method, df in contamination_results.items(): - melted_df = df.reset_index().melt(id_vars=['Source Cell Type'], var_name='Target Cell Type', value_name='Contamination') - melted_df['Segmentation Method'] = method + melted_df = df.reset_index().melt( + id_vars=["Source Cell Type"], var_name="Target Cell Type", value_name="Contamination" + ) + melted_df["Segmentation Method"] = method boxplot_data.append(melted_df) # Concatenate all contamination dataframes into one @@ -129,13 +129,13 @@ plot_contamination_boxplots(boxplot_data, output_path=figures_path, palette=method_colors) # Separate Segger into nucleus-positive and nucleus-negative cells -segmentations_dict['segger_n1'] = segmentations_dict['segger'][segmentations_dict['segger'].obs.has_nucleus] -segmentations_dict['segger_n0'] = segmentations_dict['segger'][~segmentations_dict['segger'].obs.has_nucleus] +segmentations_dict["segger_n1"] = segmentations_dict["segger"][segmentations_dict["segger"].obs.has_nucleus] +segmentations_dict["segger_n0"] = segmentations_dict["segger"][~segmentations_dict["segger"].obs.has_nucleus] # Compute clustering scores for all segmentation methods clustering_scores = {} for method, adata in segmentations_dict.items(): - ch_score, sh_score = compute_clustering_scores(adata, cell_type_column='celltype_major') + ch_score, sh_score = compute_clustering_scores(adata, cell_type_column="celltype_major") clustering_scores[method] = (ch_score, sh_score) # Plot UMAPs with clustering scores in the title @@ -143,20 +143,22 @@ # Compute neighborhood metrics for methods with spatial data for method, adata in segmentations_dict.items(): - if 'spatial' in list(adata.obsm.keys()): - compute_neighborhood_metrics(adata, radius=15, celltype_column='celltype_major') + if "spatial" in list(adata.obsm.keys()): + compute_neighborhood_metrics(adata, radius=15, celltype_column="celltype_major") # Prepare neighborhood entropy data for boxplots entropy_boxplot_data = [] for method, adata in segmentations_dict.items(): - if 'neighborhood_entropy' in adata.obs.columns: - entropy_df = pd.DataFrame({ - 'Cell Type': adata.obs['celltype_major'], - 'Neighborhood Entropy': adata.obs['neighborhood_entropy'], - 'Segmentation Method': method - }) + if "neighborhood_entropy" in adata.obs.columns: + entropy_df = pd.DataFrame( + { + "Cell Type": adata.obs["celltype_major"], + "Neighborhood Entropy": adata.obs["neighborhood_entropy"], + "Segmentation Method": method, + } + ) # Filter out NaN values, keeping only the subsetted cells - entropy_df = entropy_df.dropna(subset=['Neighborhood Entropy']) + entropy_df = entropy_df.dropna(subset=["Neighborhood Entropy"]) entropy_boxplot_data.append(entropy_df) # Concatenate all entropy dataframes into one @@ -166,7 +168,7 @@ plot_entropy_boxplots(entropy_boxplot_data, figures_path, palette=method_colors) # Find markers for sensitivity calculation -purified_markers = find_markers(scRNAseq_adata, 'celltype_major', pos_percentile=20, percentage=75) +purified_markers = find_markers(scRNAseq_adata, "celltype_major", pos_percentile=20, percentage=75) # Calculate sensitivity for each segmentation method sensitivity_results_per_method = {} @@ -178,11 +180,7 @@ sensitivity_boxplot_data = [] for method, sensitivity_results in sensitivity_results_per_method.items(): for cell_type, sensitivities in sensitivity_results.items(): - method_df = pd.DataFrame({ - 'Cell Type': cell_type, - 'Sensitivity': sensitivities, - 'Segmentation Method': method - }) + method_df = pd.DataFrame({"Cell Type": cell_type, "Sensitivity": sensitivities, "Segmentation Method": method}) sensitivity_boxplot_data.append(method_df) # Concatenate all sensitivity dataframes into one diff --git a/docs/source/conf.py b/docs/source/conf.py index 0cc5f81..17e7f88 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,23 +6,22 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'segger' -copyright = '2024, Elyas Heidari' -author = 'Elyas Heidari' -release = '0.01' +project = "segger" +copyright = "2024, Elyas Heidari" +author = "Elyas Heidari" +release = "0.01" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [] -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [] - # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'alabaster' -html_static_path = ['_static'] +html_theme = "alabaster" +html_static_path = ["_static"] diff --git a/docs/user_guide/data_creation.md b/docs/user_guide/data_creation.md index 571f9ef..8d27140 100644 --- a/docs/user_guide/data_creation.md +++ b/docs/user_guide/data_creation.md @@ -140,19 +140,22 @@ from pathlib import Path import scanpy as sc # Set up the file paths -raw_data_dir = Path('/path/to/xenium_output') -processed_data_dir = Path('path/to/processed_files') +raw_data_dir = Path("/path/to/xenium_output") +processed_data_dir = Path("path/to/processed_files") sample_tag = "sample/tag" # Load scRNA-seq data using Scanpy and subsample for efficiency -scRNAseq_path = 'path/to/scRNAseq.h5ad' +scRNAseq_path = "path/to/scRNAseq.h5ad" scRNAseq = sc.read(scRNAseq_path) sc.pp.subsample(scRNAseq, fraction=0.1) # Calculate gene cell type abundance embedding from scRNA-seq data from segger.utils import calculate_gene_celltype_abundance_embedding -celltype_column = 'celltype_column' -gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(scRNAseq, celltype_column) + +celltype_column = "celltype_column" +gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding( + scRNAseq, celltype_column +) # Create a XeniumSample instance for spatial transcriptomics processing xenium_sample = XeniumSample() @@ -161,9 +164,9 @@ xenium_sample = XeniumSample() xenium_sample.load_transcripts( base_path=raw_data_dir, sample=sample_tag, - transcripts_filename='transcripts.parquet', + transcripts_filename="transcripts.parquet", file_format="parquet", - additional_embeddings={"cell_type_abundance": gene_celltype_abundance_embedding} + additional_embeddings={"cell_type_abundance": gene_celltype_abundance_embedding}, ) # Set the embedding to "cell_type_abundance" to use it in further processing @@ -171,7 +174,7 @@ xenium_sample.set_embedding("cell_type_abundance") # Load nuclei data to define boundaries nuclei_path = raw_data_dir / sample_tag / "nucleus_boundaries.parquet" -xenium_sample.load_boundaries(path=nuclei_path, file_format='parquet') +xenium_sample.load_boundaries(path=nuclei_path, file_format="parquet") # Build PyTorch Geometric (PyG) data from a tile of the dataset tile_pyg_data = xenium_sample.build_pyg_data_from_tile( @@ -180,7 +183,7 @@ tile_pyg_data = xenium_sample.build_pyg_data_from_tile( r_tx=20, k_tx=20, use_precomputed=False, - workers=1 + workers=1, ) # Save dataset in processed format for segmentation @@ -199,7 +202,7 @@ xenium_sample.save_dataset_for_segger( test_prob=0.2, neg_sampling_ratio_approx=5, sampling_rate=1, - num_workers=1 + num_workers=1, ) ``` @@ -210,8 +213,8 @@ from segger.data import MerscopeSample from pathlib import Path # Set up the file paths -raw_data_dir = Path('path/to/merscope_outputs') -processed_data_dir = Path('path/to/processed_files') +raw_data_dir = Path("path/to/merscope_outputs") +processed_data_dir = Path("path/to/processed_files") sample_tag = "sample_tag" # Create a MerscopeSample instance for spatial transcriptomics processing @@ -221,16 +224,18 @@ merscope_sample = MerscopeSample() merscope_sample.load_transcripts( base_path=raw_data_dir, sample=sample_tag, - transcripts_filename='transcripts.csv', - file_format='csv' + transcripts_filename="transcripts.csv", + file_format="csv", ) # Optionally load cell boundaries cell_boundaries_path = raw_data_dir / sample_tag / "cell_boundaries.parquet" -merscope_sample.load_boundaries(path=cell_boundaries_path, file_format='parquet') +merscope_sample.load_boundaries(path=cell_boundaries_path, file_format="parquet") # Filter transcripts based on specific criteria -filtered_transcripts = merscope_sample.filter_transcripts(merscope_sample.transcripts_df) +filtered_transcripts = merscope_sample.filter_transcripts( + merscope_sample.transcripts_df +) # Build PyTorch Geometric (PyG) data from a tile of the dataset tile_pyg_data = merscope_sample.build_pyg_data_from_tile( @@ -239,7 +244,7 @@ tile_pyg_data = merscope_sample.build_pyg_data_from_tile( r_tx=15, k_tx=15, use_precomputed=True, - workers=2 + workers=2, ) # Save dataset in processed format for segmentation @@ -258,6 +263,6 @@ merscope_sample.save_dataset_for_segger( test_prob=0.2, neg_sampling_ratio_approx=3, sampling_rate=1, - num_workers=2 + num_workers=2, ) ``` diff --git a/docs/user_guide/training.md b/docs/user_guide/training.md index 151fc66..8b78f0c 100644 --- a/docs/user_guide/training.md +++ b/docs/user_guide/training.md @@ -69,12 +69,12 @@ To instantiate and run the `segger` model: ```python model = segger( - num_tx_tokens=5000, # Number of unique 'tx' tokens - init_emb=32, # Initial embedding dimension - hidden_channels=64, # Number of hidden channels - num_mid_layers=2, # Number of middle layers - out_channels=128, # Number of output channels - heads=4 # Number of attention heads + num_tx_tokens=5000, # Number of unique 'tx' tokens + init_emb=32, # Initial embedding dimension + hidden_channels=64, # Number of hidden channels + num_mid_layers=2, # Number of middle layers + out_channels=128, # Number of output channels + heads=4, # Number of attention heads ) output = model(x, edge_index) diff --git a/scripts/create_data_sample.py b/scripts/create_data_sample.py index 8cdb137..3e37d23 100644 --- a/scripts/create_data_sample.py +++ b/scripts/create_data_sample.py @@ -8,16 +8,20 @@ from lightning.pytorch.plugins.environments import LightningEnvironment from matplotlib import pyplot as plt import seaborn as sns + # import pandas as pd from segger.data.utils import calculate_gene_celltype_abundance_embedding import scanpy as sc import os + # import Dask.DataFrame as dd -os.environ['DASK_DAEMON'] = 'False' +os.environ["DASK_DAEMON"] = "False" -xenium_data_dir = Path('/omics/odcf/analysis/OE0606_projects/oncolgy_data_exchange/20230831-pan-cns-TMA-Xenium/output-XETG00078__0010722__TMA_AKSI__20230831__151713/') -segger_data_dir = Path('./data_tidy/pyg_datasets/pan_cns_AKSI') +xenium_data_dir = Path( + "/omics/odcf/analysis/OE0606_projects/oncolgy_data_exchange/20230831-pan-cns-TMA-Xenium/output-XETG00078__0010722__TMA_AKSI__20230831__151713/" +) +segger_data_dir = Path("./data_tidy/pyg_datasets/pan_cns_AKSI") # models_dir = Path('./models/bc_embedding_1001') # scRNAseq_path = '/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad' @@ -31,13 +35,11 @@ # gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(scRNAseq, celltype_column) - - # Setup Xenium sample to create dataset -xs = XeniumSample(verbose=False) # , embedding_df=gene_celltype_abundance_embedding) +xs = XeniumSample(verbose=False) # , embedding_df=gene_celltype_abundance_embedding) xs.set_file_paths( - transcripts_path=xenium_data_dir / 'transcripts.parquet', - boundaries_path=xenium_data_dir / 'nucleus_boundaries.parquet', + transcripts_path=xenium_data_dir / "transcripts.parquet", + boundaries_path=xenium_data_dir / "nucleus_boundaries.parquet", ) # dd.read_parquet(transcripts_path[0]) @@ -59,8 +61,7 @@ k_tx=5, val_prob=0.3, test_prob=0.1, - num_workers=6 + num_workers=6, ) except AssertionError as err: - print(f'Dataset already exists at {segger_data_dir}') - + print(f"Dataset already exists at {segger_data_dir}") diff --git a/scripts/predict_model_sample.py b/scripts/predict_model_sample.py index 11c5e89..ef013ae 100644 --- a/scripts/predict_model_sample.py +++ b/scripts/predict_model_sample.py @@ -8,21 +8,22 @@ import dask.dataframe as dd import pandas as pd from pathlib import Path -os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import cupy as cp from dask.distributed import Client, LocalCluster from dask_cuda import LocalCUDACluster import dask.dataframe as dd -segger_data_dir = Path('./data_tidy/pyg_datasets/bc_embedding_1001') -models_dir = Path('./models/bc_embedding_1001_small') -benchmarks_dir = Path('/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc') -transcripts_file = 'data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet' +segger_data_dir = Path("./data_tidy/pyg_datasets/bc_embedding_1001") +models_dir = Path("./models/bc_embedding_1001_small") +benchmarks_dir = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc") +transcripts_file = "data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet" # Initialize the Lightning data module dm = SeggerDataModule( data_dir=segger_data_dir, - batch_size=1, - num_workers=1, + batch_size=1, + num_workers=1, ) dm.setup() @@ -31,22 +32,22 @@ model_version = 0 # Load in latest checkpoint -model_path = models_dir / 'lightning_logs' / f'version_{model_version}' -model = load_model(model_path / 'checkpoints') +model_path = models_dir / "lightning_logs" / f"version_{model_version}" +model = load_model(model_path / "checkpoints") -receptive_field = {'k_bd': 4, 'dist_bd': 12,'k_tx': 5, 'dist_tx': 5} +receptive_field = {"k_bd": 4, "dist_bd": 12, "k_tx": 5, "dist_tx": 5} segment( model, dm, save_dir=benchmarks_dir, - seg_tag='segger_embedding_1001_0.5_cc', + seg_tag="segger_embedding_1001_0.5_cc", transcript_file=transcripts_file, - file_format='anndata', - receptive_field = receptive_field, + file_format="anndata", + receptive_field=receptive_field, min_transcripts=5, # max_transcripts=1500, - cell_id_col='segger_cell_id', + cell_id_col="segger_cell_id", use_cc=True, - knn_method='cuda' -) \ No newline at end of file + knn_method="cuda", +) diff --git a/scripts/train_model.py b/scripts/train_model.py index 7c25bc3..d94eda3 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -15,18 +15,20 @@ os.environ["PYTORCH_USE_CUDA_DSA"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + def check_and_create_raw_folder(directory): - raw_dir = directory / 'raw' + raw_dir = directory / "raw" if not raw_dir.exists(): raw_dir.mkdir(parents=True, exist_ok=True) - warnings.warn(f"'{raw_dir}' does not exist. Creating this dummy folder because SpatialTranscriptomicsDataset requires it.") + warnings.warn( + f"'{raw_dir}' does not exist. Creating this dummy folder because SpatialTranscriptomicsDataset requires it." + ) + def main(args): # CONFIG - - - - sys.path.insert(0, os.path.abspath('../..')) + + sys.path.insert(0, os.path.abspath("../..")) # Paths TRAIN_DIR = Path(args.train_dir) @@ -47,9 +49,9 @@ def main(args): hidden_channels=args.hidden_channels, out_channels=args.out_channels, heads=args.heads, - num_mid_layers=args.mid_layers # mid_layers is now included + num_mid_layers=args.mid_layers, # mid_layers is now included ) - model = to_hetero(model, (['tx', 'bd'], [('tx', 'belongs', 'bd'), ('tx', 'neighbors', 'tx')]), aggr=args.aggr) + model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr=args.aggr) batch = train_ds[0] model.forward(batch.x_dict, batch.edge_index_dict) @@ -73,25 +75,35 @@ def main(args): # Train the model trainer.fit(litsegger, train_loader, val_loader) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train the Segger model") - parser.add_argument('--train_dir', type=str, required=True, help='Path to the training data directory') - parser.add_argument('--val_dir', type=str, required=True, help='Path to the validation data directory') - parser.add_argument('--batch_size_train', type=int, default=4, help='Batch size for training') - parser.add_argument('--batch_size_val', type=int, default=4, help='Batch size for validation') - parser.add_argument('--num_tx_tokens', type=int, default=500, help='Number of unique tx tokens for embedding') # num_tx_tokens default 500 - parser.add_argument('--init_emb', type=int, default=8, help='Initial embedding size') - parser.add_argument('--hidden_channels', type=int, default=64, help='Number of hidden channels') - parser.add_argument('--out_channels', type=int, default=16, help='Number of output channels') - parser.add_argument('--heads', type=int, default=4, help='Number of attention heads') - parser.add_argument('--mid_layers', type=int, default=1, help='Number of middle layers in the model') # mid_layers default 1 - parser.add_argument('--aggr', type=str, default='sum', help='Aggregation method') - parser.add_argument('--accelerator', type=str, default='cuda', help='Type of accelerator') - parser.add_argument('--strategy', type=str, default='auto', help='Training strategy') - parser.add_argument('--precision', type=str, default='16-mixed', help='Precision mode') - parser.add_argument('--devices', type=int, default=4, help='Number of devices') - parser.add_argument('--epochs', type=int, default=100, help='Number of epochs') - parser.add_argument('--default_root_dir', type=str, default='./models/pancreas', help='Default root directory for logs and checkpoints') + parser.add_argument("--train_dir", type=str, required=True, help="Path to the training data directory") + parser.add_argument("--val_dir", type=str, required=True, help="Path to the validation data directory") + parser.add_argument("--batch_size_train", type=int, default=4, help="Batch size for training") + parser.add_argument("--batch_size_val", type=int, default=4, help="Batch size for validation") + parser.add_argument( + "--num_tx_tokens", type=int, default=500, help="Number of unique tx tokens for embedding" + ) # num_tx_tokens default 500 + parser.add_argument("--init_emb", type=int, default=8, help="Initial embedding size") + parser.add_argument("--hidden_channels", type=int, default=64, help="Number of hidden channels") + parser.add_argument("--out_channels", type=int, default=16, help="Number of output channels") + parser.add_argument("--heads", type=int, default=4, help="Number of attention heads") + parser.add_argument( + "--mid_layers", type=int, default=1, help="Number of middle layers in the model" + ) # mid_layers default 1 + parser.add_argument("--aggr", type=str, default="sum", help="Aggregation method") + parser.add_argument("--accelerator", type=str, default="cuda", help="Type of accelerator") + parser.add_argument("--strategy", type=str, default="auto", help="Training strategy") + parser.add_argument("--precision", type=str, default="16-mixed", help="Precision mode") + parser.add_argument("--devices", type=int, default=4, help="Number of devices") + parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") + parser.add_argument( + "--default_root_dir", + type=str, + default="./models/pancreas", + help="Default root directory for logs and checkpoints", + ) args = parser.parse_args() main(args) diff --git a/scripts/train_model_sample.py b/scripts/train_model_sample.py index ec3611a..8b834cc 100644 --- a/scripts/train_model_sample.py +++ b/scripts/train_model_sample.py @@ -8,19 +8,20 @@ from lightning.pytorch.plugins.environments import LightningEnvironment from matplotlib import pyplot as plt import seaborn as sns + # import pandas as pd from segger.data.utils import calculate_gene_celltype_abundance_embedding import scanpy as sc import os -segger_data_dir = Path('./data_tidy/pyg_datasets/bc_embedding_1001') -models_dir = Path('./models/bc_embedding_1001_small') +segger_data_dir = Path("./data_tidy/pyg_datasets/bc_embedding_1001") +models_dir = Path("./models/bc_embedding_1001_small") dm = SeggerDataModule( data_dir=segger_data_dir, - batch_size=4, - num_workers=2, + batch_size=4, + num_workers=2, ) dm.setup() @@ -33,17 +34,17 @@ out_channels=8, heads=2, num_mid_layers=2, - aggr='sum', + aggr="sum", metadata=metadata, ) # Initialize the Lightning trainer trainer = Trainer( - accelerator='cuda', - strategy='auto', - precision='16-mixed', - devices=4, - max_epochs=200, + accelerator="cuda", + strategy="auto", + precision="16-mixed", + devices=4, + max_epochs=200, default_root_dir=models_dir, logger=CSVLogger(models_dir), ) @@ -52,7 +53,4 @@ ls.forward(batch) -trainer.fit( - model=ls, - datamodule=dm -) \ No newline at end of file +trainer.fit(model=ls, datamodule=dm) diff --git a/src/segger/__init__.py b/src/segger/__init__.py index b186ac0..e7f4c44 100644 --- a/src/segger/__init__.py +++ b/src/segger/__init__.py @@ -5,4 +5,4 @@ from .models import * from .prediction import * from .training import * -from .validation import * \ No newline at end of file +from .validation import * diff --git a/src/segger/cli/cli.py b/src/segger/cli/cli.py index 9332d08..18715ee 100644 --- a/src/segger/cli/cli.py +++ b/src/segger/cli/cli.py @@ -3,12 +3,14 @@ from segger.cli.predict import predict import click + # Setup main CLI command @click.group(help="Command line interface for the Segger segmentation package") def segger(): pass + # Add sub-commands to main CLI commands segger.add_command(create_dataset) segger.add_command(train) -segger.add_command(predict) \ No newline at end of file +segger.add_command(predict) diff --git a/src/segger/cli/configs/train/default.yaml b/src/segger/cli/configs/train/default.yaml index b685eac..cf27fc1 100644 --- a/src/segger/cli/configs/train/default.yaml +++ b/src/segger/cli/configs/train/default.yaml @@ -44,7 +44,7 @@ num_workers: help: Number of workers for data loading. accelerator: type: str - default: 'cuda' + default: "cuda" help: Device type to use for training (e.g., "cuda", "cpu"). max_epochs: type: int @@ -56,9 +56,9 @@ devices: help: Number of devices (GPUs) to use. strategy: type: str - default: 'auto' + default: "auto" help: Training strategy for the trainer. precision: type: str - default: '16-mixed' + default: "16-mixed" help: Precision for training. diff --git a/src/segger/cli/create_dataset.py b/src/segger/cli/create_dataset.py index b22e1d7..f82e85b 100644 --- a/src/segger/cli/create_dataset.py +++ b/src/segger/cli/create_dataset.py @@ -8,37 +8,56 @@ import time # Path to default YAML configuration file -data_yml = Path(__file__).parent / 'configs' / 'create_dataset' / 'default.yaml' +data_yml = Path(__file__).parent / "configs" / "create_dataset" / "default.yaml" # CLI command to create a Segger dataset help_msg = "Create Segger dataset from spatial transcriptomics data (Xenium or MERSCOPE)" + + @click.command(name="create_dataset", help=help_msg) @add_options(config_path=data_yml) -@click.option('--dataset_dir', type=Path, required=True, help='Directory containing the raw dataset.') -@click.option('--data_dir', type=Path, required=True, help='Directory to save the processed Segger dataset.') -@click.option('--sample_tag', type=str, required=True, help='Sample tag for the dataset.') -@click.option('--transcripts_file', type=str, required=True, help='Name of the transcripts file.') -@click.option('--boundaries_file', type=str, required=True, help='Name of the boundaries file.') -@click.option('--x_size', type=int, default=300, help='Size of each tile in x-direction.') -@click.option('--y_size', type=int, default=300, help='Size of each tile in y-direction.') -@click.option('--d_x', type=int, default=280, help='Tile overlap in x-direction.') -@click.option('--d_y', type=int, default=280, help='Tile overlap in y-direction.') -@click.option('--margin_x', type=int, default=10, help='Margin in x-direction.') -@click.option('--margin_y', type=int, default=10, help='Margin in y-direction.') -@click.option('--r_tx', type=int, default=5, help='Radius for computing neighborhood graph.') -@click.option('--k_tx', type=int, default=5, help='Number of nearest neighbors for the neighborhood graph.') -@click.option('--val_prob', type=float, default=0.1, help='Validation data split proportion.') -@click.option('--test_prob', type=float, default=0.2, help='Test data split proportion.') -@click.option('--neg_sampling_ratio', type=float, default=5, help='Ratio for negative sampling.') -@click.option('--sampling_rate', type=float, default=1, help='Sampling rate for the dataset.') -@click.option('--workers', type=int, default=1, help='Number of workers for parallel processing.') -@click.option('--gpu', is_flag=True, default=False, help='Use GPU if available.') -def create_dataset(args: Namespace, dataset_dir: Path, data_dir: Path, sample_tag: str, - transcripts_file: str, boundaries_file: str, x_size: int, y_size: int, - d_x: int, d_y: int, margin_x: int, margin_y: int, r_tx: int, k_tx: int, - val_prob: float, test_prob: float, neg_sampling_ratio: float, - sampling_rate: float, workers: int, gpu: bool): - +@click.option("--dataset_dir", type=Path, required=True, help="Directory containing the raw dataset.") +@click.option("--data_dir", type=Path, required=True, help="Directory to save the processed Segger dataset.") +@click.option("--sample_tag", type=str, required=True, help="Sample tag for the dataset.") +@click.option("--transcripts_file", type=str, required=True, help="Name of the transcripts file.") +@click.option("--boundaries_file", type=str, required=True, help="Name of the boundaries file.") +@click.option("--x_size", type=int, default=300, help="Size of each tile in x-direction.") +@click.option("--y_size", type=int, default=300, help="Size of each tile in y-direction.") +@click.option("--d_x", type=int, default=280, help="Tile overlap in x-direction.") +@click.option("--d_y", type=int, default=280, help="Tile overlap in y-direction.") +@click.option("--margin_x", type=int, default=10, help="Margin in x-direction.") +@click.option("--margin_y", type=int, default=10, help="Margin in y-direction.") +@click.option("--r_tx", type=int, default=5, help="Radius for computing neighborhood graph.") +@click.option("--k_tx", type=int, default=5, help="Number of nearest neighbors for the neighborhood graph.") +@click.option("--val_prob", type=float, default=0.1, help="Validation data split proportion.") +@click.option("--test_prob", type=float, default=0.2, help="Test data split proportion.") +@click.option("--neg_sampling_ratio", type=float, default=5, help="Ratio for negative sampling.") +@click.option("--sampling_rate", type=float, default=1, help="Sampling rate for the dataset.") +@click.option("--workers", type=int, default=1, help="Number of workers for parallel processing.") +@click.option("--gpu", is_flag=True, default=False, help="Use GPU if available.") +def create_dataset( + args: Namespace, + dataset_dir: Path, + data_dir: Path, + sample_tag: str, + transcripts_file: str, + boundaries_file: str, + x_size: int, + y_size: int, + d_x: int, + d_y: int, + margin_x: int, + margin_y: int, + r_tx: int, + k_tx: int, + val_prob: float, + test_prob: float, + neg_sampling_ratio: float, + sampling_rate: float, + workers: int, + gpu: bool, +): + # Setup logging ch = logging.StreamHandler() ch.setLevel(logging.INFO) @@ -47,9 +66,9 @@ def create_dataset(args: Namespace, dataset_dir: Path, data_dir: Path, sample_ta # Initialize the appropriate sample class based on dataset type logging.info("Initializing sample...") - if args.dataset_type == 'xenium': + if args.dataset_type == "xenium": sample = XeniumSample() - elif args.dataset_type == 'merscope': + elif args.dataset_type == "merscope": sample = MerscopeSample() else: raise ValueError("Unsupported dataset type. Please choose 'xenium' or 'merscope'.") diff --git a/src/segger/cli/create_dataset_fast.py b/src/segger/cli/create_dataset_fast.py index 8e6e9ee..33a3a63 100644 --- a/src/segger/cli/create_dataset_fast.py +++ b/src/segger/cli/create_dataset_fast.py @@ -9,29 +9,42 @@ import time # Path to default YAML configuration file -data_yml = Path(__file__).parent / 'configs' / 'create_dataset' / 'default_fast.yaml' +data_yml = Path(__file__).parent / "configs" / "create_dataset" / "default_fast.yaml" # CLI command to create a Segger dataset help_msg = "Create Segger dataset from spatial transcriptomics data (Xenium or MERSCOPE)" + + @click.command(name="create_dataset", help=help_msg) @add_options(config_path=data_yml) -@click.option('--base_dir', type=Path, required=True, help='Directory containing the raw dataset.') -@click.option('--data_dir', type=Path, required=True, help='Directory to save the processed Segger dataset.') -@click.option('--sample_type', type=str, default=None, help='The sample type of the raw data, e.g., "xenium" or "merscope".') -@click.option('--k_bd', type=int, default=3, help='Number of nearest neighbors for boundary nodes.') -@click.option('--dist_bd', type=float, default=15., help='Maximum distance for boundary neighbors.') -@click.option('--k_tx', type=int, default=3, help='Number of nearest neighbors for transcript nodes.') -@click.option('--dist_tx', type=float, default=5., help='Maximum distance for transcript neighbors.') -@click.option('--tile_size', type=int, default=None, help='If provided, specifies the size of the tile. Overrides `tile_width` and `tile_height`.') -@click.option('--tile_width', type=int, default=None, help='Width of the tiles in pixels. Ignored if `tile_size` is provided.') -@click.option('--tile_height', type=int, default=None, help='Height of the tiles in pixels. Ignored if `tile_size` is provided.') -@click.option('--neg_sampling_ratio', type=float, default=5., help='Ratio of negative samples.') -@click.option('--frac', type=float, default=1., help='Fraction of the dataset to process.') -@click.option('--val_prob', type=float, default=0.1, help='Proportion of data for use for validation split.') -@click.option('--test_prob', type=float, default=0.2, help='Proportion of data for use for test split.') -@click.option('--n_workers', type=int, default=1, help='Number of workers for parallel processing.') +@click.option("--base_dir", type=Path, required=True, help="Directory containing the raw dataset.") +@click.option("--data_dir", type=Path, required=True, help="Directory to save the processed Segger dataset.") +@click.option( + "--sample_type", type=str, default=None, help='The sample type of the raw data, e.g., "xenium" or "merscope".' +) +@click.option("--k_bd", type=int, default=3, help="Number of nearest neighbors for boundary nodes.") +@click.option("--dist_bd", type=float, default=15.0, help="Maximum distance for boundary neighbors.") +@click.option("--k_tx", type=int, default=3, help="Number of nearest neighbors for transcript nodes.") +@click.option("--dist_tx", type=float, default=5.0, help="Maximum distance for transcript neighbors.") +@click.option( + "--tile_size", + type=int, + default=None, + help="If provided, specifies the size of the tile. Overrides `tile_width` and `tile_height`.", +) +@click.option( + "--tile_width", type=int, default=None, help="Width of the tiles in pixels. Ignored if `tile_size` is provided." +) +@click.option( + "--tile_height", type=int, default=None, help="Height of the tiles in pixels. Ignored if `tile_size` is provided." +) +@click.option("--neg_sampling_ratio", type=float, default=5.0, help="Ratio of negative samples.") +@click.option("--frac", type=float, default=1.0, help="Fraction of the dataset to process.") +@click.option("--val_prob", type=float, default=0.1, help="Proportion of data for use for validation split.") +@click.option("--test_prob", type=float, default=0.2, help="Proportion of data for use for test split.") +@click.option("--n_workers", type=int, default=1, help="Number of workers for parallel processing.") def create_dataset(args: Namespace): - + # Setup logging ch = logging.StreamHandler() ch.setLevel(logging.INFO) @@ -67,5 +80,6 @@ def create_dataset(args: Namespace): logging.info(f"Time to save dataset: {end_time - start_time} seconds") logging.info("Dataset saved successfully.") -if __name__ == '__main__': - create_dataset() \ No newline at end of file + +if __name__ == "__main__": + create_dataset() diff --git a/src/segger/cli/predict.py b/src/segger/cli/predict.py index eca5a4b..2cfe83e 100644 --- a/src/segger/cli/predict.py +++ b/src/segger/cli/predict.py @@ -5,34 +5,47 @@ import logging import os -os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + @click.command(name="run_segmentation", help="Run the Segger segmentation model.") -@click.option('--segger_data_dir', type=Path, required=True, help='Directory containing the processed Segger dataset.') -@click.option('--models_dir', type=Path, required=True, help='Directory containing the trained models.') -@click.option('--benchmarks_dir', type=Path, required=True, help='Directory to save the segmentation results.') -@click.option('--transcripts_file', type=str, required=True, help='Path to the transcripts file.') -@click.option('--batch_size', type=int, default=1, help='Batch size for processing.') -@click.option('--num_workers', type=int, default=1, help='Number of workers for data loading.') -@click.option('--model_version', type=int, default=0, help='Model version to load.') -@click.option('--save_tag', type=str, default='segger_embedding_1001_0.5', help='Tag for saving segmentation results.') -@click.option('--min_transcripts', type=int, default=5, help='Minimum number of transcripts for segmentation.') -@click.option('--cell_id_col', type=str, default='segger_cell_id', help='Column name for cell IDs.') -@click.option('--use_cc', is_flag=True, default=False, help='Use connected components if specified.') -@click.option('--knn_method', type=str, default='cuda', help='Method for KNN computation.') -@click.option('--file_format', type=str, default='anndata', help='File format for output data.') -@click.option('--k_bd', type=int, default=4, help='K value for boundary computation.') -@click.option('--dist_bd', type=int, default=12, help='Distance for boundary computation.') -@click.option('--k_tx', type=int, default=5, help='K value for transcript computation.') -@click.option('--dist_tx', type=int, default=5, help='Distance for transcript computation.') -def run_segmentation(segger_data_dir: Path, models_dir: Path, benchmarks_dir: Path, - transcripts_file: str, batch_size: int = 1, num_workers: int = 1, - model_version: int = 0, save_tag: str = 'segger_embedding_1001_0.5', - min_transcripts: int = 5, cell_id_col: str = 'segger_cell_id', - use_cc: bool = False, knn_method: str = 'cuda', - file_format: str = 'anndata', k_bd: int = 4, dist_bd: int = 12, - k_tx: int = 5, dist_tx: int = 5): - +@click.option("--segger_data_dir", type=Path, required=True, help="Directory containing the processed Segger dataset.") +@click.option("--models_dir", type=Path, required=True, help="Directory containing the trained models.") +@click.option("--benchmarks_dir", type=Path, required=True, help="Directory to save the segmentation results.") +@click.option("--transcripts_file", type=str, required=True, help="Path to the transcripts file.") +@click.option("--batch_size", type=int, default=1, help="Batch size for processing.") +@click.option("--num_workers", type=int, default=1, help="Number of workers for data loading.") +@click.option("--model_version", type=int, default=0, help="Model version to load.") +@click.option("--save_tag", type=str, default="segger_embedding_1001_0.5", help="Tag for saving segmentation results.") +@click.option("--min_transcripts", type=int, default=5, help="Minimum number of transcripts for segmentation.") +@click.option("--cell_id_col", type=str, default="segger_cell_id", help="Column name for cell IDs.") +@click.option("--use_cc", is_flag=True, default=False, help="Use connected components if specified.") +@click.option("--knn_method", type=str, default="cuda", help="Method for KNN computation.") +@click.option("--file_format", type=str, default="anndata", help="File format for output data.") +@click.option("--k_bd", type=int, default=4, help="K value for boundary computation.") +@click.option("--dist_bd", type=int, default=12, help="Distance for boundary computation.") +@click.option("--k_tx", type=int, default=5, help="K value for transcript computation.") +@click.option("--dist_tx", type=int, default=5, help="Distance for transcript computation.") +def run_segmentation( + segger_data_dir: Path, + models_dir: Path, + benchmarks_dir: Path, + transcripts_file: str, + batch_size: int = 1, + num_workers: int = 1, + model_version: int = 0, + save_tag: str = "segger_embedding_1001_0.5", + min_transcripts: int = 5, + cell_id_col: str = "segger_cell_id", + use_cc: bool = False, + knn_method: str = "cuda", + file_format: str = "anndata", + k_bd: int = 4, + dist_bd: int = 12, + k_tx: int = 5, + dist_tx: int = 5, +): + # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -41,16 +54,16 @@ def run_segmentation(segger_data_dir: Path, models_dir: Path, benchmarks_dir: Pa # Initialize the Lightning data module dm = SeggerDataModule( data_dir=segger_data_dir, - batch_size=batch_size, - num_workers=num_workers, + batch_size=batch_size, + num_workers=num_workers, ) - + dm.setup() - + logger.info("Loading the model...") # Load in the latest checkpoint - model_path = models_dir / 'lightning_logs' / f'version_{model_version}' - model = load_model(model_path / 'checkpoints') + model_path = models_dir / "lightning_logs" / f"version_{model_version}" + model = load_model(model_path / "checkpoints") logger.info("Running segmentation...") segment( @@ -59,15 +72,16 @@ def run_segmentation(segger_data_dir: Path, models_dir: Path, benchmarks_dir: Pa save_dir=benchmarks_dir, seg_tag=save_tag, transcript_file=transcripts_file, - file_format=file_format, - receptive_field={'k_bd': k_bd, 'dist_bd': dist_bd, 'k_tx': k_tx, 'dist_tx': dist_tx}, + file_format=file_format, + receptive_field={"k_bd": k_bd, "dist_bd": dist_bd, "k_tx": k_tx, "dist_tx": dist_tx}, min_transcripts=min_transcripts, cell_id_col=cell_id_col, use_cc=use_cc, knn_method=knn_method, ) - + logger.info("Segmentation completed.") -if __name__ == '__main__': + +if __name__ == "__main__": run_segmentation() diff --git a/src/segger/cli/train_model.py b/src/segger/cli/train_model.py index 78bd5d7..a3a23a7 100644 --- a/src/segger/cli/train_model.py +++ b/src/segger/cli/train_model.py @@ -7,27 +7,33 @@ from argparse import Namespace # Path to default YAML configuration file -train_yml = Path(__file__).parent / 'configs' / 'train' / 'default.yaml' +train_yml = Path(__file__).parent / "configs" / "train" / "default.yaml" help_msg = "Train the Segger segmentation model." + + @click.command(name="train_model", help=help_msg) @add_options(config_path=train_yml) -@click.option('--dataset_dir', type=Path, required=True, help='Directory containing the processed Segger dataset.') -@click.option('--models_dir', type=Path, required=True, help='Directory to save the trained model and the training logs.') -@click.option('--sample_tag', type=str, required=True, help='Sample tag for the dataset.') -@click.option('--init_emb', type=int, default=8, help='Size of the embedding layer.') -@click.option('--hidden_channels', type=int, default=32, help='Size of hidden channels in the model.') -@click.option('--num_tx_tokens', type=int, default=500, help='Number of transcript tokens.') -@click.option('--out_channels', type=int, default=8, help='Number of output channels.') -@click.option('--heads', type=int, default=2, help='Number of attention heads.') -@click.option('--num_mid_layers', type=int, default=2, help='Number of mid layers in the model.') -@click.option('--batch_size', type=int, default=4, help='Batch size for training.') -@click.option('--num_workers', type=int, default=2, help='Number of workers for data loading.') -@click.option('--accelerator', type=str, default='cuda', help='Device type to use for training (e.g., "cuda", "cpu").') # Ask for accelerator -@click.option('--max_epochs', type=int, default=200, help='Number of epochs for training.') -@click.option('--devices', type=int, default=4, help='Number of devices (GPUs) to use.') -@click.option('--strategy', type=str, default='auto', help='Training strategy for the trainer.') -@click.option('--precision', type=str, default='16-mixed', help='Precision for training.') +@click.option("--dataset_dir", type=Path, required=True, help="Directory containing the processed Segger dataset.") +@click.option( + "--models_dir", type=Path, required=True, help="Directory to save the trained model and the training logs." +) +@click.option("--sample_tag", type=str, required=True, help="Sample tag for the dataset.") +@click.option("--init_emb", type=int, default=8, help="Size of the embedding layer.") +@click.option("--hidden_channels", type=int, default=32, help="Size of hidden channels in the model.") +@click.option("--num_tx_tokens", type=int, default=500, help="Number of transcript tokens.") +@click.option("--out_channels", type=int, default=8, help="Number of output channels.") +@click.option("--heads", type=int, default=2, help="Number of attention heads.") +@click.option("--num_mid_layers", type=int, default=2, help="Number of mid layers in the model.") +@click.option("--batch_size", type=int, default=4, help="Batch size for training.") +@click.option("--num_workers", type=int, default=2, help="Number of workers for data loading.") +@click.option( + "--accelerator", type=str, default="cuda", help='Device type to use for training (e.g., "cuda", "cpu").' +) # Ask for accelerator +@click.option("--max_epochs", type=int, default=200, help="Number of epochs for training.") +@click.option("--devices", type=int, default=4, help="Number of devices (GPUs) to use.") +@click.option("--strategy", type=str, default="auto", help="Training strategy for the trainer.") +@click.option("--precision", type=str, default="16-mixed", help="Precision for training.") def train_model(args: Namespace): # Setup logging @@ -43,6 +49,7 @@ def train_model(args: Namespace): from segger.training.segger_data_module import SeggerDataModule from lightning.pytorch.loggers import CSVLogger from pytorch_lightning import Trainer + logging.info("Done.") # Load datasets @@ -66,7 +73,7 @@ def train_model(args: Namespace): out_channels=args.out_channels, # Hard-coded value heads=args.heads, # Hard-coded value num_mid_layers=args.num_mid_layers, # Hard-coded value - aggr='sum', # Hard-coded value + aggr="sum", # Hard-coded value metadata=metadata, ) @@ -80,15 +87,12 @@ def train_model(args: Namespace): default_root_dir=args.models_dir, logger=CSVLogger(args.models_dir), ) - + logging.info("Done.") # Train model logging.info("Training model...") - trainer.fit( - model=ls, - datamodule=dm - ) + trainer.fit(model=ls, datamodule=dm) logging.info("Done.") @@ -97,11 +101,13 @@ def train_model(args: Namespace): def train_slurm(args): train_model(args) + @click.group(help="Train the Segger model") def train(): pass + train.add_command(train_slurm) -if __name__ == '__main__': - train_model() \ No newline at end of file +if __name__ == "__main__": + train_model() diff --git a/src/segger/cli/utils.py b/src/segger/cli/utils.py index 2a38610..df6e816 100644 --- a/src/segger/cli/utils.py +++ b/src/segger/cli/utils.py @@ -12,11 +12,11 @@ def add_options( show_default: bool = True, ): """ - A decorator to add command-line options to a Click command from a YAML + A decorator to add command-line options to a Click command from a YAML configuration file. Parameters: - config_path (os.PathLike): The path to the YAML configuration file + config_path (os.PathLike): The path to the YAML configuration file containing the options. show_default (bool): Whether to show default values in help. @@ -26,7 +26,7 @@ def add_options( The YAML configuration file should have the following format: ``` option_name: - type: "type_name" # Optional, the type of the option + type: "type_name" # Optional, the type of the option (e.g., "str", "int") help: "description" # Optional, the help text for the option default: value # Optional, the default value for the option @@ -52,24 +52,23 @@ def greet(args): click.echo(f"Hello, {args.name}! You are {args.age} years old.") ``` """ - def decorator( - function: typing.Callable - ): + + def decorator(function: typing.Callable): # Wrap the original function to convert kwargs to a Namespace object def wrapper(**kwargs): args_namespace = Namespace(**kwargs) return function(args_namespace) - + # Load the YAML configuration file - with open(config_path, 'r') as file: + with open(config_path, "r") as file: config = yaml.safe_load(file.read()) # Decorate function with all options for name, kwargs in reversed(config.items()): - kwargs['show_default'] = show_default - if 'type' in kwargs: - kwargs['type'] = locate(kwargs['type']) - wrapper = click.option(f'--{name}', **kwargs)(wrapper) + kwargs["show_default"] = show_default + if "type" in kwargs: + kwargs["type"] = locate(kwargs["type"]) + wrapper = click.option(f"--{name}", **kwargs)(wrapper) return wrapper @@ -87,31 +86,32 @@ class CustomFormatter(logging.Formatter): bold_red (str): ANSI escape code for bold red color. reset (str): ANSI escape code to reset color. format (str): The format string for log messages. - FORMATS (dict): A dictionary mapping log levels to their respective + FORMATS (dict): A dictionary mapping log levels to their respective color-coded format strings. Methods: format(record): - Format the specified record as text, applying color codes based on the + Format the specified record as text, applying color codes based on the log level. """ + grey = "\x1b[38;20m" green = "\x1b[32;20m" yellow = "\x1b[33;20m" red = "\x1b[31;20m" bold_red = "\x1b[31;1m" reset = "\x1b[0m" - format='%(asctime)s %(levelname)s: %(message)s' + format = "%(asctime)s %(levelname)s: %(message)s" FORMATS = { logging.DEBUG: grey + format + reset, logging.INFO: green + format + reset, logging.WARNING: yellow + format + reset, logging.ERROR: red + format + reset, - logging.CRITICAL: bold_red + format + reset + logging.CRITICAL: bold_red + format + reset, } def format(self, record): log_fmt = self.FORMATS.get(record.levelno) formatter = logging.Formatter(log_fmt) - return formatter.format(record) \ No newline at end of file + return formatter.format(record) diff --git a/src/segger/data/README.md b/src/segger/data/README.md index df6e979..28d7df0 100644 --- a/src/segger/data/README.md +++ b/src/segger/data/README.md @@ -1,6 +1,6 @@ # segger - Data Preparation for Cell Segmentation -The `segger` package provides a comprehensive data preparation module for handling and processing spatial transcriptomics data, specifically designed to support **Xenium** and **Merscope** datasets. This module facilitates the creation of datasets for cell segmentation and subsequent graph-based deep learning tasks by leveraging scalable and efficient processing tools. +The `segger` package provides a comprehensive data preparation module for handling and processing spatial transcriptomics data, specifically designed to support **Xenium** and **Merscope** datasets. This module facilitates the creation of datasets for cell segmentation and subsequent graph-based deep learning tasks by leveraging scalable and efficient processing tools. ## Module Overview @@ -48,7 +48,6 @@ These classes inherit from `SpatialTranscriptomicsSample` and implement dataset- - **`XeniumSample`**: Tailored for **Xenium** datasets, it includes specific filtering rules to exclude unwanted transcripts based on naming patterns (e.g., `NegControlProbe_`, `BLANK_`). - **`MerscopeSample`**: Designed for **Merscope** datasets, allowing for custom filtering and processing logic as needed. - ## Workflow The dataset creation and processing workflow involves several key steps, each ensuring that the spatial transcriptomics data is appropriately prepared for downstream machine learning tasks. @@ -61,39 +60,42 @@ The dataset creation and processing workflow involves several key steps, each en ### Step 2: Tiling - **Spatial Segmentation**: The dataset is divided into smaller, manageable tiles of size $$x_{\text{size}} \times y_{\text{size}}$$, defined by their top-left corner coordinates $$(x_i, y_j)$$. - + $$ n_x = \left\lfloor \frac{x_{\text{max}} - x_{\text{min}}}{d_x} \right\rfloor, \quad n_y = \left\lfloor \frac{y_{\text{max}} - y_{\text{min}}}{d_y} \right\rfloor $$ - - Where: - - $$x_{\text{min}}, y_{\text{min}}$$: Minimum spatial coordinates. - - $$x_{\text{max}}, y_{\text{max}}$$: Maximum spatial coordinates. - - $$d_x, d_y$$: Step sizes along the $$x$$- and $$y$$-axes, respectively. + +Where: + +- $$x_{\text{min}}, y_{\text{min}}$$: Minimum spatial coordinates. +- $$x_{\text{max}}, y_{\text{max}}$$: Maximum spatial coordinates. +- $$d_x, d_y$$: Step sizes along the $$x$$- and $$y$$-axes, respectively. - **Transcript and Boundary Inclusion**: For each tile, transcripts and boundaries within the spatial bounds (with optional margins) are included: - -$$ -x_i - \text{margin}_x \leq x_t < x_i + x_{\text{size}} + \text{margin}_x, \quad y_j - \text{margin}_y \leq y_t < y_j + y_{\text{size}} + \text{margin}_y + +$$ +x_i - \text{margin}_x \leq x_t < x_i + x_{\text{size}} + \text{margin}_x, \quad y_j - \text{margin}_y \leq y_t < y_j + y_{\text{size}} + \text{margin}_y $$ - - Where: - - $$x_t, y_t$$: Transcript coordinates. - - $$\text{margin}_x, \text{margin}_y$$: Optional margins to include contextual data. + +Where: + +- $$x_t, y_t$$: Transcript coordinates. +- $$\text{margin}_x, \text{margin}_y$$: Optional margins to include contextual data. ### Step 3: Graph Construction For each tile, a graph $$G$$ is constructed with: - **Nodes ($$V$$)**: + - **Transcripts**: Represented by their spatial coordinates $$(x_t, y_t)$$ and feature vectors $$\mathbf{f}_t$$. - **Boundaries**: Represented by centroid coordinates $$(x_b, y_b)$$ and associated properties (e.g., area). - **Edges ($$E$$)**: - Created based on spatial proximity using methods like KD-Tree or FAISS. - Defined by a distance threshold $$d$$ and the number of nearest neighbors $$k$$: - -$$ + +$$ E = \{ (v_i, v_j) \mid \text{dist}(v_i, v_j) < d, \, v_i \in V, \, v_j \in V \} $$ @@ -102,7 +104,7 @@ $$ If enabled, edges can be labeled based on relationships, such as whether a transcript belongs to a boundary: $$ -\text{label}(t, b) = +\text{label}(t, b) = \begin{cases} 1 & \text{if } t \text{ belongs to } b \\ 0 & \text{otherwise} @@ -123,7 +125,6 @@ Each tile is randomly assigned to one of these sets according to the specified p The final output consists of a set of tiles, each containing a graph representation of the spatial transcriptomics data. These tiles are stored in designated directories (`train_tiles`, `val_tiles`, `test_tiles`) and are ready for integration into machine learning pipelines. - ## Example Usage Below are examples demonstrating how to utilize the `segger` data preparation module for both Xenium and Merscope datasets. @@ -137,25 +138,29 @@ from segger.data.utils import calculate_gene_celltype_abundance_embedding import scanpy as sc import os -xenium_data_dir = Path('./data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1') -segger_data_dir = Path('./data_tidy/pyg_datasets/bc_embedding_0919') -models_dir = Path('./models/bc_embedding_0919') +xenium_data_dir = Path("./data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1") +segger_data_dir = Path("./data_tidy/pyg_datasets/bc_embedding_0919") +models_dir = Path("./models/bc_embedding_0919") -scRNAseq_path = '/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad' +scRNAseq_path = ( + "/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad" +) scRNAseq = sc.read(scRNAseq_path) sc.pp.subsample(scRNAseq, 0.1) # Step 1: Calculate the gene cell type abundance embedding -celltype_column = 'celltype_minor' -gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(scRNAseq, celltype_column) +celltype_column = "celltype_minor" +gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding( + scRNAseq, celltype_column +) # Setup Xenium sample to create dataset -xs = XeniumSample(verbose=False , embedding_df=gene_celltype_abundance_embedding) +xs = XeniumSample(verbose=False, embedding_df=gene_celltype_abundance_embedding) xs.set_file_paths( - transcripts_path=xenium_data_dir / 'transcripts.parquet', - boundaries_path=xenium_data_dir / 'nucleus_boundaries.parquet', + transcripts_path=xenium_data_dir / "transcripts.parquet", + boundaries_path=xenium_data_dir / "nucleus_boundaries.parquet", ) xs.set_metadata() @@ -164,7 +169,7 @@ xenium_sample.set_embedding("cell_type_abundance") # Load nuclei data to define boundaries nuclei_path = raw_data_dir / sample_tag / "nucleus_boundaries.parquet" -xenium_sample.load_boundaries(path=nuclei_path, file_format='parquet') +xenium_sample.load_boundaries(path=nuclei_path, file_format="parquet") # Build PyTorch Geometric (PyG) data from a tile of the dataset tile_pyg_data = xenium_sample.build_pyg_data_from_tile( @@ -173,7 +178,7 @@ tile_pyg_data = xenium_sample.build_pyg_data_from_tile( r_tx=20, k_tx=20, use_precomputed=False, - workers=1 + workers=1, ) @@ -191,10 +196,10 @@ try: k_tx=10, val_prob=0.4, test_prob=0.1, - num_workers=6 + num_workers=6, ) except AssertionError as err: - print(f'Dataset already exists at {segger_data_dir}') + print(f"Dataset already exists at {segger_data_dir}") ``` ### Merscope Data @@ -204,8 +209,8 @@ from segger.data import MerscopeSample from pathlib import Path # Set up the file paths -raw_data_dir = Path('data_raw/merscope/') -processed_data_dir = Path('data_tidy/pyg_datasets') +raw_data_dir = Path("data_raw/merscope/") +processed_data_dir = Path("data_tidy/pyg_datasets") sample_tag = "Merscope_Sample_1" # Create a MerscopeSample instance for spatial transcriptomics processing @@ -215,16 +220,18 @@ merscope_sample = MerscopeSample() merscope_sample.load_transcripts( base_path=raw_data_dir, sample=sample_tag, - transcripts_filename='transcripts.csv', - file_format='csv' + transcripts_filename="transcripts.csv", + file_format="csv", ) # Optionally load cell boundaries cell_boundaries_path = raw_data_dir / sample_tag / "cell_boundaries.parquet" -merscope_sample.load_boundaries(path=cell_boundaries_path, file_format='parquet') +merscope_sample.load_boundaries(path=cell_boundaries_path, file_format="parquet") # Filter transcripts based on specific criteria -filtered_transcripts = merscope_sample.filter_transcripts(merscope_sample.transcripts_df) +filtered_transcripts = merscope_sample.filter_transcripts( + merscope_sample.transcripts_df +) # Build PyTorch Geometric (PyG) data from a tile of the dataset tile_pyg_data = merscope_sample.build_pyg_data_from_tile( @@ -233,12 +240,12 @@ tile_pyg_data = merscope_sample.build_pyg_data_from_tile( r_tx=15, k_tx=15, use_precomputed=True, - workers=2 + workers=2, ) # Save dataset in processed format for segmentation merscope_sample.save_dataset_for_segger( - processed_dir=processed_data_dir / 'embedding', + processed_dir=processed_data_dir / "embedding", x_size=360, y_size=360, d_x=180, @@ -252,6 +259,6 @@ merscope_sample.save_dataset_for_segger( test_prob=0.2, neg_sampling_ratio_approx=3, sampling_rate=1, - num_workers=2 + num_workers=2, ) ``` diff --git a/src/segger/data/__init__.py b/src/segger/data/__init__.py index 1d60059..380a815 100644 --- a/src/segger/data/__init__.py +++ b/src/segger/data/__init__.py @@ -5,35 +5,30 @@ """ __all__ = [ - "XeniumSample", - "MerscopeSample", - "SpatialTranscriptomicsDataset", - "filter_transcripts", - "create_anndata", - "compute_transcript_metrics", + "XeniumSample", + "MerscopeSample", + "SpatialTranscriptomicsDataset", + "filter_transcripts", + "create_anndata", + "compute_transcript_metrics", "SpatialTranscriptomicsSample", "calculate_gene_celltype_abundance_embedding", "get_edge_index", ] from .utils import ( - filter_transcripts, - create_anndata, - compute_transcript_metrics, - get_edge_index, + filter_transcripts, + create_anndata, + compute_transcript_metrics, + get_edge_index, calculate_gene_celltype_abundance_embedding, - SpatialTranscriptomicsDataset + SpatialTranscriptomicsDataset, ) from .io import ( - XeniumSample, - MerscopeSample, + XeniumSample, + MerscopeSample, SpatialTranscriptomicsSample, ) -from .constants import ( - SpatialTranscriptomicsKeys, - XeniumKeys, - MerscopeKeys -) - +from .constants import SpatialTranscriptomicsKeys, XeniumKeys, MerscopeKeys diff --git a/src/segger/data/constants.py b/src/segger/data/constants.py index 7cd1fb6..b48350f 100644 --- a/src/segger/data/constants.py +++ b/src/segger/data/constants.py @@ -1,5 +1,6 @@ from enum import Enum, auto + class SpatialTranscriptomicsKeys(Enum): """Unified keys for spatial transcriptomics data, supporting multiple platforms.""" @@ -7,11 +8,11 @@ class SpatialTranscriptomicsKeys(Enum): TRANSCRIPTS_FILE = auto() BOUNDARIES_FILE = auto() CELL_METADATA_FILE = auto() - + # Cell identifiers CELL_ID = auto() TRANSCRIPTS_ID = auto() - + # Coordinates and locations TRANSCRIPTS_X = auto() TRANSCRIPTS_Y = auto() @@ -19,7 +20,7 @@ class SpatialTranscriptomicsKeys(Enum): BOUNDARIES_VERTEX_Y = auto() GLOBAL_X = auto() GLOBAL_Y = auto() - + # Metadata METADATA_CELL_KEY = auto() COUNTS_CELL_KEY = auto() diff --git a/src/segger/data/io.py b/src/segger/data/io.py index a369b9f..fdfa059 100644 --- a/src/segger/data/io.py +++ b/src/segger/data/io.py @@ -30,11 +30,8 @@ import logging import warnings -for msg in [ - r".*Geometry is in a geographic CRS.*", - r".*You did not provide metadata.*" -]: - warnings.filterwarnings('ignore', category=UserWarning, message=msg) +for msg in [r".*Geometry is in a geographic CRS.*", r".*You did not provide metadata.*"]: + warnings.filterwarnings("ignore", category=UserWarning, message=msg) class SpatialTranscriptomicsSample(ABC): @@ -60,10 +57,8 @@ def __init__( self.boundaries_graph = boundaries_graph self.keys = keys self.embedding_df = embedding_df - self.current_embedding = 'token' + self.current_embedding = "token" self.verbose = verbose - - @abstractmethod def filter_transcripts(self, transcripts_df: pd.DataFrame, min_qv: float = 20.0) -> pd.DataFrame: @@ -78,8 +73,7 @@ def filter_transcripts(self, transcripts_df: pd.DataFrame, min_qv: float = 20.0) pd.DataFrame: The filtered dataframe. """ pass - - + def set_file_paths(self, transcripts_path: Path, boundaries_path: Path) -> None: """ Set the paths for the transcript and boundary files. @@ -90,10 +84,11 @@ def set_file_paths(self, transcripts_path: Path, boundaries_path: Path) -> None: """ self.transcripts_path = transcripts_path self.boundaries_path = boundaries_path - - if self.verbose: print(f"Set transcripts file path to {transcripts_path}") - if self.verbose: print(f"Set boundaries file path to {boundaries_path}") + if self.verbose: + print(f"Set transcripts file path to {transcripts_path}") + if self.verbose: + print(f"Set boundaries file path to {boundaries_path}") def load_transcripts( self, @@ -153,22 +148,22 @@ def load_transcripts( self.keys.TRANSCRIPTS_X.value, self.keys.TRANSCRIPTS_Y.value, self.keys.FEATURE_NAME.value, - self.keys.CELL_ID.value + self.keys.CELL_ID.value, ] # Check if the QUALITY_VALUE key exists in the dataset, and add it to the columns list if present if self.keys.QUALITY_VALUE.value in available_columns: columns_to_read.append(self.keys.QUALITY_VALUE.value) - + if self.keys.OVERLAPS_BOUNDARY.value in available_columns: columns_to_read.append(self.keys.OVERLAPS_BOUNDARY.value) # Use filters to only load data within the specified bounding box (x_min, x_max, y_min, y_max) filters = [ - (self.keys.TRANSCRIPTS_X.value, '>=', x_min), - (self.keys.TRANSCRIPTS_X.value, '<=', x_max), - (self.keys.TRANSCRIPTS_Y.value, '>=', y_min), - (self.keys.TRANSCRIPTS_Y.value, '<=', y_max) + (self.keys.TRANSCRIPTS_X.value, ">=", x_min), + (self.keys.TRANSCRIPTS_X.value, "<=", x_max), + (self.keys.TRANSCRIPTS_Y.value, ">=", y_min), + (self.keys.TRANSCRIPTS_Y.value, "<=", y_max), ] # Load the dataset lazily with filters applied for the bounding box @@ -195,27 +190,27 @@ def load_transcripts( # Lazily count the number of rows in the DataFrame before filtering initial_count = delayed(lambda df: df.shape[0])(transcripts_df) # Filter the DataFrame lazily based on valid genes from embeddings - transcripts_df = transcripts_df[ - transcripts_df[self.keys.FEATURE_NAME.value].isin(valid_genes) - ] + transcripts_df = transcripts_df[transcripts_df[self.keys.FEATURE_NAME.value].isin(valid_genes)] final_count = delayed(lambda df: df.shape[0])(transcripts_df) - if self.verbose: print(f"Dropped {initial_count - final_count} transcripts not found in {key} embedding.") + if self.verbose: + print(f"Dropped {initial_count - final_count} transcripts not found in {key} embedding.") # Ensure that the 'OVERLAPS_BOUNDARY' column is boolean if it exists if self.keys.OVERLAPS_BOUNDARY.value in transcripts_df.columns: - transcripts_df[self.keys.OVERLAPS_BOUNDARY.value] = transcripts_df[self.keys.OVERLAPS_BOUNDARY.value].astype(bool) + transcripts_df[self.keys.OVERLAPS_BOUNDARY.value] = transcripts_df[ + self.keys.OVERLAPS_BOUNDARY.value + ].astype(bool) return transcripts_df - def load_boundaries( - self, - path: Path, - file_format: str = "parquet", - x_min: float = None, - x_max: float = None, - y_min: float = None, - y_max: float = None + self, + path: Path, + file_format: str = "parquet", + x_min: float = None, + x_max: float = None, + y_min: float = None, + y_max: float = None, ) -> dd.DataFrame: """ Load boundaries data lazily using Dask, filtering by the specified bounding box. @@ -233,7 +228,7 @@ def load_boundaries( """ if file_format != "parquet": raise ValueError(f"Unsupported file format: {file_format}") - + self.boundaries_path = path # Use bounding box values from set_metadata if not explicitly provided @@ -246,15 +241,15 @@ def load_boundaries( columns_to_read = [ self.keys.BOUNDARIES_VERTEX_X.value, self.keys.BOUNDARIES_VERTEX_Y.value, - self.keys.CELL_ID.value + self.keys.CELL_ID.value, ] # Use filters to only load data within the specified bounding box (x_min, x_max, y_min, y_max) filters = [ - (self.keys.BOUNDARIES_VERTEX_X.value, '>=', x_min), - (self.keys.BOUNDARIES_VERTEX_X.value, '<=', x_max), - (self.keys.BOUNDARIES_VERTEX_Y.value, '>=', y_min), - (self.keys.BOUNDARIES_VERTEX_Y.value, '<=', y_max) + (self.keys.BOUNDARIES_VERTEX_X.value, ">=", x_min), + (self.keys.BOUNDARIES_VERTEX_X.value, "<=", x_max), + (self.keys.BOUNDARIES_VERTEX_Y.value, ">=", y_min), + (self.keys.BOUNDARIES_VERTEX_Y.value, "<=", y_max), ] # Load the dataset lazily with filters applied for the bounding box @@ -265,17 +260,15 @@ def load_boundaries( lambda x: str(x) if pd.notnull(x) else None ) - if self.verbose: print(f"Loaded boundaries from '{path}' within bounding box ({x_min}, {x_max}, {y_min}, {y_max}).") + if self.verbose: + print(f"Loaded boundaries from '{path}' within bounding box ({x_min}, {x_max}, {y_min}, {y_max}).") return boundaries_df - - - def set_metadata(self) -> None: """ Set metadata for the transcript dataset, including bounding box limits and unique gene names, - without reading the entire Parquet file. Additionally, return integer tokens for unique gene names + without reading the entire Parquet file. Additionally, return integer tokens for unique gene names instead of one-hot encodings and store the lookup table for later mapping. """ # Load the Parquet file metadata @@ -287,7 +280,7 @@ def set_metadata(self) -> None: feature_col = self.keys.FEATURE_NAME.value # Initialize variables to track min/max values for X and Y - x_min, x_max, y_min, y_max = float('inf'), float('-inf'), float('inf'), float('-inf') + x_min, x_max, y_min, y_max = float("inf"), float("-inf"), float("inf"), float("-inf") # Extract unique gene names and ensure they're strings gene_set = set() @@ -299,7 +292,7 @@ def set_metadata(self) -> None: "NegControlCodeword_", "BLANK_", "DeprecatedCodeword_", - "UnassignedCodeword_" + "UnassignedCodeword_", ) # Iterate over row groups to extract statistics and unique gene names @@ -316,8 +309,12 @@ def set_metadata(self) -> None: y_max = max(y_max, y_values.max()) # Convert feature values (gene names) to strings and filter out unwanted codewords - feature_values = row_group_table[feature_col].to_pandas().apply( - lambda x: x.decode('utf-8') if isinstance(x, bytes) else str(x), + feature_values = ( + row_group_table[feature_col] + .to_pandas() + .apply( + lambda x: x.decode("utf-8") if isinstance(x, bytes) else str(x), + ) ) # Filter out unwanted codewords @@ -332,11 +329,15 @@ def set_metadata(self) -> None: self.y_min = y_min self.y_max = y_max - if self.verbose: print(f"Bounding box limits set: x_min={self.x_min}, x_max={self.x_max}, y_min={self.y_min}, y_max={self.y_max}") + if self.verbose: + print( + f"Bounding box limits set: x_min={self.x_min}, x_max={self.x_max}, y_min={self.y_min}, y_max={self.y_max}" + ) # Convert the set of unique genes into a sorted list for consistent ordering self.unique_genes = sorted(gene_set) - if self.verbose: print(f"Extracted {len(self.unique_genes)} unique gene names for integer tokenization.") + if self.verbose: + print(f"Extracted {len(self.unique_genes)} unique gene names for integer tokenization.") # Initialize a LabelEncoder to convert unique genes into integer tokens self.tx_encoder = LabelEncoder() @@ -345,18 +346,19 @@ def set_metadata(self) -> None: self.tx_encoder.fit(self.unique_genes) # Store the integer tokens mapping to gene names - self.gene_to_token_map = dict(zip(self.tx_encoder.classes_, self.tx_encoder.transform(self.tx_encoder.classes_))) - + self.gene_to_token_map = dict( + zip(self.tx_encoder.classes_, self.tx_encoder.transform(self.tx_encoder.classes_)) + ) - if self.verbose: print("Integer tokens have been computed and stored based on unique gene names.") + if self.verbose: + print("Integer tokens have been computed and stored based on unique gene names.") # Optional: Create a reverse mapping for lookup purposes (token to gene) self.token_to_gene_map = {v: k for k, v in self.gene_to_token_map.items()} + if self.verbose: + print("Lookup tables (gene_to_token_map and token_to_gene_map) have been created.") - if self.verbose: print("Lookup tables (gene_to_token_map and token_to_gene_map) have been created.") - - def set_embedding(self, embedding_name: str) -> None: """ Set the current embedding type for the transcripts. @@ -370,8 +372,7 @@ def set_embedding(self, embedding_name: str) -> None: self.current_embedding = embedding_name else: raise ValueError(f"Embedding {embedding_name} not found in embeddings_dict.") - - + @staticmethod def create_scaled_polygon(group: pd.DataFrame, scale_factor: float, keys) -> gpd.GeoDataFrame: """ @@ -386,9 +387,9 @@ def create_scaled_polygon(group: pd.DataFrame, scale_factor: float, keys) -> gpd gpd.GeoDataFrame: A GeoDataFrame containing the scaled Polygon and cell_id. """ # Extract coordinates and cell ID from the group using keys - x_coords = group[keys['vertex_x']] - y_coords = group[keys['vertex_y']] - cell_id = group[keys['cell_id']].iloc[0] + x_coords = group[keys["vertex_x"]] + y_coords = group[keys["vertex_y"]] + cell_id = group[keys["cell_id"]].iloc[0] # Ensure there are at least 3 points to form a polygon if len(x_coords) >= 3: @@ -398,19 +399,13 @@ def create_scaled_polygon(group: pd.DataFrame, scale_factor: float, keys) -> gpd # Scale the polygon by the provided factor scaled_polygon = polygon.buffer(scale_factor) if scaled_polygon.is_valid and not scaled_polygon.is_empty: - return gpd.GeoDataFrame({ - 'geometry': [scaled_polygon], - keys['cell_id']: [cell_id] - }, geometry='geometry', crs="EPSG:4326") + return gpd.GeoDataFrame( + {"geometry": [scaled_polygon], keys["cell_id"]: [cell_id]}, geometry="geometry", crs="EPSG:4326" + ) # Return an empty GeoDataFrame if no valid polygon is created - return gpd.GeoDataFrame({ - 'geometry': [None], - keys['cell_id']: [cell_id] - }, geometry='geometry', crs="EPSG:4326") - - def generate_and_scale_polygons( - self, boundaries_df: dd.DataFrame, scale_factor: float = 1.0 - ) -> dgpd.GeoDataFrame: + return gpd.GeoDataFrame({"geometry": [None], keys["cell_id"]: [cell_id]}, geometry="geometry", crs="EPSG:4326") + + def generate_and_scale_polygons(self, boundaries_df: dd.DataFrame, scale_factor: float = 1.0) -> dgpd.GeoDataFrame: """ Generate and scale polygons from boundary coordinates using Dask. Keeps class structure intact by using static method for the core polygon generation. @@ -428,40 +423,39 @@ def generate_and_scale_polygons( cell_id_column = self.keys.CELL_ID.value vertex_x_column = self.keys.BOUNDARIES_VERTEX_X.value vertex_y_column = self.keys.BOUNDARIES_VERTEX_Y.value - + create_polygon = self.create_scaled_polygon # Use a lambda to wrap the static method call and avoid passing the function object directly to Dask polygons_ddf = boundaries_df.groupby(cell_id_column).apply( lambda group: create_polygon( - group=group, scale_factor=scale_factor, + group=group, + scale_factor=scale_factor, keys={ # Pass keys as a dict for the lambda function - 'vertex_x': vertex_x_column, - 'vertex_y': vertex_y_column, - 'cell_id': cell_id_column - } + "vertex_x": vertex_x_column, + "vertex_y": vertex_y_column, + "cell_id": cell_id_column, + }, ) ) - + # Lazily compute centroids for each polygon - if self.verbose: print("Adding centroids to the polygons...") - polygons_ddf['centroid_x'] = polygons_ddf.geometry.centroid.x - polygons_ddf['centroid_y'] = polygons_ddf.geometry.centroid.y - + if self.verbose: + print("Adding centroids to the polygons...") + polygons_ddf["centroid_x"] = polygons_ddf.geometry.centroid.x + polygons_ddf["centroid_y"] = polygons_ddf.geometry.centroid.y + polygons_ddf = polygons_ddf.drop_duplicates() # polygons_ddf = polygons_ddf.to_crs("EPSG:3857") return polygons_ddf - - - def compute_transcript_overlap_with_boundaries( self, transcripts_df: dd.DataFrame, boundaries_df: dd.DataFrame = None, polygons_gdf: dgpd.GeoDataFrame = None, - scale_factor: float = 1.0 - ) -> dd.DataFrame: + scale_factor: float = 1.0, + ) -> dd.DataFrame: """ Computes the overlap of transcript locations with scaled boundary polygons and assigns corresponding cell IDs to the transcripts using Dask. @@ -479,15 +473,16 @@ def compute_transcript_overlap_with_boundaries( if polygons_gdf is None: if boundaries_df is None: raise ValueError("Both boundaries_df and polygons_gdf cannot be None. Provide at least one.") - + # Generate polygons from boundaries_df if polygons_gdf is None # if self.verbose: print(f"No precomputed polygons provided. Computing polygons from boundaries with a scale factor of {scale_factor}.") polygons_gdf = self.generate_and_scale_polygons(boundaries_df, scale_factor) - + if polygons_gdf.empty(): raise ValueError("No valid polygons were generated from the boundaries.") else: - if self.verbose: print(f"Polygons are available. Proceeding with overlap computation.") + if self.verbose: + print(f"Polygons are available. Proceeding with overlap computation.") # Create a delayed function to check if a point is within any polygon def check_overlap(transcript, polygons_gdf): @@ -508,11 +503,14 @@ def check_overlap(transcript, polygons_gdf): return overlap, cell_id # Apply the check_overlap function in parallel to each row using Dask's map_partitions - if self.verbose: print(f"Starting overlap computation for transcripts with the boundary polygons.") + if self.verbose: + print(f"Starting overlap computation for transcripts with the boundary polygons.") transcripts_df = transcripts_df.map_partitions( lambda df: df.assign( **{ - self.keys.OVERLAPS_BOUNDARY.value: df.apply(lambda row: delayed(check_overlap)(row, polygons_gdf)[0], axis=1), + self.keys.OVERLAPS_BOUNDARY.value: df.apply( + lambda row: delayed(check_overlap)(row, polygons_gdf)[0], axis=1 + ), self.keys.CELL_ID.value: df.apply(lambda row: delayed(check_overlap)(row, polygons_gdf)[1], axis=1), } ) @@ -520,9 +518,6 @@ def check_overlap(transcript, polygons_gdf): return transcripts_df - - - def compute_boundaries_geometries( self, boundaries_df: dd.DataFrame = None, @@ -552,38 +547,47 @@ def compute_boundaries_geometries( if polygons_gdf is None: if boundaries_df is None: raise ValueError("Both boundaries_df and polygons_gdf cannot be None. Provide at least one.") - + # Generate polygons from boundaries_df if polygons_gdf is None - if self.verbose: print(f"No precomputed polygons provided. Computing polygons from boundaries with a scale factor of {scale_factor}.") + if self.verbose: + print( + f"No precomputed polygons provided. Computing polygons from boundaries with a scale factor of {scale_factor}." + ) polygons_gdf = self.generate_and_scale_polygons(boundaries_df, scale_factor) - + # Check if the generated polygons_gdf is empty if polygons_gdf.shape[0] == 0: raise ValueError("No valid polygons were generated from the boundaries.") else: - if self.verbose: print(f"Polygons are available. Proceeding with geometrical computations.") - + if self.verbose: + print(f"Polygons are available. Proceeding with geometrical computations.") + # Compute additional geometrical properties polygons = polygons_gdf.geometry # Compute additional geometrical properties if area: - if self.verbose: print("Computing area...") - polygons_gdf['area'] = polygons.area + if self.verbose: + print("Computing area...") + polygons_gdf["area"] = polygons.area if convexity: - if self.verbose: print("Computing convexity...") - polygons_gdf['convexity'] = polygons.convex_hull.area / polygons.area + if self.verbose: + print("Computing convexity...") + polygons_gdf["convexity"] = polygons.convex_hull.area / polygons.area if elongation: - if self.verbose: print("Computing elongation...") + if self.verbose: + print("Computing elongation...") r = polygons.minimum_rotated_rectangle() - polygons_gdf['elongation'] = (r.length * r.length) / r.area + polygons_gdf["elongation"] = (r.length * r.length) / r.area if circularity: - if self.verbose: print("Computing circularity...") + if self.verbose: + print("Computing circularity...") r = polygons_gdf.minimum_bounding_radius() - polygons_gdf['circularity'] = polygons.area / (r * r) + polygons_gdf["circularity"] = polygons.area / (r * r) + + if self.verbose: + print("Geometrical computations completed.") - if self.verbose: print("Geometrical computations completed.") - return polygons_gdf.reset_index(drop=True) def save_dataset_for_segger( @@ -604,9 +608,9 @@ def save_dataset_for_segger( sampling_rate: float = 1, num_workers: int = 1, scale_boundaries: float = 1.0, - method: str = 'kd_tree', + method: str = "kd_tree", gpu: bool = False, - workers: int = 1 + workers: int = 1, ) -> None: """ Saves the dataset for Segger in a processed format using Dask for parallel and lazy processing. @@ -631,49 +635,64 @@ def save_dataset_for_segger( method (str, optional): Method for computing edge indices (e.g., 'kd_tree', 'faiss'). gpu (bool, optional): Whether to use GPU acceleration for edge index computation. workers (int, optional): Number of workers to use to compute the neighborhood graph (per tile). - + """ # Prepare directories for storing processed tiles self._prepare_directories(processed_dir) - + # Get x and y coordinate ranges for tiling x_range, y_range = self._get_ranges(d_x, d_y) - + # Generate parameters for each tile tile_params = self._generate_tile_params( - x_range, y_range, x_size, y_size, margin_x, margin_y, compute_labels, - r_tx, k_tx, val_prob, test_prob, neg_sampling_ratio_approx, sampling_rate, - processed_dir, scale_boundaries, method, gpu, workers + x_range, + y_range, + x_size, + y_size, + margin_x, + margin_y, + compute_labels, + r_tx, + k_tx, + val_prob, + test_prob, + neg_sampling_ratio_approx, + sampling_rate, + processed_dir, + scale_boundaries, + method, + gpu, + workers, ) # Process each tile using Dask to parallelize the task - if self.verbose: print("Starting tile processing...") + if self.verbose: + print("Starting tile processing...") tasks = [delayed(self._process_tile)(params) for params in tile_params] - + with ProgressBar(): - # Use Dask to process all tiles in parallel + # Use Dask to process all tiles in parallel dask.compute(*tasks, num_workers=num_workers) - if self.verbose: print("Tile processing completed.") - + if self.verbose: + print("Tile processing completed.") def _prepare_directories(self, processed_dir: Path) -> None: """Prepares directories for saving tiles.""" processed_dir = Path(processed_dir) # by default, convert to Path object - for data_type in ['train', 'test', 'val']: - for data_stage in ['raw', 'processed']: - tile_dir = processed_dir / f'{data_type}_tiles' / data_stage + for data_type in ["train", "test", "val"]: + for data_stage in ["raw", "processed"]: + tile_dir = processed_dir / f"{data_type}_tiles" / data_stage tile_dir.mkdir(parents=True, exist_ok=True) if os.listdir(tile_dir): msg = f"Directory '{tile_dir}' must be empty." raise AssertionError(msg) - def _get_ranges(self, d_x: float, d_y: float) -> Tuple[np.ndarray, np.ndarray]: """Generates ranges for tiling.""" x_range = np.arange(self.x_min // 1000 * 1000, self.x_max, d_x) y_range = np.arange(self.y_min // 1000 * 1000, self.y_max, d_y) return x_range, y_range - + def _generate_tile_params( self, x_range: np.ndarray, @@ -693,7 +712,7 @@ def _generate_tile_params( scale_boundaries: float, method: str, gpu: bool, - workers: int + workers: int, ) -> List[Tuple]: """ Generates parameters for processing tiles using the bounding box approach. @@ -707,22 +726,36 @@ def _generate_tile_params( # Generate tile parameters based on ranges and margins tile_params = [ ( - i, j, x_size, y_size, x_range[i], y_range[j], margin_x, margin_y, - compute_labels, r_tx, k_tx, neg_sampling_ratio_approx, val_prob, - test_prob, processed_dir, scale_boundaries, sampling_rate, - method, gpu, workers + i, + j, + x_size, + y_size, + x_range[i], + y_range[j], + margin_x, + margin_y, + compute_labels, + r_tx, + k_tx, + neg_sampling_ratio_approx, + val_prob, + test_prob, + processed_dir, + scale_boundaries, + sampling_rate, + method, + gpu, + workers, ) - for i in range(len(x_range)) + for i in range(len(x_range)) for j in range(len(y_range)) ] return tile_params - - # def _process_tiles(self, tile_params: List[Tuple], num_workers: int) -> None: # """ # Processes the tiles using Dask's parallelization utilities. - + # Parameters: # ----------- # tile_params : List[Tuple] @@ -741,7 +774,6 @@ def _generate_tile_params( # if self.verbose: print("Tile processing completed.") - def _process_tile(self, tile_params: Tuple) -> None: """ Process a single tile using Dask for parallelism and lazy evaluation, and save the data. @@ -751,33 +783,54 @@ def _process_tile(self, tile_params: Tuple) -> None: Parameters for the tile processing. """ ( - i, j, x_size, y_size, x_loc, y_loc, margin_x, margin_y, compute_labels, - r_tx, k_tx, neg_sampling_ratio_approx, val_prob, test_prob, processed_dir, - scale_boundaries, sampling_rate, method, gpu, workers + i, + j, + x_size, + y_size, + x_loc, + y_loc, + margin_x, + margin_y, + compute_labels, + r_tx, + k_tx, + neg_sampling_ratio_approx, + val_prob, + test_prob, + processed_dir, + scale_boundaries, + sampling_rate, + method, + gpu, + workers, ) = tile_params - if self.verbose: print(f"Processing tile at location (x_min: {x_loc}, y_min: {y_loc}), size (width: {x_size}, height: {y_size})") + if self.verbose: + print( + f"Processing tile at location (x_min: {x_loc}, y_min: {y_loc}), size (width: {x_size}, height: {y_size})" + ) # Sampling rate to decide if the tile should be processed if random.random() > sampling_rate: - if self.verbose: print(f"Skipping tile at (x_min: {x_loc}, y_min: {y_loc}) due to sampling rate.") + if self.verbose: + print(f"Skipping tile at (x_min: {x_loc}, y_min: {y_loc}) due to sampling rate.") return # Read only the required boundaries and transcripts for this tile using delayed loading boundaries_df = delayed(self.load_boundaries)( path=self.boundaries_path, - x_min=x_loc - margin_x, - x_max=x_loc + x_size + margin_x, - y_min=y_loc - margin_y, - y_max=y_loc + y_size + margin_y + x_min=x_loc - margin_x, + x_max=x_loc + x_size + margin_x, + y_min=y_loc - margin_y, + y_max=y_loc + y_size + margin_y, ).compute() - + transcripts_df = delayed(self.load_transcripts)( path=self.transcripts_path, x_min=x_loc - margin_x, - x_max=x_loc + x_size , + x_max=x_loc + x_size, y_min=y_loc - margin_y, - y_max=y_loc + y_size + y_max=y_loc + y_size, ).compute() # If no data is found in transcripts or boundaries, skip the tile @@ -788,62 +841,78 @@ def _process_tile(self, tile_params: Tuple) -> None: # If the number of transcripts is less than 20 or the number of nuclei is less than 2, skip the tile if transcripts_df_count < 20 or boundaries_df_count < 2: - if self.verbose: print(f"Dropping tile (x_min: {x_loc}, y_min: {y_loc}) due to insufficient data (transcripts: {transcripts_df_count}, boundaries: {boundaries_df_count}).") + if self.verbose: + print( + f"Dropping tile (x_min: {x_loc}, y_min: {y_loc}) due to insufficient data (transcripts: {transcripts_df_count}, boundaries: {boundaries_df_count})." + ) return # Build PyG data structure from tile-specific data - if self.verbose: print(f"Building PyG data for tile at (x_min: {x_loc}, y_min: {y_loc})...") + if self.verbose: + print(f"Building PyG data for tile at (x_min: {x_loc}, y_min: {y_loc})...") data = delayed(self.build_pyg_data_from_tile)( - boundaries_df, transcripts_df, r_tx=r_tx, k_tx=k_tx, method=method, gpu=gpu, workers=workers, scale_boundaries=scale_boundaries + boundaries_df, + transcripts_df, + r_tx=r_tx, + k_tx=k_tx, + method=method, + gpu=gpu, + workers=workers, + scale_boundaries=scale_boundaries, ) - + data = data.compute() - if self.verbose: print(data) + if self.verbose: + print(data) try: # Probability to assign to train-val-test split prob = random.random() if compute_labels and (prob > test_prob): - if self.verbose: print(f"Computing labels for tile at (x_min: {x_loc}, y_min: {y_loc})...") + if self.verbose: + print(f"Computing labels for tile at (x_min: {x_loc}, y_min: {y_loc})...") transform = RandomLinkSplit( - num_val=0, num_test=0, is_undirected=True, edge_types=[('tx', 'belongs', 'bd')], + num_val=0, + num_test=0, + is_undirected=True, + edge_types=[("tx", "belongs", "bd")], neg_sampling_ratio=neg_sampling_ratio_approx * 2, ) data = delayed(transform)(data).compute()[0] - + # if self.verbose: print(data) # Save the tile data to the appropriate directory based on split - if self.verbose: print(f"Saving data for tile at (x_min: {x_loc}, y_min: {y_loc})...") + if self.verbose: + print(f"Saving data for tile at (x_min: {x_loc}, y_min: {y_loc})...") filename = f"tiles_x={x_loc}_y={y_loc}_w={x_size}_h={y_size}.pt" if prob > val_prob + test_prob: - torch.save(data, processed_dir / 'train_tiles' / 'processed' / filename) + torch.save(data, processed_dir / "train_tiles" / "processed" / filename) elif prob > test_prob: - torch.save(data, processed_dir / 'val_tiles' / 'processed' / filename) + torch.save(data, processed_dir / "val_tiles" / "processed" / filename) else: - torch.save(data, processed_dir / 'test_tiles' / 'processed' / filename) + torch.save(data, processed_dir / "test_tiles" / "processed" / filename) # Use Dask to save the file in parallel # save_task.compute() - if self.verbose: print(f"Tile at (x_min: {x_loc}, y_min: {y_loc}) processed and saved successfully.") + if self.verbose: + print(f"Tile at (x_min: {x_loc}, y_min: {y_loc}) processed and saved successfully.") except Exception as e: - if self.verbose: print(f"Error processing tile at (x_min: {x_loc}, y_min: {y_loc}): {e}") - - + if self.verbose: + print(f"Error processing tile at (x_min: {x_loc}, y_min: {y_loc}): {e}") def build_pyg_data_from_tile( - self, - boundaries_df: dd.DataFrame, - transcripts_df: dd.DataFrame, - r_tx: float = 5.0, - k_tx: int = 3, - method: str = 'kd_tree', - gpu: bool = False, + self, + boundaries_df: dd.DataFrame, + transcripts_df: dd.DataFrame, + r_tx: float = 5.0, + k_tx: int = 3, + method: str = "kd_tree", + gpu: bool = False, workers: int = 1, - scale_boundaries: float = 1.0 - + scale_boundaries: float = 1.0, ) -> HeteroData: """ Builds PyG data from a tile of boundaries and transcripts data using Dask utilities for efficient processing. @@ -857,7 +926,7 @@ def build_pyg_data_from_tile( gpu (bool, optional): Whether to use GPU acceleration for edge index computation. workers (int, optional): Number of workers to use for parallel processing. scale_boundaries (float, optional): The factor by which to scale the boundary polygons. Default is 1.0. - + Returns: HeteroData: PyG Heterogeneous Data object. """ @@ -865,100 +934,93 @@ def build_pyg_data_from_tile( data = HeteroData() # Lazily compute boundaries geometries using Dask - if self.verbose: print("Computing boundaries geometries...") + if self.verbose: + print("Computing boundaries geometries...") bd_gdf = self.compute_boundaries_geometries(boundaries_df, scale_factor=scale_boundaries) - bd_gdf = bd_gdf[bd_gdf['geometry'].notnull()] - + bd_gdf = bd_gdf[bd_gdf["geometry"].notnull()] + # Add boundary node data to PyG HeteroData lazily - data['bd'].id = bd_gdf[self.keys.CELL_ID.value].values - data['bd'].pos = torch.as_tensor(bd_gdf[['centroid_x', 'centroid_y']].values.astype(float)) - - if data['bd'].pos.isnan().any(): - raise ValueError(data['bd'].id[data['bd'].pos.isnan().any(1)]) - - bd_x = bd_gdf.iloc[:, 4:] - data['bd'].x = torch.as_tensor(bd_x.to_numpy(), dtype=torch.float32) + data["bd"].id = bd_gdf[self.keys.CELL_ID.value].values + data["bd"].pos = torch.as_tensor(bd_gdf[["centroid_x", "centroid_y"]].values.astype(float)) + + if data["bd"].pos.isnan().any(): + raise ValueError(data["bd"].id[data["bd"].pos.isnan().any(1)]) + bd_x = bd_gdf.iloc[:, 4:] + data["bd"].x = torch.as_tensor(bd_x.to_numpy(), dtype=torch.float32) # Extract the transcript coordinates lazily - if self.verbose: print("Preparing transcript features and positions...") + if self.verbose: + print("Preparing transcript features and positions...") x_xyz = transcripts_df[[self.keys.TRANSCRIPTS_X.value, self.keys.TRANSCRIPTS_Y.value]].to_numpy() - data['tx'].id = torch.as_tensor(transcripts_df[self.keys.TRANSCRIPTS_ID.value].values.astype(int)) - data['tx'].pos = torch.tensor(x_xyz, dtype=torch.float32) + data["tx"].id = torch.as_tensor(transcripts_df[self.keys.TRANSCRIPTS_ID.value].values.astype(int)) + data["tx"].pos = torch.tensor(x_xyz, dtype=torch.float32) - - - # Lazily prepare transcript embeddings (if available) - if self.verbose: print("Preparing transcript embeddings..") + if self.verbose: + print("Preparing transcript embeddings..") token_encoding = self.tx_encoder.transform(transcripts_df[self.keys.FEATURE_NAME.value]) - transcripts_df['token'] = token_encoding # Store the integer tokens in the 'token' column - data['tx'].token = torch.as_tensor(token_encoding).int() + transcripts_df["token"] = token_encoding # Store the integer tokens in the 'token' column + data["tx"].token = torch.as_tensor(token_encoding).int() # Handle additional embeddings lazily as well if not self.embedding_df.empty: - embeddings = delayed(lambda df: self.embedding_df.loc[ - df[self.keys.FEATURE_NAME.value].values - ].values)(transcripts_df) - else: + embeddings = delayed(lambda df: self.embedding_df.loc[df[self.keys.FEATURE_NAME.value].values].values)( + transcripts_df + ) + else: embeddings = token_encoding embeddings = embeddings.compute() x_features = torch.as_tensor(embeddings).int() - data['tx'].x = x_features + data["tx"].x = x_features # Check if the overlap column exists, if not, compute it lazily using Dask if self.keys.OVERLAPS_BOUNDARY.value not in transcripts_df.columns: - if self.verbose: print(f"Computing overlaps for transcripts...") - transcripts_df = self.compute_transcript_overlap_with_boundaries( - transcripts_df, bd_gdf, scale_factor=1.0 - ) + if self.verbose: + print(f"Computing overlaps for transcripts...") + transcripts_df = self.compute_transcript_overlap_with_boundaries(transcripts_df, bd_gdf, scale_factor=1.0) # Connect transcripts with their corresponding boundaries (e.g., nuclei, cells) - if self.verbose: print("Connecting transcripts with boundaries...") + if self.verbose: + print("Connecting transcripts with boundaries...") overlaps = transcripts_df[self.keys.OVERLAPS_BOUNDARY.value].values valid_cell_ids = bd_gdf[self.keys.CELL_ID.value].values - ind = np.where( - overlaps & transcripts_df[self.keys.CELL_ID.value].isin(valid_cell_ids) - )[0] - tx_bd_edge_index = np.column_stack(( - ind, - np.searchsorted( - valid_cell_ids, - transcripts_df.iloc[ind][self.keys.CELL_ID.value] - ) - )) + ind = np.where(overlaps & transcripts_df[self.keys.CELL_ID.value].isin(valid_cell_ids))[0] + tx_bd_edge_index = np.column_stack( + (ind, np.searchsorted(valid_cell_ids, transcripts_df.iloc[ind][self.keys.CELL_ID.value])) + ) # Add transcript-boundary edge index to PyG HeteroData - data['tx', 'belongs', 'bd'].edge_index = torch.as_tensor(tx_bd_edge_index.T, dtype=torch.long) + data["tx", "belongs", "bd"].edge_index = torch.as_tensor(tx_bd_edge_index.T, dtype=torch.long) # Compute transcript-to-transcript (tx-tx) edges using Dask (lazy computation) - if self.verbose: print("Computing tx-tx edges...") + if self.verbose: + print("Computing tx-tx edges...") tx_positions = transcripts_df[[self.keys.TRANSCRIPTS_X.value, self.keys.TRANSCRIPTS_Y.value]].values delayed_tx_edge_index = delayed(get_edge_index)( - tx_positions, - tx_positions, - k=k_tx, - dist=r_tx, - method=method, - gpu=gpu, - workers=workers + tx_positions, tx_positions, k=k_tx, dist=r_tx, method=method, gpu=gpu, workers=workers ) tx_edge_index = delayed_tx_edge_index.compute() # Add the tx-tx edge index to the PyG HeteroData object - data['tx', 'neighbors', 'tx'].edge_index = torch.as_tensor(tx_edge_index.T, dtype=torch.long) - - - if self.verbose: print("Finished building PyG data for the tile.") - return data - - - + data["tx", "neighbors", "tx"].edge_index = torch.as_tensor(tx_edge_index.T, dtype=torch.long) + if self.verbose: + print("Finished building PyG data for the tile.") + return data class XeniumSample(SpatialTranscriptomicsSample): - def __init__(self, transcripts_df: dd.DataFrame = None, transcripts_radius: int = 10, boundaries_graph: bool = False, embedding_df: pd.DataFrame = None, verbose: bool = True): - super().__init__(transcripts_df, transcripts_radius, boundaries_graph, embedding_df, XeniumKeys, verbose=verbose) + def __init__( + self, + transcripts_df: dd.DataFrame = None, + transcripts_radius: int = 10, + boundaries_graph: bool = False, + embedding_df: pd.DataFrame = None, + verbose: bool = True, + ): + super().__init__( + transcripts_df, transcripts_radius, boundaries_graph, embedding_df, XeniumKeys, verbose=verbose + ) def filter_transcripts(self, transcripts_df: dd.DataFrame, min_qv: float = 20.0) -> dd.DataFrame: """ @@ -977,14 +1039,14 @@ def filter_transcripts(self, transcripts_df: dd.DataFrame, min_qv: float = 20.0) "NegControlCodeword_", "BLANK_", "DeprecatedCodeword_", - "UnassignedCodeword_" + "UnassignedCodeword_", ) # Ensure FEATURE_NAME is a string type for proper filtering (compatible with Dask) # Handle potential bytes to string conversion for Dask DataFrame if pd.api.types.is_object_dtype(transcripts_df[self.keys.FEATURE_NAME.value]): transcripts_df[self.keys.FEATURE_NAME.value] = transcripts_df[self.keys.FEATURE_NAME.value].apply( - lambda x: x.decode('utf-8') if isinstance(x, bytes) else x + lambda x: x.decode("utf-8") if isinstance(x, bytes) else x ) # Apply the quality value filter using Dask @@ -1001,7 +1063,14 @@ def filter_transcripts(self, transcripts_df: dd.DataFrame, min_qv: float = 20.0) class MerscopeSample(SpatialTranscriptomicsSample): - def __init__(self, transcripts_df: dd.DataFrame = None, transcripts_radius: int = 10, boundaries_graph: bool = False, embedding_df: pd.DataFrame = None, verbose: bool = True): + def __init__( + self, + transcripts_df: dd.DataFrame = None, + transcripts_radius: int = 10, + boundaries_graph: bool = False, + embedding_df: pd.DataFrame = None, + verbose: bool = True, + ): super().__init__(transcripts_df, transcripts_radius, boundaries_graph, embedding_df, MerscopeKeys) def filter_transcripts(self, transcripts_df: dd.DataFrame, min_qv: float = 20.0) -> dd.DataFrame: @@ -1021,5 +1090,3 @@ def filter_transcripts(self, transcripts_df: dd.DataFrame, min_qv: float = 20.0) # Add custom Merscope-specific filtering logic if needed # For now, apply only the quality value filter return transcripts_df[transcripts_df[self.keys.QUALITY_VALUE.value] >= min_qv] - - diff --git a/src/segger/data/parquet/_experimental.py b/src/segger/data/parquet/_experimental.py index f8af0f1..739ff23 100644 --- a/src/segger/data/parquet/_experimental.py +++ b/src/segger/data/parquet/_experimental.py @@ -1,9 +1,9 @@ - from typing import TYPE_CHECKING -if TYPE_CHECKING: # False at runtime +if TYPE_CHECKING: # False at runtime import dask, cudf, dask_cudf, pandas as pd + class BackendHandler: """ A class to handle different DataFrame backends for reading and processing @@ -19,15 +19,15 @@ class BackendHandler: Methods ------- read_parquet(): - Returns the function to read Parquet files according to the selected + Returns the function to read Parquet files according to the selected backend. """ _valid_backends = { - 'pandas', - 'dask', - 'cudf', - 'dask_cudf', + "pandas", + "dask", + "cudf", + "dask_cudf", } def __init__(self, backend): @@ -35,31 +35,31 @@ def __init__(self, backend): if backend in self._valid_backends: self.backend = backend else: - valid = ', '.join(map(lambda o: f"'{o}'", self._valid_backends)) + valid = ", ".join(map(lambda o: f"'{o}'", self._valid_backends)) msg = f"Unsupported backend: {backend}. Valid options are {valid}." raise ValueError(msg) # Dynamically import packages only if requested - if self.backend == 'pandas': + if self.backend == "pandas": import pandas as pd - elif self.backend == 'dask': + elif self.backend == "dask": import dask - elif self.backend == 'cudf': + elif self.backend == "cudf": import cudf - elif self.backend == 'dask_cudf': + elif self.backend == "dask_cudf": import dask_cudf else: - raise ValueError('Internal Error') + raise ValueError("Internal Error") @property def read_parquet(self): - if self.backend == 'pandas': + if self.backend == "pandas": return pd.read_parquet - elif self.backend == 'dask': + elif self.backend == "dask": return dask.dataframe.read_parquet - elif self.backend == 'cudf': + elif self.backend == "cudf": return cudf.read_parquet - elif self.backend == 'dask_cudf': + elif self.backend == "dask_cudf": return dask_cudf.read_parquet else: - raise ValueError('Internal Error') \ No newline at end of file + raise ValueError("Internal Error") diff --git a/src/segger/data/parquet/_ndtree.py b/src/segger/data/parquet/_ndtree.py index cc68ef0..bad3ee5 100644 --- a/src/segger/data/parquet/_ndtree.py +++ b/src/segger/data/parquet/_ndtree.py @@ -3,10 +3,11 @@ import numpy as np import math -class NDTree(): + +class NDTree: """ - NDTree is a data structure for recursively splitting multi-dimensional data - into smaller regions until each leaf node contains less than or equal to a + NDTree is a data structure for recursively splitting multi-dimensional data + into smaller regions until each leaf node contains less than or equal to a specified number of points. It stores these regions in a balanced binary tree. @@ -19,7 +20,7 @@ class NDTree(): idx : np.ndarray The indices of the input data points. boxes : list - A list to store the bounding boxes (as shapely polygons) of each region + A list to store the bounding boxes (as shapely polygons) of each region in the tree. rect : Rectangle The bounding box of the entire input data space. @@ -46,7 +47,8 @@ def __init__(self, data, n): self.rect = Rectangle(data.min(0), data.max(0)) self.tree = innernode(self.n, self.idx, self.rect, self) -class innernode(): + +class innernode: """ Represents a node in the NDTree. Each node either stores a bounding box for the data it contains (leaf nodes) or splits the data into two child nodes. @@ -66,7 +68,7 @@ class innernode(): split_point : float The value along the split dimension used to divide the data. less : innernode - The child node containing data points less than or equal to the split + The child node containing data points less than or equal to the split point. greater : innernode The child node containing data points greater than the split point. @@ -85,10 +87,10 @@ def __init__(self, n, idx, rect, tree): else: box = shapely.box(*self.rect.mins, *self.rect.maxes) self.tree.boxes.append(box) - + def split(self): """ - Recursively splits the node's data into two child nodes along the + Recursively splits the node's data into two child nodes along the dimension with the largest spread. """ less = math.floor(self.n // 2) @@ -98,19 +100,6 @@ def split(self): data = data[:, self.split_dim] self.split_point = np.quantile(data, less / (less + greater)) mask = data <= self.split_point - less_rect, greater_rect = self.rect.split( - self.split_dim, - self.split_point - ) - self.less = innernode( - less, - self.idx[mask], - less_rect, - self.tree - ) - self.greater = innernode( - greater, - self.idx[~mask], - greater_rect, - self.tree - ) \ No newline at end of file + less_rect, greater_rect = self.rect.split(self.split_dim, self.split_point) + self.less = innernode(less, self.idx[mask], less_rect, self.tree) + self.greater = innernode(greater, self.idx[~mask], greater_rect, self.tree) diff --git a/src/segger/data/parquet/_settings/xenium.yaml b/src/segger/data/parquet/_settings/xenium.yaml index 7304aa7..6c5333e 100644 --- a/src/segger/data/parquet/_settings/xenium.yaml +++ b/src/segger/data/parquet/_settings/xenium.yaml @@ -13,14 +13,14 @@ transcripts: - "BLANK_" - "DeprecatedCodeword_" - "UnassignedCodeword_" - xy: + xy: - "x_location" - "y_location" - xyz: + xyz: - "x_location" - "y_location" - "z_location" - columns: + columns: - "x_location" - "y_location" - "z_location" @@ -36,10 +36,10 @@ boundaries: y: "vertex_y" id: "cell_id" label: "cell_id" - xy: + xy: - "vertex_x" - "vertex_y" - columns: + columns: - "vertex_x" - "vertex_y" - "cell_id" diff --git a/src/segger/data/parquet/_utils.py b/src/segger/data/parquet/_utils.py index 6f29cec..8c3ffec 100644 --- a/src/segger/data/parquet/_utils.py +++ b/src/segger/data/parquet/_utils.py @@ -10,6 +10,7 @@ from pathlib import Path import yaml + def get_xy_extents( filepath, x: str, @@ -50,6 +51,7 @@ def get_xy_extents( bounds = shapely.box(x_min, y_min, x_max, y_max) return bounds + def read_parquet_region( filepath, x: str, @@ -89,14 +91,17 @@ def read_parquet_region( # Find bounds of full file if not supplied if bounds is None: bounds = get_xy_bounds(filepath, x, y) - + # Load pre-filtered data from Parquet file - filters = [[ - (x, '>', bounds.bounds[0]), - (y, '>', bounds.bounds[1]), - (x, '<', bounds.bounds[2]), - (y, '<', bounds.bounds[3]), - ] + extra_filters] + filters = [ + [ + (x, ">", bounds.bounds[0]), + (y, ">", bounds.bounds[1]), + (x, "<", bounds.bounds[2]), + (y, "<", bounds.bounds[3]), + ] + + extra_filters + ] columns = list({x, y} | set(extra_columns)) @@ -107,6 +112,7 @@ def read_parquet_region( ) return region + def get_polygons_from_xy( boundaries: pd.DataFrame, x: str, @@ -114,13 +120,13 @@ def get_polygons_from_xy( label: str, ) -> gpd.GeoSeries: """ - Convert boundary coordinates from a cuDF DataFrame to a GeoSeries of + Convert boundary coordinates from a cuDF DataFrame to a GeoSeries of polygons. Parameters ---------- boundaries : pd.DataFrame - A DataFrame containing the boundary data with x and y coordinates + A DataFrame containing the boundary data with x and y coordinates and identifiers. x : str The name of the column representing the x-coordinate. @@ -133,7 +139,7 @@ def get_polygons_from_xy( Returns ------- gpd.GeoSeries - A GeoSeries containing the polygons created from the boundary + A GeoSeries containing the polygons created from the boundary coordinates. """ # Polygon offsets in coords @@ -152,6 +158,7 @@ def get_polygons_from_xy( return gs + def filter_boundaries( boundaries: pd.DataFrame, inset: shapely.Polygon, @@ -161,13 +168,13 @@ def filter_boundaries( label: str, ): """ - Filter boundary polygons based on their overlap with specified inset and + Filter boundary polygons based on their overlap with specified inset and outset regions. Parameters ---------- boundaries : cudf.DataFrame - A DataFrame containing the boundary data with x and y coordinates and + A DataFrame containing the boundary data with x and y coordinates and identifiers. inset : shapely.Polygon A polygon representing the inner region to filter the boundaries. @@ -187,43 +194,46 @@ def filter_boundaries( Notes ----- - The function determines overlaps of boundary polygons with the specified - inset and outset regions. It creates boolean masks for overlaps with the - top, left, right, and bottom sides of the outset region, as well as the - center region defined by the inset polygon. The filtering logic includes + The function determines overlaps of boundary polygons with the specified + inset and outset regions. It creates boolean masks for overlaps with the + top, left, right, and bottom sides of the outset region, as well as the + center region defined by the inset polygon. The filtering logic includes polygons that: - Are completely within the center region. - Overlap with the center and the left side, but not the bottom side. - Overlap with the center and the top side, but not the right side. """ + # Determine overlaps of boundary polygons def in_region(region): in_x = boundaries[x].between(region.bounds[0], region.bounds[2]) in_y = boundaries[y].between(region.bounds[1], region.bounds[3]) return in_x & in_y + x1, y1, x4, y4 = outset.bounds x2, y2, x3, y3 = inset.bounds - boundaries['top'] = in_region(shapely.box(x1, y1, x4, y2)) - boundaries['left'] = in_region(shapely.box(x1, y1, x2, y4)) - boundaries['right'] = in_region(shapely.box(x3, y1, x4, y4)) - boundaries['bottom'] = in_region(shapely.box(x1, y3, x4, y4)) - boundaries['center'] = in_region(inset) + boundaries["top"] = in_region(shapely.box(x1, y1, x4, y2)) + boundaries["left"] = in_region(shapely.box(x1, y1, x2, y4)) + boundaries["right"] = in_region(shapely.box(x3, y1, x4, y4)) + boundaries["bottom"] = in_region(shapely.box(x1, y3, x4, y4)) + boundaries["center"] = in_region(inset) # Filter boundary polygons # Include overlaps with top and left, not bottom and right gb = boundaries.groupby(label, sort=False) - total = gb['center'].transform('size') - in_top = gb['top'].transform('sum') - in_left = gb['left'].transform('sum') - in_right = gb['right'].transform('sum') - in_bottom = gb['bottom'].transform('sum') - in_center = gb['center'].transform('sum') + total = gb["center"].transform("size") + in_top = gb["top"].transform("sum") + in_left = gb["left"].transform("sum") + in_right = gb["right"].transform("sum") + in_bottom = gb["bottom"].transform("sum") + in_center = gb["center"].transform("sum") keep = in_center == total - keep |= ((in_center > 0) & (in_left > 0) & (in_bottom == 0)) - keep |= ((in_center > 0) & (in_top > 0) & (in_right == 0)) + keep |= (in_center > 0) & (in_left > 0) & (in_bottom == 0) + keep |= (in_center > 0) & (in_top > 0) & (in_right == 0) inset_boundaries = boundaries.loc[keep] return inset_boundaries + def filter_transcripts( transcripts_df: pd.DataFrame, label: Optional[str] = None, @@ -256,9 +266,10 @@ def filter_transcripts( mask &= transcripts_df["qv"].ge(min_qv) return transcripts_df[mask] + def load_settings(sample_type: str) -> SimpleNamespace: """ - Loads a matching YAML file from the _settings/ directory and converts its + Loads a matching YAML file from the _settings/ directory and converts its contents into a SimpleNamespace. Parameters @@ -276,25 +287,23 @@ def load_settings(sample_type: str) -> SimpleNamespace: ValueError If `sample_type` does not match any filenames. """ - settings_dir = Path(__file__).parent.resolve() / '_settings' + settings_dir = Path(__file__).parent.resolve() / "_settings" # Get a list of YAML filenames (without extensions) in the _settings dir - filenames = [file.stem for file in settings_dir.glob('*.yaml')] + filenames = [file.stem for file in settings_dir.glob("*.yaml")] # Convert sample_type to lowercase and check if it matches any filename sample_type = sample_type.lower() if sample_type not in filenames: - msg = ( - f"Sample type '{sample_type}' not found in settings. " - f"Available options: {', '.join(filenames)}" - ) + msg = f"Sample type '{sample_type}' not found in settings. " f"Available options: {', '.join(filenames)}" raise FileNotFoundError(msg) # Load the matching YAML file yaml_file_path = settings_dir / f"{sample_type}.yaml" - with yaml_file_path.open('r') as file: + with yaml_file_path.open("r") as file: data = yaml.safe_load(file) - + # Convert the YAML data into a SimpleNamespace recursively return _dict_to_namespace(data) + def _dict_to_namespace(d): """ Recursively converts a dictionary to a SimpleNamespace. @@ -302,4 +311,4 @@ def _dict_to_namespace(d): if isinstance(d, dict): d = {k: _dict_to_namespace(v) for k, v in d.items()} return SimpleNamespace(**d) - return d \ No newline at end of file + return d diff --git a/src/segger/data/parquet/pyg_dataset.py b/src/segger/data/parquet/pyg_dataset.py index 5599cb3..d64b9e6 100644 --- a/src/segger/data/parquet/pyg_dataset.py +++ b/src/segger/data/parquet/pyg_dataset.py @@ -5,17 +5,19 @@ from pathlib import Path import torch + class STPyGDataset(InMemoryDataset): """ - An in-memory dataset class for handling training using spatial + An in-memory dataset class for handling training using spatial transcriptomics data. """ + def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None + pre_filter: Optional[Callable] = None, ): super().__init__(root, transform, pre_transform, pre_filter) @@ -37,7 +39,7 @@ def processed_file_names(self) -> List[str]: Returns: List[str]: List of processed file names. """ - paths = glob.glob(f'{self.processed_dir}/tiles_x*_y*_*_*.pt') + paths = glob.glob(f"{self.processed_dir}/tiles_x*_y*_*_*.pt") # paths = paths.append(paths = glob.glob(f'{self.processed_dir}/tiles_x*_y*_*_*.pt')) file_names = list(map(os.path.basename, paths)) return file_names @@ -63,13 +65,13 @@ def get(self, idx: int) -> Data: """ filepath = Path(self.processed_dir) / self.processed_file_names[idx] data = torch.load(filepath) - data['tx'].x = data['tx'].x.to_dense() - if data['tx'].x.dim() == 1: - data['tx'].x = data['tx'].x.unsqueeze(1) - assert data['tx'].x.dim() == 2 + data["tx"].x = data["tx"].x.to_dense() + if data["tx"].x.dim() == 1: + data["tx"].x = data["tx"].x.unsqueeze(1) + assert data["tx"].x.dim() == 2 # this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph - if data['tx', 'belongs', 'bd'].edge_label_index.dim() == 1: - data['tx', 'belongs', 'bd'].edge_label_index = data['tx', 'belongs', 'bd'].edge_label_index.unsqueeze(1) - data['tx', 'belongs', 'bd'].edge_label = data['tx', 'belongs', 'bd'].edge_label.unsqueeze(0) - assert data['tx', 'belongs', 'bd'].edge_label_index.dim() == 2 + if data["tx", "belongs", "bd"].edge_label_index.dim() == 1: + data["tx", "belongs", "bd"].edge_label_index = data["tx", "belongs", "bd"].edge_label_index.unsqueeze(1) + data["tx", "belongs", "bd"].edge_label = data["tx", "belongs", "bd"].edge_label.unsqueeze(0) + assert data["tx", "belongs", "bd"].edge_label_index.dim() == 2 return data diff --git a/src/segger/data/parquet/sample.py b/src/segger/data/parquet/sample.py index 5e21366..6937a72 100644 --- a/src/segger/data/parquet/sample.py +++ b/src/segger/data/parquet/sample.py @@ -21,12 +21,12 @@ # TODO: Add documentation for settings -class STSampleParquet(): +class STSampleParquet: """ A class to manage spatial transcriptomics data stored in parquet files. - This class provides methods for loading, processing, and saving data related - to ST samples. It supports parallel processing and efficient handling of + This class provides methods for loading, processing, and saving data related + to ST samples. It supports parallel processing and efficient handling of transcript and boundary data. """ @@ -51,7 +51,7 @@ def __init__( Raises ------ FileNotFoundError - If the base directory does not exist or the required files are + If the base directory does not exist or the required files are missing. """ # Setup paths and resource constraints @@ -65,18 +65,17 @@ def __init__( # Setup logging logging.basicConfig(level=logging.INFO) - self.logger = logging.Logger(f'STSample@{base_dir}') + self.logger = logging.Logger(f"STSample@{base_dir}") # Internal caches self._extents = None self._transcripts_metadata = None self._boundaries_metadata = None - # Setup default embedding for transcripts - classes = self.transcripts_metadata['feature_names'] + # Setup default embedding for transcripts + classes = self.transcripts_metadata["feature_names"] self._transcript_embedding = TranscriptEmbedding(np.array(classes)) - @classmethod def _get_parquet_metadata( cls, @@ -91,7 +90,7 @@ def _get_parquet_metadata( filepath : os.PathLike The path to the parquet file. columns : Optional[List[str]], default None - List of columns to extract metadata for. If None, all columns + List of columns to extract metadata for. If None, all columns are used. Returns @@ -109,13 +108,13 @@ def _get_parquet_metadata( """ # Size in bytes of field dtypes size_map = { - 'BOOLEAN': 1, - 'INT32': 4, - 'FLOAT': 4, - 'INT64': 8, - 'DOUBLE': 8, - 'BYTE_ARRAY': 8, - 'INT96': 12, + "BOOLEAN": 1, + "INT32": 4, + "FLOAT": 4, + "INT64": 8, + "DOUBLE": 8, + "BYTE_ARRAY": 8, + "INT96": 12, } # Read in metadata @@ -129,21 +128,20 @@ def _get_parquet_metadata( # Grab important fields from metadata summary = dict() - summary['n_rows'] = metadata.num_rows - summary['n_columns'] = len(columns) - summary['column_sizes'] = dict() + summary["n_rows"] = metadata.num_rows + summary["n_columns"] = len(columns) + summary["column_sizes"] = dict() for c in columns: # Error where 10X saved BOOLEAN field as INT32 in schema - if c == 'overlaps_nucleus': - dtype = 'BOOLEAN' + if c == "overlaps_nucleus": + dtype = "BOOLEAN" else: i = metadata.schema.names.index(c) dtype = metadata.schema[i].physical_type - summary['column_sizes'][c] = size_map[dtype] + summary["column_sizes"][c] = size_map[dtype] return summary - @cached_property def transcripts_metadata(self) -> dict: """ @@ -152,7 +150,7 @@ def transcripts_metadata(self) -> dict: Returns ------- dict - Metadata dictionary for transcripts including column sizes and + Metadata dictionary for transcripts including column sizes and feature names. Raises @@ -169,13 +167,12 @@ def transcripts_metadata(self) -> dict: # Get filtered unique feature names table = pq.read_table(self._transcripts_filepath) names = pc.unique(table[self.settings.transcripts.label]) - pattern = '|'.join(self.settings.transcripts.filter_substrings) + pattern = "|".join(self.settings.transcripts.filter_substrings) mask = pc.invert(pc.match_substring_regex(names, pattern)) - metadata['feature_names'] = pc.filter(names, mask).tolist() + metadata["feature_names"] = pc.filter(names, mask).tolist() self._transcripts_metadata = metadata return self._transcripts_metadata - @cached_property def boundaries_metadata(self) -> dict: """ @@ -199,7 +196,6 @@ def boundaries_metadata(self) -> dict: self._boundaries_metadata = metadata return self._boundaries_metadata - @property def n_transcripts(self) -> int: """ @@ -210,8 +206,7 @@ def n_transcripts(self) -> int: int The number of transcripts. """ - return self.transcripts_metadata['n_rows'] - + return self.transcripts_metadata["n_rows"] @cached_property def extents(self) -> shapely.Polygon: @@ -236,7 +231,6 @@ def extents(self) -> shapely.Polygon: return self._extents - def _get_balanced_regions( self, ) -> List[shapely.Polygon]: @@ -252,10 +246,10 @@ def _get_balanced_regions( # If no. workers is 1, return full extents if self.n_workers == 1: return [self.extents] - + # Otherwise, split based on boundary distribution which is much smaller # than transcripts DataFrame. - # Note: Assumes boundaries are distributed similarly to transcripts at + # Note: Assumes boundaries are distributed similarly to transcripts at # a coarse level. data = pd.read_parquet( self._boundaries_filepath, @@ -265,7 +259,6 @@ def _get_balanced_regions( return ndtree.boxes - @staticmethod def _setup_directory( data_dir: os.PathLike, @@ -273,8 +266,8 @@ def _setup_directory( """ Sets up the directory structure for saving processed tiles. - Ensures that the necessary subdirectories for 'train', 'test', and - 'val' are created under the provided base directory. If any of these + Ensures that the necessary subdirectories for 'train', 'test', and + 'val' are created under the provided base directory. If any of these subdirectories already exist and are not empty, an error is raised. Directory structure created: @@ -298,15 +291,14 @@ def _setup_directory( If any of the 'processed' directories already contain files. """ data_dir = Path(data_dir) # by default, convert to Path object - for tile_type in ['train_tiles', 'test_tiles', 'val_tiles']: - for stage in ['raw', 'processed']: + for tile_type in ["train_tiles", "test_tiles", "val_tiles"]: + for stage in ["raw", "processed"]: tile_dir = data_dir / tile_type / stage tile_dir.mkdir(parents=True, exist_ok=True) if os.listdir(tile_dir): msg = f"Directory '{tile_dir}' must be empty." raise AssertionError(msg) - def set_transcript_embedding(self, weights: pd.DataFrame): """ Sets the transcript embedding for the sample. @@ -319,33 +311,32 @@ def set_transcript_embedding(self, weights: pd.DataFrame): Raises ------ ValueError - If the provided weights do not match the number of transcript + If the provided weights do not match the number of transcript features. """ - classes = self._transcripts_metadata['feature_names'] + classes = self._transcripts_metadata["feature_names"] self._transcript_embedding = TranscriptEmbedding(classes, weights) - def save( self, data_dir: os.PathLike, k_bd: int = 3, - dist_bd: float = 15., + dist_bd: float = 15.0, k_tx: int = 3, - dist_tx: float = 5., + dist_tx: float = 5.0, tile_size: Optional[int] = None, tile_width: Optional[float] = None, tile_height: Optional[float] = None, - neg_sampling_ratio: float = 5., - frac: float = 1., + neg_sampling_ratio: float = 5.0, + frac: float = 1.0, val_prob: float = 0.1, test_prob: float = 0.2, ): """ - Saves the tiles of the sample as PyTorch geometric datasets. See + Saves the tiles of the sample as PyTorch geometric datasets. See documentation for 'STTile' for more information on dataset contents. - Note: This function requires either 'tile_size' OR both 'tile_width' and + Note: This function requires either 'tile_size' OR both 'tile_width' and 'tile_height' to be provided. Parameters @@ -361,7 +352,7 @@ def save( dist_tx : float, optional, default 5.0 Maximum distance for transcript neighbors. tile_size : int, optional - If provided, specifies the size of the tile. Overrides `tile_width` + If provided, specifies the size of the tile. Overrides `tile_width` and `tile_height`. tile_width : int, optional Width of the tiles in pixels. Ignored if `tile_size` is provided. @@ -379,7 +370,7 @@ def save( Raises ------ ValueError - If the 'frac' parameter is greater than 1.0 or if the calculated + If the 'frac' parameter is greater than 1.0 or if the calculated number of tiles is zero. AssertionError If the specified directory structure is not properly set up. @@ -412,7 +403,7 @@ def func(region): for tile in tiles: # Choose training, test, or validation datasets data_type = np.random.choice( - a=['train_tiles', 'test_tiles', 'val_tiles'], + a=["train_tiles", "test_tiles", "val_tiles"], p=[1 - (test_prob + val_prob), test_prob, val_prob], ) xt = STTile(dataset=xm, extents=tile) @@ -425,9 +416,9 @@ def func(region): ) if pyg_data is not None: if pyg_data["tx", "belongs", "bd"].edge_index.numel() == 0: - # this tile is only for testing - data_type = 'test_tiles' - filepath = data_dir / data_type / 'processed' / f'{xt.uid}.pt' + # this tile is only for testing + data_type = "test_tiles" + filepath = data_dir / data_type / "processed" / f"{xt.uid}.pt" torch.save(pyg_data, filepath) # TODO: Add Dask backend @@ -436,12 +427,12 @@ def func(region): # TODO: Add documentation for settings -class STInMemoryDataset(): +class STInMemoryDataset: """ A class for handling in-memory representations of ST data. This class is used to load and manage ST sample data from parquet files, - filter boundaries and transcripts, and provide spatial tiling for further + filter boundaries and transcripts, and provide spatial tiling for further analysis. The class also pre-loads KDTrees for efficient spatial queries. Parameters @@ -467,7 +458,7 @@ class STInMemoryDataset(): The filtered boundaries within the dataset extents. kdtree_tx : KDTree The KDTree for fast spatial queries on the transcripts. - + Raises ------ ValueError @@ -482,7 +473,7 @@ def __init__( ): """ Initializes the STInMemoryDataset instance by loading transcripts - and boundaries from parquet files and pre-loading a KDTree for fast + and boundaries from parquet files and pre-loading a KDTree for fast spatial queries. Parameters @@ -505,11 +496,7 @@ def __init__( self._load_boundaries(self.sample._boundaries_filepath) # Pre-load KDTrees - self.kdtree_tx = KDTree( - self.transcripts[self.settings.transcripts.xy], - leafsize=100 - ) - + self.kdtree_tx = KDTree(self.transcripts[self.settings.transcripts.xy], leafsize=100) def _load_transcripts(self, path: os.PathLike, min_qv: float = 30.0): """ @@ -528,7 +515,7 @@ def _load_transcripts(self, path: os.PathLike, min_qv: float = 30.0): If the transcripts dataframe cannot be loaded or filtered. """ # Load and filter transcripts dataframe - bounds = self.extents.buffer(self.margin, join_style='mitre') + bounds = self.extents.buffer(self.margin, join_style="mitre") transcripts = utils.read_parquet_region( path, x=self.settings.transcripts.x, @@ -542,11 +529,10 @@ def _load_transcripts(self, path: os.PathLike, min_qv: float = 30.0): self.settings.transcripts.filter_substrings, min_qv, ) - + # Only set object properties once everything finishes successfully self.transcripts = transcripts - def _load_boundaries(self, path: os.PathLike): """ Loads and filters the boundaries dataframe for the dataset. @@ -562,7 +548,7 @@ def _load_boundaries(self, path: os.PathLike): If the boundaries dataframe cannot be loaded or filtered. """ # Load and filter boundaries dataframe - outset = self.extents.buffer(self.margin, join_style='mitre') + outset = self.extents.buffer(self.margin, join_style="mitre") boundaries = utils.read_parquet_region( path, x=self.settings.boundaries.x, @@ -580,7 +566,6 @@ def _load_boundaries(self, path: os.PathLike): ) self.boundaries = boundaries - def _get_rectangular_tile_bounds( self, tile_width: float, @@ -607,7 +592,7 @@ def _get_rectangular_tile_bounds( x_coords = np.append(x_coords, x_max) y_coords = np.arange(y_min, y_max, tile_height) y_coords = np.append(y_coords, y_max) - + # Generate tiles from grid points tiles = [] for x_min, x_max in zip(x_coords[:-1], x_coords[1:]): @@ -616,7 +601,6 @@ def _get_rectangular_tile_bounds( return tiles - def _get_balanced_tile_bounds( self, max_size: Optional[int], @@ -657,14 +641,14 @@ def recurse(node, bounds): bounds = Rectangle(self.kdtree_tx.mins, self.kdtree_tx.maxes) return recurse(node, bounds) - - def _tile(self, + def _tile( + self, width: Optional[float] = None, height: Optional[float] = None, max_size: Optional[int] = None, - ) -> List[shapely.Polygon]: + ) -> List[shapely.Polygon]: """ - Generates tiles based on either fixed dimensions or balanced + Generates tiles based on either fixed dimensions or balanced partitioning. Parameters @@ -674,7 +658,7 @@ def _tile(self, height : Optional[float] The height of each tile. Required if `max_size` is not provided. max_size : Optional[int] - The maximum number of points in each tile. Required if `width` and + The maximum number of points in each tile. Required if `width` and `height` are not provided. Returns @@ -685,7 +669,7 @@ def _tile(self, Raises ------ ValueError - If both `width`/`height` and `max_size` are provided or none are + If both `width`/`height` and `max_size` are provided or none are provided. """ # Square tiling kwargs provided @@ -697,11 +681,8 @@ def _tile(self, # Bad set of kwargs else: args = list(compress(locals().keys(), locals().values())) - args.remove('self') - msg = ( - "Function requires either 'max_size' or both " - f"'width' and 'height'. Found: {', '.join(args)}." - ) + args.remove("self") + msg = "Function requires either 'max_size' or both " f"'width' and 'height'. Found: {', '.join(args)}." logging.error(msg) raise ValueError @@ -740,9 +721,9 @@ def __init__( Notes ----- - The `boundaries` and `transcripts` attributes are cached to avoid the - overhead of filtering when tiles are instantiated. This is particularly - useful in multiprocessing settings where generating tiles in parallel + The `boundaries` and `transcripts` attributes are cached to avoid the + overhead of filtering when tiles are instantiated. This is particularly + useful in multiprocessing settings where generating tiles in parallel could lead to high overhead. Internal Attributes @@ -761,22 +742,21 @@ def __init__( self._boundaries = None self._transcripts = None - @property def uid(self) -> str: """ - Generates a unique identifier for the tile based on its extents. This - UID is particularly useful for saving or indexing tiles in distributed + Generates a unique identifier for the tile based on its extents. This + UID is particularly useful for saving or indexing tiles in distributed processing environments. The UID is constructed using the minimum and maximum x and y coordinates - of the tile's bounding box, representing its position and size in the + of the tile's bounding box, representing its position and size in the sample. Returns ------- str - A unique identifier string in the format + A unique identifier string in the format 'x=_y=_w=_h=' where: - ``: Minimum x-coordinate of the tile's extents. - ``: Minimum y-coordinate of the tile's extents. @@ -790,52 +770,49 @@ def uid(self) -> str: 'x=100_y=200_w=50_h=50' """ x_min, y_min, x_max, y_max = map(int, self.extents.bounds) - uid = f'tiles_x={x_min}_y={y_min}_w={x_max-x_min}_h={y_max-y_min}' + uid = f"tiles_x={x_min}_y={y_min}_w={x_max-x_min}_h={y_max-y_min}" return uid - @cached_property def boundaries(self) -> pd.DataFrame: """ Returns the filtered boundaries within the tile extents, cached for efficiency. - The boundaries are computed only once and cached. If the boundaries - have not been computed yet, they are computed using + The boundaries are computed only once and cached. If the boundaries + have not been computed yet, they are computed using `get_filtered_boundaries()`. Returns ------- pd.DataFrame - A DataFrame containing the filtered boundaries within the tile + A DataFrame containing the filtered boundaries within the tile extents. """ if self._boundaries is None: self._boundaries = self.get_filtered_boundaries() return self._boundaries - @cached_property def transcripts(self) -> pd.DataFrame: """ Returns the filtered transcripts within the tile extents, cached for efficiency. - The transcripts are computed only once and cached. If the transcripts - have not been computed yet, they are computed using + The transcripts are computed only once and cached. If the transcripts + have not been computed yet, they are computed using `get_filtered_transcripts()`. Returns ------- pd.DataFrame - A DataFrame containing the filtered transcripts within the tile + A DataFrame containing the filtered transcripts within the tile extents. """ if self._transcripts is None: self._transcripts = self.get_filtered_transcripts() return self._transcripts - def get_filtered_boundaries(self) -> pd.DataFrame: """ Filters the boundaries in the sample to include only those within @@ -844,20 +821,19 @@ def get_filtered_boundaries(self) -> pd.DataFrame: Returns ------- pd.DataFrame - A DataFrame containing the filtered boundaries within the tile + A DataFrame containing the filtered boundaries within the tile extents. """ filtered_boundaries = utils.filter_boundaries( boundaries=self.dataset.boundaries, inset=self.extents, - outset=self.extents.buffer(self.margin, join_style='mitre'), + outset=self.extents.buffer(self.margin, join_style="mitre"), x=self.settings.boundaries.x, y=self.settings.boundaries.y, label=self.settings.boundaries.label, ) return filtered_boundaries - def get_filtered_transcripts(self) -> pd.DataFrame: """ Filters the transcripts in the sample to include only those within @@ -866,13 +842,13 @@ def get_filtered_transcripts(self) -> pd.DataFrame: Returns ------- pd.DataFrame - A DataFrame containing the filtered transcripts within the tile + A DataFrame containing the filtered transcripts within the tile extents. """ # Buffer tile bounds to include transcripts around boundary - outset = self.extents.buffer(self.margin, join_style='mitre') - xmin, ymin, xmax, ymax = outset.bounds + outset = self.extents.buffer(self.margin, join_style="mitre") + xmin, ymin, xmax, ymax = outset.bounds # Get transcripts inside buffered region x, y = self.settings.transcripts.xy @@ -882,7 +858,6 @@ def get_filtered_transcripts(self) -> pd.DataFrame: return filtered_transcripts - def get_transcript_props(self) -> torch.Tensor: """ Encodes transcript features in a sparse format. @@ -894,9 +869,9 @@ def get_transcript_props(self) -> torch.Tensor: Notes ----- - The intention is for this function to simplify testing new strategies + The intention is for this function to simplify testing new strategies for 'tx' node representations. For example, the encoder can be any type - of encoder that transforms the transcript labels into a numerical + of encoder that transforms the transcript labels into a numerical matrix (in sparse format). """ # Encode transcript features in sparse format @@ -906,7 +881,6 @@ def get_transcript_props(self) -> torch.Tensor: return props - @staticmethod def get_polygon_props( polygons: gpd.GeoSeries, @@ -938,18 +912,17 @@ def get_polygon_props( """ props = pd.DataFrame(index=polygons.index, dtype=float) if area: - props['area'] = polygons.area + props["area"] = polygons.area if convexity: - props['convexity'] = polygons.convex_hull.area / polygons.area + props["convexity"] = polygons.convex_hull.area / polygons.area if elongation: rects = polygons.minimum_rotated_rectangle() - props['elongation'] = rects.area / polygons.envelope.area + props["elongation"] = rects.area / polygons.envelope.area if circularity: r = polygons.minimum_bounding_radius() - props["circularity"] = polygons.area / r ** 2 - - return props + props["circularity"] = polygons.area / r**2 + return props @staticmethod def get_kdtree_edge_index( @@ -993,7 +966,6 @@ def get_kdtree_edge_index( return edge_index - def get_boundary_props( self, area: bool = True, @@ -1007,29 +979,29 @@ def get_boundary_props( Parameters ---------- area : bool, optional - If True, compute the area of each boundary polygon (default is + If True, compute the area of each boundary polygon (default is True). convexity : bool, optional - If True, compute the convexity of each boundary polygon (default is + If True, compute the convexity of each boundary polygon (default is True). elongation : bool, optional If True, compute the elongation of each boundary polygon (default is True). circularity : bool, optional - If True, compute the circularity of each boundary polygon (default + If True, compute the circularity of each boundary polygon (default is True). Returns ------- torch.Tensor - A tensor containing the computed properties for each boundary + A tensor containing the computed properties for each boundary polygon. Notes ----- - The intention is for this function to simplify testing new strategies + The intention is for this function to simplify testing new strategies for 'bd' node representations. You can just change the function body to - return another torch.Tensor without worrying about changes to the rest + return another torch.Tensor without worrying about changes to the rest of the code. """ # Get polygons from coordinates @@ -1045,10 +1017,9 @@ def get_boundary_props( return props - def to_pyg_dataset( self, - #train: bool, + # train: bool, neg_sampling_ratio: float = 5, k_bd: int = 3, dist_bd: float = 15, @@ -1066,7 +1037,7 @@ def to_pyg_dataset( Parameters ---------- train: bool - Whether a sample is part of the training dataset. If True, add + Whether a sample is part of the training dataset. If True, add negative edges to dataset. k_bd : int, optional The number of nearest neighbors for the 'bd' nodes (default is 4). @@ -1142,7 +1113,7 @@ def to_pyg_dataset( Edge indices in COO format between transcripts and boundaries 3. ("tx", "neighbors", "tx") - Represents the relationship where a transcript is nearby another + Represents the relationship where a transcript is nearby another transcript. Attributes @@ -1154,15 +1125,15 @@ def to_pyg_dataset( pyg_data = HeteroData() # Set up Transcript nodes - pyg_data['tx'].id = torch.tensor( + pyg_data["tx"].id = torch.tensor( self.transcripts[self.settings.transcripts.id].values.astype(int), dtype=torch.int, ) - pyg_data['tx'].pos = torch.tensor( + pyg_data["tx"].pos = torch.tensor( self.transcripts[self.settings.transcripts.xyz].values, dtype=torch.float32, ) - pyg_data['tx'].x = self.get_transcript_props() + pyg_data["tx"].x = self.get_transcript_props() # Set up Transcript-Transcript neighbor edges nbrs_edge_idx = self.get_kdtree_edge_index( @@ -1187,11 +1158,9 @@ def to_pyg_dataset( self.settings.boundaries.label, ) centroids = polygons.centroid.get_coordinates() - pyg_data['bd'].id = polygons.index.to_numpy() - pyg_data['bd'].pos = torch.tensor(centroids.values, dtype=torch.float32) - pyg_data['bd'].x = self.get_boundary_props( - area, convexity, elongation, circularity - ) + pyg_data["bd"].id = polygons.index.to_numpy() + pyg_data["bd"].pos = torch.tensor(centroids.values, dtype=torch.float32) + pyg_data["bd"].x = self.get_boundary_props(area, convexity, elongation, circularity) # Set up Boundary-Transcript neighbor edges dist = np.sqrt(polygons.area.max()) * 10 # heuristic distance @@ -1208,16 +1177,14 @@ def to_pyg_dataset( logging.warning(f"No tx-neighbors-bd edges found in tile {self.uid}.") pyg_data["tx", "belongs", "bd"].edge_index = torch.tensor([], dtype=torch.long) return pyg_data - + # Now we identify and split the tx-belongs-bd edges - edge_type = ('tx', 'belongs', 'bd') + edge_type = ("tx", "belongs", "bd") # Find nuclear transcripts tx_cell_ids = self.transcripts[self.settings.boundaries.id] cell_ids_map = {idx: i for (i, idx) in enumerate(polygons.index)} - is_nuclear = self.transcripts[ - self.settings.transcripts.nuclear - ].astype(bool) + is_nuclear = self.transcripts[self.settings.transcripts.nuclear].astype(bool) is_nuclear &= tx_cell_ids.isin(polygons.index) # Set up overlap edges @@ -1242,11 +1209,10 @@ def to_pyg_dataset( ) pyg_data, _, _ = transform(pyg_data) - # Refilter negative edges to include only transcripts in the + # Refilter negative edges to include only transcripts in the # original positive edges (still need a memory-efficient solution) edges = pyg_data[edge_type] - mask = edges.edge_label_index[0].unsqueeze(1) == \ - edges.edge_index[0].unsqueeze(0) + mask = edges.edge_label_index[0].unsqueeze(1) == edges.edge_index[0].unsqueeze(0) mask = torch.nonzero(torch.any(mask, 1)).squeeze() edges.edge_label_index = edges.edge_label_index[:, mask] edges.edge_label = edges.edge_label[mask] diff --git a/src/segger/data/parquet/transcript_embedding.py b/src/segger/data/parquet/transcript_embedding.py index 2f8085c..8abeebc 100644 --- a/src/segger/data/parquet/transcript_embedding.py +++ b/src/segger/data/parquet/transcript_embedding.py @@ -6,14 +6,15 @@ from numpy.typing import ArrayLike import pandas as pd + # TODO: Add documentation class TranscriptEmbedding(torch.nn.Module): - ''' + """ Utility class to handle transcript embeddings in PyTorch so that they are optionally learnable in the future. - + Default behavior is to use the index of gene names. - ''' + """ # TODO: Add documentation @staticmethod @@ -23,26 +24,17 @@ def _check_inputs( ): # Classes is a 1D array if len(classes.shape) > 1: - msg = ( - "'classes' should be a 1D array, got an array of shape " - f"{classes.shape} instead." - ) + msg = "'classes' should be a 1D array, got an array of shape " f"{classes.shape} instead." raise ValueError(msg) # Items appear exactly once if len(classes) != len(set(classes)): - msg = ( - "All embedding classes must be unique. One or more items in " - "'classes' appears twice." - ) + msg = "All embedding classes must be unique. One or more items in " "'classes' appears twice." raise ValueError(msg) # All classes have an entry in weights elif weights is not None: missing = set(classes).difference(weights.index) if len(missing) > 0: - msg = ( - f"Index of 'weights' DataFrame is missing {len(missing)} " - "entries compared to classes." - ) + msg = f"Index of 'weights' DataFrame is missing {len(missing)} " "entries compared to classes." raise ValueError(msg) # TODO: Add documentation @@ -66,6 +58,6 @@ def embed(self, classes: ArrayLike): indices = LongTensor(self._encoder.transform(classes)) # Default, one-hot encoding if self._weights is None: - return indices #F.one_hot(indices, len(self._encoder.classes_)) + return indices # F.one_hot(indices, len(self._encoder.classes_)) else: return F.embedding(indices, self._weights) diff --git a/src/segger/data/utils.py b/src/segger/data/utils.py index 3abd5b1..b673a87 100644 --- a/src/segger/data/utils.py +++ b/src/segger/data/utils.py @@ -5,6 +5,7 @@ def try_import(module_name): except ImportError: print(f"Warning: {module_name} is not installed. Please install it to use this functionality.") + # Standard imports import pandas as pd import numpy as np @@ -20,6 +21,7 @@ def try_import(module_name): from torch_geometric.nn import radius_graph import os from scipy.spatial import cKDTree + # import hnswlib from shapely.geometry import Polygon from shapely.affinity import scale @@ -28,10 +30,10 @@ def try_import(module_name): import sys # Attempt to import specific modules with try_import function -try_import('multiprocessing') -try_import('joblib') -try_import('faiss') -try_import('cuvs') +try_import("multiprocessing") +try_import("joblib") +try_import("faiss") +try_import("cuvs") try: import cupy as cp from cuvs.neighbors import cagra @@ -42,8 +44,6 @@ def try_import(module_name): from datetime import timedelta - - def filter_transcripts( transcripts_df: pd.DataFrame, min_qv: float = 20.0, @@ -64,7 +64,7 @@ def filter_transcripts( "NegControlCodeword_", "BLANK_", "DeprecatedCodeword_", - "UnassignedCodeword_" + "UnassignedCodeword_", ) mask = transcripts_df["qv"].ge(min_qv) mask &= ~transcripts_df["feature_name"].str.startswith(filter_codewords) @@ -72,9 +72,7 @@ def filter_transcripts( def compute_transcript_metrics( - df: pd.DataFrame, - qv_threshold: float = 30, - cell_id_col: str = 'cell_id' + df: pd.DataFrame, qv_threshold: float = 30, cell_id_col: str = "cell_id" ) -> Dict[str, Any]: """ Computes various metrics for a given dataframe of transcript data filtered by quality value threshold. @@ -92,44 +90,48 @@ def compute_transcript_metrics( - 'percent_non_assigned_cytoplasmic' (float): The percentage of non-assigned cytoplasmic transcripts. - 'gene_metrics' (pd.DataFrame): A dataframe containing gene-level metrics. """ - df_filtered = df[df['qv'] > qv_threshold] + df_filtered = df[df["qv"] > qv_threshold] total_transcripts = len(df_filtered) assigned_transcripts = df_filtered[df_filtered[cell_id_col] != -1] - percent_assigned = len(assigned_transcripts) / (total_transcripts+1) * 100 - cytoplasmic_transcripts = assigned_transcripts[assigned_transcripts['overlaps_nucleus'] != 1] - percent_cytoplasmic = len(cytoplasmic_transcripts) / (len(assigned_transcripts) + 1)* 100 + percent_assigned = len(assigned_transcripts) / (total_transcripts + 1) * 100 + cytoplasmic_transcripts = assigned_transcripts[assigned_transcripts["overlaps_nucleus"] != 1] + percent_cytoplasmic = len(cytoplasmic_transcripts) / (len(assigned_transcripts) + 1) * 100 percent_nucleus = 100 - percent_cytoplasmic non_assigned_transcripts = df_filtered[df_filtered[cell_id_col] == -1] - non_assigned_cytoplasmic = non_assigned_transcripts[non_assigned_transcripts['overlaps_nucleus'] != 1] - percent_non_assigned_cytoplasmic = len(non_assigned_cytoplasmic) / (len(non_assigned_transcripts)+1) * 100 - gene_group_assigned = assigned_transcripts.groupby('feature_name') - gene_group_all = df_filtered.groupby('feature_name') - gene_percent_assigned = (gene_group_assigned.size() / (gene_group_all.size()+1) * 100).reset_index(names='percent_assigned') - cytoplasmic_gene_group = cytoplasmic_transcripts.groupby('feature_name') - gene_percent_cytoplasmic = (cytoplasmic_gene_group.size() / (len(cytoplasmic_transcripts)+1) * 100).reset_index(name='percent_cytoplasmic') - gene_metrics = pd.merge(gene_percent_assigned, gene_percent_cytoplasmic, on='feature_name', how='outer').fillna(0) + non_assigned_cytoplasmic = non_assigned_transcripts[non_assigned_transcripts["overlaps_nucleus"] != 1] + percent_non_assigned_cytoplasmic = len(non_assigned_cytoplasmic) / (len(non_assigned_transcripts) + 1) * 100 + gene_group_assigned = assigned_transcripts.groupby("feature_name") + gene_group_all = df_filtered.groupby("feature_name") + gene_percent_assigned = (gene_group_assigned.size() / (gene_group_all.size() + 1) * 100).reset_index( + names="percent_assigned" + ) + cytoplasmic_gene_group = cytoplasmic_transcripts.groupby("feature_name") + gene_percent_cytoplasmic = (cytoplasmic_gene_group.size() / (len(cytoplasmic_transcripts) + 1) * 100).reset_index( + name="percent_cytoplasmic" + ) + gene_metrics = pd.merge(gene_percent_assigned, gene_percent_cytoplasmic, on="feature_name", how="outer").fillna(0) results = { - 'percent_assigned': percent_assigned, - 'percent_cytoplasmic': percent_cytoplasmic, - 'percent_nucleus': percent_nucleus, - 'percent_non_assigned_cytoplasmic': percent_non_assigned_cytoplasmic, - 'gene_metrics': gene_metrics + "percent_assigned": percent_assigned, + "percent_cytoplasmic": percent_cytoplasmic, + "percent_nucleus": percent_nucleus, + "percent_non_assigned_cytoplasmic": percent_non_assigned_cytoplasmic, + "gene_metrics": gene_metrics, } return results def create_anndata( - df: pd.DataFrame, - panel_df: Optional[pd.DataFrame] = None, - min_transcripts: int = 5, - cell_id_col: str = 'cell_id', - qv_threshold: float = 30, - min_cell_area: float = 10.0, - max_cell_area: float = 1000.0 + df: pd.DataFrame, + panel_df: Optional[pd.DataFrame] = None, + min_transcripts: int = 5, + cell_id_col: str = "cell_id", + qv_threshold: float = 30, + min_cell_area: float = 10.0, + max_cell_area: float = 1000.0, ) -> ad.AnnData: """ Generates an AnnData object from a dataframe of segmented transcriptomics data. - + Parameters: df (pd.DataFrame): The dataframe containing segmented transcriptomics data. panel_df (Optional[pd.DataFrame]): The dataframe containing panel information. @@ -138,24 +140,23 @@ def create_anndata( qv_threshold (float): The quality value threshold for filtering transcripts. min_cell_area (float): The minimum cell area to include a cell. max_cell_area (float): The maximum cell area to include a cell. - + Returns: ad.AnnData: The generated AnnData object containing the transcriptomics data and metadata. """ # df_filtered = filter_transcripts(df, min_qv=qv_threshold) df_filtered = df # metrics = compute_transcript_metrics(df_filtered, qv_threshold, cell_id_col) - df_filtered = df_filtered[df_filtered[cell_id_col].astype(str) != '-1'] - pivot_df = df_filtered.rename(columns={ - cell_id_col: "cell", - "feature_name": "gene" - })[['cell', 'gene']].pivot_table(index='cell', columns='gene', aggfunc='size', fill_value=0) + df_filtered = df_filtered[df_filtered[cell_id_col].astype(str) != "-1"] + pivot_df = df_filtered.rename(columns={cell_id_col: "cell", "feature_name": "gene"})[["cell", "gene"]].pivot_table( + index="cell", columns="gene", aggfunc="size", fill_value=0 + ) pivot_df = pivot_df[pivot_df.sum(axis=1) >= min_transcripts] cell_summary = [] for cell_id, cell_data in df_filtered.groupby(cell_id_col): if len(cell_data) < min_transcripts: continue - cell_convex_hull = ConvexHull(cell_data[['x_location', 'y_location']], qhull_options='QJ') + cell_convex_hull = ConvexHull(cell_data[["x_location", "y_location"]], qhull_options="QJ") cell_area = cell_convex_hull.area if cell_area < min_cell_area or cell_area > max_cell_area: continue @@ -167,47 +168,50 @@ def create_anndata( # nucleus_convex_hull = ConvexHull(nucleus_data[['x_location', 'y_location']]) # else: # nucleus_convex_hull = None - cell_summary.append({ - "cell": cell_id, - "cell_centroid_x": cell_data['x_location'].mean(), - "cell_centroid_y": cell_data['y_location'].mean(), - "cell_area": cell_area, - # "nucleus_centroid_x": nucleus_data['x_location'].mean() if len(nucleus_data) > 0 else cell_data['x_location'].mean(), - # "nucleus_centroid_y": nucleus_data['x_location'].mean() if len(nucleus_data) > 0 else cell_data['x_location'].mean(), - # "nucleus_area": nucleus_convex_hull.area if nucleus_convex_hull else 0, - # "percent_cytoplasmic": len(cell_data[cell_data['overlaps_nucleus'] != 1]) / len(cell_data) * 100, - # "has_nucleus": len(nucleus_data) > 0 - }) + cell_summary.append( + { + "cell": cell_id, + "cell_centroid_x": cell_data["x_location"].mean(), + "cell_centroid_y": cell_data["y_location"].mean(), + "cell_area": cell_area, + # "nucleus_centroid_x": nucleus_data['x_location'].mean() if len(nucleus_data) > 0 else cell_data['x_location'].mean(), + # "nucleus_centroid_y": nucleus_data['x_location'].mean() if len(nucleus_data) > 0 else cell_data['x_location'].mean(), + # "nucleus_area": nucleus_convex_hull.area if nucleus_convex_hull else 0, + # "percent_cytoplasmic": len(cell_data[cell_data['overlaps_nucleus'] != 1]) / len(cell_data) * 100, + # "has_nucleus": len(nucleus_data) > 0 + } + ) cell_summary = pd.DataFrame(cell_summary).set_index("cell") if panel_df is not None: - panel_df = panel_df.sort_values('gene') - genes = panel_df['gene'].values + panel_df = panel_df.sort_values("gene") + genes = panel_df["gene"].values for gene in genes: if gene not in pivot_df: pivot_df[gene] = 0 pivot_df = pivot_df[genes.tolist()] if panel_df is None: - var_df = pd.DataFrame([{ - "gene": i, - "feature_types": 'Gene Expression', - 'genome': 'Unknown' - } for i in np.unique(pivot_df.columns.values)]).set_index('gene') + var_df = pd.DataFrame( + [ + {"gene": i, "feature_types": "Gene Expression", "genome": "Unknown"} + for i in np.unique(pivot_df.columns.values) + ] + ).set_index("gene") else: - var_df = panel_df[['gene', 'ensembl']].rename(columns={'ensembl':'gene_ids'}) - var_df['feature_types'] = 'Gene Expression' - var_df['genome'] = 'Unknown' - var_df = var_df.set_index('gene') + var_df = panel_df[["gene", "ensembl"]].rename(columns={"ensembl": "gene_ids"}) + var_df["feature_types"] = "Gene Expression" + var_df["genome"] = "Unknown" + var_df = var_df.set_index("gene") # gene_metrics = metrics['gene_metrics'].set_index('feature_name') # var_df = var_df.join(gene_metrics, how='left').fillna(0) cells = list(set(pivot_df.index) & set(cell_summary.index)) - pivot_df = pivot_df.loc[cells,:] - cell_summary = cell_summary.loc[cells,:] + pivot_df = pivot_df.loc[cells, :] + cell_summary = cell_summary.loc[cells, :] adata = ad.AnnData(pivot_df.values) adata.var = var_df - adata.obs['transcripts'] = pivot_df.sum(axis=1).values - adata.obs['unique_transcripts'] = (pivot_df > 0).sum(axis=1).values + adata.obs["transcripts"] = pivot_df.sum(axis=1).values + adata.obs["unique_transcripts"] = (pivot_df > 0).sum(axis=1).values adata.obs_names = pivot_df.index.values.tolist() - adata.obs = pd.merge(adata.obs, cell_summary.loc[adata.obs_names,:], left_index=True, right_index=True) + adata.obs = pd.merge(adata.obs, cell_summary.loc[adata.obs_names, :], left_index=True, right_index=True) # adata.uns['metrics'] = { # 'percent_assigned': metrics['percent_assigned'], # 'percent_cytoplasmic': metrics['percent_cytoplasmic'], @@ -216,10 +220,9 @@ def create_anndata( # } return adata - def calculate_gene_celltype_abundance_embedding(adata: ad.AnnData, celltype_column: str) -> pd.DataFrame: - """Calculate the cell type abundance embedding for each gene based on the percentage of cells in each cell type + """Calculate the cell type abundance embedding for each gene based on the percentage of cells in each cell type that express the gene (non-zero expression). Parameters: @@ -227,9 +230,9 @@ def calculate_gene_celltype_abundance_embedding(adata: ad.AnnData, celltype_colu celltype_column (str): The column name in `adata.obs` that contains the cell type information. Returns: - pd.DataFrame: A DataFrame where rows are genes and columns are cell types, with each value representing + pd.DataFrame: A DataFrame where rows are genes and columns are cell types, with each value representing the percentage of cells in that cell type expressing the gene. - + Example: >>> adata = AnnData(...) # Load your scRNA-seq AnnData object >>> celltype_column = 'celltype_major' @@ -255,13 +258,21 @@ def calculate_gene_celltype_abundance_embedding(adata: ad.AnnData, celltype_colu abundance = gene_expression_df[cell_type_mask].mean(axis=0) * 100 cell_type_abundance_list.append(abundance) # Create a DataFrame for the cell type abundance with gene names as rows and cell types as columns - cell_type_abundance_df = pd.DataFrame(cell_type_abundance_list, - columns=adata.var_names, - index=encoder.categories_[0]).T + cell_type_abundance_df = pd.DataFrame( + cell_type_abundance_list, columns=adata.var_names, index=encoder.categories_[0] + ).T return cell_type_abundance_df -def get_edge_index(coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5, dist: int = 10, method: str = 'kd_tree', - gpu: bool = False, workers: int = 1) -> torch.Tensor: + +def get_edge_index( + coords_1: np.ndarray, + coords_2: np.ndarray, + k: int = 5, + dist: int = 10, + method: str = "kd_tree", + gpu: bool = False, + workers: int = 1, +) -> torch.Tensor: """ Computes edge indices using various methods (KD-Tree, FAISS, RAPIDS::cuvs+cupy (cuda)). @@ -276,23 +287,21 @@ def get_edge_index(coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5, dist: Returns: torch.Tensor: Edge indices. """ - if method == 'kd_tree': + if method == "kd_tree": return get_edge_index_kdtree(coords_1, coords_2, k=k, dist=dist, workers=workers) - elif method == 'faiss': + elif method == "faiss": return get_edge_index_faiss(coords_1, coords_2, k=k, dist=dist, gpu=gpu) - elif method == 'cuda': + elif method == "cuda": # pass return get_edge_index_cuda(coords_1, coords_2, k=k, dist=dist) else: - msg = ( - f"Unknown method {method}. Valid methods include: 'kd_tree', " - "'faiss', and 'cuda'." - ) + msg = f"Unknown method {method}. Valid methods include: 'kd_tree', " "'faiss', and 'cuda'." raise ValueError() - -def get_edge_index_kdtree(coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5, dist: int = 10, workers: int = 1) -> torch.Tensor: +def get_edge_index_kdtree( + coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5, dist: int = 10, workers: int = 1 +) -> torch.Tensor: """ Computes edge indices using KDTree. @@ -313,15 +322,15 @@ def get_edge_index_kdtree(coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5 for idx, valid in enumerate(valid_mask): valid_indices = idx_out[idx][valid] if valid_indices.size > 0: - edges.append( - np.vstack((np.full(valid_indices.shape, idx), valid_indices)).T - ) + edges.append(np.vstack((np.full(valid_indices.shape, idx), valid_indices)).T) edge_index = torch.tensor(np.vstack(edges), dtype=torch.long).contiguous() return edge_index -def get_edge_index_faiss(coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5, dist: int = 10, gpu: bool = False) -> torch.Tensor: +def get_edge_index_faiss( + coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5, dist: int = 10, gpu: bool = False +) -> torch.Tensor: """ Computes edge indices using FAISS. @@ -344,30 +353,28 @@ def get_edge_index_faiss(coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5, else: index = faiss.IndexFlatL2(d) - index.add(coords_1.astype('float32')) - D, I = index.search(coords_2.astype('float32'), k) + index.add(coords_1.astype("float32")) + D, I = index.search(coords_2.astype("float32"), k) - valid_mask = D < dist ** 2 + valid_mask = D < dist**2 edges = [] for idx, valid in enumerate(valid_mask): valid_indices = I[idx][valid] if valid_indices.size > 0: - edges.append( - np.vstack((np.full(valid_indices.shape, idx), valid_indices)).T - ) + edges.append(np.vstack((np.full(valid_indices.shape, idx), valid_indices)).T) edge_index = torch.tensor(np.vstack(edges), dtype=torch.long).contiguous() return edge_index def get_edge_index_cuda( - coords_1: torch.Tensor, - coords_2: torch.Tensor, - k: int = 10, + coords_1: torch.Tensor, + coords_2: torch.Tensor, + k: int = 10, dist: float = 10.0, metric: str = "sqeuclidean", - nn_descent_niter: int = 100 + nn_descent_niter: int = 100, ) -> torch.Tensor: """ Computes edge indices using RAPIDS cuVS with cagra for vector similarity search, @@ -382,11 +389,14 @@ def get_edge_index_cuda( Returns: torch.Tensor: Edge indices as a PyTorch tensor on CUDA. """ + def cupy_to_torch(cupy_array): return torch.from_dlpack((cupy_array.toDlpack())) + # gg def torch_to_cupy(tensor): return cp.fromDlpack(dlpack.to_dlpack(tensor)) + # Convert PyTorch tensors (CUDA) to CuPy arrays using DLPack cp_coords_1 = torch_to_cupy(coords_1).astype(cp.float32) cp_coords_2 = torch_to_cupy(coords_2).astype(cp.float32) @@ -394,14 +404,16 @@ def torch_to_cupy(tensor): cp_dist = cp.float32(dist) # IndexParams and SearchParams for cagra # compression_params = cagra.CompressionParams(pq_bits=pq_bits) - index_params = cagra.IndexParams(metric=metric,nn_descent_niter=nn_descent_niter) #, compression=compression_params) + index_params = cagra.IndexParams( + metric=metric, nn_descent_niter=nn_descent_niter + ) # , compression=compression_params) search_params = cagra.SearchParams() # Build index using CuPy coords index = cagra.build_index(index_params, cp_coords_1) # Perform search to get distances and indices (still in CuPy) D, I = cagra.search(search_params, index, cp_coords_2, k) # Boolean mask for filtering distances below the squared threshold (all in CuPy) - valid_mask = cp.asarray(D < cp_dist ** 2) + valid_mask = cp.asarray(D < cp_dist**2) # Vectorized operations for row and valid indices (all in CuPy) repeats = valid_mask.sum(axis=1).tolist() row_indices = cp.repeat(cp.arange(len(cp_coords_2)), repeats) @@ -412,6 +424,7 @@ def torch_to_cupy(tensor): edge_index = cupy_to_torch(edges).long().contiguous() return edge_index + class SpatialTranscriptomicsDataset(InMemoryDataset): """A dataset class for handling SpatialTranscriptomics spatial transcriptomics data. @@ -421,7 +434,10 @@ class SpatialTranscriptomicsDataset(InMemoryDataset): pre_transform (callable): A function/transform that takes in a Data object and returns a transformed version. pre_filter (callable): A function that takes in a Data object and returns a boolean indicating whether to keep it. """ - def __init__(self, root: str, transform: Callable = None, pre_transform: Callable = None, pre_filter: Callable = None): + + def __init__( + self, root: str, transform: Callable = None, pre_transform: Callable = None, pre_filter: Callable = None + ): """Initialize the SpatialTranscriptomicsDataset. Args: @@ -448,16 +464,14 @@ def processed_file_names(self) -> List[str]: Returns: List[str]: List of processed file names. """ - return [x for x in os.listdir(self.processed_dir) if 'tiles' in x] + return [x for x in os.listdir(self.processed_dir) if "tiles" in x] def download(self) -> None: - """Download the raw data. This method should be overridden if you need to download the data. - """ + """Download the raw data. This method should be overridden if you need to download the data.""" pass def process(self) -> None: - """Process the raw data and save it to the processed directory. This method should be overridden if you need to process the data. - """ + """Process the raw data and save it to the processed directory. This method should be overridden if you need to process the data.""" pass def len(self) -> int: @@ -478,7 +492,7 @@ def get(self, idx: int) -> Data: Data: The processed data object. """ data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx])) - data['tx'].x = data['tx'].x.to_dense() + data["tx"].x = data["tx"].x.to_dense() return data @@ -531,8 +545,7 @@ def coo_to_dense_adj( # Check COO format if not edge_index.shape[0] == 2: msg = ( - "Edge index is not in COO format. First dimension should have " - f"size 2, but found {edge_index.shape[0]}." + "Edge index is not in COO format. First dimension should have " f"size 2, but found {edge_index.shape[0]}." ) raise ValueError(msg) @@ -547,39 +560,23 @@ def coo_to_dense_adj( # Fill matrix with neighbors nbr_idx = torch.full((num_nodes, num_nbrs), -1) for i, nbrs in zip(uniques, torch.split(edge_index[1], counts)): - nbr_idx[i, :len(nbrs)] = nbrs + nbr_idx[i, : len(nbrs)] = nbrs return nbr_idx - - - def format_time(elapsed: float) -> str: """ Format elapsed time to h:m:s. - + Parameters: ---------- elapsed : float Elapsed time in seconds. - + Returns: ------- str Formatted time in h:m:s. """ return str(timedelta(seconds=int(elapsed))) - - - - - - - - - - - - - diff --git a/src/segger/models/README.md b/src/segger/models/README.md index 1f872b3..033e545 100644 --- a/src/segger/models/README.md +++ b/src/segger/models/README.md @@ -1,4 +1,3 @@ - # segger: Graph Neural Network Model The `segger` model is a graph neural network designed to handle heterogeneous graphs with two primary node types: **transcripts** and **nuclei or cell boundaries**. It leverages attention-based convolutional layers to compute node embeddings and relationships in spatial transcriptomics data. The architecture includes an initial embedding layer for node feature transformation, multiple graph attention layers (GATv2Conv), and residual linear connections. @@ -32,7 +31,8 @@ The `segger` model is a graph neural network designed to handle heterogeneous gr $$ where: - - \( \alpha_{ij} \) is the attention coefficient between node \( i \) and node \( j \), computed as: + + - \( \alpha\_{ij} \) is the attention coefficient between node \( i \) and node \( j \), computed as: $$ \alpha_{ij} = \frac{\exp\left( \text{LeakyReLU}\left( \mathbf{a}^{\top} [\mathbf{W}^{(l)} \mathbf{h}_{i}^{(l)} || \mathbf{W}^{(l)} \mathbf{h}_{j}^{(l)}] \right)\right)}{\sum_{k \in \mathcal{N}(i)} \exp\left( \text{LeakyReLU}\left( \mathbf{a}^{\top} [\mathbf{W}^{(l)} \mathbf{h}_{i}^{(l)} || \mathbf{W}^{(l)} \mathbf{h}_{k}^{(l)}] \right)\right)} @@ -47,7 +47,7 @@ The `segger` model is a graph neural network designed to handle heterogeneous gr \mathbf{h}_{i}^{(l+1)} = \text{ReLU}\left( \mathbf{h}_{i}^{(l+1)} + \mathbf{W}_{res} \mathbf{h}_{i}^{(l)} \right) $$ - where \( \mathbf{W}_{res} \) is a residual weight matrix. + where \( \mathbf{W}\_{res} \) is a residual weight matrix. 4. **L2 Normalization**: Finally, the embeddings are normalized using L2 normalization: @@ -62,23 +62,21 @@ The `segger` model is a graph neural network designed to handle heterogeneous gr In the next step, the `segger` model is transformed into a **heterogeneous graph neural network** using PyTorch Geometric's `to_hetero` function. This transformation enables the model to handle distinct node and edge types (transcripts and nuclei or cell boundaries) with separate mechanisms for modeling their relationships. - ## Usage To instantiate and run the segger model: ```python model = segger( - num_tx_tokens=5000, # Number of unique 'tx' tokens - init_emb=32, # Initial embedding dimension - hidden_channels=64, # Number of hidden channels - num_mid_layers=2, # Number of middle layers - out_channels=128, # Number of output channels - heads=4 # Number of attention heads + num_tx_tokens=5000, # Number of unique 'tx' tokens + init_emb=32, # Initial embedding dimension + hidden_channels=64, # Number of hidden channels + num_mid_layers=2, # Number of middle layers + out_channels=128, # Number of output channels + heads=4, # Number of attention heads ) output = model(x, edge_index) ``` Once transformed to a heterogeneous model and trained using PyTorch Lightning, the model can efficiently learn relationships between transcripts and nuclei or cell boundaries. - diff --git a/src/segger/models/__init__.py b/src/segger/models/__init__.py index 1271af3..0a66407 100644 --- a/src/segger/models/__init__.py +++ b/src/segger/models/__init__.py @@ -4,8 +4,6 @@ Contains the implementation of the Segger model using Graph Neural Networks. """ -__all__ = [ - "Segger" - ] +__all__ = ["Segger"] from .segger_model import * diff --git a/src/segger/models/segger_model.py b/src/segger/models/segger_model.py index d2e13ad..6943dab 100644 --- a/src/segger/models/segger_model.py +++ b/src/segger/models/segger_model.py @@ -3,10 +3,20 @@ from torch.nn import Embedding from torch import Tensor from typing import Union -#from torch_sparse import SparseTensor + +# from torch_sparse import SparseTensor + class Segger(torch.nn.Module): - def __init__(self, num_tx_tokens: int, init_emb: int = 16, hidden_channels: int = 32, num_mid_layers: int = 3, out_channels: int = 32, heads: int = 3): + def __init__( + self, + num_tx_tokens: int, + init_emb: int = 16, + hidden_channels: int = 32, + num_mid_layers: int = 3, + out_channels: int = 32, + heads: int = 3, + ): """ Initializes the Segger model. @@ -54,27 +64,26 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: Returns: Tensor: Output node embeddings. """ - x = torch.nan_to_num(x, nan = 0) + x = torch.nan_to_num(x, nan=0) is_one_dim = (x.ndim == 1) * 1 - # x = x[:, None] - x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim) + # x = x[:, None] + x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim) # First layer x = x.relu() - x = self.conv_first(x, edge_index) # + self.lin_first(x) + x = self.conv_first(x, edge_index) # + self.lin_first(x) x = x.relu() # Middle layers if self.num_mid_layers > 0: - for conv_mid in self.conv_mid_layers: - x = conv_mid(x, edge_index) # + lin_mid(x) + for conv_mid in self.conv_mid_layers: + x = conv_mid(x, edge_index) # + lin_mid(x) x = x.relu() # Last layer - x = self.conv_last(x, edge_index) # + self.lin_last(x) + x = self.conv_last(x, edge_index) # + self.lin_last(x) return x - def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor: """ Decode the node embeddings to predict edge values. diff --git a/src/segger/prediction/__init__.py b/src/segger/prediction/__init__.py index abc96d9..f82a9cc 100644 --- a/src/segger/prediction/__init__.py +++ b/src/segger/prediction/__init__.py @@ -4,9 +4,6 @@ Contains prediction scripts and utilities for the Segger model. """ -__all__ = [ - "load_model", - "predict" - ] +__all__ = ["load_model", "predict"] from .predict import load_model, predict diff --git a/src/segger/prediction/predict.py b/src/segger/prediction/predict.py index cf73116..337a3c1 100644 --- a/src/segger/prediction/predict.py +++ b/src/segger/prediction/predict.py @@ -40,8 +40,8 @@ from cupyx.scipy.sparse import find # To find non-zero elements in sparse matrix from scipy.sparse.csgraph import connected_components as cc from scipy.sparse import coo_matrix as scipy_coo_matrix -# Setup Dask cluster with 3 workers +# Setup Dask cluster with 3 workers # CONFIG @@ -57,7 +57,7 @@ def load_model(checkpoint_path: str) -> LitSegger: Parameters ---------- checkpoint_path : str - Specific checkpoint file to load, or directory where the model checkpoints are stored. + Specific checkpoint file to load, or directory where the model checkpoints are stored. If directory, the latest checkpoint is loaded. Returns @@ -75,13 +75,15 @@ def load_model(checkpoint_path: str) -> LitSegger: # Get last checkpoint if directory is provided if os.path.isdir(checkpoint_path): - checkpoints = glob.glob(str(checkpoint_path / '*.ckpt')) + checkpoints = glob.glob(str(checkpoint_path / "*.ckpt")) if len(checkpoints) == 0: raise FileNotFoundError(msg) + # Sort checkpoints by epoch and step def sort_order(c): - match = re.match(r'.*epoch=(\d+)-step=(\d+).ckpt', c) + match = re.match(r".*epoch=(\d+)-step=(\d+).ckpt", c) return int(match[1]), int(match[2]) + checkpoint_path = Path(sorted(checkpoints, key=sort_order)[-1]) elif not checkpoint_path.exists(): raise FileExistsError(msg) @@ -94,16 +96,11 @@ def sort_order(c): return lit_segger - def get_similarity_scores( - model: torch.nn.Module, - batch: Batch, - from_type: str, - to_type: str, - receptive_field: dict + model: torch.nn.Module, batch: Batch, from_type: str, to_type: str, receptive_field: dict ) -> coo_matrix: """ - Compute similarity scores between embeddings for 'from_type' and 'to_type' nodes + Compute similarity scores between embeddings for 'from_type' and 'to_type' nodes using sparse matrix multiplication with CuPy and the 'sees' edge relation. Args: @@ -113,7 +110,7 @@ def get_similarity_scores( to_type (str): The type of node to which the similarity is computed. Returns: - coo_matrix: A sparse matrix containing the similarity scores between + coo_matrix: A sparse matrix containing the similarity scores between 'from_type' and 'to_type' nodes. """ # Step 1: Get embeddings from the model @@ -122,21 +119,21 @@ def get_similarity_scores( edge_index = get_edge_index( batch[to_type].pos[:, :2], # 'tx' positions batch[from_type].pos[:, :2], # 'bd' positions - k=receptive_field[f'k_{to_type}'], - dist=receptive_field[f'dist_{to_type}'], - method='cuda' + k=receptive_field[f"k_{to_type}"], + dist=receptive_field[f"dist_{to_type}"], + method="cuda", ) edge_index = coo_to_dense_adj( - edge_index.T, - num_nodes=shape[0], - num_nbrs=receptive_field[f'k_{to_type}'], + edge_index.T, + num_nodes=shape[0], + num_nbrs=receptive_field[f"k_{to_type}"], ) - + with torch.no_grad(): embeddings = model(batch.x_dict, batch.edge_index_dict) del batch - + # print(edge_index) # print(embeddings) @@ -144,19 +141,19 @@ def sparse_multiply(embeddings, edge_index, shape) -> coo_matrix: m = torch.nn.ZeroPad2d((0, 0, 0, 1)) # pad bottom with zeros similarity = torch.bmm( - m(embeddings[to_type])[edge_index], # 'to' x 'from' neighbors x embed - embeddings[from_type].unsqueeze(-1) # 'to' x embed x 1 - ) # -> 'to' x 'from' neighbors x 1 + m(embeddings[to_type])[edge_index], # 'to' x 'from' neighbors x embed + embeddings[from_type].unsqueeze(-1), # 'to' x embed x 1 + ) # -> 'to' x 'from' neighbors x 1 del embeddings # Sigmoid to get most similar 'to_type' neighbor similarity[similarity == 0] = -torch.inf # ensure zero stays zero similarity = F.sigmoid(similarity) # Neighbor-filtered similarity scores # shape = batch[from_type].x.shape[0], batch[to_type].x.shape[0] - indices = torch.argwhere(edge_index != -1).T + indices = torch.argwhere(edge_index != -1).T indices[1] = edge_index[edge_index != -1] - rows = cp.fromDlpack(to_dlpack(indices[0,:].to('cuda'))) - columns = cp.fromDlpack(to_dlpack(indices[1,:].to('cuda'))) + rows = cp.fromDlpack(to_dlpack(indices[0, :].to("cuda"))) + columns = cp.fromDlpack(to_dlpack(indices[1, :].to("cuda"))) # print(rows) del indices values = similarity[edge_index != -1].flatten() @@ -164,7 +161,6 @@ def sparse_multiply(embeddings, edge_index, shape) -> coo_matrix: return sparse_result # Free GPU memory after computation - # Call the sparse multiply function sparse_similarity = sparse_multiply(embeddings, edge_index, shape) gc.collect() @@ -175,38 +171,37 @@ def sparse_multiply(embeddings, edge_index, shape) -> coo_matrix: return sparse_similarity - - def predict_batch( lit_segger: torch.nn.Module, batch: Batch, score_cut: float, receptive_field: Dict[str, float], use_cc: bool = True, - knn_method: str = 'cuda' + knn_method: str = "cuda", ) -> pd.DataFrame: """ Predict cell assignments for a batch of transcript data using a segmentation model. - Adds a 'bound' column to indicate if the transcript is assigned to a cell (bound=1) + Adds a 'bound' column to indicate if the transcript is assigned to a cell (bound=1) or unassigned (bound=0). Args: lit_segger (torch.nn.Module): The lightning module wrapping the segmentation model. batch (Batch): A batch of transcript and cell data. score_cut (float): The threshold for assigning transcripts to cells based on similarity scores. - receptive_field (Dict[str, float]): Dictionary defining the receptive field for transcript-cell + receptive_field (Dict[str, float]): Dictionary defining the receptive field for transcript-cell and transcript-transcript relations. - use_cc (bool, optional): If True, perform connected components analysis for unassigned transcripts. + use_cc (bool, optional): If True, perform connected components analysis for unassigned transcripts. Defaults to True. knn_method (str, optional): The method to use for nearest neighbors. Defaults to 'cuda'. Returns: - pd.DataFrame: A DataFrame containing the transcript IDs, similarity scores, + pd.DataFrame: A DataFrame containing the transcript IDs, similarity scores, assigned cell IDs, and 'bound' column. """ + def _get_id(): """Generate a random Xenium-style ID.""" - return ''.join(np.random.choice(list('abcdefghijklmnopqrstuvwxyz'), 8)) + '-nx' + return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx" # Use CuPy with GPU context with cp.cuda.Device(0): @@ -214,10 +209,10 @@ def _get_id(): batch = batch.to("cuda") # Extract transcript IDs and initialize assignments DataFrame - transcript_id = cp.asnumpy(batch['tx'].id) - assignments = pd.DataFrame({'transcript_id': transcript_id}) + transcript_id = cp.asnumpy(batch["tx"].id) + assignments = pd.DataFrame({"transcript_id": transcript_id}) - if len(batch['bd'].pos) >= 10: + if len(batch["bd"].pos) >= 10: # Compute similarity scores between 'tx' and 'bd' scores = get_similarity_scores(lit_segger.model, batch, "tx", "bd", receptive_field) torch.cuda.empty_cache() @@ -227,48 +222,47 @@ def _get_id(): cp.get_default_memory_pool().free_all_blocks() # Free CuPy memory # Get direct assignments from similarity matrix belongs = cp.max(dense_scores, axis=1) # Max score per transcript - assignments['score'] = cp.asnumpy(belongs) # Move back to CPU + assignments["score"] = cp.asnumpy(belongs) # Move back to CPU - mask = assignments['score'] > score_cut - all_ids = np.concatenate(batch['bd'].id) # Keep IDs as NumPy array - assignments['segger_cell_id'] = None # Initialize as None + mask = assignments["score"] > score_cut + all_ids = np.concatenate(batch["bd"].id) # Keep IDs as NumPy array + assignments["segger_cell_id"] = None # Initialize as None max_indices = cp.argmax(dense_scores, axis=1).get() - assignments['segger_cell_id'][mask] = all_ids[max_indices[mask]] # Assign IDs - + assignments["segger_cell_id"][mask] = all_ids[max_indices[mask]] # Assign IDs + del dense_scores # Remove from memory cp.get_default_memory_pool().free_all_blocks() # Free CuPy memory torch.cuda.empty_cache() -# Move back to CPU - assignments['bound'] = 0 - assignments['bound'][mask] = 1 - - + # Move back to CPU + assignments["bound"] = 0 + assignments["bound"][mask] = 1 + if use_cc: # Compute similarity scores between 'tx' and 'tx' scores_tx = get_similarity_scores(lit_segger.model, batch, "tx", "tx", receptive_field) - # Convert to dense NumPy array - data_cpu = scores_tx.data.get() # Transfer data to CPU (NumPy) - row_cpu = scores_tx.row.get() # Transfer row indices to CPU (NumPy) - col_cpu = scores_tx.col.get() # Transfer column indices to CPU (NumPy) + # Convert to dense NumPy array + data_cpu = scores_tx.data.get() # Transfer data to CPU (NumPy) + row_cpu = scores_tx.row.get() # Transfer row indices to CPU (NumPy) + col_cpu = scores_tx.col.get() # Transfer column indices to CPU (NumPy) # dense_scores_tx = scores_tx.toarray().astype(cp.float16) # Rebuild the matrix on CPU using SciPy dense_scores_tx = scipy_coo_matrix((data_cpu, (row_cpu, col_cpu)), shape=scores_tx.shape).toarray() np.fill_diagonal(dense_scores_tx, 0) # Ignore self-similarity - + del scores_tx # Remove from memory cp.get_default_memory_pool().free_all_blocks() # Free CuPy memory # Assign unassigned transcripts using connected components - no_id = assignments['segger_cell_id'].isna() + no_id = assignments["segger_cell_id"].isna() if np.any(no_id): # Only compute if there are unassigned transcripts no_id_scores = dense_scores_tx[no_id][:, no_id] del dense_scores_tx # Remove from memory no_id_scores[no_id_scores < score_cut] = 0 n, comps = cc(no_id_scores, connection="weak", directed=False) new_ids = np.array([_get_id() for _ in range(n)]) - assignments['segger_cell_id'][no_id] = new_ids[comps] + assignments["segger_cell_id"][no_id] = new_ids[comps] # Perform memory cleanup to avoid OOM issues cp.get_default_memory_pool().free_all_blocks() @@ -276,9 +270,6 @@ def _get_id(): return assignments - - - def predict( lit_segger: LitSegger, @@ -286,7 +277,7 @@ def predict( score_cut: float, receptive_field: dict, use_cc: bool = True, - knn_method: str = 'cuda' + knn_method: str = "cuda", ) -> pd.DataFrame: # Change return type to Dask DataFrame if applicable """ Optimized prediction for multiple batches of transcript data. @@ -296,7 +287,7 @@ def predict( for batch in data_loader: assignments = predict_batch(lit_segger, batch, score_cut, receptive_field, use_cc, knn_method) all_assignments.append(dd.from_pandas(assignments, npartitions=1)) - + cp.get_default_memory_pool().free_all_blocks() torch.cuda.empty_cache() @@ -304,26 +295,26 @@ def predict( final_assignments = dd.concat(all_assignments, ignore_index=True) # Sort the Dask DataFrame by 'transcript_id' before setting it as an index - final_assignments = final_assignments.sort_values(by='transcript_id') + final_assignments = final_assignments.sort_values(by="transcript_id") # Set a unique index for Dask DataFrame - final_assignments = final_assignments.set_index('transcript_id', sorted=True) + final_assignments = final_assignments.set_index("transcript_id", sorted=True) # Max score selection logic - max_bound_idx = final_assignments[final_assignments['bound'] == 1].groupby('transcript_id')['score'].idxmax() - max_unbound_idx = final_assignments[final_assignments['bound'] == 0].groupby('transcript_id')['score'].idxmax() + max_bound_idx = final_assignments[final_assignments["bound"] == 1].groupby("transcript_id")["score"].idxmax() + max_unbound_idx = final_assignments[final_assignments["bound"] == 0].groupby("transcript_id")["score"].idxmax() # Combine indices, prioritizing bound=1 scores final_idx = max_bound_idx.combine_first(max_unbound_idx).compute() # Ensure it's computed # Now use the computed final_idx for indexing - result = final_assignments.loc[final_idx].compute().reset_index(names=['transcript_id']) - + result = final_assignments.loc[final_idx].compute().reset_index(names=["transcript_id"]) + # result = results.reset_index() # Handle cases where there's only one entry per 'segger_cell_id' # single_entry_mask = result.groupby('segger_cell_id').size() == 1 -# Handle cases where there's only one entry per 'segger_cell_id' + # Handle cases where there's only one entry per 'segger_cell_id' # single_entry_counts = result['segger_cell_id'].value_counts() # Count occurrences of each ID # single_entry_mask = single_entry_counts[single_entry_counts == 1].index # Get IDs with a count of 1 @@ -331,27 +322,26 @@ def predict( # for segger_id in single_entry_mask: # result.loc[result['segger_cell_id'] == segger_id, 'segger_cell_id'] = 'floating' - return result def segment( - model: LitSegger, - dm: SeggerDataModule, - save_dir: Union[str, Path], - seg_tag: str, - transcript_file: Union[str, Path], - score_cut: float = .5, + model: LitSegger, + dm: SeggerDataModule, + save_dir: Union[str, Path], + seg_tag: str, + transcript_file: Union[str, Path], + score_cut: float = 0.5, use_cc: bool = True, - file_format: str = 'anndata', - receptive_field: dict = {'k_bd': 4, 'dist_bd': 10, 'k_tx': 5, 'dist_tx': 3}, - knn_method: str = 'kd_tree', + file_format: str = "anndata", + receptive_field: dict = {"k_bd": 4, "dist_bd": 10, "k_tx": 5, "dist_tx": 3}, + knn_method: str = "kd_tree", verbose: bool = False, - **anndata_kwargs + **anndata_kwargs, ) -> None: """ Perform segmentation using the model, merge segmentation results with transcripts_df, and save in the specified format. - + Parameters: ---------- model : LitSegger @@ -388,22 +378,22 @@ def segment( # Step 1: Prediction step_start_time = time.time() - + train_dataloader = dm.train_dataloader() - test_dataloader = dm.test_dataloader() - val_dataloader = dm.val_dataloader() - + test_dataloader = dm.test_dataloader() + val_dataloader = dm.val_dataloader() + segmentation_train = predict(model, train_dataloader, score_cut, receptive_field, use_cc, knn_method) torch.cuda.empty_cache() cp.get_default_memory_pool().free_all_blocks() gc.collect() - - segmentation_val = predict(model, val_dataloader, score_cut, receptive_field, use_cc, knn_method) + + segmentation_val = predict(model, val_dataloader, score_cut, receptive_field, use_cc, knn_method) torch.cuda.empty_cache() cp.get_default_memory_pool().free_all_blocks() gc.collect() - - segmentation_test = predict(model, test_dataloader, score_cut, receptive_field, use_cc, knn_method) + + segmentation_test = predict(model, test_dataloader, score_cut, receptive_field, use_cc, knn_method) torch.cuda.empty_cache() cp.get_default_memory_pool().free_all_blocks() gc.collect() @@ -422,7 +412,7 @@ def segment( # print(seg_combined.columns) # print(transcripts_df.id) # Drop any unassigned rows - seg_final = seg_combined.dropna(subset=['segger_cell_id']).reset_index(drop=True) + seg_final = seg_combined.dropna(subset=["segger_cell_id"]).reset_index(drop=True) if verbose: elapsed_time = format_time(time.time() - step_start_time) @@ -440,7 +430,7 @@ def segment( seg_final_dd = dd.from_pandas(seg_final, npartitions=transcripts_df.npartitions) # Merge the segmentation results with the transcript data (still as Dask DataFrame) - transcripts_df_filtered = transcripts_df.merge(seg_final_dd, on='transcript_id', how='inner') + transcripts_df_filtered = transcripts_df.merge(seg_final_dd, on="transcript_id", how="inner") if verbose: elapsed_time = format_time(time.time() - step_start_time) @@ -448,18 +438,18 @@ def segment( # Step 4: Save the merged result step_start_time = time.time() - + if verbose: print(f"Saving results in {file_format} format...") - if file_format == 'csv': - save_path = save_dir / f'{seg_tag}_segmentation.csv' + if file_format == "csv": + save_path = save_dir / f"{seg_tag}_segmentation.csv" transcripts_df_filtered.compute().to_csv(save_path, index=False) # Use pandas after computing - elif file_format == 'parquet': - save_path = save_dir / f'{seg_tag}_segmentation.parquet' + elif file_format == "parquet": + save_path = save_dir / f"{seg_tag}_segmentation.parquet" transcripts_df_filtered.to_parquet(save_path, index=False) # Dask handles Parquet fine - elif file_format == 'anndata': - save_path = save_dir / f'{seg_tag}_segmentation.h5ad' + elif file_format == "anndata": + save_path = save_dir / f"{seg_tag}_segmentation.h5ad" segger_adata = create_anndata(transcripts_df_filtered.compute(), **anndata_kwargs) # Compute for AnnData segger_adata.write(save_path) else: @@ -479,9 +469,6 @@ def segment( torch.cuda.empty_cache() gc.collect() - - - # def predict( # lit_segger: LitSegger, @@ -493,7 +480,7 @@ def segment( # ) -> dd.DataFrame: # """ # Optimized prediction for multiple batches of transcript data using Dask and delayed processing with progress bar. - + # Args: # lit_segger (LitSegger): The lightning module wrapping the segmentation model. # data_loader (DataLoader): A data loader providing batches of transcript and cell data. @@ -539,7 +526,7 @@ def segment( # # Handle cases where there's only one entry per 'segger_cell_id' # single_entry_mask = result.groupby('segger_cell_id').size() == 1 # result.loc[single_entry_mask, 'segger_cell_id'] = 'floating' - + # return result # # Map the logic over each partition using Dask @@ -548,14 +535,11 @@ def segment( # # Trigger garbage collection and free GPU memory # torch.cuda.empty_cache() # gc.collect() - -# final_assignments = final_assignments.compute() - - -# return final_assignments +# final_assignments = final_assignments.compute() +# return final_assignments # # def predict( @@ -568,7 +552,7 @@ def segment( # # ) -> dd.DataFrame: # # """ # # Optimized prediction for multiple batches of transcript data using Dask and delayed processing with progress bar. - + # # Args: # # lit_segger (LitSegger): The lightning module wrapping the segmentation model. # # data_loader (DataLoader): A data loader providing batches of transcript and cell data. @@ -596,7 +580,7 @@ def segment( # # delayed(predict_batch)(lit_segger, batch, score_cut, receptive_field, use_cc, knn_method) # # for batch in data_loader # # ] - + # # # Build the Dask DataFrame from the delayed assignments # # assignments_dd = dd.from_delayed(delayed_assignments, meta=meta) @@ -612,7 +596,7 @@ def segment( # # # Handle cases where there's only one entry per 'segger_cell_id' # # single_entry_mask = result.groupby('segger_cell_id').size() == 1 # # result.loc[single_entry_mask, 'segger_cell_id'] = 'floating' - + # # return result # # # Map the logic over each partition using Dask @@ -627,22 +611,22 @@ def segment( # def segment( -# model: LitSegger, -# dm: SeggerDataModule, -# save_dir: Union[str, Path], -# seg_tag: str, -# transcript_file: Union[str, Path], +# model: LitSegger, +# dm: SeggerDataModule, +# save_dir: Union[str, Path], +# seg_tag: str, +# transcript_file: Union[str, Path], # score_cut: float = .25, # use_cc: bool = True, -# file_format: str = 'anndata', +# file_format: str = 'anndata', # receptive_field: dict = {'k_bd': 4, 'dist_bd': 10, 'k_tx': 5, 'dist_tx': 3}, # knn_method: str = 'kd_tree', # verbose: bool = False, # **anndata_kwargs # ) -> None: # """ -# Perform segmentation using the model, merge segmentation results with transcripts_df, -# and save in the specified format. Memory is managed efficiently using Dask and GPU +# Perform segmentation using the model, merge segmentation results with transcripts_df, +# and save in the specified format. Memory is managed efficiently using Dask and GPU # memory optimizations. # Args: @@ -674,15 +658,15 @@ def segment( # # Step 1: Prediction # step_start_time = time.time() - + # train_dataloader = dm.train_dataloader() # test_dataloader = dm.test_dataloader() # val_dataloader = dm.val_dataloader() - + # # delayed_train = predict(model, test_dataloader, score_cut=score_cut, receptive_field=receptive_field, use_cc=use_cc, knn_method=knn_method) # # delayed_val = predict(model, test_dataloader, score_cut=score_cut, receptive_field=receptive_field, use_cc=use_cc, knn_method=knn_method) # delayed_test = predict(model, test_dataloader, score_cut=score_cut, receptive_field=receptive_field, use_cc=use_cc, knn_method=knn_method) - + # delayed_test = delayed_test.compute() # # Compute all predictions at once using Dask # # with ProgressBar(): @@ -726,7 +710,7 @@ def segment( # # Step 4: Save the merged result # step_start_time = time.time() - + # if verbose: # print(f"Saving results in {file_format} format...") diff --git a/src/segger/training/README.md b/src/segger/training/README.md index ff7e04d..958cd20 100644 --- a/src/segger/training/README.md +++ b/src/segger/training/README.md @@ -7,20 +7,24 @@ The training module makes use of **PyTorch Lightning** for efficient and scalabl ## Key Components ### 1. **SpatialTranscriptomicsDataset** + The `SpatialTranscriptomicsDataset` class is used to load and manage spatial transcriptomics data stored in the format of PyTorch Geometric `Data` objects. It inherits from `InMemoryDataset` to load preprocessed datasets, ensuring efficient in-memory data handling for training and validation phases. - **Root Path**: The root directory contains the dataset, which is expected to have separate folders for training, validation, and test sets. - **Raw and Processed Data**: The module expects datasets in the form of processed PyTorch files, and the dataset class is responsible for loading them efficiently. ### 2. **Segger Model** + The `Segger` model is a custom graph neural network designed to work with heterogeneous graph data. It takes both **transcript (tx)** and **boundary (bd)** nodes, utilizing attention mechanisms for better feature aggregation. Key parameters such as `num_tx_tokens`, `init_emb`, `hidden_channels`, `out_channels`, and `heads` allow the user to control the model's architecture and initial embedding sizes. - **Heterogeneous Graph Support**: The model is converted to handle different node types using `to_hetero` from PyTorch Geometric. The transformation allows the model to handle multiple relations like `belongs` (tx to bd) and `neighbors` (tx to tx). ### 3. **LitSegger** + `LitSegger` is the PyTorch Lightning wrapper around the Segger model, which handles training, validation, and optimization. This wrapper facilitates the integration with Lightning’s trainer, allowing easy multi-GPU and distributed training. ### 4. **Training Pipeline** + The module provides an easily configurable pipeline for training the Segger model: - **Datasets**: Training and validation datasets are loaded using `SpatialTranscriptomicsDataset` with paths provided via arguments. @@ -30,6 +34,7 @@ The module provides an easily configurable pipeline for training the Segger mode ## Usage and Configuration ### Command-Line Arguments + The module accepts various command-line arguments that allow for flexible configuration: - `--train_dir`: Path to the training data directory. This directory should include `processed` and `raw` subdirectories. The direcotry `processed` should include the `pyg` `HeteroData` objects. @@ -51,6 +56,7 @@ The module accepts various command-line arguments that allow for flexible config - `--default_root_dir`: Directory where logs, checkpoints, and models will be saved. ### Example Training Command + The module can be executed from the command line as follows: ```bash diff --git a/src/segger/training/segger_data_module.py b/src/segger/training/segger_data_module.py index c1be43d..3feadef 100644 --- a/src/segger/training/segger_data_module.py +++ b/src/segger/training/segger_data_module.py @@ -21,9 +21,9 @@ def __init__( # TODO: Add documentation def setup(self, stage=None): - self.train = STPyGDataset(root=self.data_dir / 'train_tiles') - self.test = STPyGDataset(root=self.data_dir / 'test_tiles') - self.val = STPyGDataset(root=self.data_dir / 'val_tiles') + self.train = STPyGDataset(root=self.data_dir / "train_tiles") + self.test = STPyGDataset(root=self.data_dir / "test_tiles") + self.val = STPyGDataset(root=self.data_dir / "val_tiles") self.loader_kwargs = dict( batch_size=self.batch_size, num_workers=self.num_workers, diff --git a/src/segger/training/train.py b/src/segger/training/train.py index a3cf471..68adbb3 100644 --- a/src/segger/training/train.py +++ b/src/segger/training/train.py @@ -60,7 +60,17 @@ def __init__(self, **kwargs): self.validation_step_outputs = [] self.criterion = torch.nn.BCEWithLogitsLoss() - def from_new(self, num_tx_tokens: int, init_emb: int, hidden_channels: int, out_channels: int, heads: int, num_mid_layers: int, aggr: str, metadata: Union[Tuple, Metadata]): + def from_new( + self, + num_tx_tokens: int, + init_emb: int, + hidden_channels: int, + out_channels: int, + heads: int, + num_mid_layers: int, + aggr: str, + metadata: Union[Tuple, Metadata], + ): """ Initializes the LitSegger module with new parameters. @@ -124,7 +134,7 @@ def forward(self, batch: SpatialTranscriptomicsDataset) -> torch.Tensor: The output of the model. """ z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z['tx'], z['bd'].t()) # Example for bipartite graph + output = torch.matmul(z["tx"], z["bd"].t()) # Example for bipartite graph return output def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: @@ -145,16 +155,16 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """ # Forward pass to get the logits z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z['tx'], z['bd'].t()) + output = torch.matmul(z["tx"], z["bd"].t()) # Get edge labels and logits - edge_label_index = batch['tx', 'belongs', 'bd'].edge_label_index + edge_label_index = batch["tx", "belongs", "bd"].edge_label_index out_values = output[edge_label_index[0], edge_label_index[1]] - edge_label = batch['tx', 'belongs', 'bd'].edge_label - + edge_label = batch["tx", "belongs", "bd"].edge_label + # Compute binary cross-entropy loss with logits (no sigmoid here) loss = self.criterion(out_values, edge_label) - + # Log the training loss self.log("train_loss", loss, prog_bar=True, batch_size=batch.num_graphs) return loss @@ -177,31 +187,31 @@ def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """ # Forward pass to get the logits z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z['tx'], z['bd'].t()) + output = torch.matmul(z["tx"], z["bd"].t()) # Get edge labels and logits - edge_label_index = batch['tx', 'belongs', 'bd'].edge_label_index + edge_label_index = batch["tx", "belongs", "bd"].edge_label_index out_values = output[edge_label_index[0], edge_label_index[1]] - edge_label = batch['tx', 'belongs', 'bd'].edge_label - + edge_label = batch["tx", "belongs", "bd"].edge_label + # Compute binary cross-entropy loss with logits (no sigmoid here) loss = self.criterion(out_values, edge_label) - + # Apply sigmoid to logits for AUROC and F1 metrics out_values_prob = torch.sigmoid(out_values) # Compute metrics auroc = torchmetrics.AUROC(task="binary") auroc_res = auroc(out_values_prob, edge_label) - + f1 = F1Score(task="binary").to(self.device) f1_res = f1(out_values_prob, edge_label) - + # Log validation metrics self.log("validation_loss", loss, batch_size=batch.num_graphs) self.log("validation_auroc", auroc_res, prog_bar=True, batch_size=batch.num_graphs) self.log("validation_f1", f1_res, prog_bar=True, batch_size=batch.num_graphs) - + return loss def configure_optimizers(self) -> torch.optim.Optimizer: diff --git a/src/segger/validation/__init__.py b/src/segger/validation/__init__.py index 220150b..bfc7689 100644 --- a/src/segger/validation/__init__.py +++ b/src/segger/validation/__init__.py @@ -1,3 +1,3 @@ from .utils import * -from .xenium_explorer import * \ No newline at end of file +from .xenium_explorer import * diff --git a/src/segger/validation/utils.py b/src/segger/validation/utils.py index b283b00..72a5438 100644 --- a/src/segger/validation/utils.py +++ b/src/segger/validation/utils.py @@ -11,22 +11,20 @@ from matplotlib.backends.backend_pdf import PdfPages import matplotlib.pyplot as plt import dask -dask.config.set({'dataframe.query-planning': False}) + +dask.config.set({"dataframe.query-planning": False}) import squidpy as sq from sklearn.metrics import calinski_harabasz_score, silhouette_score, f1_score from pathlib import Path import seaborn as sns - - - def find_markers( - adata: ad.AnnData, - cell_type_column: str, - pos_percentile: float = 5, - neg_percentile: float = 10, - percentage: float = 50 + adata: ad.AnnData, + cell_type_column: str, + pos_percentile: float = 5, + neg_percentile: float = 10, + percentage: float = 50, ) -> Dict[str, Dict[str, List[str]]]: """Identify positive and negative markers for each cell type based on gene expression and filter by expression percentage. @@ -62,17 +60,12 @@ def find_markers( valid_pos_indices = pos_indices[expr_frac >= (percentage / 100)] positive_markers = genes[valid_pos_indices] negative_markers = genes[neg_indices] - markers[cell_type] = { - 'positive': list(positive_markers), - 'negative': list(negative_markers) - } + markers[cell_type] = {"positive": list(positive_markers), "negative": list(negative_markers)} return markers def find_mutually_exclusive_genes( - adata: ad.AnnData, - markers: Dict[str, Dict[str, List[str]]], - cell_type_column: str + adata: ad.AnnData, markers: Dict[str, Dict[str, List[str]]], cell_type_column: str ) -> List[Tuple[str, str]]: """Identify mutually exclusive genes based on expression criteria. @@ -94,7 +87,7 @@ def find_mutually_exclusive_genes( all_exclusive = [] gene_expression = adata.to_df() for cell_type, marker_sets in markers.items(): - positive_markers = marker_sets['positive'] + positive_markers = marker_sets["positive"] exclusive_genes[cell_type] = [] for gene in positive_markers: gene_expr = adata[:, gene].X @@ -104,7 +97,9 @@ def find_mutually_exclusive_genes( exclusive_genes[cell_type].append(gene) all_exclusive.append(gene) unique_genes = list({gene for i in exclusive_genes.keys() for gene in exclusive_genes[i] if gene in all_exclusive}) - filtered_exclusive_genes = {i: [gene for gene in exclusive_genes[i] if gene in unique_genes] for i in exclusive_genes.keys()} + filtered_exclusive_genes = { + i: [gene for gene in exclusive_genes[i] if gene in unique_genes] for i in exclusive_genes.keys() + } mutually_exclusive_gene_pairs = [ (gene1, gene2) for key1, key2 in combinations(filtered_exclusive_genes.keys(), 2) @@ -114,10 +109,7 @@ def find_mutually_exclusive_genes( return mutually_exclusive_gene_pairs -def compute_MECR( - adata: ad.AnnData, - gene_pairs: List[Tuple[str, str]] -) -> Dict[Tuple[str, str], float]: +def compute_MECR(adata: ad.AnnData, gene_pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], float]: """Compute the Mutually Exclusive Co-expression Rate (MECR) for each gene pair in an AnnData object. Args: @@ -143,9 +135,7 @@ def compute_MECR( def compute_quantized_mecr_area( - adata: sc.AnnData, - gene_pairs: List[Tuple[str, str]], - quantiles: int = 10 + adata: sc.AnnData, gene_pairs: List[Tuple[str, str]], quantiles: int = 10 ) -> pd.DataFrame: """Compute the average MECR, variance of MECR, and average cell area for quantiles of cell areas. @@ -161,28 +151,28 @@ def compute_quantized_mecr_area( - quantized_data: pd.DataFrame DataFrame containing quantile information, average MECR, variance of MECR, average area, and number of cells. """ - adata.obs['quantile'] = pd.qcut(adata.obs['cell_area'], quantiles, labels=False) + adata.obs["quantile"] = pd.qcut(adata.obs["cell_area"], quantiles, labels=False) quantized_data = [] for quantile in range(quantiles): - cells_in_quantile = adata.obs['quantile'] == quantile + cells_in_quantile = adata.obs["quantile"] == quantile mecr = compute_MECR(adata[cells_in_quantile, :], gene_pairs) average_mecr = np.mean([i for i in mecr.values()]) variance_mecr = np.var([i for i in mecr.values()]) - average_area = adata.obs.loc[cells_in_quantile, 'cell_area'].mean() - quantized_data.append({ - 'quantile': quantile / quantiles, - 'average_mecr': average_mecr, - 'variance_mecr': variance_mecr, - 'average_area': average_area, - 'num_cells': cells_in_quantile.sum() - }) + average_area = adata.obs.loc[cells_in_quantile, "cell_area"].mean() + quantized_data.append( + { + "quantile": quantile / quantiles, + "average_mecr": average_mecr, + "variance_mecr": variance_mecr, + "average_area": average_area, + "num_cells": cells_in_quantile.sum(), + } + ) return pd.DataFrame(quantized_data) def compute_quantized_mecr_counts( - adata: sc.AnnData, - gene_pairs: List[Tuple[str, str]], - quantiles: int = 10 + adata: sc.AnnData, gene_pairs: List[Tuple[str, str]], quantiles: int = 10 ) -> pd.DataFrame: """Compute the average MECR, variance of MECR, and average transcript counts for quantiles of transcript counts. @@ -198,28 +188,28 @@ def compute_quantized_mecr_counts( - quantized_data: pd.DataFrame DataFrame containing quantile information, average MECR, variance of MECR, average counts, and number of cells. """ - adata.obs['quantile'] = pd.qcut(adata.obs['transcripts'], quantiles, labels=False) + adata.obs["quantile"] = pd.qcut(adata.obs["transcripts"], quantiles, labels=False) quantized_data = [] for quantile in range(quantiles): - cells_in_quantile = adata.obs['quantile'] == quantile + cells_in_quantile = adata.obs["quantile"] == quantile mecr = compute_MECR(adata[cells_in_quantile, :], gene_pairs) average_mecr = np.mean([i for i in mecr.values()]) variance_mecr = np.var([i for i in mecr.values()]) - average_counts = adata.obs.loc[cells_in_quantile, 'transcripts'].mean() - quantized_data.append({ - 'quantile': quantile / quantiles, - 'average_mecr': average_mecr, - 'variance_mecr': variance_mecr, - 'average_counts': average_counts, - 'num_cells': cells_in_quantile.sum() - }) + average_counts = adata.obs.loc[cells_in_quantile, "transcripts"].mean() + quantized_data.append( + { + "quantile": quantile / quantiles, + "average_mecr": average_mecr, + "variance_mecr": variance_mecr, + "average_counts": average_counts, + "num_cells": cells_in_quantile.sum(), + } + ) return pd.DataFrame(quantized_data) def annotate_query_with_reference( - reference_adata: ad.AnnData, - query_adata: ad.AnnData, - transfer_column: str + reference_adata: ad.AnnData, query_adata: ad.AnnData, transfer_column: str ) -> ad.AnnData: """Annotate query AnnData object using a scRNA-seq reference atlas. @@ -238,25 +228,25 @@ def annotate_query_with_reference( common_genes = list(set(reference_adata.var_names) & set(query_adata.var_names)) reference_adata = reference_adata[:, common_genes] query_adata = query_adata[:, common_genes] - query_adata.layers['raw'] = query_adata.raw.X if query_adata.raw else query_adata.X - query_adata.var['raw_counts'] = query_adata.layers['raw'].sum(axis=0) + query_adata.layers["raw"] = query_adata.raw.X if query_adata.raw else query_adata.X + query_adata.var["raw_counts"] = query_adata.layers["raw"].sum(axis=0) sc.pp.normalize_total(query_adata, target_sum=1e4) sc.pp.log1p(query_adata) sc.pp.pca(reference_adata) sc.pp.neighbors(reference_adata) sc.tl.umap(reference_adata) sc.tl.ingest(query_adata, reference_adata, obs=transfer_column) - query_adata.obsm['X_umap'] = query_adata.obsm['X_umap'] + query_adata.obsm["X_umap"] = query_adata.obsm["X_umap"] return query_adata def calculate_contamination( - adata: ad.AnnData, - markers: Dict[str, Dict[str, List[str]]], - radius: float = 15, - n_neighs: int = 10, - celltype_column: str = 'celltype_major', - num_cells: int = 10000 + adata: ad.AnnData, + markers: Dict[str, Dict[str, List[str]]], + radius: float = 15, + n_neighs: int = 10, + celltype_column: str = "celltype_major", + num_cells: int = 10000, ) -> pd.DataFrame: """Calculate normalized contamination from neighboring cells of different cell types based on positive markers. @@ -282,11 +272,11 @@ def calculate_contamination( """ if celltype_column not in adata.obs: raise ValueError("Column celltype_column must be present in adata.obs.") - positive_markers = {ct: markers[ct]['positive'] for ct in markers} + positive_markers = {ct: markers[ct]["positive"] for ct in markers} adata.obsm["spatial"] = adata.obs[["cell_centroid_x", "cell_centroid_y"]].copy().to_numpy() - sq.gr.spatial_neighbors(adata, radius=radius, n_neighs=n_neighs, coord_type='generic') - neighbors = adata.obsp['spatial_connectivities'].tolil() - raw_counts = adata[:, adata.var_names].layers['raw'].toarray() + sq.gr.spatial_neighbors(adata, radius=radius, n_neighs=n_neighs, coord_type="generic") + neighbors = adata.obsp["spatial_connectivities"].tolil() + raw_counts = adata[:, adata.var_names].layers["raw"].toarray() cell_types = adata.obs[celltype_column] selected_cells = np.random.choice(adata.n_obs, size=min(num_cells, adata.n_obs), replace=False) contamination = {ct: {ct2: 0 for ct2 in positive_markers.keys()} for ct in positive_markers.keys()} @@ -309,19 +299,19 @@ def calculate_contamination( if marker in adata.var_names: marker_counts_in_neighbor = raw_counts[neighbor_idx, adata.var_names.get_loc(marker)] if total_counts_in_neighborhood > 0: - contamination[cell_type][neighbor_type] += marker_counts_in_neighbor / total_counts_in_neighborhood + contamination[cell_type][neighbor_type] += ( + marker_counts_in_neighbor / total_counts_in_neighborhood + ) negighborings[cell_type][neighbor_type] += 1 contamination_df = pd.DataFrame(contamination).T negighborings_df = pd.DataFrame(negighborings).T - contamination_df.index.name = 'Source Cell Type' - contamination_df.columns.name = 'Target Cell Type' + contamination_df.index.name = "Source Cell Type" + contamination_df.columns.name = "Target Cell Type" return contamination_df / (negighborings_df + 1) def calculate_sensitivity( - adata: ad.AnnData, - purified_markers: Dict[str, List[str]], - max_cells_per_type: int = 1000 + adata: ad.AnnData, purified_markers: Dict[str, List[str]], max_cells_per_type: int = 1000 ) -> Dict[str, List[float]]: """Calculate the sensitivity of the purified markers for each cell type. @@ -339,8 +329,8 @@ def calculate_sensitivity( """ sensitivity_results = {cell_type: [] for cell_type in purified_markers.keys()} for cell_type, markers in purified_markers.items(): - markers = markers['positive'] - subset = adata[adata.obs['celltype_major'] == cell_type] + markers = markers["positive"] + subset = adata[adata.obs["celltype_major"] == cell_type] if subset.n_obs > max_cells_per_type: cell_indices = np.random.choice(subset.n_obs, max_cells_per_type, replace=False) subset = subset[cell_indices] @@ -352,9 +342,7 @@ def calculate_sensitivity( def compute_clustering_scores( - adata: ad.AnnData, - cell_type_column: str = 'celltype_major', - use_pca: bool = True + adata: ad.AnnData, cell_type_column: str = "celltype_major", use_pca: bool = True ) -> Tuple[float, float]: """Compute the Calinski-Harabasz and Silhouette scores for an AnnData object based on the assigned cell types. @@ -384,11 +372,11 @@ def compute_clustering_scores( def compute_neighborhood_metrics( - adata: ad.AnnData, - radius: float = 10, - celltype_column: str = 'celltype_major', + adata: ad.AnnData, + radius: float = 10, + celltype_column: str = "celltype_major", n_neighs: int = 20, - subset_size: int = 10000 + subset_size: int = 10000, ) -> None: """Compute neighborhood entropy and number of neighbors for each cell in the AnnData object. @@ -418,8 +406,8 @@ def compute_neighborhood_metrics( # Randomly select a subset of cells subset_indices = np.random.choice(adata.n_obs, subset_size, replace=False) # Compute spatial neighbors for the entire dataset - sq.gr.spatial_neighbors(adata, radius=radius, coord_type='generic', n_neighs=n_neighs) - neighbors = adata.obsp['spatial_distances'].tolil().rows + sq.gr.spatial_neighbors(adata, radius=radius, coord_type="generic", n_neighs=n_neighs) + neighbors = adata.obsp["spatial_distances"].tolil().rows entropies = [] num_neighbors = [] # Calculate entropy and number of neighbors only for the selected subset @@ -441,8 +429,8 @@ def compute_neighborhood_metrics( neighbors_full = np.full(adata.n_obs, np.nan) entropy_full[subset_indices] = entropies neighbors_full[subset_indices] = num_neighbors - adata.obs['neighborhood_entropy'] = entropy_full - adata.obs['number_of_neighbors'] = neighbors_full + adata.obs["neighborhood_entropy"] = entropy_full + adata.obs["number_of_neighbors"] = neighbors_full def compute_transcript_density(adata: ad.AnnData) -> None: @@ -453,15 +441,15 @@ def compute_transcript_density(adata: ad.AnnData) -> None: Annotated data object containing transcript and cell area information. """ try: - transcript_counts = adata.obs['transcript_counts'] + transcript_counts = adata.obs["transcript_counts"] except: - transcript_counts = adata.obs['transcripts'] - cell_areas = adata.obs['cell_area'] - adata.obs['transcript_density'] = transcript_counts / cell_areas + transcript_counts = adata.obs["transcripts"] + cell_areas = adata.obs["cell_area"] + adata.obs["transcript_density"] = transcript_counts / cell_areas # def compute_celltype_f1_purity( -# adata: ad.AnnData, +# adata: ad.AnnData, # marker_genes: Dict[str, Dict[str, List[str]]] # ) -> Dict[str, float]: # """ @@ -497,7 +485,7 @@ def compute_transcript_density(adata: ad.AnnData) -> None: # def average_log_normalized_expression( -# adata: ad.AnnData, +# adata: ad.AnnData, # celltype_column: str # ) -> pd.DataFrame: # """ @@ -516,18 +504,8 @@ def compute_transcript_density(adata: ad.AnnData) -> None: # return adata.to_df().groupby(adata.obs[celltype_column]).mean() - - - - def plot_metric_comparison( - ax: plt.Axes, - data: pd.DataFrame, - metric: str, - label: str, - method1: str, - method2: str, - output_path: Path + ax: plt.Axes, data: pd.DataFrame, metric: str, label: str, method1: str, method2: str, output_path: Path ) -> None: """Plot a comparison of a specific metric between two methods and save the comparison data. @@ -547,25 +525,22 @@ def plot_metric_comparison( - output_path: Path Path to save the merged DataFrame as a CSV. """ - subset1 = data[data['method'] == method1] - subset2 = data[data['method'] == method2] - merged_data = pd.merge(subset1, subset2, on='celltype_major', suffixes=(f'_{method1}', f'_{method2}')) - + subset1 = data[data["method"] == method1] + subset2 = data[data["method"] == method2] + merged_data = pd.merge(subset1, subset2, on="celltype_major", suffixes=(f"_{method1}", f"_{method2}")) + # Save the merged data used in the plot to CSV - merged_data.to_csv(output_path / f'metric_comparison_{metric}_{method1}_vs_{method2}.csv', index=False) - - for cell_type in merged_data['celltype_major'].unique(): - cell_data = merged_data[merged_data['celltype_major'] == cell_type] - ax.scatter(cell_data[f'{metric}_{method1}'], cell_data[f'{metric}_{method2}'], - label=cell_type) - - max_value = max(merged_data[f'{metric}_{method1}'].max(), merged_data[f'{metric}_{method2}'].max()) - ax.plot([0, max_value], [0, max_value], 'k--', alpha=0.5) - ax.set_xlabel(f'{label} ({method1})') - ax.set_ylabel(f'{label} ({method2})') - ax.set_title(f'{label}: {method1} vs {method2}') + merged_data.to_csv(output_path / f"metric_comparison_{metric}_{method1}_vs_{method2}.csv", index=False) + for cell_type in merged_data["celltype_major"].unique(): + cell_data = merged_data[merged_data["celltype_major"] == cell_type] + ax.scatter(cell_data[f"{metric}_{method1}"], cell_data[f"{metric}_{method2}"], label=cell_type) + max_value = max(merged_data[f"{metric}_{method1}"].max(), merged_data[f"{metric}_{method2}"].max()) + ax.plot([0, max_value], [0, max_value], "k--", alpha=0.5) + ax.set_xlabel(f"{label} ({method1})") + ax.set_ylabel(f"{label} ({method2})") + ax.set_title(f"{label}: {method1} vs {method2}") def load_segmentations(segmentation_paths: Dict[str, Path]) -> Dict[str, sc.AnnData]: @@ -581,16 +556,15 @@ def load_segmentations(segmentation_paths: Dict[str, Path]) -> Dict[str, sc.AnnD for method, path in segmentation_paths.items(): adata = sc.read(path) # Special handling for 'segger' to separate into 'segger_n0' and 'segger_n1' - if method == 'segger': - cells_n1 = [i for i in adata.obs_names if not i.endswith('-nx')] - cells_n0 = [i for i in adata.obs_names if i.endswith('-nx')] - segmentations_dict['segger_n1'] = adata[cells_n1, :] - segmentations_dict['segger_n0'] = adata[cells_n0, :] + if method == "segger": + cells_n1 = [i for i in adata.obs_names if not i.endswith("-nx")] + cells_n0 = [i for i in adata.obs_names if i.endswith("-nx")] + segmentations_dict["segger_n1"] = adata[cells_n1, :] + segmentations_dict["segger_n0"] = adata[cells_n0, :] segmentations_dict[method] = adata return segmentations_dict - def plot_cell_counts(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None: """Plot the number of cells per segmentation method and save the cell count data as a CSV. @@ -600,33 +574,37 @@ def plot_cell_counts(segmentations_dict: Dict[str, sc.AnnData], output_path: Pat """ # Calculate the number of cells in each segmentation method cell_counts = {method: seg.n_obs for method, seg in segmentations_dict.items()} - + # Create a DataFrame for the bar plot - df = pd.DataFrame(cell_counts, index=['Number of Cells']).T - + df = pd.DataFrame(cell_counts, index=["Number of Cells"]).T + # Save the DataFrame to CSV - df.to_csv(output_path / 'cell_counts_data.csv', index=True) - + df.to_csv(output_path / "cell_counts_data.csv", index=True) + # Generate the bar plot - ax = df.plot(kind='bar', stacked=False, color=[palette.get(key, '#333333') for key in df.index], figsize=(3, 6), width=0.9) - + ax = df.plot( + kind="bar", stacked=False, color=[palette.get(key, "#333333") for key in df.index], figsize=(3, 6), width=0.9 + ) + # Add a dashed line for the 10X baseline - if '10X' in cell_counts: - baseline_height = cell_counts['10X'] - ax.axhline(y=baseline_height, color='gray', linestyle='--', linewidth=1.5, label='10X Baseline') - + if "10X" in cell_counts: + baseline_height = cell_counts["10X"] + ax.axhline(y=baseline_height, color="gray", linestyle="--", linewidth=1.5, label="10X Baseline") + # Set plot titles and labels - plt.title('Number of Cells per Segmentation Method') - plt.xlabel('Segmentation Method') - plt.ylabel('Number of Cells') - plt.legend(title='', bbox_to_anchor=(1.05, 1), loc='upper left') - + plt.title("Number of Cells per Segmentation Method") + plt.xlabel("Segmentation Method") + plt.ylabel("Number of Cells") + plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left") + # Save the figure as a PDF - plt.savefig(output_path / 'cell_counts_bar_plot.pdf', bbox_inches='tight') + plt.savefig(output_path / "cell_counts_bar_plot.pdf", bbox_inches="tight") plt.show() -def plot_percent_assigned(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None: +def plot_percent_assigned( + segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str] +) -> None: """Plot the percentage of assigned transcripts (normalized) for each segmentation method. Args: @@ -646,43 +624,38 @@ def plot_percent_assigned(segmentations_dict: Dict[str, sc.AnnData], output_path percent_assigned_normalized = total_counts_per_gene.divide(max_counts_per_gene, axis=0) * 100 # Prepare the data for the violin plot - violin_data = pd.DataFrame({ - 'Segmentation Method': [], - 'Percent Assigned (Normalized)': [] - }) - - + violin_data = pd.DataFrame({"Segmentation Method": [], "Percent Assigned (Normalized)": []}) # Add normalized percent_assigned data for each method for method in segmentations_dict.keys(): method_data = percent_assigned_normalized[method].dropna() - method_df = pd.DataFrame({ - 'Segmentation Method': [method] * len(method_data), - 'Percent Assigned (Normalized)': method_data.values - }) + method_df = pd.DataFrame( + {"Segmentation Method": [method] * len(method_data), "Percent Assigned (Normalized)": method_data.values} + ) violin_data = pd.concat([violin_data, method_df], axis=0) - - violin_data.to_csv(output_path / 'percent_assigned_normalized.csv', index=True) + + violin_data.to_csv(output_path / "percent_assigned_normalized.csv", index=True) # Plot the violin plots plt.figure(figsize=(12, 8)) - ax = sns.violinplot(x='Segmentation Method', y='Percent Assigned (Normalized)', data=violin_data, palette=palette) + ax = sns.violinplot(x="Segmentation Method", y="Percent Assigned (Normalized)", data=violin_data, palette=palette) # Add a dashed line for the 10X baseline - if '10X' in segmentations_dict: - baseline_height = percent_assigned_normalized['10X'].mean() - ax.axhline(y=baseline_height, color='gray', linestyle='--', linewidth=1.5, label='10X Baseline') + if "10X" in segmentations_dict: + baseline_height = percent_assigned_normalized["10X"].mean() + ax.axhline(y=baseline_height, color="gray", linestyle="--", linewidth=1.5, label="10X Baseline") # Set plot titles and labels - plt.title('') - plt.xlabel('Segmentation Method') - plt.ylabel('Percent Assigned (Normalized)') - plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + plt.title("") + plt.xlabel("Segmentation Method") + plt.ylabel("Percent Assigned (Normalized)") + plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") # Save the figure as a PDF - plt.savefig(output_path / 'percent_assigned_normalized_violin_plot.pdf', bbox_inches='tight') + plt.savefig(output_path / "percent_assigned_normalized_violin_plot.pdf", bbox_inches="tight") plt.show() + def plot_gene_counts(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None: """Plot the normalized gene counts for each segmentation method. @@ -703,40 +676,37 @@ def plot_gene_counts(segmentations_dict: Dict[str, sc.AnnData], output_path: Pat normalized_counts_per_gene = total_counts_per_gene.divide(max_counts_per_gene, axis=0) # Prepare the data for the box plot - boxplot_data = pd.DataFrame({ - 'Segmentation Method': [], - 'Normalized Counts': [] - }) + boxplot_data = pd.DataFrame({"Segmentation Method": [], "Normalized Counts": []}) for method in segmentations_dict.keys(): method_counts = normalized_counts_per_gene[method] - method_df = pd.DataFrame({ - 'Segmentation Method': [method] * len(method_counts), - 'Normalized Counts': method_counts.values - }) + method_df = pd.DataFrame( + {"Segmentation Method": [method] * len(method_counts), "Normalized Counts": method_counts.values} + ) boxplot_data = pd.concat([boxplot_data, method_df], axis=0) - - boxplot_data.to_csv(output_path / 'gene_counts_normalized_data.csv', index=True) + + boxplot_data.to_csv(output_path / "gene_counts_normalized_data.csv", index=True) # Plot the box plots plt.figure(figsize=(3, 6)) - ax = sns.boxplot(x='Segmentation Method', y='Normalized Counts', data=boxplot_data, palette=palette, width=0.9) + ax = sns.boxplot(x="Segmentation Method", y="Normalized Counts", data=boxplot_data, palette=palette, width=0.9) # Add a dashed line for the 10X baseline - if '10X' in normalized_counts_per_gene: - baseline_height = normalized_counts_per_gene['10X'].mean() - plt.axhline(y=baseline_height, color='gray', linestyle='--', linewidth=1.5, label='10X Baseline') + if "10X" in normalized_counts_per_gene: + baseline_height = normalized_counts_per_gene["10X"].mean() + plt.axhline(y=baseline_height, color="gray", linestyle="--", linewidth=1.5, label="10X Baseline") # Set plot titles and labels - plt.title('') - plt.xlabel('Segmentation Method') - plt.ylabel('Normalized Counts') + plt.title("") + plt.xlabel("Segmentation Method") + plt.ylabel("Normalized Counts") plt.xticks(rotation=0) # Save the figure as a PDF - plt.savefig(output_path / 'gene_counts_normalized_boxplot_by_method.pdf', bbox_inches='tight') + plt.savefig(output_path / "gene_counts_normalized_boxplot_by_method.pdf", bbox_inches="tight") plt.show() + def plot_counts_per_cell(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None: """Plot the counts per cell (log2) for each segmentation method. @@ -745,36 +715,33 @@ def plot_counts_per_cell(segmentations_dict: Dict[str, sc.AnnData], output_path: output_path (Path): Path to the directory where the plot will be saved. """ # Prepare the data for the violin plot - violin_data = pd.DataFrame({ - 'Segmentation Method': [], - 'Counts per Cell (log2)': [] - }) + violin_data = pd.DataFrame({"Segmentation Method": [], "Counts per Cell (log2)": []}) for method, adata in segmentations_dict.items(): - method_counts = adata.obs['transcripts'] + 1 - method_df = pd.DataFrame({ - 'Segmentation Method': [method] * len(method_counts), - 'Counts per Cell (log2)': method_counts.values - }) + method_counts = adata.obs["transcripts"] + 1 + method_df = pd.DataFrame( + {"Segmentation Method": [method] * len(method_counts), "Counts per Cell (log2)": method_counts.values} + ) violin_data = pd.concat([violin_data, method_df], axis=0) - - violin_data.to_csv(output_path / 'counts_per_cell_data.csv', index=True) + + violin_data.to_csv(output_path / "counts_per_cell_data.csv", index=True) # Plot the violin plots plt.figure(figsize=(4, 6)) - ax = sns.violinplot(x='Segmentation Method', y='Counts per Cell (log2)', data=violin_data, palette=palette) + ax = sns.violinplot(x="Segmentation Method", y="Counts per Cell (log2)", data=violin_data, palette=palette) ax.set(ylim=(5, 300)) # Add a dashed line for the 10X-nucleus median - if '10X-nucleus' in segmentations_dict: - median_10X_nucleus = np.median(segmentations_dict['10X-nucleus'].obs['transcripts'] + 1) - ax.axhline(y=median_10X_nucleus, color='gray', linestyle='--', linewidth=1.5, label='10X-nucleus Median') + if "10X-nucleus" in segmentations_dict: + median_10X_nucleus = np.median(segmentations_dict["10X-nucleus"].obs["transcripts"] + 1) + ax.axhline(y=median_10X_nucleus, color="gray", linestyle="--", linewidth=1.5, label="10X-nucleus Median") # Set plot titles and labels - plt.title('') - plt.xlabel('Segmentation Method') - plt.ylabel('Counts per Cell (log2)') + plt.title("") + plt.xlabel("Segmentation Method") + plt.ylabel("Counts per Cell (log2)") plt.xticks(rotation=0) # Save the figure as a PDF - plt.savefig(output_path / 'counts_per_cell_violin_plot.pdf', bbox_inches='tight') + plt.savefig(output_path / "counts_per_cell_violin_plot.pdf", bbox_inches="tight") plt.show() + def plot_cell_area(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None: """Plot the cell area (log2) for each segmentation method. @@ -783,37 +750,36 @@ def plot_cell_area(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, output_path (Path): Path to the directory where the plot will be saved. """ # Prepare the data for the violin plot - violin_data = pd.DataFrame({ - 'Segmentation Method': [], - 'Cell Area (log2)': [] - }) + violin_data = pd.DataFrame({"Segmentation Method": [], "Cell Area (log2)": []}) for method in segmentations_dict.keys(): - if 'cell_area' in segmentations_dict[method].obs.columns: - method_area = segmentations_dict[method].obs['cell_area'] + 1 - method_df = pd.DataFrame({ - 'Segmentation Method': [method] * len(method_area), - 'Cell Area (log2)': method_area.values - }) + if "cell_area" in segmentations_dict[method].obs.columns: + method_area = segmentations_dict[method].obs["cell_area"] + 1 + method_df = pd.DataFrame( + {"Segmentation Method": [method] * len(method_area), "Cell Area (log2)": method_area.values} + ) violin_data = pd.concat([violin_data, method_df], axis=0) - violin_data.to_csv(output_path / 'cell_area_log2_data.csv', index=True) + violin_data.to_csv(output_path / "cell_area_log2_data.csv", index=True) # Plot the violin plots plt.figure(figsize=(4, 6)) - ax = sns.violinplot(x='Segmentation Method', y='Cell Area (log2)', data=violin_data, palette=palette) + ax = sns.violinplot(x="Segmentation Method", y="Cell Area (log2)", data=violin_data, palette=palette) ax.set(ylim=(5, 100)) # Add a dashed line for the 10X-nucleus median - if '10X-nucleus' in segmentations_dict: - median_10X_nucleus_area = np.median(segmentations_dict['10X-nucleus'].obs['cell_area'] + 1) - ax.axhline(y=median_10X_nucleus_area, color='gray', linestyle='--', linewidth=1.5, label='10X-nucleus Median') + if "10X-nucleus" in segmentations_dict: + median_10X_nucleus_area = np.median(segmentations_dict["10X-nucleus"].obs["cell_area"] + 1) + ax.axhline(y=median_10X_nucleus_area, color="gray", linestyle="--", linewidth=1.5, label="10X-nucleus Median") # Set plot titles and labels - plt.title('') - plt.xlabel('Segmentation Method') - plt.ylabel('Cell Area (log2)') + plt.title("") + plt.xlabel("Segmentation Method") + plt.ylabel("Cell Area (log2)") plt.xticks(rotation=0) # Save the figure as a PDF - plt.savefig(output_path / 'cell_area_log2_violin_plot.pdf', bbox_inches='tight') + plt.savefig(output_path / "cell_area_log2_violin_plot.pdf", bbox_inches="tight") plt.show() -def plot_transcript_density(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None: + +def plot_transcript_density( + segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str] +) -> None: """Plot the transcript density (log2) for each segmentation method. Args: @@ -821,43 +787,53 @@ def plot_transcript_density(segmentations_dict: Dict[str, sc.AnnData], output_pa output_path (Path): Path to the directory where the plot will be saved. """ # Prepare the data for the violin plot - violin_data = pd.DataFrame({ - 'Segmentation Method': [], - 'Transcript Density (log2)': [] - }) + violin_data = pd.DataFrame({"Segmentation Method": [], "Transcript Density (log2)": []}) for method in segmentations_dict.keys(): - if 'cell_area' in segmentations_dict[method].obs.columns: - method_density = segmentations_dict[method].obs['transcripts'] / segmentations_dict[method].obs['cell_area'] + if "cell_area" in segmentations_dict[method].obs.columns: + method_density = segmentations_dict[method].obs["transcripts"] / segmentations_dict[method].obs["cell_area"] method_density_log2 = np.log2(method_density + 1) - method_df = pd.DataFrame({ - 'Segmentation Method': [method] * len(method_density_log2), - 'Transcript Density (log2)': method_density_log2.values - }) + method_df = pd.DataFrame( + { + "Segmentation Method": [method] * len(method_density_log2), + "Transcript Density (log2)": method_density_log2.values, + } + ) violin_data = pd.concat([violin_data, method_df], axis=0) - - violin_data.to_csv(output_path / 'transcript_density_log2_data.csv', index=True) + + violin_data.to_csv(output_path / "transcript_density_log2_data.csv", index=True) # Plot the violin plots plt.figure(figsize=(4, 6)) - ax = sns.violinplot(x='Segmentation Method', y='Transcript Density (log2)', data=violin_data, palette=palette) + ax = sns.violinplot(x="Segmentation Method", y="Transcript Density (log2)", data=violin_data, palette=palette) # Add a dashed line for the 10X-nucleus median - if '10X-nucleus' in segmentations_dict: - median_10X_nucleus_density_log2 = np.median(np.log2(segmentations_dict['10X-nucleus'].obs['transcripts'] / segmentations_dict['10X-nucleus'].obs['cell_area'] + 1)) - ax.axhline(y=median_10X_nucleus_density_log2, color='gray', linestyle='--', linewidth=1.5, label='10X-nucleus Median') + if "10X-nucleus" in segmentations_dict: + median_10X_nucleus_density_log2 = np.median( + np.log2( + segmentations_dict["10X-nucleus"].obs["transcripts"] + / segmentations_dict["10X-nucleus"].obs["cell_area"] + + 1 + ) + ) + ax.axhline( + y=median_10X_nucleus_density_log2, color="gray", linestyle="--", linewidth=1.5, label="10X-nucleus Median" + ) # Set plot titles and labels - plt.title('') - plt.xlabel('Segmentation Method') - plt.ylabel('Transcript Density (log2)') + plt.title("") + plt.xlabel("Segmentation Method") + plt.ylabel("Transcript Density (log2)") plt.xticks(rotation=0) # Save the figure as a PDF - plt.savefig(output_path / 'transcript_density_log2_violin_plot.pdf', bbox_inches='tight') + plt.savefig(output_path / "transcript_density_log2_violin_plot.pdf", bbox_inches="tight") plt.show() -def plot_general_statistics_plots(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None: + +def plot_general_statistics_plots( + segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str] +) -> None: """Create a summary plot with all the general statistics subplots. Args: @@ -884,11 +860,13 @@ def plot_general_statistics_plots(segmentations_dict: Dict[str, sc.AnnData], out plot_transcript_density(segmentations_dict, output_path, palette=palette) plt.tight_layout() - plt.savefig(output_path / 'general_statistics_plots.pdf', bbox_inches='tight') + plt.savefig(output_path / "general_statistics_plots.pdf", bbox_inches="tight") plt.show() -def plot_mecr_results(mecr_results: Dict[str, Dict[Tuple[str, str], float]], output_path: Path, palette: Dict[str, str]) -> None: +def plot_mecr_results( + mecr_results: Dict[str, Dict[Tuple[str, str], float]], output_path: Path, palette: Dict[str, str] +) -> None: """Plot the MECR (Mutually Exclusive Co-expression Rate) results for each segmentation method. Args: @@ -900,26 +878,25 @@ def plot_mecr_results(mecr_results: Dict[str, Dict[Tuple[str, str], float]], out plot_data = [] for method, mecr_dict in mecr_results.items(): for gene_pair, mecr_value in mecr_dict.items(): - plot_data.append({ - 'Segmentation Method': method, - 'Gene Pair': f"{gene_pair[0]} - {gene_pair[1]}", - 'MECR': mecr_value - }) + plot_data.append( + {"Segmentation Method": method, "Gene Pair": f"{gene_pair[0]} - {gene_pair[1]}", "MECR": mecr_value} + ) df = pd.DataFrame(plot_data) - df.to_csv(output_path / 'mcer_box.csv', index=True) + df.to_csv(output_path / "mcer_box.csv", index=True) plt.figure(figsize=(3, 6)) - sns.boxplot(x='Segmentation Method', y='MECR', data=df, palette=palette) - plt.title('Mutually Exclusive Co-expression Rate (MECR)') - plt.xlabel('Segmentation Method') - plt.ylabel('MECR') - plt.xticks(rotation=45, ha='right') + sns.boxplot(x="Segmentation Method", y="MECR", data=df, palette=palette) + plt.title("Mutually Exclusive Co-expression Rate (MECR)") + plt.xlabel("Segmentation Method") + plt.ylabel("MECR") + plt.xticks(rotation=45, ha="right") plt.tight_layout() - plt.savefig(output_path / 'mecr_results_boxplot.pdf', bbox_inches='tight') + plt.savefig(output_path / "mecr_results_boxplot.pdf", bbox_inches="tight") plt.show() - -def plot_quantized_mecr_counts(quantized_mecr_counts: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str]) -> None: +def plot_quantized_mecr_counts( + quantized_mecr_counts: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str] +) -> None: """Plot the quantized MECR values against transcript counts for each segmentation method, with point size proportional to the variance of MECR. Args: @@ -927,38 +904,40 @@ def plot_quantized_mecr_counts(quantized_mecr_counts: Dict[str, pd.DataFrame], o output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes. """ - quantized_mecr_counts.to_csv(output_path / 'quantized_mecr_counts.csv', index=True) + quantized_mecr_counts.to_csv(output_path / "quantized_mecr_counts.csv", index=True) plt.figure(figsize=(9, 6)) for method, df in quantized_mecr_counts.items(): plt.plot( - df['average_counts'], - df['average_mecr'], - marker='o', - linestyle='-', - color=palette.get(method, '#333333'), + df["average_counts"], + df["average_mecr"], + marker="o", + linestyle="-", + color=palette.get(method, "#333333"), label=method, - markersize=0 # No markers, only lines + markersize=0, # No markers, only lines ) plt.scatter( - df['average_counts'], - df['average_mecr'], - s=df['variance_mecr'] * 1e5, # Size of points based on the variance of MECR - color=palette.get(method, '#333333'), + df["average_counts"], + df["average_mecr"], + s=df["variance_mecr"] * 1e5, # Size of points based on the variance of MECR + color=palette.get(method, "#333333"), alpha=0.7, # Slight transparency for overlapping points - edgecolor='w', # White edge color for better visibility - linewidth=0.5 # Thin edge line + edgecolor="w", # White edge color for better visibility + linewidth=0.5, # Thin edge line ) - plt.title('Quantized MECR by Transcript Counts') - plt.xlabel('Average Transcript Counts') - plt.ylabel('Average MECR') + plt.title("Quantized MECR by Transcript Counts") + plt.xlabel("Average Transcript Counts") + plt.ylabel("Average MECR") # Place the legend outside the plot on the top right - plt.legend(title='', bbox_to_anchor=(1.05, 1), loc='upper left') + plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left") plt.tight_layout() - plt.savefig(output_path / 'quantized_mecr_counts_plot.pdf', bbox_inches='tight') + plt.savefig(output_path / "quantized_mecr_counts_plot.pdf", bbox_inches="tight") plt.show() - - -def plot_quantized_mecr_area(quantized_mecr_area: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str]) -> None: + + +def plot_quantized_mecr_area( + quantized_mecr_area: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str] +) -> None: """Plot the quantized MECR values against cell areas for each segmentation method, with point size proportional to the variance of MECR. Args: @@ -966,40 +945,41 @@ def plot_quantized_mecr_area(quantized_mecr_area: Dict[str, pd.DataFrame], outpu output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes. """ - quantized_mecr_area.to_csv(output_path / 'quantized_mecr_area.csv', index=True) + quantized_mecr_area.to_csv(output_path / "quantized_mecr_area.csv", index=True) plt.figure(figsize=(6, 4)) for method, df in quantized_mecr_area.items(): plt.plot( - df['average_area'], - df['average_mecr'], - marker='o', + df["average_area"], + df["average_mecr"], + marker="o", # s=df['variance_mecr'] * 1e5, - linestyle='-', - color=palette.get(method, '#333333'), + linestyle="-", + color=palette.get(method, "#333333"), label=method, - markersize=0 + markersize=0, ) plt.scatter( - df['average_area'], - df['average_mecr'], - s=df['variance_mecr'] * 1e5, # Size of points based on the variance of MECR - color=palette.get(method, '#333333'), + df["average_area"], + df["average_mecr"], + s=df["variance_mecr"] * 1e5, # Size of points based on the variance of MECR + color=palette.get(method, "#333333"), alpha=0.7, # Slight transparency for overlapping points - edgecolor='w', # White edge color for better visibility - linewidth=0.5 # Thin edge line + edgecolor="w", # White edge color for better visibility + linewidth=0.5, # Thin edge line ) - plt.title('Quantized MECR by Cell Area') - plt.xlabel('Average Cell Area') - plt.ylabel('Average MECR') + plt.title("Quantized MECR by Cell Area") + plt.xlabel("Average Cell Area") + plt.ylabel("Average MECR") # Place the legend outside the plot on the top right - plt.legend(title='', bbox_to_anchor=(1.05, 1), loc='upper left') + plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left") plt.tight_layout() - plt.savefig(output_path / 'quantized_mecr_area_plot.pdf', bbox_inches='tight') + plt.savefig(output_path / "quantized_mecr_area_plot.pdf", bbox_inches="tight") plt.show() - -def plot_contamination_results(contamination_results: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str]) -> None: +def plot_contamination_results( + contamination_results: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str] +) -> None: """Plot contamination results for each segmentation method. Args: @@ -1007,18 +987,18 @@ def plot_contamination_results(contamination_results: Dict[str, pd.DataFrame], o output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes. """ - contamination_results.to_csv(output_path / 'contamination_results.csv', index=True) + contamination_results.to_csv(output_path / "contamination_results.csv", index=True) for method, df in contamination_results.items(): plt.figure(figsize=(10, 6)) - sns.heatmap(df, annot=True, cmap='coolwarm', linewidths=0.5) - plt.title(f'Contamination Matrix for {method}') - plt.xlabel('Target Cell Type') - plt.ylabel('Source Cell Type') + sns.heatmap(df, annot=True, cmap="coolwarm", linewidths=0.5) + plt.title(f"Contamination Matrix for {method}") + plt.xlabel("Target Cell Type") + plt.ylabel("Source Cell Type") plt.tight_layout() - plt.savefig(output_path / f'{method}_contamination_matrix.pdf', bbox_inches='tight') + plt.savefig(output_path / f"{method}_contamination_matrix.pdf", bbox_inches="tight") plt.show() - - + + def plot_contamination_boxplots(boxplot_data: pd.DataFrame, output_path: Path, palette: Dict[str, str]) -> None: """Plot boxplots for contamination values across different segmentation methods. @@ -1027,31 +1007,25 @@ def plot_contamination_boxplots(boxplot_data: pd.DataFrame, output_path: Path, p output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes. """ - boxplot_data.to_csv(output_path / 'contamination_box_results.csv', index=True) + boxplot_data.to_csv(output_path / "contamination_box_results.csv", index=True) plt.figure(figsize=(14, 8)) - sns.boxplot( - x='Source Cell Type', - y='Contamination', - hue='Segmentation Method', - data=boxplot_data, - palette=palette - ) - plt.title('Neighborhood Contamination') - plt.xlabel('Source Cell Type') - plt.ylabel('Contamination') - plt.legend(title='', bbox_to_anchor=(1.05, 1), loc='upper left') - plt.xticks(rotation=45, ha='right') - + sns.boxplot(x="Source Cell Type", y="Contamination", hue="Segmentation Method", data=boxplot_data, palette=palette) + plt.title("Neighborhood Contamination") + plt.xlabel("Source Cell Type") + plt.ylabel("Contamination") + plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left") + plt.xticks(rotation=45, ha="right") + plt.tight_layout() - plt.savefig(output_path / 'contamination_boxplots.pdf', bbox_inches='tight') + plt.savefig(output_path / "contamination_boxplots.pdf", bbox_inches="tight") plt.show() - - + + def plot_umaps_with_scores( - segmentations_dict: Dict[str, sc.AnnData], - clustering_scores: Dict[str, Tuple[float, float]], - output_path: Path, - palette: Dict[str, str] + segmentations_dict: Dict[str, sc.AnnData], + clustering_scores: Dict[str, Tuple[float, float]], + output_path: Path, + palette: Dict[str, str], ) -> None: """Plot UMAPs colored by cell type for each segmentation method and display clustering scores in the title. Args: @@ -1069,17 +1043,15 @@ def plot_umaps_with_scores( plt.figure(figsize=(8, 6)) sc.pp.neighbors(adata_copy, n_neighbors=5) sc.tl.umap(adata_copy, spread=5) - sc.pl.umap(adata_copy, color='celltype_major', palette=palette, show=False) + sc.pl.umap(adata_copy, color="celltype_major", palette=palette, show=False) # Add clustering scores to the title - ch_score, sh_score = compute_clustering_scores(adata_copy, cell_type_column='celltype_major') + ch_score, sh_score = compute_clustering_scores(adata_copy, cell_type_column="celltype_major") plt.title(f"{method} - UMAP\nCalinski-Harabasz: {ch_score:.2f}, Silhouette: {sh_score:.2f}") # Save the figure - plt.savefig(output_path / f'{method}_umap_with_scores.pdf', bbox_inches='tight') + plt.savefig(output_path / f"{method}_umap_with_scores.pdf", bbox_inches="tight") plt.show() - - def plot_entropy_boxplots(entropy_boxplot_data: pd.DataFrame, output_path: Path, palette: Dict[str, str]) -> None: """Plot boxplots for neighborhood entropy across different segmentation methods by cell type. @@ -1090,45 +1062,37 @@ def plot_entropy_boxplots(entropy_boxplot_data: pd.DataFrame, output_path: Path, """ plt.figure(figsize=(14, 8)) sns.boxplot( - x='Cell Type', - y='Neighborhood Entropy', - hue='Segmentation Method', - data=entropy_boxplot_data, - palette=palette + x="Cell Type", y="Neighborhood Entropy", hue="Segmentation Method", data=entropy_boxplot_data, palette=palette ) - plt.title('Neighborhood Entropy') - plt.xlabel('Cell Type') - plt.ylabel('Neighborhood Entropy') - plt.legend(title='', bbox_to_anchor=(1.05, 1), loc='upper left') - plt.xticks(rotation=45, ha='right') + plt.title("Neighborhood Entropy") + plt.xlabel("Cell Type") + plt.ylabel("Neighborhood Entropy") + plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left") + plt.xticks(rotation=45, ha="right") plt.tight_layout() - plt.savefig(output_path / 'neighborhood_entropy_boxplots.pdf', bbox_inches='tight') + plt.savefig(output_path / "neighborhood_entropy_boxplots.pdf", bbox_inches="tight") plt.show() - - -def plot_sensitivity_boxplots(sensitivity_boxplot_data: pd.DataFrame, output_path: Path, palette: Dict[str, str]) -> None: +def plot_sensitivity_boxplots( + sensitivity_boxplot_data: pd.DataFrame, output_path: Path, palette: Dict[str, str] +) -> None: """Plot boxplots for sensitivity across different segmentation methods by cell type. Args: sensitivity_boxplot_data (pd.DataFrame): DataFrame containing sensitivity data for all segmentation methods. output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes. """ - sensitivity_boxplot_data.to_csv(output_path / 'sensitivity_results.csv', index=True) + sensitivity_boxplot_data.to_csv(output_path / "sensitivity_results.csv", index=True) plt.figure(figsize=(14, 8)) sns.boxplot( - x='Cell Type', - y='Sensitivity', - hue='Segmentation Method', - data=sensitivity_boxplot_data, - palette=palette + x="Cell Type", y="Sensitivity", hue="Segmentation Method", data=sensitivity_boxplot_data, palette=palette ) - plt.title('Sensitivity Score') - plt.xlabel('Cell Type') - plt.ylabel('Sensitivity') - plt.legend(title='', bbox_to_anchor=(1.05, 1), loc='upper left') - plt.xticks(rotation=45, ha='right') + plt.title("Sensitivity Score") + plt.xlabel("Cell Type") + plt.ylabel("Sensitivity") + plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left") + plt.xticks(rotation=45, ha="right") plt.tight_layout() - plt.savefig(output_path / 'sensitivity_boxplots.pdf', bbox_inches='tight') - plt.show() \ No newline at end of file + plt.savefig(output_path / "sensitivity_boxplots.pdf", bbox_inches="tight") + plt.show() diff --git a/src/segger/validation/xenium_explorer.py b/src/segger/validation/xenium_explorer.py index d4edcd9..6ad5fc9 100644 --- a/src/segger/validation/xenium_explorer.py +++ b/src/segger/validation/xenium_explorer.py @@ -10,7 +10,6 @@ from typing import Dict, Any, Optional, List, Tuple - def str_to_uint32(cell_id_str: str) -> Tuple[int, int]: """Convert a string cell ID back to uint32 format. @@ -20,18 +19,31 @@ def str_to_uint32(cell_id_str: str) -> Tuple[int, int]: Returns: Tuple[int, int]: The cell ID in uint32 format and the dataset suffix. """ - prefix, suffix = cell_id_str.split('-') + prefix, suffix = cell_id_str.split("-") str_to_hex_mapping = { - 'a': '0', 'b': '1', 'c': '2', 'd': '3', - 'e': '4', 'f': '5', 'g': '6', 'h': '7', - 'i': '8', 'j': '9', 'k': 'a', 'l': 'b', - 'm': 'c', 'n': 'd', 'o': 'e', 'p': 'f' + "a": "0", + "b": "1", + "c": "2", + "d": "3", + "e": "4", + "f": "5", + "g": "6", + "h": "7", + "i": "8", + "j": "9", + "k": "a", + "l": "b", + "m": "c", + "n": "d", + "o": "e", + "p": "f", } - hex_prefix = ''.join([str_to_hex_mapping[char] for char in prefix]) + hex_prefix = "".join([str_to_hex_mapping[char] for char in prefix]) cell_id_uint32 = int(hex_prefix, 16) dataset_suffix = int(suffix) return cell_id_uint32, dataset_suffix + def get_indices_indptr(input_array: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Get the indices and indptr arrays for sparse matrix representation. @@ -47,13 +59,14 @@ def get_indices_indptr(input_array: np.ndarray) -> Tuple[np.ndarray, np.ndarray] for cluster in clusters: cluster_indices = np.where(input_array == cluster)[0] - indptr[cluster-1] = len(indices) + indptr[cluster - 1] = len(indices) indices.extend(cluster_indices) indices.extend(-np.zeros(len(input_array[input_array == 0]))) indices = np.array(indices, dtype=np.int32).astype(np.uint32) return indices, indptr + def save_cell_clustering(merged: pd.DataFrame, zarr_path: str, columns: List[str]) -> None: """Save cell clustering information to a Zarr file. @@ -64,35 +77,38 @@ def save_cell_clustering(merged: pd.DataFrame, zarr_path: str, columns: List[str """ import zarr - new_zarr = zarr.open(zarr_path, mode='w') - new_zarr.create_group('/cell_groups') + new_zarr = zarr.open(zarr_path, mode="w") + new_zarr.create_group("/cell_groups") mappings = [] for index, column in enumerate(columns): - new_zarr['cell_groups'].create_group(index) + new_zarr["cell_groups"].create_group(index) classes = list(np.unique(merged[column].astype(str))) - mapping_dict = {key: i for i, key in zip(range(1, len(classes)), [k for k in classes if k != 'nan'])} - mapping_dict['nan'] = 0 + mapping_dict = {key: i for i, key in zip(range(1, len(classes)), [k for k in classes if k != "nan"])} + mapping_dict["nan"] = 0 clusters = merged[column].astype(str).replace(mapping_dict).values.astype(int) indices, indptr = get_indices_indptr(clusters) - new_zarr['cell_groups'][index].create_dataset('indices', data=indices) - new_zarr['cell_groups'][index].create_dataset('indptr', data=indptr) + new_zarr["cell_groups"][index].create_dataset("indices", data=indices) + new_zarr["cell_groups"][index].create_dataset("indptr", data=indptr) mappings.append(mapping_dict) - new_zarr['cell_groups'].attrs.update({ - "major_version": 1, - "minor_version": 0, - "number_groupings": len(columns), - "grouping_names": columns, - "group_names": [ - [k for k, v in sorted(mapping_dict.items(), key=lambda item: item[1])][1:] for mapping_dict in mappings - ] - }) + new_zarr["cell_groups"].attrs.update( + { + "major_version": 1, + "minor_version": 0, + "number_groupings": len(columns), + "grouping_names": columns, + "group_names": [ + [k for k, v in sorted(mapping_dict.items(), key=lambda item: item[1])][1:] for mapping_dict in mappings + ], + } + ) new_zarr.store.close() -def draw_umap(adata, column: str = 'leiden') -> None: + +def draw_umap(adata, column: str = "leiden") -> None: """Draw UMAP plots for the given AnnData object. Args: @@ -102,12 +118,13 @@ def draw_umap(adata, column: str = 'leiden') -> None: sc.pl.umap(adata, color=[column]) plt.show() - sc.pl.umap(adata, color=['KRT5', 'KRT7'], vmax='p95') + sc.pl.umap(adata, color=["KRT5", "KRT7"], vmax="p95") plt.show() - sc.pl.umap(adata, color=['ACTA2', 'PTPRC'], vmax='p95') + sc.pl.umap(adata, color=["ACTA2", "PTPRC"], vmax="p95") plt.show() + def get_leiden_umap(adata, draw: bool = False): """Perform Leiden clustering and UMAP visualization on the given AnnData object. @@ -123,12 +140,9 @@ def get_leiden_umap(adata, draw: bool = False): gene_names = adata.var_names mean_expression_values = adata.X.mean(axis=0) - gene_mean_expression_df = pd.DataFrame({ - 'gene_name': gene_names, - 'mean_expression': mean_expression_values - }) - top_genes = gene_mean_expression_df.sort_values(by='mean_expression', ascending=False).head(30) - top_gene_names = top_genes['gene_name'].tolist() + gene_mean_expression_df = pd.DataFrame({"gene_name": gene_names, "mean_expression": mean_expression_values}) + top_genes = gene_mean_expression_df.sort_values(by="mean_expression", ascending=False).head(30) + top_gene_names = top_genes["gene_name"].tolist() sc.pp.normalize_total(adata) sc.pp.log1p(adata) @@ -137,11 +151,12 @@ def get_leiden_umap(adata, draw: bool = False): sc.tl.leiden(adata) if draw: - draw_umap(adata, 'leiden') + draw_umap(adata, "leiden") return adata -def get_median_expression_table(adata, column: str = 'leiden') -> pd.DataFrame: + +def get_median_expression_table(adata, column: str = "leiden") -> pd.DataFrame: """Get the median expression table for the given AnnData object. Args: @@ -151,7 +166,23 @@ def get_median_expression_table(adata, column: str = 'leiden') -> pd.DataFrame: Returns: pd.DataFrame: The median expression table. """ - top_genes = ['GATA3', 'ACTA2', 'KRT7', 'KRT8', 'KRT5', 'AQP1', 'SERPINA3', 'PTGDS', 'CXCR4', 'SFRP1', 'ENAH', 'MYH11', 'SVIL', 'KRT14', 'CD4'] + top_genes = [ + "GATA3", + "ACTA2", + "KRT7", + "KRT8", + "KRT5", + "AQP1", + "SERPINA3", + "PTGDS", + "CXCR4", + "SFRP1", + "ENAH", + "MYH11", + "SVIL", + "KRT14", + "CD4", + ] top_gene_indices = [adata.var_names.get_loc(gene) for gene in top_genes] clusters = adata.obs[column] @@ -160,26 +191,29 @@ def get_median_expression_table(adata, column: str = 'leiden') -> pd.DataFrame: for cluster in clusters.unique(): cluster_cells = adata[clusters == cluster].X cluster_expression = cluster_cells[:, top_gene_indices] - gene_medians = [pd.Series(cluster_expression[:, gene_idx]).median() for gene_idx in range(len(top_gene_indices))] - cluster_data[f'Cluster_{cluster}'] = gene_medians + gene_medians = [ + pd.Series(cluster_expression[:, gene_idx]).median() for gene_idx in range(len(top_gene_indices)) + ] + cluster_data[f"Cluster_{cluster}"] = gene_medians cluster_expression_df = pd.DataFrame(cluster_data, index=top_genes) sorted_columns = sorted(cluster_expression_df.columns.values, key=lambda x: int(x.split("_")[-1])) cluster_expression_df = cluster_expression_df[sorted_columns] - return cluster_expression_df.T.style.background_gradient(cmap='Greens') + return cluster_expression_df.T.style.background_gradient(cmap="Greens") + def seg2explorer( - seg_df: pd.DataFrame, - source_path: str, - output_dir: str, - cells_filename: str = 'seg_cells', - analysis_filename: str = "seg_analysis", + seg_df: pd.DataFrame, + source_path: str, + output_dir: str, + cells_filename: str = "seg_cells", + analysis_filename: str = "seg_analysis", xenium_filename: str = "seg_experiment.xenium", - analysis_df: Optional[pd.DataFrame] = None, - draw: bool = False, - cell_id_columns: str = 'seg_cell_id', + analysis_df: Optional[pd.DataFrame] = None, + draw: bool = False, + cell_id_columns: str = "seg_cell_id", area_low: float = 10, - area_high: float = 100 + area_high: float = 100, ) -> None: """Convert seg output to a format compatible with Xenium explorer. @@ -214,8 +248,8 @@ def seg2explorer( for cell_incremental_id, (seg_cell_id, seg_cell) in tqdm(enumerate(grouped_by), total=len(grouped_by)): if len(seg_cell) < 5: continue - - cell_convex_hull = ConvexHull(seg_cell[['x_location', 'y_location']]) + + cell_convex_hull = ConvexHull(seg_cell[["x_location", "y_location"]]) if cell_convex_hull.area > area_high: continue if cell_convex_hull.area < area_low: @@ -224,25 +258,31 @@ def seg2explorer( uint_cell_id = cell_incremental_id + 1 cell_id2old_id[uint_cell_id] = seg_cell_id - seg_nucleous = seg_cell[seg_cell['overlaps_nucleus'] == 1] + seg_nucleous = seg_cell[seg_cell["overlaps_nucleus"] == 1] if len(seg_nucleous) >= 3: - nucleus_convex_hull = ConvexHull(seg_nucleous[['x_location', 'y_location']]) + nucleus_convex_hull = ConvexHull(seg_nucleous[["x_location", "y_location"]]) cell_id.append(uint_cell_id) - cell_summary.append({ - "cell_centroid_x": seg_cell['x_location'].mean(), - "cell_centroid_y": seg_cell['y_location'].mean(), - "cell_area": cell_convex_hull.area, - "nucleus_centroid_x": seg_cell['x_location'].mean(), - "nucleus_centroid_y": seg_cell['y_location'].mean(), - "nucleus_area": cell_convex_hull.area, - "z_level": (seg_cell.z_location.mean() // 3).round(0) * 3 - }) - - polygon_num_vertices[0].append(len(cell_convex_hull.vertices)) + cell_summary.append( + { + "cell_centroid_x": seg_cell["x_location"].mean(), + "cell_centroid_y": seg_cell["y_location"].mean(), + "cell_area": cell_convex_hull.area, + "nucleus_centroid_x": seg_cell["x_location"].mean(), + "nucleus_centroid_y": seg_cell["y_location"].mean(), + "nucleus_area": cell_convex_hull.area, + "z_level": (seg_cell.z_location.mean() // 3).round(0) * 3, + } + ) + + polygon_num_vertices[0].append(len(cell_convex_hull.vertices)) polygon_num_vertices[1].append(len(nucleus_convex_hull.vertices) if len(seg_nucleous) >= 3 else 0) - polygon_vertices[0].append(seg_cell[['x_location', 'y_location']].values[cell_convex_hull.vertices]) - polygon_vertices[1].append(seg_nucleous[['x_location', 'y_location']].values[nucleus_convex_hull.vertices] if len(seg_nucleous) >= 3 else np.array([[], []]).T) + polygon_vertices[0].append(seg_cell[["x_location", "y_location"]].values[cell_convex_hull.vertices]) + polygon_vertices[1].append( + seg_nucleous[["x_location", "y_location"]].values[nucleus_convex_hull.vertices] + if len(seg_nucleous) >= 3 + else np.array([[], []]).T + ) seg_mask_value.append(cell_incremental_id + 1) cell_polygon_vertices = get_flatten_version(polygon_vertices[0], max_value=21) @@ -251,66 +291,80 @@ def seg2explorer( cells = { "cell_id": np.array([np.array(cell_id), np.ones(len(cell_id))], dtype=np.uint32).T, "cell_summary": pd.DataFrame(cell_summary).values.astype(np.float64), - "polygon_num_vertices": np.array([ - [min(x+1, x+1) for x in polygon_num_vertices[1]], - [min(x+1, x+1) for x in polygon_num_vertices[0]] - ], dtype=np.int32), + "polygon_num_vertices": np.array( + [ + [min(x + 1, x + 1) for x in polygon_num_vertices[1]], + [min(x + 1, x + 1) for x in polygon_num_vertices[0]], + ], + dtype=np.int32, + ), "polygon_vertices": np.array([nucl_polygon_vertices, cell_polygon_vertices]).astype(np.float32), - "seg_mask_value": np.array(seg_mask_value, dtype=np.int32) + "seg_mask_value": np.array(seg_mask_value, dtype=np.int32), } - - existing_store = zarr.open(source_path / 'cells.zarr.zip', mode='r') - new_store = zarr.open(storage / f'{cells_filename}.zarr.zip', mode='w') - - new_store['cell_id'] = cells['cell_id'] - new_store['polygon_num_vertices'] = cells['polygon_num_vertices'] - new_store['polygon_vertices'] = cells['polygon_vertices'] - new_store['seg_mask_value'] = cells['seg_mask_value'] - + + existing_store = zarr.open(source_path / "cells.zarr.zip", mode="r") + new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") + + new_store["cell_id"] = cells["cell_id"] + new_store["polygon_num_vertices"] = cells["polygon_num_vertices"] + new_store["polygon_vertices"] = cells["polygon_vertices"] + new_store["seg_mask_value"] = cells["seg_mask_value"] + new_store.attrs.update(existing_store.attrs) - new_store.attrs['number_cells'] = len(cells['cell_id']) + new_store.attrs["number_cells"] = len(cells["cell_id"]) new_store.store.close() - + if analysis_df is None: analysis_df = pd.DataFrame([cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns]) - analysis_df['default'] = 'seg' - + analysis_df["default"] = "seg" + zarr_df = pd.DataFrame([cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns]) - clustering_df = pd.merge(zarr_df, analysis_df, how='left', on=cell_id_columns) + clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_columns) clusters_names = [i for i in analysis_df.columns if i != cell_id_columns] - clusters_dict = {cluster: {j: i for i, j in zip(range(1, len(sorted(np.unique(clustering_df[cluster].dropna()))) + 1), sorted(np.unique(clustering_df[cluster].dropna())))} for cluster in clusters_names} + clusters_dict = { + cluster: { + j: i + for i, j in zip( + range(1, len(sorted(np.unique(clustering_df[cluster].dropna()))) + 1), + sorted(np.unique(clustering_df[cluster].dropna())), + ) + } + for cluster in clusters_names + } - new_zarr = zarr.open(storage / (analysis_filename + ".zarr.zip"), mode='w') - new_zarr.create_group('/cell_groups') + new_zarr = zarr.open(storage / (analysis_filename + ".zarr.zip"), mode="w") + new_zarr.create_group("/cell_groups") clusters = [[clusters_dict[cluster].get(x, 0) for x in list(clustering_df[cluster])] for cluster in clusters_names] for i in range(len(clusters)): - new_zarr['cell_groups'].create_group(i) + new_zarr["cell_groups"].create_group(i) indices, indptr = get_indices_indptr(np.array(clusters[i])) - new_zarr['cell_groups'][i].create_dataset('indices', data=indices) - new_zarr['cell_groups'][i].create_dataset('indptr', data=indptr) - - new_zarr['cell_groups'].attrs.update({ - "major_version": 1, - "minor_version": 0, - "number_groupings": len(clusters_names), - "grouping_names": clusters_names, - "group_names": [ - [x[0] for x in sorted(clusters_dict[cluster].items(), key=lambda x: x[1])] - for cluster in clusters_names - ] - }) + new_zarr["cell_groups"][i].create_dataset("indices", data=indices) + new_zarr["cell_groups"][i].create_dataset("indptr", data=indptr) + + new_zarr["cell_groups"].attrs.update( + { + "major_version": 1, + "minor_version": 0, + "number_groupings": len(clusters_names), + "grouping_names": clusters_names, + "group_names": [ + [x[0] for x in sorted(clusters_dict[cluster].items(), key=lambda x: x[1])] for cluster in clusters_names + ], + } + ) new_zarr.store.close() generate_experiment_file( - template_path=source_path / 'experiment.xenium', + template_path=source_path / "experiment.xenium", output_path=storage / xenium_filename, cells_name=cells_filename, - analysis_name=analysis_filename + analysis_name=analysis_filename, ) + def get_flatten_version(polygons: List[np.ndarray], max_value: int = 21) -> np.ndarray: """Get the flattened version of polygon vertices. @@ -326,23 +380,21 @@ def get_flatten_version(polygons: List[np.ndarray], max_value: int = 21) -> np.n for i, polygon in tqdm(enumerate(polygons), total=len(polygons)): num_points = len(polygon) if num_points == 0: - result[i] = np.zeros(n*2) + result[i] = np.zeros(n * 2) continue elif num_points < max_value: repeated_points = np.tile(polygon[0], (n - num_points, 1)) padded_polygon = np.concatenate((polygon, repeated_points), axis=0) else: padded_polygon = np.zeros((n, 2)) - padded_polygon[:min(num_points, n)] = polygon[:min(num_points, n)] + padded_polygon[: min(num_points, n)] = polygon[: min(num_points, n)] padded_polygon[-1] = polygon[0] result[i] = padded_polygon.flatten() return result + def generate_experiment_file( - template_path: str, - output_path: str, - cells_name: str = "seg_cells", - analysis_name: str = 'seg_analysis' + template_path: str, output_path: str, cells_name: str = "seg_cells", analysis_name: str = "seg_analysis" ) -> None: """Generate the experiment file for Xenium. @@ -357,12 +409,12 @@ def generate_experiment_file( with open(template_path) as f: experiment = json.load(f) - experiment['images'].pop('morphology_filepath') - experiment['images'].pop('morphology_focus_filepath') + experiment["images"].pop("morphology_filepath") + experiment["images"].pop("morphology_focus_filepath") - experiment['xenium_explorer_files']['cells_zarr_filepath'] = f"{cells_name}.zarr.zip" - experiment['xenium_explorer_files'].pop('cell_features_zarr_filepath') - experiment['xenium_explorer_files']['analysis_zarr_filepath'] = f"{analysis_name}.zarr.zip" + experiment["xenium_explorer_files"]["cells_zarr_filepath"] = f"{cells_name}.zarr.zip" + experiment["xenium_explorer_files"].pop("cell_features_zarr_filepath") + experiment["xenium_explorer_files"]["analysis_zarr_filepath"] = f"{analysis_name}.zarr.zip" - with open(output_path, 'w') as f: + with open(output_path, "w") as f: json.dump(experiment, f, indent=2) diff --git a/tests/test_data.py b/tests/test_data.py index 3f3c5b8..7712067 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -7,46 +7,44 @@ import unittest import pandas as pd + class TestDataUtils(unittest.TestCase): def test_filter_transcripts(self): - data = { - 'qv': [30, 10, 25], - 'feature_name': ['gene1', 'NegControlProbe_gene2', 'gene3'] - } + data = {"qv": [30, 10, 25], "feature_name": ["gene1", "NegControlProbe_gene2", "gene3"]} df = pd.DataFrame(data) filtered_df = filter_transcripts(df, min_qv=20) self.assertEqual(len(filtered_df), 2) - self.assertTrue('gene1' in filtered_df['feature_name'].values) - self.assertTrue('gene3' in filtered_df['feature_name'].values) + self.assertTrue("gene1" in filtered_df["feature_name"].values) + self.assertTrue("gene3" in filtered_df["feature_name"].values) def test_compute_transcript_metrics(self): data = { - 'qv': [40, 40, 25, 25], - 'feature_name': ['gene1', 'gene2', 'gene1', 'gene2'], - 'cell_id': [1, 1, -1, 2], - 'overlaps_nucleus': [1, 0, 0, 1] + "qv": [40, 40, 25, 25], + "feature_name": ["gene1", "gene2", "gene1", "gene2"], + "cell_id": [1, 1, -1, 2], + "overlaps_nucleus": [1, 0, 0, 1], } df = pd.DataFrame(data) metrics = compute_transcript_metrics(df, qv_threshold=30) - self.assertAlmostEqual(metrics['percent_assigned'], 50.0) - self.assertAlmostEqual(metrics['percent_cytoplasmic'], 50.0) - self.assertAlmostEqual(metrics['percent_nucleus'], 50.0) - self.assertAlmostEqual(metrics['percent_non_assigned_cytoplasmic'], 100.0) - self.assertEqual(len(metrics['gene_metrics']), 2) - self.assertTrue('gene1' in metrics['gene_metrics']['feature_name'].values) - self.assertTrue('gene2' in metrics['gene_metrics']['feature_name'].values) - + self.assertAlmostEqual(metrics["percent_assigned"], 50.0) + self.assertAlmostEqual(metrics["percent_cytoplasmic"], 50.0) + self.assertAlmostEqual(metrics["percent_nucleus"], 50.0) + self.assertAlmostEqual(metrics["percent_non_assigned_cytoplasmic"], 100.0) + self.assertEqual(len(metrics["gene_metrics"]), 2) + self.assertTrue("gene1" in metrics["gene_metrics"]["feature_name"].values) + self.assertTrue("gene2" in metrics["gene_metrics"]["feature_name"].values) + def setUp(self): data = { - 'x_location': [100, 200, 300], - 'y_location': [100, 200, 300], - 'z_location': [0, 0, 0], - 'qv': [40, 40, 25], - 'feature_name': ['gene1', 'gene2', 'gene3'], - 'transcript_id': [1, 2, 3], - 'overlaps_nucleus': [1, 0, 1], - 'cell_id': [1, -1, 2] + "x_location": [100, 200, 300], + "y_location": [100, 200, 300], + "z_location": [0, 0, 0], + "qv": [40, 40, 25], + "feature_name": ["gene1", "gene2", "gene3"], + "transcript_id": [1, 2, 3], + "overlaps_nucleus": [1, 0, 1], + "cell_id": [1, -1, 2], } self.df = pd.DataFrame(data) self.sample = XeniumSample(self.df) @@ -54,21 +52,18 @@ def setUp(self): def test_crop_transcripts(self): cropped_sample = self.sample.crop_transcripts(50, 50, 200, 200) self.assertEqual(len(cropped_sample.transcripts_df), 1) - self.assertEqual(cropped_sample.transcripts_df.iloc[0]['feature_name'], 'gene1') + self.assertEqual(cropped_sample.transcripts_df.iloc[0]["feature_name"], "gene1") def test_filter_transcripts(self): filtered_df = XeniumSample.filter_transcripts(self.df, min_qv=30) self.assertEqual(len(filtered_df), 2) - self.assertTrue('gene1' in filtered_df['feature_name'].values) - self.assertTrue('gene2' in filtered_df['feature_name'].values) + self.assertTrue("gene1" in filtered_df["feature_name"].values) + self.assertTrue("gene2" in filtered_df["feature_name"].values) def test_unassign_all_except_nucleus(self): unassigned_df = XeniumSample.unassign_all_except_nucleus(self.df) - self.assertEqual(unassigned_df.loc[unassigned_df['overlaps_nucleus'] == 0, 'cell_id'].values[0], 'UNASSIGNED') + self.assertEqual(unassigned_df.loc[unassigned_df["overlaps_nucleus"] == 0, "cell_id"].values[0], "UNASSIGNED") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() - - - diff --git a/tests/test_model.py b/tests/test_model.py index 802c17b..b6dfdf0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,21 +4,18 @@ from torch_geometric.nn import to_hetero from torch_geometric.data import HeteroData + class TestSeggerModel(unittest.TestCase): def setUp(self): - model = Segger( - init_emb=16, hidden_channels=32, out_channels=32, heads=3 - ) - metadata = ( - ["tx", "nc"], [("tx", "belongs", "nc"), ("tx", "neighbors", "tx")] - ) - self.model = to_hetero(model, metadata=metadata, aggr='sum') + model = Segger(init_emb=16, hidden_channels=32, out_channels=32, heads=3) + metadata = (["tx", "nc"], [("tx", "belongs", "nc"), ("tx", "neighbors", "tx")]) + self.model = to_hetero(model, metadata=metadata, aggr="sum") self.data = HeteroData() - self.data['tx'].x = torch.randn(10, 16) - self.data['nc'].x = torch.randn(5, 16) - self.data['tx', 'belongs', 'nc'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]], dtype=torch.long) - self.data['tx', 'neighbors', 'tx'].edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) + self.data["tx"].x = torch.randn(10, 16) + self.data["nc"].x = torch.randn(5, 16) + self.data["tx", "belongs", "nc"].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]], dtype=torch.long) + self.data["tx", "neighbors", "tx"].edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) def test_forward(self): out = self.model(self.data.x_dict, self.data.edge_index_dict) @@ -26,13 +23,15 @@ def test_forward(self): self.assertTrue("nc" in out) self.assertEqual(out["tx"].shape[1], 32 * 3) self.assertEqual(out["nc"].shape[1], 32 * 3) - ''' + + """ def test_decode(self): z = {'tx': torch.randn(10, 16), 'nc': torch.randn(5, 16)} edge_label_index = torch.tensor([[0, 1, 2], [0, 1, 2]], dtype=torch.long) out = self.model.decode(z, edge_label_index) self.assertEqual(out.shape[0], 3) - ''' + """ + if __name__ == "__main__": unittest.main() diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 9d90316..fd77227 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -4,21 +4,23 @@ from segger.models.segger_model import Segger from torch_geometric.data import HeteroData + class TestPrediction(unittest.TestCase): def setUp(self): self.model = Segger(init_emb=16, hidden_channels=32, out_channels=32, heads=3) - self.lit_model = load_model("path/to/checkpoint", 16, 32, 32, 3, 'sum') + self.lit_model = load_model("path/to/checkpoint", 16, 32, 32, 3, "sum") self.data = HeteroData() - self.data['tx'].x = torch.randn(10, 16) - self.data['nc'].x = torch.randn(5, 16) - self.data['tx', 'belongs', 'nc'].edge_label_index = torch.tensor([[0, 1, 2], [0, 1, 2]], dtype=torch.long) - self.data['tx', 'neighbors', 'tx'].edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) + self.data["tx"].x = torch.randn(10, 16) + self.data["nc"].x = torch.randn(5, 16) + self.data["tx", "belongs", "nc"].edge_label_index = torch.tensor([[0, 1, 2], [0, 1, 2]], dtype=torch.long) + self.data["tx", "neighbors", "tx"].edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) def test_predict(self): output_path = "path/to/output.csv.gz" predict(self.lit_model, "path/to/dataset", output_path, 0.5, 4, 20, 5, 10) self.assertTrue(os.path.exists(output_path)) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_training.py b/tests/test_training.py index 5154fef..11615f8 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -4,40 +4,32 @@ from torch_geometric.data import HeteroData import torch + class TestTraining(unittest.TestCase): def setUp(self): # Setup model and data - metadata = ( - ["tx", "nc"], [("tx", "belongs", "nc"), ("tx", "neighbors", "tx")] - ) + metadata = (["tx", "nc"], [("tx", "belongs", "nc"), ("tx", "neighbors", "tx")]) self.lit_segger = LitSegger( init_emb=16, hidden_channels=32, out_channels=32, heads=3, metadata=metadata, - aggr='sum', + aggr="sum", ) self.data = HeteroData() self.data["tx"].x = torch.randn(10, 16) self.data["nc"].x = torch.randn(5, 16) - self.data["tx", "belongs", "nc"].edge_label_index = torch.tensor( - [[0, 1, 2], [0, 1, 2]], dtype=torch.long - ) - self.data["tx", "belongs", "nc"].edge_label = torch.tensor( - [1.0, 0.0, 1.0], dtype=torch.float - ) - self.data["tx", "neighbors", "tx"].edge_index = torch.tensor( - [[0, 1], [1, 2]], dtype=torch.long - ) - + self.data["tx", "belongs", "nc"].edge_label_index = torch.tensor([[0, 1, 2], [0, 1, 2]], dtype=torch.long) + self.data["tx", "belongs", "nc"].edge_label = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float) + self.data["tx", "neighbors", "tx"].edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) + # Move model and data to GPU self.lit_segger.cuda() self.data.to("cuda") - def test_training_step(self): optimizer = self.lit_segger.configure_optimizers() self.lit_segger.train() @@ -47,5 +39,6 @@ def test_training_step(self): optimizer.step() self.assertGreater(loss.item(), 0) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main()