diff --git a/Project.toml b/Project.toml index 9b54b5a676..5245347838 100644 --- a/Project.toml +++ b/Project.toml @@ -89,7 +89,7 @@ FunctionWrappersWrappers = "0.1" Graphs = "1.5.2" InteractiveUtils = "1" JuliaFormatter = "1.0.47" -JumpProcesses = "9.1" +JumpProcesses = "9.13.1" LabelledArrays = "1.3" Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16" Libdl = "1" diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index da75b7dfd6..0ad1a009eb 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -426,7 +426,8 @@ jprob = JumpProblem(complete(js), dprob, Direct()) sol = solve(jprob, SSAStepper()) ``` """ -function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback = nothing, +function JumpProcesses.JumpProblem(js::JumpSystem, prob, + aggregator = JumpProcesses.NullAggregator(); callback = nothing, eval_expression = false, eval_module = @__MODULE__, kwargs...) if !iscomplete(js) error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `JumpProblem`") @@ -448,7 +449,8 @@ function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback = error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps") jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs) - if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator) + if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator) || + (aggregator isa JumpProcesses.NullAggregator) jdeps = asgraph(js) vdeps = variable_dependencies(js) vtoj = jdeps.badjlist diff --git a/test/jumpsystem.jl b/test/jumpsystem.jl index 11c9fc1cd9..d14fa8b545 100644 --- a/test/jumpsystem.jl +++ b/test/jumpsystem.jl @@ -69,16 +69,21 @@ parammap = [β => 0.1 / 1000, γ => 0.01] dprob = DiscreteProblem(js2, u₀map, tspan, parammap) jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng) Nsims = 30000 -function getmean(jprob, Nsims) +function getmean(jprob, Nsims; use_stepper = true) m = 0.0 for i in 1:Nsims - sol = solve(jprob, SSAStepper()) + sol = use_stepper ? solve(jprob, SSAStepper()) : solve(jprob) m += sol[end, end] end m / Nsims end m = getmean(jprob, Nsims) +# test auto-alg selection works +jprobb = JumpProblem(js2, dprob; save_positions = (false, false), rng) +mb = getmean(jprobb, Nsims; use_stepper = false) +@test abs(m - mb) / m < 0.01 + @variables S2(t) obs = [S2 ~ 2 * S] @named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs) @@ -89,7 +94,6 @@ sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10) @test all(2 .* sol[S] .== sol[S2]) # test save_positions is working - jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng) sol = solve(jprob, SSAStepper(), saveat = 1.0) @test all((sol.t) .== collect(0.0:tspan[2])) @@ -270,3 +274,22 @@ affect = [X ~ X - 1] j1 = ConstantRateJump(k, [X ~ X - 1]) @test_nowarn @mtkbuild js1 = JumpSystem([j1], t, [X], [k]) + +# test correct autosolver is selected, which implies appropriate dep graphs are available +let + @parameters k + @variables X(t) + rate = k + affect = [X ~ X - 1] + j1 = ConstantRateJump(k, [X ~ X - 1]) + + Nv = [1, JumpProcesses.USE_DIRECT_THRESHOLD + 1, JumpProcesses.USE_RSSA_THRESHOLD + 1] + algtypes = [Direct, RSSA, RSSACR] + for (N, algtype) in zip(Nv, algtypes) + @named jsys = JumpSystem([deepcopy(j1) for _ in 1:N], t, [X], [k]) + jsys = complete(jsys) + dprob = DiscreteProblem(jsys, [X => 10], (0.0, 10.0), [k => 1]) + jprob = JumpProblem(jsys, dprob) + @test jprob.aggregator isa algtype + end +end