Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I4087-multiprocessing #4260

Draft
wants to merge 23 commits into
base: develop
Choose a base branch
from
Draft

I4087-multiprocessing #4260

wants to merge 23 commits into from

Conversation

martinjrobins
Copy link
Contributor

@martinjrobins martinjrobins commented Jul 12, 2024

Description

Partially fixes #4087

This PR expands pybamm's ability to run multiple simulations with different values for input parameters, using multiple cpu threads. The new features added are:

  1. input parameters can now be included in initial conditions, so you can run many simulations in parallel all from different initial conditions
  2. if n sets of input parameters are given to solve, user can set batch_size=n_b argument to batch these n simulations into groups of size n_b, so that the number of solves is reduced to n / n_b. This effectivly duplicates the set of equations solved n_b times, so that each solve has n_s * n_b equations, where n_s is the original number of equations. Events are handled using a softmax function over the batch of simulations, so that the simulation continues until the event is triggered in all simulations in the batch. Note that experiments show that this is effective in reducing simulation duration even when using a single thread, although this batching is probably most useful for running simulations on a GPU (or if you have a small system of equations). However, this is not an effective replacement for the use of the multiprocessing library if you are targetting CPUs, which was the original plan in refactor multiprocessing and multiple inputs #4087.
  3. all solvers have a "num_threads" option that sets the number of threads to use when solving simulations in parallel. The threads are first allocated to solving groups of simulations independently on different threads, and if there are more threads than simulations, then the threads are allocated to solving each batch of simulations in parallel (this is not very effective unless you have a big system of equations)
  4. Each solver has different parallisation strategies:
    • the casadi and scipy solvers use the original strategy of using the python multiprocessing library. The casadi solver has an effective vmap that we can use, that would be better than multiprocessing, but I think its worth waiting until casadi implements events later this year before we swap to this.
    • the idaklu solver uses openmp in the C++ solver code for running simulations in parallel, so can do this in a much more efficient manner than multiprocessing
    • jax uses the original strategy of using either vmap for gpu targets, and asyncio for cpu targets

Note that while the original plan for #4087 was to remove multiprocessing, I've only managed to do this for the idaklu solver for now. Once casadi implements events then we can remove multiprocessing (replace with casadi.vmap) for this solver as well.

Other:

  • all backends (jax, python, casadi) have been made more consitent in how they take inputs into their evaluate functions, they now all take an array (this was the original behaviour for the casadi backend)
  • the Symbol.diff function now returns a vector if a vector expression is differentiated wrt itself (previously this gave a scalar). This was a source of a bug and I thought this made more sense mathematically

Here is a script to demonstrate the timing for solving 1000 SPM models in parallel

import pybamm
import time 


n = 1000
inputs = [{"Current function [A]": 0.1 + i / n} for i in range(n)]

for solver_cls in [pybamm.CasadiSolver, pybamm.IDAKLUSolver]:
    print("=====================================")
    print(f"solver_cls = {solver_cls}")
    print("=====================================")
    for num_threads in [1, 5, 10, 20]:
        for batch_size in [1, 5, 10]:
            model = pybamm.lithium_ion.SPM()
            param = model.default_parameter_values
            param.update({"Current function [A]": '[input]'})
            solver = solver_cls(options={"num_threads": num_threads})
            sim = pybamm.Simulation(model, parameter_values=param, solver=solver)
            start = time.perf_counter()
            sim.solve(t_eval=[0, 3600], inputs=inputs, batch_size=batch_size)
            #sim.solve(t_eval=[0, 3600], inputs=inputs)
            end = time.perf_counter()
            print(f"num_threads = {num_threads}, batch_size = {batch_size}, time = {end - start}")

This gives, on my machine:

=====================================
solver_cls = <class 'pybamm.solvers.casadi_solver.CasadiSolver'>
=====================================
/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
num_threads = 1, batch_size = 1, time = 8.082125603104942
num_threads = 1, batch_size = 5, time = 5.098211347940378
num_threads = 1, batch_size = 10, time = 4.66524504206609
num_threads = 5, batch_size = 1, time = 5.394547090982087
num_threads = 5, batch_size = 5, time = 4.402106147957966
num_threads = 5, batch_size = 10, time = 4.022565938998014
num_threads = 10, batch_size = 1, time = 8.073186431080103
num_threads = 10, batch_size = 5, time = 6.842109054909088
num_threads = 10, batch_size = 10, time = 5.623942147009075
num_threads = 20, batch_size = 1, time = 11.686325625982136
num_threads = 20, batch_size = 5, time = 8.376022804994136
num_threads = 20, batch_size = 10, time = 6.587947316933423
=====================================
solver_cls = <class 'pybamm.solvers.idaklu_solver.IDAKLUSolver'>
=====================================
num_threads = 1, batch_size = 1, time = 3.4076489090221003
num_threads = 1, batch_size = 5, time = 3.1277068270137534
num_threads = 1, batch_size = 10, time = 3.058925225981511
num_threads = 5, batch_size = 1, time = 1.5330516750691459
num_threads = 5, batch_size = 5, time = 1.8383638469967991
num_threads = 5, batch_size = 10, time = 1.2027837909990922
num_threads = 10, batch_size = 1, time = 1.4807175860041752
num_threads = 10, batch_size = 5, time = 1.043990166974254
num_threads = 10, batch_size = 10, time = 0.933833296992816
num_threads = 20, batch_size = 1, time = 1.5065614050254226
num_threads = 20, batch_size = 5, time = 0.8734990020748228
num_threads = 20, batch_size = 10, time = 0.9002695180242881

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)

Key checklist:

  • [ x] No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

@kratman
Copy link
Contributor

kratman commented Jan 7, 2025

This will have to be moved to pybammsolvers

@kratman kratman closed this Jan 7, 2025
@kratman
Copy link
Contributor

kratman commented Jan 7, 2025

Sorry did not mean to close this when I commented

@kratman kratman reopened this Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

refactor multiprocessing and multiple inputs
2 participants