Skip to content

Commit

Permalink
Additional abstractions for statsmodels (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nosferican authored and nalimilan committed Apr 8, 2018
1 parent 8ef287b commit 4cf13cd
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 13 deletions.
11 changes: 10 additions & 1 deletion docs/src/statmodels.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,29 @@ deviance
dof
fit
fit!
informationmatrix
isfitted
islinear
loglikelihood
mss
nobs
nulldeviance
r2
rss
score
stderr
vcov
weights(::StatisticalModel)
```

`RegressionModel` extends `StatisticalModel` by implementing the following additional methods.
```@docs
dof_residual
fitted
leverage
meanresponse
modelmatrix
model_response
response
predict
predict!
residuals
Expand Down
10 changes: 9 additions & 1 deletion src/StatsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,27 @@ module StatsBase
fit,
fit!,
fitted,
informationmatrix,
isfitted,
islinear,
leverage,
loglikelihood,
meanresponse,
modelmatrix,
mss,
response,
nobs,
nulldeviance,
nullloglikelihood,
rss,
score,
stderr,
vcov,
predict,
predict!,
residuals,
r2,
r²,
model_response,

ConvergenceException

Expand Down
3 changes: 2 additions & 1 deletion src/deprecates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Base.varm, Base.stdm
@deprecate (obj::StatisticalModel, variant::Symbol) (obj, variant)
@deprecate adjR2(obj::StatisticalModel, variant::Symbol) adjr2(obj, variant)
@deprecate adjR²(obj::StatisticalModel, variant::Symbol) adjr²(obj, variant)
@deprecate model_response(obj::StatisticalModel) response(obj)

@deprecate norepeats(a::AbstractArray) allunique(a)

Expand Down Expand Up @@ -86,4 +87,4 @@ rand(s::RandIntSampler) = rand(Compat.Random.GLOBAL_RNG, s)

@deprecate(mad!(v::AbstractArray{T}, center;
constant::Real = 1 / (-sqrt(2 * one(T)) * erfcinv(3 * one(T) / 2))) where T<:Real,
mad!(v, center=center, constant=constant))
mad!(v, center=center, constant=constant))
100 changes: 91 additions & 9 deletions src/statmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ the likelihood of the model.
"""
deviance(obj::StatisticalModel) = error("deviance is not defined for $(typeof(obj)).")

"""
islinear(obj::StatisticalModel)
Indicate whether the model is linear.
"""
islinear(obj::StatisticalModel) = error("islinear is not defined for $(typeof(obj)).")

"""
nulldeviance(obj::StatisticalModel)
Expand All @@ -61,6 +68,14 @@ This is usually the model containing only the intercept.
"""
nullloglikelihood(obj::StatisticalModel) = error("nullloglikelihood is not defined for $(typeof(obj)).")

"""
score(obj::StatisticalModel)
Return the score of the statistical model. The score is the gradient of the
log-likelihood with respect to the coefficients.
"""
score(obj::StatisticalModel) = error("score is not defined for $(typeof(obj)).")

"""
nobs(obj::StatisticalModel)
Expand All @@ -79,6 +94,29 @@ when applicable the intercept and the distribution's dispersion parameter.
"""
dof(obj::StatisticalModel) = error("dof is not defined for $(typeof(obj)).")

"""
mss(obj::StatisticalModel)
Return the model sum of squares.
"""
mss(obj::StatisticalModel) = error("mss is not defined for $(typeof(obj)).")

"""
rss(obj::StatisticalModel)
Return the residual sum of squares.
"""
rss(obj::StatisticalModel) = error("rss is not defined for $(typeof(obj)).")

"""
informationmatrix(model::StatisticalModel; expected::Bool = true)
Return the information matrix. By default the Fisher information matrix is returned,
while the observed information matrix can be requested with `expected = false`.
"""
informationmatrix(model::StatisticalModel; expected::Bool = true) =
error("informationmatrix is not defined for $(typeof(obj)).")

"""
stderr(obj::StatisticalModel)
Expand All @@ -93,6 +131,20 @@ Return the variance-covariance matrix for the coefficients of the model.
"""
vcov(obj::StatisticalModel) = error("vcov is not defined for $(typeof(obj)).")

"""
weights(obj::StatisticalModel)
Return the weights used in the model.
"""
weights(obj::StatisticalModel) = error("weights is not defined for $(typeof(obj)).")

"""
isfitted(obj::StatisticalModel)
Indicate whether the model has been fitted.
"""
isfitted(obj::StatisticalModel) = error("isfitted is not defined for $(typeof(obj)).")

"""
Fit a statistical model.
"""
Expand Down Expand Up @@ -137,15 +189,23 @@ the likelihood of the model, ``k`` its number of consumed degrees of freedom
bic(obj::StatisticalModel) = -2loglikelihood(obj) + dof(obj)*log(nobs(obj))

"""
r2(obj::StatisticalModel, variant::Symbol)
r²(obj::StatisticalModel, variant::Symbol)
r2(obj::StatisticalModel)
r²(obj::StatisticalModel)
Coefficient of determination (R-squared).
For a linear model, the R² is defined as ``ESS/TSS``, with ``ESS`` the explained sum of squares
and ``TSS`` the total sum of squares, and `variant` can be omitted.
and ``TSS`` the total sum of squares.
"""
r2(obj::StatisticalModel) = mss(obj) / deviance(obj)

For other models, one of several pseudo R² definitions must be chosen via `variant`.
"""
r2(obj::StatisticalModel, variant::Symbol)
r²(obj::StatisticalModel, variant::Symbol)
Pseudo-coefficient of determination (pseudo R-squared).
For nonlinear models, one of several pseudo R² definitions must be chosen via `variant`.
Supported variants are:
- `:MacFadden` (a.k.a. likelihood ratio index), defined as ``1 - \\log L/\\log L0``.
- `:CoxSnell`, defined as ``1 - (L0/L)^{2/n}``
Expand Down Expand Up @@ -174,16 +234,24 @@ end
const= r2

