-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic step sizes for SVRG #207
Merged
BalzaniEdoardo
merged 61 commits into
flatironinstitute:development
from
bagibence:auto_stepsize_svrg
Oct 28, 2024
Merged
Changes from 46 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
0f9f1da
Add automatic batch and step size for softplus-Poisson GLMs optimized…
bagibence a88bc61
Add docstrings
bagibence a021cbb
Handle stepsize not being in solver_kwargs
bagibence e606fc1
Add new way to calculate stepsize and also b_tilde
bagibence e94cbc8
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo 95870f8
started renaming vars
BalzaniEdoardo 5661899
added ref to algorithm
BalzaniEdoardo 481220d
renamed function and generalized lookup
BalzaniEdoardo 68c3240
added the table calculations
BalzaniEdoardo db5b70f
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo 730b789
brought back maxiter to 10K.
BalzaniEdoardo d15ad8d
moved pieces around
BalzaniEdoardo ab4dbfd
improved doscrsrings and added test for config
BalzaniEdoardo 886bdeb
improved naming and docstrings
BalzaniEdoardo d5baa02
started testing
BalzaniEdoardo deae6a1
changed naming
BalzaniEdoardo 496702a
linted
BalzaniEdoardo e084844
added two missed lines for cov
BalzaniEdoardo 0d36aaf
added test all table cases
BalzaniEdoardo aada505
linted
BalzaniEdoardo e9028a8
linted
BalzaniEdoardo b8801b5
added glm tests
BalzaniEdoardo 7e8d576
linted
BalzaniEdoardo 06b9f53
improved glm docstrings
BalzaniEdoardo cae93ac
removed args from docstrings
BalzaniEdoardo 5eec1e3
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo 35622bd
added billy's comments
BalzaniEdoardo e49603b
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo 8dcc9b2
Merge branch 'main' into auto_stepsize_svrg
BalzaniEdoardo 2cadad0
merged dev
BalzaniEdoardo 8f33e7f
added tests for auto-stepsize
BalzaniEdoardo 028952d
removed unused import
BalzaniEdoardo 7d766eb
fixed warns in tests
BalzaniEdoardo 93a7033
fix warn svrg default
BalzaniEdoardo 3e55ce7
moved the methods around for re-usability
BalzaniEdoardo 4fdcf32
fixed mockregressor
BalzaniEdoardo 30934ca
fix comment
BalzaniEdoardo b3f801f
batched multiply
BalzaniEdoardo 26499a7
changed typing
BalzaniEdoardo 05f4f36
fixed tests
BalzaniEdoardo d8e8216
add batched compute of lmax
BalzaniEdoardo b10c192
added comment
BalzaniEdoardo 24e5f4b
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo e09623b
linted
BalzaniEdoardo c220a1d
linted
BalzaniEdoardo 4c9553d
modified svrg error to match GradientDescent and ProximalGradient fro…
BalzaniEdoardo 009561e
added pdf
BalzaniEdoardo 60b2f27
improved error message and provided solution
BalzaniEdoardo f0e133e
change default to power iteration
BalzaniEdoardo 7fc7e66
saga removed
BalzaniEdoardo c0103aa
removed model.fit
BalzaniEdoardo 78cb8ed
removed test saga
BalzaniEdoardo 7d87c4a
merged development
BalzaniEdoardo 37f6b11
added extra example
BalzaniEdoardo 0718819
added warning
BalzaniEdoardo 6ecf1a1
changed descr of svrg usage
BalzaniEdoardo 6bb16b1
linted
BalzaniEdoardo d901ebb
improved docstrings
BalzaniEdoardo 01bc153
added expectation in test
BalzaniEdoardo 9dc8e6e
added typeerror and valueerror, as well as tests
BalzaniEdoardo 028b26a
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from .exceptions import NotFittedError | ||
from .pytrees import FeaturePytree | ||
from .regularizer import GroupLasso, Lasso, Regularizer, Ridge | ||
from .solvers._compute_defaults import glm_compute_optimal_stepsize_configs | ||
from .type_casting import jnp_asarray_if, support_pynapple | ||
from .typing import DESIGN_INPUT_TYPE | ||
|
||
|
@@ -55,10 +56,35 @@ class GLM(BaseRegressor): | |
|
||
| Regularizer | Default Solver | Available Solvers | | ||
| ------------- | ---------------- | ----------------------------------------------------------- | | ||
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient | | ||
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient | | ||
| Lasso | ProximalGradient | ProximalGradient | | ||
| GroupLasso | ProximalGradient | ProximalGradient | | ||
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG | | ||
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG | | ||
| Lasso | ProximalGradient | ProximalGradient, ProxSVRG | | ||
| GroupLasso | ProximalGradient | ProximalGradient, , ProxSVRG | | ||
|
||
|
||
**Fitting Large Models** | ||
|
||
For very large models, you may consider using the Stochastic Variance Reduced Gradient | ||
([SVRG](../solvers/_svrg/#nemos.solvers._svrg.SVRG)) or its proximal variant | ||
([ProxSVRG](../solvers/_svrg/#nemos.solvers._svrg.ProxSVRG)) solver, | ||
which take advantage of batched computation. | ||
|
||
The performance of the SVRG solver depends critically on the choice of `batch_size` and `stepsize` | ||
hyperparameters. These parameters control the size of the mini-batches used for gradient computations | ||
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow | ||
convergence or even divergence of the optimization process. | ||
|
||
To assist with this, for certain GLM configurations, we provide recommended `batch_size` and `stepsize` | ||
values that are theoretically guaranteed to ensure fast convergence. | ||
BalzaniEdoardo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Below is a list of the configurations for which we can provide guaranteed hyperparameters: | ||
|
||
| GLM / PopulationGLM Configuration | Stepsize | Batch Size | | ||
| --------------------------------- | :------: | :---------: | | ||
| Poisson + soft-plus + UnRegularized | ✅ | ❌ | | ||
| Poisson + soft-plus + Ridge | ✅ | ✅ | | ||
| Poisson + soft-plus + Lasso | ✅ | ❌ | | ||
| Poisson + soft-plus + GroupLasso | ✅ | ❌ | | ||
|
||
Parameters | ||
---------- | ||
|
@@ -890,8 +916,10 @@ def initialize_state( | |
) | ||
self.regularizer.mask = jnp.ones((1, data.shape[1])) | ||
|
||
opt_solver_kwargs = self.optimize_solver_params(data, y) | ||
|
||
# set up the solver init/run/update attrs | ||
self.instantiate_solver() | ||
self.instantiate_solver(solver_kwargs=opt_solver_kwargs) | ||
|
||
opt_state = self.solver_init_state(init_params, data, y) | ||
return opt_state | ||
|
@@ -988,6 +1016,10 @@ def update( | |
|
||
return opt_step | ||
|
||
def get_optimal_solver_params_config(self): | ||
"""Return the functions for computing default step and batch size for the solver.""" | ||
return glm_compute_optimal_stepsize_configs(self) | ||
|
||
|
||
class PopulationGLM(GLM): | ||
""" | ||
|
@@ -1003,10 +1035,35 @@ class PopulationGLM(GLM): | |
|
||
| Regularizer | Default Solver | Available Solvers | | ||
| ------------- | ---------------- | ----------------------------------------------------------- | | ||
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient | | ||
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient | | ||
| Lasso | ProximalGradient | ProximalGradient | | ||
| GroupLasso | ProximalGradient | ProximalGradient | | ||
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG | | ||
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG | | ||
| Lasso | ProximalGradient | ProximalGradient, ProxSVRG | | ||
| GroupLasso | ProximalGradient | ProximalGradient, ProxSVRG | | ||
|
||
|
||
**Fitting Large Models** | ||
|
||
For very large models, you may consider using the Stochastic Variance Reduced Gradient | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same point about example and doc link There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. subsequent pr |
||
([SVRG](../solvers/_svrg/#nemos.solvers._svrg.SVRG)) or its proximal variant | ||
([ProxSVRG](../solvers/_svrg/#nemos.solvers._svrg.ProxSVRG)) solver, | ||
which take advantage of batched computation. | ||
|
||
The performance of the SVRG solver depends critically on the choice of `batch_size` and `stepsize` | ||
hyperparameters. These parameters control the size of the mini-batches used for gradient computations | ||
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow | ||
convergence or even divergence of the optimization process. | ||
|
||
To assist with this, for certain GLM configurations, we provide recommended `batch_size` and `stepsize` | ||
values that are theoretically guaranteed to ensure fast convergence. | ||
|
||
Below is a list of the configurations for which we can provide guaranteed hyperparameters: | ||
|
||
| GLM / PopulationGLM Configuration | Stepsize | Batch Size | | ||
| --------------------------------- | :------: | :---------: | | ||
| Poisson + soft-plus + UnRegularized | ✅ | ❌ | | ||
| Poisson + soft-plus + Ridge | ✅ | ✅ | | ||
| Poisson + soft-plus + Lasso | ✅ | ❌ | | ||
| Poisson + soft-plus + GroupLasso | ✅ | ❌ | | ||
|
||
Parameters | ||
---------- | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._svrg import SVRG, ProxSVRG | ||
from ._svrg_defaults import ( | ||
glm_softplus_poisson_l_max_and_l, | ||
svrg_optimal_batch_and_stepsize, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union | ||
|
||
import jax | ||
|
||
from ..observation_models import PoissonObservations | ||
from ..regularizer import Ridge | ||
from ._svrg_defaults import ( | ||
glm_softplus_poisson_l_max_and_l, | ||
svrg_optimal_batch_and_stepsize, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from ..glm import GLM, PopulationGLM | ||
|
||
|
||
def glm_compute_optimal_stepsize_configs( | ||
model: Union[GLM, PopulationGLM] | ||
) -> Tuple[Optional[Callable], Optional[Callable], Optional[float]]: | ||
""" | ||
Compute configuration functions for optimal step size selection based on the model. | ||
|
||
This function returns a tuple of three elements that are used for configuring the | ||
optimal step size and batch size for variance reduced gradient (SVRG and | ||
ProxSVRG) algorithms. If the model is configured with specific solver names, | ||
the appropriate computation functions are returned. Additionally, it determines the | ||
smoothness and strong convexity constants based on the model's observation and regularizer. | ||
|
||
Parameters | ||
---------- | ||
model : | ||
The generalized linear model object for which the optimal step size and batch | ||
configuration need to be computed. The model should have attributes like | ||
billbrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
`solver_name`, `observation_model`, and `regularizer`. | ||
|
||
Returns | ||
------- | ||
compute_optimal_params : | ||
A function to compute the optimal batch size and step size if the model | ||
is configured with the SVRG or ProxSVRG solver, None otherwise. | ||
|
||
compute_smoothness : | ||
A function to compute the smoothness constant of the loss function if the | ||
observation model uses a softplus inverse link function and is a Poisson | ||
observation model, None otherwise. | ||
|
||
strong_convexity : | ||
The strong convexity constant of the loss function if the model has a | ||
Ridge regularizer. If the model does not have a Ridge regularizer, this | ||
value will be None. | ||
|
||
""" | ||
# initialize funcs and strong convexity constant | ||
compute_optimal_params = None | ||
compute_smoothness = None | ||
strong_convexity = ( | ||
None if not isinstance(model.regularizer, Ridge) else model.regularizer_strength | ||
) | ||
|
||
# look-up table for selecting the optimal step and batch | ||
if model.solver_name in ("SVRG", "ProxSVRG"): | ||
compute_optimal_params = svrg_optimal_batch_and_stepsize | ||
|
||
# get the smoothness parameter compute function | ||
if model.observation_model.inverse_link_function is jax.nn.softplus and isinstance( | ||
model.observation_model, PoissonObservations | ||
): | ||
compute_smoothness = glm_softplus_poisson_l_max_and_l | ||
|
||
return compute_optimal_params, compute_smoothness, strong_convexity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to point to example in the docs here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but for a future PR in the documentation. I'll link to this comment in the docs project