Skip to content

Commit

Permalink
Merge pull request #2946 from isaacsas/better_vrj_support
Browse files Browse the repository at this point in the history
support JumpProblems over ODEProblems
  • Loading branch information
ChrisRackauckas authored Aug 11, 2024
2 parents c10d00b + 42ebdf4 commit e4e7d97
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 52 deletions.
16 changes: 8 additions & 8 deletions src/systems/dependency_graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ equation_dependencies(jumpsys)
equation_dependencies(jumpsys, variables = parameters(jumpsys))
```
"""
function equation_dependencies(sys::AbstractSystem; variables = unknowns(sys))
eqs = equations(sys)
function equation_dependencies(sys::AbstractSystem; variables = unknowns(sys),
eqs = equations(sys))
deps = Set()
depeqs_to_vars = Vector{Vector}(undef, length(eqs))

Expand Down Expand Up @@ -114,8 +114,9 @@ digr = asgraph(jumpsys)
```
"""
function asgraph(sys::AbstractSystem; variables = unknowns(sys),
variablestoids = Dict(v => i for (i, v) in enumerate(variables)))
asgraph(equation_dependencies(sys, variables = variables), variablestoids)
variablestoids = Dict(v => i for (i, v) in enumerate(variables)),
eqs = equations(sys))
asgraph(equation_dependencies(sys; variables, eqs), variablestoids)
end

"""
Expand All @@ -141,8 +142,7 @@ variable_dependencies(jumpsys)
```
"""
function variable_dependencies(sys::AbstractSystem; variables = unknowns(sys),
variablestoids = nothing)
eqs = equations(sys)
variablestoids = nothing, eqs = equations(sys))
vtois = isnothing(variablestoids) ? Dict(v => i for (i, v) in enumerate(variables)) :
variablestoids

Expand Down Expand Up @@ -193,8 +193,8 @@ dg = asdigraph(digr, jumpsys)
```
"""
function asdigraph(g::BipartiteGraph, sys::AbstractSystem; variables = unknowns(sys),
equationsfirst = true)
neqs = length(equations(sys))
equationsfirst = true, eqs = equations(sys))
neqs = length(eqs)
nvars = length(variables)
fadjlist = deepcopy(g.fadjlist)
badjlist = deepcopy(g.badjlist)
Expand Down
69 changes: 64 additions & 5 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ function JumpSystem(eqs, iv, unknowns, ps;
metadata, gui_metadata, checks = checks)
end

has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])

function generate_rate_function(js::JumpSystem, rate)
consts = collect_constants(rate)
if !isempty(consts) # The SymbolicUtils._build_function method of this case doesn't support postprocess_fbody
Expand Down Expand Up @@ -311,7 +315,7 @@ end
```julia
DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan,
parammap = DiffEqBase.NullParameters;
use_union = false,
use_union = true,
kwargs...)
```
Expand All @@ -331,7 +335,6 @@ dprob = DiscreteProblem(complete(js), u₀map, tspan, parammap)
"""
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
parammap = DiffEqBase.NullParameters();
checkbounds = false,
use_union = true,
eval_expression = false,
eval_module = @__MODULE__,
Expand Down Expand Up @@ -385,7 +388,7 @@ struct DiscreteProblemExpr{iip} end

function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
parammap = DiffEqBase.NullParameters();
use_union = false,
use_union = true,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
Expand All @@ -412,6 +415,60 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
end
end

