From facdbdc603c99df07e7c84da0788994505eb4098 Mon Sep 17 00:00:00 2001 From: "Corey R. Randall" Date: Tue, 14 Jan 2025 17:06:25 -0700 Subject: [PATCH] resfn and rhsfn now work with numba jit/njit --- .github/workflows/ci.yml | 8 +++ .github/workflows/release.yml | 36 +++++-------- inprogress/getfullargspec.py | 97 ----------------------------------- noxfile.py | 1 + src/sksundae/_cy_cvode.pyx | 24 ++++----- src/sksundae/_cy_ida.pyx | 24 ++++----- 6 files changed, 44 insertions(+), 146 deletions(-) delete mode 100644 inprogress/getfullargspec.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 624dd89..60a6db5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -95,3 +95,11 @@ jobs: - name: Pytest run: nox -s tests -- no-reports + + - name: learn micromamba + run: | + micromamba remove sundials + micromamba create -n test python=${{ matrix.python-version }} + micromamba activate test + micromamba info + micromamba list diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 48336dd..db72757 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -87,24 +87,19 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Setup conda/python - uses: conda-incubator/setup-miniconda@v3 - with: # ci_environment.yml specifies sundials version to compile against - auto-update-conda: true - miniconda-version: latest + - name: Setup Python and SUNDIALS + uses: mamba-org/setup-micromamba@v2 + with: # ci_environment.yml specifies sundials version to compile environment-file: environments/ci_environment.yml - python-version: ${{ matrix.python-version }} - activate-environment: sun - channels: conda-forge - conda-remove-defaults: "true" + create-args: python=${{ matrix.python-version }} - name: Install build run: pip install build - name: List info run: | - conda info - conda list + micromamba info + micromamba list - name: Set up environment variables for MacOS if: runner.os == 'macOS' @@ -153,9 +148,9 @@ jobs: env: # Remove known SUNDIALS header and lib paths DYLD_LIBRARY_PATH: run: | - conda uninstall sundials - conda create -n test python=${{ matrix.python-version }} - conda activate test + micromamba remove sundials + micromamba create -n test python=${{ matrix.python-version }} + micromamba activate test python -m pip install --upgrade pip pip install wheels/*.whl -v @@ -188,16 +183,11 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Setup conda/python - uses: conda-incubator/setup-miniconda@v3 - with: # ci_environment.yml specifies sundials version to compile against - auto-update-conda: true - miniconda-version: latest + - name: Setup Python and SUNDIALS + uses: mamba-org/setup-micromamba@v2 + with: # ci_environment.yml specifies sundials version to compile environment-file: environments/ci_environment.yml - python-version: ${{ matrix.python-version }} - activate-environment: sun - channels: conda-forge - conda-remove-defaults: "true" + create-args: python=${{ matrix.python-version }} - name: Install build run: pip install build diff --git a/inprogress/getfullargspec.py b/inprogress/getfullargspec.py deleted file mode 100644 index 842fd4d..0000000 --- a/inprogress/getfullargspec.py +++ /dev/null @@ -1,97 +0,0 @@ -import inspect -from collections import namedtuple -from inspect import (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD, _VAR_POSITIONAL, - _KEYWORD_ONLY, _VAR_KEYWORD,) - -from jax import jit as jjit -from numba import jit as njit - - -def py_func(a: int, b: int, *c, x: int = 1, y: int = 2, **z): - return x + y - - -@njit -def nb_func(a: int, b: int, *c, x: int = 1, y: int = 2, **z): - return x + y - - -@jjit -def jax_func(a: int, b: int, *c, x: int = 1, y: int = 2, **z): - return x + y - - -class py_call: - - def __call__(self, a: int, b: int, *c, x: int = 1, y: int = 2, **z): - return x + y - - @classmethod - def call(cls, a: int, b: int, *c, x: int = 1, y: int = 2, **z): - return x + y - - -FullArgSpec = namedtuple( - 'FullArgSpec', - 'args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations', -) - - -def getfullargspec(func) -> FullArgSpec: - - try: - sig = inspect.signature(func) - except Exception as ex: - raise TypeError('unsupported callable') from ex - - args = [] - varargs = None - varkw = None - posonlyargs = [] - kwonlyargs = [] - annotations = {} - defaults = () - kwdefaults = {} - - if sig.return_annotation is not sig.empty: - annotations['return'] = sig.return_annotation - - for param in sig.parameters.values(): - kind = param.kind - name = param.name - - if kind is _POSITIONAL_ONLY: - posonlyargs.append(name) - if param.default is not param.empty: - defaults += (param.default,) - elif kind is _POSITIONAL_OR_KEYWORD: - args.append(name) - if param.default is not param.empty: - defaults += (param.default,) - elif kind is _VAR_POSITIONAL: - varargs = name - elif kind is _KEYWORD_ONLY: - kwonlyargs.append(name) - if param.default is not param.empty: - kwdefaults[name] = param.default - elif kind is _VAR_KEYWORD: - varkw = name - - if param.annotation is not param.empty: - annotations[name] = param.annotation - - if not kwdefaults: - kwdefaults = None - - if not defaults: - defaults = None - - return FullArgSpec(posonlyargs + args, varargs, varkw, defaults, - kwonlyargs, kwdefaults, annotations) - - -print('py:', getfullargspec(py_func)) -print('nb:', getfullargspec(nb_func)) -print('jax:', getfullargspec(jax_func)) -print('call:', getfullargspec(py_call())) -print('cls:', getfullargspec(py_call.call)) diff --git a/noxfile.py b/noxfile.py index b76a605..ef2c469 100644 --- a/noxfile.py +++ b/noxfile.py @@ -189,4 +189,5 @@ def run_build_ext(session: nox.Session) -> None: """ + session.run('pip', 'install', '--upgrade', '--quiet', 'cython') session.run('python', 'setup.py', 'build_ext', '--inplace') diff --git a/src/sksundae/_cy_cvode.pyx b/src/sksundae/_cy_cvode.pyx index f2831de..68754c9 100644 --- a/src/sksundae/_cy_cvode.pyx +++ b/src/sksundae/_cy_cvode.pyx @@ -4,9 +4,9 @@ # cython: embedsignature=True, embeddedsignature.format='python' # Standard library +import inspect from warnings import warn from types import MethodType -from inspect import getfullargspec from typing import Callable, Iterable # Dependencies @@ -814,27 +814,25 @@ cdef _collect_stats(void* mem): def _check_signature(name: str, func: Callable, expected: tuple[int]) -> int: """Check 'rhsfn', 'eventsfn', and 'jacfn' signatures.""" - argspec = getfullargspec(func) - if isinstance(func, MethodType): # if method, remove self/cls - argspec.args.pop(0) - elif argspec.args[0] in ("self", "cls"): - argspec.args.pop(0) + signature = inspect.signature(func) + parameters = signature.parameters.values() - if argspec.varargs or argspec.varkw: + has_args = any([p.kind == inspect._VAR_POSITIONAL for p in parameters]) + has_kwargs = any([p.kind == inspect._VAR_KEYWORD for p in parameters]) + + if has_args or has_kwargs: raise ValueError(f"'{name}' cannot include *args or **kwargs.") - elif argspec.kwonlyargs: - raise ValueError(f"'{name}' cannot include keyword-only args.") - if name == "rhsfn" and len(argspec.args) not in expected: + if name == "resfn" and len(parameters) not in expected: raise ValueError(f"'{name}' has an invalid signature. It must only" " have 3 (w/o userdata) or 4 (w/ userdata) args.") - elif len(argspec.args) not in expected: + elif len(parameters) not in expected: raise ValueError(f"'{name}' signature is inconsistent with 'rhsfn'." " look for a missing or extraneous 'userdata' arg.") - if name == "rhsfn" and len(argspec.args) == 3: + if name == "rhsfn" and len(parameters) == 3: with_userdata = 0 - elif name == "rhsfn" and len(argspec.args) == 4: + elif name == "rhsfn" and len(parameters) == 4: with_userdata = 1 else: with_userdata = None diff --git a/src/sksundae/_cy_ida.pyx b/src/sksundae/_cy_ida.pyx index 18a6648..0aeb6ea 100644 --- a/src/sksundae/_cy_ida.pyx +++ b/src/sksundae/_cy_ida.pyx @@ -4,9 +4,9 @@ # cython: embedsignature=True, embeddedsignature.format='python' # Standard library +import inspect from warnings import warn from types import MethodType -from inspect import getfullargspec from typing import Callable, Iterable # Dependencies @@ -878,27 +878,25 @@ cdef _collect_stats(void* mem): def _check_signature(name: str, func: Callable, expected: tuple[int]) -> int: """Check 'resfn', 'eventsfn', and 'jacfn' signatures.""" - argspec = getfullargspec(func) - if isinstance(func, MethodType): # if method, remove self/cls - argspec.args.pop(0) - elif argspec.args[0] in ("self", "cls"): - argspec.args.pop(0) + signature = inspect.signature(func) + parameters = signature.parameters.values() - if argspec.varargs or argspec.varkw: + has_args = any([p.kind == inspect._VAR_POSITIONAL for p in parameters]) + has_kwargs = any([p.kind == inspect._VAR_KEYWORD for p in parameters]) + + if has_args or has_kwargs: raise ValueError(f"'{name}' cannot include *args or **kwargs.") - elif argspec.kwonlyargs: - raise ValueError(f"'{name}' cannot include keyword-only args.") - if name == "resfn" and len(argspec.args) not in expected: + if name == "resfn" and len(parameters) not in expected: raise ValueError(f"'{name}' has an invalid signature. It must only" " have 4 (w/o userdata) or 5 (w/ userdata) args.") - elif len(argspec.args) not in expected: + elif len(parameters) not in expected: raise ValueError(f"'{name}' signature is inconsistent with 'resfn'." " look for a missing or extraneous 'userdata' arg.") - if name == "resfn" and len(argspec.args) == 4: + if name == "resfn" and len(parameters) == 4: with_userdata = 0 - elif name == "resfn" and len(argspec.args) == 5: + elif name == "resfn" and len(parameters) == 5: with_userdata = 1 else: with_userdata = None