Skip to content

Commit

Permalink
Propogate Python exceptions through err_handler
Browse files Browse the repository at this point in the history
  • Loading branch information
c-randall committed Dec 20, 2024
1 parent 91933a0 commit c599525
Show file tree
Hide file tree
Showing 16 changed files with 283 additions and 38 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ plt.show()
## Citing this Work
This work was authored by researchers at the National Renewable Energy Laboratory (NREL). The project is tracked in NREL's software records under SWR-24-137 and has a DOI available for citing the work. If you use use this package in your work, please include the following citation:

> Placeholder... waiting for DOI.
> Randall, Corey R. "scikit-SUNDAE ((SUN)DIALS Differential Algebraic Equations) [SWR-24-137]." Computer software. url: https://github.com/NREL/scikit-sundae. doi: https://doi.org/10.11578/dc.20241104.3.
For convenience, we also provide the following for your BibTex:

```
@misc{Randall2024,
title = {{scikit-SUNDAE: Python bindings to SUNDIALS DAE solvers}},
@misc{Randall-2024,
title = {{scikit-SUNDAE ((SUN)DIALS Differential Algebraic Equations) [SWR-24-137]}},
author = {Randall, Corey R.},
year = {2024},
doi = {placeholder... waiting for DOI},
doi = {10.11578/dc.20241104.3},
url = {https://github.com/NREL/scikit-sundae},
year = {2024},
}
```

Expand Down
12 changes: 12 additions & 0 deletions docs/source/_templates/copyright.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{# Displays the copyright information (which is defined in conf.py). #}
{% if show_copyright and copyright %}
<p class="copyright">
{% if hasdoc('copyright') %}
© <a href="{{ pathto('copyright') }}">{% trans copyright=copyright|e %}Copyright {{ copyright }}{% endtrans %}</a>.
<br/>
{% else %}
{% trans copyright=copyright|e %}© {{ copyright }}.{% endtrans %}
<br/>
{% endif %}
</p>
{% endif %}
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sksundae as sun

project = 'scikit-sundae'
copyright = '2024, Corey R. Randall'
copyright = 'Alliance for Sustainable Energy, LLC'
author = 'Corey R. Randall'
version = sun.__version__
release = sun.__version__
Expand Down
2 changes: 1 addition & 1 deletion environments/ci_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ name: sun
channels:
- conda-forge
dependencies:
- sundials=7.1
- sundials=7.2
2 changes: 1 addition & 1 deletion environments/rtd_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
- python=3.12
- sundials=7.1
- sundials=7.2
- pip>=24.3
- pip:
- ../.[docs]
8 changes: 4 additions & 4 deletions images/tests.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
132 changes: 132 additions & 0 deletions scripts/version_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import requests
import argparse
from packaging.version import Version


def get_latest_version(package: str, prefix: str = None) -> str:
"""
Fetch the latest version with matching prefix from PyPI.
Parameters
----------
package : str
The name of the package to query.
prefix : str
A filtering prefix used to get a subset of releases, e.g., '1.1' will
return the latest patch to version 1.1 even if 1.2 exists.
Returns
-------
latest_version : str
The latest version available on PyPI. '0.0.0' is returned if there
are no versions.
Raises
------
ValueError
Failed to fetch PyPI data for requested package.
"""

url = f"https://pypi.org/pypi/{package}/json"

response = requests.get(url)
if response.status_code != 200:
raise ValueError(f"Failed to fetch PyPI data for '{package}'.")

data = response.json()
versions = list(data['releases'].keys())
if not versions:
print(f"{package=} not found on PyPI.")
return '0.0.0'

if prefix:
versions = [v for v in versions if v.startswith(prefix)]
assert len(versions) != 0, f"{prefix=} has no existing matches."

sorted_versions = sorted(versions, key=Version, reverse=True)
latest_version = sorted_versions[0]

print(f"Latest PyPI version for {package}: {latest_version}.")
return latest_version


def check_against_pypi(pypi: str, local: str) -> None:
"""
Verify the local version is newer than PyPI.
Parameters
----------
pypi : str
Latest version on PyPI.
local : str
Local package version.
Returns
-------
None.
Raises
------
ValueError
Local package is older than PyPI.
"""

pypi = Version(pypi)
local = Version(local)

if local < pypi:
raise ValueError(f"Local package {local} is older than PyPI {pypi}.")

print(f"Local package {local} is newer than PyPI {pypi}.")


def check_against_tag(tag: str, local: str) -> None:
"""
Check that the tag matches the package version.
Parameters
----------
tag : str
Semmantically versioned tag.
local : str
Local package version.
Returns
-------
None.
Raises
------
ValueError
Version mismatch: tag differs from local.
"""

tag = Version(tag)
local = Version(local)

if tag != local:
raise ValueError(f"Version mismatch: {tag=} vs. {local=}")

print(f"Local and tag versions match: {tag} == {local}.")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tag', required=True)
parser.add_argument('--local', required=True)
args = parser.parse_args()

check_against_tag(args.tag, args.local)

patch_check = Version(args.local)
if patch_check.micro > 0:
prefix = str(patch_check.major) + '.' + str(patch_check.minor)
else:
prefix = None

pypi = get_latest_version('scikit-sundae', prefix)

check_against_pypi(pypi, args.local)
2 changes: 1 addition & 1 deletion src/sksundae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@

__all__ = ['ida', 'utils', 'cvode', 'SUNDIALS_VERSION']

__version__ = '1.0.0rc3'
__version__ = '1.0.0'
2 changes: 1 addition & 1 deletion src/sksundae/_cy_common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
cimport numpy as np

# Extern cdef headers
from .c_sundials cimport * # Access to types
from .c_sundials cimport * # Access to C types

# Convert between N_Vector and numpy array
cdef svec2np(N_Vector nvec, np.ndarray[DTYPE_t, ndim=1] np_array)
Expand Down
8 changes: 4 additions & 4 deletions src/sksundae/_cy_cvode.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ LSMESSAGES = {


cdef int _rhsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,
void* data) noexcept:
void* data) except? -1:
"""Wraps 'rhsfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -101,7 +101,7 @@ cdef int _rhsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,


cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, sunrealtype* ee,
void* data) noexcept:
void* data) except? -1:
"""Wraps 'eventsfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -120,7 +120,7 @@ cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, sunrealtype* ee,

