Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic step sizes for SVRG #207

Merged
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
0f9f1da
Add automatic batch and step size for softplus-Poisson GLMs optimized…
bagibence Jul 29, 2024
a88bc61
Add docstrings
bagibence Aug 8, 2024
a021cbb
Handle stepsize not being in solver_kwargs
bagibence Aug 9, 2024
e606fc1
Add new way to calculate stepsize and also b_tilde
bagibence Aug 9, 2024
e94cbc8
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo Aug 16, 2024
95870f8
started renaming vars
BalzaniEdoardo Aug 21, 2024
5661899
added ref to algorithm
BalzaniEdoardo Aug 22, 2024
481220d
renamed function and generalized lookup
BalzaniEdoardo Aug 22, 2024
68c3240
added the table calculations
BalzaniEdoardo Aug 22, 2024
db5b70f
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo Aug 23, 2024
730b789
brought back maxiter to 10K.
BalzaniEdoardo Aug 23, 2024
d15ad8d
moved pieces around
BalzaniEdoardo Aug 23, 2024
ab4dbfd
improved doscrsrings and added test for config
BalzaniEdoardo Aug 26, 2024
886bdeb
improved naming and docstrings
BalzaniEdoardo Aug 26, 2024
d5baa02
started testing
BalzaniEdoardo Aug 26, 2024
deae6a1
changed naming
BalzaniEdoardo Aug 26, 2024
496702a
linted
BalzaniEdoardo Aug 26, 2024
e084844
added two missed lines for cov
BalzaniEdoardo Aug 26, 2024
0d36aaf
added test all table cases
BalzaniEdoardo Aug 26, 2024
aada505
linted
BalzaniEdoardo Aug 26, 2024
e9028a8
linted
BalzaniEdoardo Aug 26, 2024
b8801b5
added glm tests
BalzaniEdoardo Aug 26, 2024
7e8d576
linted
BalzaniEdoardo Aug 26, 2024
06b9f53
improved glm docstrings
BalzaniEdoardo Aug 26, 2024
cae93ac
removed args from docstrings
BalzaniEdoardo Aug 27, 2024
5eec1e3
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo Aug 27, 2024
35622bd
added billy's comments
BalzaniEdoardo Oct 1, 2024
e49603b
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo Oct 1, 2024
8dcc9b2
Merge branch 'main' into auto_stepsize_svrg
BalzaniEdoardo Oct 1, 2024
2cadad0
merged dev
BalzaniEdoardo Oct 8, 2024
8f33e7f
added tests for auto-stepsize
BalzaniEdoardo Oct 8, 2024
028952d
removed unused import
BalzaniEdoardo Oct 8, 2024
7d766eb
fixed warns in tests
BalzaniEdoardo Oct 8, 2024
93a7033
fix warn svrg default
BalzaniEdoardo Oct 8, 2024
3e55ce7
moved the methods around for re-usability
BalzaniEdoardo Oct 8, 2024
4fdcf32
fixed mockregressor
BalzaniEdoardo Oct 8, 2024
30934ca
fix comment
BalzaniEdoardo Oct 9, 2024
b3f801f
batched multiply
BalzaniEdoardo Oct 16, 2024
26499a7
changed typing
BalzaniEdoardo Oct 16, 2024
05f4f36
fixed tests
BalzaniEdoardo Oct 16, 2024
d8e8216
add batched compute of lmax
BalzaniEdoardo Oct 16, 2024
b10c192
added comment
BalzaniEdoardo Oct 16, 2024
24e5f4b
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo Oct 18, 2024
e09623b
linted
BalzaniEdoardo Oct 21, 2024
c220a1d
linted
BalzaniEdoardo Oct 21, 2024
4c9553d
modified svrg error to match GradientDescent and ProximalGradient fro…
BalzaniEdoardo Oct 21, 2024
009561e
added pdf
BalzaniEdoardo Oct 25, 2024
60b2f27
improved error message and provided solution
BalzaniEdoardo Oct 25, 2024
f0e133e
change default to power iteration
BalzaniEdoardo Oct 25, 2024
7fc7e66
saga removed
BalzaniEdoardo Oct 25, 2024
c0103aa
removed model.fit
BalzaniEdoardo Oct 25, 2024
78cb8ed
removed test saga
BalzaniEdoardo Oct 25, 2024
7d87c4a
merged development
BalzaniEdoardo Oct 25, 2024
37f6b11
added extra example
BalzaniEdoardo Oct 28, 2024
0718819
added warning
BalzaniEdoardo Oct 28, 2024
6ecf1a1
changed descr of svrg usage
BalzaniEdoardo Oct 28, 2024
6bb16b1
linted
BalzaniEdoardo Oct 28, 2024
d901ebb
improved docstrings
BalzaniEdoardo Oct 28, 2024
01bc153
added expectation in test
BalzaniEdoardo Oct 28, 2024
9dc8e6e
added typeerror and valueerror, as well as tests
BalzaniEdoardo Oct 28, 2024
028b26a
Merge branch 'development' into auto_stepsize_svrg
BalzaniEdoardo Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions src/nemos/base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import abc
import inspect
import warnings
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -228,7 +229,7 @@ def solver_name(self, solver_name: str):
if solver_name not in self._regularizer.allowed_solvers:
raise ValueError(
f"The solver: {solver_name} is not allowed for "
f"{self._regularizer.__class__.__name__} regularizaration. Allowed solvers are "
f"{self._regularizer.__class__.__name__} regularization. Allowed solvers are "
f"{self._regularizer.allowed_solvers}."
)
self._solver_name = solver_name
Expand Down Expand Up @@ -270,7 +271,9 @@ def _check_solver_kwargs(solver_class, solver_kwargs):
f"kwargs {undefined_kwargs} in solver_kwargs not a kwarg for {solver_class.__name__}!"
)