"""
```julia
DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan,
parammap = DiffEqBase.NullParameters;
use_union = true,
kwargs...)
```
Generates a blank ODEProblem for a pure jump JumpSystem to utilize as its `prob.prob`. This
is used in the case where there are no ODEs and no SDEs associated with the system but there
are jumps with an explicit time dependency (i.e. `VariableRateJump`s). If no jumps have an
explicit time dependence, i.e. all are `ConstantRateJump`s or `MassActionJump`s then
`DiscreteProblem` should be preferred for performance reasons.
Continuing the example from the [`JumpSystem`](@ref) definition:
```julia
using DiffEqBase, JumpProcesses
u₀map = [S => 999, I => 1, R => 0]
parammap = [β => 0.1 / 1000, γ => 0.01]
tspan = (0.0, 250.0)
oprob = ODEProblem(complete(js), u₀map, tspan, parammap)
```
"""
function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
parammap = DiffEqBase.NullParameters();
use_union = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...)
if !iscomplete(sys)
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
end
dvs = unknowns(sys)
ps = parameters(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
defs = mergedefaults(defs, u0map, dvs)

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

f = (du, u, p, t) -> (du .= 0; nothing)
df = ODEFunction(f; sys, observed = observedfun)
ODEProblem(df, u0, tspan, p; kwargs...)
end

"""
```julia
DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
Expand Down Expand Up @@ -449,10 +506,12 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob,
error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)

# dep graphs are only for constant rate jumps
nonvrjs = ArrayPartition(eqs.x[1], eqs.x[2])
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator) ||
(aggregator isa JumpProcesses.NullAggregator)
jdeps = asgraph(js)
vdeps = variable_dependencies(js)
jdeps = asgraph(js; eqs = nonvrjs)
vdeps = variable_dependencies(js; eqs = nonvrjs)
vtoj = jdeps.badjlist
jtov = vdeps.badjlist
jtoj = needs_depgraph(aggregator) ? eqeq_dependencies(jdeps, vdeps).fadjlist :
Expand Down
86 changes: 80 additions & 6 deletions test/dep_graphs.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test
using ModelingToolkit, Graphs, JumpProcesses
using ModelingToolkit, Graphs, JumpProcesses, RecursiveArrayTools
using ModelingToolkit: t_nounits as t, D_nounits as D
import ModelingToolkit: value

Expand All @@ -16,11 +16,11 @@ j₅ = ConstantRateJump(k1 * I, [R ~ R + 1])
j₆ = VariableRateJump(k1 * k2 / (1 + t) * S, [S ~ S - 1, R ~ R + 1])
eqs = [j₁, j₂, j₃, j₄, j₅, j₆]
@named js = JumpSystem(eqs, t, [S, I, R], [k1, k2])
S = value(S);
I = value(I);
R = value(R);
k1 = value(k1);
k2 = value(k2);
S = value(S)
I = value(I)
R = value(R)
k1 = value(k1)
k2 = value(k2)
# eq to vars they depend on
eq_sdeps = [Variable[], [S], [S, I], [S, R], [I], [S]]
eq_sidepsf = [Int[], [1], [1, 2], [1, 3], [2], [1]]
Expand Down Expand Up @@ -72,6 +72,80 @@ end
dg4 = varvar_dependencies(depsbg, deps2)
@test dg == dg4

# testing when ignoring VariableRateJumps
let
@parameters k1 k2
@variables S(t) I(t) R(t)
j₁ = MassActionJump(k1, [0 => 1], [S => 1])
j₂ = MassActionJump(k1, [S => 1], [S => -1])
j₃ = MassActionJump(k2, [S => 1, I => 1], [S => -1, I => 1])
j₄ = MassActionJump(k2, [S => 2, R => 1], [R => -1])
j₅ = ConstantRateJump(k1 * I, [R ~ R + 1])
j₆ = VariableRateJump(k1 * k2 / (1 + t) * S, [S ~ S - 1, R ~ R + 1])
eqs = [j₁, j₂, j₃, j₄, j₅, j₆]
@named js = JumpSystem(eqs, t, [S, I, R], [k1, k2])
S = value(S)
I = value(I)
R = value(R)
k1 = value(k1)
k2 = value(k2)
# eq to vars they depend on
eq_sdeps = [Variable[], [S], [S, I], [S, R], [I]]
eq_sidepsf = [Int[], [1], [1, 2], [1, 3], [2]]
eq_sidepsb = [[2, 3, 4], [3, 5], [4]]

# filter out vrjs in making graphs
eqs = ArrayPartition(equations(js).x[1], equations(js).x[2])
deps = equation_dependencies(js; eqs)
@test length(deps) == length(eq_sdeps)
@test all(i -> isequal(Set(eq_sdeps[i]), Set(deps[i])), 1:length(eqs))
depsbg = asgraph(js; eqs)
@test depsbg.fadjlist == eq_sidepsf
@test depsbg.badjlist == eq_sidepsb

# eq to params they depend on
eq_pdeps = [[k1], [k1], [k2], [k2], [k1]]
eq_pidepsf = [[1], [1], [2], [2], [1]]
eq_pidepsb = [[1, 2, 5], [3, 4]]
deps = equation_dependencies(js; variables = parameters(js), eqs)
@test length(deps) == length(eq_pdeps)
@test all(i -> isequal(Set(eq_pdeps[i]), Set(deps[i])), 1:length(eqs))
depsbg2 = asgraph(js; variables = parameters(js), eqs)
@test depsbg2.fadjlist == eq_pidepsf
@test depsbg2.badjlist == eq_pidepsb

# var to eqs that modify them
s_eqdepsf = [[1, 2, 3], [3], [4, 5]]
s_eqdepsb = [[1], [1], [1, 2], [3], [3]]
ne = 6
bg = BipartiteGraph(ne, s_eqdepsf, s_eqdepsb)
deps2 = variable_dependencies(js; eqs)
@test isequal(bg, deps2)

# eq to eqs that depend on them
eq_eqdeps = [[2, 3, 4], [2, 3, 4], [2, 3, 4, 5], [4], [4], [2, 3, 4]]
dg = SimpleDiGraph(5)
for (eqidx, eqdeps) in enumerate(eq_eqdeps)
for eqdepidx in eqdeps
add_edge!(dg, eqidx, eqdepidx)
end
end
dg3 = eqeq_dependencies(depsbg, deps2)
@test dg == dg3

# var to vars that depend on them
var_vardeps = [[1, 2, 3], [1, 2, 3], [3]]
ne = 7
dg = SimpleDiGraph(3)
for (vidx, vdeps) in enumerate(var_vardeps)
for vdepidx in vdeps
add_edge!(dg, vidx, vdepidx)
end
end
dg4 = varvar_dependencies(depsbg, deps2)
@test dg == dg4
end

#####################################
# testing for ODE/SDEs
#####################################
Expand Down
Loading

0 comments on commit e4e7d97

Please sign in to comment.