Skip to content

Commit 6a1063d

Browse files
ararslangrerogalenlynchalexmorley
committed
Add ConvergenceException from StatsBase
I've added the folks who contributed the type and improvements to it and its documentation as coauthors of the commit. (Apologies if I missed anyone.) Co-authored-by: Roger Herikstad <[email protected]> Co-authored-by: Galen Lynch <[email protected]> Co-authored-by: Alexander Morley <[email protected]>
1 parent 14e0ba7 commit 6a1063d

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

src/statisticalmodel.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,41 @@ function adjr2(model::StatisticalModel, variant::Symbol)
314314
end
315315

316316
const adjr² = adjr2
317+
318+
"""
319+
ConvergenceException(iterations::Int, lastchange::Real=NaN, tolerance::Real=NaN,
320+
message::String="")
321+
322+
The fitting procedure failed to converge in `iterations` number of iterations. Typically
323+
this is because the `lastchange` between the objective in the final and penultimate
324+
iterations was greater than the specified `tolerance`. Further information can be provided
325+
by `message`.
326+
"""
327+
struct ConvergenceException{T<:Real} <: Exception
328+
iterations::Int
329+
lastchange::T
330+
tolerance::T
331+
message::String
332+
333+
function ConvergenceException(iterations, lastchange=NaN, tolerance=NaN, message="")
334+
if tolerance > lastchange
335+
throw(ArgumentError("can't construct `ConvergenceException` with change " *
336+
"less than tolerance; got $lastchange and $tolerance"))
337+
end
338+
T = promote_type(typeof(lastchange), typeof(tolerance))
339+
return new{T}(iterations, lastchange, tolerance, message)
340+
end
341+
end
342+
343+
function Base.showerror(io::IO, ce::ConvergenceException)
344+
print(io, "failure to converge after ", ce.iterations, " iterations")
345+
if !isnan(ce.lastchange)
346+
print(io, "; last change between iterations (", ce.lastchange, ") was greater ",
347+
"than tolerance (", ce.tolerance, ")")
348+
end
349+
print(io, '.')
350+
if !isempty(ce.message)
351+
print(io, ' ', ce.message)
352+
end
353+
return nothing
354+
end

test/statisticalmodel.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module TestStatisticalModel
22

33
using Test, StatsAPI
4-
using StatsAPI: StatisticalModel, stderror, aic, aicc, bic, r2, r², adjr2, adjr²
4+
using StatsAPI: ConvergenceException, StatisticalModel, stderror, aic, aicc, bic,
5+
r2, r², adjr2, adjr²
56

67
struct MyStatisticalModel <: StatisticalModel
78
end
@@ -36,4 +37,16 @@ StatsAPI.nobs(::MyStatisticalModel) = 100
3637
@test adjr2 === adjr²
3738
end
3839

39-
end # module TestStatisticalModel
40+
@testset "ConvergenceException" begin
41+
fail = "failure to converge after 10 iterations"
42+
chgtol = "last change between iterations (0.2) was greater than tolerance (0.1)"
43+
msg = "Try changing maxiter."
44+
@test sprint(showerror, ConvergenceException(10)) == "$fail."
45+
@test sprint(showerror, ConvergenceException(10, 0.2, 0.1)) == "$fail; $chgtol."
46+
@test sprint(showerror, ConvergenceException(10, 0.2, 0.1, msg)) == "$fail; $chgtol. $msg"
47+
err = @test_throws ArgumentError ConvergenceException(10, 0.1, 0.2)
48+
@test err.value.msg == string("can't construct `ConvergenceException` with change ",
49+
"less than tolerance; got 0.1 and 0.2")
50+
end
51+
52+
end # module TestStatisticalModel

0 commit comments

Comments
 (0)