def instantiate_solver(self, *args) -> BaseRegressor:
def instantiate_solver(
self, *args, solver_kwargs: Optional[dict] = None
) -> BaseRegressor:
"""
Instantiate the solver with the provided loss function.

Expand All @@ -289,6 +292,9 @@ def instantiate_solver(self, *args) -> BaseRegressor:
*args:
Positional arguments for the jaxopt `solver.run` method, e.g. the regularizing
strength for proximal gradient methods.
solver_kwargs:
Optional dictionary with the solver kwargs.
If nothing is provided, it defaults to self.solver_kwargs.

Returns
-------
Expand All @@ -299,7 +305,7 @@ def instantiate_solver(self, *args) -> BaseRegressor:
if self.solver_name not in self.regularizer.allowed_solvers:
raise ValueError(
f"The solver: {self.solver_name} is not allowed for "
f"{self._regularizer.__class__.__name__} regularizaration. Allowed solvers are "
f"{self._regularizer.__class__.__name__} regularization. Allowed solvers are "
f"{self._regularizer.allowed_solvers}."
)

Expand All @@ -313,8 +319,9 @@ def instantiate_solver(self, *args) -> BaseRegressor:
else:
loss = self._predict_and_compute_loss

# copy dictionary of kwargs to avoid modifying user settings
solver_kwargs = deepcopy(self.solver_kwargs)
if solver_kwargs is None:
# copy dictionary of kwargs to avoid modifying user settings
solver_kwargs = deepcopy(self.solver_kwargs)

# check that the loss is Callable
utils.assert_is_callable(loss, "loss")
Expand Down Expand Up @@ -600,3 +607,57 @@ def _get_solver_class(solver_name: str):
)

return solver_class

def optimize_solver_params(self, X: DESIGN_INPUT_TYPE, y: jnp.ndarray) -> dict:
"""
Compute and update solver parameters with optimal defaults if available.

This method checks the current solver configuration and, if an optimal
configuration is known for the given model parameters, computes the optimal
batch size, step size, and other hyperparameters to ensure faster convergence.

Parameters
----------
X :
Input data used to compute smoothness and strong convexity constants.
y :
Target values used in conjunction with X for the same purpose.

Returns
-------
:
A dictionary containing the solver parameters, updated with optimal defaults
where applicable.

