Skip to content

Commit

Permalink
Merge branch 'DOR-895-demux-summary' into 'master'
Browse files Browse the repository at this point in the history
[DOR-895] Fix missing filenames in demux+alignment summaries

Closes DOR-895

See merge request machine-learning/dorado!1266
  • Loading branch information
blawrence-ont committed Nov 13, 2024
2 parents e83a80c + 86ff2f9 commit affea85
Show file tree
Hide file tree
Showing 16 changed files with 270 additions and 233 deletions.
271 changes: 131 additions & 140 deletions dorado/data_loader/DataLoader.cpp

Large diffs are not rendered by default.

20 changes: 16 additions & 4 deletions dorado/read_pipeline/HtsReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,27 @@ HtsReader::HtsReader(const std::string& filename,
}

template <typename T>
bool HtsReader::try_initialise_generator(const std::string& filename) {
auto generator = std::make_shared<T>(filename); // shared to allow copy assignment
bool HtsReader::try_initialise_generator(const std::string& filepath) {
auto generator = std::make_shared<T>(filepath); // shared to allow copy assignment
if (!generator->is_valid()) {
return false;
}
m_header = generator->header();
m_format = generator->format();
m_bam_record_generator = [generator_ = std::move(generator)](bam1_t& bam_record) {
return generator_->try_get_next_record(bam_record);
m_bam_record_generator = [generator_ = std::move(generator),
filename = std::filesystem::path(filepath).filename().string(),
this](bam1_t& bam_record) {
if (!generator_->try_get_next_record(bam_record)) {
return false;
}

// If the record doesn't have a filename set then say that it came from the currently processing file.
if (m_add_filename_tag && !bam_aux_get(&bam_record, "fn")) {
bam_aux_append(&bam_record, "fn", 'Z', static_cast<int>(filename.size() + 1),
reinterpret_cast<const uint8_t*>(filename.c_str()));
}

return true;
};
return true;
}
Expand Down
17 changes: 12 additions & 5 deletions dorado/read_pipeline/HtsReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class HtsReader {
public:
HtsReader(const std::string& filename,
std::optional<std::unordered_set<std::string>> read_list);

// By default we'll add a filename tag to each record to match the current file
// if one isn't included in the data, but that can be disabled with this method.
void set_add_filename_tag(bool should) { m_add_filename_tag = should; }

bool read();

// If reading directly into a pipeline need to set the client info on the messages
Expand All @@ -36,21 +41,22 @@ class HtsReader {
void set_record_mutator(std::function<void(BamPtr&)> mutator);

bool is_aligned{false};
BamPtr record{nullptr};
BamPtr record;

sam_hdr_t* header();
const sam_hdr_t* header() const;
const std::string& format() const;

private:
sam_hdr_t* m_header{nullptr}; // non-owning
std::string m_format{};
std::string m_format;
std::shared_ptr<ClientInfo> m_client_info;

std::function<void(BamPtr&)> m_record_mutator{};
std::function<void(BamPtr&)> m_record_mutator;
std::optional<std::unordered_set<std::string>> m_read_list;

std::function<bool(bam1_t&)> m_bam_record_generator{};
std::function<bool(bam1_t&)> m_bam_record_generator;
bool m_add_filename_tag{true};

template <typename T>
bool try_initialise_generator(const std::string& filename);
Expand All @@ -69,7 +75,8 @@ T HtsReader::get_tag(const char* tagname) {
} else if constexpr (std::is_floating_point_v<T>) {
tag_value = static_cast<T>(bam_aux2f(tag));
} else {
tag_value = static_cast<T>(bam_aux2Z(tag));
const char* val = bam_aux2Z(tag);
tag_value = val ? val : T{};
}

return tag_value;
Expand Down
4 changes: 2 additions & 2 deletions dorado/read_pipeline/messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ void ReadCommon::generate_read_tags(bam1_t *aln, bool emit_moves, bool is_duplex
int rn = attributes.read_number;
bam_aux_append(aln, "rn", 'i', sizeof(rn), (uint8_t *)&rn);

bam_aux_append(aln, "fn", 'Z', int(attributes.fast5_filename.length() + 1),
(uint8_t *)attributes.fast5_filename.c_str());
bam_aux_append(aln, "fn", 'Z', int(attributes.filename.length() + 1),
(uint8_t *)attributes.filename.c_str());

float sm = shift;
bam_aux_append(aln, "sm", 'f', sizeof(sm), (uint8_t *)&sm);
Expand Down
6 changes: 4 additions & 2 deletions dorado/read_pipeline/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ struct Attributes {
int32_t read_number{-1}; // Per-channel number of each read as it was acquired by minknow
int32_t channel_number{-1}; //Channel ID
std::string start_time{}; //Read acquisition start time
std::string fast5_filename{};
uint64_t num_samples;
std::string filename{};
// Indicates if this read had end reason `mux_change` or `unblock_mux_change`
bool is_end_reason_mux_change{false};

// Only used by tests, and only valid for POD5 data.
uint64_t num_samples{};
};

} // namespace details
Expand Down
66 changes: 40 additions & 26 deletions dorado/summary/summary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <cctype>
#include <csignal>
#include <filesystem>
#include <string_view>

namespace {

Expand Down Expand Up @@ -37,27 +38,44 @@ volatile sig_atomic_t SigIntHandler::interrupt{};

namespace dorado {

std::vector<std::string> SummaryData::s_required_fields = {"filename", "read_id"};
namespace {

using namespace std::string_view_literals;

const std::array s_required_fields = {
"filename"sv,
"read_id"sv,
};

const std::array s_general_fields = {
"run_id"sv,
"channel"sv,
"mux"sv,
"start_time"sv,
"duration"sv,
"template_start"sv,
"template_duration"sv,
"sequence_length_template"sv,
"mean_qscore_template"sv,
};

std::vector<std::string> SummaryData::s_general_fields = {"run_id",
"channel",
"mux",
"start_time",
"duration",
"template_start",
"template_duration",
"sequence_length_template",
"mean_qscore_template"};
const std::array s_barcoding_fields = {
"barcode"sv,
};

std::vector<std::string> SummaryData::s_barcoding_fields = {"barcode"};
const std::array s_alignment_fields = {
"alignment_genome"sv, "alignment_genome_start"sv,
"alignment_genome_end"sv, "alignment_strand_start"sv,
"alignment_strand_end"sv, "alignment_direction"sv,
"alignment_length"sv, "alignment_num_aligned"sv,
"alignment_num_correct"sv, "alignment_num_insertions"sv,
"alignment_num_deletions"sv, "alignment_num_substitutions"sv,
"alignment_mapq"sv, "alignment_strand_coverage"sv,
"alignment_identity"sv, "alignment_accuracy"sv,
"alignment_bed_hits"sv,
};

std::vector<std::string> SummaryData::s_alignment_fields = {
"alignment_genome", "alignment_genome_start", "alignment_genome_end",
"alignment_strand_start", "alignment_strand_end", "alignment_direction",
"alignment_length", "alignment_num_aligned", "alignment_num_correct",
"alignment_num_insertions", "alignment_num_deletions", "alignment_num_substitutions",
"alignment_mapq", "alignment_strand_coverage", "alignment_identity",
"alignment_accuracy", "alignment_bed_hits"};
} // namespace

SummaryData::SummaryData() = default;

Expand All @@ -73,7 +91,7 @@ void SummaryData::set_fields(FieldFlags flags) {
m_field_flags = flags;
}

bool SummaryData::process_file(const std::string& filename, std::ostream& writer) {
void SummaryData::process_file(const std::string& filename, std::ostream& writer) {
SigIntHandler sig_handler;
HtsReader reader(filename, std::nullopt);
m_field_flags = GENERAL_FIELDS | BARCODING_FIELDS;
Expand All @@ -82,7 +100,7 @@ bool SummaryData::process_file(const std::string& filename, std::ostream& writer
}
auto read_group_exp_start_time = utils::get_read_group_info(reader.header(), "DT");
write_header(writer);
return write_rows_from_reader(reader, writer, read_group_exp_start_time);
write_rows_from_reader(reader, writer, read_group_exp_start_time);
}

bool SummaryData::process_tree(const std::string& folder, std::ostream& writer) {
Expand All @@ -104,10 +122,7 @@ bool SummaryData::process_tree(const std::string& folder, std::ostream& writer)
for (const auto& read_file : files) {
HtsReader reader(read_file, std::nullopt);
auto read_group_exp_start_time = utils::get_read_group_info(reader.header(), "DT");
bool ok = write_rows_from_reader(reader, writer, read_group_exp_start_time);
if (!ok) {
spdlog::error("File {} could not be processed. Skipping file.", read_file);
}
write_rows_from_reader(reader, writer, read_group_exp_start_time);
}
return true;
}
Expand Down Expand Up @@ -137,7 +152,7 @@ void SummaryData::write_header(std::ostream& writer) {
writer << '\n';
}

bool SummaryData::write_rows_from_reader(
void SummaryData::write_rows_from_reader(
HtsReader& reader,
std::ostream& writer,
const std::map<std::string, std::string>& read_group_exp_start_time) {
Expand Down Expand Up @@ -266,7 +281,6 @@ bool SummaryData::write_rows_from_reader(
}
writer << '\n';
}
return true;
}

} // namespace dorado
9 changes: 2 additions & 7 deletions dorado/summary/summary.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,17 @@ class SummaryData {
void set_fields(FieldFlags flags);

/// This will automatically set the fields based on the contents of the file.
bool process_file(const std::string& filename, std::ostream& writer);
void process_file(const std::string& filename, std::ostream& writer);

/// For this method the fields must already be set.
bool process_tree(const std::string& folder, std::ostream& writer);

private:
static std::vector<std::string> s_required_fields;
static std::vector<std::string> s_general_fields;
static std::vector<std::string> s_barcoding_fields;
static std::vector<std::string> s_alignment_fields;

char m_separator{'\t'};
FieldFlags m_field_flags{};

void write_header(std::ostream& writer);
bool write_rows_from_reader(HtsReader& reader,
void write_rows_from_reader(HtsReader& reader,
std::ostream& writer,
const std::map<std::string, std::string>& rgst);
};
Expand Down
29 changes: 13 additions & 16 deletions tests/AlignerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class AlignerNodeTestFixture {
MessageTypePtr RunPipelineForRead(
const std::shared_ptr<dorado::alignment::AlignmentInfo>& loaded_align_info,
const std::shared_ptr<dorado::alignment::AlignmentInfo>& client_align_info,
std::string read_id,
std::string sequence) {
const std::string& read_id,
const std::string& sequence) {
auto index_file_access = std::make_shared<dorado::alignment::IndexFileAccess>();
auto bed_file_access = std::make_shared<dorado::alignment::BedFileAccess>();
CHECK(index_file_access->load_index(loaded_align_info->reference_file,
Expand All @@ -110,16 +110,14 @@ class AlignerNodeTestFixture {
create_pipeline(index_file_access, bed_file_access, thread_pool,
dorado::utils::concurrency::TaskPriority::normal);

dorado::ReadCommon read_common{};
auto client_info = std::make_shared<dorado::DefaultClientInfo>();
client_info->contexts().register_context<const dorado::alignment::AlignmentInfo>(
client_align_info);
read_common.client_info = client_info;
read_common.read_id = std::move(read_id);
read_common.seq = std::move(sequence);

auto read = std::make_unique<MessageType>();
read->read_common = std::move(read_common);
read->read_common.client_info = std::move(client_info);
read->read_common.read_id = read_id;
read->read_common.seq = sequence;

pipeline->push_message(std::move(read));
pipeline->terminate({});
Expand Down Expand Up @@ -544,6 +542,7 @@ TEST_CASE_METHOD(AlignerNodeTestFixture,

// Get the sam line from BAM pipeline
dorado::HtsReader bam_reader(query, std::nullopt);
bam_reader.set_add_filename_tag(false);
auto bam_records = RunPipelineWithBamMessages(bam_reader, ref, "", options, 2);
CHECK(bam_records.size() == 1);
auto sam_line_from_bam_ptr = get_sam_line_from_bam(std::move(bam_records[0]));
Expand All @@ -553,23 +552,21 @@ TEST_CASE_METHOD(AlignerNodeTestFixture,
auto align_info = std::make_shared<dorado::alignment::AlignmentInfo>();
align_info->minimap_options = options;
align_info->reference_file = ref;
auto simplex_read = RunPipelineForRead<dorado::SimplexRead>(
align_info, align_info, std::move(read_id), std::move(sequence));
auto simplex_read =
RunPipelineForRead<dorado::SimplexRead>(align_info, align_info, read_id, sequence);
auto sam_line_from_read_common =
std::move(simplex_read->read_common.alignment_results[0].sam_string);

// Do the comparison checks
CHECK_FALSE(sam_line_from_read_common.empty());

if (sam_line_from_read_common.at(sam_line_from_read_common.size() - 1) == '\n') {
sam_line_from_read_common =
sam_line_from_read_common.substr(0, sam_line_from_read_common.size() - 1);
REQUIRE_FALSE(sam_line_from_read_common.empty());
if (sam_line_from_read_common.back() == '\n') {
sam_line_from_read_common.resize(sam_line_from_read_common.size() - 1);
}

const auto bam_fields = dorado::utils::split(sam_line_from_bam_ptr, '\t');
const auto read_common_fields = dorado::utils::split(sam_line_from_read_common, '\t');
CHECK(bam_fields.size() == read_common_fields.size());
CHECK(bam_fields.size() >= 11);
REQUIRE(bam_fields.size() == read_common_fields.size());
REQUIRE(bam_fields.size() >= 11);
// first 11 mandatory fields should be identical
for (std::size_t field_index{0}; field_index < 11; ++field_index) {
CAPTURE(field_index);
Expand Down
25 changes: 22 additions & 3 deletions tests/BamReaderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,17 @@ TEST_CASE("HtsReaderTest: get_tag", TEST_GROUP) {

dorado::HtsReader reader(sam.string(), std::nullopt);
while (reader.read()) {
// All records in small.sam have this set to 0.
CHECK(reader.get_tag<int>("rl") == 0);
// All records in small.sam have these set.
CHECK(reader.get_tag<int>("XA") == 42);
CHECK(reader.get_tag<std::string>("XB") == "test");
// Intentionally bad tag to test that missing tags don't return garbage.
CHECK(reader.get_tag<int>("##") == 0);
CHECK(reader.get_tag<float>("##") == 0);
CHECK(reader.get_tag<std::string>("##") == "");
// Type mismatch doesn't crash.
CHECK(reader.get_tag<int>("XB") == 0);
CHECK(reader.get_tag<float>("XB") == 0);
CHECK(reader.get_tag<std::string>("XA") == "");
}
}

Expand Down Expand Up @@ -129,4 +136,16 @@ TEST_CASE(
"read=1728 ch=332 start_time=2017-06-16T15:31:55Z");
}

} // namespace dorado::hts_reader::test
TEST_CASE("HtsReaderTest: filename tag added if missing", TEST_GROUP) {
fs::path aligner_test_dir = fs::path(get_data_dir("bam_reader"));
auto filename = GENERATE("input.fa", "fastq_with_tags.fq");
auto fasta = aligner_test_dir / filename;

dorado::HtsReader reader(fasta.string(), std::nullopt);
while (reader.read()) {
// All should be given the name of input file.
CHECK(reader.get_tag<std::string>("fn") == filename);
}
}

} // namespace dorado::hts_reader::test
1 change: 1 addition & 0 deletions tests/BamUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ TEST_CASE("BamUtilsTest: Remove all alignment tags", TEST_GROUP) {
auto sam = bam_utils_test_dir / "aligned_record.bam";

HtsReader reader(sam.string(), std::nullopt);
reader.set_add_filename_tag(false);
REQUIRE(reader.read()); // Parse first and only record.
auto record = reader.record.get();

Expand Down
1 change: 0 additions & 1 deletion tests/FastxRandomReaderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ BamPtr generate_bam_entry(const std::string& read_id,
TEST_CASE("Check if read can be loaded correctly.", "FastxRandomReader") {
auto temp_dir = tests::make_temp_dir("fastx_random_reader_test");
auto temp_input_file = temp_dir.m_path / "input.fq";
spdlog::info("{}", temp_dir.m_path.string());

const std::string seq = "ACTGATCG";
const std::vector<uint8_t> qscore = {20, 20, 30, 30, 20, 20, 40, 40};
Expand Down
2 changes: 1 addition & 1 deletion tests/NodeSmokeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class NodeSmokeTestBase {
read->read_common.attributes.read_number = 12345;
read->read_common.attributes.channel_number = 5;
read->read_common.attributes.start_time = "2017-04-29T09:10:04Z";
read->read_common.attributes.fast5_filename = "test.fast5";
read->read_common.attributes.filename = "test.fast5";
read->read_common.client_info = client_info;
return read;
}
Expand Down
Loading

0 comments on commit affea85

Please sign in to comment.