Skip to content

Commit

Permalink
made changes
Browse files Browse the repository at this point in the history
  • Loading branch information
medha-14 committed Dec 14, 2024
1 parent ac02b61 commit 0900b03
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 39 deletions.
13 changes: 8 additions & 5 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import numbers
import scipy.sparse as sparse
from .base_solver import validate_max_step

import importlib

Expand Down Expand Up @@ -92,7 +91,7 @@ def __init__(
root_tol=1e-6,
extrap_tol=None,
max_step=np.inf,
output_variables=[],
output_variables=None,
options=None,
):
# set default options,
Expand All @@ -115,7 +114,7 @@ def __init__(
options[key] = value
self._options = options

self.max_step = validate_max_step(max_step)
self.max_step = max_step

self.output_variables = output_variables

Expand Down Expand Up @@ -517,7 +516,7 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS):

return base_set_up_return

def _integrate(self, model, t_eval, inputs_dict=None):
def _integrate(self, model, t_eval, inputs_dict=None, max_step=np.inf):
"""
Solve a DAE model defined by residuals with initial conditions y0.
Expand All @@ -529,6 +528,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
The times at which to compute the solution
inputs_dict : dict, optional
Any input parameters to pass to the model when solving
max_step : float, optional
Maximum allowed step size. Default is np.inf.
"""
inputs_dict = inputs_dict or {}
# stack inputs
Expand Down Expand Up @@ -580,6 +581,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
y0full,
ydot0full,
inputs,
max_step=max_step,
)
else:
sol = idaklu.solve_python(
Expand All @@ -601,6 +603,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
rtol,
inputs,
self._setup["number_of_sensitivity_parameters"],
max_step=max_step,
)
integration_time = timer.time()

