diff --git a/Cargo.lock b/Cargo.lock index d55918fb1..d1d606598 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1447,6 +1447,7 @@ dependencies = [ "lock_api", "once_cell", "parking_lot_core", + "rayon", "serde", ] diff --git a/Cargo.toml b/Cargo.toml index fbf5f84a6..7bd6d2221 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,7 +83,7 @@ rand_distr = "0.4.3" rustc-hash = "2.0.0" twox-hash = "2.1.0" lock_api = { version = "0.4.11", features = ["arc_lock", "serde"] } -dashmap = { version = "6.0.1", features = ["serde"] } +dashmap = { version = "6.0.1", features = ["serde", "rayon"] } enum_dispatch = "0.3.12" glam = "0.29.0" quad-rand = "0.2.1" diff --git a/raphtory/src/db/api/state/group_by.rs b/raphtory/src/db/api/state/group_by.rs index 2883a6121..a872743a6 100644 --- a/raphtory/src/db/api/state/group_by.rs +++ b/raphtory/src/db/api/state/group_by.rs @@ -5,8 +5,10 @@ use crate::{ }, prelude::{GraphViewOps, NodeStateOps}, }; +use dashmap::DashMap; use raphtory_api::core::entities::VID; -use std::{collections::HashMap, hash::Hash, sync::Arc}; +use rayon::prelude::*; +use std::{hash::Hash, sync::Arc}; #[derive(Clone, Debug)] pub struct NodeGroups { @@ -14,14 +16,15 @@ pub struct NodeGroups { graph: G, } -impl<'graph, V: Hash + Eq, G: GraphViewOps<'graph>> NodeGroups { - pub(crate) fn new(values: impl Iterator, graph: G) -> Self { - let mut groups: HashMap> = HashMap::new(); - for (node, v) in values { +impl<'graph, V: Hash + Eq + Send + Sync + Clone, G: GraphViewOps<'graph>> NodeGroups { + pub(crate) fn new(values: impl ParallelIterator, graph: G) -> Self { + let groups: DashMap, ahash::RandomState> = DashMap::default(); + values.for_each(|(node, v)| { groups.entry(v).or_insert_with(Vec::new).push(node); - } + }); + let groups = groups - .into_iter() + .into_par_iter() .map(|(k, v)| (k, Index::new(v))) .collect(); Self { groups, graph } diff --git a/raphtory/src/db/api/state/node_state_ops.rs b/raphtory/src/db/api/state/node_state_ops.rs index 7ba019671..1e293b3c1 100644 --- a/raphtory/src/db/api/state/node_state_ops.rs +++ b/raphtory/src/db/api/state/node_state_ops.rs @@ -208,12 +208,12 @@ pub trait NodeStateOps<'graph>: values.into_iter().nth(median_index) } - fn group_by V + Sync>( + fn group_by V + Sync>( &self, group_fn: F, ) -> NodeGroups { NodeGroups::new( - self.iter() + self.par_iter() .map(|(node, v)| (node.node, group_fn(v.borrow()))), self.graph().clone(), )