Skip to content

Commit

Permalink
add FitResult again
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Feb 15, 2024
1 parent 068ba59 commit 4e05569
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 40 deletions.
75 changes: 37 additions & 38 deletions src/MLJGLMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import MLJModelInterface
import MLJModelInterface: metadata_pkg, metadata_model, Table, Continuous, Count, Finite,
OrderedFactor, Multiclass, @mlj_model
using Distributions: Bernoulli, Distribution, Poisson
using StatsModels: ConstantTerm, Term, FormulaTerm, term
using StatsModels: ConstantTerm, Term, FormulaTerm, term, modelcols
using Tables
import GLM

Expand Down Expand Up @@ -218,7 +218,6 @@ function prepare_inputs(model, X)
Xminoffset, offset = split_X_offset(Xcols, model.offsetcol)
features = Tables.columnnames(Xminoffset)

@show first(Xminoffset)
check_sample_size(model, length(first(Xminoffset)), p)

return Xminoffset, offset, _to_array(features)
Expand Down Expand Up @@ -322,7 +321,9 @@ end
#### FIT FUNCTIONS
####

struct FitResult{V<:AbstractVector, T, R}
struct FitResult{F, V<:AbstractVector, T, R}
"Formula containing all coefficients and their types"
formula::F
"Vector containg coeficients of the predictors and intercept"
coefs::V
"An estimate of the dispersion parameter of the glm model. "
Expand All @@ -331,10 +332,11 @@ struct FitResult{V<:AbstractVector, T, R}
params::R
end

FitResult(fitted_glm, features) = FitResult(GLM.formula(fitted_glm), GLM.coef(fitted_glm), GLM.dispersion(fitted_glm.model), (features = features,))

dispersion(fr::FitResult) = fr.dispersion
params(fr::FitResult) = fr.params


coefs(fr::FitResult) = fr.coefs

function MMI.fit(model::LinearRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Expand All @@ -345,11 +347,13 @@ function MMI.fit(model::LinearRegressor, verbosity::Int, X, y, w=nothing)
form = glm_formula(model, features)
fitted_lm = GLM.lm(form, data; model.dropcollinear, wts)

fitresult = FitResult(fitted_lm, features)

# form the report
report = glm_report(fitted_lm, features, model.report_keys)
cache = nothing
# return
return fitted_lm, cache, report
return fitresult, cache, report
end

function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing)
Expand All @@ -368,11 +372,13 @@ function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing)
wts
)

fitresult = FitResult(fitted_glm, features)

# form the report
report = glm_report(fitted_glm, features, model.report_keys)
cache = nothing
# return
return fitted_glm, cache, report
return fitresult, cache, report
end

function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing)
Expand All @@ -393,21 +399,24 @@ function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing)
wts
)

fitresult = FitResult(fitted_glm, features)

# form the report
report = glm_report(fitted_glm, features, model.report_keys)
cache = nothing
# return
return (fitted_glm, decode), cache, report
return (fitresult, decode), cache, report
end

glm_fitresult(::LinearRegressor, fitresult) = fitresult
glm_fitresult(::LinearCountRegressor, fitresult) = fitresult
glm_fitresult(::LinearBinaryClassifier, fitresult) = fitresult[1]


function MMI.fitted_params(model::GLM_MODELS, fitresult)
result = glm_fitresult(model, fitresult)
coef = GLM.coef(result)
features = Symbol.(filter(x -> x!="(Intercept)", GLM.coefnames(result)))
coef = coefs(result)
features = copy(params(result).features)
if model.fit_intercept
intercept = coef[end]
coef_ = coef[1:end-1]
Expand All @@ -424,39 +433,30 @@ end
glm_link(model) = model.link
glm_link(::LinearRegressor) = GLM.IdentityLink()

# more efficient than MLJBase fallback
function MMI.predict_mean(model::GLM_MODELS, fitresult, Xnew)
Xmatrix, offset, _ = prepare_inputs(model, Xnew; handle_intercept=true)
result = glm_fitresult(model, fitresult) # ::FitResult
coef = coefs(result)
p = size(Xmatrix, 2)
if p != length(coef)
throw(
DimensionMismatch(
"The number of features in training and prediction datasets must be equal"
)
)
end
link = glm_link(model)
return glm_predict(link, coef, Xmatrix, model.offsetcol, offset)
function glm_predict(link, terms, coef, offsetcol::Nothing, Xnew)
mm = modelcols(terms, Xnew)
η = mm * coef
μ = GLM.linkinv.(link, η)
return μ
end


# More efficient fallback. mean is not defined for LinearBinaryClassifier
function MMI.predict_mean(model::LinearRegressor, fitresult, Xnew)
X_col_table, offset, features = prepare_inputs(model, Xnew)
p = GLM.predict(fitresult, X_col_table)
isempty(offset) ? p : p .+ offset
function glm_predict(link, terms, coef, offsetcol::Symbol, Xnew)
mm = modelcols(terms, Xnew)
offset = Tables.getcolumn(Xnew, offsetcol)
η = mm * coef .+ offset

Check warning on line 446 in src/MLJGLMInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJGLMInterface.jl#L443-L446

Added lines #L443 - L446 were not covered by tests
μ = GLM.linkinv.(link, η)
return μ
end

function MMI.predict_mean(model::LinearCountRegressor, fitresult, Xnew)
X_col_table, offset, features = prepare_inputs(model, Xnew)
return GLM.predict(fitresult, X_col_table; offset = offset)
# More efficient fallback. predict_mean is not defined for LinearBinaryClassifier
function MMI.predict_mean(model::Union{LinearRegressor, LinearCountRegressor}, fitresult, Xnew)
p = glm_predict(glm_link(model), fitresult.formula.rhs, fitresult.coefs, model.offsetcol, Xnew)
return p
end

function MMI.predict(model::LinearRegressor, fitresult, Xnew)
μ = MMI.predict_mean(model, fitresult, Xnew)
σ̂ = GLM.dispersion(fitresult.model)
σ̂ = dispersion(fitresult)
return [GLM.Normal(μᵢ, σ̂) for μᵢ μ]
end

Expand All @@ -466,9 +466,8 @@ function MMI.predict(model::LinearCountRegressor, fitresult, Xnew)
end

function MMI.predict(model::LinearBinaryClassifier, (fitresult, decode), Xnew)
X_col_table, offset, features = prepare_inputs(model, Xnew)
π = GLM.predict(fitresult, X_col_table; offset)
return MMI.UnivariateFinite(decode, π, augment=true)
p = glm_predict(glm_link(model), fitresult.formula.rhs, fitresult.coefs, model.offsetcol, Xnew)
return MMI.UnivariateFinite(decode, p, augment=true)
end

# NOTE: predict_mode uses MLJBase's fallback
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ expit(X) = 1 ./ (1 .+ exp.(-X))

# test `predict` object
p_distr = predict(atom_ols, fitresult, selectrows(X, test))
dispersion = GLM.dispersion(fitresult.model)
dispersion = MLJGLMInterface.dispersion(fitresult)
@test p_distr[1] == Normal(p[1], dispersion)

# test metadata
Expand Down Expand Up @@ -344,7 +344,7 @@ end

fp = fitted_params(mach)

@test fp.features == Symbol.(["x1", "x2: 0", "x2: 1"])
@test fp.features == [:x1, :x2]
@test_throws KeyError predict(mach, (x1 = [2,3,4], x2 = categorical([0,1,2])))
@test all(isapprox.(pdf.(predict(mach, X), true), [0,0,1], atol = 1e-3))
end
Expand Down

0 comments on commit 4e05569

Please sign in to comment.