diff --git a/Snakefile b/Snakefile index ccbfe05..7aac42b 100644 --- a/Snakefile +++ b/Snakefile @@ -15,6 +15,10 @@ OUT = config["output_dir"] # iget collection and save to a path passed to cli PREVIOUS_CLUSTERING = config["previous_clustering"] +if PREVIOUS_CLUSTERING == "None": + Path(OUT).mkdir(parents=True, exist_ok=True) + Path(OUT + "/previous_list_excluded_samples.txt").touch() + # Configure pipeline outputs expected_outputs = [] @@ -30,6 +34,8 @@ elif config["clustering_type"] == "mlst": localrules: all, copy_assemblies_to_temp, + copy_or_touch_list_excluded_samples, + touch_list_excluded_samples, include: "workflow/rules/combine_snp_profiles.smk" diff --git a/config/presets.yaml b/config/presets.yaml index 908388c..e30d037 100644 --- a/config/presets.yaml +++ b/config/presets.yaml @@ -3,8 +3,12 @@ mycobacterium_tuberculosis: max_distance: 200 clustering_type: "alignment" N_content_threshold: 0.5 + coverage_threshold: 20 + inclusion_pattern: "^NLA[a-zA-Z0-9]+" salmonella: cluster_threshold: 7 max_distance: 200 clustering_type: "mlst" - N_content_threshold: None \ No newline at end of file + N_content_threshold: None + coverage_threshold: 0 + inclusion_pattern: "*" \ No newline at end of file diff --git a/juno_clustering.py b/juno_clustering.py index 4f8a073..2c3a453 100644 --- a/juno_clustering.py +++ b/juno_clustering.py @@ -67,10 +67,11 @@ def _parse_args(self) -> argparse.Namespace: args = super()._parse_args() # Optional arguments are loaded into self here - self.previous_clustering: str = args.previous_clustering + self.previous_clustering: Optional[str] = args.previous_clustering self.clustering_preset: str = args.clustering_preset self.presets_path: Optional[Path] = args.presets_path self.merged_cluster_separator: str = args.merged_cluster_separator + self.clustering_preset: str = args.clustering_preset return args @@ -99,10 +100,13 @@ def setup(self) -> None: "exclusion_file": str(self.exclusion_file), "previous_clustering": str(self.previous_clustering), "merged_cluster_separator": str(self.merged_cluster_separator), + "clustering_preset": str(self.clustering_preset), "cluster_threshold": str(self.cluster_threshold), # from presets "max_distance": str(self.max_distance), # from presets "clustering_type": str(self.clustering_type), # from presets "N_content_threshold": str(self.N_content_threshold), # from presets + "coverage_threshold": str(self.coverage_threshold), # from presets + "inclusion_pattern": str(self.inclusion_pattern), # from presets } def set_presets(self) -> None: diff --git a/workflow/rules/clustering.smk b/workflow/rules/clustering.smk index 8a5b1de..90fe269 100644 --- a/workflow/rules/clustering.smk +++ b/workflow/rules/clustering.smk @@ -1,9 +1,67 @@ +if config["clustering_preset"] == "mycobacterium_tuberculosis": + rule copy_or_touch_list_excluded_samples: + output: + temp(OUT + "/previous_list_excluded_samples.tsv"), + params: + previous_list = PREVIOUS_CLUSTERING + "/list_excluded_samples.tsv" + shell: + """ +if [ -f {params.previous_list} ] +then + cp {params.previous_list} {output} +else + touch {output} +fi + """ + + rule list_excluded_samples: + input: + seq_exp_json = expand(INPUT + "/mtb_typing/seq_exp_json/{sample}.json", sample=SAMPLES), + exclude_list = OUT + "/previous_list_excluded_samples.tsv", + output: + OUT + "/list_excluded_samples.tsv", + log: + OUT + "/log/list_excluded_samples.log", + message: + "Listing samples which should be excluded." + resources: + mem_gb=config["mem_gb"]["compression"], + conda: + "../envs/scripts.yaml" + container: + "docker://ghcr.io/boasvdp/juno_clustering_scripts:0.2" + params: + coverage_threshold=config["coverage_threshold"], + inclusion_pattern=config["inclusion_pattern"], + threads: config["threads"]["compression"] + shell: + """ +# columns: sample, reason, date +python workflow/scripts/list_excluded_samples.py \ +--input {input.seq_exp_json} \ +--previous-exclude-list {input.exclude_list} \ +--output {output} \ +--inclusion-pattern {params.inclusion_pattern} \ +--coverage-threshold {params.coverage_threshold} \ +2>&1> {log} + """ +else: + rule touch_list_excluded_samples: + output: + temp(OUT + "/list_excluded_samples.tsv"), + shell: + """ +touch {output} + """ + + # PREVIOUS_CLUSTERING is read into config as a str if PREVIOUS_CLUSTERING == "None": rule clustering_from_scratch: input: distances=OUT + "/distances.tsv", + exclude_list=OUT + "/list_excluded_samples.tsv", output: OUT + "/clusters.csv", log: @@ -29,7 +87,8 @@ python workflow/scripts/cluster.py \ --log {log} \ --verbose \ --merged-cluster-separator {params.merged_cluster_separator:q} \ ---output {output} +--output {output} \ +--exclude {input.exclude_list} """ else: @@ -38,6 +97,7 @@ else: input: distances=OUT + "/distances.tsv", previous_clustering=PREVIOUS_CLUSTERING + "/clusters.csv", + exclude_list=OUT + "/list_excluded_samples.txt", output: OUT + "/clusters.csv", log: @@ -64,5 +124,6 @@ python workflow/scripts/cluster.py \ --log {log} \ --verbose \ --merged-cluster-separator {params.merged_cluster_separator:q} \ +--exclude {input.exclude_list} --output {output} """ diff --git a/workflow/scripts/cluster.py b/workflow/scripts/cluster.py index 518a0e3..f1bab3a 100644 --- a/workflow/scripts/cluster.py +++ b/workflow/scripts/cluster.py @@ -66,6 +66,36 @@ def read_data(distances, previous_clustering): ) return df_distances, df_previous_clustering +@timing +def exclude_samples(df_distances, exclude_list): + """ + Exclude samples from the distances dataframe + + Parameters + ---------- + df_distances : pd.DataFrame + Dataframe with distances + exclude_list : Path + Path to list of samples to exclude + + Returns + ------- + df_distances : pd.DataFrame + Dataframe with distances + + """ + with open(exclude_list) as f: + nr_lines = len(f.readlines()) + if nr_lines > 0: + logging.info(f"Excluding samples") + df_exclude = pd.read_csv(exclude_list, sep="\t") + exclude_samples = df_exclude["sample"].tolist() + df_distances = df_distances[ + ~df_distances["sample1"].isin(exclude_samples) + & ~df_distances["sample2"].isin(exclude_samples) + ] + return df_distances + @timing def emit_and_save_critical_warning(message, output_path): """ @@ -399,6 +429,8 @@ def main(args): args.distances, args.previous_clustering ) + df_distances = exclude_samples(df_distances, args.exclude_list) + df_nodes = get_df_nodes(df_distances, df_previous_clustering) df_distances_filtered = filter_edges(df_distances, args.threshold) @@ -434,6 +466,9 @@ def main(args): help="Separator for merged clusters", default="|", ) + parser.add_argument( + "--exclude-list", type=Path, help="Path to list of samples to exclude from clustering" + ) parser.add_argument( "--log", type=Path, help="Path to log file", default="cluster.log" ) diff --git a/workflow/scripts/list_excluded_samples.py b/workflow/scripts/list_excluded_samples.py new file mode 100644 index 0000000..ff2e6a8 --- /dev/null +++ b/workflow/scripts/list_excluded_samples.py @@ -0,0 +1,57 @@ +import argparse +import json +from pathlib import Path +import pandas as pd +import datetime +import logging + +def read_input_data(input_files): + data = {} + for file in input_files: + with open(file) as f: + data[file.stem] = json.load(f) + df = pd.DataFrame.from_dict(data, orient='index').reset_index(names="sample") + df['mean_coverage'] = df['mean_coverage'].astype(float) + return df + +def exclude_on_coverage(df, threshold): + df_copy = df.copy() + return df_copy[df_copy['mean_coverage'] < threshold] + +def exclude_on_pattern(df, pattern): + df_copy = df.copy() + return df_copy[~df_copy['sample'].str.contains(pattern)] + +def read_previous_exclude_list(file): + with open(file) as f: + lines = f.readlines() + if len(lines) == 0: + return pd.DataFrame(columns=['sample', 'reason', 'date']) + else: + df = pd.read_csv(file, sep="\t") + return df + +def main(args): + df = read_input_data(args.input) + df_coverage_excluded = exclude_on_coverage(df, args.coverage_threshold) + df_coverage_excluded["reason"] = "low_coverage" + df_pattern_excluded = exclude_on_pattern(df, args.inclusion_pattern) + df_pattern_excluded["reason"] = "not_NLA" + df_excluded = pd.concat([df_coverage_excluded[['sample', 'reason']], df_pattern_excluded[['sample', 'reason']]]) + df_excluded['date'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + df_previous_excluded = read_previous_exclude_list(args.previous_exclude_list) + df_final = pd.concat([df_previous_excluded, df_excluded]) + df_final.to_csv(args.output, sep="\t", index=False) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--input', type=Path, required=True, nargs='+') + parser.add_argument('--previous-exclude-list', type=Path, required=True) + parser.add_argument('--output', type=Path, required=True) + parser.add_argument('--inclusion-pattern', type=str, required=True) + parser.add_argument('--coverage-threshold', type=float, required=True) + + args = parser.parse_args() + + main(args) \ No newline at end of file