Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow passing nested arrays and slices into foreign calls #4053

Closed
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,40 @@ namespace Circuit {
static BinaryIntOp bincodeDeserialize(std::vector<uint8_t>);
};

struct HeapValueType;

struct HeapValueType {

struct Simple {
friend bool operator==(const Simple&, const Simple&);
std::vector<uint8_t> bincodeSerialize() const;
static Simple bincodeDeserialize(std::vector<uint8_t>);
};

struct Array {
std::vector<Circuit::HeapValueType> value_types;
uint64_t size;

friend bool operator==(const Array&, const Array&);
std::vector<uint8_t> bincodeSerialize() const;
static Array bincodeDeserialize(std::vector<uint8_t>);
};

struct Vector {
std::vector<Circuit::HeapValueType> value_types;

friend bool operator==(const Vector&, const Vector&);
std::vector<uint8_t> bincodeSerialize() const;
static Vector bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Simple, Array, Vector> value;

friend bool operator==(const HeapValueType&, const HeapValueType&);
std::vector<uint8_t> bincodeSerialize() const;
static HeapValueType bincodeDeserialize(std::vector<uint8_t>);
};

struct RegisterIndex {
uint64_t value;

Expand All @@ -371,6 +405,7 @@ namespace Circuit {
struct HeapArray {
Circuit::RegisterIndex pointer;
uint64_t size;
std::vector<Circuit::HeapValueType> value_types;

friend bool operator==(const HeapArray&, const HeapArray&);
std::vector<uint8_t> bincodeSerialize() const;
Expand All @@ -380,6 +415,7 @@ namespace Circuit {
struct HeapVector {
Circuit::RegisterIndex pointer;
Circuit::RegisterIndex size;
std::vector<Circuit::HeapValueType> value_types;

friend bool operator==(const HeapVector&, const HeapVector&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4083,6 +4119,7 @@ namespace Circuit {
inline bool operator==(const HeapArray &lhs, const HeapArray &rhs) {
if (!(lhs.pointer == rhs.pointer)) { return false; }
if (!(lhs.size == rhs.size)) { return false; }
if (!(lhs.value_types == rhs.value_types)) { return false; }
return true;
}

Expand All @@ -4109,6 +4146,7 @@ void serde::Serializable<Circuit::HeapArray>::serialize(const Circuit::HeapArray
serializer.increase_container_depth();
serde::Serializable<decltype(obj.pointer)>::serialize(obj.pointer, serializer);
serde::Serializable<decltype(obj.size)>::serialize(obj.size, serializer);
serde::Serializable<decltype(obj.value_types)>::serialize(obj.value_types, serializer);
serializer.decrease_container_depth();
}

Expand All @@ -4119,15 +4157,173 @@ Circuit::HeapArray serde::Deserializable<Circuit::HeapArray>::deserialize(Deseri
Circuit::HeapArray obj;
obj.pointer = serde::Deserializable<decltype(obj.pointer)>::deserialize(deserializer);
obj.size = serde::Deserializable<decltype(obj.size)>::deserialize(deserializer);
obj.value_types = serde::Deserializable<decltype(obj.value_types)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}

namespace Circuit {

inline bool operator==(const HeapValueType &lhs, const HeapValueType &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> HeapValueType::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<HeapValueType>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline HeapValueType HeapValueType::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<HeapValueType>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::HeapValueType>::serialize(const Circuit::HeapValueType &obj, Serializer &serializer) {
serializer.increase_container_depth();
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
serializer.decrease_container_depth();
}

template <>
template <typename Deserializer>
Circuit::HeapValueType serde::Deserializable<Circuit::HeapValueType>::deserialize(Deserializer &deserializer) {
deserializer.increase_container_depth();
Circuit::HeapValueType obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}

namespace Circuit {

inline bool operator==(const HeapValueType::Simple &lhs, const HeapValueType::Simple &rhs) {
return true;
}

inline std::vector<uint8_t> HeapValueType::Simple::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<HeapValueType::Simple>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline HeapValueType::Simple HeapValueType::Simple::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<HeapValueType::Simple>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::HeapValueType::Simple>::serialize(const Circuit::HeapValueType::Simple &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::HeapValueType::Simple serde::Deserializable<Circuit::HeapValueType::Simple>::deserialize(Deserializer &deserializer) {
Circuit::HeapValueType::Simple obj;
return obj;
}

namespace Circuit {

inline bool operator==(const HeapValueType::Array &lhs, const HeapValueType::Array &rhs) {
if (!(lhs.value_types == rhs.value_types)) { return false; }
if (!(lhs.size == rhs.size)) { return false; }
return true;
}

inline std::vector<uint8_t> HeapValueType::Array::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<HeapValueType::Array>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline HeapValueType::Array HeapValueType::Array::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<HeapValueType::Array>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::HeapValueType::Array>::serialize(const Circuit::HeapValueType::Array &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.value_types)>::serialize(obj.value_types, serializer);
serde::Serializable<decltype(obj.size)>::serialize(obj.size, serializer);
}

template <>
template <typename Deserializer>
Circuit::HeapValueType::Array serde::Deserializable<Circuit::HeapValueType::Array>::deserialize(Deserializer &deserializer) {
Circuit::HeapValueType::Array obj;
obj.value_types = serde::Deserializable<decltype(obj.value_types)>::deserialize(deserializer);
obj.size = serde::Deserializable<decltype(obj.size)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const HeapValueType::Vector &lhs, const HeapValueType::Vector &rhs) {
if (!(lhs.value_types == rhs.value_types)) { return false; }
return true;
}

inline std::vector<uint8_t> HeapValueType::Vector::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<HeapValueType::Vector>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline HeapValueType::Vector HeapValueType::Vector::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<HeapValueType::Vector>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::HeapValueType::Vector>::serialize(const Circuit::HeapValueType::Vector &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.value_types)>::serialize(obj.value_types, serializer);
}

template <>
template <typename Deserializer>
Circuit::HeapValueType::Vector serde::Deserializable<Circuit::HeapValueType::Vector>::deserialize(Deserializer &deserializer) {
Circuit::HeapValueType::Vector obj;
obj.value_types = serde::Deserializable<decltype(obj.value_types)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const HeapVector &lhs, const HeapVector &rhs) {
if (!(lhs.pointer == rhs.pointer)) { return false; }
if (!(lhs.size == rhs.size)) { return false; }
if (!(lhs.value_types == rhs.value_types)) { return false; }
return true;
}

Expand All @@ -4154,6 +4350,7 @@ void serde::Serializable<Circuit::HeapVector>::serialize(const Circuit::HeapVect
serializer.increase_container_depth();
serde::Serializable<decltype(obj.pointer)>::serialize(obj.pointer, serializer);
serde::Serializable<decltype(obj.size)>::serialize(obj.size, serializer);
serde::Serializable<decltype(obj.value_types)>::serialize(obj.value_types, serializer);
serializer.decrease_container_depth();
}

Expand All @@ -4164,6 +4361,7 @@ Circuit::HeapVector serde::Deserializable<Circuit::HeapVector>::deserialize(Dese
Circuit::HeapVector obj;
obj.pointer = serde::Deserializable<decltype(obj.pointer)>::deserialize(deserializer);
obj.size = serde::Deserializable<decltype(obj.size)>::deserialize(deserializer);
obj.value_types = serde::Deserializable<decltype(obj.value_types)>::deserialize(deserializer);
deserializer.decrease_container_depth();
return obj;
}
Expand Down
4 changes: 3 additions & 1 deletion acvm-repo/acir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
mod reflection {
//! Getting test failures? You've probably changed the ACIR serialization format.
//!
//! These tests generate C++ deserializers for [`ACIR bytecode`][super::circuit::Circuit]

Check warning on line 20 in acvm-repo/acir/src/lib.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (deserializers)
//! and the [`WitnessMap`] structs. These get checked against the C++ files committed to the `codegen` folder
//! to see if changes have been to the serialization format. These are almost always a breaking change!
//!
Expand All @@ -32,7 +32,8 @@
};

use brillig::{
BinaryFieldOp, BinaryIntOp, BlackBoxOp, Opcode as BrilligOpcode, RegisterOrMemory,
BinaryFieldOp, BinaryIntOp, BlackBoxOp, HeapValueType, Opcode as BrilligOpcode,
RegisterOrMemory,
};
use serde_reflection::{Tracer, TracerConfig};

Expand Down Expand Up @@ -70,6 +71,7 @@
tracer.trace_simple_type::<BlackBoxOp>().unwrap();
tracer.trace_simple_type::<Directive>().unwrap();
tracer.trace_simple_type::<RegisterOrMemory>().unwrap();
tracer.trace_simple_type::<HeapValueType>().unwrap();

let registry = tracer.registry().unwrap();

Expand Down
28 changes: 18 additions & 10 deletions acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use acir::{
native_types::{Expression, Witness},
};
use acir_field::FieldElement;
use brillig::{HeapArray, RegisterIndex, RegisterOrMemory};
use brillig::{HeapArray, HeapValueType, RegisterIndex, RegisterOrMemory};

#[test]
fn addition_circuit() {
Expand Down Expand Up @@ -245,11 +245,19 @@ fn complex_brillig_foreign_call() {
brillig::Opcode::ForeignCall {
function: "complex".into(),
inputs: vec![
RegisterOrMemory::HeapArray(HeapArray { pointer: 0.into(), size: 3 }),
RegisterOrMemory::HeapArray(HeapArray {
pointer: 0.into(),
size: 3,
value_types: vec![HeapValueType::Simple],
}),
RegisterOrMemory::RegisterIndex(RegisterIndex::from(1)),
],
destinations: vec![
RegisterOrMemory::HeapArray(HeapArray { pointer: 0.into(), size: 3 }),
RegisterOrMemory::HeapArray(HeapArray {
pointer: 0.into(),
size: 3,
value_types: vec![HeapValueType::Simple],
}),
RegisterOrMemory::RegisterIndex(RegisterIndex::from(1)),
RegisterOrMemory::RegisterIndex(RegisterIndex::from(2)),
],
Expand All @@ -269,13 +277,13 @@ fn complex_brillig_foreign_call() {
let bytes = Circuit::serialize_circuit(&circuit);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 219, 10, 128, 48, 8, 117, 174, 139, 159, 179,
254, 160, 127, 137, 222, 138, 122, 236, 243, 19, 114, 32, 22, 244, 144, 131, 118, 64, 156,
178, 29, 14, 59, 74, 0, 16, 224, 66, 228, 64, 57, 7, 169, 53, 242, 189, 81, 114, 250, 134,
33, 248, 113, 165, 82, 26, 177, 2, 141, 177, 128, 198, 60, 15, 63, 245, 219, 211, 23, 215,
255, 139, 15, 251, 211, 112, 180, 28, 157, 212, 189, 100, 82, 179, 64, 170, 63, 109, 235,
190, 204, 135, 166, 178, 150, 216, 62, 154, 252, 250, 70, 147, 35, 220, 119, 93, 227, 4,
182, 131, 81, 25, 36, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 219, 10, 128, 48, 8, 117, 174, 203, 62, 103,
253, 65, 255, 18, 189, 21, 245, 216, 231, 55, 200, 193, 193, 122, 137, 28, 180, 3, 226, 20,
39, 135, 29, 103, 32, 34, 71, 23, 124, 50, 150, 179, 147, 24, 145, 235, 70, 241, 241, 27,
6, 103, 215, 43, 150, 226, 200, 21, 112, 244, 5, 56, 230, 121, 248, 169, 222, 150, 186,
152, 190, 159, 127, 248, 63, 77, 178, 54, 89, 39, 113, 47, 62, 192, 44, 4, 200, 79, 219,
186, 47, 243, 129, 173, 180, 36, 152, 211, 49, 43, 255, 234, 62, 22, 48, 221, 119, 0, 226,
4, 104, 45, 56, 241, 60, 4, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down
10 changes: 5 additions & 5 deletions acvm-repo/acvm_js/test/shared/complex_foreign_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `complex_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 219, 10, 128, 48, 8, 117, 174, 139, 159, 179, 254, 160, 127, 137, 222,
138, 122, 236, 243, 19, 114, 32, 22, 244, 144, 131, 118, 64, 156, 178, 29, 14, 59, 74, 0, 16, 224, 66, 228, 64, 57, 7,
169, 53, 242, 189, 81, 114, 250, 134, 33, 248, 113, 165, 82, 26, 177, 2, 141, 177, 128, 198, 60, 15, 63, 245, 219,
211, 23, 215, 255, 139, 15, 251, 211, 112, 180, 28, 157, 212, 189, 100, 82, 179, 64, 170, 63, 109, 235, 190, 204, 135,
166, 178, 150, 216, 62, 154, 252, 250, 70, 147, 35, 220, 119, 93, 227, 4, 182, 131, 81, 25, 36, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 219, 10, 128, 48, 8, 117, 174, 203, 62, 103, 253, 65, 255, 18, 189, 21,
245, 216, 231, 55, 200, 193, 193, 122, 137, 28, 180, 3, 226, 20, 39, 135, 29, 103, 32, 34, 71, 23, 124, 50, 150, 179,
147, 24, 145, 235, 70, 241, 241, 27, 6, 103, 215, 43, 150, 226, 200, 21, 112, 244, 5, 56, 230, 121, 248, 169, 222,
150, 186, 152, 190, 159, 127, 248, 63, 77, 178, 54, 89, 39, 113, 47, 62, 192, 44, 4, 200, 79, 219, 186, 47, 243, 129,
173, 180, 36, 152, 211, 49, 43, 255, 234, 62, 22, 48, 221, 119, 0, 226, 4, 104, 45, 56, 241, 60, 4, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000001'],
Expand Down
2 changes: 1 addition & 1 deletion acvm-repo/brillig/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};

/// These opcodes provide an equivalent of ACIR blackbox functions.
/// They are implemented as native functions in the VM.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlackBoxOp {
/// Calculates the SHA256 hash of the inputs.
Sha256 { message: HeapVector, output: HeapArray },
Expand Down
3 changes: 2 additions & 1 deletion acvm-repo/brillig/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ mod value;
pub use black_box::BlackBoxOp;
pub use foreign_call::{ForeignCallParam, ForeignCallResult};
pub use opcodes::{
BinaryFieldOp, BinaryIntOp, HeapArray, HeapVector, RegisterIndex, RegisterOrMemory,
BinaryFieldOp, BinaryIntOp, HeapArray, HeapValueType, HeapVector, RegisterIndex,
RegisterOrMemory,
};
pub use opcodes::{BrilligOpcode as Opcode, Label};
pub use value::Typ;
Expand Down
Loading
Loading