"""
# Start with a copy of the existing solver parameters
new_solver_kwargs = self.solver_kwargs.copy()

# get the model specific configs
compute_defaults, compute_l_smooth, strong_convexity = (
self.get_optimal_solver_params_config()
)
if compute_defaults and compute_l_smooth:
# Check if the user has provided batch size or stepsize, or else use None
batch_size = new_solver_kwargs.get("batch_size", None)
stepsize = new_solver_kwargs.get("stepsize", None)

# Compute the optimal batch size and stepsize based on smoothness, strong convexity, etc.
new_params = compute_defaults(
compute_l_smooth,
X,
y,
batch_size=batch_size,
stepsize=stepsize,
strong_convexity=strong_convexity,
)

# Update the solver parameters with the computed optimal values
new_solver_kwargs.update(new_params)

return new_solver_kwargs

@abstractmethod
def get_optimal_solver_params_config(self):
"""Return the functions for computing default step and batch size for the solver."""
pass
75 changes: 66 additions & 9 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .exceptions import NotFittedError
from .pytrees import FeaturePytree
from .regularizer import GroupLasso, Lasso, Regularizer, Ridge
from .solvers._compute_defaults import glm_compute_optimal_stepsize_configs
from .type_casting import jnp_asarray_if, support_pynapple
from .typing import DESIGN_INPUT_TYPE

Expand Down Expand Up @@ -55,10 +56,35 @@ class GLM(BaseRegressor):

| Regularizer | Default Solver | Available Solvers |
| ------------- | ---------------- | ----------------------------------------------------------- |
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
| Lasso | ProximalGradient | ProximalGradient |
| GroupLasso | ProximalGradient | ProximalGradient |
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
| Lasso | ProximalGradient | ProximalGradient, ProxSVRG |
| GroupLasso | ProximalGradient | ProximalGradient, , ProxSVRG |


**Fitting Large Models**

For very large models, you may consider using the Stochastic Variance Reduced Gradient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to point to example in the docs here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but for a future PR in the documentation. I'll link to this comment in the docs project

([SVRG](../solvers/_svrg/#nemos.solvers._svrg.SVRG)) or its proximal variant
([ProxSVRG](../solvers/_svrg/#nemos.solvers._svrg.ProxSVRG)) solver,
which take advantage of batched computation.

The performance of the SVRG solver depends critically on the choice of `batch_size` and `stepsize`
hyperparameters. These parameters control the size of the mini-batches used for gradient computations
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow
convergence or even divergence of the optimization process.

To assist with this, for certain GLM configurations, we provide recommended `batch_size` and `stepsize`
values that are theoretically guaranteed to ensure fast convergence.
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved

Below is a list of the configurations for which we can provide guaranteed hyperparameters:

| GLM / PopulationGLM Configuration | Stepsize | Batch Size |
| --------------------------------- | :------: | :---------: |
| Poisson + soft-plus + UnRegularized | ✅ | ❌ |
| Poisson + soft-plus + Ridge | ✅ | ✅ |
| Poisson + soft-plus + Lasso | ✅ | ❌ |
| Poisson + soft-plus + GroupLasso | ✅ | ❌ |

Parameters
----------
Expand Down Expand Up @@ -890,8 +916,10 @@ def initialize_state(
)
self.regularizer.mask = jnp.ones((1, data.shape[1]))

opt_solver_kwargs = self.optimize_solver_params(data, y)

# set up the solver init/run/update attrs
self.instantiate_solver()
self.instantiate_solver(solver_kwargs=opt_solver_kwargs)

opt_state = self.solver_init_state(init_params, data, y)
return opt_state
Expand Down Expand Up @@ -988,6 +1016,10 @@ def update(

return opt_step

def get_optimal_solver_params_config(self):
"""Return the functions for computing default step and batch size for the solver."""
return glm_compute_optimal_stepsize_configs(self)


class PopulationGLM(GLM):
"""
Expand All @@ -1003,10 +1035,35 @@ class PopulationGLM(GLM):

| Regularizer | Default Solver | Available Solvers |
| ------------- | ---------------- | ----------------------------------------------------------- |
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
| Lasso | ProximalGradient | ProximalGradient |
| GroupLasso | ProximalGradient | ProximalGradient |
| UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
| Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
| Lasso | ProximalGradient | ProximalGradient, ProxSVRG |
| GroupLasso | ProximalGradient | ProximalGradient, ProxSVRG |


**Fitting Large Models**

For very large models, you may consider using the Stochastic Variance Reduced Gradient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same point about example and doc link

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subsequent pr

