diff --git a/README.md b/README.md index 30fde6e0..d2f9c4bc 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ - [![Build Status](https://github.com/SciML/JumpProcesses.jl/workflows/CI/badge.svg)](https://github.com/SciML/JumpProcesses.jl/actions?query=workflow%3ACI) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index f11d5b0a..a382d208 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -53,6 +53,7 @@ include("aggregators/directcr.jl") include("aggregators/rssacr.jl") include("aggregators/rdirect.jl") include("aggregators/coevolve.jl") +include("aggregators/ccnrm.jl") # spatial: include("spatial/spatial_massaction_jump.jl") @@ -85,7 +86,7 @@ export SplitCoupledJumpProblem export Direct, DirectFW, SortingDirect, DirectCR export BracketData, RSSA -export FRM, FRMFW, NRM +export FRM, FRMFW, NRM, CCNRM export RSSACR, RDirect export Coevolve diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index e64501d7..b8dd8dd5 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -138,6 +138,15 @@ evolution, Journal of Machine Learning Research 18(1), 1305–1353 (2017). doi: """ struct Coevolve <: AbstractAggregatorAlgorithm end +""" +A constant-complexity NRM method. Stores next reaction times in a table with a specified bin width. + +Kevin R. Sanft and Hans G. Othmer, Constant-complexity stochastic simulation +algorithm with optimal binning, Journal of Chemical Physics 143, 074108 +(2015). doi: 10.1063/1.4928635. +""" +struct CCNRM <: AbstractAggregatorAlgorithm end + # spatial methods """ @@ -164,7 +173,7 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 struct DirectCRDirect <: AbstractAggregatorAlgorithm end const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), - FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve()) + FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), CCNRM()) # For JumpProblem construction without an aggregator struct NullAggregator <: AbstractAggregatorAlgorithm end @@ -174,6 +183,7 @@ needs_depgraph(aggregator::AbstractAggregatorAlgorithm) = false needs_depgraph(aggregator::DirectCR) = true needs_depgraph(aggregator::SortingDirect) = true needs_depgraph(aggregator::NRM) = true +needs_depgraph(aggregator::CCNRM) = true needs_depgraph(aggregator::RDirect) = true needs_depgraph(aggregator::Coevolve) = true diff --git a/src/aggregators/ccnrm.jl b/src/aggregators/ccnrm.jl new file mode 100644 index 00000000..9f1800f8 --- /dev/null +++ b/src/aggregators/ccnrm.jl @@ -0,0 +1,162 @@ +# Implementation of the constant-complexity Next Reaction Method +# Kevin R. Sanft and Hans G. Othmer, Constant-complexity stochastic simulation +# algorithm with optimal binning, Journal of Chemical Physics 143, 074108 +# (2015). doi: 10.1063/1.4928635. + +mutable struct CCNRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PT} <: + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + next_jump::Int + prev_jump::Int + next_jump_time::T + end_time::T + cur_rates::Vector{T} + sum_rate::T + ma_jumps::S + rates::F1 + affects!::F2 + save_positions::Tuple{Bool, Bool} + rng::RNG + dep_gr::DEPGR + ptt::PT +end + +function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, + rng::RNG; num_specs, dep_graph = nothing, + kwargs...) where {T, S, F1, F2, RNG} + + # a dependency graph is needed and must be provided if there are constant rate jumps + if dep_graph === nothing + if (get_num_majumps(maj) == 0) || !isempty(rs) + error("To use ConstantRateJumps with the constant-complexity Next Reaction Method (CCNRM) algorithm a dependency graph must be supplied.") + else + dg = make_dependency_graph(num_specs, maj) + end + else + dg = dep_graph + + # make sure each jump depends on itself + add_self_dependencies!(dg) + end + + binwidthconst = haskey(kwargs, :binwidthconst) ? kwargs[:binwidthconst] : 16 + numbinsconst = haskey(kwargs, :numbinsconst) ? kwargs[:numbinsconst] : 20 + ptt = PriorityTimeTable(zeros(T, length(crs)), zero(T), one(T), + binwidthconst = binwidthconst, numbinsconst = numbinsconst) # We will re-initialize this in initialize!() + + affecttype = F2 <: Tuple ? F2 : Any + CCNRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(ptt)}( + nj, nj, njt, et, + crs, sr, maj, + rs, affs!, sps, + rng, dg, ptt) +end + ++############################# Required Functions ############################## +# creating the JumpAggregation structure (function wrapper-based constant jumps) +function aggregate(aggregator::CCNRM, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; kwargs...) + + # handle constant jumps using function wrappers + rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) + + build_jump_aggregation(CCNRMJumpAggregation, u, p, t, end_time, ma_jumps, + rates, affects!, save_positions, rng; num_specs = length(u), + kwargs...) +end + +# set up a new simulation and calculate the first jump / jump time +function initialize!(p::CCNRMJumpAggregation, integrator, u, params, t) + p.end_time = integrator.sol.prob.tspan[2] + initialize_rates_and_times!(p, u, params, t) + generate_jumps!(p, integrator, u, params, t) + nothing +end + +# execute one jump, changing the system state +function execute_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t, affects!) + # execute jump + u = update_state!(p, integrator, u, affects!) + + # update current jump rates and times + update_dependent_rates!(p, u, params, t) + nothing +end + +# calculate the next jump / jump time +# just the first reaction in the first non-empty bin in the priority table +function generate_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t) + p.next_jump, p.next_jump_time = getfirst(p.ptt) + + # Rebuild the table if no next jump is found. + if p.next_jump == 0 + timestep = 1 / sum(p.cur_rates) + min_time = minimum(p.ptt.times) + rebuild!(p.ptt, min_time, timestep) + p.next_jump, p.next_jump_time = getfirst(p.ptt) + end + + nothing +end + +######################## SSA specific helper routines ######################## + +# Recalculate jump rates for jumps that depend on the just executed jump (p.next_jump) +function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) + @inbounds dep_rxs = p.dep_gr[p.next_jump] + @unpack ptt, cur_rates, rates, ma_jumps, end_time = p + num_majumps = get_num_majumps(ma_jumps) + + @inbounds for rx in dep_rxs + oldrate = cur_rates[rx] + times = ptt.times + oldtime = times[rx] + + # update the jump rate + @inbounds cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u, + params, t, rx) + + # Calculate new jump times for dependent jumps + if rx != p.next_jump && oldrate > zero(oldrate) + if cur_rates[rx] > zero(eltype(cur_rates)) + update!(ptt, rx, oldtime, t + oldrate / cur_rates[rx] * (oldtime - t)) + else + update!(ptt, rx, oldtime, floatmax(typeof(t))) + end + else + if cur_rates[rx] > zero(eltype(cur_rates)) + update!(ptt, rx, oldtime, t + randexp(p.rng) / cur_rates[rx]) + else + update!(ptt, rx, oldtime, floatmax(typeof(t))) + end + end + end + nothing +end + +# Evaluate all the rates and initialize the times in the priority table. +function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t) + # Initialize next-reaction times for the mass action jumps + majumps = p.ma_jumps + cur_rates = p.cur_rates + pttdata = Vector{typeof(t)}(undef, length(cur_rates)) + @inbounds for i in 1:get_num_majumps(majumps) + cur_rates[i] = evalrxrate(u, i, majumps) + pttdata[i] = t + randexp(p.rng) / cur_rates[i] + end + + # Initialize next-reaction times for the constant rates + rates = p.rates + idx = get_num_majumps(majumps) + 1 + @inbounds for rate in rates + cur_rates[idx] = rate(u, params, t) + pttdata[idx] = t + randexp(p.rng) / cur_rates[idx] + idx += 1 + end + + # Build the priority time table with the times and bin width. + timestep = 1 / sum(cur_rates) + p.ptt.times = pttdata + rebuild!(p.ptt, t, timestep) + nothing +end diff --git a/src/aggregators/prioritytable.jl b/src/aggregators/prioritytable.jl index fae23478..7184ee02 100644 --- a/src/aggregators/prioritytable.jl +++ b/src/aggregators/prioritytable.jl @@ -11,7 +11,7 @@ The ranges are assumed to be powers of two: bin 1 = {0}, bin 2 = (0,`minpriority`), bin 3 = [`minpriority`,`2*minpriority`)... - bin N = [`.5*maxpriority`,`maxpriority`) +in N = [`.5*maxpriority`,`maxpriority`) *Assumes* the `priortogid` function that maps priorities to groups maps the upper end of the interval to the next group. i.e. maxpriority -> N+1 """ @@ -57,17 +57,23 @@ end lastpid end +@inline function ids(pg::PriorityGroup) + pg.pids[1:(pg.numpids)] +end + function Base.show(io::IO, pg::PriorityGroup) println(io, " ", summary(pg)) println(io, " maxpriority = ", pg.maxpriority) println(io, " numpids = ", pg.numpids) - println(io, " pids = ", pg.pids[1:(pg.numpids)]) + println(io, " pids = ", ids(pg)) end """ Table to store the groups. """ -mutable struct PriorityTable{F, S, T, U <: Function} +abstract type AbstractPriorityTable end + +mutable struct PriorityTable{F, S, T, U <: Function} <: AbstractPriorityTable "non-zero values below this are binned together, static" minpriority::F @@ -250,9 +256,9 @@ function Base.show(io::IO, pt::PriorityTable) end end -####################################################### -# routines for DirectCR -####################################################### +############################# +### routines for DirectCR ### +############################# # map priority (i.e. jump rate) to integer # add two as 0. -> 1 and priority < minpriority ==> pid -> 2 @@ -312,3 +318,165 @@ function sample(pt::PriorityTable, priorities, rng = DEFAULT_RNG) # sample element within the group @inbounds sample(groups[gid], priorities, rng) end + +########################## +### Routines for CCNRM ### +########################## + +struct TimeGrouper{T <: Number} + mintime::T + binwidth::T +end + +@inline function (t::TimeGrouper{T})(time::T) where {T} + return floor(Int, (time - t.mintime) / t.binwidth) + 1 +end + +mutable struct PriorityTimeTable{T, F <: Int} + groups::Vector{PriorityGroup{T, Vector{F}}} + pidtogroup::Vector{Tuple{F, F}} + times::Vector{T} + timegrouper::TimeGrouper{T} + minbin::F + steps::F # TODO: For adaptive rebuilding. + maxtime::T + binwidthconst::F + numbinsconst::F +end + +# Construct the time table with the default optimal bin width and number of bins. +# DEFAULT NUMBINS: 20 * √length(times) +# DEFAULT BINWIDTH: 16 / sum(propensities) +function PriorityTimeTable( + times::AbstractVector, mintime, timestep; binwidthconst = 16, numbinsconst = 20) + binwidth = binwidthconst * timestep + numbins = floor(Int64, numbinsconst * sqrt(length(times))) + maxtime = mintime + numbins * binwidth + + pidtype = typeof(numbins) + ptype = eltype(times) + groups = Vector{PriorityGroup{ptype, Vector{pidtype}}}() + pidtogroup = Vector{Tuple{Int, Int}}(undef, length(times)) + + ttgdata = TimeGrouper{ptype}(mintime, binwidth) + # Create the groups, [t_min, t_min + τ), [t_min + τ, t_min + 2τ)... + for i in 1:numbins + push!(groups, PriorityGroup{pidtype}(mintime + i * binwidth)) + end + + ptt = PriorityTimeTable( + groups, pidtogroup, times, ttgdata, zero(pidtype), + zero(pidtype), maxtime, binwidthconst, numbinsconst) + # Insert priority ids into the groups + for (pid, time) in enumerate(times) + if time > maxtime + pidtogroup[pid] = (0, 0) + continue + end + insert!(ptt, pid, time) + end + + ptt.minbin = findfirst(g -> g.numpids > (0), groups) + ptt.minbin === nothing && (ptt.minbin = 0) + ptt +end + +# Rebuild the table when there are no more reaction times within the current +# time window. +function rebuild!(ptt::PriorityTimeTable{T, F}, mintime, timestep) where {T, F} + @unpack pidtogroup, groups, times, binwidthconst = ptt + fill!(pidtogroup, (zero(F), zero(F))) + + numbins = length(groups) + binwidth = binwidthconst * timestep + ptt.maxtime = mintime + numbins * binwidth + ptt.timegrouper = TimeGrouper(mintime, binwidth) + + groupmaxtime = mintime + for group in groups + group.numpids = zero(F) + groupmaxtime += binwidth + group.maxpriority = groupmaxtime + end + + # Reinsert the times into the groups. + for (id, time) in enumerate(times) + time > ptt.maxtime && continue + insert!(ptt, id, time) + end + ptt.minbin = findfirst(g -> g.numpids > (0), groups) + ptt.minbin === nothing && (ptt.minbin = 0) + ptt.steps = 0 + + return nothing +end + +# Get the reaction with the earliest timestep. +function getfirst(ptt::PriorityTimeTable) + @unpack groups, times, minbin = ptt + minbin == 0 && return (0, 0) + + while groups[minbin].numpids == 0 + minbin += 1 + if minbin > length(groups) + return (0, 0) + end + end + + ptt.minbin = minbin + ptt.steps += 1 + min_time = typemax(eltype(times)) + min_idx = 0 + @inbounds for i in 1:(groups[minbin].numpids) + pid = groups[minbin].pids[i] + times[pid] < min_time && begin + min_time = times[pid] + min_idx = pid + end + end + + return min_idx, min_time +end + +function insert!(ptt::PriorityTimeTable, pid, time) + @unpack timegrouper, pidtogroup, groups = ptt + gid = timegrouper(time) + @inbounds pididx = insert!(groups[gid], pid) + @inbounds pidtogroup[pid] = (gid, pididx) + + return nothing +end + +# Update the priority table when a reaction time gets updated. We only shift +# between bins if the new time is within the current time window; otherwise +# we remove the reaction and wait until rebuild. +function update!(ptt::PriorityTimeTable{T, F}, pid, oldtime, newtime) where {T, F} + @unpack times, timegrouper, maxtime, pidtogroup, groups = ptt + + times[pid] = newtime + if oldtime >= maxtime + # If a reaction comes back into the time window, insert it. + newtime < maxtime ? insert!(ptt, pid, newtime) : return nothing + elseif newtime >= maxtime + # If the new time lands outside of current window, remove it. + @inbounds begin + gid, pidx = pidtogroup[pid] + movedpid = remove!(groups[gid], pidx) + pidtogroup[movedpid] = (gid, pidx) + pidtogroup[pid] = (zero(F), zero(F)) + end + else + # Move bins if the reaction was already inside. + oldgid = timegrouper(oldtime) + newgid = timegrouper(newtime) + oldgid == newgid && return nothing + @inbounds begin + pidx = pidtogroup[pid][2] + movedpid = remove!(groups[oldgid], pidx) + pidtogroup[movedpid] = (oldgid, pidx) + newpidx = insert!(groups[newgid], pid) + pidtogroup[pid] = (newgid, newpidx) + end + end + return nothing +end diff --git a/test/table_test.jl b/test/table_test.jl index 02313fd3..d7c3e6cd 100644 --- a/test/table_test.jl +++ b/test/table_test.jl @@ -61,3 +61,37 @@ for i in 1:Nsamps (pid == 8) && (cnt += 1) end @test abs(cnt // Nsamps - 0.008968535978248484) / 0.008968535978248484 < 0.05 + +##### PRIORITY TIME TABLE TESTS FOR CCNRM +mintime = 0.0; +maxtime = 100.0; +timestep = 1.5/16; +times = [2.0, 8.0, 13.0, 15.0, 74.0] + +ptt = DJ.PriorityTimeTable(times, mintime, timestep) +@test DJ.getfirst(ptt) == (1, 2.0) +@test ptt.pidtogroup[5] == (0, 0) # Should not store the last one, outside the window. + +# Test update +DJ.update!(ptt, 1, times[1], 10 * times[1]) # 2. -> 20., group 2 to group 14 +@test ptt.groups[14].numpids == 1 +@test DJ.getfirst(ptt) == (2, 8.0) +# Updating beyond the time window should not change the max priority. +DJ.update!(ptt, 1, times[1], 70.0) # 20. -> 70. +@test ptt.groups[14].numpids == 0 +@test ptt.maxtime == 66.0 +@test ptt.pidtogroup[1] == (0, 0) + +# Test rebuild +for i in 1:4 + DJ.update!(ptt, i, times[i], times[i] + 66.0) +end +@test DJ.getfirst(ptt) === (0, 0) # No more left. + +mintime = 66.0; +timestep = 0.75/16; +DJ.rebuild!(ptt, mintime, timestep) +@test ptt.groups[11].numpids == 2 # 73.5-74.25 +@test ptt.groups[18].numpids == 1 +@test ptt.groups[21].numpids == 1 +@test ptt.pidtogroup[1] == (0, 0)