cdef int _jacfn_wrapper(sunrealtype t, N_Vector yy, N_Vector fy, SUNMatrix JJ,
void* data, N_Vector tmp1, N_Vector tmp2,
N_Vector tmp3) noexcept:
N_Vector tmp3) except? -1:
"""Wraps 'jacfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -140,7 +140,7 @@ cdef int _jacfn_wrapper(sunrealtype t, N_Vector yy, N_Vector fy, SUNMatrix JJ,

cdef void _err_handler(int line, const char* func, const char* file,
const char* msg, int err_code, void* err_user_data,
SUNContext ctx) noexcept:
SUNContext ctx) except *:
"""Custom error handler for shorter messages (no line or file)."""

decoded_func = func.decode("utf-8")
Expand Down
8 changes: 4 additions & 4 deletions src/sksundae/_cy_ida.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ LSMESSAGES = {


cdef int _resfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rr,
void* data) noexcept:
void* data) except? -1:
"""Wraps 'resfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -100,7 +100,7 @@ cdef int _resfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rr,


cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,
sunrealtype* ee, void* data) noexcept:
sunrealtype* ee, void* data) except? -1:
"""Wraps 'eventsfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -120,7 +120,7 @@ cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,

cdef int _jacfn_wrapper(sunrealtype t, sunrealtype cj, N_Vector yy, N_Vector yp,
N_Vector rr, SUNMatrix JJ, void* data, N_Vector tmp1,
N_Vector tmp2, N_Vector tmp3) noexcept:
N_Vector tmp2, N_Vector tmp3) except? -1:
"""Wraps 'jacfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -142,7 +142,7 @@ cdef int _jacfn_wrapper(sunrealtype t, sunrealtype cj, N_Vector yy, N_Vector yp,

cdef void _err_handler(int line, const char* func, const char* file,
const char* msg, int err_code, void* err_user_data,
SUNContext ctx) noexcept:
SUNContext ctx) except *:
"""Custom error handler for shorter messages (no line or file)."""

decoded_func = func.decode("utf-8")
Expand Down
8 changes: 4 additions & 4 deletions src/sksundae/c_cvode.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ from .c_sundials cimport * # Access to types
cdef extern from "cvode/cvode.h":

# user-supplied functions
ctypedef int (*CVRhsFn)(sunrealtype t, N_Vector yy, N_Vector yp, void* data)
ctypedef int (*CVRootFn)(sunrealtype t, N_Vector yy, sunrealtype* ee, void* data)
ctypedef int (*CVRhsFn)(sunrealtype t, N_Vector yy, N_Vector yp, void* data) except? -1
ctypedef int (*CVRootFn)(sunrealtype t, N_Vector yy, sunrealtype* ee, void* data) except? -1

# imethod
int CV_ADAMS
Expand Down Expand Up @@ -66,8 +66,8 @@ cdef extern from "cvode/cvode.h":
cdef extern from "cvode/cvode_ls.h":

# user-supplied functions
ctypedef int (*CVLsJacFn)(sunrealtype t, N_Vector yy, N_Vector fy, SUNMatrix JJ,
void* data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
ctypedef int (*CVLsJacFn)(sunrealtype t, N_Vector yy, N_Vector fy, SUNMatrix JJ, void* data,
N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) except? -1

# exported functions
int CVodeSetLinearSolver(void* mem, SUNLinearSolver LS, SUNMatrix A)
Expand Down
6 changes: 3 additions & 3 deletions src/sksundae/c_ida.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ from .c_sundials cimport *
cdef extern from "ida/ida.h":

# user-supplied functions
ctypedef int (*IDAResFn)(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rr, void* data)
ctypedef int (*IDARootFn)(sunrealtype t, N_Vector yy, N_Vector yp, sunrealtype* ee, void* data)
ctypedef int (*IDAResFn)(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rr, void* data) except? -1
ctypedef int (*IDARootFn)(sunrealtype t, N_Vector yy, N_Vector yp, sunrealtype* ee, void* data) except? -1

# itask
int IDA_NORMAL
Expand Down Expand Up @@ -73,7 +73,7 @@ cdef extern from "ida/ida_ls.h":
# user-supplied functions
ctypedef int (*IDALsJacFn)(sunrealtype t, sunrealtype cj, N_Vector yy, N_Vector yp,
N_Vector rr, SUNMatrix JJ, void* data, N_Vector tmp1,
N_Vector tmp2, N_Vector tmp3)
N_Vector tmp2, N_Vector tmp3) except? -1

# exported functions
int IDASetLinearSolver(void* mem, SUNLinearSolver LS, SUNMatrix A)
Expand Down
2 changes: 1 addition & 1 deletion src/sksundae/c_sundials.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ cdef extern from "sundials/sundials_types.h":
ctypedef int SUNComm
ctypedef void (*SUNErrHandlerFn)(int line, const char* func, const char* file,
const char* msg, int err_code, void* err_user_data,
SUNContext ctx)
SUNContext ctx) except *

int SUN_COMM_NULL

Expand Down
Loading

0 comments on commit c599525

Please sign in to comment.