Skip to content

Commit

Permalink
feat(call): progressive output for TSVs/VCFs
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Feb 4, 2024
1 parent eeda0b5 commit 4339ed0
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 99 deletions.
71 changes: 56 additions & 15 deletions strkit/call/call_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import multiprocessing.managers as mmg
import numpy as np
import os
import pysam
import re
import threading
import time
Expand All @@ -15,13 +16,13 @@
from multiprocessing.synchronize import Event as EventClass # For type hinting

from typing import Literal, Optional
from strkit.logger import logger

from strkit.logger import logger
from .allele import get_n_alleles
from .call_locus import call_locus
from .non_daemonic_pool import NonDaemonicPool
from .params import CallParams
from .output import output_json_report, output_tsv as output_tsv_fn, output_vcf
from .output import output_json_report, output_tsv as output_tsv_fn, build_vcf_header, output_vcf_lines
from .utils import get_new_seed

__all__ = [
Expand Down Expand Up @@ -171,6 +172,7 @@ def call_sample(
vcf_path: Optional[str] = None,
indent_json: bool = False,
output_tsv: bool = True,
output_chunk_size: int = 10000,
) -> None:
# Start the call timer
start_time = datetime.now()
Expand All @@ -182,6 +184,13 @@ def call_sample(
# Seed the random number generator if a seed is provided, for replicability
rng: np.random.Generator = np.random.default_rng(seed=params.seed)

# If we're outputting a VCF, open the file and write the header
sample_id_str = params.sample_id or "sample"
vf: Optional[pysam.VariantFile] = None
if vcf_path is not None:
vh = build_vcf_header(sample_id_str, params.reference_file)
vf = pysam.VariantFile(vcf_path if vcf_path != "stdout" else "-", "w", header=vh)

manager: mmg.SyncManager = mp.Manager()
locus_queue = manager.Queue()

Expand All @@ -200,6 +209,15 @@ def call_sample(
# We use locus-specific random seeds for replicability, no matter which order
# the loci are yanked out of the queue / how many processes we have.
# Tuple of (1-indexed locus index, locus data, locus-specific random seed)

# Add occasional None breaks to make the jobs terminate. Then, as long as we have
# entries in the locus queue, we can get chunks to order and write to disk
# instead of keeping everything in RAM.

if num_loci and (num_loci % output_chunk_size == 0):
for _ in range(params.processes):
locus_queue.put(None)

locus_queue.put((t_idx, t, n_alleles, get_new_seed(rng)))
contig_set.add(contig)
num_loci += 1
Expand All @@ -210,7 +228,11 @@ def call_sample(
locus_queue.put(None)

is_single_processed = params.processes == 1
should_keep_all_results_in_mem = json_path is not None

result_lists = []
# Only populated if we're outputting JSON; otherwise, we don't want to keep everything in memory at once.
all_results: list[dict] = []

pool_class = mpd.Pool if is_single_processed else NonDaemonicPool
finish_event = mp.Event()
Expand Down Expand Up @@ -244,32 +266,51 @@ def call_sample(
snv_genotype_cache,
is_single_processed,
)
jobs = [p.apply_async(locus_worker, job_args) for _ in range(params.processes)]

# Gather the process-specific results for combining.
for j in jobs:
result_lists.append(j.get())
qsize: int = locus_queue.qsize()
while qsize > 0:
jobs = [p.apply_async(locus_worker, job_args) for _ in range(params.processes)]

# Gather the process-specific results for combining.
for j in jobs:
result_lists.append(j.get())

# Write results
# - merge sorted result lists into single sorted list
results: tuple[dict, ...] = tuple(heapq.merge(*result_lists, key=lambda x: x["locus_index"]))

# - write partial results to stdout if we're writing a stdout TSV
if output_tsv:
output_tsv_fn(results, has_snv_vcf=params.snv_vcf is not None)

if should_keep_all_results_in_mem:
all_results.extend(results)

# - write partial results to VCF if we're writing a VCF
if vf is not None:
output_vcf_lines(sample_id_str, vf, results)

last_qsize = qsize
qsize = locus_queue.qsize()
if last_qsize == qsize:
# If this happens, we've stalled out on completing work while having a
# positive qsize, so we need to terminate (but something went wrong.)
logger.warning(f"Terminating with non-zero queue size: {qsize}")
break

finish_event.set()
progress_job.join()

# Merge sorted result lists into single sorted list.
results: tuple[dict, ...] = tuple(heapq.merge(*result_lists, key=lambda x: x["locus_index"]))

time_taken = datetime.now() - start_time

if output_tsv:
output_tsv_fn(results, has_snv_vcf=params.snv_vcf is not None)
logger.info(f"Finished STR genotyping in {time_taken.total_seconds()}s")

if json_path:
output_json_report(
params,
time_taken,
contig_set,
results,
all_results,
json_path,
indent_json,
)

if vcf_path:
output_vcf(params, results, vcf_path)
5 changes: 3 additions & 2 deletions strkit/call/output/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .json_report import output_json_report
from .tsv import output_tsv
from .vcf import output_vcf
from .vcf import build_vcf_header, output_vcf_lines

__all__ = [
"output_json_report",
"output_tsv",
"output_vcf",
"build_vcf_header",
"output_vcf_lines",
]
2 changes: 1 addition & 1 deletion strkit/call/output/json_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def output_json_report(
params: CallParams,
time_taken: timedelta,
contig_set: set[str],
results: tuple[dict, ...],
results: list[dict],
json_path: str,
indent_json: bool,
):
Expand Down
149 changes: 69 additions & 80 deletions strkit/call/output/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import pysam
from os.path import commonprefix

from ..params import CallParams
from ..utils import cat_strs

__all__ = ["output_vcf"]
__all__ = [
"build_vcf_header",
"output_vcf_lines",
]


# VCF_ALLELE_CNV_TR = "<CNV:TR>"
Expand All @@ -22,7 +24,7 @@
# )


def _build_variant_header(sample_id: str, reference_file: str) -> pysam.VariantHeader:
def build_vcf_header(sample_id: str, reference_file: str) -> pysam.VariantHeader:
vh = pysam.VariantHeader() # automatically sets VCF version to 4.2

# Add an absolute path to the reference genome
Expand Down Expand Up @@ -64,16 +66,7 @@ def _reversed_str(s: str) -> str:
return cat_strs(reversed(s))


def output_vcf(
params: CallParams,
results: tuple[dict, ...],
vcf_path: str,
):
sample_id_str: str = params.sample_id or "sample"

vh = _build_variant_header(sample_id_str, params.reference_file)
vf = pysam.VariantFile(vcf_path if vcf_path != "stdout" else "-", "w", header=vh)

def output_vcf_lines(sample_id: str, variant_file: pysam.VariantFile, results: tuple[dict, ...]):
contig_vrs: list[pysam.VariantRecord] = []

def _write_contig_vrs():
Expand All @@ -82,96 +75,92 @@ def _write_contig_vrs():

# write them to the VCF
for contig_vr in contig_vrs:
vf.write(contig_vr)
variant_file.write(contig_vr)

# clear the contig variant record list for the new contig
contig_vrs.clear()

try:
last_contig = results[0]["contig"] if results else ""
last_contig = results[0]["contig"] if results else ""

# has_at_least_one_snv_set = next((r.get("snvs") is not None for r in results), None) is not None
snvs_written: set[str] = set()
# has_at_least_one_snv_set = next((r.get("snvs") is not None for r in results), None) is not None
snvs_written: set[str] = set()

for result_idx, result in enumerate(results, 1):
contig = result["contig"]
for result_idx, result in enumerate(results, 1):
contig = result["contig"]

if contig != last_contig:
# we moved on from the last contig, so write the last batch of variant records to the VCF
_write_contig_vrs()
if contig != last_contig:
# we moved on from the last contig, so write the last batch of variant records to the VCF
_write_contig_vrs()

ref_start_anchor = result["ref_start_anchor"].upper()
ref_start_anchor = result["ref_start_anchor"].upper()

ref_seq = result["ref_seq"].upper()
ref_seq = result["ref_seq"].upper()

seqs = tuple(map(str.upper, (result["peaks"] or {}).get("seqs", ())))
seqs = tuple(map(str.upper, (result["peaks"] or {}).get("seqs", ())))

seq_alts = sorted(set(filter(lambda c: c != ref_seq, seqs)))
common_suffix_idx = -1 * len(commonprefix(tuple(map(_reversed_str, (ref_seq, *seqs)))))
seq_alts = sorted(set(filter(lambda c: c != ref_seq, seqs)))
common_suffix_idx = -1 * len(commonprefix(tuple(map(_reversed_str, (ref_seq, *seqs)))))

call = result["call"]
seq_alleles_raw: tuple[str, ...] = (ref_seq, *(seq_alts or (".",))) if call is not None else (".",)
seq_alleles: list[str] = []
call = result["call"]
seq_alleles_raw: tuple[str, ...] = (ref_seq, *(seq_alts or (".",))) if call is not None else (".",)
seq_alleles: list[str] = []

if call is not None:
seq_alleles.append(ref_start_anchor + ref_seq[:common_suffix_idx])
if seq_alts:
seq_alleles.extend(ref_start_anchor + a[:common_suffix_idx] for a in seq_alts)
else:
seq_alleles.append(".")
if call is not None:
seq_alleles.append(ref_start_anchor + ref_seq[:common_suffix_idx])
if seq_alts:
seq_alleles.extend(ref_start_anchor + a[:common_suffix_idx] for a in seq_alts)
else:
seq_alleles.append(".")

start = result.get("start_adj", result["start"]) - len(ref_start_anchor)
vr: pysam.VariantRecord = vf.new_record(
contig=contig,
start=start,
alleles=seq_alleles,
)
start = result.get("start_adj", result["start"]) - len(ref_start_anchor)
vr: pysam.VariantRecord = variant_file.new_record(
contig=contig,
start=start,
alleles=seq_alleles,
)

vr.samples[sample_id_str]["GT"] = tuple(map(seq_alleles_raw.index, seqs)) if seqs else (".",)
vr.samples[sample_id_str]["DP"] = sum(result["peaks"]["n_reads"])
vr.samples[sample_id_str]["AD"] = tuple(result["peaks"]["n_reads"])
vr.samples[sample_id_str]["MC"] = tuple(map(int, result["call"])) # TODO: support fractional
vr.samples[sample_id]["GT"] = tuple(map(seq_alleles_raw.index, seqs)) if seqs else (".",)
vr.samples[sample_id]["DP"] = sum(result["peaks"]["n_reads"])
vr.samples[sample_id]["AD"] = tuple(result["peaks"]["n_reads"])
vr.samples[sample_id]["MC"] = tuple(map(int, result["call"])) # TODO: support fractional

ps = result["ps"]
ps = result["ps"]

if ps is not None: # have phase set on call, so mark as phased
vr.samples[sample_id_str].phased = True
vr.samples[sample_id_str]["PS"] = ps
if ps is not None: # have phase set on call, so mark as phased
vr.samples[sample_id].phased = True
vr.samples[sample_id]["PS"] = ps

for snv in result.get("snvs", ()):
snv_id = snv["id"]
if snv_id in snvs_written:
continue
snvs_written.add(snv_id)
for snv in result.get("snvs", ()):
snv_id = snv["id"]
if snv_id in snvs_written:
continue
snvs_written.add(snv_id)

ref = snv["ref"]
snv_alts = tuple(filter(lambda v: v != ref, snv["call"]))
snv_alleles = (ref, *snv_alts)
snv_pos = snv["pos"]
ref = snv["ref"]
snv_alts = tuple(filter(lambda v: v != ref, snv["call"]))
snv_alleles = (ref, *snv_alts)
snv_pos = snv["pos"]

snv_vr = vf.new_record(
contig=contig,
id=snv_id,
start=snv_pos,
stop=snv_pos + 1,
alleles=snv_alleles,
)

# TODO: write "rcs" for sample SNV genotypes - list of #reads per allele
snv_vr = variant_file.new_record(
contig=contig,
id=snv_id,
start=snv_pos,
stop=snv_pos + 1,
alleles=snv_alleles,
)

snv_vr.samples[sample_id_str]["GT"] = tuple(map(snv_alleles.index, snv["call"]))
snv_vr.samples[sample_id_str]["DP"] = sum(snv["rcs"])
snv_vr.samples[sample_id_str]["AD"] = snv["rcs"]
# TODO: write "rcs" for sample SNV genotypes - list of #reads per allele

if ps is not None:
snv_vr.samples[sample_id_str].phased = True
snv_vr.samples[sample_id_str]["PS"] = ps
snv_vr.samples[sample_id]["GT"] = tuple(map(snv_alleles.index, snv["call"]))
snv_vr.samples[sample_id]["DP"] = sum(snv["rcs"])
snv_vr.samples[sample_id]["AD"] = snv["rcs"]

contig_vrs.append(snv_vr)
if ps is not None:
snv_vr.samples[sample_id].phased = True
snv_vr.samples[sample_id]["PS"] = ps

contig_vrs.append(vr)
contig_vrs.append(snv_vr)

_write_contig_vrs() # write the final contig's worth of variant records to the VCF at the end
contig_vrs.append(vr)

finally:
vf.close()
_write_contig_vrs() # write the final contig's worth of variant records to the VCF at the end
5 changes: 4 additions & 1 deletion strkit/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def add_call_parser_args(call_parser):
action="store_true",
help="If passed, no TSV call output will be written to stdout.")

call_parser.add_argument(
"--output-chunk-size", type=int, default=10000, help="How many call records to write to disk at a time.")

# END FILE OUTPUT ARGUMENTS ========================================================================================

call_parser.add_argument(
Expand Down Expand Up @@ -355,7 +358,7 @@ def _exec_call(p_args) -> None:
indent_json=p_args.indent_json,
vcf_path=p_args.vcf,
output_tsv=not p_args.no_tsv,
# seed=p_args.seed,
output_chunk_size=p_args.output_chunk_size,
)


Expand Down

0 comments on commit 4339ed0

Please sign in to comment.