diff --git a/src/main/scala/com/fulcrumgenomics/bam/Bams.scala b/src/main/scala/com/fulcrumgenomics/bam/Bams.scala index 6e1650dc3..72e3fb0f8 100644 --- a/src/main/scala/com/fulcrumgenomics/bam/Bams.scala +++ b/src/main/scala/com/fulcrumgenomics/bam/Bams.scala @@ -161,8 +161,9 @@ object Bams extends LazyLogging { def sorter(order: SamOrder, header: SAMFileHeader, maxRecordsInRam: Int = MaxInMemory, - tmpDir: DirPath = Io.tmpDir): Sorter[SamRecord,order.A] = { - new Sorter(maxRecordsInRam, new SamRecordCodec(header), order.sortkey, tmpDir=tmpDir) + tmpDir: DirPath = Io.tmpDir, + threads: Int = 4): Sorter[SamRecord,order.A] = { + new Sorter(maxRecordsInRam, new SamRecordCodec(header), order.sortkey, tmpDir=tmpDir, threads=threads) } /** A wrapper to order objects of type [[TagType]] using the ordering given. Used when sorting by tag where we wish diff --git a/src/main/scala/com/fulcrumgenomics/bam/SortBam.scala b/src/main/scala/com/fulcrumgenomics/bam/SortBam.scala index 4e9a59360..9a936104f 100644 --- a/src/main/scala/com/fulcrumgenomics/bam/SortBam.scala +++ b/src/main/scala/com/fulcrumgenomics/bam/SortBam.scala @@ -29,7 +29,7 @@ import com.fulcrumgenomics.bam.api.{SamOrder, SamSource, SamWriter} import com.fulcrumgenomics.cmdline.{ClpGroups, FgBioTool} import com.fulcrumgenomics.commons.util.LazyLogging import com.fulcrumgenomics.sopt.{arg, clp} -import com.fulcrumgenomics.util.Io +import com.fulcrumgenomics.util.{Io, ProgressLogger, Sorter} @clp(group=ClpGroups.SamOrBam, description = """ @@ -57,15 +57,25 @@ class SortBam ( @arg(flag='i', doc="Input SAM or BAM.") val input: PathToBam, @arg(flag='o', doc="Output SAM or BAM.") val output: PathToBam, @arg(flag='s', doc="Order into which to sort the records.") val sortOrder: SamOrder = SamOrder.Coordinate, - @arg(flag='m', doc="Max records in RAM.") val maxRecordsInRam: Int = SamWriter.DefaultMaxRecordsInRam + @arg(flag='m', doc="Max records in RAM.") val maxRecordsInRam: Int = SamWriter.DefaultMaxRecordsInRam, + @arg(flag='t', doc="Number of threads to use.") val threads: Int = 4 ) extends FgBioTool with LazyLogging { override def execute(): Unit = { Io.assertReadable(input) Io.assertCanWriteFile(output) - val in = SamSource(input) - val out = SamWriter(output, in.header.clone(), sort=Some(sortOrder), maxRecordsInRam=maxRecordsInRam) - out ++= in + val in = SamSource(input) + val header = in.header.clone() + sortOrder.applyTo(header) + val sorter = Bams.sorter(sortOrder, header, maxRecordsInRam=maxRecordsInRam, threads=threads) + val progress = ProgressLogger(logger, verb="sorted", unit=2e6.toInt) + in.foreach { r => + sorter += r + progress.record(r) + } + + val out = SamWriter(output, header=header) + sorter.foreach { r => out += r } out.close() } } diff --git a/src/main/scala/com/fulcrumgenomics/util/Sorter.scala b/src/main/scala/com/fulcrumgenomics/util/Sorter.scala index 494115163..f3bf1ef47 100644 --- a/src/main/scala/com/fulcrumgenomics/util/Sorter.scala +++ b/src/main/scala/com/fulcrumgenomics/util/Sorter.scala @@ -26,15 +26,18 @@ package com.fulcrumgenomics.util import java.io._ import java.nio.file.{Files, Path} -import java.util - +import java.{lang, util} import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.commons.io.Writer import com.fulcrumgenomics.commons.collection.SelfClosingIterator +import com.fulcrumgenomics.commons.util.LazyLogging import com.fulcrumgenomics.util.Sorter.{Codec, SortEntry} import htsjdk.samtools.util.TempStreamFactory +import java.util.concurrent.{Executors, TimeUnit} import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag +import scala.tools.nsc.doc.html.HtmlTags.B /** * Companion object for [[Sorter]] that contains various types used. @@ -75,15 +78,87 @@ object Sorter { * * Both must be thread-safe as they may be invoked across threads without external synchronization */ -class Sorter[A,B <: Ordered[B]](val maxObjectsInRam: Int, - private val codec: Codec[A], - private val keyfunc: A => B, - private val tmpDir: DirPath = Io.tmpDir) extends Iterable[A] with Writer[A] { +class Sorter[A : ClassTag, B <: Ordered[B]](val maxObjectsInRam: Int, + private val codec: Codec[A], + private val keyfunc: A => B, + private val tmpDir: DirPath = Io.tmpDir, + private val threads: Int = 10) extends Iterable[A] with Writer[A] with LazyLogging { require(maxObjectsInRam > 1, "Max records in RAM must be at least 2, and probably much higher!") - private val stash = new Array[SortEntry[B]](maxObjectsInRam) - private val files = new ArrayBuffer[Path]() - private var recordsInMemory: Int = 0 - private val _tmpStreamFactory = new TempStreamFactory + private val l1StashSize: Int = (maxObjectsInRam / threads.toDouble).toInt / 10 + private val l2StashSize: Int = l1StashSize * 10 + private var l1Stash: Array[A] = newL1Stash + private var l2Stash: Array[SortEntry[B]] = newL2Stash + private var l1StashCount: Int = 0 // Number of records in current L1 stash + private var l2StashCount: Int = 0 // Number of records in current L2 stash + private val stashLock: Object = new Object + private val files = new ArrayBuffer[Path]() + private val _tmpStreamFactory = new TempStreamFactory + private val executor = { + if (threads <= 1) { + Executors.newSingleThreadExecutor() + } + + Executors.newFixedThreadPool(threads) + } + + @inline private def newL1Stash: Array[A] = new Array[A](l1StashSize) + @inline private def newL2Stash: Array[SortEntry[B]] = new Array[SortEntry[B]](l2StashSize) + + private class StashDrainer(val l1: Array[A], val n: Int, forceFlush: Boolean = false) extends Runnable { + override def run(): Unit = { + // Encode all the L1 stash items + val encodeStart = System.currentTimeMillis() + val items = new Array[SortEntry[B]](n) + forloop (from=0, until=n) { i => + val item = l1(i) + val key = keyfunc(item) + val bytes = codec.encode(item) + items(i) = SortEntry(key, bytes) + } + val encodeEnd = System.currentTimeMillis() + + // Now lock the L2 stash and write into it, returning the taken l2 stash if it also needs flushing + val copyStart = System.currentTimeMillis() + val l2ToDrain = Sorter.this.stashLock.synchronized { + Array.copy(items, 0, l2Stash, l2StashCount, items.length) + l2StashCount += items.length + + if (forceFlush || l2StashCount == l2StashSize) { + val l2 = l2Stash + val l2Count = l2StashCount + l2Stash = newL2Stash + l2StashCount = 0 + + // Make the file we're going to write to and add it to the list of files + val path = Io.makeTempFile("sorter.", ".tmp", dir=Some(tmpDir)) + files += path + path.toFile.deleteOnExit() + + Some(l2, l2Count, path) + } + else { + None + } + } + val copyEnd = System.currentTimeMillis() + + logger.debug(s"Encoding took ${encodeEnd-encodeStart}ms, copying took ${copyEnd-copyStart}ms on thread ${Thread.currentThread().getName}.") + + l2ToDrain.foreach { case (l2, n, path) => + val drainStart = System.currentTimeMillis() + util.Arrays.parallelSort(l2, 0, n) + val out = new DataOutputStream(_tmpStreamFactory.wrapTempOutputStream(Io.toOutputStream(path), Io.bufferSize)) + forloop(from = 0, until = n) { i => + val bytes = l2(i).bytes + out.writeInt(bytes.length) + out.write(bytes) + } + out.close() + val drainEnd = System.currentTimeMillis() + logger.debug(s"Sorting and draining to file took ${drainEnd-drainStart}ms.") + } + } + } /** * An iterator that consumes data from a single tmp file of sorted data and produces @@ -141,6 +216,7 @@ class Sorter[A,B <: Ordered[B]](val maxObjectsInRam: Int, def close(): Unit = if (!closed) { this.stream.safelyClose(); closed = true } } + ////////////////////////////////////////////////////////////////////////////////////////////////// /** * An iterator that merges records from [[SortedIterator]]s and maintains ordering. @@ -173,38 +249,30 @@ class Sorter[A,B <: Ordered[B]](val maxObjectsInRam: Int, } } + ////////////////////////////////////////////////////////////////////////////////////////////////// + /** * Adds a record to the sorter. This is an amortized constant time operation, but the individual times * will vary wildly as the accrued objects are written to disk once the max in memory is reached. */ override def write(item: A): Unit = { - val key = keyfunc(item) - val bytes = this.codec.encode(item) - stash(recordsInMemory) = SortEntry(key, bytes) - recordsInMemory += 1 - - if (recordsInMemory == maxObjectsInRam) spill() + this.l1Stash(l1StashCount) = item + l1StashCount += 1 + if (l1StashCount == l1StashSize) drain(force=false) } /** * Writes out a temporary file containing all the accrued objects and releases the objects so * that they can be garbage collected. */ - private def spill(): Unit = { - if (recordsInMemory > 0) { - util.Arrays.parallelSort(stash, 0, recordsInMemory) - val path = Io.makeTempFile("sorter.", ".tmp", dir=Some(this.tmpDir)) - val out = new DataOutputStream(_tmpStreamFactory.wrapTempOutputStream(Io.toOutputStream(path), Io.bufferSize)) - forloop(from = 0, until = recordsInMemory) { i => - val bytes = stash(i).bytes - out.writeInt(bytes.length) - out.write(bytes) - stash(i) = null - } - out.close() - path.toFile.deleteOnExit() - this.files += path - this.recordsInMemory = 0 + private def drain(force: Boolean): Unit = { + if (force || l1StashCount > 0 || l2StashCount > 0) { + val l1 = this.l1Stash + val l1Count = this.l1StashCount + this.l1Stash = newL1Stash + this.l1StashCount = 0 + + this.executor.submit(new StashDrainer(l1, l1Count, force)) } } @@ -215,7 +283,12 @@ class Sorter[A,B <: Ordered[B]](val maxObjectsInRam: Int, * not be invoked too frequently! */ def iterator: SelfClosingIterator[A] = { - spill() + drain(force=true) + if (!this.executor.isShutdown()) { + this.executor.shutdown() + this.executor.awaitTermination(600, TimeUnit.SECONDS) + } + val streams = files.iterator.map(f => _tmpStreamFactory.wrapTempInputStream(Io.toInputStream(f), Io.bufferSize)).toSeq val mergingIterator = new MergingIterator(streams) new SelfClosingIterator(mergingIterator, () => mergingIterator.close())