Skip to content

Commit

Permalink
Merge pull request #151 from Kulivox/compareSTR-comparison-improvements
Browse files Browse the repository at this point in the history
Compare str comparison improvements
  • Loading branch information
LiterallyUniqueLogin authored Mar 24, 2022
2 parents 30d70fd + 0310ddd commit 8dd9abf
Show file tree
Hide file tree
Showing 14 changed files with 334 additions and 105 deletions.
11 changes: 11 additions & 0 deletions RELEASE_NOTES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
Unreleased changes
-----

CompareSTR and mergeutils changes:

* CompareSTR: the tool now only compares records that start and end at the same position. If overlap in records is detected,
the program will output a warning to the user. This warning contains IDs of the records and their positions.

* mergeutils: function GetMinHarmonizedRecords was transformed into GetRecordComparabilityAndIncrement, which allows the caller
to define custom predicate that decides whether records are comparable.

4.0.2
-----

Expand Down
246 changes: 166 additions & 80 deletions trtools/compareSTR/compareSTR.py

Large diffs are not rendered by default.

95 changes: 94 additions & 1 deletion trtools/compareSTR/tests/test_compareSTR.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
import os
from typing import List

import numpy as np
import pytest

from trtools.utils.tests.test_mergeutils import DummyHarmonizedRecord
from ..compareSTR import *


Expand Down Expand Up @@ -173,6 +174,97 @@ def test_main(tmpdir, vcfdir):
retcode = main(args)
assert retcode == 1

def test_no_comparable_records(tmpdir, vcfdir, capfd):
vcfcomp = os.path.join(vcfdir, "compareSTR_vcfs")
hipstr_vcf_1 = os.path.join(vcfcomp, "test_no_comparable_records_1.vcf.gz")
hipstr_vcf_2 = os.path.join(vcfcomp, "test_no_comparable_records_2.vcf.gz")

args = base_argparse(tmpdir)
args.vcf1 = hipstr_vcf_1
args.region = ""
args.vcf2 = hipstr_vcf_2

ret = main(args)
_, err = capfd.readouterr()
assert ret == 1
assert err == "No comparable records were found, exiting!\n"


def test_better_comparability_calculation(tmpdir, vcfdir, capfd):
vcfcomp = os.path.join(vcfdir, "compareSTR_vcfs")

test_vcf_1 = os.path.join(vcfcomp, "test_better_comparability_calculation_1.vcf.gz")
test_vcf_2 = os.path.join(vcfcomp, "test_better_comparability_calculation_2.vcf.gz")

args = base_argparse(tmpdir)
args.vcf1 = test_vcf_1
args.region = ""
args.vcftype1 = 'hipstr'
args.vcf2 = test_vcf_2
args.vcftype2 = 'hipstr'
retcode = main(args)
assert retcode == 0
with open(tmpdir + "/test_compare-locuscompare.tab", "r") as out_overall:
lines = out_overall.readlines()
# Two of the records wont be compared
assert len(lines) == 2
_, err = capfd.readouterr()
## first output is about records that have the same starting position but different end pos
assert err == ("Records STR_40 and STR_40 overlap:\n"
"STR_40: (112695, 112700)\n"
"STR_40: (112695, 112702),\n"
"but are NOT comparable!\n"
## second is more general, they just sort of overlap each other
"Records STR_41 and STR_41 overlap:\n"
"STR_41: (113695, 113700)\n"
"STR_41: (113693, 113702),\n"
"but are NOT comparable!\n"
## ends are the same but start positions are different
"Records STR_42 and STR_42 overlap:\n"
"STR_42: (114695, 114700)\n"
"STR_42: (114693, 114700),\n"
"but are NOT comparable!\n"
)
def test_comparability_handler(tmpdir, vcfdir, capfd):

### Tests without arguments
handler = handle_overlaps

records = [None, None]
chrom_idxs = [np.inf, np.inf]
min_idx = np.inf

assert not handler(records, chrom_idxs, min_idx)

records = [DummyHarmonizedRecord("chr1", 10), None]
min_idx = 0

assert not handler(records, chrom_idxs, min_idx)

records = [None, DummyHarmonizedRecord("chr1", 10, 4, "AC")]
chrom_idxs = [np.inf, 0]
assert not handler(records, chrom_idxs, min_idx)

