From c599525cd6e6d78c0b9c8a88e8a241fd25df52a1 Mon Sep 17 00:00:00 2001 From: "Corey R. Randall" Date: Fri, 20 Dec 2024 11:52:10 -0700 Subject: [PATCH] Propogate Python exceptions through err_handler --- README.md | 10 +- docs/source/_templates/copyright.html | 12 +++ docs/source/conf.py | 2 +- environments/ci_environment.yml | 2 +- environments/rtd_environment.yml | 2 +- images/tests.svg | 8 +- scripts/version_checker.py | 132 ++++++++++++++++++++++++++ src/sksundae/__init__.py | 2 +- src/sksundae/_cy_common.pxd | 2 +- src/sksundae/_cy_cvode.pyx | 8 +- src/sksundae/_cy_ida.pyx | 8 +- src/sksundae/c_cvode.pxd | 8 +- src/sksundae/c_ida.pxd | 6 +- src/sksundae/c_sundials.pxd | 2 +- tests/test_cvode.py | 48 ++++++++++ tests/test_ida.py | 69 ++++++++++++-- 16 files changed, 283 insertions(+), 38 deletions(-) create mode 100644 docs/source/_templates/copyright.html create mode 100644 scripts/version_checker.py diff --git a/README.md b/README.md index 0d0f5c3..70c5c41 100644 --- a/README.md +++ b/README.md @@ -85,17 +85,17 @@ plt.show() ## Citing this Work This work was authored by researchers at the National Renewable Energy Laboratory (NREL). The project is tracked in NREL's software records under SWR-24-137 and has a DOI available for citing the work. If you use use this package in your work, please include the following citation: -> Placeholder... waiting for DOI. +> Randall, Corey R. "scikit-SUNDAE ((SUN)DIALS Differential Algebraic Equations) [SWR-24-137]." Computer software. url: https://github.com/NREL/scikit-sundae. doi: https://doi.org/10.11578/dc.20241104.3. For convenience, we also provide the following for your BibTex: ``` -@misc{Randall2024, - title = {{scikit-SUNDAE: Python bindings to SUNDIALS DAE solvers}}, +@misc{Randall-2024, + title = {{scikit-SUNDAE ((SUN)DIALS Differential Algebraic Equations) [SWR-24-137]}}, author = {Randall, Corey R.}, - year = {2024}, - doi = {placeholder... waiting for DOI}, + doi = {10.11578/dc.20241104.3}, url = {https://github.com/NREL/scikit-sundae}, + year = {2024}, } ``` diff --git a/docs/source/_templates/copyright.html b/docs/source/_templates/copyright.html new file mode 100644 index 0000000..08f6c80 --- /dev/null +++ b/docs/source/_templates/copyright.html @@ -0,0 +1,12 @@ +{# Displays the copyright information (which is defined in conf.py). #} +{% if show_copyright and copyright %} + +{% endif %} \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 0a66995..73ef902 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,7 @@ import sksundae as sun project = 'scikit-sundae' -copyright = '2024, Corey R. Randall' +copyright = 'Alliance for Sustainable Energy, LLC' author = 'Corey R. Randall' version = sun.__version__ release = sun.__version__ diff --git a/environments/ci_environment.yml b/environments/ci_environment.yml index 4f6f1e4..1818a53 100644 --- a/environments/ci_environment.yml +++ b/environments/ci_environment.yml @@ -2,4 +2,4 @@ name: sun channels: - conda-forge dependencies: - - sundials=7.1 \ No newline at end of file + - sundials=7.2 \ No newline at end of file diff --git a/environments/rtd_environment.yml b/environments/rtd_environment.yml index f6d4b0a..e008fdd 100644 --- a/environments/rtd_environment.yml +++ b/environments/rtd_environment.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: - python=3.12 - - sundials=7.1 + - sundials=7.2 - pip>=24.3 - pip: - ../.[docs] \ No newline at end of file diff --git a/images/tests.svg b/images/tests.svg index 320f899..282c9bb 100644 --- a/images/tests.svg +++ b/images/tests.svg @@ -1,5 +1,5 @@ - - tests: 29 + + tests: 31 @@ -15,7 +15,7 @@ tests - - 29 + + 31 diff --git a/scripts/version_checker.py b/scripts/version_checker.py new file mode 100644 index 0000000..5a7c804 --- /dev/null +++ b/scripts/version_checker.py @@ -0,0 +1,132 @@ +import requests +import argparse +from packaging.version import Version + + +def get_latest_version(package: str, prefix: str = None) -> str: + """ + Fetch the latest version with matching prefix from PyPI. + + Parameters + ---------- + package : str + The name of the package to query. + prefix : str + A filtering prefix used to get a subset of releases, e.g., '1.1' will + return the latest patch to version 1.1 even if 1.2 exists. + + Returns + ------- + latest_version : str + The latest version available on PyPI. '0.0.0' is returned if there + are no versions. + + Raises + ------ + ValueError + Failed to fetch PyPI data for requested package. + + """ + + url = f"https://pypi.org/pypi/{package}/json" + + response = requests.get(url) + if response.status_code != 200: + raise ValueError(f"Failed to fetch PyPI data for '{package}'.") + + data = response.json() + versions = list(data['releases'].keys()) + if not versions: + print(f"{package=} not found on PyPI.") + return '0.0.0' + + if prefix: + versions = [v for v in versions if v.startswith(prefix)] + assert len(versions) != 0, f"{prefix=} has no existing matches." + + sorted_versions = sorted(versions, key=Version, reverse=True) + latest_version = sorted_versions[0] + + print(f"Latest PyPI version for {package}: {latest_version}.") + return latest_version + + +def check_against_pypi(pypi: str, local: str) -> None: + """ + Verify the local version is newer than PyPI. + + Parameters + ---------- + pypi : str + Latest version on PyPI. + local : str + Local package version. + + Returns + ------- + None. + + Raises + ------ + ValueError + Local package is older than PyPI. + + """ + + pypi = Version(pypi) + local = Version(local) + + if local < pypi: + raise ValueError(f"Local package {local} is older than PyPI {pypi}.") + + print(f"Local package {local} is newer than PyPI {pypi}.") + + +def check_against_tag(tag: str, local: str) -> None: + """ + Check that the tag matches the package version. + + Parameters + ---------- + tag : str + Semmantically versioned tag. + local : str + Local package version. + + Returns + ------- + None. + + Raises + ------ + ValueError + Version mismatch: tag differs from local. + + """ + + tag = Version(tag) + local = Version(local) + + if tag != local: + raise ValueError(f"Version mismatch: {tag=} vs. {local=}") + + print(f"Local and tag versions match: {tag} == {local}.") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--tag', required=True) + parser.add_argument('--local', required=True) + args = parser.parse_args() + + check_against_tag(args.tag, args.local) + + patch_check = Version(args.local) + if patch_check.micro > 0: + prefix = str(patch_check.major) + '.' + str(patch_check.minor) + else: + prefix = None + + pypi = get_latest_version('scikit-sundae', prefix) + + check_against_pypi(pypi, args.local) diff --git a/src/sksundae/__init__.py b/src/sksundae/__init__.py index 407ff82..d2baf37 100644 --- a/src/sksundae/__init__.py +++ b/src/sksundae/__init__.py @@ -61,4 +61,4 @@ __all__ = ['ida', 'utils', 'cvode', 'SUNDIALS_VERSION'] -__version__ = '1.0.0rc3' +__version__ = '1.0.0' diff --git a/src/sksundae/_cy_common.pxd b/src/sksundae/_cy_common.pxd index b9e2110..457e3c6 100644 --- a/src/sksundae/_cy_common.pxd +++ b/src/sksundae/_cy_common.pxd @@ -4,7 +4,7 @@ cimport numpy as np # Extern cdef headers -from .c_sundials cimport * # Access to types +from .c_sundials cimport * # Access to C types # Convert between N_Vector and numpy array cdef svec2np(N_Vector nvec, np.ndarray[DTYPE_t, ndim=1] np_array) diff --git a/src/sksundae/_cy_cvode.pyx b/src/sksundae/_cy_cvode.pyx index bf506cb..a91fca1 100644 --- a/src/sksundae/_cy_cvode.pyx +++ b/src/sksundae/_cy_cvode.pyx @@ -83,7 +83,7 @@ LSMESSAGES = { cdef int _rhsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, - void* data) noexcept: + void* data) except? -1: """Wraps 'rhsfn' by converting between N_Vector and ndarray types.""" aux = data @@ -101,7 +101,7 @@ cdef int _rhsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, sunrealtype* ee, - void* data) noexcept: + void* data) except? -1: """Wraps 'eventsfn' by converting between N_Vector and ndarray types.""" aux = data @@ -120,7 +120,7 @@ cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, sunrealtype* ee, cdef int _jacfn_wrapper(sunrealtype t, N_Vector yy, N_Vector fy, SUNMatrix JJ, void* data, N_Vector tmp1, N_Vector tmp2, - N_Vector tmp3) noexcept: + N_Vector tmp3) except? -1: """Wraps 'jacfn' by converting between N_Vector and ndarray types.""" aux = data @@ -140,7 +140,7 @@ cdef int _jacfn_wrapper(sunrealtype t, N_Vector yy, N_Vector fy, SUNMatrix JJ, cdef void _err_handler(int line, const char* func, const char* file, const char* msg, int err_code, void* err_user_data, - SUNContext ctx) noexcept: + SUNContext ctx) except *: """Custom error handler for shorter messages (no line or file).""" decoded_func = func.decode("utf-8") diff --git a/src/sksundae/_cy_ida.pyx b/src/sksundae/_cy_ida.pyx index b52480e..b20bb2a 100644 --- a/src/sksundae/_cy_ida.pyx +++ b/src/sksundae/_cy_ida.pyx @@ -81,7 +81,7 @@ LSMESSAGES = { cdef int _resfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rr, - void* data) noexcept: + void* data) except? -1: """Wraps 'resfn' by converting between N_Vector and ndarray types.""" aux = data @@ -100,7 +100,7 @@ cdef int _resfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rr, cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, - sunrealtype* ee, void* data) noexcept: + sunrealtype* ee, void* data) except? -1: """Wraps 'eventsfn' by converting between N_Vector and ndarray types.""" aux = data @@ -120,7 +120,7 @@ cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, cdef int _jacfn_wrapper(sunrealtype t, sunrealtype cj, N_Vector yy, N_Vector yp, N_Vector rr, SUNMatrix JJ, void* data, N_Vector tmp1, - N_Vector tmp2, N_Vector tmp3) noexcept: + N_Vector tmp2, N_Vector tmp3) except? -1: """Wraps 'jacfn' by converting between N_Vector and ndarray types.""" aux = data @@ -142,7 +142,7 @@ cdef int _jacfn_wrapper(sunrealtype t, sunrealtype cj, N_Vector yy, N_Vector yp, cdef void _err_handler(int line, const char* func, const char* file, const char* msg, int err_code, void* err_user_data, - SUNContext ctx) noexcept: + SUNContext ctx) except *: """Custom error handler for shorter messages (no line or file).""" decoded_func = func.decode("utf-8") diff --git a/src/sksundae/c_cvode.pxd b/src/sksundae/c_cvode.pxd index 22eb8ca..56c3546 100644 --- a/src/sksundae/c_cvode.pxd +++ b/src/sksundae/c_cvode.pxd @@ -6,8 +6,8 @@ from .c_sundials cimport * # Access to types cdef extern from "cvode/cvode.h": # user-supplied functions - ctypedef int (*CVRhsFn)(sunrealtype t, N_Vector yy, N_Vector yp, void* data) - ctypedef int (*CVRootFn)(sunrealtype t, N_Vector yy, sunrealtype* ee, void* data) + ctypedef int (*CVRhsFn)(sunrealtype t, N_Vector yy, N_Vector yp, void* data) except? -1 + ctypedef int (*CVRootFn)(sunrealtype t, N_Vector yy, sunrealtype* ee, void* data) except? -1 # imethod int CV_ADAMS @@ -66,8 +66,8 @@ cdef extern from "cvode/cvode.h": cdef extern from "cvode/cvode_ls.h": # user-supplied functions - ctypedef int (*CVLsJacFn)(sunrealtype t, N_Vector yy, N_Vector fy, SUNMatrix JJ, - void* data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) + ctypedef int (*CVLsJacFn)(sunrealtype t, N_Vector yy, N_Vector fy, SUNMatrix JJ, void* data, + N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) except? -1 # exported functions int CVodeSetLinearSolver(void* mem, SUNLinearSolver LS, SUNMatrix A) diff --git a/src/sksundae/c_ida.pxd b/src/sksundae/c_ida.pxd index c13d8ee..d416aa2 100644 --- a/src/sksundae/c_ida.pxd +++ b/src/sksundae/c_ida.pxd @@ -6,8 +6,8 @@ from .c_sundials cimport * cdef extern from "ida/ida.h": # user-supplied functions - ctypedef int (*IDAResFn)(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rr, void* data) - ctypedef int (*IDARootFn)(sunrealtype t, N_Vector yy, N_Vector yp, sunrealtype* ee, void* data) + ctypedef int (*IDAResFn)(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rr, void* data) except? -1 + ctypedef int (*IDARootFn)(sunrealtype t, N_Vector yy, N_Vector yp, sunrealtype* ee, void* data) except? -1 # itask int IDA_NORMAL @@ -73,7 +73,7 @@ cdef extern from "ida/ida_ls.h": # user-supplied functions ctypedef int (*IDALsJacFn)(sunrealtype t, sunrealtype cj, N_Vector yy, N_Vector yp, N_Vector rr, SUNMatrix JJ, void* data, N_Vector tmp1, - N_Vector tmp2, N_Vector tmp3) + N_Vector tmp2, N_Vector tmp3) except? -1 # exported functions int IDASetLinearSolver(void* mem, SUNLinearSolver LS, SUNMatrix A) diff --git a/src/sksundae/c_sundials.pxd b/src/sksundae/c_sundials.pxd index 7d54e76..3190392 100644 --- a/src/sksundae/c_sundials.pxd +++ b/src/sksundae/c_sundials.pxd @@ -35,7 +35,7 @@ cdef extern from "sundials/sundials_types.h": ctypedef int SUNComm ctypedef void (*SUNErrHandlerFn)(int line, const char* func, const char* file, const char* msg, int err_code, void* err_user_data, - SUNContext ctx) + SUNContext ctx) except * int SUN_COMM_NULL diff --git a/tests/test_cvode.py b/tests/test_cvode.py index 74478af..6d4b7a9 100644 --- a/tests/test_cvode.py +++ b/tests/test_cvode.py @@ -199,6 +199,54 @@ def jacfn(t, y, fy, JJ): assert np.allclose(soln.y, ode_soln(soln.t, y0)) +def test_failures_on_exceptions(): + + # exception in rhsfn + def bad_ode(t, y, yp): + if t > 1: + raise ValueError() + + yp[0] = 0.1 + yp[1] = y[1] + + y0 = np.array([1, 2]) + + solver = CVODE(bad_ode, rtol=1e-9, atol=1e-12) + + with pytest.raises(ValueError): + tspan = np.linspace(0, 10, 11) + soln = solver.solve(tspan, y0) + assert soln.status < 0 + + # exceptions in eventsfn + def eventsfn(t, y, events): + if t > 1: + raise ValueError() + + events[0] = y[0] - 1.55 + + solver = CVODE(ode, rtol=1e-9, atol=1e-12, eventsfn=eventsfn, num_events=1) + + with pytest.raises(ValueError): + tspan = np.linspace(0, 10, 11) + soln = solver.solve(tspan, y0) + assert soln.status < 0 + + # exceptions in jacfn + def jacfn(t, y, fy, JJ): + if t > 1: + raise ValueError() + + JJ[1, 1] = 1 + + solver = CVODE(ode, rtol=1e-9, atol=1e-12, jacfn=jacfn) + + with pytest.raises(ValueError): + tspan = np.linspace(0, 10, 11) + soln = solver.solve(tspan, y0) + assert soln.status < 0 + + def test_CVODEResult(): y0 = np.array([1, 2]) diff --git a/tests/test_ida.py b/tests/test_ida.py index 9c0a932..11e1536 100644 --- a/tests/test_ida.py +++ b/tests/test_ida.py @@ -78,7 +78,7 @@ def test_ida_dae_solve(): y0 = np.array([1, 2]) yp0 = np.array([0.1, 0.2]) - solver = IDA(dae, rtol=1e-9, atol=1e-12) + solver = IDA(dae, rtol=1e-9, atol=1e-12, algebraic_idx=[1]) tspan = np.linspace(0, 10, 11) # normal solve - user picks times soln = solver.solve(tspan, y0, yp0) @@ -95,7 +95,7 @@ def test_ida_dae_step(): y0 = np.array([1, 2]) yp0 = np.array([0.1, 0.2]) - solver = IDA(dae, rtol=1e-9, atol=1e-12) + solver = IDA(dae, rtol=1e-9, atol=1e-12, algebraic_idx=[1]) with pytest.raises(ValueError): # have to call init_step first _ = solver.step(10) @@ -122,7 +122,7 @@ def dae_w_data(t, y, yp, res, userdata): with pytest.raises(ValueError): # userdata keyword arg cannot be None _ = IDA(dae_w_data, rtol=1e-9, atol=1e-12) - solver = IDA(dae_w_data, rtol=1e-9, atol=1e-12, + solver = IDA(dae_w_data, rtol=1e-9, atol=1e-12, algebraic_idx=[1], userdata={'rate': 0.1, 'ratio': 2}) tspan = np.linspace(0, 10, 11) @@ -165,8 +165,8 @@ def test_ida_linsolver(): with pytest.raises(ValueError): # forgot bandwidth(s) _ = IDA(ode, rtol=1e-9, atol=1e-12, linsolver='band') - solver = IDA(ode, rtol=1e-9, atol=1e-12, linsolver='band', lband=0, - uband=0) + solver = IDA(ode, rtol=1e-9, atol=1e-12, algebraic_idx=[1], + linsolver='band', lband=0, uband=0) tspan = np.linspace(0, 10, 11) soln = solver.solve(tspan, y0, yp0) @@ -287,19 +287,72 @@ def jacfn(t, y, yp, res, cj, JJ): soln = solver.solve(tspan, y0, yp0) assert np.allclose(soln.y, dae_soln(soln.t, y0)) - solver = IDA(dae, rtol=1e-9, atol=1e-12, linsolver='band', - lband=1, uband=0, jacfn=jacfn) + solver = IDA(dae, rtol=1e-9, atol=1e-12, algebraic_idx=[1], + linsolver='band', lband=1, uband=0, jacfn=jacfn) tspan = np.linspace(0, 10, 11) soln = solver.solve(tspan, y0, yp0) assert np.allclose(soln.y, dae_soln(soln.t, y0)) +def test_failures_on_exceptions(): + + # exception in resfn + def bad_dae(t, y, yp, res): + if t > 1: + raise ValueError() + + res[0] = yp[0] - 0.1 + res[1] = 2*y[0] - y[1] + + y0 = np.array([1, 2]) + yp0 = np.array([0.1, 0.2]) + + solver = IDA(bad_dae, rtol=1e-9, atol=1e-12, algebraic_idx=[1]) + + with pytest.raises(ValueError): + tspan = np.linspace(0, 10, 11) + soln = solver.solve(tspan, y0, yp0) + assert soln.status < 0 + + # exceptions in eventsfn + def eventsfn(t, y, yp, events): + if t > 1: + raise ValueError() + + events[0] = y[0] - 1.55 + + solver = IDA(dae, rtol=1e-9, atol=1e-12, algebraic_idx=[1], + eventsfn=eventsfn, num_events=1) + + with pytest.raises(ValueError): + tspan = np.linspace(0, 10, 11) + soln = solver.solve(tspan, y0, yp0) + assert soln.status < 0 + + # exceptions in jacfn + def jacfn(t, y, yp, res, cj, JJ): + if t > 1: + raise ValueError() + + JJ[0, 0] = cj + JJ[1, 0] = 2 + JJ[1, 1] = -1 + + solver = IDA(dae, rtol=1e-9, atol=1e-12, algebraic_idx=[1], + jacfn=jacfn) + + with pytest.raises(ValueError): + tspan = np.linspace(0, 10, 11) + soln = solver.solve(tspan, y0, yp0) + assert soln.status < 0 + + def test_IDAResult(): y0 = np.array([1, 2]) yp0 = np.array([0.1, 0.2]) - solver = IDA(dae, rtol=1e-9, atol=1e-12) + solver = IDA(dae, rtol=1e-9, atol=1e-12, algebraic_idx=[1]) tspan = np.linspace(0, 10, 11) soln = solver.solve(tspan, y0, yp0)