Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds max_step param to base_solver #4673

Draft
wants to merge 18 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/work_precision_sets.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
run: python -m pip install pybamm==${{ env.VERSION }}
- name: Run time_vs_* benchmarks for PyBaMM v${{ env.VERSION }}
run: |
python benchmarks/work_precision_sets/time_vs_dt_max.py
python benchmarks/work_precision_sets/time_vs_dt_event.py
python benchmarks/work_precision_sets/time_vs_mesh_size.py
python benchmarks/work_precision_sets/time_vs_no_of_states.py
python benchmarks/work_precision_sets/time_vs_reltols.py
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add an entry to the breaking changes, concerning the rename of dt_max to dt_event.

Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ package to install PyBaMM with only the required dependencies. ([conda-forge/pyb
- Added `WyciskOpenCircuitPotential` for differential capacity hysteresis state open-circuit potential submodel ([#3593](https://github.com/pybamm-team/PyBaMM/pull/3593))
- Transport efficiency submodel has new options from the literature relating to different tortuosity factor models and also a new option called "tortuosity factor" for specifying the value or function directly as parameters ([#3437](https://github.com/pybamm-team/PyBaMM/pull/3437))
- Heat of mixing source term can now be included into thermal models ([#2837](https://github.com/pybamm-team/PyBaMM/pull/2837))
- Added `max_step` parameter to `BaseSolver` and passed it to dependent solvers ([#3106](https://github.com/pybamm-team/PyBaMM/pull/3106))

## Bug Fixes

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/release-work-precision-sets.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

<img src='./benchmark_images/time_vs_mesh_size_22.7.png'>

## Solve Time vs dt_max
## Solve Time vs dt_event

<img src='./benchmark_images/time_vs_dt_max_22.7.png'>
<img src='./benchmark_images/time_vs_dt_event_22.7.png'>

## Solve Time vs Number of states

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

models = {"SPM": pybamm.lithium_ion.SPM(), "DFN": pybamm.lithium_ion.DFN()}

dt_max = [
dt_event = [
10,
20,
50,
Expand Down Expand Up @@ -70,8 +70,8 @@
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

for t in dt_max:
solver = pybamm.CasadiSolver(dt_max=t)
for t in dt_event:
solver = pybamm.CasadiSolver(dt_event=t)

solver.solve(model, t_eval=t_eval)
time = 0
Expand All @@ -85,20 +85,20 @@

ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("dt_max")
ax.set_xlabel("dt_event")
ax.set_ylabel("time(s)")
ax.set_title(f"{model_name}")
ax.plot(dt_max, time_points)
ax.plot(dt_event, time_points)

plt.tight_layout()
plt.gca().legend(
parameters,
loc="upper right",
)
plt.savefig(f"benchmarks/benchmark_images/time_vs_dt_max_{pybamm.__version__}.png")
plt.savefig(f"benchmarks/benchmark_images/time_vs_dt_event_{pybamm.__version__}.png")


content = f"## Solve Time vs dt_max\n<img src='./benchmark_images/time_vs_dt_max_{pybamm.__version__}.png'>\n"
content = f"## Solve Time vs dt_event\n<img src='./benchmark_images/time_vs_dt_event_{pybamm.__version__}.png'>\n"

with open("./benchmarks/release_work_precision_sets.md") as original:
data = original.read()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@
" sim = pybamm.Simulation(\n",
" model,\n",
" parameter_values=param,\n",
" solver=pybamm.CasadiSolver(dt_max=5),\n",
" solver=pybamm.CasadiSolver(dt_event=5),\n",
" )\n",
" solution.append(sim.solve(t_eval=t_eval))\n",
"stop = timeit.default_timer()\n",
Expand Down Expand Up @@ -887,7 +887,7 @@
" model,\n",
" experiment=experiment,\n",
" parameter_values=param,\n",
" solver=pybamm.CasadiSolver(dt_max=5),\n",
" solver=pybamm.CasadiSolver(dt_event=5),\n",
" )\n",
" solution.append(sim.solve(calc_esoh=False))\n",
"stop = timeit.default_timer()\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
" [f\"Discharge at {C_rate:.4f}C until 3.2V\"], period=f\"{10 / C_rate:.4f} seconds\"\n",
" )\n",
" sim = pybamm.Simulation(\n",
" model, experiment=experiment, solver=pybamm.CasadiSolver(dt_max=120)\n",
" model, experiment=experiment, solver=pybamm.CasadiSolver(dt_event=120)\n",
" )\n",
" sim.solve()\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
"sim = pybamm.Simulation(\n",
" model,\n",
" parameter_values=param,\n",
" solver=pybamm.CasadiSolver(dt_max=600),\n",
" solver=pybamm.CasadiSolver(dt_event=600),\n",
" var_pts=var_pts,\n",
")\n",
"solution = sim.solve(t_eval=[0, 3600], inputs={\"C-rate\": 1})\n",
Expand Down
48 changes: 24 additions & 24 deletions docs/source/examples/notebooks/solvers/speed-up-solver.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -366,22 +366,22 @@
"metadata": {},
"outputs": [],
"source": [
"safe_solver_2 = pybamm.CasadiSolver(mode=\"safe\", dt_max=30)\n",
"safe_solver_2 = pybamm.CasadiSolver(mode=\"safe\", dt_event=30)\n",
"safe_sol_2 = sim.solve([0, 160], solver=safe_solver_2, inputs={\"Crate\": 10})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Choosing dt_max to speed up the safe mode"
"### Choosing dt_event to speed up the safe mode"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The parameter `dt_max` controls how large the steps taken by the `CasadiSolver` with \"safe\" mode are when looking for events."
"The parameter `dt_event` controls how large the steps taken by the `CasadiSolver` with \"safe\" mode are when looking for events."
]
},
{
Expand All @@ -393,24 +393,24 @@
"name": "stdout",
"output_type": "stream",
"text": [
"With dt_max=10, took 575.783 ms (integration time: 508.473 ms)\n",
"With dt_max=20, took 575.500 ms (integration time: 510.705 ms)\n",
"With dt_max=100, took 316.721 ms (integration time: 275.459 ms)\n",
"With dt_max=1000, took 76.646 ms (integration time: 49.294 ms)\n",
"With dt_max=3700, took 48.773 ms (integration time: 32.436 ms)\n",
"With dt_event=10, took 575.783 ms (integration time: 508.473 ms)\n",
"With dt_event=20, took 575.500 ms (integration time: 510.705 ms)\n",
"With dt_event=100, took 316.721 ms (integration time: 275.459 ms)\n",
"With dt_event=1000, took 76.646 ms (integration time: 49.294 ms)\n",
"With dt_event=3700, took 48.773 ms (integration time: 32.436 ms)\n",
"With 'fast' mode, took 42.224 ms (integration time: 32.177 ms)\n"
]
}
],
"source": [
"for dt_max in [10, 20, 100, 1000, 3700]:\n",
"for dt_event in [10, 20, 100, 1000, 3700]:\n",
" safe_sol = sim.solve(\n",
" [0, 3600],\n",
" solver=pybamm.CasadiSolver(mode=\"safe\", dt_max=dt_max),\n",
" solver=pybamm.CasadiSolver(mode=\"safe\", dt_event=dt_event),\n",
" inputs={\"Crate\": 1},\n",
" )\n",
" print(\n",
" f\"With dt_max={dt_max}, took {safe_sol.solve_time} \"\n",
" f\"With dt_event={dt_event}, took {safe_sol.solve_time} \"\n",
" + f\"(integration time: {safe_sol.integration_time})\"\n",
" )\n",
"\n",
Expand All @@ -425,17 +425,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In general, a larger value of `dt_max` gives a faster solution, since fewer integrator creations and calls are required.\n",
"In general, a larger value of `dt_event` gives a faster solution, since fewer integrator creations and calls are required.\n",
"\n",
"Below the solution time interval of 36s, the value of `dt_max` does not affect the solve time, since steps must be at least 36s large.\n",
"Below the solution time interval of 36s, the value of `dt_event` does not affect the solve time, since steps must be at least 36s large.\n",
"The discrepancy between the solve time and integration time is due to the extra operations recorded by \"solve time\", such as creating the integrator. The \"fast\" solver does not need to do this (it reuses the first one it had already created), so the solve time is much closer to the integration time."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The example above was a case where no events are triggered, so the largest `dt_max` works well. If we step over events, then it is possible to makes `dt_max` too large, so that the solver will attempt (and fail) to take large steps past the event, iteratively reducing the step size until it works. For example:"
"The example above was a case where no events are triggered, so the largest `dt_event` works well. If we step over events, then it is possible to makes `dt_event` too large, so that the solver will attempt (and fail) to take large steps past the event, iteratively reducing the step size until it works. For example:"
]
},
{
Expand All @@ -447,10 +447,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"With dt_max=10, took 504.163 ms (integration time: 419.740 ms)\n",
"With dt_max=20, took 504.691 ms (integration time: 421.396 ms)\n",
"With dt_max=100, took 286.620 ms (integration time: 238.390 ms)\n",
"With dt_max=1000, took 98.500 ms (integration time: 60.880 ms)\n"
"With dt_event=10, took 504.163 ms (integration time: 419.740 ms)\n",
"With dt_event=20, took 504.691 ms (integration time: 421.396 ms)\n",
"With dt_event=100, took 286.620 ms (integration time: 238.390 ms)\n",
"With dt_event=1000, took 98.500 ms (integration time: 60.880 ms)\n"
]
},
{
Expand All @@ -466,22 +466,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
"With dt_max=3600, took 645.118 ms (integration time: 32.601 ms)\n"
"With dt_event=3600, took 645.118 ms (integration time: 32.601 ms)\n"
]
}
],
"source": [
"for dt_max in [10, 20, 100, 1000, 3600]:\n",
"for dt_event in [10, 20, 100, 1000, 3600]:\n",
" # Reduce max_num_steps to fail faster\n",
" safe_sol = sim.solve(\n",
" [0, 4500],\n",
" solver=pybamm.CasadiSolver(\n",
" mode=\"safe\", dt_max=dt_max, extra_options_setup={\"max_num_steps\": 1000}\n",
" mode=\"safe\", dt_event=dt_event, extra_options_setup={\"max_num_steps\": 1000}\n",
" ),\n",
" inputs={\"Crate\": 1},\n",
" )\n",
" print(\n",
" f\"With dt_max={dt_max}, took {safe_sol.solve_time} \"\n",
" f\"With dt_event={dt_event}, took {safe_sol.solve_time} \"\n",
" + f\"(integration time: {safe_sol.integration_time})\"\n",
" )"
]
Expand All @@ -490,7 +490,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The integration time with `dt_max=3600` remains the fastest, but the solve time is the slowest due to all the failed steps."
"The integration time with `dt_event=3600` remains the fastest, but the solve time is the slowest due to all the failed steps."
]
},
{
Expand All @@ -504,7 +504,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The \"period\" argument of the experiments also affects how long the simulations take, for a similar reason to `dt_max`. Therefore, this argument can be manually tuned to speed up how long an experiment takes to solve."
"The \"period\" argument of the experiments also affects how long the simulations take, for a similar reason to `dt_event`. Therefore, this argument can be manually tuned to speed up how long an experiment takes to solve."
]
},
{
Expand Down
5 changes: 5 additions & 0 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class BaseSolver:
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
max_step : float, optional
Maximum allowed step size. Default is np.inf, i.e., the step size is not
bounded and determined solely by the solver.
output_variables : list[str], optional
List of variables to calculate and return. If none are specified then
the complete state vector is returned (can be very large) (default is [])
Expand All @@ -51,13 +54,15 @@ def __init__(
root_tol=1e-6,
extrap_tol=None,
output_variables=None,
max_step=np.inf,
):
self.method = method
self.rtol = rtol
self.atol = atol
self.root_tol = root_tol
self.root_method = root_method
self.extrap_tol = extrap_tol or -1e-10
self.max_step = max_step
self.output_variables = [] if output_variables is None else output_variables
self._model_set_up = {}

Expand Down
Loading
Loading