From 406930cedffa551ed3e21331766c8f630a9b7881 Mon Sep 17 00:00:00 2001 From: Binglan Li Date: Wed, 22 Nov 2023 16:06:11 -0800 Subject: [PATCH] perf: improve runtime of diplotype comparison Improved by 1) reducing the number of diplotypes to compare and 2) replacing the matrix comparison strategy with the path searching scheme. --- .../compare_diplotype_definition.py | 76 +++ .../diplotype_comparison/filter_tests.py | 15 +- src/scripts/diplotype_comparison/utilities.py | 587 ++++++++---------- 3 files changed, 339 insertions(+), 339 deletions(-) diff --git a/src/scripts/diplotype_comparison/compare_diplotype_definition.py b/src/scripts/diplotype_comparison/compare_diplotype_definition.py index 076cb4e73..9e57b0bd2 100644 --- a/src/scripts/diplotype_comparison/compare_diplotype_definition.py +++ b/src/scripts/diplotype_comparison/compare_diplotype_definition.py @@ -7,6 +7,7 @@ from pathlib import Path from timeit import default_timer as timer +import numpy as np import pandas as pd import utilities as util @@ -115,6 +116,81 @@ allele_defining_variants: dict = util.get_allele_defining_variants(json_data) # read in the list of alleles, allele-defining positions, and defining genotypes + allele_definitions: dict = util.get_allele_definitions(json_data, allele_defining_variants) + + # get an allele-position array to identify alleles that do not share definitions with others + allele_arrays: np.ndarray = util.convert_dictionary_to_numpy_array(allele_definitions) + allele_position_array: np.ndarray = np.sum(allele_arrays, axis=1).astype(bool).astype('int8') + + # identify reference alleles + allele_names: np.ndarray = np.array([*allele_definitions]) + bool_reference_alleles: np.ndarray = np.all(allele_position_array, axis=1) + reference_alleles: list[str] = list(allele_names[bool_reference_alleles]) + + # identify the 'busy' positions that define only one allele besides the reference allele + idx_busy_positions: np.ndarray = np.where(np.sum(allele_position_array, axis=0) > 2)[0] + + # identify alleles that share definitions with others besides the reference allele + idx_gregarious_alleles: np.ndarray = np.unique(np.where(allele_position_array[:, idx_busy_positions])[0]) + gregarious_alleles: list[str] = list(allele_names[idx_gregarious_alleles]) + + # skip the gene if none of its alleles shares any allele-defining positions with others + if len(gregarious_alleles): + # find all possible outcomes of alternative calls for each diplotype + dict_predicted_calls = dict() + if args.missing: + if gene == 'CYP2D6': + print(f'\tSkipping - too complex') + continue + + # identify the combinations of missing positions to evaluate + hgvs_names = [*allele_defining_variants] + missing_combinations = util.find_missingness_combinations(allele_position_array, list(allele_names), + hgvs_names, reference_alleles, + gregarious_alleles) + + # if the gene does not have positions whose absence will cause ambiguous pharmcat calls, then skip + if not missing_combinations: + print(f'\tSkipping - no ambiguous calls') + continue + else: + print(f'\tNumber of combinations = {len(missing_combinations)}') + for m in missing_combinations: + # find possible calls for each missing position + dict_predicted_calls_m = util.predict_pharmcat_calls(allele_defining_variants, + allele_definitions, m) + + # append to dict_predicted_calls + for key in dict_predicted_calls_m: + dict_predicted_calls[key] = dict_predicted_calls.get(key, []) + dict_predicted_calls_m[key] + else: + dict_predicted_calls = util.predict_pharmcat_calls(allele_defining_variants, + allele_definitions, + gregarious_alleles) + + # convert the python dictionary to a pandas data frame for output + df_predicted_calls = pd.DataFrame.from_dict(dict_predicted_calls) + # add gene name to the data frame for readability in the output + df_predicted_calls['gene'] = gene + # write to an output + if len(df_predicted_calls): + if args.clinical_outcome: + rows = df_predicted_calls[df_predicted_calls['actual'].str.contains(';')] + cols = ['gene', 'actual'] + else: + rows = df_predicted_calls + cols = rows.columns + if m_output_file.is_file(): + mode = 'a' + header = False + else: + mode = 'w' + header = True + rows.to_csv(m_output_file.absolute(), mode=mode, sep="\t", columns=cols, header=header, index=False) + else: + print(f'\tNo alternative calls') + else: + print(f'\tNo alleles share definitions with others') allele_definitions = util.get_allele_definitions(json_data, allele_defining_variants) # find all possible outcomes of alternative calls for each diplotype diff --git a/src/scripts/diplotype_comparison/filter_tests.py b/src/scripts/diplotype_comparison/filter_tests.py index 26ba4def2..115cf12c1 100644 --- a/src/scripts/diplotype_comparison/filter_tests.py +++ b/src/scripts/diplotype_comparison/filter_tests.py @@ -11,19 +11,8 @@ # read reference predictions reference_file_dir: Path = Path(globals().get("__file__", "./_")).absolute().parent -reference_file_pattern: Path = Path('predicted_pharmcat_calls_*tsv') -reference_files: list[str] = glob(str(reference_file_dir.joinpath(reference_file_pattern))) -reference_predictions = pd.DataFrame() -# iteratively read predictions -for file in reference_files: - # read the file - reference_prediction_i: pd.DataFrame = pd.read_csv(file, delimiter='\t') - # if no content, continue to the next file - if len(reference_prediction_i) == 0: - continue - - # concatenate the reference predictions to the summary data frame - reference_predictions = pd.concat([reference_predictions, reference_prediction_i], ignore_index=True) +reference_file: Path = reference_file_dir.joinpath('predicted_pharmcat_calls.tsv') +reference_predictions: pd.DataFrame = pd.read_csv(str(reference_file), delimiter='\t') # replace "NaN" values with '' reference_predictions = reference_predictions.fillna('') diff --git a/src/scripts/diplotype_comparison/utilities.py b/src/scripts/diplotype_comparison/utilities.py index 1d9c75a32..b0d87a5cc 100644 --- a/src/scripts/diplotype_comparison/utilities.py +++ b/src/scripts/diplotype_comparison/utilities.py @@ -6,7 +6,7 @@ import numpy as np from pathlib import Path -from typing import List, Set, Tuple, Optional +from typing import List, Set, Optional # todo: convert sys.exit(1) to reportableExceptions @@ -14,7 +14,7 @@ # define the list of wobble genotypes _wobble_genotype_list: list = ['S', 'Y', 'M', 'K', 'R', 'W', 'V', 'H', 'D', 'B', 'N'] # build a dictionary of wobble genotypes and their corresponding basepairs -_wobble_match_table: dict[str, List[str]] = { +_wobble_match_table: dict[str, Set[str]] = { 'M': {'A', 'C'}, 'R': {'A', 'G'}, 'W': {'A', 'T'}, @@ -87,38 +87,39 @@ def get_allele_defining_variants(json_data: dict) -> dict[str, dict[str, str]]: def get_allele_definitions(json_data: dict, - variants: dict[str, dict[str, str]]) -> dict[str, dict[str, list[str]]]: + variants: dict[str, dict[str, str]]) -> dict[str, dict[str, str]]: """ Extract allele definitions :param json_data: a dictionary of allele-defining variants and named alleles from PharmCAT JSON :param variants: a dictionary of all allele-defining positions :return: a dictionary of {allele_name: { - defining_genotype: list of defining genotypes, whose sequence should correspond to json - 'variants', - reference_genotype: list of reference genotype at each position} } - """ + allele-defining position (by hgvs names): {allele-defining genotype} } + } += """ # extract named allele definitions try: - allele_list: List[str] = [entry['name'] for entry in json_data['namedAlleles']] - geno_list: List[List[str]] = [entry['alleles'] for entry in json_data['namedAlleles']] - - # check whether there are any empty values - if None in allele_list: + alleles: list[str] = [entry['name'] for entry in json_data['namedAlleles']] + genotypes: list[list[str]] = [entry['alleles'] for entry in json_data['namedAlleles']] + hgvs_names: list[str] = [*variants] + n_variants = len(hgvs_names) + n_alleles = len(alleles) + + # check whether there are any missing values + if None in alleles: print('One of the allele names is missing') # check whether any allele definition is empty - if any(x.count(None) == len(x) for x in geno_list): + if any(x.count(None) == len(x) for x in genotypes): print('One of the alleles is not defined by any genotypes at any positions.') # check whether each allele is defined by the same number of positions in the json 'variants' attribute - n_variants = len(variants) - if any(len(g) != n_variants for g in geno_list): + if any(len(g) != n_variants for g in genotypes): print('One of the alleles has a different number of allele-defining positions.') - # generate a dictionary of allele definitions with the allele name, defining genotypes, and reference genotypes - n_alleles = len(allele_list) - allele_definitions: dict[str, dict[str, list[str]]] = { - allele_list[i]: { - 'genotypes': geno_list[i] - } for i in range(n_alleles) - } + # generate a dictionary of allele definitions with the allele name and allele-defining genotypes + allele_definitions: dict[str, dict[str, str]] = {} + for i in range(n_alleles): + allele_definitions[alleles[i]] = {} + for j in range(n_variants): + allele_definitions[alleles[i]][hgvs_names[j]] = genotypes[i][j] return allele_definitions @@ -128,18 +129,23 @@ def get_allele_definitions(json_data: dict, f'Check whether the attributes exist for all allele-defining positions.') -def count_allele_scores(allele_definitions: dict[str, dict[str, list[str]]]) -> dict[str, int]: +def count_allele_scores(dict_allele_definitions: dict[str, dict[str, str]]) -> dict[str, int]: """ calculate the haplotype scores by counting the number of non-empty allele-defining positions - :param allele_definitions: a dictionary of allele definitions {n, {p, p} } + :param dict_allele_definitions: a dictionary of allele definitions + {allele_name: {hgvs_name: genotype}} :return: hap_score: dict[str, int] = { allele_name: allele_score } """ - # for each allele, count the number of allele-defining genotypes - n_defining_genotypes: list[int] = [len(x['genotypes']) - x['genotypes'].count(None) - for x in allele_definitions.values()] + # initialize the return variable + allele_scores: dict[str, int] = dict() - # create a dictionary of haplotype scores where dict['allele name'] = {haplotype score} - allele_scores: dict[str, int] = dict(zip(allele_definitions, n_defining_genotypes)) + # iterate over alleles + for allele, dict_defining_genotypes in dict_allele_definitions.items(): + # count the number of allele-defining genotypes + genotypes = list(dict_defining_genotypes.values()) + n_defining_genotypes: int = len(genotypes) - genotypes.count(None) + # add the allele score to allele_scores + allele_scores[allele] = n_defining_genotypes return allele_scores @@ -147,8 +153,8 @@ def count_allele_scores(allele_definitions: dict[str, dict[str, list[str]]]) -> def count_diplotype_scores(allele_scores: dict[str, int]) -> dict[str, int]: """ Based on a dictionary of allele/haplotype scores, get all the diplotype combinations and their scores - :param allele_scores: a dictionary where keys = allele name, values = allele score - :return: dip_score: a dictionary where keys = the diplotype names and the values = the diplotype scores + :param allele_scores: a dictionary {allele_name: allele_score} + :return: dip_score: a dictionary {diplotype_name: diplotype_score} """ # calculate the diplotype scores diplotype_scores: dict[str, int] = {k1 + '/' + k2: v1 + v2 @@ -158,48 +164,31 @@ def count_diplotype_scores(allele_scores: dict[str, int]) -> dict[str, int]: return diplotype_scores -def fill_definitions_with_references(allele_definitions: dict[str, dict[str, list[str]]], - variants: dict[str, dict[str, str]]) \ - -> dict[str, dict[str, list[str]]]: +def fill_definitions_with_references(dict_allele_definitions: dict[str, dict[str, str]], + dict_allele_defining_variants: dict[str, dict[str, str]]) \ + -> dict[str, dict[str, str]]: """ fill up empty allele-defining positions with the reference genotype :return: a list of genotypes where empty/None genotypes have been filled with reference genotypes at the position """ - allele_definitions_filled = dict() - for a, sub_dict in allele_definitions.items(): - # get the list of allele defining genotypes and the reference genotypes at the corresponding positions - definitions: list[str] = sub_dict['genotypes'] - references: list[str] = [x['ref'] for x in variants.values()] - - # get the list of definitions where NAs are filled with reference genotypes - n = len(definitions) - complete_genotypes: list[str] = [definitions[i] if definitions[i] else references[i] for i in range(n)] - allele_definitions_filled[a] = { - 'genotypes': complete_genotypes - } - - return allele_definitions_filled + # initialize the return variable + dict_allele_definitions_filled: dict[str, dict[str, str]] = dict() + # iterate over each allele's definition + for allele, dict_defining_genotypes in dict_allele_definitions.items(): + # initialize an empty dictionary for an allele + dict_allele_definitions_filled[allele] = {} + # find the allele-defining genotype at each position + for position, genotype in dict_defining_genotypes.items(): + # if the genotype at this position is None, use the reference genotype + updated_genotype: str = dict_allele_defining_variants[position]['ref'] if genotype is None else genotype + # add the update genotype to dict_allele_definitions_filled + dict_allele_definitions_filled[allele][position] = updated_genotype -def split_wobl_nonwobl_genotypes(genotypes: List[str]) -> Tuple[List[str], List[str]]: - """ - split wobble and non-wobble genotypes from a list into two separate lists - :param genotypes: a list of genotypes - :return: g_w: a list of wobble genotypes - :return: g_nw: a list of non-wobble genotypes - """ - # set up empty lists for wobble and non-wobble genotypes - g_w: List[str] = [] - g_nw: List[str] = [] + return dict_allele_definitions_filled - # iterate over genotypes to find wobble vs non-wobble genotypes - for g in genotypes: - (g_nw, g_w)[g in _wobble_genotype_list].append(g) - return g_w, g_nw - - -def replace_wobble(wobl_genotype: str) -> List[str]: +def replace_wobble(wobl_genotype: str) -> Set[str]: """ replace wobble genotypes with basic basepairs A/T/C/G :param wobl_genotype: an allele-defining genotype @@ -209,16 +198,16 @@ def replace_wobble(wobl_genotype: str) -> List[str]: return g_flat -def get_unique_combinations(g1: str, g2: str) -> Set[str]: +def get_unique_combinations(g1: str, g2: str) -> set[str]: """ get all possible combinations of genotypes at an allele-defining positions - :param g1: a list of allele-defining genotypes from allele 1 - :param g2: a list of allele-defining genotypes from allele 2 + :param g1: an allele-defining genotypes from allele 1 + :param g2: an allele-defining genotypes from allele 2 :return: a set of unique combinations of genotypes """ # replace wobbles - g1_ = replace_wobble(g1) if g1 in _wobble_genotype_list else [g1] - g2_ = replace_wobble(g2) if g2 in _wobble_genotype_list else [g2] + g1_: list[str] = replace_wobble(g1) if g1 in _wobble_genotype_list else [g1] + g2_: list[str] = replace_wobble(g2) if g2 in _wobble_genotype_list else [g2] # get all possible genotype combinations at this positions if None in g1_ and None in g2_: @@ -233,59 +222,47 @@ def get_unique_combinations(g1: str, g2: str) -> Set[str]: return uniq_comb -def get_diplotype_definition_dictionary(allele_definitions: dict[str, dict[str, list[str]]], - allele_defining_variants: dict[str, dict[str, str]]) \ - -> dict[str, dict[str, Set[str]]]: +def get_diplotype_definition_dictionary(dict_allele_definitions: dict[str, dict[str, str]], + alleles_to_test: list[str]) -> dict[str, dict[str, Set[str]]]: """ obtain a dictionary of diplotype definitions - :param allele_definitions: a dictionary of allele definitions - :param allele_defining_variants: a dictionary of allele-defining positions + :param dict_allele_definitions: a dictionary of allele definitions + :param alleles_to_test: list of alleles that share allele-defining positions with other alleles :return: a dictionary {diplotype_name: {position: {unique combinations of genotypes at a position} } } """ - # initialize variables # set up an empty dictionary that stores each diplotype's definition - diplotype_definitions: dict[str, dict[str, Set[str]]] = {} + diplotype_definitions: dict[str, dict[str, set[str]]] = {} # get the list of alleles - alleles = [*allele_definitions] - hgvs_names = [*allele_defining_variants] - - # get the length of the allele list for the loop to iterate through each element - n_allel = len(allele_definitions) + allele_names: list[str] = [*dict_allele_definitions] + n_alleles: int = len(allele_names) # iterate over the allele list - for i in range(n_allel): - for j in range(i, n_allel): + for i in range(n_alleles): + # get the allele name + a1: str = allele_names[i] + for j in range(i, n_alleles): # get the allele names - a1: str = alleles[i] - a2: str = alleles[j] + a2: str = allele_names[j] + + # skip if both alleles are defined by exclusive positions + if (a1 not in alleles_to_test) and (a2 not in alleles_to_test): + continue # specify the diplotype's name, which will be used as the key in dic_mat diplotype_name = a1 + '/' + a2 - # check whether the diplotype name is already in the dic_mat dictionary - if diplotype_name in diplotype_definitions: - # each diplotype should only be processed once, so each key/diplotype name should be unique - print(f'Diplotype {diplotype_name} is processed more than once. Something is wrong.') - sys.exit(1) - - # if no reportable errors, continue to generate the diplotype dictionary - # first, get the complete defining genotypes without empty values for the alleles - g1 = allele_definitions[a1]['genotypes'] - g2 = allele_definitions[a2]['genotypes'] - - # second, get unique combinations of genotypes at each genetic position - genotype_combinations = [get_unique_combinations(x, y) for x, y in zip(g1, g2)] - # third, create a dictionary variable to store the unique genotype combinations at each position - definition_i: dict[str, Set[str]] = dict(zip(hgvs_names, genotype_combinations)) + single_diplotype_definition: dict[str, set[str]] = dict() + for hgvs_name in dict_allele_definitions[a1]: + g1 = dict_allele_definitions[a1][hgvs_name] + g2 = dict_allele_definitions[a2][hgvs_name] + single_diplotype_definition[hgvs_name] = get_unique_combinations(g1, g2) # finally, append the diplotype and its genotypes to the dictionary - diplotype_definitions[diplotype_name] = definition_i - # end - for j in range(n_allel) - # end - for i in range(i, n_allel) + diplotype_definitions[diplotype_name] = single_diplotype_definition return diplotype_definitions @@ -304,75 +281,52 @@ def find_powerset(s): yield {ss for mask, ss in zip(masks, s) if i & mask} -def find_missingness_combinations(allele_defining_variants: dict[str, dict[str, str]], - allele_definitions: dict[str, dict[str, list[str]]]) -> list[list[str]]: +def find_missingness_combinations(allele_array: np.ndarray, + allele_names: list[str], + hgvs_names: list[str], + reference_alleles: list[str], + alleles_to_test: list[str]) -> list[list[str]]: """ Find the combinations of positions to be set to missing for autogenerated tests - :param allele_defining_variants: - :param allele_definitions: + :param allele_array: an allele-position array (A, P) where each cell denotes whether + the allele is defined by a specific genotype at a position + A = number of alleles in a gene + P = allele-defining positions + :param allele_names: list of allele names of length A + :param hgvs_names: list of hgvs names of length P + :param reference_alleles: list of reference alleles + :param alleles_to_test: list of alleles that share allele-defining positions with others :return: power set (all combinations) of missing positions if missing_position_combinations is empty, then there is no autogenerate test with missing positions if missing_position_combinations is not empty, it lists all missing combinations except the empty set """ p_separator: str = ';;' - allele_defining_positions: np.ndarray = np.array([*allele_defining_variants]) - allele_names: list[str] = [*allele_definitions] - # get the numbers of alleles and allele-defining genotypes - n_positions: int = len(allele_defining_positions) - n_alleles: int = len(allele_definitions) - - # create an empty genotype-allele array - allele_genotypes: np.ndarray = np.zeros(shape=(n_alleles, n_positions)) - for allele in allele_definitions: - allele_idx = allele_names.index(allele) - genotypes = allele_definitions[allele]['genotypes'] - for i, g in enumerate(genotypes): - if g is not None: - allele_genotypes[allele_idx, i] = 1 - - # identify reference allele(s) - ref_allel_bool: np.ndarray = np.all(allele_genotypes, axis=1) - ref_allel: list[str] = [y for x, y in zip(ref_allel_bool, allele_names) if x] - - # remove the row of reference allele - allele_genotypes_no_ref: np.ndarray = allele_genotypes[~ref_allel_bool, :] - allele_names_no_ref: list[str] = [y for x, y in zip(ref_allel_bool, allele_names) if not x] - - # find positions that define more than one allele - multi_allele_position_idx: np.ndarray = np.where(np.sum(allele_genotypes_no_ref, axis=0) > 1)[0] - # find alleles that are defined by multi-allele positions - alleles_to_test_idx: np.ndarray = np.unique(np.where(allele_genotypes_no_ref[:, multi_allele_position_idx])[0]) - alleles_to_test: list[str] = [allele_names_no_ref[i] for i in alleles_to_test_idx] # identify missing position combinations for each diplotype missing_position_combinations: list[list[str]] = [] for a1 in allele_names: - for a2 in allele_names: - # skip if none of the alleles are ambiguous - if a1 not in alleles_to_test and a2 not in alleles_to_test: - continue - + for a2 in alleles_to_test: # skip reference/reference - if a1 in ref_allel and a2 in ref_allel: + if a1 in reference_alleles and a2 in reference_alleles: continue # otherwise, get the index of allele-defining positions # this can include positions exclusive to either alleles - elif a1 in ref_allel: - i: int = allele_names_no_ref.index(a2) - p_idx: np.ndarray = np.unique(np.where(allele_genotypes_no_ref[i, :])[0]) - elif a2 in ref_allel: - i: int = allele_names_no_ref.index(a1) - p_idx: np.ndarray = np.unique(np.where(allele_genotypes_no_ref[i, :])[0]) + elif a1 in reference_alleles: + i: int = allele_names.index(a2) + p_idx: np.ndarray = np.where(allele_array[i, :])[0] + elif a2 in reference_alleles: + i: int = allele_names.index(a1) + p_idx: np.ndarray = np.unique(np.where(allele_array[i, :])[0]) else: - i: int = allele_names_no_ref.index(a1) - j: int = allele_names_no_ref.index(a2) - p_idx: np.ndarray = np.unique(np.where(allele_genotypes_no_ref[[i, j], :])[1]) + i: int = allele_names.index(a1) + j: int = allele_names.index(a2) + p_idx: np.ndarray = np.unique(np.where(allele_array[[i, j], :])[1]) # get the positions' hgvs names - p_names: np.ndarray = allele_defining_positions[p_idx] + p_names: np.ndarray = np.array(hgvs_names)[p_idx] # find indices of positions that need to be missing together for an allele to be mistaken as another - u, indices = np.unique(allele_genotypes_no_ref[:, p_idx], axis=1, return_inverse=True) + u, indices = np.unique(allele_array[:, p_idx], axis=1, return_inverse=True) # get the list of missing positions based on indices p_to_permute: set[str] = set() for k in range(indices.max() + 1): @@ -390,54 +344,73 @@ def find_missingness_combinations(allele_defining_variants: dict[str, dict[str, return missing_position_combinations -def convert_dictionary_to_numpy_array(definition_dictionary: dict[str, dict[str, Set[str]]], - variants: dict[str, dict[str, str]]) -> np.ndarray: +def convert_dictionary_to_numpy_array(dict_definitions: dict[str, dict]) -> np.ndarray: """ convert the diplotype definition dictionary to a numpy array - :param variants: a dictionary of {hgvs_name: {position: [], ref: []}} - :param definition_dictionary: {'diplotype name': {'position hgvs name': {unique set of genotypes at a position} } } - :return: a 3D numpy array [diplotypes, genotype combinations, positions]. here genotype combinations indicates 'A/G' + :param dict_definitions: + {allele_name: {position_hgvs_name: nucleotide} } + or + {diplotype_name: {position_hgvs_name: nucleotide_combinations} } + :return: a 3D numpy array [alleles/diplotypes, genotypes, positions] """ # initialize variables - # set up an empty set to record all the possible genotype combinations across all positions in a gene + # set up an empty set to record all the possible genotypes across positions in a gene unique_genotypes: set[str] = set() - # use python Set to find all the unique genotype combinations across positions - unique_genotypes.update([g for sub_dict in definition_dictionary.values() for x in sub_dict.values() for g in x]) + for dict_defining_genotypes in dict_definitions.values(): + # get the specific genotype(s) at a certain position + for genotypes in dict_defining_genotypes.values(): + # for diplotypes, genotypes are a set of nucleotide combinations, like 'A/A' or 'delT/T' + if isinstance(genotypes, set): + # add each genotype to the list and skip None + for x in genotypes: + if x is not None: + unique_genotypes.add(x) + # for alleles, genotype(s) are a single nucleotide that defines an allele at a position + elif isinstance(genotypes, str): + unique_genotypes.add(genotypes) + # for alleles, besides the specific allele-defining positions, others are reference and left as None + elif genotypes is None: + continue + else: + print(f'\tUnexpected data type was found in convert_dictionary_to_numpy_array') # convert the genotype set to a list which has an order genotype_list: list[str] = list(unique_genotypes) - # get the list of diplotype names - diplotype_names: list[str] = [*definition_dictionary] + # get the list of allele or diplotype names + definition_names: list[str] = [*dict_definitions] # get the list of hgvs names for all positions - hgvs_names: list[str] = [*variants] + hgvs_names: list[str] = [*dict_definitions[definition_names[0]]] # initialize an empty numpy array with rows and columns - n_diplotypes: int = len(diplotype_names) + n_definitions: int = len(definition_names) n_positions: int = len(hgvs_names) n_genotypes: int = len(genotype_list) - diplotype_array: np.ndarray = np.zeros((n_diplotypes, n_genotypes, n_positions), dtype='int8') - - # fill the numpy array with 1 where position(key)-genotype(value) combinations exist - for diplotype, one_diplotype_definition in definition_dictionary.items(): + definition_arrays: np.ndarray = np.zeros((n_definitions, n_genotypes, n_positions), dtype='int8') - # get the diplotype index for the numpy array - diplotype_idx = diplotype_names.index(diplotype) - - for position, genotypes in one_diplotype_definition.items(): + # fill the numpy arrays with 1 where position(key)-genotype(value) combinations exist + for definition, dict_one_definition in dict_definitions.items(): + # get the allele or diplotype index for the numpy array + idx_definition = definition_names.index(definition) + for position, genotypes in dict_one_definition.items(): # get the position index for the numpy array position_idx = hgvs_names.index(position) - for g in genotypes: - # get the genotype index for the numpy array - genotype_idx = genotype_list.index(g) + # skip None + if genotypes is None: + continue + # otherwise, denote the corresponding genotype-position as '1' in the matrix + else: + for g in genotypes: + # get the genotype index for the numpy array + genotype_idx = genotype_list.index(g) - # fill the corresponding cell in the numpy array with 1 to denote the definition - diplotype_array[diplotype_idx, genotype_idx, position_idx] = 1 + # fill the corresponding cell in the numpy array with 1 to denote the definition + definition_arrays[idx_definition, genotype_idx, position_idx] = 1 - return diplotype_array + return definition_arrays def is_sharing_definitions(g1: np.ndarray, g2: np.ndarray, axis: Optional[int] = 1) -> np.ndarray[bool]: @@ -524,7 +497,7 @@ def is_included(g1: np.ndarray, g2: np.ndarray) -> bool: g_subs = g1 - g2 # find diplotypes that includes g1 - status = set(g_subs.flatten()) == {0, -1} + status = set(np.unique(g_subs)) == {0, -1} return status @@ -643,7 +616,7 @@ def find_possible_calls(g1: np.ndarray, # sys.exit(1) # get the names of the diplotypes that share definitions with g1 - diplotypes_sharing_definitions = [d for d, s in zip(diplotype_names, status_sharing_definitions) if s] + diplotypes_sharing_definitions: list[str] = [d for d, s in zip(diplotype_names, status_sharing_definitions) if s] return diplotypes_sharing_definitions @@ -653,9 +626,9 @@ def find_wobble_subsets(genotypes: np.ndarray, diplotypes: np.ndarray, wobble_su """ find subsets of possible calls for provided diplotypes :param genotypes: a numpy array of multiple diplotypes (D, M, P) - D = the total number of diplotypes for a gene - M = all possible genotypes at a position for a diplotype - P = allele-defining position for a gene + D = the total number of diplotypes for a gene + M = all possible genotypes at a position for a diplotype + P = allele-defining position for a gene :param diplotypes: an 1D array of diplotypes (D, ) :param wobble_subsets: a set that saves all possible calls based on the input genotype matrix :return: subsets of diplotypes that share unique genotype definitions @@ -695,172 +668,137 @@ def find_wobble_subsets(genotypes: np.ndarray, diplotypes: np.ndarray, wobble_su return wobble_subsets -def find_all_possible_calls(diplotype_definitions: dict[str, dict[str, Set[str]]], - allele_defining_variants: dict[str, dict[str, str]]) -> dict: - """ - find the alternative calls for all diplotypes. - check out the function 'compare_diplotype_pairs' for how the relationships between diplotypes were determined. - - :param diplotype_definitions: a dictionary of diplotype definitions - :param allele_defining_variants: a dictionary of all allele-defining variants for a gene - :return: a dictionary of all the alternative calls one will find for a gene - {diplotypes: {case123: [list of all possible calls]}} ## not all wobble diplotypes have multiple cases - """ - # initialize an empty dictionary to store the alternative calls - dict_possible_calls: dict = dict() - - # convert the diplotype definition dictionaries to numpy arrays of genotypes for each diplotype - # vectorized computation by numpy arrays speeds up diplotype comparison - definition_arrays: np.ndarray = convert_dictionary_to_numpy_array(diplotype_definitions, - allele_defining_variants) - - # get the list of wobble diplotypes - wobble_diplotypes: list[str] = [] - for d, v in diplotype_definitions.items(): - for g in v.values(): - if len(g) > 1: - wobble_diplotypes.append(d) - break - - # get the list of diplotypes that have overlapping definitions with other diplotype - diplotype_names: list[str] = [*diplotype_definitions] - # get the number of diplotypes - n_diplotypes = len(diplotype_definitions) - - # loop over all diplotypes - for i in range(n_diplotypes): - - # get the genotype array of the target diplotype - d1 = diplotype_names[i] - g1 = definition_arrays[i, :, :] - - # find the alternative calls of a diplotype - possible_calls: list[str] = find_possible_calls(g1, definition_arrays, diplotype_names) - - # if there is no other alternative calls except d1 itself, move on to the next diplotype - if [d1] == possible_calls: - continue - # if d1 is not a wobble diplotype, add the alternative calls to the summary dictionary and continue - elif d1 not in wobble_diplotypes: - dict_possible_calls[d1] = {'case_1': possible_calls} - # if there are only two diplotypes in the possible calls, continue to the next - elif len(possible_calls) == 2: - dict_possible_calls[d1] = {'case_1': possible_calls} - else: - # use a set to temporary store subsets of possible calls - # set is not hashable in a set, so use hashable frozensets in a set instead - wobble_subsets: set[frozenset] = {frozenset(possible_calls)} - - # identify the indices of possible_calls in the definition_arrays - d_indices: list[int] = [diplotype_names.index(d) for d in possible_calls] - # extract genotypes of the possible_calls - g_possible_calls: np.ndarray = definition_arrays[d_indices, :, :] - - # since this is to find alternative calls of d1,trim g_possible_calls to - # retain only definitions shared with d1 - # numbers in g_possible_calls represents how many diplotypes, including d1, share a specific definitions - g_possible_calls = g_possible_calls * g1 - - # find subsets of possible calls for d1 - possible_calls_np: np.ndarray = np.array(possible_calls) - wobble_subsets = find_wobble_subsets(g_possible_calls, possible_calls_np, wobble_subsets) - - # add the frozensets to the summary dictionary - if len(wobble_subsets) == 1: - dict_possible_calls[d1] = list(list(wobble_subsets)[0]) - else: - # initialize a dictionary item for d1 - dict_possible_calls[d1] = dict() - # add different sets of alternative calls to the summary dictionary for d1 - for j, x in enumerate(wobble_subsets): - dict_possible_calls[d1]['case_' + str(j + 1)] = list(x) +def find_all_possible_calls(definition_arrays: np.ndarray, diplotype_names: list[str], + possible_call_sets: list[set[str]]) -> list[set[str]]: + # initialize variables + n_positions = definition_arrays.shape[2] + # initialize an empty list to store the sets of overlapping diplotypes based on the current position + current_possible_call_sets: list[set[str]] = [] + + for idx_position in range(n_positions): + # get the total number of possible_call_sets + n_combinations = len(possible_call_sets) + + # update the overlapping diplotype calls based on the current position + for i in range(n_combinations): + # get the possible calls + possible_calls: np.ndarray = np.array(list(possible_call_sets[i])) + idx_possible_calls: list[int] = [diplotype_names.index(x) for x in possible_calls] + genotypes_of_possible_calls = definition_arrays[idx_possible_calls, :, idx_position] + + # find genotypes that define more than one diplotype + busy_genotypes = np.where(np.sum(genotypes_of_possible_calls, axis=0) > 1)[0] + if len(busy_genotypes): + for idx_genotype in busy_genotypes: + idx_diplotypes = np.unique(np.where(genotypes_of_possible_calls[:, idx_genotype])[0]) + current_possible_calls = set(possible_calls[idx_diplotypes]) + if current_possible_calls not in current_possible_call_sets: + current_possible_call_sets.append(current_possible_calls) + + # update the possible_call_sets + possible_call_sets = current_possible_call_sets + # reset current_possible_call_sets + current_possible_call_sets = [] # return the summary dictionary of alternative calls for all diplotypes - return dict_possible_calls - - -def get_actual_and_alternative_calls(possible_calls: list[str], diplotype_scores: dict[str, int]) \ - -> Tuple[str, str]: - # initialize output variables - actual_calls: list[str] = [] - alternative_calls: list[str] = [] - - # get the scores for the possible diplotypes - scores = np.array([diplotype_scores[d] for d in possible_calls]) - # get the diplotypes with the max scores - max_idx = np.where(scores == np.max(scores))[0] - - # separate the actual and alternative calls based on diplotype scores - for i, d in enumerate(possible_calls): - (alternative_calls, actual_calls)[i in max_idx].append(d) + return possible_call_sets - # return str - return ';'.join(actual_calls), ';'.join(alternative_calls) - -def find_predicted_calls(dict_possible_calls: dict[str, str], +def find_predicted_calls(possible_call_sets: list[set[str]], diplotype_scores: dict[str, int], missing_positions: list[str] = None) -> dict[list[str], list[str], list[str]]: + """ + for a diplotype that have alternative calls, find the call with the highest score that PharmCAT will return + :param possible_call_sets: (list) sets of possible calls. + Each set comprises diplotypes with overlapping definitions. + :param diplotype_scores: dictionary of diplotype scores based on the number of allele-defining positions + :param missing_positions: list of positions that are presumed missing and not considered in this round of comparison + :return: dictionary of three lists of the same length + expected calls: purported genotypes for a diplotype in a VCF + actual calls: diplotypes with the highest score based on the number of allele-defining positions + alternative calls: other diplotypes that share definitions with the actual calls but have lower scores + """ # initialize the return dictionary predict_calls_cols = ['expected', 'actual', 'alternative', 'missing_positions'] predicted_calls: dict = {key: [] for key in predict_calls_cols} - # iterate over different expected calls - for expected_call, possible_calls in dict_possible_calls.items(): - # for an expected call, iterate over all possible sets of alternative calls - for case in possible_calls.values(): - # get the lists of actual calls and alternative calls based on scores - actual_calls, alternative_calls = get_actual_and_alternative_calls(case, diplotype_scores) + # iterate over possible_call_set + for possible_call_set in possible_call_sets: + # initialize lists to concatenate + actual_calls: list[str] = [] + alternative_calls: list[str] = [] + + # get diplotype scores for the possible call + scores = np.array([diplotype_scores[d] for d in possible_call_set]) + # get the diplotypes with the max scores + idx_max_score = np.where(scores == np.max(scores))[0] + # separate the actual and alternative calls based on diplotype scores + for i, d in enumerate(possible_call_set): + (alternative_calls, actual_calls)[i in idx_max_score].append(d) + + # add results to the predicted_calls + for d in possible_call_set: # add to the summary dictionary - predicted_calls['expected'].append(expected_call) - predicted_calls['actual'].append(actual_calls) - predicted_calls['alternative'].append(alternative_calls) + predicted_calls['expected'].append(d) + predicted_calls['actual'].append(';'.join(actual_calls)) + predicted_calls['alternative'].append(';'.join(alternative_calls)) predicted_calls['missing_positions'].append(';'.join(missing_positions) if missing_positions else None) # return the expected, actual, and alternative calls return predicted_calls -def predict_pharmcat_calls(allele_defining_variants: dict[str, dict[str, str]], - allele_definitions: dict[str, dict[str, list[str]]], +def predict_pharmcat_calls(dict_allele_defining_variants: dict[str, dict[str, str]], + dict_allele_definitions: dict[str, dict[str, set[str]]], + alleles_to_test: list[str], missing_positions: Optional[set[str]] = None) -> dict[list[str], list[str], list[str]]: # initialize values definitions = dict() defining_variants = dict() - # remove missing positions from allele definitions - positions = [*allele_defining_variants] - for allele in allele_definitions: - # get genotypes and remove missing positions - g = allele_definitions[allele]['genotypes'] - if missing_positions: - g = [g[i] for i, x in enumerate(positions) if x not in missing_positions] - # if definitions are not empty for this allele, add to the definitions dictionary - if any(g): - definitions[allele] = {'genotypes': g} - # else: - # print(f'...Discarding {allele} as the definition is emtpy due to missing positions') - # remove missing positions from the list of allele defining positions - for m in allele_defining_variants: - # skip if m is one of the missing positions + for position in dict_allele_defining_variants: + # skip if position is one of the missing positions if missing_positions: - if m in missing_positions: + if position in missing_positions: continue # add the position to the new dictionary - defining_variants[m] = { - 'position': allele_defining_variants[m]['position'], - 'ref': allele_defining_variants[m]['ref'] + defining_variants[position] = { + 'position': dict_allele_defining_variants[position]['position'], + 'ref': dict_allele_defining_variants[position]['ref'] } + # remove missing positions from allele definitions + for allele, dict_defining_genotypes in dict_allele_definitions.items(): + if missing_positions: + # get the hgvs names of allele-defining positions, excluding missing positions + p: list[str] = [k for k in dict_defining_genotypes if k not in missing_positions] + # get genotypes at each allele-defining positions, excluding missing positions + g: list[str] = [dict_defining_genotypes[hgvs_name] for hgvs_name in p] + # if definitions are not empty for this allele, add to the definitions dictionary + if any(g): + definitions[allele] = {k: v for k, v in zip(p, g)} + else: + definitions[allele] = dict_defining_genotypes + # update the allele_definitions and fill empty cells with reference genotypes - allele_definitions_filled = fill_definitions_with_references(definitions, defining_variants) + allele_definitions_filled: dict[str, dict[str, str]] = fill_definitions_with_references(definitions, + defining_variants) # get a dictionary of diplotype definitions for comparison # {diplotype_name: {position: {unique combinations of genotypes at a position} } } - diplotype_definitions = get_diplotype_definition_dictionary(allele_definitions_filled, defining_variants) - dict_possible_calls = find_all_possible_calls(diplotype_definitions, defining_variants) + diplotype_definitions = get_diplotype_definition_dictionary(allele_definitions_filled, alleles_to_test) + + # get diplotype definition arrays + # vectorized computation by numpy arrays speeds up diplotype comparison + diplotype_definition_arrays: np.ndarray = convert_dictionary_to_numpy_array(diplotype_definitions) + # get diplotype names + diplotype_names: list[str] = [*diplotype_definitions] + # initialize a possible call sets + possible_call_sets: list[set[str]] = [set(diplotype_names)] + # find possible calls for each diplotype + possible_call_sets = find_all_possible_calls(diplotype_definition_arrays, diplotype_names, possible_call_sets) # get a dictionary of calculated diplotype scores based on the number of core allele-defining positions # first, calculate the haplotype scores @@ -868,7 +806,7 @@ def predict_pharmcat_calls(allele_defining_variants: dict[str, dict[str, str]], # then, calculate the diplotype scores by adding up the allele scores diplotype_scores = count_diplotype_scores(allele_scores) # summarize expected vs actual calls - dict_predicted_calls = find_predicted_calls(dict_possible_calls, diplotype_scores, missing_positions) + dict_predicted_calls = find_predicted_calls(possible_call_sets, diplotype_scores, missing_positions) # compare diplotype pairs and identify pairwise relationships # pairwise_comparisons = compare_diplotype_pairs(diplotype_definitions, allele_defining_variants) @@ -876,8 +814,7 @@ def predict_pharmcat_calls(allele_defining_variants: dict[str, dict[str, str]], return dict_predicted_calls -def compare_diplotype_pairs(diplotype_definitions: dict, - allele_defining_variants: dict[str, dict[str, str]]) -> dict: +def compare_diplotype_pairs(dict_diplotype_definitions: dict[str, dict[str, set[str]]]) -> dict: """ compare unphased diplotype definitions G = a MxP matrix where rows are all possible nucleotide pairs in a gene (M) and columns are positions (P) @@ -941,8 +878,7 @@ def compare_diplotype_pairs(diplotype_definitions: dict, empty column = G1 and G2 do not share any defining genotypes at a position - :param allele_defining_variants: - :param diplotype_definitions: + :param dict_diplotype_definitions: :return: a dictionary { 'diplotype_1': { score: 123, @@ -954,13 +890,12 @@ def compare_diplotype_pairs(diplotype_definitions: dict, # convert the diplotype definition dictionaries to numpy arrays of genotypes for each diplotype # vectorized computation by numpy arrays speeds up diplotype comparison - definition_arrays: np.ndarray = convert_dictionary_to_numpy_array(diplotype_definitions, - allele_defining_variants) + definition_arrays: np.ndarray = convert_dictionary_to_numpy_array(dict_diplotype_definitions) # get the list of diplotype names - diplotype_names: list[str] = [*diplotype_definitions] + diplotype_names: list[str] = [*dict_diplotype_definitions] # get the number of diplotypes - n_diplotypes = len(diplotype_definitions) + n_diplotypes = len(dict_diplotype_definitions) # iterate over diplotypes for i in range(n_diplotypes): @@ -970,7 +905,7 @@ def compare_diplotype_pairs(diplotype_definitions: dict, g1 = definition_arrays[i, :, :] # check whether g1 shares definitions with any other diplotype definitions - status_sharing_definitions = is_sharing_definitions(g1, definition_arrays) + status_sharing_definitions: np.ndarray = is_sharing_definitions(g1, definition_arrays) # check whether each of the diplotypes differs from g1 status_discrepant = is_discrepant(g1, definition_arrays) # sanity check to make sure no diplotypes can be discrepant but share definitions at the same time