Skip to content

Commit

Permalink
Add lastchange and tol fields to ConvergenceException (#284)
Browse files Browse the repository at this point in the history
Added new fields for `ConvergenceException` to show how far away from convergence the model was.
  • Loading branch information
alexmorley authored and nalimilan committed Aug 22, 2017
1 parent de36108 commit 4d03581
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
27 changes: 23 additions & 4 deletions src/statmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,31 @@ function show(io::IO, ct::CoefTable)
end

"""
ConvergenceException(iters::Int)
ConvergenceException(iters::Int, lastchange::Real=NaN, tol::Real=NaN)
The fitting procedure failed to converge in `iters` number of iterations.
The fitting procedure failed to converge in `iters` number of iterations,
i.e. the `lastchange` between the cost of the final and penultimate iteration was greater than
specified tolerance `tol`.
"""
struct ConvergenceException <: Exception
struct ConvergenceException{T<:Real} <: Exception
iters::Int
lastchange::T
tol::T
function ConvergenceException{T}(iters, lastchange::T, tol::T) where T<:Real
if tol > lastchange
throw(ArgumentError("Change must be greater than tol."))
else
new(iters, lastchange, tol)
end
end
end

Base.showerror(io::IO, ce::ConvergenceException) = print(io, "failure to converge after $(ce.iters) iterations")
ConvergenceException(iters, lastchange::T=NaN, tol::T=NaN) where {T<:Real} =
ConvergenceException{T}(iters, lastchange, tol)

function Base.showerror(io::IO, ce::ConvergenceException)
print(io, "failure to converge after $(ce.iters) iterations.")
if !isnan(ce.lastchange)
print(io, " Last change ($(ce.lastchange)) was greater than tolerance ($(ce.tol)).")
end
end
12 changes: 7 additions & 5 deletions test/statmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ x3 0.453058 0.72525 0.999172 0.5567
@test_throws ErrorException StatsBase.PValue(-0.1)
@test_throws ErrorException StatsBase.PValue(1.1)

try
throw(ConvergenceException(10))
catch ex
@test sprint(showerror, ex) == "failure to converge after 10 iterations"
end
@test sprint(showerror, ConvergenceException(10)) == "failure to converge after 10 iterations."

@test sprint(showerror, ConvergenceException(10, 0.2, 0.1)) ==
"failure to converge after 10 iterations. Last change (0.2) was greater than tolerance (0.1)."

err = @test_throws ArgumentError ConvergenceException(10,.1,.2)
@test err.value.msg == "Change must be greater than tol."

0 comments on commit 4d03581

Please sign in to comment.