Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix visibility warnings #8

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 8 additions & 121 deletions src/lib.nr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mod quicksort;
pub mod quicksort;
use crate::quicksort::quicksort::quicksort as quicksort;
use crate::quicksort::quicksort_explicit::quicksort as quicksort_explicit;
use dep::check_shuffle::{check_shuffle, get_shuffle_indices};
Expand All @@ -14,7 +14,7 @@ pub fn sort<T, let N: u32>(input: [T; N]) -> [T; N]
where
T: std::cmp::Ord + std::cmp::Eq,
{
let sorted = quicksort(input);
let sorted = unsafe { quicksort(input) };

for i in 0..N - 1 {
assert(sorted[i] <= sorted[i + 1]);
Expand All @@ -36,7 +36,7 @@ pub fn sort_via<T, let N: u32>(input: [T; N], sortfn: fn(T, T) -> bool) -> [T; N
where
T: std::cmp::Eq,
{
let sorted = quicksort_explicit(input, sortfn);
let sorted = unsafe { quicksort_explicit(input, sortfn) };

for i in 0..N - 1 {
assert(sortfn(sorted[i], sorted[i + 1]));
Expand Down Expand Up @@ -67,7 +67,7 @@ pub fn sort_extended<T, let N: u32>(
where
T: std::cmp::Eq,
{
let sorted = quicksort_explicit(input, sortfn);
let sorted = unsafe { quicksort_explicit(input, sortfn) };

for i in 0..N - 1 {
sortfn_assert(sorted[i], sorted[i + 1]);
Expand All @@ -76,9 +76,9 @@ where
sorted
}

struct SortResult<T, let N: u32> {
sorted: [T; N],
sort_indices: [Field; N],
pub struct SortResult<T, let N: u32> {
pub sorted: [T; N],
pub sort_indices: [Field; N],
}
pub fn sort_advanced<T, let N: u32>(
input: [T; N],
Expand All @@ -88,7 +88,7 @@ pub fn sort_advanced<T, let N: u32>(
where
T: std::cmp::Eq,
{
let sorted = quicksort_explicit(input, sortfn);
let sorted = unsafe { quicksort_explicit(input, sortfn) };

let sort_indices = get_shuffle_indices(input, sorted);

Expand Down Expand Up @@ -148,116 +148,3 @@ mod test {
}
}

fn sort_u32(a: u32, b: u32) -> bool {
a <= b
}

fn lt_u32(a: u32, b: u32) -> bool {
a < b
}
// unconditional_lt will cost fewer constraints than the `<=` operator
// as we do not need to constrain the case where `a > b`, and assign a boolean variable to the result
fn unconditional_lt(_a: u32, _b: u32) {
let a = _a as Field;
let b = _b as Field;

let diff = b - a;
diff.assert_max_bit_size::<32>();
}

struct TestStruct {
a: bool,
b: u32,
c: Field,
}

impl std::cmp::Eq for TestStruct {
fn eq(self, other: Self) -> bool {
(self.a == other.a) & (self.b == other.b) & (self.c == other.c)
}
}

pub unconstrained fn get_lt_predicate_f(x: Field, y: Field) -> bool {
let a = x as u32;
let b = y as u32;
let r = a < b;
r
}

pub fn lt_f(x: Field, y: Field) -> bool {
let predicate = get_lt_predicate_f(x, y);
let delta = y as Field - x as Field;
let lt_parameter = 2 * (predicate as Field) * delta - predicate as Field - delta;
lt_parameter.assert_max_bit_size::<32>();

predicate
}

fn less_than_for_test_struct(lhs: TestStruct, rhs: TestStruct) -> bool {
let a_lt = lhs.a < rhs.a;
let b_lt = lhs.b < rhs.b;
let c_lt = lt_f(lhs.c, rhs.c);

let a_eq = lhs.a == rhs.a;
let b_eq = lhs.b == rhs.b;

let b_flag = a_eq;

let c_flag = a_eq & b_eq;
let result = a_lt | (b_flag & b_lt) | (c_flag & c_lt);

result
}

fn unconditional_lte(lhs: TestStruct, rhs: TestStruct) {
// lhs < rhs implies:
// a == false, b == false
// a == false, b == true
// a == true, b == true
// i.e. a == true, b == false is not allowed
assert(lhs.a as Field * (1 - rhs.a as Field) == 0);

// a < b as u32 implies
// b - a > 0
let diff = lhs.b as Field - rhs.b as Field;
diff.assert_max_bit_size::<32>();

// a < b as Field (32 bit condition)
let diff = lhs.c as Field - rhs.c as Field;
diff.assert_max_bit_size::<32>();
}

global Num: u32 = 100;

// // size 100: 7,638
// // size 1,000: 51,738
// // diff = 49
// fn main2(x: [TestStruct; Num]) {
// let sorted = sort_extended(x, less_than_for_test_struct, unconditional_lte);
// println(f"{sorted}");
// }

// // size 100: 9,321
// // size 1,000: 68,721
// // diff = 59,400 = 66 per
// fn main3(x: [TestStruct; Num]) {
// let sorted = sort_via(x, less_than_for_test_struct);
// println(f"{sorted}");
// }

fn unconditional_lt_f(a: Field, b: Field) {
let diff = b - a;
diff.assert_max_bit_size::<32>();
}

// 5,089
fn main20(x: [Field; Num]) {
let sorted = sort_via(x, lt_f);
println(f"{sorted}");
}

// 4,891
fn main000(x: [Field; Num]) {
let sorted = sort_extended(x, lt_f, unconditional_lt_f);
println(f"{sorted}");
}
4 changes: 2 additions & 2 deletions src/quicksort.nr
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mod quicksort;
mod quicksort_explicit;
pub mod quicksort;
pub mod quicksort_explicit;
2 changes: 1 addition & 1 deletion src/quicksort/quicksort.nr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
trait Swap {
pub trait Swap {
fn swap(&mut self, i: u32, j: u32);
}

Expand Down
2 changes: 1 addition & 1 deletion src/quicksort/quicksort_explicit.nr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
trait Swap {
pub trait Swap {
fn swap(&mut self, i: u32, j: u32);
}

Expand Down
Loading