Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use stronger typing for SimdUnaryOp argument #606

Merged
merged 1 commit into from
Mar 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions rten-simd/src/safe/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,24 @@ pub fn dispatch<Op: SimdOp>(op: Op) -> Op::Output {
pub trait SimdUnaryOp<T: Elem> {
/// Evaluate the unary function on the elements in `x`.
///
/// `eval` is passed an untyped SIMD vector. This can be cast to the
/// specific type expected by the operation.
/// In order to perform operations on `x`, it will need to be cast to
/// the specific type used by the ISA:
///
/// ```
/// use rten_simd::safe::{Isa, Simd, SimdFloatOps, SimdOps, SimdUnaryOp};
///
/// struct Reciprocal {}
///
/// impl SimdUnaryOp<f32> for Reciprocal {
/// fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
/// fn eval<I: Isa, S: Simd<Elem=f32, Isa=I>>(&self, isa: I, x: S) -> S {
/// let ops = isa.f32();
/// let x = ops.from_bits(x);
/// let x = x.same_cast();
/// let reciprocal = ops.div(ops.one(), x);
/// reciprocal.to_bits()
/// reciprocal.same_cast()
/// }
/// }
/// ```
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits;
fn eval<I: Isa, S: Simd<Elem = T, Isa = I>>(&self, isa: I, x: S) -> S;

/// Evaluate the unary function on elements in `x`.
///
Expand All @@ -93,7 +93,7 @@ pub trait SimdUnaryOp<T: Elem> {
where
Self: Default,
{
S::from_bits(Self::default().eval(isa, x.to_bits()))
Self::default().eval(isa, x)
}

/// Apply this function to a slice.
Expand Down Expand Up @@ -161,7 +161,7 @@ macro_rules! impl_simd_map_op {
isa.$type(),
self.src_dest,
#[inline(always)]
|x| I::$cap_type::from_bits(self.op.eval(isa, x.to_bits())),
|x| self.op.eval(isa, x),
)
}
}
Expand Down Expand Up @@ -202,11 +202,11 @@ mod tests {
struct Reciprocal {}

impl SimdUnaryOp<f32> for Reciprocal {
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.f32();
let x = ops.from_bits(x);
let x = x.same_cast();
let y = ops.div(ops.one(), x);
y.to_bits()
y.same_cast()
}
}

Expand All @@ -217,20 +217,30 @@ mod tests {
}

#[test]
fn test_unary_int_op() {
fn test_unary_generic_op() {
struct Double {}

impl SimdUnaryOp<i32> for Double {
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
let ops = isa.i32();
let x = ops.from_bits(x);
ops.add(x, x).to_bits()
}
macro_rules! impl_double {
($elem:ident) => {
impl SimdUnaryOp<$elem> for Double {
fn eval<I: Isa, S: Simd<Elem = $elem, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.$elem();
let x = x.same_cast();
ops.add(x, x).same_cast()
}
}
};
}

impl_double!(i32);
impl_double!(f32);

let mut buf = [1, 2, 3, 4];
Double {}.map_mut(&mut buf);

assert_eq!(buf, [2, 4, 6, 8]);

let mut buf = [1., 2., 3., 4.];
Double {}.map_mut(&mut buf);
assert_eq!(buf, [2., 4., 6., 8.]);
}
}
12 changes: 12 additions & 0 deletions rten-simd/src/safe/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ pub trait Simd: Copy + Debug {
T::from_bits(self.to_bits())
}

/// Cast this vector to another with the same ISA and element type.
///
/// This cast is a no-op which doesn't generate any code. It is needed in
/// some cases to downcast a `Simd` type to one of an `Isa`s associated
/// types, or vice-versa.
fn same_cast<T>(self) -> T
where
T: Simd<Elem = Self::Elem, Isa = Self::Isa>,
{
T::from_bits(self.to_bits())
}

/// Convert `self` to a SIMD array.
///
/// This is a cheap transmute in most cases, since SIMD vectors usually
Expand Down
12 changes: 6 additions & 6 deletions rten-vecmath/src/erf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ pub struct Erf {}

