Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more parameters to control fitting, and add data checks #24

Merged
merged 2 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 156 additions & 39 deletions src/MLJGLMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ end
fit_intercept::Bool = true
link::GLM.Link01 = GLM.LogitLink()
offsetcol::Union{Symbol, Nothing} = nothing
maxiter::Integer = 30
atol::Real = 1e-6
rtol::Real = 1e-6
minstepfac::Real = 0.001
report_keys::KEYS_TYPE = DEFAULT_KEYS::(isnothing(_) || issubset(_, VALID_KEYS))
end

Expand All @@ -68,6 +72,10 @@ end
distribution::Distribution = Poisson()
link::GLM.Link = GLM.LogLink()
offsetcol::Union{Symbol, Nothing} = nothing
maxiter::Integer = 30
atol::Real = 1e-6
rtol::Real = 1e-6
minstepfac::Real = 0.001
report_keys::KEYS_TYPE = DEFAULT_KEYS::(isnothing(_) || issubset(_, VALID_KEYS))
end

Expand All @@ -87,17 +95,19 @@ Augment the matrix `X` with a column of ones if the intercept is to be
fitted (`b=true`), return `X` otherwise.
"""
function augment_X(X::Matrix, b::Bool)::Matrix
b && return hcat(X, ones(eltype(X), size(X, 1), 1))
b && return hcat(X, ones(float(Int), size(X, 1), 1))
OkonSamuel marked this conversation as resolved.
Show resolved Hide resolved
return X
end

_to_vector(v::Vector) = v
_to_vector(v) = collect(v)
_to_array(v::AbstractArray) = v
_to_array(v) = collect(v)

"""
split_X_offset(X, offsetcol::Nothing)

When no offset is specied, return X and an empty vector.
When no offset is specified, return `X` and an empty vector.
"""
split_X_offset(X, offsetcol::Nothing) = (X, Float64[])

Expand All @@ -115,19 +125,125 @@ function split_X_offset(X, offsetcol::Symbol)
return newX, _to_vector(offset)
end

# If `estimates_dispersion_param` returns `false` then the dispersion
# parameter isn't estimated from data but known apriori to be `1`.
estimates_dispersion_param(::LinearRegressor) = true
estimates_dispersion_param(::LinearBinaryClassifier) = false

function estimates_dispersion_param(model::LinearCountRegressor)
return GLM.dispersion_parameter(model.distribution)
end

function _throw_sample_size_error(model, est_dispersion_param)
requires_info = _requires_info(model, est_dispersion_param)

if isnothing(model.offsetcol)
offset_info = " `offsetcol == nothing`"
else
offset_info = " `offsetcol !== nothing`"
end

modelname = nameof(typeof(model))
if model isa LinearCountRegressor
distribution_info = "and `distribution = $(nameof(typeof(model.distribution)))()`"
else
distribution_info = "\b"
end

throw(
ArgumentError(
" `$(modelname)` with `fit_intercept = $(model.fit_intercept)`,"*
"$(offset_info) $(distribution_info) requires $(requires_info)"
)
)
return nothing
end

"""
_requires_info(model, est_dispersion_param)

Returns one of the following strings
- "`n_samples >= n_features`", "`n_samples > n_features`"
- "`n_samples >= n_features - 1`", "`n_samples > n_features - 1`"
- "`n_samples >= n_features + 1`", "`n_samples > n_features + 1`"
"""
function _requires_info(model, est_dispersion_param)
inequality = est_dispersion_param ? ">" : ">="
int_num = model.fit_intercept - !isnothing(model.offsetcol)
OkonSamuel marked this conversation as resolved.
Show resolved Hide resolved
OkonSamuel marked this conversation as resolved.
Show resolved Hide resolved

if iszero(int_num)
int_num_string = "\b"
elseif int_num < 0
int_num_string = "- $(abs(int_num))"
else
int_num_string = "+ $(int_num)"
end

return "`n_samples $(inequality) n_features $(int_num_string)`."
end

function check_sample_size(model, n, p)
if estimates_dispersion_param(model)
n <= p + model.fit_intercept && _throw_sample_size_error(model, true)
else
n < p + model.fit_intercept && _throw_sample_size_error(model, false)
end
return nothing
end

function _matrix_and_features(model, Xcols, handle_intercept=false)
col_names = Tables.columnnames(Xcols)
n, p = Tables.rowcount(Xcols), length(col_names)
augment = handle_intercept && model.fit_intercept

if !handle_intercept # i.e This only runs during `fit`
check_sample_size(model, n, p)
end

if p == 0
Xmatrix = Matrix{float(Int)}(undef, n, p)
else
Xmatrix = Tables.matrix(Xcols)
end

Xmatrix = augment_X(Xmatrix, augment)

return Xmatrix, col_names
end

_to_columns(t::Tables.AbstractColumns) = t
_to_columns(t) = Tables.Columns(t)

"""
prepare_inputs(model, X; handle_intercept=false)

