Skip to content

Commit

Permalink
perf: improve runtime of diplotype comparison
Browse files Browse the repository at this point in the history
Improved by 1) reducing the number of diplotypes to compare and 2) replacing the matrix comparison strategy with the path searching scheme.
  • Loading branch information
BinglanLi committed Nov 23, 2023
1 parent 600ae6d commit 406930c
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 339 deletions.
76 changes: 76 additions & 0 deletions src/scripts/diplotype_comparison/compare_diplotype_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from timeit import default_timer as timer

import numpy as np
import pandas as pd

import utilities as util
Expand Down Expand Up @@ -115,6 +116,81 @@
allele_defining_variants: dict = util.get_allele_defining_variants(json_data)

# read in the list of alleles, allele-defining positions, and defining genotypes
allele_definitions: dict = util.get_allele_definitions(json_data, allele_defining_variants)

# get an allele-position array to identify alleles that do not share definitions with others
allele_arrays: np.ndarray = util.convert_dictionary_to_numpy_array(allele_definitions)
allele_position_array: np.ndarray = np.sum(allele_arrays, axis=1).astype(bool).astype('int8')

# identify reference alleles
allele_names: np.ndarray = np.array([*allele_definitions])
bool_reference_alleles: np.ndarray = np.all(allele_position_array, axis=1)
reference_alleles: list[str] = list(allele_names[bool_reference_alleles])

# identify the 'busy' positions that define only one allele besides the reference allele
idx_busy_positions: np.ndarray = np.where(np.sum(allele_position_array, axis=0) > 2)[0]

# identify alleles that share definitions with others besides the reference allele
idx_gregarious_alleles: np.ndarray = np.unique(np.where(allele_position_array[:, idx_busy_positions])[0])
gregarious_alleles: list[str] = list(allele_names[idx_gregarious_alleles])

# skip the gene if none of its alleles shares any allele-defining positions with others
if len(gregarious_alleles):
# find all possible outcomes of alternative calls for each diplotype
dict_predicted_calls = dict()
if args.missing:
if gene == 'CYP2D6':
print(f'\tSkipping - too complex')
continue

# identify the combinations of missing positions to evaluate
hgvs_names = [*allele_defining_variants]
missing_combinations = util.find_missingness_combinations(allele_position_array, list(allele_names),
hgvs_names, reference_alleles,
gregarious_alleles)

# if the gene does not have positions whose absence will cause ambiguous pharmcat calls, then skip
if not missing_combinations:
print(f'\tSkipping - no ambiguous calls')
continue
else:
print(f'\tNumber of combinations = {len(missing_combinations)}')
for m in missing_combinations:
# find possible calls for each missing position
dict_predicted_calls_m = util.predict_pharmcat_calls(allele_defining_variants,
allele_definitions, m)

# append to dict_predicted_calls
for key in dict_predicted_calls_m:
dict_predicted_calls[key] = dict_predicted_calls.get(key, []) + dict_predicted_calls_m[key]
else:
dict_predicted_calls = util.predict_pharmcat_calls(allele_defining_variants,
allele_definitions,
gregarious_alleles)

# convert the python dictionary to a pandas data frame for output
df_predicted_calls = pd.DataFrame.from_dict(dict_predicted_calls)
# add gene name to the data frame for readability in the output
df_predicted_calls['gene'] = gene
# write to an output
if len(df_predicted_calls):
if args.clinical_outcome:
rows = df_predicted_calls[df_predicted_calls['actual'].str.contains(';')]
cols = ['gene', 'actual']
else:
rows = df_predicted_calls
cols = rows.columns
if m_output_file.is_file():
mode = 'a'
header = False
else:
mode = 'w'
header = True
rows.to_csv(m_output_file.absolute(), mode=mode, sep="\t", columns=cols, header=header, index=False)
else:
print(f'\tNo alternative calls')
else:
print(f'\tNo alleles share definitions with others')
allele_definitions = util.get_allele_definitions(json_data, allele_defining_variants)

# find all possible outcomes of alternative calls for each diplotype
Expand Down
15 changes: 2 additions & 13 deletions src/scripts/diplotype_comparison/filter_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,8 @@

# read reference predictions
reference_file_dir: Path = Path(globals().get("__file__", "./_")).absolute().parent
reference_file_pattern: Path = Path('predicted_pharmcat_calls_*tsv')
reference_files: list[str] = glob(str(reference_file_dir.joinpath(reference_file_pattern)))
reference_predictions = pd.DataFrame()
# iteratively read predictions
for file in reference_files:
# read the file
reference_prediction_i: pd.DataFrame = pd.read_csv(file, delimiter='\t')
# if no content, continue to the next file
if len(reference_prediction_i) == 0:
continue

# concatenate the reference predictions to the summary data frame
reference_predictions = pd.concat([reference_predictions, reference_prediction_i], ignore_index=True)
reference_file: Path = reference_file_dir.joinpath('predicted_pharmcat_calls.tsv')
reference_predictions: pd.DataFrame = pd.read_csv(str(reference_file), delimiter='\t')
# replace "NaN" values with ''
reference_predictions = reference_predictions.fillna('')

Expand Down
Loading

0 comments on commit 406930c

Please sign in to comment.