Skip to content

Commit

Permalink
chore: add compile-time assertions on generic arguments of stdlib fun…
Browse files Browse the repository at this point in the history
…ctions (#6981)

Co-authored-by: Tom French <[email protected]>
  • Loading branch information
michaeljklein and TomAFrench authored Jan 23, 2025
1 parent c8d5ce5 commit dce2c7d
Show file tree
Hide file tree
Showing 15 changed files with 210 additions and 14 deletions.
7 changes: 6 additions & 1 deletion compiler/noirc_evaluator/src/acir/generated_acir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,17 @@ impl<F: AcirField> GeneratedAcir<F> {
bit_size: u32,
) -> Result<Vec<Witness>, RuntimeError> {
let radix_big = BigUint::from(radix);
let radix_range = BigUint::from(2u128)..=BigUint::from(256u128);
assert!(
radix_range.contains(&radix_big),
"ICE: Radix must be in the range 2..=256, but found: {:?}",
radix
);
assert_eq!(
BigUint::from(2u128).pow(bit_size),
radix_big,
"ICE: Radix must be a power of 2"
);

let limb_witnesses = self.brillig_to_radix(input_expr, radix, limb_count);

let mut composed_limbs = Expression::default();
Expand Down
7 changes: 6 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,12 @@ fn constant_to_radix(
) -> SimplifyResult {
let bit_size = u32::BITS - (radix - 1).leading_zeros();
let radix_big = BigUint::from(radix);
assert_eq!(BigUint::from(2u128).pow(bit_size), radix_big, "ICE: Radix must be a power of 2");
let radix_range = BigUint::from(2u128)..=BigUint::from(256u128);
if !radix_range.contains(&radix_big) || BigUint::from(2u128).pow(bit_size) != radix_big {
// NOTE: expect an error to be thrown later in
// acir::generated_acir::radix_le_decompose
return SimplifyResult::None;
}
let big_integer = BigUint::from_bytes_be(&field.to_be_bytes());

// Decompose the integer into its radix digits in little endian form.
Expand Down
25 changes: 24 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl<'local, 'context> Interpreter<'local, 'context> {
"array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location),
"array_len" => array_len(interner, arguments, location),
"array_refcount" => Ok(Value::U32(0)),
"assert_constant" => Ok(Value::Bool(true)),
"assert_constant" => Ok(Value::Unit),
"as_slice" => as_slice(interner, arguments, location),
"ctstring_eq" => ctstring_eq(arguments, location),
"ctstring_hash" => ctstring_hash(arguments, location),
Expand Down Expand Up @@ -175,6 +175,7 @@ impl<'local, 'context> Interpreter<'local, 'context> {
"slice_push_front" => slice_push_front(interner, arguments, location),
"slice_refcount" => Ok(Value::U32(0)),
"slice_remove" => slice_remove(interner, arguments, location, call_stack),
"static_assert" => static_assert(interner, arguments, location, call_stack),
"str_as_bytes" => str_as_bytes(interner, arguments, location),
"str_as_ctstring" => str_as_ctstring(interner, arguments, location),
"struct_def_add_attribute" => struct_def_add_attribute(interner, arguments, location),
Expand Down Expand Up @@ -327,6 +328,28 @@ fn slice_push_back(
Ok(Value::Slice(values, typ))
}

// static_assert<let N: u32>(predicate: bool, message: str<N>)
fn static_assert(
interner: &NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
call_stack: &im::Vector<Location>,
) -> IResult<Value> {
let (predicate, message) = check_two_arguments(arguments, location)?;
let predicate = get_bool(predicate)?;
let message = get_str(interner, message)?;

if predicate {
Ok(Value::Unit)
} else {
failing_constraint(
format!("static_assert failed: {}", message).clone(),
location,
call_stack,
)
}
}

fn str_as_bytes(
interner: &NodeInterner,
arguments: Vec<(Value, Location)>,
Expand Down
21 changes: 21 additions & 0 deletions compiler/noirc_frontend/src/tests/metaprogramming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use noirc_errors::Spanned;
use crate::{
ast::Ident,
hir::{
comptime::InterpreterError,
def_collector::{
dc_crate::CompilationError,
errors::{DefCollectorErrorKind, DuplicateType},
Expand All @@ -26,6 +27,26 @@ fn comptime_let() {
assert_eq!(errors.len(), 0);
}

#[test]
fn comptime_code_rejects_dynamic_variable() {
let src = r#"fn main(x: Field) {
comptime let my_var = (x - x) + 2;
assert_eq(my_var, 2);
}"#;
let errors = get_program_errors(src);

assert_eq!(errors.len(), 1);
match &errors[0].0 {
CompilationError::InterpreterError(InterpreterError::NonComptimeVarReferenced {
name,
..
}) => {
assert_eq!(name, "x");
}
_ => panic!("expected an InterpreterError"),
}
}

#[test]
fn comptime_type_in_runtime_code() {
let source = "pub fn foo(_f: FunctionDefinition) {}";
Expand Down
4 changes: 2 additions & 2 deletions noir_stdlib/src/collections/bounded_vec.nr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{cmp::Eq, convert::From, runtime::is_unconstrained};
use crate::{cmp::Eq, convert::From, runtime::is_unconstrained, static_assert};

/// A `BoundedVec<T, MaxLen>` is a growable storage similar to a `Vec<T>` except that it
/// is bounded with a maximum possible length. Unlike `Vec`, `BoundedVec` is not implemented
Expand Down Expand Up @@ -345,7 +345,7 @@ impl<T, let MaxLen: u32> BoundedVec<T, MaxLen> {
/// let bounded_vec: BoundedVec<Field, 10> = BoundedVec::from_array([1, 2, 3])
/// ```
pub fn from_array<let Len: u32>(array: [T; Len]) -> Self {
assert(Len <= MaxLen, "from array out of bounds");
static_assert(Len <= MaxLen, "from array out of bounds");
let mut vec: BoundedVec<T, MaxLen> = BoundedVec::new();
vec.extend_from_array(array);
vec
Expand Down
94 changes: 91 additions & 3 deletions noir_stdlib/src/field/mod.nr
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod bn254;
use crate::runtime::is_unconstrained;
use crate::{runtime::is_unconstrained, static_assert};
use bn254::lt as bn254_lt;

impl Field {
Expand All @@ -10,7 +10,10 @@ impl Field {
// docs:start:assert_max_bit_size
pub fn assert_max_bit_size<let BIT_SIZE: u32>(self) {
// docs:end:assert_max_bit_size
assert(BIT_SIZE < modulus_num_bits() as u32);
static_assert(
BIT_SIZE < modulus_num_bits() as u32,
"BIT_SIZE must be less than modulus_num_bits",
);
self.__assert_max_bit_size(BIT_SIZE);
}

Expand Down Expand Up @@ -61,6 +64,10 @@ impl Field {
// docs:start:to_le_bytes
pub fn to_le_bytes<let N: u32>(self: Self) -> [u8; N] {
// docs:end:to_le_bytes
static_assert(
N <= modulus_le_bytes().len(),
"N must be less than or equal to modulus_le_bytes().len()",
);
// Compute the byte decomposition
let bytes = self.to_le_radix(256);

Expand Down Expand Up @@ -94,6 +101,10 @@ impl Field {
// docs:start:to_be_bytes
pub fn to_be_bytes<let N: u32>(self: Self) -> [u8; N] {
// docs:end:to_be_bytes
static_assert(
N <= modulus_le_bytes().len(),
"N must be less than or equal to modulus_le_bytes().len()",
);
// Compute the byte decomposition
let bytes = self.to_be_radix(256);

Expand All @@ -119,7 +130,9 @@ impl Field {
pub fn to_le_radix<let N: u32>(self: Self, radix: u32) -> [u8; N] {
// Brillig does not need an immediate radix
if !crate::runtime::is_unconstrained() {
crate::assert_constant(radix);
static_assert(1 < radix, "radix must be greater than 1");
static_assert(radix <= 256, "radix must be less than or equal to 256");
static_assert(radix & (radix - 1) == 0, "radix must be a power of 2");
}
self.__to_le_radix(radix)
}
Expand All @@ -139,6 +152,7 @@ impl Field {
#[builtin(to_le_radix)]
fn __to_le_radix<let N: u32>(self, radix: u32) -> [u8; N] {}

// `_radix` must be less than 256
#[builtin(to_be_radix)]
fn __to_be_radix<let N: u32>(self, radix: u32) -> [u8; N] {}

Expand Down Expand Up @@ -172,6 +186,10 @@ impl Field {
/// Convert a little endian byte array to a field element.
/// If the provided byte array overflows the field modulus then the Field will silently wrap around.
pub fn from_le_bytes<let N: u32>(bytes: [u8; N]) -> Field {
static_assert(
N <= modulus_le_bytes().len(),
"N must be less than or equal to modulus_le_bytes().len()",
);
let mut v = 1;
let mut result = 0;

Expand Down Expand Up @@ -262,6 +280,7 @@ fn lt_fallback(x: Field, y: Field) -> bool {
}

mod tests {
use crate::{panic::panic, runtime};
use super::field_less_than;

#[test]
Expand Down Expand Up @@ -322,6 +341,75 @@ mod tests {
}
// docs:end:to_le_radix_example

#[test(should_fail_with = "radix must be greater than 1")]
fn test_to_le_radix_1() {
// this test should only fail in constrained mode
if !runtime::is_unconstrained() {
let field = 2;
let _: [u8; 8] = field.to_le_radix(1);
} else {
panic(f"radix must be greater than 1");
}
}

#[test]
fn test_to_le_radix_brillig_1() {
// this test should only fail in constrained mode
if runtime::is_unconstrained() {
let field = 1;
let out: [u8; 8] = field.to_le_radix(1);
crate::println(out);
let expected = [0; 8];
assert(out == expected, "unexpected result");
}
}

#[test(should_fail_with = "radix must be a power of 2")]
fn test_to_le_radix_3() {
// this test should only fail in constrained mode
if !runtime::is_unconstrained() {
let field = 2;
let _: [u8; 8] = field.to_le_radix(3);
} else {
panic(f"radix must be a power of 2");
}
}

#[test]
fn test_to_le_radix_brillig_3() {
// this test should only fail in constrained mode
if runtime::is_unconstrained() {
let field = 1;
let out: [u8; 8] = field.to_le_radix(3);
let mut expected = [0; 8];
expected[0] = 1;
assert(out == expected, "unexpected result");
}
}

#[test(should_fail_with = "radix must be less than or equal to 256")]
fn test_to_le_radix_512() {
// this test should only fail in constrained mode
if !runtime::is_unconstrained() {
let field = 2;
let _: [u8; 8] = field.to_le_radix(512);
} else {
panic(f"radix must be less than or equal to 256")
}
}

#[test]
fn test_to_le_radix_brillig_512() {
// this test should only fail in constrained mode
if runtime::is_unconstrained() {
let field = 1;
let out: [u8; 8] = field.to_le_radix(512);
let mut expected = [0; 8];
expected[0] = 1;
assert(out == expected, "unexpected result");
}
}

#[test]
unconstrained fn test_field_less_than() {
assert(field_less_than(0, 1));
Expand Down
3 changes: 2 additions & 1 deletion noir_stdlib/src/meta/ctstring.nr
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ impl CtString {
"".as_ctstring()
}

// Bug: using &mut self as the object results in this method not being found
// TODO(https://github.com/noir-lang/noir/issues/6980): Bug: using &mut self
// as the object results in this method not being found
// docs:start:append_str
pub comptime fn append_str<let N: u32>(self, s: str<N>) -> Self {
// docs:end:append_str
Expand Down
4 changes: 2 additions & 2 deletions noir_stdlib/src/uint128.nr
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::cmp::{Eq, Ord, Ordering};
use crate::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Not, Rem, Shl, Shr, Sub};
use crate::static_assert;
use super::{convert::AsPrimitive, default::Default};

global pow64: Field = 18446744073709551616; //2^64;
Expand Down Expand Up @@ -67,11 +68,10 @@ impl U128 {
}

pub fn from_hex<let N: u32>(hex: str<N>) -> U128 {
let N = N as u32;
let bytes = hex.as_bytes();
// string must starts with "0x"
assert((bytes[0] == 48) & (bytes[1] == 120), "Invalid hexadecimal string");
assert(N < 35, "Input does not fit into a U128");
static_assert(N < 35, "Input does not fit into a U128");

let mut lo = 0;
let mut hi = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "comptime_static_assert_failure"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use std::static_assert;

comptime fn foo(x: Field) -> bool {
static_assert(x == 4, "x != 4");
x == 4
}

fn main() {
comptime {
static_assert(foo(3), "expected message");
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "comptime_static_assert"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use std::static_assert;

comptime fn foo(x: Field) -> bool {
static_assert(x == 4, "x != 4");
x == 4
}

global C: bool = {
let out = foo(2 + 2);
static_assert(out, "foo did not pass in C");
out
};

fn main() {
comptime {
static_assert(foo(4), "foo did not pass in main");
static_assert(C, "C did not pass")
}
}
8 changes: 7 additions & 1 deletion tooling/nargo_cli/src/cli/compile_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ mod tests {
use noirc_driver::{CompileOptions, CrateName};

use crate::cli::compile_cmd::{get_target_width, parse_workspace, read_workspace};
use crate::cli::test_cmd::formatters::diagnostic_to_string;

/// Try to find the directory that Cargo sets when it is running;
/// otherwise fallback to assuming the CWD is the root of the repository
Expand Down Expand Up @@ -414,7 +415,12 @@ mod tests {
&CompileOptions::default(),
None,
)
.expect("failed to compile");
.unwrap_or_else(|err| {
for diagnostic in err {
println!("{}", diagnostic_to_string(&diagnostic, &file_manager));
}
panic!("Failed to compile")
});

let width = get_target_width(package.expression_width, None);

Expand Down
2 changes: 1 addition & 1 deletion tooling/nargo_cli/src/cli/test_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::{cli::check_cmd::check_crate_and_report_errors, errors::CliError};

use super::{NargoConfig, PackageOptions};

mod formatters;
pub(crate) mod formatters;

/// Run the tests for this program
#[derive(Debug, Clone, Args)]
Expand Down
Loading

0 comments on commit dce2c7d

Please sign in to comment.