Skip to content

Commit

Permalink
feat: add exclusion of samples in clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
boasvdp committed Jul 25, 2024
1 parent 31d01a8 commit c105a5d
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 3 deletions.
6 changes: 6 additions & 0 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ OUT = config["output_dir"]
# iget collection and save to a path passed to cli
PREVIOUS_CLUSTERING = config["previous_clustering"]

if PREVIOUS_CLUSTERING == "None":
Path(OUT).mkdir(parents=True, exist_ok=True)
Path(OUT + "/previous_list_excluded_samples.txt").touch()

# Configure pipeline outputs
expected_outputs = []

Expand All @@ -30,6 +34,8 @@ elif config["clustering_type"] == "mlst":
localrules:
all,
copy_assemblies_to_temp,
copy_or_touch_list_excluded_samples,
touch_list_excluded_samples,


include: "workflow/rules/combine_snp_profiles.smk"
Expand Down
6 changes: 5 additions & 1 deletion config/presets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ mycobacterium_tuberculosis:
max_distance: 200
clustering_type: "alignment"
N_content_threshold: 0.5
coverage_threshold: 20
inclusion_pattern: "^NLA[a-zA-Z0-9]+"
salmonella:
cluster_threshold: 7
max_distance: 200
clustering_type: "mlst"
N_content_threshold: None
N_content_threshold: None
coverage_threshold: 0
inclusion_pattern: "*"
6 changes: 5 additions & 1 deletion juno_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def _parse_args(self) -> argparse.Namespace:
args = super()._parse_args()

# Optional arguments are loaded into self here
self.previous_clustering: str = args.previous_clustering
self.previous_clustering: Optional[str] = args.previous_clustering
self.clustering_preset: str = args.clustering_preset
self.presets_path: Optional[Path] = args.presets_path
self.merged_cluster_separator: str = args.merged_cluster_separator
self.clustering_preset: str = args.clustering_preset

return args

Expand Down Expand Up @@ -99,10 +100,13 @@ def setup(self) -> None:
"exclusion_file": str(self.exclusion_file),
"previous_clustering": str(self.previous_clustering),
"merged_cluster_separator": str(self.merged_cluster_separator),
"clustering_preset": str(self.clustering_preset),
"cluster_threshold": str(self.cluster_threshold), # from presets
"max_distance": str(self.max_distance), # from presets
"clustering_type": str(self.clustering_type), # from presets
"N_content_threshold": str(self.N_content_threshold), # from presets
"coverage_threshold": str(self.coverage_threshold), # from presets
"inclusion_pattern": str(self.inclusion_pattern), # from presets
}

def set_presets(self) -> None:
Expand Down
63 changes: 62 additions & 1 deletion workflow/rules/clustering.smk
Original file line number Diff line number Diff line change
@@ -1,9 +1,67 @@
if config["clustering_preset"] == "mycobacterium_tuberculosis":
rule copy_or_touch_list_excluded_samples:
output:
temp(OUT + "/previous_list_excluded_samples.tsv"),
params:
previous_list = PREVIOUS_CLUSTERING + "/list_excluded_samples.tsv"
shell:
"""
if [ -f {params.previous_list} ]
then
cp {params.previous_list} {output}
else
touch {output}
fi
"""

rule list_excluded_samples:
input:
seq_exp_json = expand(INPUT + "/mtb_typing/seq_exp_json/{sample}.json", sample=SAMPLES),
exclude_list = OUT + "/previous_list_excluded_samples.tsv",
output:
OUT + "/list_excluded_samples.tsv",
log:
OUT + "/log/list_excluded_samples.log",
message:
"Listing samples which should be excluded."
resources:
mem_gb=config["mem_gb"]["compression"],
conda:
"../envs/scripts.yaml"
container:
"docker://ghcr.io/boasvdp/juno_clustering_scripts:0.2"
params:
coverage_threshold=config["coverage_threshold"],
inclusion_pattern=config["inclusion_pattern"],
threads: config["threads"]["compression"]
shell:
"""
# columns: sample, reason, date
python workflow/scripts/list_excluded_samples.py \
--input {input.seq_exp_json} \
--previous-exclude-list {input.exclude_list} \
--output {output} \
--inclusion-pattern {params.inclusion_pattern} \
--coverage-threshold {params.coverage_threshold} \
2>&1> {log}
"""
else:
rule touch_list_excluded_samples:
output:
temp(OUT + "/list_excluded_samples.tsv"),
shell:
"""
touch {output}
"""


# PREVIOUS_CLUSTERING is read into config as a str
if PREVIOUS_CLUSTERING == "None":

