Skip to content

Commit

Permalink
Replace explicit bitmask requirement with abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
aumetra committed Jul 12, 2024
1 parent c8dc14f commit 7a3a864
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
13 changes: 8 additions & 5 deletions src/util/simd/v256.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::ops::{BitAnd, BitOr, BitOrAssign};

use super::{bits::combine_u16, Mask, Simd};
use super::{bits::combine_u16, BitMask, Mask, Simd};
use crate::impl_lanes;

impl_lanes!([impl<B: Simd> Simd256u<B>] 32);
Expand All @@ -19,7 +19,10 @@ pub struct Simd256i<B: Simd = super::Simd128i>((B, B));
#[repr(transparent)]
pub struct Mask256<M: Mask = super::Mask128>(pub(crate) (M, M));

impl<M: Mask<BitMask = u16>> Mask for Mask256<M> {
impl<M: Mask> Mask for Mask256<M>
where
<M as Mask>::BitMask: BitMask<Primitive = u16>,
{
type BitMask = u32;
type Element = u8;

Expand All @@ -31,7 +34,7 @@ impl<M: Mask<BitMask = u16>> Mask for Mask256<M> {
let(v0, v1) = self.0;
unsafe { super::neon::to_bitmask32(v0.0, v1.0) }
} else {
combine_u16(self.0 .0.bitmask(), self.0 .1.bitmask())
combine_u16(self.0.0.bitmask().as_primitive(), self.0.1.bitmask().as_primitive())
}
}
}
Expand Down Expand Up @@ -75,7 +78,7 @@ impl<M: Mask> BitAnd<Mask256<M>> for Mask256<M> {
impl<B> Simd for Simd256u<B>
where
B: Simd<Element = u8>,
B::Mask: Mask<BitMask = u16>,
<B::Mask as Mask>::BitMask: BitMask<Primitive = u16>,
{
const LANES: usize = 32;

Expand Down Expand Up @@ -125,7 +128,7 @@ where
impl<B> Simd for Simd256i<B>
where
B: Simd<Element = i8>,
B::Mask: Mask<BitMask = u16>,
<B::Mask as Mask>::BitMask: BitMask<Primitive = u16>,
{
const LANES: usize = 32;

Expand Down
13 changes: 8 additions & 5 deletions src/util/simd/v512.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::ops::{BitAnd, BitOr, BitOrAssign};

use super::{bits::combine_u32, Mask, Simd};
use super::{bits::combine_u32, BitMask, Mask, Simd};
use crate::impl_lanes;

impl_lanes!([impl<B: Simd> Simd512u<B>] 64);
Expand All @@ -19,7 +19,10 @@ pub struct Simd512i<B: Simd = super::Simd256i>((B, B));
#[repr(transparent)]
pub struct Mask512<M: Mask = super::Mask256>((M, M));

impl<M: Mask<BitMask = u32>> Mask for Mask512<M> {
impl<M: Mask> Mask for Mask512<M>
where
<M as Mask>::BitMask: BitMask<Primitive = u32>,
{
type BitMask = u64;
type Element = u8;

Expand All @@ -33,7 +36,7 @@ impl<M: Mask<BitMask = u32>> Mask for Mask512<M> {
let (m2, m3) = v1.0;
unsafe { super::neon::to_bitmask64(m0.0, m1.0, m2.0, m3.0) }
} else {
combine_u32(self.0 .0.bitmask(), self.0 .1.bitmask())
combine_u32(self.0.0.bitmask().as_primitive(), self.0.1.bitmask().as_primitive())
}
}
}
Expand Down Expand Up @@ -77,7 +80,7 @@ impl<M: Mask> BitAnd<Mask512<M>> for Mask512<M> {
impl<B> Simd for Simd512u<B>
where
B: Simd<Element = u8>,
B::Mask: Mask<BitMask = u32>,
<B::Mask as Mask>::BitMask: BitMask<Primitive = u32>,
{
const LANES: usize = 64;
type Element = u8;
Expand Down Expand Up @@ -126,7 +129,7 @@ where
impl<B> Simd for Simd512i<B>
where
B: Simd<Element = i8>,
B::Mask: Mask<BitMask = u32>,
<B::Mask as Mask>::BitMask: BitMask<Primitive = u32>,
{
const LANES: usize = 64;
type Element = i8;
Expand Down

0 comments on commit 7a3a864

Please sign in to comment.