Skip to content

Commit

Permalink
There are mask utilities that perform some of the functions here, use…
Browse files Browse the repository at this point in the history
… those instead of replicating
  • Loading branch information
ChuckHastings committed Dec 1, 2023
1 parent c6aa981 commit afc00ee
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 61 deletions.
33 changes: 10 additions & 23 deletions cpp/src/structure/detail/structure_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cugraph/utilities/dataframe_buffer.hpp>
#include <cugraph/utilities/device_functors.cuh>
#include <cugraph/utilities/error.hpp>
#include <cugraph/utilities/mask_utils.cuh>
#include <cugraph/utilities/misc_utils.cuh>
#include <cugraph/utilities/packed_bool_utils.hpp>

Expand Down Expand Up @@ -524,35 +525,21 @@ std::tuple<size_t, rmm::device_uvector<uint32_t>> mark_entries(raft::handle_t co
return word;
});

// FIXME: use detail::count_set_bits
size_t bit_count = thrust::transform_reduce(
handle.get_thrust_policy(),
marked_entries.begin(),
marked_entries.end(),
[] __device__(auto word) { return __popc(word); },
size_t{0},
thrust::plus<size_t>());
size_t bit_count = detail::count_set_bits(handle, marked_entries.begin(), num_entries);

return std::make_tuple(bit_count, std::move(marked_entries));
}

template <typename T>
rmm::device_uvector<T> remove_flagged_elements(raft::handle_t const& handle,
rmm::device_uvector<T>&& vector,
raft::device_span<uint32_t const> remove_flags,
size_t remove_count)
rmm::device_uvector<T> keep_flagged_elements(raft::handle_t const& handle,
rmm::device_uvector<T>&& vector,
raft::device_span<uint32_t const> keep_flags,
size_t keep_count)
{
rmm::device_uvector<T> result(vector.size() - remove_count, handle.get_stream());

thrust::copy_if(
handle.get_thrust_policy(),
thrust::make_counting_iterator(size_t{0}),
thrust::make_counting_iterator(vector.size()),
thrust::make_transform_output_iterator(result.begin(),
indirection_t<size_t, T*>{vector.data()}),
[remove_flags] __device__(size_t i) {
return !(remove_flags[cugraph::packed_bool_offset(i)] & cugraph::packed_bool_mask(i));
});
rmm::device_uvector<T> result(keep_count, handle.get_stream());

detail::copy_if_mask_set(
handle, vector.begin(), vector.end(), keep_flags.begin(), result.begin());

return result;
}
Expand Down
37 changes: 17 additions & 20 deletions cpp/src/structure/remove_multi_edges_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ remove_multi_edges(raft::handle_t const& handle,
}
}

auto [multi_edge_count, multi_edges_to_delete] =
auto [keep_count, keep_flags] =
detail::mark_entries(handle,
edgelist_srcs.size(),
[d_edgelist_srcs = edgelist_srcs.data(),
Expand All @@ -263,41 +263,38 @@ remove_multi_edges(raft::handle_t const& handle,
(d_edgelist_dsts[idx - 1] == d_edgelist_dsts[idx]);
});

if (multi_edge_count > 0) {
edgelist_srcs = detail::remove_flagged_elements(
if (keep_count < edgelist_srcs.size()) {
edgelist_srcs = detail::keep_flagged_elements(
handle,
std::move(edgelist_srcs),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(), multi_edges_to_delete.size()},
multi_edge_count);
edgelist_dsts = detail::remove_flagged_elements(
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
edgelist_dsts = detail::keep_flagged_elements(
handle,
std::move(edgelist_dsts),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(), multi_edges_to_delete.size()},
multi_edge_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_weights)
edgelist_weights = detail::remove_flagged_elements(
edgelist_weights = detail::keep_flagged_elements(
handle,
std::move(*edgelist_weights),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(),
multi_edges_to_delete.size()},
multi_edge_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_edge_ids)
edgelist_edge_ids = detail::remove_flagged_elements(
edgelist_edge_ids = detail::keep_flagged_elements(
handle,
std::move(*edgelist_edge_ids),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(),
multi_edges_to_delete.size()},
multi_edge_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_edge_types)
edgelist_edge_types = detail::remove_flagged_elements(
edgelist_edge_types = detail::keep_flagged_elements(
handle,
std::move(*edgelist_edge_types),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(),
multi_edges_to_delete.size()},
multi_edge_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
}

return std::make_tuple(std::move(edgelist_srcs),
Expand Down
36 changes: 18 additions & 18 deletions cpp/src/structure/remove_self_loops_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,44 @@ remove_self_loops(raft::handle_t const& handle,
std::optional<rmm::device_uvector<edge_t>>&& edgelist_edge_ids,
std::optional<rmm::device_uvector<edge_type_t>>&& edgelist_edge_types)
{
auto [self_loop_count, self_loops_to_delete] =
auto [keep_count, keep_flags] =
detail::mark_entries(handle,
edgelist_srcs.size(),
[d_srcs = edgelist_srcs.data(), d_dsts = edgelist_dsts.data()] __device__(
size_t i) { return d_srcs[i] == d_dsts[i]; });
size_t i) { return d_srcs[i] != d_dsts[i]; });

if (self_loop_count > 0) {
edgelist_srcs = detail::remove_flagged_elements(
if (keep_count < edgelist_srcs.size()) {
edgelist_srcs = detail::keep_flagged_elements(
handle,
std::move(edgelist_srcs),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
edgelist_dsts = detail::remove_flagged_elements(
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
edgelist_dsts = detail::keep_flagged_elements(
handle,
std::move(edgelist_dsts),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_weights)
edgelist_weights = detail::remove_flagged_elements(
edgelist_weights = detail::keep_flagged_elements(
handle,
std::move(*edgelist_weights),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_edge_ids)
edgelist_edge_ids = detail::remove_flagged_elements(
edgelist_edge_ids = detail::keep_flagged_elements(
handle,
std::move(*edgelist_edge_ids),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_edge_types)
edgelist_edge_types = detail::remove_flagged_elements(
edgelist_edge_types = detail::keep_flagged_elements(
handle,
std::move(*edgelist_edge_types),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
}

return std::make_tuple(std::move(edgelist_srcs),
Expand Down

0 comments on commit afc00ee

Please sign in to comment.