From 07dd27b5fa196c2ed30d33e5f2f1f8e6bd5f872d Mon Sep 17 00:00:00 2001 From: Adam English Date: Mon, 6 Jan 2025 21:30:33 -0500 Subject: [PATCH] More API improvements Matcher methods make more sense inside VariantRecords. Flattened Matcher so that matcher.params.abc becomes matcher.abc. VariantFile can be sent a Matcher so that users don't have to tote it around. Some progress on documentation --- docs/api/truvari.examples.rst | 40 ++++++ repo_utils/run_unittest.py | 9 +- truvari/annotations/chunks.py | 8 +- truvari/bench.py | 51 +++---- truvari/collapse.py | 41 +++--- truvari/matching.py | 114 ++++----------- truvari/refine.py | 2 +- truvari/variants.py | 264 +++++++++++++++++++++++++++------- 8 files changed, 330 insertions(+), 199 deletions(-) diff --git a/docs/api/truvari.examples.rst b/docs/api/truvari.examples.rst index 058a233b..f22c3a5e 100644 --- a/docs/api/truvari.examples.rst +++ b/docs/api/truvari.examples.rst @@ -14,5 +14,45 @@ statements). print(entry.info['SVTYPE']) # pysam access to the variant's INFO fields print(entry.allele_freq_annos()) # truvari calculation of a variant's allele frequency + # pysam GT access + if 'GT' in entry.samples['SAMPLE']: + print(entry.samples[0]['GT']) + # truvari GT access + print(entry.gt('SAMPLE')) + Details of all available functions are in :ref:`package documentation ` + +Besides some helpful accession functions, the `truvari.VariantRecord` also makes comparing two VCF entries easy. + +.. code-block:: python + + # Given two `truvari.VariantRecords`, entry1 and entry2 + match = entry1.match(entry2) + print("Entries' Sequence Similarity:", match.seqsim) + print("Entries' Size Similarity:", match.sizesim) + print("Is the match above thresholds:", match.state) + +This returns a `truvari.MatchResult`. We can customize the thresholds for matching by giving the `truvari.VariantFile` a +`truvari.Matcher`. + +.. code-block:: python + + # Turn off sequence and size similarity. Turn on reciprocal overlap + matcher = truvari.Matcher(seqsim=0, sizesim=0, recovl=0.5, ...) + vcf = truvari.VariantFile("input.vcf.gz", matcher=matcher) + entry1 = next(vcf) + entry2 = next(vcf) + match = entry1.match(entry2) + +Another useful function is a quick way filter variants. The `truvari.Matcher` has parameters for e.g. minimum or maximum +size of SVs one wants to analyze which can be leveraged via: + +.. code-block:: python + + matcher = truvari.Matcher(sizemin=200, sizemax=500) + vcf = truvari.VariantFile("input.vcf.gz", matcher=matcher) + # Grab all of the variant records between sizemin and sizemax + results = [entry for entry in vcf if not entry.size_filter()] + +Additional filtering for things such as monomorphic reference sites or single-end BNDs are available by calling `entry.filter_call()` diff --git a/repo_utils/run_unittest.py b/repo_utils/run_unittest.py index d472d4fe..fdb96d11 100644 --- a/repo_utils/run_unittest.py +++ b/repo_utils/run_unittest.py @@ -75,13 +75,10 @@ """ Filtering logic """ -vcf = truvari.VariantFile("repo_utils/test_files/variants/filter.vcf") -matcher = truvari.Matcher() -matcher.params.sizemin = 0 -matcher.params.sizefilt = 0 -matcher.params.passonly = True +matcher = truvari.Matcher(sizemin=0, sizefilt=0, passonly=True) +vcf = truvari.VariantFile("repo_utils/test_files/variants/filter.vcf", matcher=matcher) for entry in vcf: try: - assert matcher.filter_call(entry), f"Didn't filter {str(entry)}" + assert entry.filter_call(), f"Didn't filter {str(entry)}" except ValueError as e: assert e.args[0].startswith("Cannot compare multi-allelic"), f"Unknown exception {str(entry)}" diff --git a/truvari/annotations/chunks.py b/truvari/annotations/chunks.py index 6eabba67..a964ab53 100644 --- a/truvari/annotations/chunks.py +++ b/truvari/annotations/chunks.py @@ -49,13 +49,7 @@ def chunks_main(args): """ args = parse_args(args) v = truvari.VariantFile(args.input) - m = truvari.Matcher() - m.params.pctseq = 0 - m.params.sizemin = args.sizemin - m.params.sizefilt = args.sizemin - m.params.sizemax = args.sizemax - m.params.chunksize = args.chunksize - m.params.refdist = args.chunksize + m = truvari.Matcher(args=args, pctseq=0) if args.bed: v = v.bed_fetch(args.bed) c = truvari.chunker(m, ('base', v)) diff --git a/truvari/bench.py b/truvari/bench.py index e89ef5d9..dc083d95 100644 --- a/truvari/bench.py +++ b/truvari/bench.py @@ -323,7 +323,7 @@ def __init__(self, bench, matcher): os.mkdir(self.m_bench.outdir) param_dict = self.m_bench.param_dict() - param_dict.update(vars(self.m_matcher.params)) + param_dict.update(vars(self.m_matcher)) if self.m_bench.do_logging: truvari.setup_logging(self.m_bench.debug, truvari.LogFileStderr( @@ -333,8 +333,8 @@ def __init__(self, bench, matcher): with open(os.path.join(self.m_bench.outdir, 'params.json'), 'w') as fout: json.dump(param_dict, fout) - b_vcf = truvari.VariantFile(self.m_bench.base_vcf) - c_vcf = truvari.VariantFile(self.m_bench.comp_vcf) + b_vcf = truvari.VariantFile(self.m_bench.base_vcf, matcher=self.m_matcher) + c_vcf = truvari.VariantFile(self.m_bench.comp_vcf, matcher=self.m_matcher) self.n_headers = {'b': edit_header(b_vcf), 'c': edit_header(c_vcf)} @@ -345,14 +345,14 @@ def __init__(self, bench, matcher): self.out_vcfs = {} for key in ['tpb', 'fn']: self.out_vcfs[key] = truvari.VariantFile( - self.vcf_filenames[key], mode='w', header=self.n_headers['b']) + self.vcf_filenames[key], mode='w', header=self.n_headers['b'], matcher=self.m_matcher) for key in ['tpc', 'fp']: self.out_vcfs[key] = truvari.VariantFile( - self.vcf_filenames[key], mode='w', header=self.n_headers['c']) + self.vcf_filenames[key], mode='w', header=self.n_headers['c'], matcher=self.m_matcher) self.stats_box = StatsBox() - def write_match(self, match, resolved=False): + def write_match(self, match): """ Annotate a MatchResults' entries then write to the apppropriate file and do the stats counting. @@ -368,30 +368,30 @@ def write_match(self, match, resolved=False): box["gt_matrix"][gtBase][gtComp] += 1 box["TP-base"] += 1 - self.out_vcfs["tpb"].write(match.base, resolved) + self.out_vcfs["tpb"].write(match.base) if match.gt_match == 0: box["TP-base_TP-gt"] += 1 else: box["TP-base_FP-gt"] += 1 else: box["FN"] += 1 - self.out_vcfs["fn"].write(match.base, resolved) + self.out_vcfs["fn"].write(match.base) if match.comp: annotate_entry(match.comp, match, self.n_headers['c']) if match.state: box["comp cnt"] += 1 box["TP-comp"] += 1 - self.out_vcfs["tpc"].write(match.comp, resolved) + self.out_vcfs["tpc"].write(match.comp) if match.gt_match == 0: box["TP-comp_TP-gt"] += 1 else: box["TP-comp_FP-gt"] += 1 - elif match.comp.size() >= self.m_matcher.params.sizemin: + elif match.comp.size() >= self.m_matcher.sizemin: # The if is because we don't count FPs between sizefilt-sizemin box["comp cnt"] += 1 box["FP"] += 1 - self.out_vcfs["fp"].write(match.comp, resolved) + self.out_vcfs["fp"].write(match.comp) def close_outputs(self): """ @@ -422,8 +422,7 @@ class Bench(): .. code-block:: python - matcher = truvari.Matcher() - matcher.params.pctseq = 0.50 + matcher = truvari.Matcher(pctseq=0.50) m_bench = truvari.Bench(matcher) To run on a chunk of :class:`truvari.VariantRecord` already loaded, pass them in as lists to: @@ -451,10 +450,8 @@ class Bench(): However, the returned `BenchOutput` has attributes pointing to all the results. """ - #pylint: disable=too-many-arguments def __init__(self, matcher=None, base_vcf=None, comp_vcf=None, outdir=None, - includebed=None, extend=0, debug=False, do_logging=False, short_circuit=False, - write_resolved=False): + includebed=None, extend=0, debug=False, do_logging=False, short_circuit=False): """ Initilize """ @@ -468,8 +465,6 @@ def __init__(self, matcher=None, base_vcf=None, comp_vcf=None, outdir=None, self.do_logging = do_logging self.short_circuit = short_circuit self.refine_candidates = [] - self.write_resolved = write_resolved - #pylint: enable=too-many-arguments def param_dict(self): """ @@ -492,8 +487,8 @@ def run(self): output = BenchOutput(self, self.matcher) - base = truvari.VariantFile(self.base_vcf) - comp = truvari.VariantFile(self.comp_vcf) + base = truvari.VariantFile(self.base_vcf, matcher=self.matcher) + comp = truvari.VariantFile(self.comp_vcf, matcher=self.matcher) region_tree = truvari.build_region_tree(base, comp, self.includebed) truvari.merge_region_tree_overlaps(region_tree) @@ -513,7 +508,7 @@ def run(self): and not match.state and not match.comp.within_tree(region_tree)): match.comp = None - output.write_match(match, self.write_resolved) + output.write_match(match) with open(os.path.join(self.outdir, 'candidate.refine.bed'), 'w') as fout: fout.write("\n".join(self.refine_candidates)) @@ -531,7 +526,7 @@ def compare_chunk(self, chunk): chunk_dict["base"], chunk_dict["comp"], chunk_id) self.check_refine_candidate(result) # Check BNDs separately - if self.matcher.params.bnddist != -1 and (chunk_dict['base_BND'] or chunk_dict['comp_BND']): + if self.matcher.bnddist != -1 and (chunk_dict['base_BND'] or chunk_dict['comp_BND']): result.extend(self.compare_calls(chunk_dict['base_BND'], chunk_dict['comp_BND'], chunk_id)) return result @@ -584,7 +579,7 @@ def compare_calls(self, base_variants, comp_variants, chunk_id=0): base_variants, comp_variants, chunk_id) if isinstance(match_matrix, list): return match_matrix - return PICKERS[self.matcher.params.pick](match_matrix) + return PICKERS[self.matcher.pick](match_matrix) def build_matrix(self, base_variants, comp_variants, chunk_id=0, skip_gt=False): """ @@ -597,8 +592,7 @@ def build_matrix(self, base_variants, comp_variants, chunk_id=0, skip_gt=False): for bid, base in enumerate(base_variants): base_matches = [] for cid, comp in enumerate(comp_variants): - mat = base.match(comp, matcher=self.matcher, - skip_gt=skip_gt, short_circuit=self.short_circuit) + mat = base.match(comp, skip_gt=skip_gt, short_circuit=self.short_circuit) mat.matid = [f"{chunk_id}.{bid}", f"{chunk_id}.{cid}"] logging.debug("Made mat -> %s", mat) base_matches.append(mat) @@ -615,14 +609,14 @@ def check_refine_candidate(self, result): chrom = None for match in result: has_unmatched |= not match.state - if match.base is not None and match.base.size() >= self.matcher.params.sizemin: + if match.base is not None and match.base.size() >= self.matcher.sizemin: chrom = match.base.chrom pos.extend(match.base.boundaries()) if match.comp is not None: chrom = match.comp.chrom pos.extend(match.comp.boundaries()) if has_unmatched and pos: - # min(10, self.matcher.params.chunksize) need to make sure the refine covers the region + # min(10, self.matcher.chunksize) need to make sure the refine covers the region buf = 10 start = max(0, min(*pos) - buf) self.refine_candidates.append( @@ -783,8 +777,7 @@ def bench_main(cmdargs): extend=args.extend, debug=args.debug, do_logging=True, - short_circuit=args.short, - write_resolved=args.write_resolved) + short_circuit=args.short) output = m_bench.run() logging.info("Stats: %s", json.dumps(output.stats_box, indent=4)) diff --git a/truvari/collapse.py b/truvari/collapse.py index 59d883a9..ef1949d1 100644 --- a/truvari/collapse.py +++ b/truvari/collapse.py @@ -95,7 +95,7 @@ def combine(self, other): self.genotype_mask |= other.genotype_mask -def chain_collapse(cur_collapse, all_collapse, matcher): +def chain_collapse(cur_collapse, all_collapse): """ Perform transitive matching of cur_collapse to all_collapse Check the cur_collapse's entry to all other collapses' consolidated entries @@ -103,7 +103,6 @@ def chain_collapse(cur_collapse, all_collapse, matcher): for m_collap in all_collapse: for other in m_collap.matches: mat = cur_collapse.entry.match(other.comp, - matcher=matcher, skip_gt=True, short_circuit=True) mat.matid = m_collap.match_id @@ -138,7 +137,6 @@ def collapse_chunk(chunk, matcher): # Sort based on size difference to current call for candidate in sorted(remaining_calls, key=partial(relative_size_sorter, m_collap.entry)): mat = m_collap.entry.match(candidate, - matcher=matcher, skip_gt=True, short_circuit=True) mat.matid = m_collap.match_id @@ -148,13 +146,13 @@ def collapse_chunk(chunk, matcher): mat.state = False if mat.state: m_collap.matches.append(mat) - elif mat.sizesim is not None and mat.sizesim < matcher.params.pctsize: + elif mat.sizesim is not None and mat.sizesim < matcher.pctsize: # Can we do this? The sort tells us that we're going through most->least # similar size. So the next one will only be worse... break # Does this collap need to go into a previous collap? - if not matcher.chain or not chain_collapse(m_collap, ret, matcher): + if not matcher.chain or not chain_collapse(m_collap, ret): ret.append(m_collap) # If hap, only allow the best match @@ -640,12 +638,12 @@ def consolidate(self, entry, write_resolved=False): entry = n_entry return "\t".join(str(entry).split('\t')[:10]) + '\n' - def write(self, entry, write_resolved=False): + def write(self, entry): """ Writes header (str) or entries (VariantRecords) """ if isinstance(entry, truvari.VariantRecord): - entry = self.consolidate(entry, write_resolved) + entry = self.consolidate(entry, entry.matcher.write_resolved) if self.isgz: entry = entry.encode() self.fh.write(entry) @@ -662,14 +660,14 @@ class CollapseOutput(dict): Output writer for collapse """ - def __init__(self, args): + def __init__(self, args, matcher): """ Makes all of the output files for collapse """ super().__init__() logging.info("Params:\n%s", json.dumps(vars(args), indent=4)) - in_vcf = truvari.VariantFile(args.input) + in_vcf = truvari.VariantFile(args.input, matcher=matcher) self["o_header"] = edit_header(in_vcf, args.median_info) self["c_header"] = trubench.edit_header(in_vcf) num_samps = len(self["o_header"].samples) @@ -681,12 +679,11 @@ def __init__(self, args): self["output_vcf"] = IntraMergeOutput(args.output, self["o_header"]) else: self["output_vcf"] = truvari.VariantFile(args.output, 'w', - header=self["o_header"]) + header=self["o_header"], matcher=matcher) self["collap_vcf"] = truvari.VariantFile(args.removed_output, 'w', - header=self["c_header"]) + header=self["c_header"], matcher=matcher) self["stats_box"] = {"collap_cnt": 0, "kept_cnt": 0, "out_cnt": 0, "consol_cnt": 0} - self.write_resolved = args.write_resolved def write(self, collap, median_info=False): """ @@ -696,17 +693,17 @@ def write(self, collap, median_info=False): self["stats_box"]["out_cnt"] += 1 # Nothing collapsed, no need to annotate if not collap.matches: - self["output_vcf"].write(collap.entry, self.write_resolved) + self["output_vcf"].write(collap.entry) return collap.annotate_entry(self["o_header"], median_info) - self["output_vcf"].write(collap.entry, self.write_resolved) + self["output_vcf"].write(collap.entry) self["stats_box"]["kept_cnt"] += 1 self["stats_box"]["consol_cnt"] += collap.gt_consolidate_count for match in collap.matches: trubench.annotate_entry(match.comp, match, self["c_header"]) - self["collap_vcf"].write(match.comp, self.write_resolved) + self["collap_vcf"].write(match.comp) self['stats_box']["collap_cnt"] += 1 def close(self): @@ -817,8 +814,8 @@ def tree_size_chunker(matcher, chunks): for entry in chunk['base']: # How much smaller/larger would be in sizesim? sz = entry.size() - diff = sz * (1 - matcher.params.pctsize) - if not matcher.params.typeignore: + diff = sz * (1 - matcher.pctsize) + if not matcher.typeignore: sz *= -1 if entry.var_type() == truvari.SV.DEL else 1 to_add.append((sz - diff, sz + diff, LinkedList(entry))) tree = merge_intervals(to_add) @@ -844,8 +841,8 @@ def tree_dist_chunker(matcher, chunks): to_add = [] for entry in chunk['base']: st, ed = entry.boundaries() - st -= matcher.params.refdist - ed += matcher.params.refdist + st -= matcher.refdist + ed += matcher.refdist to_add.append((st, ed, LinkedList(entry))) tree = merge_intervals(to_add) for intv in tree: @@ -872,7 +869,7 @@ def collapse_main(args): args.sizefilt = args.sizemin args.no_ref = False matcher = truvari.Matcher(args=args) - matcher.params.includebed = None + matcher.includebed = None matcher.keep = args.keep matcher.hap = args.hap matcher.gt = args.gt @@ -881,7 +878,7 @@ def collapse_main(args): matcher.no_consolidate = args.no_consolidate matcher.picker = 'single' - base = truvari.VariantFile(args.input) + base = truvari.VariantFile(args.input, matcher=matcher) regions = truvari.build_region_tree(base, includebed=args.bed) truvari.merge_region_tree_overlaps(regions) base_i = base.regions_fetch(regions) @@ -890,7 +887,7 @@ def collapse_main(args): smaller_chunks = tree_size_chunker(matcher, chunks) even_smaller_chunks = tree_dist_chunker(matcher, smaller_chunks) - outputs = CollapseOutput(args) + outputs = CollapseOutput(args, matcher) m_collap = partial(collapse_chunk, matcher=matcher) for call in itertools.chain.from_iterable(map(m_collap, even_smaller_chunks)): outputs.write(call, args.median_info) diff --git a/truvari/matching.py b/truvari/matching.py index b15c7659..6b750cbe 100644 --- a/truvari/matching.py +++ b/truvari/matching.py @@ -7,6 +7,7 @@ from functools import total_ordering import pysam + @total_ordering class MatchResult(): # pylint: disable=too-many-instance-attributes """ @@ -67,28 +68,33 @@ class Matcher(): Example >>> import truvari - >>> mat = truvari.Matcher() - >>> mat.params.pctseq = 0 + >>> mat = truvari.Matcher(pctseq=0) >>> v = truvari.VariantFile('repo_utils/test_files/variants/input1.vcf.gz') >>> one = next(v); two = next(v) - >>> mat.build_match(one, two) + >>> one.match(two, matcher=mat) Look at `Matcher.make_match_params()` for a list of all params and their defaults """ - def __init__(self, args=None): + def __init__(self, args=None, **kwargs): """ Initalize. args is a Namespace from argparse """ if args is not None: - self.params = self.make_match_params_from_args(args) + params = self.make_match_params_from_args(args) else: - self.params = self.make_match_params() + params = self.make_match_params() + + # Override parameters with those provided in kwargs + for key, value in kwargs.items(): + if hasattr(params, key): + setattr(params, key, value) + else: + raise ValueError(f"Invalid parameter: {key}") - self.reference = None - if self.params.reference is not None: - self.reference = pysam.FastaFile(self.params.reference) + for key, value in params.__dict__.items(): + setattr(self, key, value) @staticmethod def make_match_params(): @@ -117,83 +123,23 @@ def make_match_params(): params.ignore_monref = True params.check_multi = True params.check_monref = True + params.no_single_bnd = True + params.write_resolved = False return params @staticmethod def make_match_params_from_args(args): """ - Makes a simple namespace of matching parameters - """ - ret = types.SimpleNamespace() - ret.reference = args.reference - ret.refdist = args.refdist - ret.pctseq = args.pctseq - ret.pctsize = args.pctsize - ret.pctovl = args.pctovl - ret.typeignore = args.typeignore - ret.no_roll = args.no_roll - ret.chunksize = args.chunksize - ret.bSample = args.bSample if args.bSample else 0 - ret.cSample = args.cSample if args.cSample else 0 - ret.dup_to_ins = args.dup_to_ins if "dup_to_ins" in args else False - ret.bnddist = args.bnddist if 'bnddist' in args else -1 - # filtering properties - ret.sizemin = args.sizemin - ret.sizefilt = args.sizefilt - ret.sizemax = args.sizemax - ret.passonly = args.passonly - ret.no_ref = args.no_ref - ret.pick = args.pick if "pick" in args else "single" - ret.check_monref = True - ret.check_multi = True - return ret - - def filter_call(self, entry, base=False): - """ - Returns True if the call should be filtered based on parameters or truvari requirements - Base has different filtering requirements, so let the method know + Makes a simple namespace of matching parameters. + Populates defaults from make_match_params, then updates with values from args. """ - if self.params.check_monref and entry.is_monrefstar(): - return True - - if self.params.check_multi and entry.is_multi(): - raise ValueError( - f"Cannot compare multi-allelic records. Please split\nline {str(entry)}") - - if self.params.passonly and entry.is_filtered(): - return True + ret = Matcher.make_match_params() - prefix = 'b' if base else 'c' - if self.params.no_ref in ["a", prefix] or self.params.pick == 'ac': - samp = self.params.bSample if base else self.params.cSample - if not entry.is_present(samp): - return True + for key in vars(ret): + if hasattr(args, key): + setattr(ret, key, getattr(args, key)) - # No single end BNDs - return entry.is_single_bnd() - - def size_filter(self, entry, base=False): - """ - Returns True if entry should be filtered due to its size - """ - size = entry.size() - return (size > self.params.sizemax) \ - or (base and size < self.params.sizemin) \ - or (not base and size < self.params.sizefilt) - - def compare_gts(self, match, base, comp): - """ - Given a MatchResult, populate the genotype specific comparisons in place - """ - b_gt = base.gt(self.params.bSample) - c_gt = comp.gt(self.params.cSample) - if b_gt: - match.base_gt = b_gt - match.base_gt_count = sum(1 for _ in match.base_gt if _ == 1) - if c_gt: - match.comp_gt = c_gt - match.comp_gt_count = sum(1 for _ in match.comp_gt if _ == 1) - match.gt_match = abs(match.base_gt_count - match.comp_gt_count) + return ret ############################ # Parsing and set building # @@ -252,20 +198,22 @@ def chunker(matcher, *files): cur_end = 0 cur_chunk = defaultdict(list) unresolved_warned = False + reference = pysam.FastaFile(matcher.reference) if matcher.reference is not None else None + for key, entry in file_zipper(*files): - if matcher.filter_call(entry, key == 'base'): + if entry.filter_call(key == 'base'): cur_chunk['__filtered'].append(entry) call_counts['__filtered'] += 1 continue - if not entry.is_bnd() and matcher.size_filter(entry, key == 'base'): + if not entry.is_bnd() and entry.size_filter(key == 'base'): cur_chunk['__filtered'].append(entry) call_counts['__filtered'] += 1 continue # check symbolic, resolve if needed/possible - if matcher.params.pctseq != 0 and entry.alts[0].startswith('<'): - was_resolved = entry.resolve(matcher.reference, matcher.params.dup_to_ins) + if matcher.pctseq != 0 and entry.alts[0].startswith('<'): + was_resolved = entry.resolve(reference) if not was_resolved: if not unresolved_warned: logging.warning("Some symbolic SVs couldn't be resolved") @@ -275,7 +223,7 @@ def chunker(matcher, *files): continue new_chrom = cur_chrom and entry.chrom != cur_chrom - new_chunk = cur_end and cur_end + matcher.params.chunksize < entry.start + new_chunk = cur_end and cur_end + matcher.chunksize < entry.start if new_chunk or new_chrom: chunk_count += 1 yield cur_chunk, chunk_count diff --git a/truvari/refine.py b/truvari/refine.py index f6317b9a..f0b8746f 100644 --- a/truvari/refine.py +++ b/truvari/refine.py @@ -421,7 +421,7 @@ def refine_main(cmdargs): # Now run bench on the phab harmonized variants logging.info("Running bench") matcher = truvari.Matcher(params) - matcher.params.no_ref = 'a' + matcher.no_ref = 'a' outdir = os.path.join(args.benchdir, "phab_bench") m_bench = truvari.Bench(matcher=matcher, base_vcf=phab_vcf, comp_vcf=phab_vcf, outdir=outdir, includebed=reeval_bed, short_circuit=True) diff --git a/truvari/variants.py b/truvari/variants.py index c2904ef0..1c261fd7 100644 --- a/truvari/variants.py +++ b/truvari/variants.py @@ -15,62 +15,109 @@ class VariantFile: """ - Wrapper around pysam.VariantFile with helper functions for iteration - Note: The context manager functionality of pysam.VariantFile is not available with truvari.VariantFile + Wrapper around pysam.VariantFile with helper functions for iteration. + + .. note:: + The context manager functionality of pysam.VariantFile is not available with truvari.VariantFile. """ - def __init__(self, filename, *args, **kwargs): + def __init__(self, filename, *args, matcher=None, **kwargs): + """ + Initialize the VariantFile wrapper. + + :param filename: Path to the VCF file to be opened. + :type filename: str + :param matcher: Matcher to apply to all VariantRecords + :type matcher: `truvari.Matcher` + :param args: Additional positional arguments to pass to pysam.VariantFile. + :param kwargs: Additional keyword arguments to pass to pysam.VariantFile. + """ + self.matcher = matcher self._vcf = pysam.VariantFile(filename, *args, **kwargs) def __getattr__(self, name): """ - Delegate attribute access to the original VariantFile + Delegate attribute access to the original VariantFile. + + :param name: Attribute name to access from the underlying pysam.VariantFile object. + :type name: str + :return: The requested attribute from the underlying pysam.VariantFile. + :rtype: Any """ return getattr(self._vcf, name) def __iter__(self): """ - Iterate the VariantFile, wrapping into truvari VariantRecords + Iterate over the `pysam.VariantFile`, wrapping entries into `truvari.VariantRecord`. + + :return: Iterator of truvari.VariantRecord objects. + :rtype: iterator """ for i in self._vcf: - yield VariantRecord(i) + yield VariantRecord(i, self.matcher) def __next__(self): """ - Return the next + Return the next truvari.VariantRecord in the VariantFile. + + :return: Next truvari.VariantRecord. + :rtype: truvari.VariantRecord """ - return VariantRecord(next(self._vcf)) + return VariantRecord(next(self._vcf), self.matcher) def fetch(self, *args, **kwargs): """ - Fetch from the VariantFile, wrapping into truvari VariantRecords + Fetch variants from the `pysam.VariantFile`, wrapping them into `truvari.VariantRecords`. + + :param args: Positional arguments for the pysam.VariantFile.fetch method. + :param kwargs: Keyword arguments for the pysam.VariantFile.fetch method. + :return: Iterator of truvari.VariantRecord objects. + :rtype: iterator """ + for i in self._vcf.fetch(*args, **kwargs): - yield truvari.VariantRecord(i) + yield truvari.VariantRecord(i, self.matcher) def regions_fetch(self, tree, inside=True, with_region=False): """ - Given a tree of chrom:IntervalTree, fetch variants from the VCF - that are inside (or outside) the tree's regions - with_regions will return tuples of the (variant, region) + Fetch variants from the VCF based on regions defined in a tree of chrom:IntervalTree. + + :param tree: Tree of chrom:IntervalTree defining regions of interest. + :type tree: dict + :param inside: If True, fetch variants inside the regions. If False, fetch variants outside the regions. + :type inside: bool + :param with_region: If True, return tuples of (`truvari.VariantRecord`, region). Defaults to False. + :type with_region: bool + :return: Iterator of truvari.VariantRecord objects or tuples of (`truvari.VariantRecord`, region). + :rtype: iterator """ return region_filter(self, tree, inside, with_region) def bed_fetch(self, bed_fn, inside=True, with_region=False): """ - Given a bed file, iterate variants inside or outside the regions. - with_regions will return tuples of the (variant, region) + Fetch variants from the VCF based on regions defined in a BED file. + + :param bed_fn: Path to the BED file defining regions of interest. + :type bed_fn: str + :param inside: If True, fetch variants inside the regions. If False, fetch variants outside the regions. + :type inside: bool + :param with_region: If True, return tuples of (`truvari.VariantRecord`, region). Defaults to False. + :type with_region: bool + :return: Iterator of truvari.VariantRecord objects or tuples of (`truvari.VariantRecord`, region). + :rtype: iterator """ tree = truvari.read_bed_tree(bed_fn) return self.regions_fetch(tree, inside, with_region) - def write(self, record, resolved=False): + def write(self, record): """ - Pull pysam VarianRecord out of truvari VariantRecord before writing - If resolved, replace the record's REF and ALT with self.get_ref() self.get_alt() + Write a `truvari.VariantRecord` to the `pysam.VariantFile`. + + :param record: The truvari.VariantRecord to be written. + :type record: `truvari.VariantRecord` """ out = record.get_record() - if resolved: + if self.matcher and self.matcher.write_resolved: out.ref = record.get_ref() out.alts = (record.get_alt(),) self._vcf.write(out) @@ -81,7 +128,7 @@ class VariantRecord: Wrapper around pysam.VariantRecords with helper functions of variant properties and basic comparisons """ - def __init__(self, record): + def __init__(self, record, matcher=None): """ Initialize with just the internal record """ @@ -89,6 +136,10 @@ def __init__(self, record): self.resolved_ref = None self.resolved_alt = None self.end = record.stop + if matcher is None: + self.matcher = truvari.Matcher() + else: + self.matcher = matcher def __getattr__(self, name): """ @@ -128,11 +179,19 @@ def allele_freq_annos(self, samples=None): def bnd_direction_strand(self): """ Parses a BND ALT string to determine its direction and strand. - A direction of 'left' means the piece is anchored on the left of the breakpoint - Note that this method assumes `self.is_bnd()` - Returns: - tuple: A tuple containing the direction ("left" or "right") and strand ("direct" or "complement"). + A BND (breakend) ALT string indicates a structural variant breakpoint. This method parses the ALT string to determine: + + - The direction: "left" means the piece is anchored on the left side of the breakpoint, while "right" means it's anchored on the right. + - The strand: "direct" indicates the base is on the direct strand, and "complement" indicates the base is on the complement strand. + + .. note:: + This method assumes that `self.is_bnd()` is `True`, meaning the variant is a BND-type structural variant. + + :return: A tuple containing the direction ("left" or "right") and the strand ("direct" or "complement"). + :rtype: tuple (str, str) + + :raises ValueError: If the ALT string does not follow the expected BND format. """ bnd = self.alts[0] if bnd.startswith('[') or bnd.endswith('['): @@ -156,8 +215,14 @@ def bnd_position(self): """ Extracts the chromosome and position from a BND ALT string. - Returns: - tuple: A tuple containing the chromosome (str) and position (int). + Breakend (BND) ALT strings indicate structural variant breakpoints that span across chromosomes or positions. + This method parses the ALT string to extract the target chromosome and position of the breakpoint. + + :return: A tuple containing the chromosome (as a string) and the position (as an integer). + :rtype: tuple (str, int) + + :raises ValueError: If the ALT string does not follow the expected BND format. + """ # Regular expression to match the BND format and extract chrom:pos match = re.search(r'[\[\]]([^\[\]:]+):(\d+)[\[\]]', self.alts[0]) @@ -169,19 +234,16 @@ def bnd_position(self): return chrom, pos - def bnd_match(self, other, matcher=None, **_kwargs): + def bnd_match(self, other, **_kwargs): """ Build a MatchResult for bnds """ - if matcher is None: - matcher = truvari.Matcher() - def bounds(entry, pos, key): """ Inflate a bnd position based on CIPOS. """ - start = pos - matcher.params.bnddist - end = pos + matcher.params.bnddist + start = pos - self.matcher.bnddist + end = pos + self.matcher.bnddist key = 'CI' + key idx = 0 if key == 'POS' else 1 @@ -232,12 +294,12 @@ def bounds(entry, pos, key): str(self), str(other)) return ret - matcher.compare_gts(ret, self, other) + self.compare_gts(other, ret) # Score is percent of allowed distance needed to find this match - if matcher.params.bnddist > 0: + if self.matcher.bnddist > 0: ret.score = max(0, (1 - ((abs(ret.st_dist) + abs(ret.ed_dist)) / 2) - / matcher.params.bnddist) * 100) + / self.matcher.bnddist) * 100) else: ret.score = int(ret.state) * 100 @@ -396,7 +458,7 @@ def get_record(self): """ return self._record - def match(self, other, matcher=None, skip_gt=False, short_circuit=False): + def match(self, other, skip_gt=False, short_circuit=False): """ Build a MatchResult from comparison of two VariantRecords If self and other are non-bnd, calls VariantRecord.var_match, @@ -404,12 +466,10 @@ def match(self, other, matcher=None, skip_gt=False, short_circuit=False): Otherwise, raises a TypeError. If no matcher is provided, builds one from defaults """ - if matcher is None: - matcher = truvari.Matcher() if not self.is_bnd() and not other.is_bnd(): - return self.var_match(other, matcher, skip_gt, short_circuit) + return self.var_match(other, skip_gt, short_circuit) if self.is_bnd() and other.is_bnd(): - return self.bnd_match(other, matcher) + return self.bnd_match(other) raise TypeError("Incompatible Variants (BND and !BND) can't be matched") def move_record(self, out_vcf, sample=None): @@ -449,7 +509,7 @@ def recovl(self, other, ins_inflate=True): bstart, bend = other.boundaries(ins_inflate) return truvari.reciprocal_overlap(astart, aend, bstart, bend) - def resolve(self, ref, dup_to_ins=False): + def resolve(self, ref): """ Attempts to resolve an SV's REF/ALT sequences. Stores in self.resolved_ref, self.resolved_alt, and self.end """ @@ -464,7 +524,7 @@ def resolve(self, ref, dup_to_ins=False): elif svtype == truvari.SV.INV: self.resolved_ref = seq self.resolved_alt = seq.translate(RC)[::-1] - elif svtype == truvari.SV.DUP and dup_to_ins: + elif svtype == truvari.SV.DUP and self.matcher.dup_to_ins: self.resolved_ref = seq[0] self.resolved_alt = seq self.end = self.start + 1 @@ -594,7 +654,7 @@ def sizesim(self, other): """ return truvari.sizesim(self.size(), other.size()) - def var_match(self, other, matcher=None, skip_gt=False, short_circuit=False): + def var_match(self, other, skip_gt=False, short_circuit=False): """ Build a MatchResult if skip_gt, don't do genotype comparison @@ -606,7 +666,7 @@ def var_match(self, other, matcher=None, skip_gt=False, short_circuit=False): ret.state = True - if not matcher.params.typeignore and not self.same_type(other, matcher.params.dup_to_ins): + if not self.matcher.typeignore and not self.same_type(other): logging.debug("%s and %s are not the same SVTYPE", str(self), str(other)) ret.state = False @@ -615,7 +675,7 @@ def var_match(self, other, matcher=None, skip_gt=False, short_circuit=False): bstart, bend = self.boundaries() cstart, cend = other.boundaries() - if not truvari.overlaps(bstart - matcher.params.refdist, bend + matcher.params.refdist, cstart, cend): + if not truvari.overlaps(bstart - self.matcher.refdist, bend + self.matcher.refdist, cstart, cend): logging.debug("%s and %s are not within REFDIST", str(self), str(other)) ret.state = False @@ -623,7 +683,7 @@ def var_match(self, other, matcher=None, skip_gt=False, short_circuit=False): return ret ret.sizesim, ret.sizediff = self.sizesim(other) - if ret.sizesim < matcher.params.pctsize: + if ret.sizesim < self.matcher.pctsize: logging.debug("%s and %s size similarity is too low (%.3f)", str(self), str(other), ret.sizesim) ret.state = False @@ -631,19 +691,19 @@ def var_match(self, other, matcher=None, skip_gt=False, short_circuit=False): return ret if not skip_gt: - matcher.compare_gts(ret, self, other) + self.compare_gts(other, ret) ret.ovlpct = self.recovl(other) - if ret.ovlpct < matcher.params.pctovl: + if ret.ovlpct < self.matcher.pctovl: logging.debug("%s and %s overlap percent is too low (%.3f)", str(self), str(other), ret.ovlpct) ret.state = False if short_circuit: return ret - if matcher.params.pctseq > 0: - ret.seqsim = self.seqsim(other, matcher.params.no_roll) - if ret.seqsim < matcher.params.pctseq: + if self.matcher.pctseq > 0: + ret.seqsim = self.seqsim(other) + if ret.seqsim < self.matcher.pctseq: logging.debug("%s and %s sequence similarity is too low (%.3ff)", str(self), str(other), ret.seqsim) ret.state = False @@ -760,3 +820,105 @@ def within(self, rstart, rend): qstart, qend = self.boundaries() end_within = self.var_type() != truvari.SV.INS return truvari.coords_within(qstart, qend, rstart, rend, end_within) + + def filter_call(self, base=False): + """ + Determines whether a variant call should be filtered based on Truvari parameters or specific requirements. + + This method evaluates a variant entry (`entry`) and checks if it should be excluded from further processing + based on filtering criteria such as monomorphic reference, multi-allelic records, filtering status, + sample presence, or unsupported single-end BNDs. + + :param entry: The variant entry to evaluate. + :type entry: truvari.VariantRecord + :param base: A flag indicating whether the entry is the "base" (reference) call or the "comparison" call. + Filtering behavior may differ based on this flag. + :type base: bool, optional + + :return: `True` if the variant should be filtered (excluded), otherwise `False`. + :rtype: bool + + :raises ValueError: If the entry is multi-allelic and `check_multi` is enabled in the Truvari parameters. + + Filtering Logic: + - **Monomorphic Reference:** If `check_monref` is enabled and the entry is a monomorphic reference, it is filtered. + - **Multi-Allelic Records:** If `check_multi` is enabled and the entry is multi-allelic, an error is raised. + - **Filtered Variants:** If `passonly` is enabled and the entry is flagged as filtered, it is excluded. + - **Sample Presence:** If `no_ref` is set to include the entry's type (base or comparison) or `pick == 'ac'`, the sample must be present in the entry. + - **Single-End BNDs:** Single-end BNDs are always excluded. + """ + if self.matcher.check_monref and self.is_monrefstar(): + return True + + if self.matcher.check_multi and self.is_multi(): + raise ValueError( + f"Cannot compare multi-allelic records. Please split\nline {str(self)}") + + if self.matcher.passonly and self.is_filtered(): + return True + + prefix = 'b' if base else 'c' + if self.matcher.no_ref in ["a", prefix] or self.matcher.pick == 'ac': + samp = self.matcher.bSample if base else self.matcher.cSample + if not self.is_present(samp): + return True + + if self.matcher.no_single_bnd and self.is_single_bnd(): + return True + return False + + def size_filter(self, base=False): + """ + Determines whether a variant entry should be filtered based on its size. + + This method evaluates the size of a variant and checks if it falls outside the specified size thresholds. + Filtering criteria depend on whether the entry is a "base" (reference) or "comparison" call. + + :param entry: The variant entry to evaluate. + :type entry: truvari.VariantRecord + :param base: A flag indicating whether the entry is the "base" (reference) call. If `True`, the `sizemin` + parameter is used as the minimum size threshold. Otherwise, the `sizefilt` parameter is used. + :type base: bool, optional + + :return: `True` if the variant should be filtered due to its size, otherwise `False`. + :rtype: bool + + Filtering Logic: + - **Maximum Size:** If the size exceeds `sizemax`, the variant is filtered. + - **Minimum Size (Base):** If `base=True` and the size is less than `sizemin`, the variant is filtered. + - **Minimum Size (Comparison):** If `base=False` and the size is less than `sizefilt`, the variant is filtered. + """ + size = self.size() + return (size > self.matcher.sizemax) \ + or (base and size < self.matcher.sizemin) \ + or (not base and size < self.matcher.sizefilt) + + def compare_gts(self, other, match): + """ + Populates genotype-specific comparison details in a `MatchResult`. + + This method compares the genotypes of a "base" (reference) variant and a "comparison" variant and updates + the provided `MatchResult` object with the results. It computes the genotype counts and determines the + difference between the base and comparison genotypes. + + :param other: The other variant entry. + :type other: `truvari.VariantRecord` + :param match: The `MatchResult` object to update with genotype comparison details. + :type match: truvari.MatchResult + + Updates the `MatchResult` object with the following attributes: + - `base_gt`: The genotype of the base sample. + - `base_gt_count`: The count of the reference allele (1) in the base genotype. + - `comp_gt`: The genotype of the comparison sample. + - `comp_gt_count`: The count of the reference allele (1) in the comparison genotype. + - `gt_match`: The absolute difference between `base_gt_count` and `comp_gt_count`. + """ + b_gt = self.gt(self.matcher.bSample) + c_gt = other.gt(self.matcher.cSample) + if b_gt: + match.base_gt = b_gt + match.base_gt_count = sum(1 for _ in match.base_gt if _ == 1) + if c_gt: + match.comp_gt = c_gt + match.comp_gt_count = sum(1 for _ in match.comp_gt if _ == 1) + match.gt_match = abs(match.base_gt_count - match.comp_gt_count)