diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 2d2a52b439..a7733cae22 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -6,7 +6,6 @@ import numpy as np import numbers import scipy.sparse as sparse -from .base_solver import validate_max_step import importlib @@ -92,7 +91,7 @@ def __init__( root_tol=1e-6, extrap_tol=None, max_step=np.inf, - output_variables=[], + output_variables=None, options=None, ): # set default options, @@ -115,7 +114,7 @@ def __init__( options[key] = value self._options = options - self.max_step = validate_max_step(max_step) + self.max_step = max_step self.output_variables = output_variables @@ -517,7 +516,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS): return base_set_up_return - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, max_step=np.inf): """ Solve a DAE model defined by residuals with initial conditions y0. @@ -529,6 +528,8 @@ def _integrate(self, model, t_eval, inputs_dict=None): The times at which to compute the solution inputs_dict : dict, optional Any input parameters to pass to the model when solving + max_step : float, optional + Maximum allowed step size. Default is np.inf. """ inputs_dict = inputs_dict or {} # stack inputs @@ -580,6 +581,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): y0full, ydot0full, inputs, + max_step=max_step, ) else: sol = idaklu.solve_python( @@ -601,6 +603,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): rtol, inputs, self._setup["number_of_sensitivity_parameters"], + max_step=max_step, ) integration_time = timer.time() @@ -655,7 +658,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): number_of_samples = sol.y.shape[0] // number_of_timesteps sol.y = sol.y.reshape((number_of_timesteps, number_of_samples)) startk = 0 - for vark, var in enumerate(self.output_variables): + for _vark, var in enumerate(self.output_variables): # ExplicitTimeIntegral's are not computed as part of the solver and # do not need to be converted if isinstance( diff --git a/pybamm/solvers/scikits_dae_solver.py b/pybamm/solvers/scikits_dae_solver.py index 571412cd31..b0bc43595a 100644 --- a/pybamm/solvers/scikits_dae_solver.py +++ b/pybamm/solvers/scikits_dae_solver.py @@ -8,8 +8,6 @@ import importlib import scipy.sparse as sparse -from .base_solver import validate_max_step - scikits_odes_spec = importlib.util.find_spec("scikits") if scikits_odes_spec is not None: scikits_odes_spec = importlib.util.find_spec("scikits.odes") @@ -72,13 +70,13 @@ def __init__( self.name = f"Scikits DAE solver ({method})" self.extra_options = extra_options or {} - self.max_step = validate_max_step(max_step) + self.max_step = max_step pybamm.citations.register("Malengier2018") pybamm.citations.register("Hindmarsh2000") pybamm.citations.register("Hindmarsh2005") - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, max_step=np.inf): """ Solve a model defined by dydt with initial conditions y0. @@ -90,6 +88,8 @@ def _integrate(self, model, t_eval, inputs_dict=None): The times at which to compute the solution inputs_dict : dict, optional Any input parameters to pass to the model when solving + max_step : float, optional + Maximum allowed step size. Default is onp.inf. """ inputs_dict = inputs_dict or {} @@ -134,6 +134,7 @@ def rootfn(t, y, ydot, return_root): "old_api": False, "rtol": self.rtol, "atol": self.atol, + "max_step": max_step, } if jacobian: @@ -164,7 +165,7 @@ def jacfn(t, y, ydot, residuals, cj, J): sol = dae_solver.solve(t_eval, y0, ydot0) integration_time = timer.time() - # return solution, we need to tranpose y to match scipy's interface + # return solution, we need to transpose y to match scipy's interface if sol.flag in [0, 2]: # 0 = solved for all t_eval if sol.flag == 0: diff --git a/pybamm/solvers/scikits_ode_solver.py b/pybamm/solvers/scikits_ode_solver.py index d7282afa63..ca69c3d518 100644 --- a/pybamm/solvers/scikits_ode_solver.py +++ b/pybamm/solvers/scikits_ode_solver.py @@ -8,7 +8,6 @@ import importlib import scipy.sparse as sparse -from .base_solver import validate_max_step scikits_odes_spec = importlib.util.find_spec("scikits") if scikits_odes_spec is not None: @@ -64,13 +63,13 @@ def __init__( self.extra_options = extra_options or {} self.ode_solver = True self.name = f"Scikits ODE solver ({method})" - self.max_step = validate_max_step(max_step) + self.max_step = max_step pybamm.citations.register("Malengier2018") pybamm.citations.register("Hindmarsh2000") pybamm.citations.register("Hindmarsh2005") - def _integrate(self, model, t_eval, inputs_dict=None): + def _integrate(self, model, t_eval, inputs_dict=None, max_step=np.inf): """ Solve a model defined by dydt with initial conditions y0. @@ -82,6 +81,8 @@ def _integrate(self, model, t_eval, inputs_dict=None): The times at which to compute the solution inputs_dict : dict, optional Any input parameters to pass to the model when solving + max_step : float, optional + Maximum allowed step size. Default is onp.inf. """ inputs_dict = inputs_dict or {} @@ -141,6 +142,7 @@ def jac_times_setupfn(t, y, fy, userdata): "old_api": False, "rtol": self.rtol, "atol": self.atol, + "max_step": max_step, } # Read linsolver (defaults to dense) diff --git a/src/pybamm/solvers/algebraic_solver.py b/src/pybamm/solvers/algebraic_solver.py index 39c5da4776..9b6663d007 100644 --- a/src/pybamm/solvers/algebraic_solver.py +++ b/src/pybamm/solvers/algebraic_solver.py @@ -6,7 +6,6 @@ import numpy as np from scipy import optimize from scipy.sparse import issparse -from .base_solver import validate_max_step class AlgebraicSolver(pybamm.BaseSolver): @@ -25,9 +24,6 @@ class AlgebraicSolver(pybamm.BaseSolver): specified in the form "lsq_methodname" tol : float, optional The tolerance for the solver (default is 1e-6). - max_step : float, optional - Maximum allowed step size. Default is np.inf, i.e., the step size is not - bounded and determined solely by the solver. extra_options : dict, optional Any options to pass to the rootfinder. Vary depending on which method is chosen. Please consult `SciPy documentation @@ -35,9 +31,8 @@ class AlgebraicSolver(pybamm.BaseSolver): for details. """ - def __init__(self, method="lm", tol=1e-6, max_step=np.inf, extra_options=None): + def __init__(self, method="lm", tol=1e-6, extra_options=None): super().__init__(method=method) - self.max_step = validate_max_step(max_step) self.tol = tol self.extra_options = extra_options or {} self.name = f"Algebraic solver ({method})" diff --git a/src/pybamm/solvers/casadi_algebraic_solver.py b/src/pybamm/solvers/casadi_algebraic_solver.py index 10b3b366e7..b139199f8c 100644 --- a/src/pybamm/solvers/casadi_algebraic_solver.py +++ b/src/pybamm/solvers/casadi_algebraic_solver.py @@ -1,7 +1,6 @@ import casadi import pybamm import numpy as np -from .base_solver import validate_max_step class CasadiAlgebraicSolver(pybamm.BaseSolver): @@ -15,9 +14,6 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver): ---------- tol : float, optional The tolerance for the solver (default is 1e-6). - max_step : float, optional - Maximum allowed step size. Default is np.inf, i.e., the step size is not - bounded and determined solely by the solver. extra_options : dict, optional Any options to pass to the CasADi rootfinder. Please consult `CasADi documentation `_ for @@ -25,7 +21,7 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver): """ - def __init__(self, tol=1e-6, max_step=np.inf, extra_options=None): + def __init__(self, tol=1e-6, extra_options=None): super().__init__() self.tol = tol self.name = "CasADi algebraic solver" diff --git a/src/pybamm/solvers/casadi_solver.py b/src/pybamm/solvers/casadi_solver.py index e6f697ae28..629d818042 100644 --- a/src/pybamm/solvers/casadi_solver.py +++ b/src/pybamm/solvers/casadi_solver.py @@ -4,7 +4,6 @@ import warnings from scipy.interpolate import interp1d from .lrudict import LRUDict -from .base_solver import validate_max_step class CasadiSolver(pybamm.BaseSolver): @@ -109,7 +108,7 @@ def __init__( "'fast', for solving quickly without events, or 'safe without grid' or " "'fast with events' (both experimental)" ) - self.max_step = validate_max_step(max_step) + self.max_step = max_step self.max_step_decrease_count = max_step_decrease_count self.dt_event = dt_event or 600 @@ -152,6 +151,8 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): The times at which to compute the solution inputs_dict : dict, optional Any input parameters to pass to the model when solving + max_step : float, optional + Maximum allowed step size. Default is np.inf. """ # Record whether there are any symbolic inputs @@ -171,10 +172,14 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): use_event_switch = False # Create an integrator with the grid (we just need to do this once) self.create_integrator( - model, inputs, t_eval, use_event_switch=use_event_switch + model, + inputs, + t_eval, + use_event_switch=use_event_switch, + max_step=self.max_step, ) solution = self._run_integrator( - model, model.y0, inputs_dict, inputs, t_eval + model, model.y0, inputs_dict, inputs, t_eval, max_step=self.max_step ) # Check if the sign of an event changes, if so find an accurate # termination point and exit @@ -193,7 +198,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): # in "safe without grid" mode, # create integrator once, without grid, # to avoid having to create several times - self.create_integrator(model, inputs) + self.create_integrator(model, inputs, max_step=self.max_step) # Initialize solution solution = pybamm.Solution( np.array([t]), @@ -241,7 +246,10 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): if self.mode == "safe": # update integrator with the grid - self.create_integrator(model, inputs, t_window) + self.create_integrator( + model, inputs, t_window, max_step=self.max_step + ) + # Try to solve with the current global step, if it fails then # halve the step size and try again. try: @@ -257,6 +265,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): t_window, use_grid=use_grid, extract_sensitivities_in_solution=False, + max_step=self.max_step, ) first_ts_solved = True solved = True diff --git a/src/pybamm/solvers/jax_solver.py b/src/pybamm/solvers/jax_solver.py index 5e9cf52b74..9413aedf18 100644 --- a/src/pybamm/solvers/jax_solver.py +++ b/src/pybamm/solvers/jax_solver.py @@ -5,7 +5,6 @@ import asyncio import pybamm -from .base_solver import validate_max_step if pybamm.has_jax(): import jax @@ -85,7 +84,7 @@ def __init__( self._ode_solver = method == "RK45" self.extra_options = extra_options or {} self.name = f"JAX solver ({method})" - self.max_step = validate_max_step(max_step) + self.max_step = max_step self._cached_solves = dict() pybamm.citations.register("jax2018") @@ -216,7 +215,8 @@ def _integrate(self, model, t_eval, inputs=None, t_interp=None): The times at which to compute the solution inputs : dict, list[dict], optional Any input parameters to pass to the model when solving - + max_step : float, optional + Maximum allowed step size. Default is np.inf. Returns ------- list of `pybamm.Solution` @@ -257,7 +257,9 @@ async def solve_model_async(inputs_v): inputs_v = { key: jnp.array([dic[key] for dic in inputs]) for key in inputs[0] } - y.extend(jax.vmap(self._cached_solves[model])(inputs_v)) + y.extend( + jax.vmap(self._cached_solves[model])(inputs_v, max_step=self.max_step) + ) else: # Unknown platform, use serial execution as fallback print( @@ -265,7 +267,7 @@ async def solve_model_async(inputs_v): "falling back to serial execution" ) for inputs_v in inputs: - y.append(self._cached_solves[model](inputs_v)) + y.append(self._cached_solves[model](inputs_v, max_step=self.max_step)) # This code block implements single-program multiple-data execution # using pmap across multiple XLAs. It is currently commented out diff --git a/src/pybamm/solvers/scipy_solver.py b/src/pybamm/solvers/scipy_solver.py index 078b9ec299..d038f73e6e 100644 --- a/src/pybamm/solvers/scipy_solver.py +++ b/src/pybamm/solvers/scipy_solver.py @@ -7,8 +7,6 @@ import scipy.integrate as it import numpy as np -from .base_solver import validate_max_step - class ScipySolver(pybamm.BaseSolver): """Solve a discretised model, using scipy.integrate.solve_ivp. @@ -52,7 +50,7 @@ def __init__( self._ode_solver = True self.extra_options = extra_options or {} self.name = f"Scipy solver ({method})" - self.max_step = validate_max_step(max_step) + self.max_step = max_step pybamm.citations.register("Virtanen2020") def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): @@ -67,6 +65,8 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): The times at which to compute the solution inputs_dict : dict, optional Any input parameters to pass to the model when solving + max_step : float, optional + Maximum allowed step size. Default is onp.inf. Returns ------- @@ -82,7 +82,12 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): else: inputs = inputs_dict - extra_options = {**self.extra_options, "rtol": self.rtol, "atol": self.atol} + extra_options = { + **self.extra_options, + "rtol": self.rtol, + "atol": self.atol, + "max_step": self.max_step, + } # Initial conditions y0 = model.y0