Expand Down Expand Up @@ -655,7 +658,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
number_of_samples = sol.y.shape[0] // number_of_timesteps
sol.y = sol.y.reshape((number_of_timesteps, number_of_samples))
startk = 0
for vark, var in enumerate(self.output_variables):
for _vark, var in enumerate(self.output_variables):
# ExplicitTimeIntegral's are not computed as part of the solver and
# do not need to be converted
if isinstance(
Expand Down
11 changes: 6 additions & 5 deletions pybamm/solvers/scikits_dae_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import importlib
import scipy.sparse as sparse

from .base_solver import validate_max_step

scikits_odes_spec = importlib.util.find_spec("scikits")
if scikits_odes_spec is not None:
scikits_odes_spec = importlib.util.find_spec("scikits.odes")
Expand Down Expand Up @@ -72,13 +70,13 @@ def __init__(
self.name = f"Scikits DAE solver ({method})"

self.extra_options = extra_options or {}
self.max_step = validate_max_step(max_step)
self.max_step = max_step

pybamm.citations.register("Malengier2018")
pybamm.citations.register("Hindmarsh2000")
pybamm.citations.register("Hindmarsh2005")

def _integrate(self, model, t_eval, inputs_dict=None):
def _integrate(self, model, t_eval, inputs_dict=None, max_step=np.inf):
"""
Solve a model defined by dydt with initial conditions y0.
Expand All @@ -90,6 +88,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
The times at which to compute the solution
inputs_dict : dict, optional
Any input parameters to pass to the model when solving
max_step : float, optional
Maximum allowed step size. Default is onp.inf.
"""
inputs_dict = inputs_dict or {}
Expand Down Expand Up @@ -134,6 +134,7 @@ def rootfn(t, y, ydot, return_root):
"old_api": False,
"rtol": self.rtol,
"atol": self.atol,
"max_step": max_step,
}

if jacobian:
Expand Down Expand Up @@ -164,7 +165,7 @@ def jacfn(t, y, ydot, residuals, cj, J):
sol = dae_solver.solve(t_eval, y0, ydot0)
integration_time = timer.time()

# return solution, we need to tranpose y to match scipy's interface
# return solution, we need to transpose y to match scipy's interface
if sol.flag in [0, 2]:
# 0 = solved for all t_eval
if sol.flag == 0:
Expand Down
8 changes: 5 additions & 3 deletions pybamm/solvers/scikits_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import importlib
import scipy.sparse as sparse

from .base_solver import validate_max_step

scikits_odes_spec = importlib.util.find_spec("scikits")
if scikits_odes_spec is not None:
Expand Down Expand Up @@ -64,13 +63,13 @@ def __init__(
self.extra_options = extra_options or {}
self.ode_solver = True
self.name = f"Scikits ODE solver ({method})"
self.max_step = validate_max_step(max_step)
self.max_step = max_step

pybamm.citations.register("Malengier2018")
pybamm.citations.register("Hindmarsh2000")
pybamm.citations.register("Hindmarsh2005")

def _integrate(self, model, t_eval, inputs_dict=None):
def _integrate(self, model, t_eval, inputs_dict=None, max_step=np.inf):
"""
Solve a model defined by dydt with initial conditions y0.
Expand All @@ -82,6 +81,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
The times at which to compute the solution
inputs_dict : dict, optional
Any input parameters to pass to the model when solving
max_step : float, optional
Maximum allowed step size. Default is onp.inf.
"""
inputs_dict = inputs_dict or {}
Expand Down Expand Up @@ -141,6 +142,7 @@ def jac_times_setupfn(t, y, fy, userdata):
"old_api": False,
"rtol": self.rtol,
"atol": self.atol,
"max_step": max_step,
}

# Read linsolver (defaults to dense)
Expand Down
7 changes: 1 addition & 6 deletions src/pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
from scipy import optimize
from scipy.sparse import issparse
from .base_solver import validate_max_step


class AlgebraicSolver(pybamm.BaseSolver):
Expand All @@ -25,19 +24,15 @@ class AlgebraicSolver(pybamm.BaseSolver):
specified in the form "lsq_methodname"
tol : float, optional
The tolerance for the solver (default is 1e-6).
max_step : float, optional
Maximum allowed step size. Default is np.inf, i.e., the step size is not
bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the rootfinder. Vary depending on which method is chosen.
Please consult `SciPy documentation
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.show_options.html>`_
for details.
"""

def __init__(self, method="lm", tol=1e-6, max_step=np.inf, extra_options=None):
def __init__(self, method="lm", tol=1e-6, extra_options=None):
super().__init__(method=method)
self.max_step = validate_max_step(max_step)
self.tol = tol
self.extra_options = extra_options or {}
self.name = f"Algebraic solver ({method})"
Expand Down
6 changes: 1 addition & 5 deletions src/pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import casadi
import pybamm
import numpy as np
from .base_solver import validate_max_step


class CasadiAlgebraicSolver(pybamm.BaseSolver):
Expand All @@ -15,17 +14,14 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver):
----------
tol : float, optional
The tolerance for the solver (default is 1e-6).
max_step : float, optional
Maximum allowed step size. Default is np.inf, i.e., the step size is not
bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the CasADi rootfinder.
Please consult `CasADi documentation <https://web.casadi.org/python-api/#rootfinding>`_ for
details.
"""

def __init__(self, tol=1e-6, max_step=np.inf, extra_options=None):
def __init__(self, tol=1e-6, extra_options=None):
super().__init__()
self.tol = tol
self.name = "CasADi algebraic solver"
Expand Down
21 changes: 15 additions & 6 deletions src/pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings
from scipy.interpolate import interp1d
from .lrudict import LRUDict
from .base_solver import validate_max_step


class CasadiSolver(pybamm.BaseSolver):
Expand Down Expand Up @@ -109,7 +108,7 @@ def __init__(
"'fast', for solving quickly without events, or 'safe without grid' or "
"'fast with events' (both experimental)"
)
self.max_step = validate_max_step(max_step)
self.max_step = max_step
self.max_step_decrease_count = max_step_decrease_count
self.dt_event = dt_event or 600

Expand Down Expand Up @@ -152,6 +151,8 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
The times at which to compute the solution
inputs_dict : dict, optional
Any input parameters to pass to the model when solving
max_step : float, optional
Maximum allowed step size. Default is np.inf.
"""

# Record whether there are any symbolic inputs
Expand All @@ -171,10 +172,14 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
use_event_switch = False
# Create an integrator with the grid (we just need to do this once)
self.create_integrator(
model, inputs, t_eval, use_event_switch=use_event_switch
model,
inputs,
t_eval,
use_event_switch=use_event_switch,
max_step=self.max_step,
)
solution = self._run_integrator(
model, model.y0, inputs_dict, inputs, t_eval
model, model.y0, inputs_dict, inputs, t_eval, max_step=self.max_step
)
# Check if the sign of an event changes, if so find an accurate
# termination point and exit
Expand All @@ -193,7 +198,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
# in "safe without grid" mode,
# create integrator once, without grid,
# to avoid having to create several times
self.create_integrator(model, inputs)
self.create_integrator(model, inputs, max_step=self.max_step)
# Initialize solution
solution = pybamm.Solution(
np.array([t]),
Expand Down Expand Up @@ -241,7 +246,10 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):

if self.mode == "safe":
# update integrator with the grid
self.create_integrator(model, inputs, t_window)
self.create_integrator(
model, inputs, t_window, max_step=self.max_step
)

# Try to solve with the current global step, if it fails then
# halve the step size and try again.
try:
Expand All @@ -257,6 +265,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
t_window,
use_grid=use_grid,
extract_sensitivities_in_solution=False,
max_step=self.max_step,
)
first_ts_solved = True
solved = True
Expand Down
12 changes: 7 additions & 5 deletions src/pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import asyncio

import pybamm
from .base_solver import validate_max_step

if pybamm.has_jax():
import jax
Expand Down Expand Up @@ -85,7 +84,7 @@ def __init__(
self._ode_solver = method == "RK45"
self.extra_options = extra_options or {}
self.name = f"JAX solver ({method})"
self.max_step = validate_max_step(max_step)
self.max_step = max_step
self._cached_solves = dict()
pybamm.citations.register("jax2018")

Expand Down Expand Up @@ -216,7 +215,8 @@ def _integrate(self, model, t_eval, inputs=None, t_interp=None):
The times at which to compute the solution
inputs : dict, list[dict], optional
Any input parameters to pass to the model when solving
max_step : float, optional
Maximum allowed step size. Default is np.inf.
Returns
-------
list of `pybamm.Solution`
Expand Down Expand Up @@ -257,15 +257,17 @@ async def solve_model_async(inputs_v):
inputs_v = {
key: jnp.array([dic[key] for dic in inputs]) for key in inputs[0]
}
y.extend(jax.vmap(self._cached_solves[model])(inputs_v))
y.extend(
jax.vmap(self._cached_solves[model])(inputs_v, max_step=self.max_step)
)
else:
# Unknown platform, use serial execution as fallback
print(
f'Unknown platform requested: "{platform}", '
"falling back to serial execution"
)
for inputs_v in inputs:
y.append(self._cached_solves[model](inputs_v))
y.append(self._cached_solves[model](inputs_v, max_step=self.max_step))

# This code block implements single-program multiple-data execution
# using pmap across multiple XLAs. It is currently commented out
Expand Down
13 changes: 9 additions & 4 deletions src/pybamm/solvers/scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import scipy.integrate as it
import numpy as np

from .base_solver import validate_max_step


class ScipySolver(pybamm.BaseSolver):
"""Solve a discretised model, using scipy.integrate.solve_ivp.
Expand Down Expand Up @@ -52,7 +50,7 @@ def __init__(
self._ode_solver = True
self.extra_options = extra_options or {}
self.name = f"Scipy solver ({method})"
self.max_step = validate_max_step(max_step)
self.max_step = max_step
pybamm.citations.register("Virtanen2020")

def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
Expand All @@ -67,6 +65,8 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
The times at which to compute the solution
inputs_dict : dict, optional
Any input parameters to pass to the model when solving
max_step : float, optional
Maximum allowed step size. Default is onp.inf.
Returns
-------
Expand All @@ -82,7 +82,12 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
else:
inputs = inputs_dict

extra_options = {**self.extra_options, "rtol": self.rtol, "atol": self.atol}
extra_options = {
**self.extra_options,
"rtol": self.rtol,
"atol": self.atol,
"max_step": self.max_step,
}

# Initial conditions
y0 = model.y0
Expand Down

0 comments on commit 0900b03

Please sign in to comment.