Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ready for Review] - Constant-complexity NRM implementation #455

Merged
merged 34 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
<!-- [![Coverage Status](https://coveralls.io/repos/github/SciML/JumpProcesses.jl/badge.svg?branch=master)](https://coveralls.io/github/SciML/JumpProcesses.jl?branch=master)
[![codecov](https://codecov.io/gh/SciML/JumpProcesses.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/SciML/JumpProcesses.jl) -->
<!-- [![Join the chat at https://julialang.zulipchat.com #sciml-bridged](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/279055-sciml-bridged) -->

[![Build Status](https://github.com/SciML/JumpProcesses.jl/workflows/CI/badge.svg)](https://github.com/SciML/JumpProcesses.jl/actions?query=workflow%3ACI)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)
Expand Down
3 changes: 2 additions & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ include("aggregators/directcr.jl")
include("aggregators/rssacr.jl")
include("aggregators/rdirect.jl")
include("aggregators/coevolve.jl")
include("aggregators/ccnrm.jl")

# spatial:
include("spatial/spatial_massaction_jump.jl")
Expand Down Expand Up @@ -85,7 +86,7 @@ export SplitCoupledJumpProblem

export Direct, DirectFW, SortingDirect, DirectCR
export BracketData, RSSA
export FRM, FRMFW, NRM
export FRM, FRMFW, NRM, CCNRM
export RSSACR, RDirect
export Coevolve

Expand Down
12 changes: 11 additions & 1 deletion src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ evolution, Journal of Machine Learning Research 18(1), 1305–1353 (2017). doi:
"""
struct Coevolve <: AbstractAggregatorAlgorithm end

"""
A constant-complexity NRM method. Stores next reaction times in a table with a specified bin width.

Kevin R. Sanft and Hans G. Othmer, Constant-complexity stochastic simulation
algorithm with optimal binning, Journal of Chemical Physics 143, 074108
(2015). doi: 10.1063/1.4928635.
"""
struct CCNRM <: AbstractAggregatorAlgorithm end

# spatial methods

"""
Expand All @@ -164,7 +173,7 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108
struct DirectCRDirect <: AbstractAggregatorAlgorithm end

const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(),
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve())
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), CCNRM())

# For JumpProblem construction without an aggregator
struct NullAggregator <: AbstractAggregatorAlgorithm end
Expand All @@ -174,6 +183,7 @@ needs_depgraph(aggregator::AbstractAggregatorAlgorithm) = false
needs_depgraph(aggregator::DirectCR) = true
needs_depgraph(aggregator::SortingDirect) = true
needs_depgraph(aggregator::NRM) = true
needs_depgraph(aggregator::CCNRM) = true
needs_depgraph(aggregator::RDirect) = true
needs_depgraph(aggregator::Coevolve) = true

Expand Down
159 changes: 159 additions & 0 deletions src/aggregators/ccnrm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Implementation of the constant-complexity Next Reaction Method
# Kevin R. Sanft and Hans G. Othmer, Constant-complexity stochastic simulation
# algorithm with optimal binning, Journal of Chemical Physics 143, 074108
# (2015). doi: 10.1063/1.4928635.

mutable struct CCNRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PT} <:
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
next_jump_time::T
end_time::T
cur_rates::Vector{T}
sum_rate::T
ma_jumps::S
rates::F1
affects!::F2
save_positions::Tuple{Bool, Bool}
rng::RNG
dep_gr::DEPGR
ptt::PT
end

function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,
maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool},
rng::RNG; num_specs, dep_graph = nothing,
kwargs...) where {T, S, F1, F2, RNG}

# a dependency graph is needed and must be provided if there are constant rate jumps
if dep_graph === nothing
if (get_num_majumps(maj) == 0) || !isempty(rs)
error("To use ConstantRateJumps with the constant-complexity Next Reaction Method (CCNRM) algorithm a dependency graph must be supplied.")
else
dg = make_dependency_graph(num_specs, maj)
end
else
dg = dep_graph

# make sure each jump depends on itself
add_self_dependencies!(dg)
end

ptt = PriorityTimeTable(zeros(T, length(crs)), 0.0, 1.0) # We will re-initialize this in initialize!()

affecttype = F2 <: Tuple ? F2 : Any
CCNRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(ptt)}(
nj, nj, njt, et,
crs, sr, maj,
rs, affs!, sps,
rng, dg, ptt)
end

+############################# Required Functions ##############################
# creating the JumpAggregation structure (function wrapper-based constant jumps)
function aggregate(aggregator::CCNRM, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; kwargs...)

# handle constant jumps using function wrappers
rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps)

