Skip to content

Commit

Permalink
feat: Introduce BackendManager and PolicyManger
Browse files Browse the repository at this point in the history
fix tests

chore: remove unnecessary use of super()

fix typo

cosmetics

proposal: inject backend through decorators

add incremntal basic statistics

add dbscan

add kmeans

remove BackendMixin from kmeans

rm BackendMixin

fix pyproject; remove basemixin and add covariance

add svm

add pca

fix backend import

add forest

add linear_model

add neighbors

fixup svm

refactor: unify host/dpc backend decorators into single decorator

update backend import

cleanup

update_abstractmethods fix for py3.9

fixup

fixup

fixup

fixup

fixup after rebase

decorate methods not classes - preparation for spmd

simplify backend import

align spmd

fixup

update tests

fixup

fixup

fixup

fixup: add license

improve debug message

wip: fix spmd

spmd test cases

fix more spmd tests

more spmd test fixups

further fixups

further fixups

revert error message

_get_queue -> _get_policy

fix for spmd classes that rely on batch functions

more fixes for default policy
  • Loading branch information
ahuber21 committed Nov 23, 2024
1 parent 935c56b commit bb1559e
Show file tree
Hide file tree
Showing 55 changed files with 1,330 additions and 916 deletions.
85 changes: 61 additions & 24 deletions onedal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@

from daal4py.sklearn._utils import daal_check_version


class Backend:
"""Encapsulates the backend module and provides a unified interface to it together with additional properties about dpc/spmd policies"""

def __init__(self, backend_module, is_dpc, is_spmd):
self.backend = backend_module
self.is_dpc = is_dpc
self.is_spmd = is_spmd

# accessing the instance will return the backend_module
def __getattr__(self, name):
return getattr(self.backend, name)


if "Windows" in platform.system():
import os
import site
Expand All @@ -40,44 +54,67 @@
pass
os.environ["PATH"] = path_to_libs + os.pathsep + os.environ["PATH"]

try:
import onedal._onedal_py_dpc as _backend

_is_dpc_backend = True
except ImportError:
import onedal._onedal_py_host as _backend

_is_dpc_backend = False

_is_spmd_backend = False
try:
# use dpc backend if available
import onedal._onedal_py_dpc

if _is_dpc_backend:
try:
import onedal._onedal_py_spmd_dpc as _spmd_backend
_dpc_backend = Backend(onedal._onedal_py_dpc, is_dpc=True, is_spmd=False)

_is_spmd_backend = True
except ImportError:
_is_spmd_backend = False
_host_backend = None
except ImportError:
# fall back to host backend
_dpc_backend = None

import onedal._onedal_py_host

__all__ = ["covariance", "decomposition", "ensemble", "neighbors", "primitives", "svm"]
_host_backend = Backend(onedal._onedal_py_host, is_dpc=False, is_spmd=False)

if _is_spmd_backend:
__all__.append("spmd")
try:
# also load spmd backend if available
import onedal._onedal_py_spmd_dpc

_spmd_backend = Backend(onedal._onedal_py_spmd_dpc, is_dpc=True, is_spmd=True)
except ImportError:
_spmd_backend = None

# if/elif/else layout required for pylint to realize _default_backend cannot be None
if _dpc_backend is not None:
_default_backend = _dpc_backend
elif _host_backend is not None:
_default_backend = _host_backend
else:
raise ImportError("No oneDAL backend available")

# Core modules to export
__all__ = [
"_host_backend",
"_default_backend",
"_dpc_backend",
"_spmd_backend",
"covariance",
"decomposition",
"ensemble",
"neighbors",
"primitives",
"svm",
]

# Additional features based on version checks
if daal_check_version((2023, "P", 100)):
__all__ += ["basic_statistics", "linear_model"]
if daal_check_version((2023, "P", 200)):
__all__ += ["cluster"]

if _is_spmd_backend:
# Exports if SPMD backend is available
if _spmd_backend is not None:
__all__ += ["spmd"]
if daal_check_version((2023, "P", 100)):
__all__ += [
"spmd.basic_statistics",
"spmd.decomposition",
"spmd.linear_model",
"spmd.neighbors",
]

if daal_check_version((2023, "P", 200)):
__all__ += ["cluster"]

if _is_spmd_backend:
if daal_check_version((2023, "P", 200)):
__all__ += ["spmd.cluster"]
3 changes: 1 addition & 2 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
# ==============================================================================

import logging
from collections.abc import Iterable
from functools import wraps

Expand All @@ -36,7 +35,7 @@
# in _get_global_queue always true for situations without the
# dpc backend when `device_offload` is used. Instead, it will
# fail at the policy check phase yielding a RuntimeError
SyclQueue = getattr(onedal._backend, "SyclQueue", object)
SyclQueue = getattr(onedal._dpc_backend, "SyclQueue", object)

if dpnp_available:
import dpnp
Expand Down
14 changes: 9 additions & 5 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,28 @@
# limitations under the License.
# ==============================================================================

import warnings
from abc import ABCMeta, abstractmethod

import numpy as np

from ..common._base import BaseEstimator
from ..common._backend import bind_default_backend
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _is_csr
from ..utils.validation import _check_array


class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta):
class BaseBasicStatistics(metaclass=ABCMeta):
@abstractmethod
def __init__(self, result_options, algorithm):
self.options = result_options
self.algorithm = algorithm

@bind_default_backend("basic_statistics")
def _get_policy(self, queue, *data): ...

@bind_default_backend("basic_statistics")
def compute(self, policy, params, data_table, weights_table): ...