impl SimdUnaryOp<f32> for Erf {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.f32();
let x = I::F32::from_bits(x);
let x = x.same_cast();

let neg_mask = ops.lt(x, ops.zero());

Expand Down Expand Up @@ -51,7 +51,7 @@ impl SimdUnaryOp<f32> for Erf {

// Approximation is valid only for x >= 0. For negative values approximation
// can be computed as -erf(-x).
ops.select(ops.neg(y), y, neg_mask).to_bits()
ops.select(ops.neg(y), y, neg_mask).same_cast()
}
}

Expand All @@ -63,15 +63,15 @@ pub struct Gelu {}

impl SimdUnaryOp<f32> for Gelu {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.f32();
let x = I::F32::from_bits(x);
let x = x.same_cast();

let half_x = ops.mul(x, ops.splat(0.5));
let sqrt_2_rcp = ops.splat(SQRT_2_RCP);
let y = ops.mul(x, sqrt_2_rcp);
let y = ops.add(Erf::apply(isa, y), ops.splat(1.0));
ops.mul(half_x, y).to_bits()
ops.mul(half_x, y).same_cast()
}
}

Expand Down
25 changes: 13 additions & 12 deletions rten-vecmath/src/exp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ pub struct Exp {}
// into multiple steps to extend the domain.
impl SimdUnaryOp<f32> for Exp {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.f32();
let int_ops = isa.i32();

let x = I::F32::from_bits(x);
let x = x.same_cast();

// Load constants
let inv_log_2 = ops.splat(INV_LOG2);
Expand Down Expand Up @@ -123,7 +123,7 @@ impl SimdUnaryOp<f32> for Exp {
let overflow_mask = ops.ge(x, ops.splat(104.0));
let underflow_mask = ops.le(x, ops.splat(-104.0));
let r = ops.select(ops.splat(f32::INFINITY), r, overflow_mask);
ops.select(ops.zero(), r, underflow_mask).to_bits()
ops.select(ops.zero(), r, underflow_mask).same_cast()
}
}

Expand All @@ -139,13 +139,13 @@ pub struct Sigmoid {}

impl SimdUnaryOp<f32> for Sigmoid {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.f32();
let x = I::F32::from_bits(x);
let x = x.same_cast();

// 1. + exp(-x)
let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
ops.reciprocal(denom).to_bits()
ops.reciprocal(denom).same_cast()
}
}

Expand All @@ -156,11 +156,11 @@ pub struct Silu {}

impl SimdUnaryOp<f32> for Silu {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.f32();
let x = I::F32::from_bits(x);
let x = x.same_cast();

ops.mul(x, Sigmoid::apply(isa, x)).to_bits()
ops.mul(x, Sigmoid::apply(isa, x)).same_cast()
}
}

Expand All @@ -173,12 +173,13 @@ pub struct Swish {

impl SimdUnaryOp<f32> for Swish {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.f32();
let x = I::F32::from_bits(x);
let x = x.same_cast();

let beta = ops.splat(self.beta);
ops.mul(x, Sigmoid::apply(isa, ops.mul(x, beta))).to_bits()
ops.mul(x, Sigmoid::apply(isa, ops.mul(x, beta)))
.same_cast()
}
}

Expand Down
6 changes: 3 additions & 3 deletions rten-vecmath/src/tanh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ pub struct Tanh {}

impl SimdUnaryOp<f32> for Tanh {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::Bits) -> I::Bits {
fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
let ops = isa.f32();
let x = I::F32::from_bits(x);
let x = x.same_cast();

let x_negative = ops.le(x, ops.zero());
let abs_x = ops.abs(x);
Expand Down Expand Up @@ -60,7 +60,7 @@ impl SimdUnaryOp<f32> for Tanh {
let y = ops.select(abs_x, y, x_tiny);

// Flip sign if input was negative.
ops.select(ops.neg(y), y, x_negative).to_bits()
ops.select(ops.neg(y), y, x_negative).same_cast()
}
}

Expand Down