Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpliceAI Batching Support #94

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ Required parameters:
Optional parameters:
- ```-D```: Maximum distance between the variant and gained/lost splice site (default: 50).
- ```-M```: Mask scores representing annotated acceptor/donor gain and unannotated acceptor/donor loss (default: 0).
- ```-B```: Number of predictions to collect before running models on them in batch. (default: 1 (don't batch))
- ```-T```: Internal Tensorflow `predict()` batch size if you want something different from the `-B` value. (default: the `-B` value)
- ```-V```: Enable verbose logging during run

**Batching Considerations:** When setting the batching parameters, be mindful of the system and gpu memory of the machine you
are running the script on. Feel free to experiment, but some reasonable `-B` numbers would be 64/128.

Batching Performance Benchmarks:

| Type | Speed |
| -------- | ----------- |
| n1-standard-2 CPU (GCP) | ~800 per hour |
| CPU (2019 MacBook Pro) | ~3,000 per hour |
| K80 GPU (GCP) | ~25,000 per hour |
| V100 GPU (GCP) | ~150,000 per hour |

Details of SpliceAI INFO field:

Expand Down Expand Up @@ -107,5 +122,16 @@ acceptor_prob = y[0, :, 1]
donor_prob = y[0, :, 2]
```

### Modifications to Original

**Batching Support** - Invitae (_December 2021_)

* Adds new command line parameters, `--prediction-batch-size` and `--tensorflow-batch-size` to support batching variants to optimize prediction utilization on a GPU
* Adds a `VCFPredictionBatch` class that manages collection the VCF records, placing them in batches based on the encoded tensor size. Once the batch size is reached, predictions are run in batches, then output is written back in the original order reassembling the annotations for the VCF record. Each VCF record has a lookup key for where each of the ref/alts are within their batches, so it knows where to grab the results during reassembly
* Breaks out code in the existing `get_delta_scores` method into reusable methods used in the batching and the original source code. This way the batching code can utilize the same logic inside that method while still maintaining the original version
* Adds batch utility methods that split up what was all previously done in `get_delta_scores`. `encode_batch_record` handles what was in the first half, taking in the VCF record and generating one-hot encoded matrices for the ref/alts. `extract_delta_scores` handles the second half of the `get_delta_scores` by reassembling the annotations based on the batched predictions
* Adds test cases to run a small file using a generated FASTA reference to test if the results are the same with no batching and with different batching sizes
* Slightly modifies the entrypoint of running the code to allow for easier unit testing. Being able to pass in what would normally come from the argparser

### Contact
Kishore Jaganathan: [email protected]
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Original source code modified to add prediction batching support by Invitae in 2021.
# Modifications copyright (c) 2021 Invitae Corporation.

from setuptools import setup
import io

Expand All @@ -11,7 +14,7 @@
author_email='[email protected]',
license='GPLv3',
url='https://github.com/illumina/SpliceAI',
packages=['spliceai'],
packages=['spliceai', 'spliceai.batch'],
install_requires=['keras>=2.0.5',
'pyfaidx>=0.5.0',
'pysam>=0.10.0',
Expand Down
9 changes: 8 additions & 1 deletion spliceai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Original source code modified to add prediction batching support by Invitae in 2021.
# Modifications copyright (c) 2021 Invitae Corporation.

import signal
from pkg_resources import get_distribution


signal.signal(signal.SIGINT, lambda x, y: exit(0))
try:
signal.signal(signal.SIGINT, lambda x, y: exit(0))
except ValueError:
# Continue if we're not able to set the signal handler due to which thread is running the code
pass

name = 'spliceai'
__version__ = get_distribution(name).version
72 changes: 61 additions & 11 deletions spliceai/__main__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Original source code modified to add prediction batching support by Invitae in 2021.
# Modifications copyright (c) 2021 Invitae Corporation.

import sys
import argparse
import logging
import pysam
from spliceai.utils import Annotator, get_delta_scores

from spliceai.batch.batch import VCFPredictionBatch
from spliceai.utils import Annotator, get_delta_scores

try:
from sys.stdin import buffer as std_in
Expand Down Expand Up @@ -34,22 +38,46 @@ def get_options():
type=int, choices=[0, 1],
help='mask scores representing annotated acceptor/donor gain and '
'unannotated acceptor/donor loss, defaults to 0')
parser.add_argument('-B', '--prediction-batch-size', metavar='prediction_batch_size', default=1, type=int,
help='number of predictions to process at a time, note a single vcf record '
'may have multiple predictions for overlapping genes and multiple alts')
parser.add_argument('-T', '--tensorflow-batch-size', metavar='tensorflow_batch_size', type=int,
help='tensorflow batch size for model predictions')
parser.add_argument('-V', '--verbose', action='store_true', help='enables verbose logging')
args = parser.parse_args()

return args


def main():

args = get_options()

if args.verbose:
logging.basicConfig(
format='%(asctime)s %(levelname)s %(name)s: - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.DEBUG,
)

if None in [args.I, args.O, args.D, args.M]:
logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation '
'[-D [distance]] [-M [mask]]')
'[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]]')
exit()

# Default the tensorflow batch size to the prediction_batch_size if it's not supplied in the args
tensorflow_batch_size = args.tensorflow_batch_size if args.tensorflow_batch_size else args.prediction_batch_size

run_spliceai(input_data=args.I, output_data=args.O, reference=args.R,
annotation=args.A, distance=args.D, mask=args.M,
prediction_batch_size=args.prediction_batch_size,
tensorflow_batch_size=tensorflow_batch_size)


def run_spliceai(input_data, output_data, reference, annotation, distance, mask, prediction_batch_size,
tensorflow_batch_size):

try:
vcf = pysam.VariantFile(args.I)
vcf = pysam.VariantFile(input_data)
except (IOError, ValueError) as e:
logging.error('{}'.format(e))
exit()
Expand All @@ -61,21 +89,43 @@ def main():
'Format: ALLELE|SYMBOL|DS_AG|DS_AL|DS_DG|DS_DL|DP_AG|DP_AL|DP_DG|DP_DL">')

try:
output = pysam.VariantFile(args.O, mode='w', header=header)
output_data = pysam.VariantFile(output_data, mode='w', header=header)
except (IOError, ValueError) as e:
logging.error('{}'.format(e))
exit()

ann = Annotator(args.R, args.A)
ann = Annotator(reference, annotation)
batch = None

# Only use the batching code if we are batching
if prediction_batch_size > 1:
batch = VCFPredictionBatch(
ann=ann,
output=output_data,
dist=distance,
mask=mask,
prediction_batch_size=prediction_batch_size,
tensorflow_batch_size=tensorflow_batch_size,
)

for record in vcf:
scores = get_delta_scores(record, ann, args.D, args.M)
if len(scores) > 0:
record.info['SpliceAI'] = scores
output.write(record)
if batch:
# Add record to batch, if batch fills, then they will all be processed at once
batch.add_record(record)
else:
# If we're not batching, let's run the original code
scores = get_delta_scores(record, ann, distance, mask)
if len(scores) > 0:
record.info['SpliceAI'] = scores
output_data.write(record)

if batch:
# Ensure we process any leftover records in the batch when we finish iterating the VCF. This
# would be a good candidate for a context manager if we removed the original non batching code above
batch.finish()

vcf.close()
output.close()
output_data.close()


if __name__ == '__main__':
Expand Down
Empty file added spliceai/batch/__init__.py
Empty file.
200 changes: 200 additions & 0 deletions spliceai/batch/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Original source code modified to add prediction batching support by Invitae in 2021.
# Modifications copyright (c) 2021 Invitae Corporation.

import collections
import logging
import time

import numpy as np

from spliceai.batch.batch_utils import extract_delta_scores, get_preds, encode_batch_records

logger = logging.getLogger(__name__)

SequenceType_REF = 0
SequenceType_ALT = 1


BatchLookupIndex = collections.namedtuple(
'BatchLookupIndex', 'sequence_type tensor_size batch_index'
)

PreparedVCFRecord = collections.namedtuple(
'PreparedVCFRecord', 'vcf_record gene_info locations'
)


class VCFPredictionBatch:
def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_batch_size):
self.ann = ann
self.output = output
self.dist = dist
self.mask = mask
# This is the maximum number of predictions to parse/encode/predict at a time
self.prediction_batch_size = prediction_batch_size
# This is the size of the batch tensorflow will use to make the predictions
self.tensorflow_batch_size = tensorflow_batch_size

# Batch vars
self.batches = {}
self.prepared_vcf_records = []

# Counts
self.batch_predictions = 0
self.total_predictions = 0
self.total_vcf_records = 0

def _clear_batch(self):
self.batch_predictions = 0
self.batches.clear()
del self.prepared_vcf_records[:]

def _process_batch(self):
start = time.time()
total_batch_predictions = 0
logger.debug('Starting process_batch')

# Sanity check dump of batch sizes
batch_sizes = ["{}:{}".format(tensor_size, len(batch)) for tensor_size, batch in self.batches.items()]
logger.debug('Batch Sizes: {}'.format(batch_sizes))

# Collect each batch's predictions
batch_preds = {}
for tensor_size, batch in self.batches.items():
# Convert list of encodings into a proper sized numpy matrix
prediction_batch = np.concatenate(batch, axis=0)

# Run predictions
batch_preds[tensor_size] = np.mean(
get_preds(self.ann, prediction_batch, self.prediction_batch_size), axis=0
)

# Iterate over original list of vcf records, reconstructing record with annotations
for prepared_record in self.prepared_vcf_records:
record_predictions = self._write_record(prepared_record, batch_preds)
total_batch_predictions += record_predictions

self._clear_batch()
logger.debug('Predictions: {}, VCF Records: {}'.format(self.total_predictions, self.total_vcf_records))
duration = time.time() - start
preds_per_sec = total_batch_predictions / duration
preds_per_hour = preds_per_sec * 60 * 60
logger.debug('Finished in {:0.2f}s, per sec: {:0.2f}, per hour: {:0.2f}'.format(duration,
preds_per_sec,
preds_per_hour))

def _write_record(self, prepared_record, batch_preds):
record = prepared_record.vcf_record
gene_info = prepared_record.gene_info
record_predictions = 0

all_y_ref = []
all_y_alt = []

# Each prediction in the batch is located and put into the correct y
for location in prepared_record.locations:
# No prediction here
if location.tensor_size == 0:
if location.sequence_type == SequenceType_REF:
all_y_ref.append(None)
else:
all_y_alt.append(None)
continue

# Extract the prediction from the batch into a list of predictions for this record
batch = batch_preds[location.tensor_size]
if location.sequence_type == SequenceType_REF:
all_y_ref.append(batch[[location.batch_index], :, :])
else:
all_y_alt.append(batch[[location.batch_index], :, :])

delta_scores = extract_delta_scores(
all_y_ref=all_y_ref,
all_y_alt=all_y_alt,
record=record,
ann=self.ann,
dist_var=self.dist,
mask=self.mask,
gene_info=gene_info,
)

# If there are predictions, write them to the VCF INFO section
if len(delta_scores) > 0:
record.info['SpliceAI'] = delta_scores
record_predictions += len(delta_scores)

self.output.write(record)
return record_predictions

def add_record(self, record):
"""
Adds a record to a batch. It'll capture the gene information for the record and
save it for later to avoid looking it up again, then it'll encode ref and alt from
the VCF record and place the encoded values into lists of matching sizes. Once the
encoded values are added, a BatchLookupIndex is created so that after the predictions
are made, it knows where to look up the corresponding prediction for the vcf record.

Once the batch size hits it's capacity, it'll process all the predictions for the
encoded batches.
"""

self.total_vcf_records += 1
# Collect gene information for this record
gene_info = self.ann.get_name_and_strand(record.chrom, record.pos)

# Keep track of how many predictions we're going to make
prediction_count = len(record.alts) * len(gene_info.genes)
self.batch_predictions += prediction_count
self.total_predictions += prediction_count

# Collect lists of encoded ref/alt sequences
x_ref, x_alt = encode_batch_records(record, self.ann, self.dist, gene_info)

# List of BatchLookupIndex's so we know how to lookup predictions for records from
# the batches
batch_lookup_indexes = []

# Process the encodings into batches
for var_type, encoded_seq in zip((SequenceType_REF, SequenceType_ALT), (x_ref, x_alt)):

if len(encoded_seq) == 0:
# Add BatchLookupIndex with zeros so when the batch collects the outputs
# it knows that there is no prediction for this record
batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0))
continue

# Iterate over the encoded sequence and drop into the correct batch by size and
# create an index to use to pull out the result after batch is processed
for row in encoded_seq:
# Extract the size of the sequence that was encoded to build a batch from
tensor_size = row.shape[1]

# Create batch for this size
if tensor_size not in self.batches:
self.batches[tensor_size] = []

# Add encoded record to batch
self.batches[tensor_size].append(row)

# Get the index of the record we just added in the batch
cur_batch_record_ix = len(self.batches[tensor_size]) - 1

# Store a reference so we can pull out the prediction for this item from the batches
batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, cur_batch_record_ix))

# Save the batch locations for this record on the composite object
prepared_record = PreparedVCFRecord(
vcf_record=record, gene_info=gene_info, locations=batch_lookup_indexes
)
self.prepared_vcf_records.append(prepared_record)

# If we're reached our threshold for the max items to process, then process the batch
if self.batch_predictions >= self.prediction_batch_size:
self._process_batch()

def finish(self):
"""
Method to process all the remaining items that have been added to the batch.
"""
if len(self.prepared_vcf_records) > 0:
self._process_batch()
Loading