build_jump_aggregation(CCNRMJumpAggregation, u, p, t, end_time, ma_jumps,
rates, affects!, save_positions, rng; num_specs = length(u),
kwargs...)
end

# set up a new simulation and calculate the first jump / jump time
function initialize!(p::CCNRMJumpAggregation, integrator, u, params, t)
p.end_time = integrator.sol.prob.tspan[2]
initialize_rates_and_times!(p, u, params, t)
generate_jumps!(p, integrator, u, params, t)
nothing
end

# execute one jump, changing the system state
function execute_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t, affects!)
# execute jump
u = update_state!(p, integrator, u, affects!)

# update current jump rates and times
update_dependent_rates!(p, u, params, t)
nothing
end

# calculate the next jump / jump time
# just the first reaction in the first non-empty bin in the priority table
function generate_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t)
p.next_jump, p.next_jump_time = getfirst(p.ptt)

# Rebuild the table if no next jump is found.
if p.next_jump === nothing
vyudu marked this conversation as resolved.
Show resolved Hide resolved
binwidth = 16 / sum(p.cur_rates)
min_time = minimum(p.ptt.times)
rebuild!(p.ptt, min_time, binwidth)
p.next_jump, p.next_jump_time = getfirst(p.ptt)
end

nothing
end

######################## SSA specific helper routines ########################

# Recalculate jump rates for jumps that depend on the just executed jump (p.next_jump)
function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t)
@inbounds dep_rxs = p.dep_gr[p.next_jump]
@unpack ptt, cur_rates, rates, ma_jumps, end_time = p
num_majumps = get_num_majumps(ma_jumps)

@inbounds for rx in dep_rxs
oldrate = cur_rates[rx]
times = ptt.times
oldtime = times[rx]

# update the jump rate
@inbounds cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u,
params, t, rx)

# Calculate new jump times for dependent jumps
if rx != p.next_jump && oldrate > zero(oldrate)
if cur_rates[rx] > zero(eltype(cur_rates))
update!(ptt, rx, oldtime, t + oldrate / cur_rates[rx] * (times[rx] - t))
vyudu marked this conversation as resolved.
Show resolved Hide resolved
else
update!(ptt, rx, oldtime, 2 * end_time)
vyudu marked this conversation as resolved.
Show resolved Hide resolved
end
else
if cur_rates[rx] > zero(eltype(cur_rates))
update!(ptt, rx, oldtime, t + randexp(p.rng) / cur_rates[rx])
else
update!(ptt, rx, oldtime, 2 * end_time)
vyudu marked this conversation as resolved.
Show resolved Hide resolved
end
end
end
nothing
end

# Evaluate all the rates and initialize the times in the priority table.
function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t)
# Initialize next-reaction times for the mass action jumps
majumps = p.ma_jumps
cur_rates = p.cur_rates
pttdata = Vector{typeof(t)}(undef, length(cur_rates))
@inbounds for i in 1:get_num_majumps(majumps)
cur_rates[i] = evalrxrate(u, i, majumps)
pttdata[i] = t + randexp(p.rng) / cur_rates[i]
end

# Initialize next-reaction times for the constant rates
rates = p.rates
idx = get_num_majumps(majumps) + 1
@inbounds for rate in rates
cur_rates[idx] = rate(u, params, t)
pttdata[idx] = t + randexp(p.rng) / cur_rates[idx]
idx += 1
end

# Build the priority time table with the times and bin width.
binwidth = 16 / sum(cur_rates)
vyudu marked this conversation as resolved.
Show resolved Hide resolved
p.ptt.times = pttdata
rebuild!(p.ptt, t, binwidth)
nothing
end
Loading
Loading