Skip to content

Commit

Permalink
Merge pull request #240 from YosefLab/ecdna-simulator
Browse files Browse the repository at this point in the history
ecDNA simulator
  • Loading branch information
mattjones315 authored May 8, 2024
2 parents 41dbff8 + fca94b2 commit 78442c4
Show file tree
Hide file tree
Showing 15 changed files with 1,766 additions and 202,252 deletions.
4 changes: 4 additions & 0 deletions cassiopeia/mixins/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class DistanceSolverError(Exception):

pass

class ecDNABirthDeathSimulatorError(Exception):
"""An ExceptionClass for ecDNABirthDeathSimulator class."""

pass

class FitchCountError(Exception):
"""An ExceptionClass for FitchCount."""
Expand Down
3 changes: 3 additions & 0 deletions cassiopeia/plotting/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ def place_tree_and_annotations(
colorstrips.extend(heatmap)

# Any other annotations
if type(meta_data) == str:
meta_data = [meta_data]

for meta_item in meta_data:
if meta_item not in tree.cell_meta.columns:
raise PlottingError(
Expand Down
213 changes: 171 additions & 42 deletions cassiopeia/simulator/BirthDeathFitnessSimulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
process, including differing fitness on lineages on the tree. Allows for a
variety of division and fitness regimes to be specified by the user.
"""

from typing import Callable, Dict, Generator, List, Optional, Union

import networkx as nx
import numpy as np
from queue import PriorityQueue

from cassiopeia.data.CassiopeiaTree import CassiopeiaTree
from cassiopeia.mixins import TreeSimulatorError
from cassiopeia.mixins import CassiopeiaTreeError, TreeSimulatorError
from cassiopeia.simulator.TreeSimulator import TreeSimulator


Expand Down Expand Up @@ -94,6 +95,11 @@ class BirthDeathFitnessSimulator(TreeSimulator):
collapse_unifurcations: Specifies whether to collapse unifurcations in
the tree resulting from pruning dead lineages
random_seed: A seed for reproducibility
initial_tree: A tree used for initializing the simulation. When this
argument is passed, a simulation will pick up from the leaves of the
specified tree. This can be useful for simulating trees when
selection may change over time (for example, in the presence
or absence of a drug pressure).
Raises:
TreeSimulatorError if invalid stopping conditions are provided or if a
Expand All @@ -114,6 +120,7 @@ def __init__(
experiment_time: Optional[float] = None,
collapse_unifurcations: bool = True,
random_seed: int = None,
initial_tree: Optional[CassiopeiaTree] = None,
):
if num_extant is None and experiment_time is None:
raise TreeSimulatorError(
Expand Down Expand Up @@ -147,6 +154,103 @@ def __init__(
self.collapse_unifurcations = collapse_unifurcations
self.random_seed = random_seed

# useful for resuming a simulation, perhaps under different pressures.
self.initial_tree = initial_tree

def initialize_tree(self, names) -> nx.DiGraph:
"""Initializes a tree.
Initializes a tree (nx.DiGraph() object with one node). Auxiliary data
for each node is grabbed from self (initial conditions / params) or
hardcoded.
Args:
names: A generator (function object that stores internal state) that
will be used to generate names for the tree nodes
Returns:
tree (DiGraph object with one node, the root) and root
(name of root node in tree)
"""
if self.initial_tree:
tree = self.initial_tree.get_tree_topology()
for node in self.initial_tree.nodes:
try:
tree.nodes[node]["birth_scale"] = (
self.initial_tree.get_attribute(node, "birth_scale")
)
except CassiopeiaTreeError:
tree.nodes[node]['birth_scale'] = self.initial_birth_scale
tree.nodes[node]["time"] = self.initial_tree.get_attribute(
node, "time"
)
return tree

tree = nx.DiGraph()
root = next(names)
tree.add_node(root)
tree.nodes[root]["birth_scale"] = self.initial_birth_scale
tree.nodes[root]["time"] = 0

return tree

def make_initial_lineage_dict(self, tree: nx.DiGraph):
"""Makes initial lineage queue.
Uses self initial-conditions and hardcoded default parameters to create
an initial lineage dict
Args:
id_value: name of new lineage
Returns:
A lineage dict
"""

leaves = [node for node in tree if tree.out_degree(node) == 0]
current_lineages = PriorityQueue()
for leaf in leaves:

lineage_dict = self.make_lineage_dict(
leaf,
tree.nodes[leaf]["birth_scale"],
tree.nodes[leaf]["time"],
True,
)

if len(tree.nodes) == 1:
return lineage_dict

current_lineages.put((tree.nodes[leaf]["time"], leaf, lineage_dict))

return current_lineages

def make_lineage_dict(
self,
id_value,
birth_scale,
total_time,
active_flag,
):
"""makes a dict (lineage) from the given parameters. keys are hardcoded.
Args:
id_value: id of new lineage
birth_scale: birth_scale parameter of new lineage
total_time: age of lineage
active_flag: bool to indicate whether lineage is active
Returns:
A dict (lineage) with the parameter values under the hard-coded keys
"""
lineage_dict = {
"id": id_value,
"birth_scale": birth_scale,
"total_time": total_time,
"active": active_flag,
}
return lineage_dict

def simulate_tree(
self,
) -> CassiopeiaTree:
Expand All @@ -169,39 +273,41 @@ def simulate_tree(
TreeSimulatorError if all lineages die before a stopping condition
"""

def node_name_generator() -> Generator[str, None, None]:
def node_name_generator(start=0) -> Generator[str, None, None]:
"""Generates unique node names for the tree."""
i = 0
i = start
while True:
yield str(i)
i += 1

names = node_name_generator()
starting_index = 0
if self.initial_tree:
starting_index = (
np.max([int(l) for l in self.initial_tree.leaves]) + 1
)
names = node_name_generator(starting_index)

# Set the seed
if self.random_seed:
np.random.seed(self.random_seed)

# Instantiate the implicit root
tree = nx.DiGraph()
root = next(names)
tree.add_node(root)
tree.nodes[root]["birth_scale"] = self.initial_birth_scale
tree.nodes[root]["time"] = 0
current_lineages = PriorityQueue()
tree = self.initialize_tree(names)

current_lineages = PriorityQueue() # instantiate queue
# Records the nodes that are observed at the end of the experiment

# TO DO: update to accept arbitrary fields in the dict.
observed_nodes = []
starting_lineage = {
"id": root,
"birth_scale": self.initial_birth_scale,
"total_time": 0,
"active": True,
}

# Sample the waiting time until the first division
self.sample_lineage_event(
starting_lineage, current_lineages, tree, names, observed_nodes
)
starting_lineage = self.make_initial_lineage_dict(tree)

if len(tree.nodes) == 1:
# Sample the waiting time until the first division
self.sample_lineage_event(
starting_lineage, current_lineages, tree, names, observed_nodes
)
else:
current_lineages = starting_lineage

# Perform the process until there are no active extant lineages left
while not current_lineages.empty():
Expand All @@ -219,6 +325,7 @@ def node_name_generator() -> Generator[str, None, None]:
min_total_time = remaining_lineages[0]["total_time"]
for lineage in remaining_lineages:
parent = list(tree.predecessors(lineage["id"]))[0]

tree.nodes[lineage["id"]]["time"] += (
min_total_time - lineage["total_time"]
)
Expand All @@ -229,31 +336,18 @@ def node_name_generator() -> Generator[str, None, None]:
break
# Pop the minimum age lineage to simulate forward time
_, _, lineage = current_lineages.get()

# If the lineage is no longer active, just remove it from the queue.
# This represents the time at which the lineage dies.
if lineage["active"]:
for _ in range(2):
for i in range(2):
self.sample_lineage_event(
lineage, current_lineages, tree, names, observed_nodes
)

cassiopeia_tree = CassiopeiaTree(tree=tree)
time_dictionary = {}
for i in tree.nodes:
time_dictionary[i] = tree.nodes[i]["time"]
cassiopeia_tree.set_times(time_dictionary)

# Prune dead lineages and collapse resulting unifurcations
to_remove = list(set(cassiopeia_tree.leaves) - set(observed_nodes))
cassiopeia_tree.remove_leaves_and_prune_lineages(to_remove)
if self.collapse_unifurcations and len(cassiopeia_tree.nodes) > 1:
cassiopeia_tree.collapse_unifurcations(source="1")

# If only implicit root remains after pruning dead lineages, error
if len(cassiopeia_tree.nodes) == 1:
raise TreeSimulatorError(
"All lineages died before stopping condition"
)
cassiopeia_tree = self.populate_tree_from_simulation(
tree, observed_nodes
)

return cassiopeia_tree

Expand All @@ -266,7 +360,6 @@ def sample_lineage_event(
observed_nodes: List[str],
) -> None:
"""A helper function that samples an event for a lineage.
Takes a lineage and determines the next event in that lineage's
future. Simulates the lifespan of a new descendant. Birth and
death waiting times are sampled, representing how long the
Expand All @@ -285,7 +378,6 @@ def sample_lineage_event(
the lifespan is cut off at the experiment time and a final observed
sample is added to the tree. In this case the lineage is marked as
inactive as well.
Args:
unique_id: The unique ID number to be used to name a new node
added to the tree
Expand All @@ -300,7 +392,6 @@ def sample_lineage_event(
names: A generator providing unique names for tree nodes
observed_nodes: A list of nodes that are observed at the end of
the experiment
Raises:
TreeSimulatorError if a negative waiting time is sampled or a
non-active lineage is passed in
Expand Down Expand Up @@ -435,3 +526,41 @@ def update_fitness(self, birth_scale: float) -> float:
self.fitness_base ** self.fitness_distribution()
)
return birth_scale * base_selection_coefficient

def populate_tree_from_simulation(
self, tree: nx.DiGraph, observed_nodes: List[str]
) -> CassiopeiaTree:
"""Populates tree with appropriate meta data.
Args:
tree: The tree simulated with ecDNA and fitness values populated as
attributes.
observed_nodes: The observed leaves of the tree.
Returns:
A CassiopeiaTree with relevant node attributes filled in.
"""

cassiopeia_tree = CassiopeiaTree(tree=tree)

time_dictionary = {}
for i in tree.nodes:
time_dictionary[i] = tree.nodes[i]["time"]
cassiopeia_tree.set_attribute(
i, "birth_scale", tree.nodes[i]["birth_scale"]
)
cassiopeia_tree.set_times(time_dictionary)

# Prune dead lineages and collapse resulting unifurcations
to_remove = list(set(cassiopeia_tree.leaves) - set(observed_nodes))
cassiopeia_tree.remove_leaves_and_prune_lineages(to_remove)
if self.collapse_unifurcations and len(cassiopeia_tree.nodes) > 1:
cassiopeia_tree.collapse_unifurcations(source="1")

# If only implicit root remains after pruning dead lineages, error
if len(cassiopeia_tree.nodes) == 1:
raise TreeSimulatorError(
"All lineages died before stopping condition"
)

return cassiopeia_tree
1 change: 1 addition & 0 deletions cassiopeia/simulator/UniformLeafSubsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,5 @@ def subsample_leaves(
[(node, tree.get_time(node)) for node in subsampled_tree.nodes]
)
)

return subsampled_tree
2 changes: 2 additions & 0 deletions cassiopeia/simulator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .ClonalSpatialDataSimulator import ClonalSpatialDataSimulator
from .CompleteBinarySimulator import CompleteBinarySimulator
from .DataSimulator import DataSimulator
from .ecDNABirthDeathSimulator import ecDNABirthDeathSimulator
from .LeafSubsampler import LeafSubsampler
from .LineageTracingDataSimulator import LineageTracingDataSimulator
from .SimpleFitSubcloneSimulator import SimpleFitSubcloneSimulator
Expand All @@ -25,6 +26,7 @@
"SeqeuntialLineageTracingDataSimulator",
"CompleteBinarySimulator",
"DataSimulator",
"ecDNABirthDeathSimulator",
"LeafSubsampler",
"LineageTracingDataSimulator",
"SimpleFitSubcloneSimulator",
Expand Down
Loading

0 comments on commit 78442c4

Please sign in to comment.