diff --git a/src/statmodels.jl b/src/statmodels.jl index 30eaf0a74..12702c175 100644 --- a/src/statmodels.jl +++ b/src/statmodels.jl @@ -336,12 +336,31 @@ function show(io::IO, ct::CoefTable) end """ - ConvergenceException(iters::Int) + ConvergenceException(iters::Int, lastchange::Real=NaN, tol::Real=NaN) -The fitting procedure failed to converge in `iters` number of iterations. +The fitting procedure failed to converge in `iters` number of iterations, +i.e. the `lastchange` between the cost of the final and penultimate iteration was greater than +specified tolerance `tol`. """ -struct ConvergenceException <: Exception +struct ConvergenceException{T<:Real} <: Exception iters::Int + lastchange::T + tol::T + function ConvergenceException{T}(iters, lastchange::T, tol::T) where T<:Real + if tol > lastchange + throw(ArgumentError("Change must be greater than tol.")) + else + new(iters, lastchange, tol) + end + end end -Base.showerror(io::IO, ce::ConvergenceException) = print(io, "failure to converge after $(ce.iters) iterations") +ConvergenceException(iters, lastchange::T=NaN, tol::T=NaN) where {T<:Real} = + ConvergenceException{T}(iters, lastchange, tol) + +function Base.showerror(io::IO, ce::ConvergenceException) + print(io, "failure to converge after $(ce.iters) iterations.") + if !isnan(ce.lastchange) + print(io, " Last change ($(ce.lastchange)) was greater than tolerance ($(ce.tol)).") + end +end diff --git a/test/statmodels.jl b/test/statmodels.jl index 7cc7b5c52..794ea07b3 100644 --- a/test/statmodels.jl +++ b/test/statmodels.jl @@ -31,8 +31,10 @@ x3 0.453058 0.72525 0.999172 0.5567 @test_throws ErrorException StatsBase.PValue(-0.1) @test_throws ErrorException StatsBase.PValue(1.1) -try - throw(ConvergenceException(10)) -catch ex - @test sprint(showerror, ex) == "failure to converge after 10 iterations" -end +@test sprint(showerror, ConvergenceException(10)) == "failure to converge after 10 iterations." + +@test sprint(showerror, ConvergenceException(10, 0.2, 0.1)) == + "failure to converge after 10 iterations. Last change (0.2) was greater than tolerance (0.1)." + +err = @test_throws ArgumentError ConvergenceException(10,.1,.2) +@test err.value.msg == "Change must be greater than tol."