Handle `model.offsetcol` and `model.fit_intercept` if `handle_intercept=true`.
`handle_intercept` is disabled for fitting since the StatsModels.@formula handles the intercept.
"""
function prepare_inputs(model, X; handle_intercept=false)
Xminoffset, offset = split_X_offset(X, model.offsetcol)
Xmatrix = MMI.matrix(Xminoffset)
if handle_intercept
Xmatrix = augment_X(Xmatrix, model.fit_intercept)
Xcols = _to_columns(X)
table_features = Tables.columnnames(Xcols)
p = length(table_features)
p >= 1 || throw(
ArgumentError("`X` must contain at least one feature column.")
)
if !isnothing(model.offsetcol)
model.offsetcol in table_features || throw(
ArgumentError("offset column `$(model.offsetcol)` not found in table `X")
)
if p < 2 && !model.fit_intercept
throw(
ArgumentError(
"At least 2 feature columns are required for learning with"*
" `offsetcol !== nothing` and `fit_intercept == false`."
)
)
end
end
return Xmatrix, offset
Xminoffset, offset = split_X_offset(Xcols, model.offsetcol)
Xminoffset_cols = _to_columns(Xminoffset)
Xmatrix, features = _matrix_and_features(model, Xminoffset_cols , handle_intercept)
return Xmatrix, offset, _to_array(features)
end

"""
Expand Down Expand Up @@ -170,7 +286,6 @@ function glm_report(glm_model, features, reportkeys)
return NamedTuple{Tuple(keys(report_dict))}(values(report_dict))
end


"""
glm_formula(model, features) -> FormulaTerm

Expand All @@ -191,31 +306,10 @@ end
Return data which is ready to be passed to `fit(form, data, ...)`.
"""
function glm_data(model, Xmatrix, y, features)
header = collect(features)
data = Tables.table([Xmatrix y]; header=[header; :y])
data = Tables.table([Xmatrix y]; header=[features...; :y])
return data
end

_to_array(v::AbstractArray) = v
_to_array(v) = collect(v)

"""
glm_features(model, X)

Returns an iterable features object, to be used in the construction of
glm formula and glm data header.
"""
function glm_features(model, X)
if Tables.columnaccess(X)
table_features = _to_array(keys(Tables.columns(X)))
else
first_row = iterate(Tables.rows(X), 1)[1]
table_features = first_row === nothing ? Symbol[] : _to_array(keys(first_row))
end
filter!(!=(model.offsetcol), table_features)
return table_features
end

"""
check_weights(w, y)

Expand Down Expand Up @@ -259,8 +353,7 @@ params(fr::FitResult) = fr.params

function MMI.fit(model::LinearRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Xmatrix, offset = prepare_inputs(model, X)
features = glm_features(model, X)
Xmatrix, offset, features = prepare_inputs(model, X)
y_ = isempty(offset) ? y : y .- offset
wts = check_weights(w, y_)
data = glm_data(model, Xmatrix, y_, features)
Expand All @@ -278,12 +371,20 @@ end

function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Xmatrix, offset = prepare_inputs(model, X)
features = glm_features(model, X)
Xmatrix, offset, features = prepare_inputs(model, X)
data = glm_data(model, Xmatrix, y, features)
wts = check_weights(w, y)
form = glm_formula(model, features)
fitted_glm = GLM.glm(form, data, model.distribution, model.link; offset, wts).model
fitted_glm_frame = GLM.glm(
form, data, model.distribution, model.link;
offset,
model.maxiter,
model.atol,
model.rtol,
model.minstepfac,
wts
)
fitted_glm = fitted_glm_frame.model
fitresult = FitResult(
GLM.coef(fitted_glm), GLM.dispersion(fitted_glm), (features = features,)
)
Expand All @@ -299,11 +400,19 @@ function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing)
decode = y[1]
y_plain = MMI.int(y) .- 1 # 0, 1 of type Int
wts = check_weights(w, y_plain)
Xmatrix, offset = prepare_inputs(model, X)
features = glm_features(model, X)
Xmatrix, offset, features = prepare_inputs(model, X)
data = glm_data(model, Xmatrix, y_plain, features)
form = glm_formula(model, features)
fitted_glm = GLM.glm(form, data, Bernoulli(), model.link; offset, wts).model
fitted_glm_frame = GLM.glm(
form, data, Bernoulli(), model.link;
offset,
model.maxiter,
model.atol,
model.rtol,
model.minstepfac,
wts
)
fitted_glm = fitted_glm_frame.model
fitresult = FitResult(
GLM.coef(fitted_glm), GLM.dispersion(fitted_glm), (features = features,)
)
Expand Down Expand Up @@ -342,9 +451,17 @@ glm_link(::LinearRegressor) = GLM.IdentityLink()

# more efficient than MLJBase fallback
function MMI.predict_mean(model::GLM_MODELS, fitresult, Xnew)
Xmatrix, offset = prepare_inputs(model, Xnew; handle_intercept=true)
Xmatrix, offset, _ = prepare_inputs(model, Xnew; handle_intercept=true)
result = glm_fitresult(model, fitresult) # ::FitResult
coef = coefs(result)
p = size(Xmatrix, 2)
if p != length(coef)
throw(
DimensionMismatch(
"The number of features in training and prediction datasets must be equal"
)
)
end
link = glm_link(model)
return glm_predict(link, coef, Xmatrix, model.offsetcol, offset)
end
Expand Down
Loading