diff --git a/truvari/bench.py b/truvari/bench.py index 93306d54..08a0e1d6 100644 --- a/truvari/bench.py +++ b/truvari/bench.py @@ -333,8 +333,10 @@ def __init__(self, bench, params): with open(os.path.join(self.m_bench.outdir, 'params.json'), 'w') as fout: json.dump(param_dict, fout) - b_vcf = truvari.VariantFile(self.m_bench.base_vcf, params=self.m_params) - c_vcf = truvari.VariantFile(self.m_bench.comp_vcf, params=self.m_params) + b_vcf = truvari.VariantFile( + self.m_bench.base_vcf, params=self.m_params) + c_vcf = truvari.VariantFile( + self.m_bench.comp_vcf, params=self.m_params) self.n_headers = {'b': edit_header(b_vcf), 'c': edit_header(c_vcf)} @@ -407,6 +409,7 @@ def close_outputs(self): self.stats_box.write_json(os.path.join( self.m_bench.outdir, "summary.json")) + class Bench(): """ Object to perform operations of truvari bench @@ -524,6 +527,7 @@ def compare_chunk(self, chunk): logging.debug("Comparing chunk %s", chunk_id) result = self.compare_calls( chunk_dict["base"], chunk_dict["comp"], chunk_id) + # Not checking BNDs as part of refine_candidates because they can't be refined. self.check_refine_candidate(result) # Check BNDs separately if self.params.bnddist != -1 and (chunk_dict['base_BND'] or chunk_dict['comp_BND']): @@ -619,8 +623,8 @@ def check_refine_candidate(self, result): # min(10, self.matcher.chunksize) need to make sure the refine covers the region buf = 10 start = max(0, min(*pos) - buf) - self.refine_candidates.append( - f"{chrom}\t{start}\t{max(*pos) + buf}") + end = max(*pos) + buf + self.refine_candidates.append(f"{chrom}\t{start}\t{end}") ################# # Match Pickers # diff --git a/truvari/matching.py b/truvari/matching.py index 5978dc6a..fcf1c8ef 100644 --- a/truvari/matching.py +++ b/truvari/matching.py @@ -101,13 +101,25 @@ def __repr__(self): def file_zipper(*start_files): """ - Zip files to yield the entries in order. - Each file must be sorted in the same order. - start_files is a tuple of ('key', iterable) - where key is the identifier (so we know which file the yielded entry came from) - and iterable is usually a truvari.VariantFile + Zip multiple files to yield their entries in order. - yields key, truvari.VariantRecord + The function takes as input tuples of (`key`, `iterable`), where: + + - `key` is an identifier (used to track which file the yielded entry comes from). + - `iterable` is an iterable object, typically a `truvari.VariantFile`. + + The function iterates through all input files in a coordinated manner, yielding the entries in order. + + :param start_files: A variable-length argument list of tuples (`key`, `iterable`). + :type start_files: tuple + + :yields: A tuple of (`key`, `truvari.VariantRecord`), where `key` is the file identifier and the second element is the next record from the corresponding file. + :rtype: tuple + + :raises StopIteration: Raised when all input files have been exhausted. + + **Logs**: + - Logs a summary of the zipping process after all files have been processed. """ markers = [] # list of lists: [name, file_handler, top_entry] file_counts = Counter() diff --git a/truvari/variant_record.py b/truvari/variant_record.py index c07eeab5..730309a0 100644 --- a/truvari/variant_record.py +++ b/truvari/variant_record.py @@ -10,7 +10,6 @@ RC = str.maketrans("ATCG", "TAGC") - class VariantRecord: """ Wrapper around pysam.VariantRecords with helper functions of variant properties and basic comparisons @@ -35,6 +34,21 @@ def __getattr__(self, name): """ return getattr(self._record, name) + def __setattr__(self, name, value): + """ + Attempt to delegate attribute setting to the original VariantRecord first + """ + if name.startswith("_") or not hasattr(self, "_record"): + # Directly set internal attributes or during __init__ before _record is set + super().__setattr__(name, value) + else: + try: + # Try to set the attribute on the wrapped _record + setattr(self._record, name, value) + except AttributeError: + # If the wrapped object does not have the attribute, set it on self + super().__setattr__(name, value) + def __str__(self): return str(self._record) @@ -348,8 +362,7 @@ def is_present(self, sample=0, allow_missing=True): gt = self.gt(sample) if allow_missing: return 1 in gt - # Hemi... - return truvari.get_gt(gt)[truvari.GT.HET, truvari.GT.HOM] + return truvari.get_gt(gt) in [truvari.GT.HET, truvari.GT.HOM] def is_single_bnd(self): """