Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-DSA: AVX2 implementations of some arithmetic functions and t1 deserialization. #455

Merged
merged 7 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions libcrux-intrinsics/src/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ pub fn mm256_loadu_si256_i32(input: &[i32]) -> Vec256 {
pub fn mm256_setzero_si256() -> Vec256 {
unsafe { _mm256_setzero_si256() }
}
pub fn mm256_set_m128i(hi: Vec128, lo: Vec128) -> Vec256 {
unsafe { _mm256_set_m128i(hi, lo) }
}

pub fn mm_set_epi8(
byte15: u8,
Expand Down Expand Up @@ -220,6 +223,11 @@ pub fn mm256_add_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_add_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_abs_epi32(a: Vec256) -> Vec256 {
unsafe { _mm256_abs_epi32(a) }
}

pub fn mm256_sub_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_sub_epi16(lhs, rhs) }
}
Expand Down Expand Up @@ -270,6 +278,10 @@ pub fn mm256_and_si256(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_and_si256(lhs, rhs) }
}

pub fn mm256_testz_si256(lhs: Vec256, rhs: Vec256) -> i32 {
unsafe { _mm256_testz_si256(lhs, rhs) }
}

pub fn mm256_xor_si256(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_xor_si256(lhs, rhs) }
}
Expand Down
89 changes: 24 additions & 65 deletions libcrux-ml-dsa/src/simd/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ pub struct AVX2SIMDUnit {
pub(crate) coefficients: libcrux_intrinsics::avx2::Vec256,
}

impl From<libcrux_intrinsics::avx2::Vec256> for AVX2SIMDUnit {
fn from(coefficients: libcrux_intrinsics::avx2::Vec256) -> Self {
Self { coefficients }
}
}