@staticmethod
def get_all_result_options():
return [
Expand Down Expand Up @@ -99,9 +104,8 @@ def fit(self, data, sample_weight=None, queue=None):
def _compute_raw(
self, data_table, weights_table, policy, dtype=np.float32, is_csr=False
):
module = self._get_backend("basic_statistics")
params = self._get_onedal_params(is_csr, dtype)
result = module.compute(policy, params, data_table, weights_table)
result = self.compute(policy, params, data_table, weights_table)
options = self._get_result_options(self.options).split("|")

return {opt: getattr(result, opt) for opt in options}
38 changes: 18 additions & 20 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# limitations under the License.
# ==============================================================================

from abc import abstractmethod

import numpy as np

from daal4py.sklearn._utils import get_dtype
from onedal.common._backend import bind_default_backend

from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
Expand Down Expand Up @@ -69,10 +72,18 @@ def __init__(self, result_options="all"):
super().__init__(result_options, algorithm="by_default")
self._reset()

@bind_default_backend("basic_statistics")
def partial_compute_result(self): ...

@bind_default_backend("basic_statistics")
def partial_compute(self, *args, **kwargs): ...

@bind_default_backend("basic_statistics")
def finalize_compute(self, *args, **kwargs): ...

def _reset(self):
self._partial_result = self._get_backend(
"basic_statistics", None, "partial_compute_result"
)
# get the _partial_result pointer from backend
self._partial_result = self.partial_compute_result()

def partial_fit(self, X, weights=None, queue=None):
"""
Expand Down Expand Up @@ -113,15 +124,8 @@ def partial_fit(self, X, weights=None, queue=None):
self._onedal_params = self._get_onedal_params(False, dtype=dtype)

X_table, weights_table = to_table(X, weights)
self._partial_result = self._get_backend(
"basic_statistics",
None,
"partial_compute",
policy,
self._onedal_params,
self._partial_result,
X_table,
weights_table,
self._partial_result = self.partial_compute(
policy, self._onedal_params, self._partial_result, X_table, weights_table
)

def finalize_fit(self, queue=None):
Expand All @@ -145,14 +149,8 @@ def finalize_fit(self, queue=None):
else:
policy = self._get_policy(self._queue)

result = self._get_backend(
"basic_statistics",
None,
"finalize_compute",
policy,
self._onedal_params,
self._partial_result,
)
result = self.finalize_compute(policy, self._onedal_params, self._partial_result)

options = self._get_result_options(self.options).split("|")
for opt in options:
setattr(self, opt, from_table(getattr(result, opt)).ravel())
Expand Down
1 change: 0 additions & 1 deletion onedal/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from numpy.testing import assert_allclose
from scipy import sparse as sp

from daal4py.sklearn._utils import daal_check_version
from onedal.basic_statistics import BasicStatistics
from onedal.basic_statistics.tests.utils import options_and_tests
from onedal.tests.utils._device_selection import get_queues
Expand Down
44 changes: 12 additions & 32 deletions onedal/cluster/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
# limitations under the License.
# ===============================================================================

from abc import abstractmethod

import numpy as np

from daal4py.sklearn._utils import get_dtype, make2d
from onedal.common._backend import bind_default_backend

from ..common._base import BaseEstimator
from ..common._mixin import ClusterMixin
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array


class BaseDBSCAN(BaseEstimator, ClusterMixin):
class DBSCAN(ClusterMixin):
def __init__(
self,
eps=0.5,
Expand All @@ -46,6 +48,12 @@ def __init__(
self.p = p
self.n_jobs = n_jobs

@bind_default_backend("dbscan")
def _get_policy(self, queue, *data): ...

@bind_default_backend("dbscan.clustering")
def compute(self, policy, params, data_table, weights_table): ...

def _get_onedal_params(self, dtype=np.float32):
return {
"fptype": "float" if dtype == np.float32 else "double",
Expand All @@ -56,7 +64,7 @@ def _get_onedal_params(self, dtype=np.float32):
"result_options": "core_observation_indices|responses",
}

def _fit(self, X, y, sample_weight, module, queue):
def fit(self, X, y=None, sample_weight=None, queue=None):
policy = self._get_policy(queue, X)
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
sample_weight = make2d(sample_weight) if sample_weight is not None else None
Expand All @@ -68,7 +76,7 @@ def _fit(self, X, y, sample_weight, module, queue):
X = _convert_to_supported(policy, X)
dtype = get_dtype(X)
params = self._get_onedal_params(dtype)
result = module.compute(policy, params, to_table(X), to_table(sample_weight))
result = self.compute(policy, params, to_table(X), to_table(sample_weight))

self.labels_ = from_table(result.responses).ravel()
if result.core_observation_indices is not None:
Expand All @@ -80,31 +88,3 @@ def _fit(self, X, y, sample_weight, module, queue):
self.components_ = np.take(X, self.core_sample_indices_, axis=0)
self.n_features_in_ = X.shape[1]
return self


class DBSCAN(BaseDBSCAN):
def __init__(
self,
eps=0.5,
*,
min_samples=5,
metric="euclidean",
metric_params=None,
algorithm="auto",
leaf_size=30,
p=None,
n_jobs=None,
):
self.eps = eps
self.min_samples = min_samples
self.metric = metric
self.metric_params = metric_params
self.algorithm = algorithm
self.leaf_size = leaf_size
self.p = p
self.n_jobs = n_jobs

def fit(self, X, y=None, sample_weight=None, queue=None):
return super()._fit(
X, y, sample_weight, self._get_backend("dbscan", "clustering", None), queue
)
Loading

0 comments on commit bb1559e

Please sign in to comment.