records = [DummyHarmonizedRecord("chr2", 10, 4, "AC", end_pos=17), DummyHarmonizedRecord("chr1", 10, 4, "AC", end_pos=17)]
chrom_idxs = [1, 0]
# records from different chromosomes aren't comparable
assert not handler(records, chrom_idxs, min_idx)

chrom_idxs = [0, 0]
assert handler(records, chrom_idxs, min_idx)

records = [DummyHarmonizedRecord("chr1", 10, 5, "AC", "rec1", end_pos=19), DummyHarmonizedRecord("chr1", 10, 4, "AC", "rec2", end_pos=17)]
assert not handler(records, chrom_idxs, min_idx)
_, err = capfd.readouterr()
assert err != ""

records = [DummyHarmonizedRecord("chr1", 10, 4, "AC", end_pos=17), DummyHarmonizedRecord("chr1", 10, 4, "TG", end_pos=17)]
assert handler(records, chrom_idxs, min_idx)

records = [DummyHarmonizedRecord("chr1", 8, 5, "AC", end_pos=17), DummyHarmonizedRecord("chr1", 10, 4, "AC", end_pos=17)]
assert not handler(records, chrom_idxs, min_idx)


def test_hipstr_position_harmonisation(tmpdir, vcfdir):
vcfcomp = os.path.join(vcfdir, "compareSTR_vcfs")

Expand All @@ -188,6 +280,7 @@ def test_hipstr_position_harmonisation(tmpdir, vcfdir):
retcode = main(args)
assert retcode == 0


with open(tmpdir + "/test_compare-locuscompare.tab", "r") as out_overall:
lines = out_overall.readlines()
## vcf1 : flank bp at start of record, vcf2: no flanking bp
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
38 changes: 23 additions & 15 deletions trtools/utils/mergeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import trtools.utils.common as common
import trtools.utils.tr_harmonizer as trh

from typing import List, Union, Any, Optional, Callable
from typing import List, Union, Any, Optional, Callable, Tuple

CYVCF_RECORD = cyvcf2.Variant
CYVCF_READER = cyvcf2.VCF
Expand Down Expand Up @@ -249,12 +249,13 @@ def GetMinRecords(record_list: List[Optional[trh.TRRecord]], chroms: List[str])



