diff --git a/truvari/bench.py b/truvari/bench.py index 065a7cf7..abdf3319 100644 --- a/truvari/bench.py +++ b/truvari/bench.py @@ -585,8 +585,8 @@ def build_matrix(self, base_variants, comp_variants, chunk_id=0, skip_gt=False, for bid, b in enumerate(base_variants): base_matches = [] for cid, c in enumerate(comp_variants): - mat = matcher(b, c, [f"{chunk_id}.{bid}", f"{chunk_id}.{cid}"], - skip_gt, self.short_circuit) + mat = matcher(b, c, matid=[f"{chunk_id}.{bid}", f"{chunk_id}.{cid}"], + skip_gt=skip_gt, short=self.short_circuit) logging.debug("Made mat -> %s", mat) base_matches.append(mat) match_matrix.append(base_matches) diff --git a/truvari/matching.py b/truvari/matching.py index 070b6be5..d2ed1cb7 100644 --- a/truvari/matching.py +++ b/truvari/matching.py @@ -194,11 +194,11 @@ def compare_gts(self, match, base, comp): match.comp_gt_count = sum(1 for _ in match.comp_gt if _ == 1) match.gt_match = abs(match.base_gt_count - match.comp_gt_count) - def build_match(self, base, comp, matid=None, skip_gt=False, short_circuit=False): + def build_match(self, base, comp, matid=None, skip_gt=False, short=False): """ Build a MatchResult if skip_gt, don't do genotype comparison - if short_circuit, return after first failure + if short, return after first failure """ ret = MatchResult() ret.base = base @@ -211,7 +211,7 @@ def build_match(self, base, comp, matid=None, skip_gt=False, short_circuit=False logging.debug("%s and %s are not the same SVTYPE", str(base), str(comp)) ret.state = False - if short_circuit: + if short: return ret bstart, bend = truvari.entry_boundaries(base) @@ -220,7 +220,7 @@ def build_match(self, base, comp, matid=None, skip_gt=False, short_circuit=False logging.debug("%s and %s are not within REFDIST", str(base), str(comp)) ret.state = False - if short_circuit: + if short: return ret ret.sizesim, ret.sizediff = truvari.entry_size_similarity(base, comp) @@ -228,7 +228,7 @@ def build_match(self, base, comp, matid=None, skip_gt=False, short_circuit=False logging.debug("%s and %s size similarity is too low (%.3f)", str(base), str(comp), ret.sizesim) ret.state = False - if short_circuit: + if short: return ret if not skip_gt: @@ -239,7 +239,7 @@ def build_match(self, base, comp, matid=None, skip_gt=False, short_circuit=False logging.debug("%s and %s overlap percent is too low (%.3f)", str(base), str(comp), ret.ovlpct) ret.state = False - if short_circuit: + if short: return ret if self.params.pctseq > 0: @@ -249,7 +249,7 @@ def build_match(self, base, comp, matid=None, skip_gt=False, short_circuit=False logging.debug("%s and %s sequence similarity is too low (%.3ff)", str(base), str(comp), ret.seqsim) ret.state = False - if short_circuit: + if short: return ret else: ret.seqsim = 0 @@ -259,19 +259,30 @@ def build_match(self, base, comp, matid=None, skip_gt=False, short_circuit=False return ret - def bnd_build_match(self, base, comp, matid=None, *_args, **_kwargs): + def bnd_build_match(self, base, comp, matid=None, **_kwargs): """ Build a MatchResult for bnds """ + def bounds(entry, pos, key='POS'): + """ + Inflate a bnd position + """ + start = pos - self.params.bnddist + end = pos + self.params.bnddist + if key + '1' in entry.info: + start -= entry.info['CI' + key + '1'] + if key + '2' in entry.info: + end += entry.info['CI' + key + '2'] + return start, end + ret = truvari.MatchResult() ret.base = base ret.comp = comp ret.matid = matid ret.state = base.chrom == comp.chrom - ret.st_dist = base.pos - comp.pos - ret.state &= abs(ret.st_dist) < self.params.bnddist + ret.state &= truvari.overlaps(*bounds(base, base.pos, 'POS'), *bounds(comp, comp.pos, 'POS')) b_bnd = bnd_direction_strand(base.alts[0]) c_bnd = bnd_direction_strand(comp.alts[0]) ret.state &= b_bnd == c_bnd @@ -280,14 +291,16 @@ def bnd_build_match(self, base, comp, matid=None, *_args, **_kwargs): c_pos2 = bnd_position(comp.alts[0]) ret.ed_dist = b_pos2[1] - c_pos2[1] ret.state &= b_pos2[0] == c_pos2[0] + + ret.state &= truvari.overlaps(*bounds(base, b_pos2[1], 'END'), *bounds(comp, c_pos2[1], 'END')) ret.state &= ret.ed_dist < self.params.bnddist self.compare_gts(ret, base, comp) # Score is percent of allowed distance needed to find this match - ret.score = (1 - ((abs(ret.st_dist) + abs(ret.ed_dist)) / - 2) / self.params.bnddist) * 100 - # I think I'm missing GT stuff here + ret.score = (1 - ((abs(ret.st_dist) + abs(ret.ed_dist)) / 2) + / self.params.bnddist) * 100 + return ret ############################