Skip to content

Commit

Permalink
Merge pull request #361 from Clinical-Genomics/speedup_gene_overview
Browse files Browse the repository at this point in the history
Speedup gene overview
  • Loading branch information
northwestwitch authored Oct 4, 2024
2 parents bccaba0 + 30ae18e commit 1046858
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 190 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Updated tests workflow to cargo install the latest d4tools from git (master branch)
- Computing coverage completeness stats using d4tools `perc_cov` stat function (much quicker reports)
- Moved functions computing the coverage stats to a separate `meta/handle_coverage_stats.py` module
- Refactored code collecting stats shown on gene overview report
### Fixed
- Updated dependencies including `certifi` to address dependabot alert
- Update pytest to v.7.4.4 to address a `ReDoS` vulnerability
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM clinicalgenomics/python3.11-venv-d4tools:3.0
FROM clinicalgenomics/python3.11-venv-d4tools:3.0.1

LABEL about.home="https://github.com/Clinical-Genomics/chanjo2"
LABEL about.license="MIT License (MIT)"
Expand Down
198 changes: 54 additions & 144 deletions src/chanjo2/meta/handle_d4.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import logging
import tempfile
from statistics import mean
from typing import Dict, List, Optional, Tuple, Union

from sqlalchemy.orm import Session

from chanjo2.crud.intervals import get_gene_intervals, set_sql_intervals
from chanjo2.meta.handle_bed import sort_interval_ids_coords
from chanjo2.meta.handle_completeness_stats import get_completeness_stats
from chanjo2.meta.handle_completeness_stats import (
get_completeness_stats,
get_d4tools_intervals_completeness,
)
from chanjo2.meta.handle_coverage_stats import (
get_d4tools_chromosome_mean_coverage,
get_d4tools_intervals_coverage,
get_d4tools_intervals_mean_coverage,
)
from chanjo2.models import SQLExon, SQLGene, SQLTranscript
from chanjo2.models.pydantic_models import (
GeneCoverage,
IntervalCoverage,
IntervalType,
Sex,
TranscriptTag,
)
from chanjo2.models.pydantic_models import ReportQuerySample, Sex

LOG = logging.getLogger(__name__)


def get_report_sample_interval_coverage(
Expand Down Expand Up @@ -119,139 +118,6 @@ def get_report_sample_interval_coverage(
)


def get_sample_interval_coverage(
db: Session,
d4_file_path: str,
genes: List[SQLGene],
interval_type: Union[SQLGene, SQLTranscript, SQLExon],
completeness_thresholds: List[Optional[int]],
transcript_tags: Optional[List[TranscriptTag]] = [],
) -> List[GeneCoverage]:
"""Compute stats to populate a coverage overview report for one sample."""

if not genes:
return []

genes_coverage_stats: List[GeneCoverage] = []

sql_intervals: List[Union[SQLGene, SQLTranscript, SQLExon]] = set_sql_intervals(
db=db, interval_type=interval_type, genes=genes, transcript_tags=transcript_tags
)
interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = [
(interval.ensembl_id, (interval.chromosome, interval.start, interval.stop))
for interval in sql_intervals
]
interval_ids_coords = sort_interval_ids_coords(interval_ids_coords)

intervals_coverage_completeness: Dict[str, dict] = get_completeness_stats(
d4_file_path=d4_file_path,
thresholds=completeness_thresholds,
interval_ids_coords=interval_ids_coords,
)

# Create GeneCoverage objects
for gene in genes:
gene_coverage = GeneCoverage(
**{
"ensembl_gene_id": gene.ensembl_id,
"hgnc_id": gene.hgnc_id,
"hgnc_symbol": gene.hgnc_symbol,
"interval_type": IntervalType.GENES,
"interval_id": gene.ensembl_id,
"mean_coverage": 0,
"completeness": {},
"inner_intervals": [],
}
)

if interval_type == SQLGene: # The interval requested is the genes itself
gene_coverage.mean_coverage = mean(
get_d4tools_intervals_mean_coverage(
d4_file_path=d4_file_path,
interval_ids_coords=interval_ids_coords,
)
)
gene_coverage.completeness = intervals_coverage_completeness.get(
gene.ensembl_id, {}
)

else: # Retrieve transcripts or exons for this gene

gene_intervals: List[Union[SQLTranscript, SQLExon]] = get_gene_intervals(
db=db,
build=gene.build,
interval_type=interval_type,
ensembl_ids=None,
hgnc_ids=None,
hgnc_symbols=None,
ensembl_gene_ids=[gene.ensembl_id],
limit=None,
transcript_tags=transcript_tags,
)

inner_intervals_ensembl_ids = set()
interval_ids_coords: List[Tuple[str, Tuple]] = []
intervals_mean_completeness: Dict[int:List] = {
threshold: [] for threshold in completeness_thresholds
}

for interval in gene_intervals:
if interval.ensembl_id in inner_intervals_ensembl_ids:
continue

interval_tuple: Tuple[str, Tuple] = (
interval.ensembl_id,
(interval.chromosome, interval.start, interval.stop),
)
interval_ids_coords.append(interval_tuple)

for threshold in completeness_thresholds:
intervals_mean_completeness[threshold].append(
intervals_coverage_completeness[interval.ensembl_id][threshold]
)

interval_coverage = IntervalCoverage(
**{
"interval_type": interval_type.__tablename__,
"interval_id": interval.ensembl_id,
"mean_coverage": mean(
get_d4tools_intervals_mean_coverage(
d4_file_path=d4_file_path,
interval_ids_coords=[interval_tuple],
)
),
"completeness": intervals_coverage_completeness[
interval.ensembl_id
],
}
)

gene_coverage.inner_intervals.append(interval_coverage)
inner_intervals_ensembl_ids.add(interval.ensembl_id)

gene_intervals_mean_coverage: List[float] = (
get_d4tools_intervals_mean_coverage(
d4_file_path=d4_file_path, interval_ids_coords=interval_ids_coords
)
)
gene_coverage.mean_coverage = (
mean(gene_intervals_mean_coverage)
if gene_intervals_mean_coverage
else 0
)

for threshold in completeness_thresholds:
gene_coverage.completeness[threshold] = (
mean(intervals_mean_completeness[threshold])
if intervals_mean_completeness[threshold]
else 0
)

genes_coverage_stats.append(gene_coverage)

return genes_coverage_stats


def predict_sex(x_cov: float, y_cov: float) -> str:
"""Return predict sex based on sex chromosomes coverage - this code is taken from the old chanjo."""
if y_cov == 0:
Expand Down Expand Up @@ -282,3 +148,47 @@ def get_samples_sex_metrics(d4_file_path: str) -> Dict:
x_cov=sex_chroms_coverage[0][1], y_cov=sex_chroms_coverage[1][1]
),
}