rule clustering_from_scratch:
input:
distances=OUT + "/distances.tsv",
exclude_list=OUT + "/list_excluded_samples.tsv",
output:
OUT + "/clusters.csv",
log:
Expand All @@ -29,7 +87,8 @@ python workflow/scripts/cluster.py \
--log {log} \
--verbose \
--merged-cluster-separator {params.merged_cluster_separator:q} \
--output {output}
--output {output} \
--exclude {input.exclude_list}
"""

else:
Expand All @@ -38,6 +97,7 @@ else:
input:
distances=OUT + "/distances.tsv",
previous_clustering=PREVIOUS_CLUSTERING + "/clusters.csv",
exclude_list=OUT + "/list_excluded_samples.txt",
output:
OUT + "/clusters.csv",
log:
Expand All @@ -64,5 +124,6 @@ python workflow/scripts/cluster.py \
--log {log} \
--verbose \
--merged-cluster-separator {params.merged_cluster_separator:q} \
--exclude {input.exclude_list}
--output {output}
"""
35 changes: 35 additions & 0 deletions workflow/scripts/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,36 @@ def read_data(distances, previous_clustering):
)
return df_distances, df_previous_clustering

@timing
def exclude_samples(df_distances, exclude_list):
"""
Exclude samples from the distances dataframe
Parameters
----------
df_distances : pd.DataFrame
Dataframe with distances
exclude_list : Path
Path to list of samples to exclude
Returns
-------
df_distances : pd.DataFrame
Dataframe with distances
"""
with open(exclude_list) as f:
nr_lines = len(f.readlines())
if nr_lines > 0:
logging.info(f"Excluding samples")
df_exclude = pd.read_csv(exclude_list, sep="\t")
exclude_samples = df_exclude["sample"].tolist()
df_distances = df_distances[
~df_distances["sample1"].isin(exclude_samples)
& ~df_distances["sample2"].isin(exclude_samples)
]
return df_distances

@timing
def emit_and_save_critical_warning(message, output_path):
"""
Expand Down Expand Up @@ -399,6 +429,8 @@ def main(args):
args.distances, args.previous_clustering
)

df_distances = exclude_samples(df_distances, args.exclude_list)

df_nodes = get_df_nodes(df_distances, df_previous_clustering)

df_distances_filtered = filter_edges(df_distances, args.threshold)
Expand Down Expand Up @@ -434,6 +466,9 @@ def main(args):
help="Separator for merged clusters",
default="|",
)
parser.add_argument(
"--exclude-list", type=Path, help="Path to list of samples to exclude from clustering"
)
parser.add_argument(
"--log", type=Path, help="Path to log file", default="cluster.log"
)
Expand Down
57 changes: 57 additions & 0 deletions workflow/scripts/list_excluded_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import argparse
import json
from pathlib import Path
import pandas as pd
import datetime
import logging

def read_input_data(input_files):
data = {}
for file in input_files:
with open(file) as f:
data[file.stem] = json.load(f)
df = pd.DataFrame.from_dict(data, orient='index').reset_index(names="sample")
df['mean_coverage'] = df['mean_coverage'].astype(float)
return df

def exclude_on_coverage(df, threshold):
df_copy = df.copy()
return df_copy[df_copy['mean_coverage'] < threshold]

def exclude_on_pattern(df, pattern):
df_copy = df.copy()
return df_copy[~df_copy['sample'].str.contains(pattern)]

def read_previous_exclude_list(file):
with open(file) as f:
lines = f.readlines()
if len(lines) == 0:
return pd.DataFrame(columns=['sample', 'reason', 'date'])
else:
df = pd.read_csv(file, sep="\t")
return df

def main(args):
df = read_input_data(args.input)
df_coverage_excluded = exclude_on_coverage(df, args.coverage_threshold)
df_coverage_excluded["reason"] = "low_coverage"
df_pattern_excluded = exclude_on_pattern(df, args.inclusion_pattern)
df_pattern_excluded["reason"] = "not_NLA"
df_excluded = pd.concat([df_coverage_excluded[['sample', 'reason']], df_pattern_excluded[['sample', 'reason']]])
df_excluded['date'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
df_previous_excluded = read_previous_exclude_list(args.previous_exclude_list)
df_final = pd.concat([df_previous_excluded, df_excluded])
df_final.to_csv(args.output, sep="\t", index=False)

if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument('--input', type=Path, required=True, nargs='+')
parser.add_argument('--previous-exclude-list', type=Path, required=True)
parser.add_argument('--output', type=Path, required=True)
parser.add_argument('--inclusion-pattern', type=str, required=True)
parser.add_argument('--coverage-threshold', type=float, required=True)

args = parser.parse_args()

main(args)

0 comments on commit c105a5d

Please sign in to comment.