From ffb3905845dd98143017267870b8dc3eb4c6a585 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Sun, 1 Sep 2024 22:17:16 -0400 Subject: [PATCH] make terminate! really terminate --- src/SSA_stepper.jl | 38 +++++++++++++++++++++----------------- test/extinction_test.jl | 5 +++-- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index 9e56d5bf..948b6883 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -122,27 +122,31 @@ function DiffEqBase.solve!(integrator::SSAIntegrator) while should_continue_solve(integrator) # It stops before adding a tstop over step!(integrator) end - integrator.t = end_time - # check callbacks one last time - if !(integrator.opts.callback.discrete_callbacks isa Tuple{}) - DiffEqBase.apply_discrete_callback!(integrator, - integrator.opts.callback.discrete_callbacks...) - end + # if the user terminated the solve we shouldn't advance in time any more + if integrator.sol.retcode !== ReturnCode.Terminated + integrator.t = end_time - if integrator.saveat !== nothing && !isempty(integrator.saveat) - # Split to help prediction - while integrator.cur_saveat <= length(integrator.saveat) && - integrator.saveat[integrator.cur_saveat] < integrator.t - push!(integrator.sol.t, integrator.saveat[integrator.cur_saveat]) - push!(integrator.sol.u, copy(integrator.u)) - integrator.cur_saveat += 1 + # check callbacks one last time + if !(integrator.opts.callback.discrete_callbacks isa Tuple{}) + DiffEqBase.apply_discrete_callback!(integrator, + integrator.opts.callback.discrete_callbacks...) end - end - if integrator.save_end && integrator.sol.t[end] != end_time - push!(integrator.sol.t, end_time) - push!(integrator.sol.u, copy(integrator.u)) + if integrator.saveat !== nothing && !isempty(integrator.saveat) + # Split to help prediction + while integrator.cur_saveat <= length(integrator.saveat) && + integrator.saveat[integrator.cur_saveat] < integrator.t + push!(integrator.sol.t, integrator.saveat[integrator.cur_saveat]) + push!(integrator.sol.u, copy(integrator.u)) + integrator.cur_saveat += 1 + end + end + + if integrator.save_end && integrator.sol.t[end] != end_time + push!(integrator.sol.t, end_time) + push!(integrator.sol.u, copy(integrator.u)) + end end DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator) diff --git a/test/extinction_test.jl b/test/extinction_test.jl index 0fef42a5..880254ba 100644 --- a/test/extinction_test.jl +++ b/test/extinction_test.jl @@ -73,7 +73,8 @@ end cb = DiscreteCallback(extinction_condition2, extinction_affect!2, save_positions = (false, false)) dprob = DiscreteProblem(u0, (0.0, 1000.0), rates) -jprob = JumpProblem(dprob, Direct(), majump; save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), callback = cb, save_end = false) +jprob = JumpProblem(dprob, majump; save_positions = (false, false), rng) +sol = solve(jprob; callback = cb, save_end = false) @test sol[1, end] == 1 @test sol.retcode == ReturnCode.Terminated +@test sol.t[end] < 1000.0