def get_gene_overview_stats(
sql_intervals: List[SQLTranscript],
samples: List[ReportQuerySample],
completeness_thresholds: List[int],
) -> Dict[str, list]:
"""Returns stats to be included in the gene overview page."""
interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = [
(interval.ensembl_id, (interval.chromosome, interval.start, interval.stop))
for interval in sql_intervals
]
interval_ids_coords = tuple(
sort_interval_ids_coords(set(interval_ids_coords))
) # removes duplicates and orders intervals by chromosome, start and stop
transcripts_stats = {interval_id: [] for interval_id, _ in interval_ids_coords}

# create a temp bed file containing transcripts coordinates
bed_lines = [
f"{coords[0]}\t{coords[1]}\t{coords[2]}" for _, coords in interval_ids_coords
]
temp_bed_file = tempfile.NamedTemporaryFile()
with open(temp_bed_file.name, "w") as intervals_bed:
intervals_bed.write("\n".join(bed_lines))
intervals_bed.flush()

for sample in samples:
transcripts_coverage = get_d4tools_intervals_coverage(
d4_file_path=sample.coverage_file_path, bed_file_path=temp_bed_file.name
)
transcripts_completeness = get_d4tools_intervals_completeness(
d4_file_path=sample.coverage_file_path,
bed_file_path=temp_bed_file.name,
completeness_thresholds=completeness_thresholds,
)
for idx, transcripts_coords in enumerate(interval_ids_coords):
append_tuple = (
sample.name,
transcripts_coverage[idx],
transcripts_completeness[idx],
)
transcripts_stats[transcripts_coords[0]].append(append_tuple)

return transcripts_stats
58 changes: 13 additions & 45 deletions src/chanjo2/meta/handle_report_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from chanjo2.crud.intervals import get_genes, get_hgnc_gene, set_sql_intervals
from chanjo2.meta.handle_d4 import (
get_gene_overview_stats,
get_report_sample_interval_coverage,
get_sample_interval_coverage,
get_samples_sex_metrics,
)
from chanjo2.models import SQLExon, SQLGene, SQLTranscript
Expand Down Expand Up @@ -201,7 +201,7 @@ def get_report_sex_rows(samples: List[ReportQuerySample]) -> List[Dict]:


def get_gene_overview_coverage_stats(form_data: GeneReportForm, session: Session):
"""Returns coverage stats over the intervals sof one gene for one or more samples."""
"""Returns coverage stats over the intervals of one gene for one or more samples."""

gene_stats = {
"levels": get_ordered_levels(
Expand All @@ -218,49 +218,17 @@ def get_gene_overview_coverage_stats(form_data: GeneReportForm, session: Session
)
if gene is None:
return gene_stats
gene_stats["gene"] = gene

samples_coverage_stats: Dict[str, List[GeneCoverage]] = {
sample.name: get_sample_interval_coverage(
db=session,
d4_file_path=sample.coverage_file_path,
genes=[gene],
interval_type=INTERVAL_TYPE_SQL_TYPE[form_data.interval_type],
completeness_thresholds=form_data.completeness_thresholds,
)
for sample in form_data.samples
}
gene_stats["samples_coverage_stats_by_interval"] = (
get_gene_coverage_stats_by_interval(coverage_by_sample=samples_coverage_stats)
gene_stats["gene"] = gene
sql_intervals = set_sql_intervals(
db=session,
interval_type=SQLTranscript,
genes=[gene],
transcript_tags=[],
)
gene_stats["samples_coverage_stats_by_interval"] = get_gene_overview_stats(
sql_intervals=sql_intervals,
samples=form_data.samples,
completeness_thresholds=form_data.completeness_thresholds,
)
return gene_stats


def get_gene_coverage_stats_by_interval(
coverage_by_sample: Dict[str, List[GeneCoverage]]
) -> Dict[str, List[Tuple]]:
"""Arrange coverage stats by interval id instead of by sample."""

intervals_stats: Dict[str, List] = {}

for sample, stats in coverage_by_sample.items():
for gene_interval in stats:
for inner_interval in gene_interval.inner_intervals:
if inner_interval.interval_id in intervals_stats:
intervals_stats[inner_interval.interval_id].append(
(
sample,
inner_interval.mean_coverage,
inner_interval.completeness,
)
)
else:
intervals_stats[inner_interval.interval_id] = [
(
sample,
inner_interval.mean_coverage,
inner_interval.completeness,
)
]

return intervals_stats

0 comments on commit 1046858

Please sign in to comment.