impl Operations for AVX2SIMDUnit {
fn ZERO() -> Self {
Self {
Expand All @@ -34,62 +40,37 @@ impl Operations for AVX2SIMDUnit {
}

fn add(lhs: &Self, rhs: &Self) -> Self {
Self {
coefficients: arithmetic::add(lhs.coefficients, rhs.coefficients),
}
arithmetic::add(lhs.coefficients, rhs.coefficients).into()
}

fn subtract(lhs: &Self, rhs: &Self) -> Self {
Self {
coefficients: arithmetic::subtract(lhs.coefficients, rhs.coefficients),
}
arithmetic::subtract(lhs.coefficients, rhs.coefficients).into()
}

fn montgomery_multiply_by_constant(simd_unit: Self, constant: i32) -> Self {
Self {
coefficients: arithmetic::montgomery_multiply_by_constant(
simd_unit.coefficients,
constant,
),
}
arithmetic::montgomery_multiply_by_constant(simd_unit.coefficients, constant).into()
}
fn montgomery_multiply(lhs: Self, rhs: Self) -> Self {
Self {
coefficients: arithmetic::montgomery_multiply(lhs.coefficients, rhs.coefficients),
}
arithmetic::montgomery_multiply(lhs.coefficients, rhs.coefficients).into()
}
fn shift_left_then_reduce<const SHIFT_BY: i32>(simd_unit: Self) -> Self {
Self {
coefficients: arithmetic::shift_left_then_reduce::<SHIFT_BY>(simd_unit.coefficients),
}
arithmetic::shift_left_then_reduce::<SHIFT_BY>(simd_unit.coefficients).into()
}

fn power2round(simd_unit: Self) -> (Self, Self) {
let simd_unit = PortableSIMDUnit::from_coefficient_array(&simd_unit.to_coefficient_array());

let (lower, upper) = PortableSIMDUnit::power2round(simd_unit);
let (lower, upper) = arithmetic::power2round(simd_unit.coefficients);

(
Self::from_coefficient_array(&lower.to_coefficient_array()),
Self::from_coefficient_array(&upper.to_coefficient_array()),
)
(lower.into(), upper.into())
}

fn infinity_norm_exceeds(simd_unit: Self, bound: i32) -> bool {
let simd_unit = PortableSIMDUnit::from_coefficient_array(&simd_unit.to_coefficient_array());

PortableSIMDUnit::infinity_norm_exceeds(simd_unit, bound)
arithmetic::infinity_norm_exceeds(simd_unit.coefficients, bound)
}

fn decompose<const GAMMA2: i32>(simd_unit: Self) -> (Self, Self) {
let simd_unit = PortableSIMDUnit::from_coefficient_array(&simd_unit.to_coefficient_array());

let (lower, upper) = PortableSIMDUnit::decompose::<GAMMA2>(simd_unit);
let (lower, upper) = arithmetic::decompose::<GAMMA2>(simd_unit.coefficients);

(
Self::from_coefficient_array(&lower.to_coefficient_array()),
Self::from_coefficient_array(&upper.to_coefficient_array()),
)
(lower.into(), upper.into())
}

fn compute_hint<const GAMMA2: i32>(low: Self, high: Self) -> (usize, Self) {
Expand Down Expand Up @@ -139,9 +120,7 @@ impl Operations for AVX2SIMDUnit {
encoding::error::serialize::<OUTPUT_SIZE>(simd_unit.coefficients)
}
fn error_deserialize<const ETA: usize>(serialized: &[u8]) -> Self {
AVX2SIMDUnit {
coefficients: encoding::error::deserialize::<ETA>(serialized),
}
encoding::error::deserialize::<ETA>(serialized).into()
}

fn t0_serialize(simd_unit: Self) -> [u8; 13] {
Expand All @@ -159,25 +138,17 @@ impl Operations for AVX2SIMDUnit {
encoding::t1::serialize(simd_unit.coefficients)
}
fn t1_deserialize(serialized: &[u8]) -> Self {
let result = PortableSIMDUnit::t1_deserialize(serialized);

Self::from_coefficient_array(&result.to_coefficient_array())
encoding::t1::deserialize(serialized).into()
}

fn ntt_at_layer_0(simd_unit: Self, zeta0: i32, zeta1: i32, zeta2: i32, zeta3: i32) -> Self {
Self {
coefficients: ntt::ntt_at_layer_0(simd_unit.coefficients, zeta0, zeta1, zeta2, zeta3),
}
ntt::ntt_at_layer_0(simd_unit.coefficients, zeta0, zeta1, zeta2, zeta3).into()
}
fn ntt_at_layer_1(simd_unit: Self, zeta0: i32, zeta1: i32) -> Self {
Self {
coefficients: ntt::ntt_at_layer_1(simd_unit.coefficients, zeta0, zeta1),
}
ntt::ntt_at_layer_1(simd_unit.coefficients, zeta0, zeta1).into()
}
fn ntt_at_layer_2(simd_unit: Self, zeta: i32) -> Self {
Self {
coefficients: ntt::ntt_at_layer_2(simd_unit.coefficients, zeta),
}
ntt::ntt_at_layer_2(simd_unit.coefficients, zeta).into()
}

fn invert_ntt_at_layer_0(
Expand All @@ -187,24 +158,12 @@ impl Operations for AVX2SIMDUnit {
zeta2: i32,
zeta3: i32,
) -> Self {
Self {
coefficients: ntt::invert_ntt_at_layer_0(
simd_unit.coefficients,
zeta0,
zeta1,
zeta2,
zeta3,
),
}
ntt::invert_ntt_at_layer_0(simd_unit.coefficients, zeta0, zeta1, zeta2, zeta3).into()
}
fn invert_ntt_at_layer_1(simd_unit: Self, zeta0: i32, zeta1: i32) -> Self {
Self {
coefficients: ntt::invert_ntt_at_layer_1(simd_unit.coefficients, zeta0, zeta1),
}
ntt::invert_ntt_at_layer_1(simd_unit.coefficients, zeta0, zeta1).into()
}
fn invert_ntt_at_layer_2(simd_unit: Self, zeta: i32) -> Self {
Self {
coefficients: ntt::invert_ntt_at_layer_2(simd_unit.coefficients, zeta),
}
ntt::invert_ntt_at_layer_2(simd_unit.coefficients, zeta).into()
}
}
113 changes: 112 additions & 1 deletion libcrux-ml-dsa/src/simd/avx2/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
use crate::simd::traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R};
use crate::{
constants::BITS_IN_LOWER_PART_OF_T,
simd::traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R},
};

use libcrux_intrinsics::avx2::*;

fn to_unsigned_representatives(t: Vec256) -> Vec256 {
let signs = mm256_srai_epi32::<31>(t);
let conditional_add_field_modulus = mm256_and_si256(signs, mm256_set1_epi32(FIELD_MODULUS));

mm256_add_epi32(t, conditional_add_field_modulus)
}

#[inline(always)]
pub fn add(lhs: Vec256, rhs: Vec256) -> Vec256 {
mm256_add_epi32(lhs, rhs)
Expand Down Expand Up @@ -72,3 +82,104 @@ pub fn shift_left_then_reduce<const SHIFT_BY: i32>(simd_unit: Vec256) -> Vec256

mm256_sub_epi32(shifted, quotient_times_field_modulus)
}

// TODO: Revisit this function when doing the range analysis and testing
// additional KATs.
#[inline(always)]
pub fn infinity_norm_exceeds(simd_unit: Vec256, bound: i32) -> bool {
let absolute_values = mm256_abs_epi32(simd_unit);

// We will test if |simd_unit| > bound - 1, because if this is the case then
// it follows that |simd_unit| >= bound
let bound = mm256_set1_epi32(bound - 1);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory this could underflow, right? Is that why you want to revisit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I'm not sure if it can underflow, but I put the TODO there since I haven't tried to properly calculate the representative ranges that this function will see.

let compare_with_bound = mm256_cmpgt_epi32(absolute_values, bound);

// If every lane of |result| is 0, all coefficients are <= bound - 1
let result = mm256_testz_si256(compare_with_bound, compare_with_bound);

if result == 1 {
false
} else {
true
}
}

#[inline(always)]
pub fn power2round(r: Vec256) -> (Vec256, Vec256) {
let r = to_unsigned_representatives(r);

let r1 = mm256_add_epi32(
r,
mm256_set1_epi32((1 << (BITS_IN_LOWER_PART_OF_T - 1)) - 1),
);
let r1 = mm256_srai_epi32::<{ BITS_IN_LOWER_PART_OF_T as i32 }>(r1);

let r0 = mm256_slli_epi32::<{ BITS_IN_LOWER_PART_OF_T as i32 }>(r1);
let r0 = mm256_sub_epi32(r, r0);

(r0, r1)
}

#[allow(non_snake_case)]
#[inline(always)]
pub fn decompose<const GAMMA2: i32>(r: Vec256) -> (Vec256, Vec256) {
let r = to_unsigned_representatives(r);

let field_modulus_halved = mm256_set1_epi32((FIELD_MODULUS - 1) / 2);

// When const-generic expressions are available, this could be turned into a
// const value.
let ALPHA: i32 = GAMMA2 * 2;

let r1 = {
let ceil_of_r_by_128 = mm256_add_epi32(r, mm256_set1_epi32(127));
let ceil_of_r_by_128 = mm256_srai_epi32::<7>(ceil_of_r_by_128);

match ALPHA {
190_464 => {
// We approximate 1 / 1488 as:
// ⌊2²⁴ / 1488⌋ / 2²⁴ = 11,275 / 2²⁴
let result = mm256_mullo_epi32(ceil_of_r_by_128, mm256_set1_epi32(11_275));
let result = mm256_add_epi32(result, mm256_set1_epi32(1 << 23));
let result = mm256_srai_epi32::<24>(result);

// For the corner-case a₁ = (q-1)/α = 44, we have to set a₁=0.
let mask = mm256_sub_epi32(mm256_set1_epi32(43), result);
let mask = mm256_srai_epi32::<31>(mask);

let not_result = mm256_xor_si256(result, mask);

mm256_and_si256(result, not_result)
}

523_776 => {
// We approximate 1 / 4092 as:
// ⌊2²² / 4092⌋ / 2²² = 1025 / 2²²
let result = mm256_mullo_epi32(ceil_of_r_by_128, mm256_set1_epi32(1025));
let result = mm256_add_epi32(result, mm256_set1_epi32(1 << 21));
let result = mm256_srai_epi32::<22>(result);

// For the corner-case a₁ = (q-1)/α = 16, we have to set a₁=0.
mm256_and_si256(result, mm256_set1_epi32(15))
}

_ => unreachable!(),
}
};

// In the corner-case, when we set a₁=0, we will incorrectly
// have a₀ > (q-1)/2 and we'll need to subtract q. As we
// return a₀ + q, that comes down to adding q if a₀ < (q-1)/2.
let r0 = mm256_mullo_epi32(r1, mm256_set1_epi32(ALPHA));
let r0 = mm256_sub_epi32(r, r0);

let mask = mm256_sub_epi32(field_modulus_halved, r0);
let mask = mm256_srai_epi32::<31>(mask);

let field_modulus_and_mask = mm256_and_si256(mask, mm256_set1_epi32(FIELD_MODULUS));

let r0 = mm256_sub_epi32(r0, field_modulus_and_mask);

(r0, r1)
}
27 changes: 26 additions & 1 deletion libcrux-ml-dsa/src/simd/avx2/encoding/t1.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use libcrux_intrinsics::avx2::*;

#[inline(always)]
pub fn serialize(simd_unit: Vec256) -> [u8; 10] {
pub(crate) fn serialize(simd_unit: Vec256) -> [u8; 10] {
let mut serialized = [0u8; 24];

let adjacent_2_combined =
Expand All @@ -26,3 +26,28 @@ pub fn serialize(simd_unit: Vec256) -> [u8; 10] {

serialized[0..10].try_into().unwrap()
}

#[inline(always)]
pub(crate) fn deserialize(bytes: &[u8]) -> Vec256 {
debug_assert_eq!(bytes.len(), 10);

const COEFFICIENT_MASK: i32 = (1 << 10) - 1;

let mut bytes_extended = [0u8; 16];
bytes_extended[0..10].copy_from_slice(bytes);

let bytes_loaded = mm_loadu_si128(&bytes_extended);
let bytes_loaded = mm256_set_m128i(bytes_loaded, bytes_loaded);

let coefficients = mm256_shuffle_epi8(
bytes_loaded,
mm256_set_epi8(
-1, -1, 9, 8, -1, -1, 8, 7, -1, -1, 7, 6, -1, -1, 6, 5, -1, -1, 4, 3, -1, -1, 3, 2, -1,
-1, 2, 1, -1, -1, 1, 0,
),
);

let coefficients = mm256_srlv_epi32(coefficients, mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0));

mm256_and_si256(coefficients, mm256_set1_epi32(COEFFICIENT_MASK))
}
1 change: 1 addition & 0 deletions libcrux-ml-dsa/src/simd/avx2/rejection_sample.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub(crate) mod less_than_eta;
pub(crate) mod less_than_field_modulus;
mod shuffle_table;
mod utils;
29 changes: 7 additions & 22 deletions libcrux-ml-dsa/src/simd/avx2/rejection_sample/less_than_eta.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
use crate::simd::avx2::{encoding, rejection_sample::shuffle_table::SHUFFLE_TABLE};
use crate::simd::avx2::{
encoding,
rejection_sample::{shuffle_table::SHUFFLE_TABLE, utils},
};

use libcrux_intrinsics::avx2::*;

#[inline(always)]
fn extract_least_significant_bits(simd_unit: Vec256) -> u8 {
let first_byte_from_each_i32_lane = mm256_shuffle_epi8(
simd_unit,
mm256_set_epi8(
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, 12, 8, 4, 0,
),
);

let bytes_grouped = mm256_permutevar8x32_epi32(
first_byte_from_each_i32_lane,
mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0),
);
let bytes_grouped = mm256_castsi256_si128(bytes_grouped);

let bits = mm_movemask_epi8(bytes_grouped);

(bits & 0xFF) as u8
}
// TODO: This code seems to slow the implementation down, but stabilizes
// benchmarks. Revisit this once the other functions are vectorized.

#[inline(always)]
fn shift_interval<const ETA: usize>(coefficients: Vec256) -> Vec256 {
Expand Down Expand Up @@ -59,7 +44,7 @@ pub(crate) fn sample<const ETA: usize>(input: &[u8], output: &mut [i32]) -> usiz
// Since every bit in each lane is either 0 or all 1s, we only need one bit
// from each lane to tell us what coefficients to keep and what to throw-away.
// Combine all the bits (there are 8) into one byte.
let good = extract_least_significant_bits(compare_with_interval_boundary);
let good = utils::extract_least_significant_bits(compare_with_interval_boundary);

let good_lower_half = good & 0x0F;
let good_upper_half = good >> 4;
Expand Down
Loading
Loading