([SVRG](../solvers/_svrg/#nemos.solvers._svrg.SVRG)) or its proximal variant
([ProxSVRG](../solvers/_svrg/#nemos.solvers._svrg.ProxSVRG)) solver,
which take advantage of batched computation.

The performance of the SVRG solver depends critically on the choice of `batch_size` and `stepsize`
hyperparameters. These parameters control the size of the mini-batches used for gradient computations
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow
convergence or even divergence of the optimization process.

To assist with this, for certain GLM configurations, we provide recommended `batch_size` and `stepsize`
values that are theoretically guaranteed to ensure fast convergence.

Below is a list of the configurations for which we can provide guaranteed hyperparameters:

| GLM / PopulationGLM Configuration | Stepsize | Batch Size |
| --------------------------------- | :------: | :---------: |
| Poisson + soft-plus + UnRegularized | ✅ | ❌ |
| Poisson + soft-plus + Ridge | ✅ | ✅ |
| Poisson + soft-plus + Lasso | ✅ | ❌ |
| Poisson + soft-plus + GroupLasso | ✅ | ❌ |

Parameters
----------
Expand Down
5 changes: 5 additions & 0 deletions src/nemos/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._svrg import SVRG, ProxSVRG
from ._svrg_defaults import (
glm_softplus_poisson_l_max_and_l,
svrg_optimal_batch_and_stepsize,
)
71 changes: 71 additions & 0 deletions src/nemos/solvers/_compute_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

import jax

from ..observation_models import PoissonObservations
from ..regularizer import Ridge
from ._svrg_defaults import (
glm_softplus_poisson_l_max_and_l,
svrg_optimal_batch_and_stepsize,
)

if TYPE_CHECKING:
from ..glm import GLM, PopulationGLM


def glm_compute_optimal_stepsize_configs(
model: Union[GLM, PopulationGLM]
) -> Tuple[Optional[Callable], Optional[Callable], Optional[float]]:
"""
Compute configuration functions for optimal step size selection based on the model.

This function returns a tuple of three elements that are used for configuring the
optimal step size and batch size for variance reduced gradient (SVRG and
ProxSVRG) algorithms. If the model is configured with specific solver names,
the appropriate computation functions are returned. Additionally, it determines the
smoothness and strong convexity constants based on the model's observation and regularizer.

Parameters
----------
model :
The generalized linear model object for which the optimal step size and batch
configuration need to be computed. The model should have attributes like
billbrod marked this conversation as resolved.
Show resolved Hide resolved
`solver_name`, `observation_model`, and `regularizer`.

Returns
-------
compute_optimal_params :
A function to compute the optimal batch size and step size if the model
is configured with the SVRG or ProxSVRG solver, None otherwise.

compute_smoothness :
A function to compute the smoothness constant of the loss function if the
observation model uses a softplus inverse link function and is a Poisson
observation model, None otherwise.

strong_convexity :
The strong convexity constant of the loss function if the model has a
Ridge regularizer. If the model does not have a Ridge regularizer, this
value will be None.

"""
# initialize funcs and strong convexity constant
compute_optimal_params = None
compute_smoothness = None
strong_convexity = (
None if not isinstance(model.regularizer, Ridge) else model.regularizer_strength
)

# look-up table for selecting the optimal step and batch
if model.solver_name in ("SVRG", "ProxSVRG"):
compute_optimal_params = svrg_optimal_batch_and_stepsize

# get the smoothness parameter compute function
if model.observation_model.inverse_link_function is jax.nn.softplus and isinstance(
model.observation_model, PoissonObservations
):
compute_smoothness = glm_softplus_poisson_l_max_and_l

return compute_optimal_params, compute_smoothness, strong_convexity
16 changes: 8 additions & 8 deletions src/nemos/solvers.py → src/nemos/solvers/_svrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from jaxopt._src import loop
from jaxopt.prox import prox_none

from .tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub
from .typing import KeyArrayLike, Pytree
from ..tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub
from ..typing import KeyArrayLike, Pytree


class SVRGState(NamedTuple):
Expand Down Expand Up @@ -207,7 +207,7 @@ def _inner_loop_param_update_step(
# gradient of f_{i_k} at x_{k} in the pseudocode of Gower et al. 2020
minibatch_grad_at_current_params = self.loss_gradient(params, *args)
# gradient on batch_{i_k} evaluated at the anchor point
# gradient of f_{i_k} at x_{x} in the pseudocode of Gower et al. 2020
# gradient of f_{i_k} at x_{k} in the pseudocode of Gower et al. 2020
minibatch_grad_at_reference_point = self.loss_gradient(reference_point, *args)

# SVRG gradient estimate
Expand Down Expand Up @@ -575,7 +575,7 @@ def inner_loop_body(_, carry):
@staticmethod
def _error(x, x_prev, stepsize):
"""
Calculate the magnitude of the update relative to the parameters.
Calculate the magnitude of the update relative to the stepsize.
Used for terminating the algorithm if a certain tolerance is reached.

Params
Expand All @@ -589,15 +589,15 @@ def _error(x, x_prev, stepsize):
-------
Scaled update magnitude.
"""
# stepsize is an argument to be consistent with jaxopt
return tree_l2_norm(tree_sub(x, x_prev)) / tree_l2_norm(x_prev)
return tree_l2_norm(tree_sub(x, x_prev)) / stepsize


class SVRG(ProxSVRG):
"""
SVRG solver
SVRG solver.

Equivalent to ProxSVRG with prox as the identity function and hyperparams_prox=None.
This solver implements "Algorithm 3" of [1]. Equivalent to ProxSVRG with prox as the identity
function and hyperparams_prox=None.

Attributes
----------
Expand Down
Loading
Loading