Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: RF refactoring #2924

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "oneapi/dal/backend/memory.hpp"
#include "oneapi/dal/backend/interop/common.hpp"
#include "oneapi/dal/table/homogen.hpp"
#include "oneapi/dal/backend/primitives/rng/rng_engine.hpp"
#include "oneapi/dal/backend/primitives/rng/rng_cpu.hpp"
#include "oneapi/dal/detail/threading.hpp"

namespace oneapi::dal::preview::connected_components::backend {
Expand Down Expand Up @@ -90,9 +90,9 @@ std::int32_t most_frequent_element(const std::atomic<std::int32_t> *components,
const std::int64_t &samples_count = 1024) {
std::int32_t *rnd_vertex_ids = allocate(vertex_allocator, samples_count);

dal::backend::primitives::engine eng;
dal::backend::primitives::rng<std::int32_t> rn_gen;
rn_gen.uniform(samples_count, rnd_vertex_ids, eng.get_state(), 0, vertex_count);
dal::backend::primitives::daal_engine eng;
dal::backend::primitives::daal_rng<std::int32_t> rn_gen;
rn_gen.uniform(samples_count, rnd_vertex_ids, eng.get_cpu_engine_state(), 0, vertex_count);

std::int32_t *root_sample_counts = allocate(vertex_allocator, vertex_count);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "oneapi/dal/detail/policy.hpp"
#include "oneapi/dal/table/row_accessor.hpp"
#include "oneapi/dal/detail/profiler.hpp"

#include <iostream>
#include "oneapi/dal/algo/decision_forest/backend/gpu/infer_kernel_impl.hpp"

namespace oneapi::dal::decision_forest::backend {
Expand All @@ -44,9 +44,11 @@ void infer_kernel_impl<Float, Index, Task>::validate_input(const descriptor_t& d
if (data.get_row_count() > de::limits<Index>::max()) {
throw domain_error(dal::detail::error_messages::invalid_range_of_rows());
}

if (data.get_column_count() > de::limits<Index>::max()) {
throw domain_error(dal::detail::error_messages::invalid_range_of_columns());
}

if (model.get_tree_count() > de::limits<Index>::max()) {
throw domain_error(dal::detail::error_messages::invalid_number_of_trees());
}
Expand All @@ -67,6 +69,7 @@ void infer_kernel_impl<Float, Index, Task>::init_params(infer_context_t& ctx,
ctx.class_count = de::integral_cast<Index>(desc.get_class_count());
ctx.voting_mode = desc.get_voting_mode();
}

ctx.row_count = de::integral_cast<Index>(data.get_row_count());
ctx.column_count = de::integral_cast<Index>(data.get_column_count());

Expand Down Expand Up @@ -245,6 +248,7 @@ infer_kernel_impl<Float, Index, Task>::predict_by_tree_group(const infer_context
{ local_size, 1 });

sycl::event last_event = zero_obs_response_event;

for (Index proc_tree_count = 0; proc_tree_count < tree_count;
proc_tree_count += ctx.tree_in_group_count) {
last_event = queue_.submit([&](sycl::handler& cgh) {
Expand Down Expand Up @@ -347,6 +351,7 @@ infer_kernel_impl<Float, Index, Task>::reduce_tree_group_response(
be::make_multiple_nd_range_1d({ ctx.max_group_count * local_size }, { local_size });

sycl::event last_event = zero_response_event;

last_event = queue_.submit([&](sycl::handler& cgh) {
cgh.depends_on(deps);
cgh.depends_on(last_event);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "oneapi/dal/table/row_accessor.hpp"
#include "oneapi/dal/backend/memory.hpp"
#include "oneapi/dal/detail/profiler.hpp"

#include <iostream>
#ifdef ONEDAL_DATA_PARALLEL

namespace oneapi::dal::decision_forest::backend {
Expand All @@ -29,6 +29,12 @@ namespace de = dal::detail;
namespace bk = dal::backend;
namespace pr = dal::backend::primitives;

template <typename Float>
std::int64_t propose_block_size(const sycl::queue& q, const std::int64_t r) {
constexpr std::int64_t fsize = sizeof(Float);
return 0x10000l * (8 / fsize);
}

template <typename Float, typename Index>
inline sycl::event sort_inplace(sycl::queue& queue_,
pr::ndarray<Float, 1>& src,
Expand Down Expand Up @@ -56,18 +62,29 @@ sycl::event indexed_features<Float, Bin, Index>::extract_column(
Float* values = values_nd.get_mutable_data();
Index* indices = indices_nd.get_mutable_data();
auto column_count = column_count_;

const sycl::range<1> range = de::integral_cast<std::size_t>(row_count_);

auto event = queue_.submit([&](sycl::handler& h) {
h.depends_on(deps);
h.parallel_for(range, [=](sycl::id<1> idx) {
values[idx] = data[idx * column_count + feature_id];
indices[idx] = idx;
const auto block_size = propose_block_size<Float>(queue_, row_count_);
const bk::uniform_blocking blocking(row_count_, block_size);

std::vector<sycl::event> events(blocking.get_block_count());
for (std::int64_t block_index = 0; block_index < blocking.get_block_count(); ++block_index) {
const auto first_row = blocking.get_block_start_index(block_index);
const auto last_row = blocking.get_block_end_index(block_index);
const auto curr_block = last_row - first_row;
ONEDAL_ASSERT(curr_block > 0);

auto event = queue_.submit([&](sycl::handler& cgh) {
cgh.depends_on(deps);
cgh.parallel_for<>(de::integral_cast<std::size_t>(curr_block), [=](sycl::id<1> idx) {
const std::int64_t row = idx + first_row;

values[row] = data[row * column_count + feature_id];
indices[row] = row;
});
});
});

return event;
events.push_back(event);
}
return bk::wait_or_pass(events);
}

template <typename Float, typename Bin, typename Index>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
#include "oneapi/dal/backend/primitives/ndarray.hpp"
#include "oneapi/dal/backend/primitives/utils.hpp"
#include "oneapi/dal/algo/decision_forest/train_types.hpp"

#include "oneapi/dal/backend/primitives/rng/rng.hpp"
#include "oneapi/dal/backend/primitives/rng/rng_cpu.hpp"
#include "oneapi/dal/backend/primitives/rng/rng_engine_collection.hpp"

#include "oneapi/dal/algo/decision_forest/backend/gpu/train_misc_structs.hpp"
Expand Down Expand Up @@ -50,7 +51,7 @@ class train_kernel_hist_impl {
using model_manager_t = train_model_manager<Float, Index, Task>;
using train_context_t = train_context<Float, Index, Task>;
using imp_data_t = impurity_data<Float, Index, Task>;
using rng_engine_t = pr::engine;
using rng_engine_t = pr::oneapi_engine<pr::engine_list::mcg59>;
using rng_engine_list_t = std::vector<rng_engine_t>;
using msg = dal::detail::error_messages;
using comm_t = bk::communicator<spmd::device_memory_access::usm>;
Expand All @@ -62,7 +63,7 @@ class train_kernel_hist_impl {
train_kernel_hist_impl(const bk::context_gpu& ctx)
: queue_(ctx.get_queue()),
comm_(ctx.get_communicator()),
train_service_kernels_(queue_) {}
train_service_kernels_(ctx.get_queue()) {}
~train_kernel_hist_impl() = default;

result_t operator()(const descriptor_t& desc,
Expand All @@ -83,13 +84,11 @@ class train_kernel_hist_impl {
pr::ndarray<Index, 1>& node_list,
pr::ndarray<Index, 1>& tree_order_level,
Index engine_offset,
Index node_count);
Index node_count,
const bk::event_vector& deps = {});

void validate_input(const descriptor_t& desc, const table& data, const table& labels) const;

Index get_row_total_count(bool distr_mode, Index row_count);
Index get_global_row_offset(bool distr_mode, Index row_count);

/// Initializes `ctx` training context structure based on data and
/// descriptor class. Filling and calculating all parameters in context,
/// for example, tree count, required memory size, calculating indexed features, etc.
Expand Down Expand Up @@ -149,6 +148,24 @@ class train_kernel_hist_impl {
Index node_count,
const bk::event_vector& deps = {});

sycl::event compute_initial_imp_for_node_list_regression(
const train_context_t& ctx,
const pr::ndarray<Index, 1>& node_list,
const pr::ndarray<Float, 1>& local_sum_hist,
const pr::ndarray<Float, 1>& local_sum2cent_hist,
imp_data_t& imp_data_list,
Index node_count,
const bk::event_vector& deps = {});

sycl::event compute_local_sum_histogram(const train_context_t& ctx,
const pr::ndarray<Float, 1>& response,
const pr::ndarray<Index, 1>& tree_order,
const pr::ndarray<Index, 1>& node_list,
pr::ndarray<Float, 1>& local_sum_hist,
pr::ndarray<Float, 1>& local_sum2cent_hist,
Index node_count,
const bk::event_vector& deps = {});

/// Computes initial histograms for each node to compute impurity.
///
/// @param[in] ctx a training context structure for a GPU backend
Expand Down
Loading
Loading