"""
adjr2(obj::StatisticalModel, variant::Symbol)
adjr²(obj::StatisticalModel, variant::Symbol)
adjr2(obj::StatisticalModel)
adjr²(obj::StatisticalModel)
Adjusted coefficient of determination (adjusted R-squared).
For linear models, the adjusted R² is defined as ``1 - (1 - (1-R^2)(n-1)/(n-p))``, with ``R^2``
the coefficient of determination, ``n`` the number of observations, and ``p`` the number of
coefficients (including the intercept). This definition is generally known as the Wherry Formula I.
"""
adjr2(obj::StatisticalModel) = error("adjr2 is not defined for $(typeof(obj)).")

For other models, one of the several pseudo R² definitions must be chosen via `variant`.
"""
adjr2(obj::StatisticalModel, variant::Symbol)
adjr²(obj::StatisticalModel, variant::Symbol)
Adjusted pseudo-coefficient of determination (adjusted pseudo R-squared).
For nonlinear models, one of the several pseudo R² definitions must be chosen via `variant`.
The only currently supported variant is `:MacFadden`, defined as ``1 - (\\log L - k)/\\log L0``.
In this formula, ``L`` is the likelihood of the model, ``L0`` that of the null model
(the model including only the intercept). These two quantities are taken to be minus half
Expand Down Expand Up @@ -214,11 +282,18 @@ Return the fitted values of the model.
fitted(obj::RegressionModel) = error("fitted is not defined for $(typeof(obj)).")

"""
model_response(obj::RegressionModel)
response(obj::RegressionModel)
Return the model response (a.k.a. the dependent variable).
"""
model_response(obj::RegressionModel) = error("model_response is not defined for $(typeof(obj)).")
response(obj::RegressionModel) = error("response is not defined for $(typeof(obj)).")

"""
meanresponse(obj::RegressionModel)
Return the mean of the response.
"""
meanresponse(obj::RegressionModel) = error("meanresponse is not defined for $(typeof(obj)).")

"""
modelmatrix(obj::RegressionModel)
Expand All @@ -227,6 +302,13 @@ Return the model matrix (a.k.a. the design matrix).
"""
modelmatrix(obj::RegressionModel) = error("modelmatrix is not defined for $(typeof(obj)).")

"""
leverage(obj::RegressionModel)
Return the diagonal of the projection matrix.
"""
leverage(obj::RegressionModel) = error("leverage is not defined for $(typeof(obj)).")

"""
residuals(obj::RegressionModel)
Expand Down
2 changes: 1 addition & 1 deletion test/statmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ x3 0.453058 0.72525 0.999172 0.5567

@test sprint(showerror, ConvergenceException(10)) == "failure to converge after 10 iterations."

@test sprint(showerror, ConvergenceException(10, 0.2, 0.1)) ==
@test sprint(showerror, ConvergenceException(10, 0.2, 0.1)) ==
"failure to converge after 10 iterations. Last change (0.2) was greater than tolerance (0.1)."

err = @test_throws ArgumentError ConvergenceException(10,.1,.2)
Expand Down

0 comments on commit 4cf13cd

Please sign in to comment.