Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very rough draft of parallelizing Sorter. #792

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/main/scala/com/fulcrumgenomics/bam/Bams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ object Bams extends LazyLogging {
def sorter(order: SamOrder,
header: SAMFileHeader,
maxRecordsInRam: Int = MaxInMemory,
tmpDir: DirPath = Io.tmpDir): Sorter[SamRecord,order.A] = {
new Sorter(maxRecordsInRam, new SamRecordCodec(header), order.sortkey, tmpDir=tmpDir)
tmpDir: DirPath = Io.tmpDir,
threads: Int = 4): Sorter[SamRecord,order.A] = {
new Sorter(maxRecordsInRam, new SamRecordCodec(header), order.sortkey, tmpDir=tmpDir, threads=threads)
}

/** A wrapper to order objects of type [[TagType]] using the ordering given. Used when sorting by tag where we wish
Expand Down
20 changes: 15 additions & 5 deletions src/main/scala/com/fulcrumgenomics/bam/SortBam.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.fulcrumgenomics.bam.api.{SamOrder, SamSource, SamWriter}
import com.fulcrumgenomics.cmdline.{ClpGroups, FgBioTool}
import com.fulcrumgenomics.commons.util.LazyLogging
import com.fulcrumgenomics.sopt.{arg, clp}
import com.fulcrumgenomics.util.Io
import com.fulcrumgenomics.util.{Io, ProgressLogger, Sorter}

@clp(group=ClpGroups.SamOrBam, description =
"""
Expand Down Expand Up @@ -57,15 +57,25 @@ class SortBam
( @arg(flag='i', doc="Input SAM or BAM.") val input: PathToBam,
@arg(flag='o', doc="Output SAM or BAM.") val output: PathToBam,
@arg(flag='s', doc="Order into which to sort the records.") val sortOrder: SamOrder = SamOrder.Coordinate,
@arg(flag='m', doc="Max records in RAM.") val maxRecordsInRam: Int = SamWriter.DefaultMaxRecordsInRam
@arg(flag='m', doc="Max records in RAM.") val maxRecordsInRam: Int = SamWriter.DefaultMaxRecordsInRam,
@arg(flag='t', doc="Number of threads to use.") val threads: Int = 4
) extends FgBioTool with LazyLogging {
override def execute(): Unit = {
Io.assertReadable(input)
Io.assertCanWriteFile(output)

val in = SamSource(input)
val out = SamWriter(output, in.header.clone(), sort=Some(sortOrder), maxRecordsInRam=maxRecordsInRam)
out ++= in
val in = SamSource(input)
val header = in.header.clone()
sortOrder.applyTo(header)
val sorter = Bams.sorter(sortOrder, header, maxRecordsInRam=maxRecordsInRam, threads=threads)
val progress = ProgressLogger(logger, verb="sorted", unit=2e6.toInt)
in.foreach { r =>
sorter += r
progress.record(r)
}

val out = SamWriter(output, header=header)
sorter.foreach { r => out += r }
out.close()
}
}
137 changes: 105 additions & 32 deletions src/main/scala/com/fulcrumgenomics/util/Sorter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ package com.fulcrumgenomics.util

import java.io._
import java.nio.file.{Files, Path}
import java.util

import java.{lang, util}
import com.fulcrumgenomics.FgBioDef._
import com.fulcrumgenomics.commons.io.Writer
import com.fulcrumgenomics.commons.collection.SelfClosingIterator
import com.fulcrumgenomics.commons.util.LazyLogging
import com.fulcrumgenomics.util.Sorter.{Codec, SortEntry}
import htsjdk.samtools.util.TempStreamFactory

import java.util.concurrent.{Executors, TimeUnit}
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.tools.nsc.doc.html.HtmlTags.B

/**
* Companion object for [[Sorter]] that contains various types used.
Expand Down Expand Up @@ -75,15 +78,87 @@ object Sorter {
*
* Both must be thread-safe as they may be invoked across threads without external synchronization
*/
class Sorter[A,B <: Ordered[B]](val maxObjectsInRam: Int,
private val codec: Codec[A],
private val keyfunc: A => B,
private val tmpDir: DirPath = Io.tmpDir) extends Iterable[A] with Writer[A] {
class Sorter[A : ClassTag, B <: Ordered[B]](val maxObjectsInRam: Int,
private val codec: Codec[A],
private val keyfunc: A => B,
private val tmpDir: DirPath = Io.tmpDir,
private val threads: Int = 10) extends Iterable[A] with Writer[A] with LazyLogging {
require(maxObjectsInRam > 1, "Max records in RAM must be at least 2, and probably much higher!")
private val stash = new Array[SortEntry[B]](maxObjectsInRam)
private val files = new ArrayBuffer[Path]()
private var recordsInMemory: Int = 0
private val _tmpStreamFactory = new TempStreamFactory
private val l1StashSize: Int = (maxObjectsInRam / threads.toDouble).toInt / 10
private val l2StashSize: Int = l1StashSize * 10
private var l1Stash: Array[A] = newL1Stash
private var l2Stash: Array[SortEntry[B]] = newL2Stash
private var l1StashCount: Int = 0 // Number of records in current L1 stash
private var l2StashCount: Int = 0 // Number of records in current L2 stash
private val stashLock: Object = new Object
private val files = new ArrayBuffer[Path]()
private val _tmpStreamFactory = new TempStreamFactory
private val executor = {
if (threads <= 1) {
Executors.newSingleThreadExecutor()
}

Executors.newFixedThreadPool(threads)
}

@inline private def newL1Stash: Array[A] = new Array[A](l1StashSize)
@inline private def newL2Stash: Array[SortEntry[B]] = new Array[SortEntry[B]](l2StashSize)

private class StashDrainer(val l1: Array[A], val n: Int, forceFlush: Boolean = false) extends Runnable {
override def run(): Unit = {
// Encode all the L1 stash items
val encodeStart = System.currentTimeMillis()
val items = new Array[SortEntry[B]](n)
forloop (from=0, until=n) { i =>
val item = l1(i)
val key = keyfunc(item)
val bytes = codec.encode(item)
items(i) = SortEntry(key, bytes)
}
val encodeEnd = System.currentTimeMillis()

// Now lock the L2 stash and write into it, returning the taken l2 stash if it also needs flushing
val copyStart = System.currentTimeMillis()
val l2ToDrain = Sorter.this.stashLock.synchronized {
Array.copy(items, 0, l2Stash, l2StashCount, items.length)
l2StashCount += items.length

if (forceFlush || l2StashCount == l2StashSize) {
val l2 = l2Stash
val l2Count = l2StashCount
l2Stash = newL2Stash
l2StashCount = 0

// Make the file we're going to write to and add it to the list of files
val path = Io.makeTempFile("sorter.", ".tmp", dir=Some(tmpDir))
files += path
path.toFile.deleteOnExit()

Some(l2, l2Count, path)
}
else {
None
}
}
val copyEnd = System.currentTimeMillis()

logger.debug(s"Encoding took ${encodeEnd-encodeStart}ms, copying took ${copyEnd-copyStart}ms on thread ${Thread.currentThread().getName}.")

l2ToDrain.foreach { case (l2, n, path) =>
val drainStart = System.currentTimeMillis()
util.Arrays.parallelSort(l2, 0, n)
val out = new DataOutputStream(_tmpStreamFactory.wrapTempOutputStream(Io.toOutputStream(path), Io.bufferSize))
forloop(from = 0, until = n) { i =>
val bytes = l2(i).bytes
out.writeInt(bytes.length)
out.write(bytes)
}
out.close()
val drainEnd = System.currentTimeMillis()
logger.debug(s"Sorting and draining to file took ${drainEnd-drainStart}ms.")
}
}
}

/**
* An iterator that consumes data from a single tmp file of sorted data and produces
Expand Down Expand Up @@ -141,6 +216,7 @@ class Sorter[A,B <: Ordered[B]](val maxObjectsInRam: Int,
def close(): Unit = if (!closed) { this.stream.safelyClose(); closed = true }
}

//////////////////////////////////////////////////////////////////////////////////////////////////

/**
* An iterator that merges records from [[SortedIterator]]s and maintains ordering.
Expand Down Expand Up @@ -173,38 +249,30 @@ class Sorter[A,B <: Ordered[B]](val maxObjectsInRam: Int,
}
}

//////////////////////////////////////////////////////////////////////////////////////////////////

/**
* Adds a record to the sorter. This is an amortized constant time operation, but the individual times
* will vary wildly as the accrued objects are written to disk once the max in memory is reached.
*/
override def write(item: A): Unit = {
val key = keyfunc(item)
val bytes = this.codec.encode(item)
stash(recordsInMemory) = SortEntry(key, bytes)
recordsInMemory += 1

if (recordsInMemory == maxObjectsInRam) spill()
this.l1Stash(l1StashCount) = item
l1StashCount += 1
if (l1StashCount == l1StashSize) drain(force=false)
}

/**
* Writes out a temporary file containing all the accrued objects and releases the objects so
* that they can be garbage collected.
*/
private def spill(): Unit = {
if (recordsInMemory > 0) {
util.Arrays.parallelSort(stash, 0, recordsInMemory)
val path = Io.makeTempFile("sorter.", ".tmp", dir=Some(this.tmpDir))
val out = new DataOutputStream(_tmpStreamFactory.wrapTempOutputStream(Io.toOutputStream(path), Io.bufferSize))
forloop(from = 0, until = recordsInMemory) { i =>
val bytes = stash(i).bytes
out.writeInt(bytes.length)
out.write(bytes)
stash(i) = null
}
out.close()
path.toFile.deleteOnExit()
this.files += path
this.recordsInMemory = 0
private def drain(force: Boolean): Unit = {
if (force || l1StashCount > 0 || l2StashCount > 0) {
val l1 = this.l1Stash
val l1Count = this.l1StashCount
this.l1Stash = newL1Stash
this.l1StashCount = 0

this.executor.submit(new StashDrainer(l1, l1Count, force))
}
}

Expand All @@ -215,7 +283,12 @@ class Sorter[A,B <: Ordered[B]](val maxObjectsInRam: Int,
* not be invoked too frequently!
*/
def iterator: SelfClosingIterator[A] = {
spill()
drain(force=true)
if (!this.executor.isShutdown()) {
this.executor.shutdown()
this.executor.awaitTermination(600, TimeUnit.SECONDS)
}

val streams = files.iterator.map(f => _tmpStreamFactory.wrapTempInputStream(Io.toInputStream(f), Io.bufferSize)).toSeq
val mergingIterator = new MergingIterator(streams)
new SelfClosingIterator(mergingIterator, () => mergingIterator.close())
Expand Down