Skip to content

Commit

Permalink
refact(call): use time.perf_counter for perf tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Jun 18, 2024
1 parent 1c7f42c commit 6234269
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 23 deletions.
28 changes: 14 additions & 14 deletions strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import numpy as np
import operator
import threading
import time

from collections import Counter
from collections.abc import Sequence
from datetime import datetime
from pysam import FastaFile
from sklearn.cluster import AgglomerativeClustering
from sklearn.mixture import GaussianMixture
Expand Down Expand Up @@ -787,7 +787,7 @@ def call_locus(
read_file_has_chr: bool = True,
ref_file_has_chr: bool = True,
) -> Optional[LocusResult]:
call_timer = datetime.now()
call_timer = time.perf_counter()

# params de-structuring ------------
consensus = params.consensus
Expand Down Expand Up @@ -836,7 +836,7 @@ def call_locus(

# Get reference sequence and copy number ---------------------------------------------------------------------------

ref_timer = datetime.now()
ref_timer = time.perf_counter()

ref_seq_offset_l = left_coord - left_flank_coord
ref_seq_offset_r = right_coord - left_flank_coord
Expand Down Expand Up @@ -922,7 +922,7 @@ def call_locus(
locus_result["start_adj"] = left_coord_adj
locus_result["end_adj"] = right_coord_adj

ref_time = (datetime.now() - ref_timer).total_seconds()
ref_time = time.perf_counter() - ref_timer

# Find the initial set of overlapping aligned segments with associated read lengths + whether we have in-locus
# chimera reads (i.e., reads which aligned twice with different soft-clipping, likely due to a large indel.) -------
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def call_locus(
read_kmers.clear()
read_kmers.update(tr_read_seq_wc[i:i+motif_size] for i in range(0, tr_len - motif_size + 1))

rc_timer = datetime.now()
rc_timer = time.perf_counter()
# Set initial integer copy number guess based on aligned TR size, plus the previous read offset (how much the
# last guess was wrong by, as a delta.)
read_sc = round(tr_len / motif_size)
Expand All @@ -1126,7 +1126,7 @@ def call_locus(
# Update using +=, since if we use an offset that was correct, the new returned offset will be 0, so we really
# want to keep the old offset, not set it to 0.
read_offset_frac_from_starting_guess += new_offset_from_starting_count / max(read_cn, 1)
rc_time = (datetime.now() - rc_timer).total_seconds()
rc_time = time.perf_counter() - rc_timer

if n_read_cn_iters >= max_rc_iters:
logger_.debug(f"{locus_log_str} - locus repeat counting exceeded maximum # iterations ({n_read_cn_iters})")
Expand All @@ -1153,7 +1153,7 @@ def call_locus(
**locus_result,
"peaks": None,
"read_peaks_called": False,
"time": (datetime.now() - call_timer).total_seconds(),
"time": time.perf_counter() - call_timer,
}

if read_adj_score < min_read_align_score:
Expand All @@ -1172,7 +1172,7 @@ def call_locus(
**locus_result,
"peaks": None,
"read_peaks_called": False,
"time": (datetime.now() - call_timer).total_seconds(),
"time": time.perf_counter() - call_timer,
}

continue
Expand Down Expand Up @@ -1270,7 +1270,7 @@ def call_locus(
**locus_result,
"peaks": None,
"read_peaks_called": False,
"time": (datetime.now() - call_timer).total_seconds(),
"time": time.perf_counter() - call_timer,
}

# Now, we know we have enough reads to maybe make a call -----------------------------------------------------------
Expand All @@ -1291,7 +1291,7 @@ def call_locus(
have_rare_realigns = True
break

allele_start_time = datetime.now()
allele_start_time = time.perf_counter()

if params.use_hp:
top_ps = phase_sets.most_common(1)
Expand Down Expand Up @@ -1379,7 +1379,7 @@ def call_locus(
if single_or_dist_assign: # Didn't use SNVs, so call the 'old-fashioned' way - using only copy number
call_data = call_alleles_with_gmm(params, n_alleles, read_dict, assign_method, rng, logger_, locus_log_str)

allele_time = (datetime.now() - allele_start_time).total_seconds()
allele_time = time.perf_counter() - allele_start_time

logger_.debug(f"{locus_log_str} - finished assigning alleles using {assign_method} method: took {allele_time:.4f}s")

Expand All @@ -1399,7 +1399,7 @@ def call_locus(

# Assign reads to peaks and compute peak k-mers (and optionally consensus sequences) -------------------------------

assign_start_time = datetime.now()
assign_start_time = time.perf_counter()

# We cannot call read-level cluster labels with >2 peaks using distance alone;
# don't know how re-sampling has occurred.
Expand Down Expand Up @@ -1496,11 +1496,11 @@ def call_locus(
**({"seqs": call_seqs} if consensus else {}),
} if call_data else None

assign_time = (datetime.now() - assign_start_time).total_seconds()
assign_time = time.perf_counter() - assign_start_time

# Calculate call time ----------------------------------------------------------------------------------------------

call_time = (datetime.now() - call_timer).total_seconds()
call_time = time.perf_counter() - call_timer

if call_time > CALL_WARN_TIME:
logger_.warning(
Expand Down
11 changes: 5 additions & 6 deletions strkit/call/call_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import time
import traceback

from datetime import datetime
from operator import itemgetter
from multiprocessing.synchronize import Event as EventClass # For type hinting
from typing import Iterable, Literal, Optional
Expand Down Expand Up @@ -196,7 +195,7 @@ def locus_worker(

def progress_worker(
sample_id: Optional[str],
start_time: datetime,
start_time: float,
log_level: int,
locus_queue: mp.Queue,
locus_counter: mmg.ValueProxy,
Expand All @@ -215,7 +214,7 @@ def progress_worker(
def _log():
try:
processed_loci = int(locus_counter.get())
n_seconds = (datetime.now() - start_time).total_seconds()
n_seconds = time.perf_counter() - start_time
loci_per_second = processed_loci / n_seconds
est_time_remaining = ((num_loci - processed_loci) / loci_per_second) if loci_per_second else float("inf")
lg.info(
Expand Down Expand Up @@ -272,7 +271,7 @@ def call_sample(
logger = get_main_logger()

# Start the call timer
start_time = datetime.now()
start_time = time.perf_counter()

logger.info(
f"Starting STR genotyping; sample={params.sample_id}, hq={params.hq}, targeted={params.targeted}, "
Expand Down Expand Up @@ -474,9 +473,9 @@ def call_sample(
finish_event.set()
progress_job.join()

time_taken = datetime.now() - start_time
time_taken = time.perf_counter() - start_time

logger.info(f"Finished STR genotyping in {time_taken.total_seconds():.1f}s")
logger.info(f"Finished STR genotyping in {time_taken:.1f}s")

if json_path:
output_json_report_footer(time_taken, json_path, indent_json)
5 changes: 2 additions & 3 deletions strkit/call/output/json_report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from datetime import timedelta
from typing import Callable, Literal

from strkit import __version__
Expand Down Expand Up @@ -72,8 +71,8 @@ def output_json_report_results(results: tuple[LocusResult, ...], is_last: bool,
_write_bytes(results_bytes, json_path, "ab")


def output_json_report_footer(time_taken: timedelta, json_path: str, indent_json: bool):
runtime_bytes = dumps(time_taken.total_seconds())
def output_json_report_footer(time_taken: float, json_path: str, indent_json: bool):
runtime_bytes = dumps(time_taken)
if indent_json:
footer_bytes = b'\n ],\n "runtime": ' + runtime_bytes + b'\n}\n'
else:
Expand Down

0 comments on commit 6234269

Please sign in to comment.