From 4e05569cc401805dbf5b5f796ff8cf1d2704a5b2 Mon Sep 17 00:00:00 2001 From: tiemvanderdeure Date: Thu, 15 Feb 2024 10:23:02 +0100 Subject: [PATCH] add FitResult again --- src/MLJGLMInterface.jl | 75 +++++++++++++++++++++--------------------- test/runtests.jl | 4 +-- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/src/MLJGLMInterface.jl b/src/MLJGLMInterface.jl index 1869f31..cad5b59 100644 --- a/src/MLJGLMInterface.jl +++ b/src/MLJGLMInterface.jl @@ -18,7 +18,7 @@ import MLJModelInterface import MLJModelInterface: metadata_pkg, metadata_model, Table, Continuous, Count, Finite, OrderedFactor, Multiclass, @mlj_model using Distributions: Bernoulli, Distribution, Poisson -using StatsModels: ConstantTerm, Term, FormulaTerm, term +using StatsModels: ConstantTerm, Term, FormulaTerm, term, modelcols using Tables import GLM @@ -218,7 +218,6 @@ function prepare_inputs(model, X) Xminoffset, offset = split_X_offset(Xcols, model.offsetcol) features = Tables.columnnames(Xminoffset) - @show first(Xminoffset) check_sample_size(model, length(first(Xminoffset)), p) return Xminoffset, offset, _to_array(features) @@ -322,7 +321,9 @@ end #### FIT FUNCTIONS #### -struct FitResult{V<:AbstractVector, T, R} +struct FitResult{F, V<:AbstractVector, T, R} + "Formula containing all coefficients and their types" + formula::F "Vector containg coeficients of the predictors and intercept" coefs::V "An estimate of the dispersion parameter of the glm model. " @@ -331,10 +332,11 @@ struct FitResult{V<:AbstractVector, T, R} params::R end +FitResult(fitted_glm, features) = FitResult(GLM.formula(fitted_glm), GLM.coef(fitted_glm), GLM.dispersion(fitted_glm.model), (features = features,)) + dispersion(fr::FitResult) = fr.dispersion params(fr::FitResult) = fr.params - - +coefs(fr::FitResult) = fr.coefs function MMI.fit(model::LinearRegressor, verbosity::Int, X, y, w=nothing) # apply the model @@ -345,11 +347,13 @@ function MMI.fit(model::LinearRegressor, verbosity::Int, X, y, w=nothing) form = glm_formula(model, features) fitted_lm = GLM.lm(form, data; model.dropcollinear, wts) + fitresult = FitResult(fitted_lm, features) + # form the report report = glm_report(fitted_lm, features, model.report_keys) cache = nothing # return - return fitted_lm, cache, report + return fitresult, cache, report end function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing) @@ -368,11 +372,13 @@ function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing) wts ) + fitresult = FitResult(fitted_glm, features) + # form the report report = glm_report(fitted_glm, features, model.report_keys) cache = nothing # return - return fitted_glm, cache, report + return fitresult, cache, report end function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing) @@ -393,21 +399,24 @@ function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing) wts ) + fitresult = FitResult(fitted_glm, features) + # form the report report = glm_report(fitted_glm, features, model.report_keys) cache = nothing # return - return (fitted_glm, decode), cache, report + return (fitresult, decode), cache, report end glm_fitresult(::LinearRegressor, fitresult) = fitresult glm_fitresult(::LinearCountRegressor, fitresult) = fitresult glm_fitresult(::LinearBinaryClassifier, fitresult) = fitresult[1] + function MMI.fitted_params(model::GLM_MODELS, fitresult) result = glm_fitresult(model, fitresult) - coef = GLM.coef(result) - features = Symbol.(filter(x -> x!="(Intercept)", GLM.coefnames(result))) + coef = coefs(result) + features = copy(params(result).features) if model.fit_intercept intercept = coef[end] coef_ = coef[1:end-1] @@ -424,39 +433,30 @@ end glm_link(model) = model.link glm_link(::LinearRegressor) = GLM.IdentityLink() -# more efficient than MLJBase fallback -function MMI.predict_mean(model::GLM_MODELS, fitresult, Xnew) - Xmatrix, offset, _ = prepare_inputs(model, Xnew; handle_intercept=true) - result = glm_fitresult(model, fitresult) # ::FitResult - coef = coefs(result) - p = size(Xmatrix, 2) - if p != length(coef) - throw( - DimensionMismatch( - "The number of features in training and prediction datasets must be equal" - ) - ) - end - link = glm_link(model) - return glm_predict(link, coef, Xmatrix, model.offsetcol, offset) +function glm_predict(link, terms, coef, offsetcol::Nothing, Xnew) + mm = modelcols(terms, Xnew) + η = mm * coef + μ = GLM.linkinv.(link, η) + return μ end - -# More efficient fallback. mean is not defined for LinearBinaryClassifier -function MMI.predict_mean(model::LinearRegressor, fitresult, Xnew) - X_col_table, offset, features = prepare_inputs(model, Xnew) - p = GLM.predict(fitresult, X_col_table) - isempty(offset) ? p : p .+ offset +function glm_predict(link, terms, coef, offsetcol::Symbol, Xnew) + mm = modelcols(terms, Xnew) + offset = Tables.getcolumn(Xnew, offsetcol) + η = mm * coef .+ offset + μ = GLM.linkinv.(link, η) + return μ end -function MMI.predict_mean(model::LinearCountRegressor, fitresult, Xnew) - X_col_table, offset, features = prepare_inputs(model, Xnew) - return GLM.predict(fitresult, X_col_table; offset = offset) +# More efficient fallback. predict_mean is not defined for LinearBinaryClassifier +function MMI.predict_mean(model::Union{LinearRegressor, LinearCountRegressor}, fitresult, Xnew) + p = glm_predict(glm_link(model), fitresult.formula.rhs, fitresult.coefs, model.offsetcol, Xnew) + return p end function MMI.predict(model::LinearRegressor, fitresult, Xnew) μ = MMI.predict_mean(model, fitresult, Xnew) - σ̂ = GLM.dispersion(fitresult.model) + σ̂ = dispersion(fitresult) return [GLM.Normal(μᵢ, σ̂) for μᵢ ∈ μ] end @@ -466,9 +466,8 @@ function MMI.predict(model::LinearCountRegressor, fitresult, Xnew) end function MMI.predict(model::LinearBinaryClassifier, (fitresult, decode), Xnew) - X_col_table, offset, features = prepare_inputs(model, Xnew) - π = GLM.predict(fitresult, X_col_table; offset) - return MMI.UnivariateFinite(decode, π, augment=true) + p = glm_predict(glm_link(model), fitresult.formula.rhs, fitresult.coefs, model.offsetcol, Xnew) + return MMI.UnivariateFinite(decode, p, augment=true) end # NOTE: predict_mode uses MLJBase's fallback diff --git a/test/runtests.jl b/test/runtests.jl index 6df6da7..bb383f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,7 +49,7 @@ expit(X) = 1 ./ (1 .+ exp.(-X)) # test `predict` object p_distr = predict(atom_ols, fitresult, selectrows(X, test)) - dispersion = GLM.dispersion(fitresult.model) + dispersion = MLJGLMInterface.dispersion(fitresult) @test p_distr[1] == Normal(p[1], dispersion) # test metadata @@ -344,7 +344,7 @@ end fp = fitted_params(mach) - @test fp.features == Symbol.(["x1", "x2: 0", "x2: 1"]) + @test fp.features == [:x1, :x2] @test_throws KeyError predict(mach, (x1 = [2,3,4], x2 = categorical([0,1,2]))) @test all(isapprox.(pdf.(predict(mach, X), true), [0,0,1], atol = 1e-3)) end