def GetMinHarmonizedRecords(record_list: List[Optional[trh.TRRecord]], chroms: List[str]) -> List[bool]:
r"""Check if each record is next up in sort order
def GetRecordComparabilityAndIncrement(record_list: List[Optional[trh.TRRecord]],
chroms: List[str],
overlap_callback: Callable[[List[Optional[trh.TRRecord]], List[int], int], bool]) \
-> Tuple[List[bool], bool]:
r"""Get list that says which records should be skipped in the next
iteration, and whether they are all comparable with each other
Return a vector of boolean set to true if
the record is in lowest sort order of all the records
Use order in chroms to determine sort order of chromosomes
Parameters
----------
Expand All @@ -264,21 +265,28 @@ def GetMinHarmonizedRecords(record_list: List[Optional[trh.TRRecord]], chroms: L
chroms : list of str
Ordered list of all chromosomes
overlap_callback: Callable[[List[Optional[trh.TRRecord]], List[int], int], bool]
Function that calculates whether the records are comparable
Returns
-------
checks : list of bool
Set to True for records that are first in sort order
increment : list of bool
List or bools, where items are set to True when the record at the index of the item should be
skipped during VCF file comparison.
comparable: bool
Value, that determines whether current records are comparable / mergable, depending on the callback
"""
chrom_order = [np.inf if r is None else chroms.index(r.chrom) for r in record_list]
pos = [np.inf if r is None else r.pos for r in record_list]
min_pos = min(pos)
min_chrom_index = min(chrom_order)

min_chrom = min(chrom_order)
allpos = [pos[i] for i in range(len(pos)) if GetChromOrderEqual(chrom_order[i], min_chrom)]
if len(allpos) > 0:
min_pos = min(allpos)
else:
return [False] * len(record_list)
return [False if r is None else r.chrom == chroms[min_chrom] and r.pos == min_pos for r in record_list]
increment = \
[chrom_order[i] == min_chrom_index and pos[i] == min_pos and record_list[i] is not None
for i in range(len(chrom_order))]
comparable = overlap_callback(record_list, chrom_order, min_chrom_index)

return increment, comparable


def DoneReading(records: List[Union[CYVCF_RECORD, trh.TRRecord]]) -> bool:
Expand Down
34 changes: 25 additions & 9 deletions trtools/utils/tests/test_mergeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ def __init__(self, chrom, pos, ref, alts=[], info={}):


class DummyHarmonizedRecord:
def __init__(self, chrom, pos):
def __init__(self, chrom, pos, reflen=None, motif=None, record_id=None, end_pos=None):
self.chrom = chrom
self.pos = pos
self.end_pos = end_pos
self.ref_allele_length = reflen
self.motif = motif
self.record_id = record_id


def test_DebugPrintRecordLocations(capsys):
Expand Down Expand Up @@ -95,26 +99,38 @@ def test_UnzippedUnindexedFile(mrgvcfdir):
assert "Could not find VCF index" in str(info.value)


def test_GetMinHarmonizedRecords():
def test_GetRecordComparabilityAndIncrement():
chromosomes = ["chr1", "chr2", "chr3"]

def comp_callback_true(x, y, z):
return True

def comp_callback_false(x, y, z):
return False


pair = [DummyHarmonizedRecord("chr1", 20), DummyHarmonizedRecord("chr1", 20)]
assert mergeutils.GetMinHarmonizedRecords(pair, chromosomes) == [True, True]
assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_true) == ([True, True], True)

# these two test cases show that second result of GetRecordComparabilityAndIncrement is
# entirely dependant on the callback
pair = [DummyHarmonizedRecord("chr1", 21), DummyHarmonizedRecord("chr1", 20)]
assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([False, True], False)

pair = [DummyHarmonizedRecord("chr1", 21), DummyHarmonizedRecord("chr1", 20)]
assert mergeutils.GetMinHarmonizedRecords(pair, chromosomes) == [False, True]
assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_true) == ([False, True], True)

pair = [DummyHarmonizedRecord("chr2", 20), DummyHarmonizedRecord("chr1", 20)]
assert mergeutils.GetMinHarmonizedRecords(pair, chromosomes) == [False, True]
assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([False, True], False)

pair = [DummyHarmonizedRecord("chr1", 20), DummyHarmonizedRecord("chr1", 21)]
assert mergeutils.GetMinHarmonizedRecords(pair, chromosomes) == [True, False]
assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_true) == ([True, False], True)

pair = [None, None]
assert mergeutils.GetMinHarmonizedRecords(pair, chromosomes) == [False, False]
assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([False, False], False)

pair = [DummyHarmonizedRecord("chr1", 20), None]
assert mergeutils.GetMinHarmonizedRecords(pair, chromosomes) == [True, False]
assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([True, False], False)

pair = [None, DummyHarmonizedRecord("chr1", 20)]
assert mergeutils.GetMinHarmonizedRecords(pair, chromosomes) == [False, True]
assert mergeutils.GetRecordComparabilityAndIncrement(pair, chromosomes, comp_callback_false) == ([False, True], False)
15 changes: 15 additions & 0 deletions trtools/utils/tr_harmonizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,12 @@ class TRRecord:
The chromosome this locus is in
pos : int
The bp along the chromosome that this locus is at (ignoring flanking base pairs/full alleles)
end_pos:
Position of the last bp of ref allele (ignoring flanking base pairs/full alleles)
full_alleles_pos:
Position of the first bp of the full ref allele (including the flanking base pairs)
full_alleles_end_pos:
Position of the last bp of the full ref allele (including the flanking base pairs)
info : Dict[str, Any]
The dictionary of INFO fields at this locus
format : Dict[str, np.ndarray]
Expand Down Expand Up @@ -680,6 +686,13 @@ def __init__(self,
self.has_fabricated_ref_allele = False
self.ref_allele_length = len(ref_allele) / len(motif)

# declaration of end_pos variables. Values are rounded because self.ref_allele_length can
# sometimes be a float because of partial repeats. This can cause weird float problems, and simple cast
# is not enought to ensure that the proper position is calculated
self.end_pos = round(self.pos + self.ref_allele_length * len(motif) - 1)
self.full_alleles_end_pos = self.end_pos if full_alleles is None else \
round(self.full_alleles_pos + len(self.full_alleles[0]) - 1)

if alt_allele_lengths is not None:
self.has_fabricated_alt_alleles = True
self.alt_alleles = [
Expand All @@ -692,6 +705,8 @@ def __init__(self,
len(allele) / len(motif) for allele in self.alt_alleles
]



try:
self._CheckRecord()
except ValueError as e:
Expand Down

0 comments on commit 8dd9abf

Please sign in to comment.