diff --git a/cpp/src/barretenberg/honk/composer/eccvm_composer.cpp b/cpp/src/barretenberg/honk/composer/eccvm_composer.cpp new file mode 100644 index 0000000000..0ce8c5310b --- /dev/null +++ b/cpp/src/barretenberg/honk/composer/eccvm_composer.cpp @@ -0,0 +1,126 @@ +#include "./eccvm_composer.hpp" +#include "barretenberg/honk/proof_system/ultra_prover.hpp" +#include "barretenberg/proof_system/composer/composer_lib.hpp" +#include "barretenberg/proof_system/composer/permutation_lib.hpp" + +namespace proof_system::honk { + +/** + * @brief Compute witness polynomials + * + */ +template +void ECCVMComposerHelper_::compute_witness(CircuitConstructor& circuit_constructor) +{ + if (computed_witness) { + return; + } + + auto polynomials = circuit_constructor.compute_full_polynomials(); + + auto key_wires = proving_key->get_wires(); + auto poly_wires = polynomials.get_wires(); + + for (size_t i = 0; i < key_wires.size(); ++i) { + std::copy(poly_wires[i].begin(), poly_wires[i].end(), key_wires[i].begin()); + } + + computed_witness = true; +} + +template +ECCVMProver_ ECCVMComposerHelper_::create_prover(CircuitConstructor& circuit_constructor) +{ + compute_proving_key(circuit_constructor); + compute_witness(circuit_constructor); + compute_commitment_key(proving_key->circuit_size); + + ECCVMProver_ output_state(proving_key, commitment_key); + + return output_state; +} + +/** + * Create verifier: compute verification key, + * initialize verifier with it and an initial manifest and initialize commitment_scheme. + * + * @return The verifier. + * */ +template +ECCVMVerifier_ ECCVMComposerHelper_::create_verifier(CircuitConstructor& circuit_constructor) +{ + auto verification_key = compute_verification_key(circuit_constructor); + + ECCVMVerifier_ output_state(verification_key); + + auto pcs_verification_key = std::make_unique(verification_key->circuit_size, crs_factory_); + + output_state.pcs_verification_key = std::move(pcs_verification_key); + + return output_state; +} + +template +std::shared_ptr ECCVMComposerHelper_::compute_proving_key( + CircuitConstructor& circuit_constructor) +{ + if (proving_key) { + return proving_key; + } + + // Initialize proving_key + // TODO(#392)(Kesha): replace composer types. + { + // TODO: get num gates in a more efficient way + const auto rows = circuit_constructor.compute_full_polynomials(); + const size_t subgroup_size = rows.lagrange_first.size(); + // Differentiate between Honk and Plonk here since Plonk pkey requires crs whereas Honk pkey does not + proving_key = std::make_shared(subgroup_size, 0); + } + + // construct_selector_polynomials(circuit_constructor, proving_key.get()); + + // TODO(@zac-williamson): We don't enforce nonzero selectors atm. Will create problems in recursive setting. Needs + // fix enforce_nonzero_polynomial_selectors(circuit_constructor, proving_key.get()); + + compute_first_and_last_lagrange_polynomials(proving_key.get()); + { + const size_t n = proving_key->circuit_size; + typename Flavor::Polynomial lagrange_polynomial_second(n); + lagrange_polynomial_second[1] = 1; + proving_key->lagrange_second = lagrange_polynomial_second; + } + + proving_key->contains_recursive_proof = false; + + return proving_key; +} + +/** + * Compute verification key consisting of selector precommitments. + * + * @return Pointer to created circuit verification key. + * */ +template +std::shared_ptr ECCVMComposerHelper_::compute_verification_key( + CircuitConstructor& circuit_constructor) +{ + if (verification_key) { + return verification_key; + } + + if (!proving_key) { + compute_proving_key(circuit_constructor); + } + + verification_key = + std::make_shared(proving_key->circuit_size, proving_key->num_public_inputs); + + verification_key->lagrange_first = commitment_key->commit(proving_key->lagrange_first); + verification_key->lagrange_second = commitment_key->commit(proving_key->lagrange_second); + verification_key->lagrange_last = commitment_key->commit(proving_key->lagrange_last); + return verification_key; +} +template class ECCVMComposerHelper_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/composer/eccvm_composer.hpp b/cpp/src/barretenberg/honk/composer/eccvm_composer.hpp new file mode 100644 index 0000000000..8c9fe7425e --- /dev/null +++ b/cpp/src/barretenberg/honk/composer/eccvm_composer.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include "barretenberg/honk/proof_system/eccvm_prover.hpp" +#include "barretenberg/honk/proof_system/eccvm_verifier.hpp" +#include "barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp" +#include "barretenberg/proof_system/composer/composer_lib.hpp" +#include "barretenberg/srs/factories/file_crs_factory.hpp" + +namespace proof_system::honk { +template class ECCVMComposerHelper_ { + public: + using CircuitConstructor = ECCVMCircuitConstructor; + using ProvingKey = typename Flavor::ProvingKey; + using VerificationKey = typename Flavor::VerificationKey; + using PCSParams = typename Flavor::PCSParams; + using PCS = typename Flavor::PCS; + using PCSCommitmentKey = typename PCSParams::CommitmentKey; + using PCSVerificationKey = typename PCSParams::VerificationKey; + + static constexpr std::string_view NAME_STRING = "ECCVM"; + static constexpr size_t NUM_RESERVED_GATES = 0; // equal to the number of multilinear evaluations leaked + static constexpr size_t NUM_WIRES = CircuitConstructor::NUM_WIRES; + std::shared_ptr proving_key; + std::shared_ptr verification_key; + + // The crs_factory holds the path to the srs and exposes methods to extract the srs elements + std::shared_ptr crs_factory_; + + // The commitment key is passed to the prover but also used herein to compute the verfication key commitments + std::shared_ptr commitment_key; + + std::vector recursive_proof_public_input_indices; + bool contains_recursive_proof = false; + bool computed_witness = false; + + ECCVMComposerHelper_() + : crs_factory_(barretenberg::srs::get_crs_factory()){}; + + explicit ECCVMComposerHelper_(std::shared_ptr crs_factory) + : crs_factory_(std::move(crs_factory)) + {} + + ECCVMComposerHelper_(std::shared_ptr p_key, std::shared_ptr v_key) + : proving_key(std::move(p_key)) + , verification_key(std::move(v_key)) + {} + + ECCVMComposerHelper_(ECCVMComposerHelper_&& other) noexcept = default; + ECCVMComposerHelper_(ECCVMComposerHelper_ const& other) noexcept = default; + ECCVMComposerHelper_& operator=(ECCVMComposerHelper_&& other) noexcept = default; + ECCVMComposerHelper_& operator=(ECCVMComposerHelper_ const& other) noexcept = default; + ~ECCVMComposerHelper_() = default; + + std::shared_ptr compute_proving_key(CircuitConstructor& circuit_constructor); + std::shared_ptr compute_verification_key(CircuitConstructor& circuit_constructor); + + void compute_witness(CircuitConstructor& circuit_constructor); + + ECCVMProver_ create_prover(CircuitConstructor& circuit_constructor); + ECCVMVerifier_ create_verifier(CircuitConstructor& circuit_constructor); + + void add_table_column_selector_poly_to_proving_key(polynomial& small, const std::string& tag); + + void compute_commitment_key(size_t circuit_size) + { + commitment_key = std::make_shared(circuit_size, crs_factory_); + }; +}; +extern template class ECCVMComposerHelper_; +// TODO(#532): this pattern is weird; is this not instantiating the templates? +using ECCVMComposerHelper = ECCVMComposerHelper_; +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/composer/eccvm_composer.test.cpp b/cpp/src/barretenberg/honk/composer/eccvm_composer.test.cpp new file mode 100644 index 0000000000..0682e7fa74 --- /dev/null +++ b/cpp/src/barretenberg/honk/composer/eccvm_composer.test.cpp @@ -0,0 +1,93 @@ +#include +#include +#include +#include + +#include "barretenberg/honk/composer/eccvm_composer.hpp" +#include "barretenberg/honk/proof_system/prover.hpp" +#include "barretenberg/honk/sumcheck/relations/permutation_relation.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" +#include "barretenberg/honk/sumcheck/sumcheck_round.hpp" +#include "barretenberg/honk/utils/grand_product_delta.hpp" +#include "barretenberg/numeric/uint256/uint256.hpp" +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp" + +using namespace proof_system::honk; + +namespace test_standard_honk_composer { + +class ECCVMComposerTests : public ::testing::Test { + protected: + static void SetUpTestSuite() { barretenberg::srs::init_crs_factory("../srs_db/ignition"); } +}; +namespace { +auto& engine = numeric::random::get_debug_engine(); +} +proof_system::ECCVMCircuitConstructor generate_trace(numeric::random::Engine* engine = nullptr) +{ + proof_system::ECCVMCircuitConstructor result; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::g1::element b = grumpkin::get_generator(1); + grumpkin::g1::element c = grumpkin::get_generator(2); + grumpkin::fr x = grumpkin::fr::random_element(engine); + grumpkin::fr y = grumpkin::fr::random_element(engine); + + grumpkin::g1::element expected_1 = (a * x) + a + a + (b * y) + (b * x) + (b * x); + grumpkin::g1::element expected_2 = (a * x) + c + (b * x); + + result.add_accumulate(a); + result.mul_accumulate(a, x); + result.mul_accumulate(b, x); + result.mul_accumulate(b, y); + result.add_accumulate(a); + result.mul_accumulate(b, x); + result.eq(expected_1); + result.add_accumulate(c); + result.mul_accumulate(a, x); + result.mul_accumulate(b, x); + result.eq(expected_2); + result.mul_accumulate(a, x); + result.mul_accumulate(b, x); + result.mul_accumulate(c, x); + + return result; +} + +TEST_F(ECCVMComposerTests, BaseCase) +{ + auto circuit_constructor = generate_trace(&engine); + + auto composer = ECCVMComposerHelper(); + auto prover = composer.create_prover(circuit_constructor); + + // / size_t pidx = 0; + // for (auto& p : prover.prover_polynomials) { + // size_t count = 0; + // for (auto& x : p) { + // std::cout << "poly[" << pidx << "][" << count << "] = " << x << std::endl; + // count++; + // } + // pidx++; + // } + auto proof = prover.construct_proof(); + auto verifier = composer.create_verifier(circuit_constructor); + bool verified = verifier.verify_proof(proof); + ASSERT_TRUE(verified); +} + +TEST_F(ECCVMComposerTests, EqFails) +{ + auto circuit_constructor = generate_trace(&engine); + // create an eq opcode that is not satisfied + circuit_constructor.eq(grumpkin::g1::affine_one); + auto composer = ECCVMComposerHelper(); + auto prover = composer.create_prover(circuit_constructor); + + auto proof = prover.construct_proof(); + auto verifier = composer.create_verifier(circuit_constructor); + bool verified = verifier.verify_proof(proof); + ASSERT_FALSE(verified); +} +} // namespace test_standard_honk_composer diff --git a/cpp/src/barretenberg/honk/flavor/ecc_vm.hpp b/cpp/src/barretenberg/honk/flavor/ecc_vm.hpp new file mode 100644 index 0000000000..b9850aba73 --- /dev/null +++ b/cpp/src/barretenberg/honk/flavor/ecc_vm.hpp @@ -0,0 +1,857 @@ +#pragma once +#include "../sumcheck/relations/relation_definitions_fwd.hpp" +#include "../sumcheck/relations/relation_types.hpp" +#include "barretenberg/honk/pcs/commitment_key.hpp" +#include "barretenberg/honk/pcs/ipa/ipa.hpp" +#include "barretenberg/honk/pcs/kzg/kzg.hpp" +#include "barretenberg/honk/sumcheck/polynomials/univariate.hpp" +#include "barretenberg/proof_system/flavor/flavor.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_msm_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_point_table_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_set_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_transcript_relation.hpp" +#include "barretenberg/proof_system/relations/ecc_vm/ecc_wnaf_relation.hpp" +#include +#include +#include +#include +#include +#include + +// NOLINTBEGIN(cppcoreguidelines-avoid-const-or-ref-data-members) + +namespace proof_system::honk { +namespace flavor { + +template typename PCS_T> +class ECCVMBase { + public: + using CycleGroup = CycleGroup_T; + // forward template params into the ECCVMBase namespace + using G1 = G1_T; + using PCSParams = PCSParams_T; + using PCS = PCS_T; + + using FF = typename G1::subgroup_field; + using Polynomial = barretenberg::Polynomial; + using PolynomialHandle = std::span; + using GroupElement = typename G1::element; + using Commitment = typename G1::affine_element; + using CommitmentHandle = typename G1::affine_element; + + static constexpr size_t NUM_WIRES = 74; + + // The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often + // need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`. + // Note: this number does not include the individual sorted list polynomials. + static constexpr size_t NUM_ALL_ENTITIES = 105; + // The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying + // assignment of witnesses. We again choose a neutral name. + static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 3; + // The total number of witness entities not including shifts. + static constexpr size_t NUM_WITNESS_ENTITIES = 76; + + using GrandProductRelations = std::tuple>; + // define the tuple of Relations that comprise the Sumcheck relation + using Relations = std::tuple, + sumcheck::ECCVMPointTableRelation, + sumcheck::ECCVMWnafRelation, + sumcheck::ECCVMMSMRelation, + sumcheck::ECCVMSetRelation, + sumcheck::ECCVMLookupRelation>; + + using LookupRelation = sumcheck::ECCVMLookupRelation; + static constexpr size_t MAX_RELATION_LENGTH = get_max_relation_length(); + + // MAX_RANDOM_RELATION_LENGTH = algebraic degree of sumcheck relation *after* multiplying by the `pow_zeta` random + // polynomial e.g. For \sum(x) [A(x) * B(x) + C(x)] * PowZeta(X), relation length = 2 and random relation length = 3 + static constexpr size_t MAX_RANDOM_RELATION_LENGTH = MAX_RELATION_LENGTH + 1; + static constexpr size_t NUM_RELATIONS = std::tuple_size::value; + + // Instantiate the BarycentricData needed to extend each Relation Univariate + // static_assert(instantiate_barycentric_utils()); + + // define the containers for storing the contributions from each relation in Sumcheck + using RelationUnivariates = decltype(create_relation_univariates_container()); + using RelationValues = decltype(create_relation_values_container()); + + private: + // class Counter { + // constexpr size_t foo() + // { + // return Thing<>; + // } + // }; + /** + * @brief A base class labelling precomputed entities and (ordered) subsets of interest. + * @details Used to build the proving key and verification key. + */ + template + class PrecomputedEntities : public PrecomputedEntities_ { + public: + DataType& lagrange_first = std::get<0>(this->_data); + DataType& lagrange_second = std::get<1>(this->_data); + DataType& lagrange_last = std::get<2>(this->_data); + + std::vector get_selectors() override { return { lagrange_first, lagrange_second, lagrange_last }; }; + std::vector get_sigma_polynomials() override { return {}; }; + std::vector get_id_polynomials() override { return {}; }; + std::vector get_table_polynomials() { return {}; }; + }; + + /** + * @brief Container for all witness polynomials used/constructed by the prover. + * @details Shifts are not included here since they do not occupy their own memory. + */ + template + class WitnessEntities : public WitnessEntities_ { + public: + // clang-format off + DataType& q_transcript_add = std::get<0>(this->_data); + DataType& q_transcript_mul = std::get<1>(this->_data); + DataType& q_transcript_eq = std::get<2>(this->_data); + DataType& q_transcript_accumulate = std::get<3>(this->_data); + DataType& q_transcript_msm_transition = std::get<4>(this->_data); + DataType& transcript_pc = std::get<5>(this->_data); + DataType& transcript_msm_count = std::get<6>(this->_data); + DataType& transcript_x = std::get<7>(this->_data); + DataType& transcript_y = std::get<8>(this->_data); + DataType& transcript_z1 = std::get<9>(this->_data); + DataType& transcript_z2 = std::get<10>(this->_data); + DataType& transcript_z1zero = std::get<11>(this->_data); + DataType& transcript_z2zero = std::get<12>(this->_data); + DataType& transcript_op = std::get<13>(this->_data); + DataType& transcript_accumulator_x = std::get<14>(this->_data); + DataType& transcript_accumulator_y = std::get<15>(this->_data); + DataType& transcript_msm_x = std::get<16>(this->_data); + DataType& transcript_msm_y = std::get<17>(this->_data); + DataType& table_pc = std::get<18>(this->_data); + DataType& table_point_transition = std::get<19>(this->_data); + DataType& table_round = std::get<20>(this->_data); + DataType& table_scalar_sum = std::get<21>(this->_data); + DataType& table_s1 = std::get<22>(this->_data); + DataType& table_s2 = std::get<23>(this->_data); + DataType& table_s3 = std::get<24>(this->_data); + DataType& table_s4 = std::get<25>(this->_data); + DataType& table_s5 = std::get<26>(this->_data); + DataType& table_s6 = std::get<27>(this->_data); + DataType& table_s7 = std::get<28>(this->_data); + DataType& table_s8 = std::get<29>(this->_data); + DataType& table_skew = std::get<30>(this->_data); + DataType& table_dx = std::get<31>(this->_data); + DataType& table_dy = std::get<32>(this->_data); + DataType& table_tx = std::get<33>(this->_data); + DataType& table_ty = std::get<34>(this->_data); + DataType& q_msm_transition = std::get<35>(this->_data); + DataType& msm_q_add = std::get<36>(this->_data); + DataType& msm_q_double = std::get<37>(this->_data); + DataType& msm_q_skew = std::get<38>(this->_data); + DataType& msm_accumulator_x = std::get<39>(this->_data); + DataType& msm_accumulator_y = std::get<40>(this->_data); + DataType& msm_pc = std::get<41>(this->_data); + DataType& msm_size_of_msm = std::get<42>(this->_data); + DataType& msm_count = std::get<43>(this->_data); + DataType& msm_round = std::get<44>(this->_data); + DataType& msm_q_add1 = std::get<45>(this->_data); + DataType& msm_q_add2 = std::get<46>(this->_data); + DataType& msm_q_add3 = std::get<47>(this->_data); + DataType& msm_q_add4 = std::get<48>(this->_data); + DataType& msm_x1 = std::get<49>(this->_data); + DataType& msm_y1 = std::get<50>(this->_data); + DataType& msm_x2 = std::get<51>(this->_data); + DataType& msm_y2 = std::get<52>(this->_data); + DataType& msm_x3 = std::get<53>(this->_data); + DataType& msm_y3 = std::get<54>(this->_data); + DataType& msm_x4 = std::get<55>(this->_data); + DataType& msm_y4 = std::get<56>(this->_data); + DataType& msm_collision_x1 = std::get<57>(this->_data); + DataType& msm_collision_x2 = std::get<58>(this->_data); + DataType& msm_collision_x3 = std::get<59>(this->_data); + DataType& msm_collision_x4 = std::get<60>(this->_data); + DataType& msm_lambda1 = std::get<61>(this->_data); + DataType& msm_lambda2 = std::get<62>(this->_data); + DataType& msm_lambda3 = std::get<63>(this->_data); + DataType& msm_lambda4 = std::get<64>(this->_data); + DataType& msm_slice1 = std::get<65>(this->_data); + DataType& msm_slice2 = std::get<66>(this->_data); + DataType& msm_slice3 = std::get<67>(this->_data); + DataType& msm_slice4 = std::get<68>(this->_data); + DataType& transcript_accumulator_empty = std::get<69>(this->_data); + DataType& transcript_q_reset_accumulator = std::get<70>(this->_data); + DataType& q_wnaf = std::get<71>(this->_data); + DataType& lookup_read_counts_0 = std::get<72>(this->_data); + DataType& lookup_read_counts_1 = std::get<73>(this->_data); + DataType& z_perm = std::get<74>(this->_data); + DataType& lookup_inverses = std::get<75>(this->_data); + + // clang-format on + std::vector get_wires() override + { + return { + q_transcript_add, + q_transcript_mul, + q_transcript_eq, + q_transcript_accumulate, + q_transcript_msm_transition, + transcript_pc, + transcript_msm_count, + transcript_x, + transcript_y, + transcript_z1, + transcript_z2, + transcript_z1zero, + transcript_z2zero, + transcript_op, + transcript_accumulator_x, + transcript_accumulator_y, + transcript_msm_x, + transcript_msm_y, + table_pc, + table_point_transition, + table_round, + table_scalar_sum, + table_s1, + table_s2, + table_s3, + table_s4, + table_s5, + table_s6, + table_s7, + table_s8, + table_skew, + table_dx, + table_dy, + table_tx, + table_ty, + q_msm_transition, + msm_q_add, + msm_q_double, + msm_q_skew, + msm_accumulator_x, + msm_accumulator_y, + msm_pc, + msm_size_of_msm, + msm_count, + msm_round, + msm_q_add1, + msm_q_add2, + msm_q_add3, + msm_q_add4, + msm_x1, + msm_y1, + msm_x2, + msm_y2, + msm_x3, + msm_y3, + msm_x4, + msm_y4, + msm_collision_x1, + msm_collision_x2, + msm_collision_x3, + msm_collision_x4, + msm_lambda1, + msm_lambda2, + msm_lambda3, + msm_lambda4, + msm_slice1, + msm_slice2, + msm_slice3, + msm_slice4, + transcript_accumulator_empty, + transcript_q_reset_accumulator, + q_wnaf, + lookup_read_counts_0, + lookup_read_counts_1, + }; + }; + // The sorted concatenations of table and witness data needed for plookup. + std::vector get_sorted_polynomials() { return {}; }; + }; + + /** + * @brief A base class labelling all entities (for instance, all of the polynomials used by the prover during + * sumcheck) in this Honk variant along with particular subsets of interest + * @details Used to build containers for: the prover's polynomial during sumcheck; the sumcheck's folded + * polynomials; the univariates consturcted during during sumcheck; the evaluations produced by sumcheck. + * + * Symbolically we have: AllEntities = PrecomputedEntities + WitnessEntities + "ShiftedEntities". It could be + * implemented as such, but we have this now. + */ + // SUEHRGFPIEAUHFPAWEIUFHEAWP9UFH NEED TO MAKE SURE POINTS ARE NOT POINTS AT INFINITY + // I3EUBFPEWUBEWOPFUHEWPIFUHEPWFUHEQWOIFUHEOLRFHEQPFUHEQPFUH I.E. ALL ARE NONZERO EPFHUEPFHUEPGRFGHBEWOFIEHUPOFUHRNF + template + class AllEntities : public AllEntities_ { + public: + // clang-format off + DataType& lagrange_first = std::get<0>(this->_data); + DataType& lagrange_second = std::get<1>(this->_data); + DataType& lagrange_last = std::get<2>(this->_data); + DataType& q_transcript_add = std::get<3 + 0>(this->_data); + DataType& q_transcript_mul = std::get<3 + 1>(this->_data); + DataType& q_transcript_eq = std::get<3 + 2>(this->_data); + DataType& q_transcript_accumulate = std::get<3 + 3>(this->_data); + DataType& q_transcript_msm_transition = std::get<3 + 4>(this->_data); + DataType& transcript_pc = std::get<3 + 5>(this->_data); + DataType& transcript_msm_count = std::get<3 + 6>(this->_data); + DataType& transcript_x = std::get<3 + 7>(this->_data); + DataType& transcript_y = std::get<3 + 8>(this->_data); + DataType& transcript_z1 = std::get<3 + 9>(this->_data); + DataType& transcript_z2 = std::get<3 + 10>(this->_data); + DataType& transcript_z1zero = std::get<3 + 11>(this->_data); + DataType& transcript_z2zero = std::get<3 + 12>(this->_data); + DataType& transcript_op = std::get<3 + 13>(this->_data); + DataType& transcript_accumulator_x = std::get<3 + 14>(this->_data); + DataType& transcript_accumulator_y = std::get<3 + 15>(this->_data); + DataType& transcript_msm_x = std::get<3 + 16>(this->_data); + DataType& transcript_msm_y = std::get<3 + 17>(this->_data); + DataType& table_pc = std::get<3 + 18>(this->_data); + DataType& table_point_transition = std::get<3 + 19>(this->_data); + DataType& table_round = std::get<3 + 20>(this->_data); + DataType& table_scalar_sum = std::get<3 + 21>(this->_data); + DataType& table_s1 = std::get<3 + 22>(this->_data); + DataType& table_s2 = std::get<3 + 23>(this->_data); + DataType& table_s3 = std::get<3 + 24>(this->_data); + DataType& table_s4 = std::get<3 + 25>(this->_data); + DataType& table_s5 = std::get<3 + 26>(this->_data); + DataType& table_s6 = std::get<3 + 27>(this->_data); + DataType& table_s7 = std::get<3 + 28>(this->_data); + DataType& table_s8 = std::get<3 + 29>(this->_data); + DataType& table_skew = std::get<3 + 30>(this->_data); + DataType& table_dx = std::get<3 + 31>(this->_data); + DataType& table_dy = std::get<3 + 32>(this->_data); + DataType& table_tx = std::get<3 + 33>(this->_data); + DataType& table_ty = std::get<3 + 34>(this->_data); + DataType& q_msm_transition = std::get<3 + 35>(this->_data); + DataType& msm_q_add = std::get<3 + 36>(this->_data); + DataType& msm_q_double = std::get<3 + 37>(this->_data); + DataType& msm_q_skew = std::get<3 + 38>(this->_data); + DataType& msm_accumulator_x = std::get<3 + 39>(this->_data); + DataType& msm_accumulator_y = std::get<3 + 40>(this->_data); + DataType& msm_pc = std::get<3 + 41>(this->_data); + DataType& msm_size_of_msm = std::get<3 + 42>(this->_data); + DataType& msm_count = std::get<3 + 43>(this->_data); + DataType& msm_round = std::get<3 + 44>(this->_data); + DataType& msm_q_add1 = std::get<3 + 45>(this->_data); + DataType& msm_q_add2 = std::get<3 + 46>(this->_data); + DataType& msm_q_add3 = std::get<3 + 47>(this->_data); + DataType& msm_q_add4 = std::get<3 + 48>(this->_data); + DataType& msm_x1 = std::get<3 + 49>(this->_data); + DataType& msm_y1 = std::get<3 + 50>(this->_data); + DataType& msm_x2 = std::get<3 + 51>(this->_data); + DataType& msm_y2 = std::get<3 + 52>(this->_data); + DataType& msm_x3 = std::get<3 + 53>(this->_data); + DataType& msm_y3 = std::get<3 + 54>(this->_data); + DataType& msm_x4 = std::get<3 + 55>(this->_data); + DataType& msm_y4 = std::get<3 + 56>(this->_data); + DataType& msm_collision_x1 = std::get<3 + 57>(this->_data); + DataType& msm_collision_x2 = std::get<3 + 58>(this->_data); + DataType& msm_collision_x3 = std::get<3 + 59>(this->_data); + DataType& msm_collision_x4 = std::get<3 + 60>(this->_data); + DataType& msm_lambda1 = std::get<3 + 61>(this->_data); + DataType& msm_lambda2 = std::get<3 + 62>(this->_data); + DataType& msm_lambda3 = std::get<3 + 63>(this->_data); + DataType& msm_lambda4 = std::get<3 + 64>(this->_data); + DataType& msm_slice1 = std::get<3 + 65>(this->_data); + DataType& msm_slice2 = std::get<3 + 66>(this->_data); + DataType& msm_slice3 = std::get<3 + 67>(this->_data); + DataType& msm_slice4 = std::get<3 + 68>(this->_data); + DataType& transcript_accumulator_empty = std::get<3 + 69>(this->_data); + DataType& transcript_q_reset_accumulator = std::get<3 + 70>(this->_data); + DataType& q_wnaf = std::get<3 + 71>(this->_data); + DataType& lookup_read_counts_0 = std::get<3 + 72>(this->_data); + DataType& lookup_read_counts_1 = std::get<3 + 73>(this->_data); + DataType& z_perm = std::get<3 + 74>(this->_data); + DataType& lookup_inverses = std::get<3 + 75>(this->_data); + DataType& q_transcript_mul_shift = std::get<3 + 76>(this->_data); + DataType& q_transcript_accumulate_shift = std::get<3 + 77>(this->_data); + DataType& transcript_msm_count_shift = std::get<3 + 78>(this->_data); + DataType& transcript_accumulator_x_shift = std::get<3 + 79>(this->_data); + DataType& transcript_accumulator_y_shift = std::get<3 + 80>(this->_data); + DataType& table_scalar_sum_shift = std::get<3 + 81>(this->_data); + DataType& table_dx_shift = std::get<3 + 82>(this->_data); + DataType& table_dy_shift = std::get<3 + 83>(this->_data); + DataType& table_tx_shift = std::get<3 + 84>(this->_data); + DataType& table_ty_shift = std::get<3 + 85>(this->_data); + DataType& q_msm_transition_shift = std::get<3 + 86>(this->_data); + DataType& msm_q_add_shift = std::get<3 + 87>(this->_data); + DataType& msm_q_double_shift = std::get<3 + 88>(this->_data); + DataType& msm_q_skew_shift = std::get<3 + 89>(this->_data); + DataType& msm_accumulator_x_shift = std::get<3 + 90>(this->_data); + DataType& msm_accumulator_y_shift = std::get<3 + 91>(this->_data); + DataType& msm_count_shift = std::get<3 + 92>(this->_data); + DataType& msm_round_shift = std::get<3 + 93>(this->_data); + DataType& msm_q_add1_shift = std::get<3 + 94>(this->_data); + DataType& msm_pc_shift = std::get<3 + 95>(this->_data); + DataType& table_pc_shift = std::get<3 + 96>(this->_data); + DataType& transcript_pc_shift = std::get<3 + 97>(this->_data); + DataType& table_round_shift = std::get<3 + 98>(this->_data); + DataType& transcript_accumulator_empty_shift= std::get<3 + 99>(this->_data); + DataType& q_wnaf_shift = std::get<3 + 100>(this->_data); + DataType& z_perm_shift = std::get<3 + 101>(this->_data); + + template + [[nodiscard]] const DataType& lookup_read_counts() const + { + static_assert(index == 0 || index == 1); + return std::get<75 + index>(this->_data); + } + // clang-format on + + std::vector get_wires() override + { + return { + q_transcript_add, + q_transcript_mul, + q_transcript_eq, + q_transcript_accumulate, + q_transcript_msm_transition, + transcript_pc, + transcript_msm_count, + transcript_x, + transcript_y, + transcript_z1, + transcript_z2, + transcript_z1zero, + transcript_z2zero, + transcript_op, + transcript_accumulator_x, + transcript_accumulator_y, + transcript_msm_x, + transcript_msm_y, + table_pc, + table_point_transition, + table_round, + table_scalar_sum, + table_s1, + table_s2, + table_s3, + table_s4, + table_s5, + table_s6, + table_s7, + table_s8, + table_skew, + table_dx, + table_dy, + table_tx, + table_ty, + q_msm_transition, + msm_q_add, + msm_q_double, + msm_q_skew, + msm_accumulator_x, + msm_accumulator_y, + msm_pc, + msm_size_of_msm, + msm_count, + msm_round, + msm_q_add1, + msm_q_add2, + msm_q_add3, + msm_q_add4, + msm_x1, + msm_y1, + msm_x2, + msm_y2, + msm_x3, + msm_y3, + msm_x4, + msm_y4, + msm_collision_x1, + msm_collision_x2, + msm_collision_x3, + msm_collision_x4, + msm_lambda1, + msm_lambda2, + msm_lambda3, + msm_lambda4, + msm_slice1, + msm_slice2, + msm_slice3, + msm_slice4, + transcript_accumulator_empty, + transcript_q_reset_accumulator, + q_wnaf, + lookup_read_counts_0, + lookup_read_counts_1, + }; + }; + // Gemini-specific getters. + std::vector get_unshifted() override + { + return { + lagrange_first, + lagrange_second, + lagrange_last, + q_transcript_add, + q_transcript_eq, + q_transcript_msm_transition, + transcript_x, + transcript_y, + transcript_z1, + transcript_z2, + transcript_z1zero, + transcript_z2zero, + transcript_op, + transcript_msm_x, + transcript_msm_y, + table_point_transition, + table_s1, + table_s2, + table_s3, + table_s4, + table_s5, + table_s6, + table_s7, + table_s8, + table_skew, + msm_size_of_msm, + msm_q_add2, + msm_q_add3, + msm_q_add4, + msm_x1, + msm_y1, + msm_x2, + msm_y2, + msm_x3, + msm_y3, + msm_x4, + msm_y4, + msm_collision_x1, + msm_collision_x2, + msm_collision_x3, + msm_collision_x4, + msm_lambda1, + msm_lambda2, + msm_lambda3, + msm_lambda4, + msm_slice1, + msm_slice2, + msm_slice3, + msm_slice4, + transcript_q_reset_accumulator, + lookup_read_counts_0, + lookup_read_counts_1, + lookup_inverses, + }; + }; + + std::vector get_to_be_shifted() override + { + return { + q_transcript_mul, + q_transcript_accumulate, // NOT USED + transcript_msm_count, + transcript_accumulator_x, + transcript_accumulator_y, + table_scalar_sum, + table_dx, + table_dy, + table_tx, + table_ty, + q_msm_transition, + msm_q_add, + msm_q_double, + msm_q_skew, + msm_accumulator_x, + msm_accumulator_y, + msm_count, + msm_round, + msm_q_add1, + msm_pc, + table_pc, + transcript_pc, + table_round, + transcript_accumulator_empty, + q_wnaf, + z_perm, + }; + }; + std::vector get_shifted() override + { + return { + q_transcript_mul_shift, + q_transcript_accumulate_shift, + transcript_msm_count_shift, + transcript_accumulator_x_shift, + transcript_accumulator_y_shift, + table_scalar_sum_shift, + table_dx_shift, + table_dy_shift, + table_tx_shift, + table_ty_shift, + q_msm_transition_shift, + msm_q_add_shift, + msm_q_double_shift, + msm_q_skew_shift, + msm_accumulator_x_shift, + msm_accumulator_y_shift, + msm_count_shift, + msm_round_shift, + msm_q_add1_shift, + msm_pc_shift, + table_pc_shift, + transcript_pc_shift, + table_round_shift, + transcript_accumulator_empty_shift, + q_wnaf_shift, + z_perm_shift, + }; + }; + + AllEntities() = default; + + AllEntities(const AllEntities& other) + : AllEntities_(other){}; + + AllEntities(AllEntities&& other) noexcept + : AllEntities_(other){}; + + AllEntities& operator=(const AllEntities& other) + { + if (this == &other) { + return *this; + } + AllEntities_::operator=(other); + return *this; + } + + AllEntities& operator=(AllEntities&& other) noexcept + { + AllEntities_::operator=(other); + return *this; + } + + ~AllEntities() override = default; + }; + + public: + /** + * @brief The proving key is responsible for storing the polynomials used by the prover. + * @note TODO(Cody): Maybe multiple inheritance is the right thing here. In that case, nothing should eve inherit + * from ProvingKey. + */ + class ProvingKey : public ProvingKey_, + WitnessEntities> { + public: + // Expose constructors on the base class + using Base = ProvingKey_, + WitnessEntities>; + using Base::Base; + + // The plookup wires that store plookup read data. + std::array get_table_column_wires() { return {}; }; + }; + + /** + * @brief The verification key is responsible for storing the the commitments to the precomputed (non-witnessk) + * polynomials used by the verifier. + * + * @note Note the discrepancy with what sort of data is stored here vs in the proving key. We may want to resolve + * that, and split out separate PrecomputedPolynomials/Commitments data for clarity but also for portability of our + * circuits. + */ + using VerificationKey = VerificationKey_>; + + /** + * @brief A container for polynomials handles; only stores spans. + */ + using ProverPolynomials = AllEntities; + + /** + * @brief A container for polynomials produced after the first round of sumcheck. + * @todo TODO(#394) Use polynomial classes for guaranteed memory alignment. + */ + using FoldedPolynomials = AllEntities, PolynomialHandle>; + + using RawPolynomials = AllEntities; + + /** + * @brief A container for polynomials produced after the first round of sumcheck. + * @todo TODO(#394) Use polynomial classes for guaranteed memory alignment. + */ + using RowPolynomials = AllEntities; + + /** + * @brief A container for storing the partially evaluated multivariates produced by sumcheck. + */ + class PartiallyEvaluatedMultivariates : public AllEntities { + + public: + PartiallyEvaluatedMultivariates() = default; + PartiallyEvaluatedMultivariates(const size_t circuit_size) + { + // Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2) + for (auto& poly : this->_data) { + poly = Polynomial(circuit_size / 2); + } + } + }; + + /** + * @brief A container for univariates produced during the hot loop in sumcheck. + * @todo TODO(#390): Simplify this by moving MAX_RELATION_LENGTH? + */ + template + using ExtendedEdges = + AllEntities, sumcheck::Univariate>; + + /** + * @brief A container for the polynomials evaluations produced during sumcheck, which are purported to be the + * evaluations of polynomials committed in earlier rounds. + */ + class ClaimedEvaluations : public AllEntities { + public: + using Base = AllEntities; + using Base::Base; + ClaimedEvaluations(std::array _data_in) { this->_data = _data_in; } + }; + + /** + * @brief A container for commitment labels. + * @note It's debatable whether this should inherit from AllEntities. since most entries are not strictly needed. It + * has, however, been useful during debugging to have these labels available. + * + */ + class CommitmentLabels : public AllEntities { + private: + using Base = AllEntities; + + public: + CommitmentLabels() + : AllEntities() + { + Base::q_transcript_add = "Q_TRANSCRIPT_ADD"; + Base::q_transcript_mul = "Q_TRANSCRIPT_MUL"; + Base::q_transcript_eq = "Q_TRANSCRIPT_EQ"; + Base::q_transcript_accumulate = "Q_TRANSCRIPT_ACCUMULATE"; + Base::q_transcript_msm_transition = "Q_TRANSCRIPT_MSM_TRANSITION"; + Base::transcript_pc = "TRANSCRIPT_PC"; + Base::transcript_msm_count = "TRANSCRIPT_MSM_COUNT"; + Base::transcript_x = "TRANSCRIPT_X"; + Base::transcript_y = "TRANSCRIPT_Y"; + Base::transcript_z1 = "TRANSCRIPT_Z1"; + Base::transcript_z2 = "TRANSCRIPT_Z2"; + Base::transcript_z1zero = "TRANSCRIPT_Z1ZERO"; + Base::transcript_z2zero = "TRANSCRIPT_Z2ZERO"; + Base::transcript_op = "TRANSCRIPT_OP"; + Base::transcript_accumulator_x = "TRANSCRIPT_ACCUMULATOR_X"; + Base::transcript_accumulator_y = "TRANSCRIPT_ACCUMULATOR_Y"; + Base::transcript_msm_x = "TRANSCRIPT_MSM_X"; + Base::transcript_msm_y = "TRANSCRIPT_MSM_Y"; + Base::table_pc = "TABLE_PC"; + Base::table_point_transition = "TABLE_POINT_TRANSITION"; + Base::table_round = "TABLE_ROUND"; + Base::table_scalar_sum = "TABLE_SCALAR_SUM"; + Base::table_s1 = "TABLE_S1"; + Base::table_s2 = "TABLE_S2"; + Base::table_s3 = "TABLE_S3"; + Base::table_s4 = "TABLE_S4"; + Base::table_s5 = "TABLE_S5"; + Base::table_s6 = "TABLE_S6"; + Base::table_s7 = "TABLE_S7"; + Base::table_s8 = "TABLE_S8"; + Base::table_skew = "TABLE_SKEW"; + Base::table_dx = "TABLE_DX"; + Base::table_dy = "TABLE_DY"; + Base::table_tx = "TABLE_TX"; + Base::table_ty = "TABLE_TY"; + Base::q_msm_transition = "Q_MSM_TRANSITION"; + Base::msm_q_add = "MSM_Q_ADD"; + Base::msm_q_double = "MSM_Q_DOUBLE"; + Base::msm_q_skew = "MSM_Q_SKEW"; + Base::msm_accumulator_x = "MSM_ACCUMULATOR_X"; + Base::msm_accumulator_y = "MSM_ACCUMULATOR_Y"; + Base::msm_pc = "MSM_PC"; + Base::msm_size_of_msm = "MSM_SIZE_OF_MSM"; + Base::msm_count = "MSM_COUNT"; + Base::msm_round = "MSM_ROUND"; + Base::msm_q_add1 = "MSM_Q_ADD1"; + Base::msm_q_add2 = "MSM_Q_ADD2"; + Base::msm_q_add3 = "MSM_Q_ADD3"; + Base::msm_q_add4 = "MSM_Q_ADD4"; + Base::msm_x1 = "MSM_X1"; + Base::msm_y1 = "MSM_Y1"; + Base::msm_x2 = "MSM_X2"; + Base::msm_y2 = "MSM_Y2"; + Base::msm_x3 = "MSM_X3"; + Base::msm_y3 = "MSM_Y3"; + Base::msm_x4 = "MSM_X4"; + Base::msm_y4 = "MSM_Y4"; + Base::msm_collision_x1 = "MSM_COLLISION_X1"; + Base::msm_collision_x2 = "MSM_COLLISION_X2"; + Base::msm_collision_x3 = "MSM_COLLISION_X3"; + Base::msm_collision_x4 = "MSM_COLLISION_X4"; + Base::msm_lambda1 = "MSM_LAMBDA1"; + Base::msm_lambda2 = "MSM_LAMBDA2"; + Base::msm_lambda3 = "MSM_LAMBDA3"; + Base::msm_lambda4 = "MSM_LAMBDA4"; + Base::msm_slice1 = "MSM_SLICE1"; + Base::msm_slice2 = "MSM_SLICE2"; + Base::msm_slice3 = "MSM_SLICE3"; + Base::msm_slice4 = "MSM_SLICE4"; + Base::transcript_accumulator_empty = "TRANSCRIPT_ACCUMULATOR_EMPTY"; + Base::transcript_q_reset_accumulator = "TRANSCRIPT_Q_RESET_ACCUMULATOR"; + Base::q_wnaf = "Q_WNAF"; + Base::lookup_read_counts_0 = "LOOKUP_READ_COUNTS_0"; + Base::lookup_read_counts_1 = "LOOKUP_READ_COUNTS_1"; + Base::z_perm = "Z_PERM"; + Base::lookup_inverses = "LOOKUP_INVERSES"; + // The ones beginning with "__" are only used for debugging + Base::lagrange_first = "__LAGRANGE_FIRST"; + Base::lagrange_second = "__LAGRANGE_SECOND"; + Base::lagrange_last = "__LAGRANGE_LAST"; + }; + }; + + class VerifierCommitments : public AllEntities { + private: + using Base = AllEntities; + + public: + VerifierCommitments(const std::shared_ptr& verification_key, + const VerifierTranscript& transcript) + { + static_cast(transcript); + Base::lagrange_first = verification_key->lagrange_first; + Base::lagrange_second = verification_key->lagrange_second; + Base::lagrange_last = verification_key->lagrange_last; + } + }; +}; + +class ECCVM : public ECCVMBase {}; +// not actually grumpkin, need to finish supporting grumpkin in ipa +class ECCVMGrumpkin : public ECCVMBase {}; + +// NOLINTEND(cppcoreguidelines-avoid-const-or-ref-data-members) + +} // namespace flavor +namespace sumcheck { + +extern template class ECCVMTranscriptRelationBase; +extern template class ECCVMWnafRelationBase; +extern template class ECCVMPointTableRelationBase; +extern template class ECCVMMSMRelationBase; +extern template class ECCVMSetRelationBase; +extern template class ECCVMLookupRelationBase; + +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMTranscriptRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMWnafRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMPointTableRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMMSMRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMSetRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMLookupRelationBase, flavor::ECCVM); + +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMTranscriptRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMWnafRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMPointTableRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMMSMRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMSetRelationBase, flavor::ECCVMGrumpkin); +DECLARE_SUMCHECK_RELATION_CLASS(ECCVMLookupRelationBase, flavor::ECCVMGrumpkin); + +DECLARE_SUMCHECK_PERMUTATION_CLASS(ECCVMSetRelationBase, flavor::ECCVM); +DECLARE_SUMCHECK_PERMUTATION_CLASS(ECCVMSetRelationBase, flavor::ECCVMGrumpkin); +} // namespace sumcheck +} // namespace proof_system::honk \ No newline at end of file diff --git a/cpp/src/barretenberg/honk/proof_system/eccvm_prover.cpp b/cpp/src/barretenberg/honk/proof_system/eccvm_prover.cpp new file mode 100644 index 0000000000..aac0a2b46b --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/eccvm_prover.cpp @@ -0,0 +1,385 @@ +#include "eccvm_prover.hpp" +#include "barretenberg/honk/pcs/claim.hpp" +#include "barretenberg/honk/pcs/commitment_key.hpp" +#include "barretenberg/honk/proof_system/lookup_library.hpp" +#include "barretenberg/honk/proof_system/permutation_library.hpp" +#include "barretenberg/honk/proof_system/prover_library.hpp" +#include "barretenberg/honk/sumcheck/polynomials/univariate.hpp" // will go away +#include "barretenberg/honk/sumcheck/relations/lookup_relation.hpp" +#include "barretenberg/honk/sumcheck/relations/permutation_relation.hpp" +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include "barretenberg/honk/utils/power_polynomial.hpp" +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/transcript/transcript_wrappers.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proof_system::honk { + +/** + * Create ECCVMProver_ from proving key, witness and manifest. + * + * @param input_key Proving key. + * @param input_manifest Input manifest + * + * @tparam settings Settings class. + * */ +template +ECCVMProver_::ECCVMProver_(std::shared_ptr input_key, + std::shared_ptr commitment_key) + : key(input_key) + , queue(commitment_key, transcript) + , pcs_commitment_key(commitment_key) +{ + + // TODO(@zac-williamson) is there a cleaner way of doing this? + prover_polynomials.q_transcript_add = key->q_transcript_add; + prover_polynomials.q_transcript_mul = key->q_transcript_mul; + prover_polynomials.q_transcript_eq = key->q_transcript_eq; + prover_polynomials.q_transcript_accumulate = key->q_transcript_accumulate; + prover_polynomials.q_transcript_msm_transition = key->q_transcript_msm_transition; + prover_polynomials.transcript_pc = key->transcript_pc; + prover_polynomials.transcript_msm_count = key->transcript_msm_count; + prover_polynomials.transcript_x = key->transcript_x; + prover_polynomials.transcript_y = key->transcript_y; + prover_polynomials.transcript_z1 = key->transcript_z1; + prover_polynomials.transcript_z2 = key->transcript_z2; + prover_polynomials.transcript_z1zero = key->transcript_z1zero; + prover_polynomials.transcript_z2zero = key->transcript_z2zero; + prover_polynomials.transcript_op = key->transcript_op; + prover_polynomials.transcript_accumulator_x = key->transcript_accumulator_x; + prover_polynomials.transcript_accumulator_y = key->transcript_accumulator_y; + prover_polynomials.transcript_msm_x = key->transcript_msm_x; + prover_polynomials.transcript_msm_y = key->transcript_msm_y; + prover_polynomials.table_pc = key->table_pc; + prover_polynomials.table_point_transition = key->table_point_transition; + prover_polynomials.table_round = key->table_round; + prover_polynomials.table_scalar_sum = key->table_scalar_sum; + prover_polynomials.table_s1 = key->table_s1; + prover_polynomials.table_s2 = key->table_s2; + prover_polynomials.table_s3 = key->table_s3; + prover_polynomials.table_s4 = key->table_s4; + prover_polynomials.table_s5 = key->table_s5; + prover_polynomials.table_s6 = key->table_s6; + prover_polynomials.table_s7 = key->table_s7; + prover_polynomials.table_s8 = key->table_s8; + prover_polynomials.table_skew = key->table_skew; + prover_polynomials.table_dx = key->table_dx; + prover_polynomials.table_dy = key->table_dy; + prover_polynomials.table_tx = key->table_tx; + prover_polynomials.table_ty = key->table_ty; + prover_polynomials.q_msm_transition = key->q_msm_transition; + prover_polynomials.msm_q_add = key->msm_q_add; + prover_polynomials.msm_q_double = key->msm_q_double; + prover_polynomials.msm_q_skew = key->msm_q_skew; + prover_polynomials.msm_accumulator_x = key->msm_accumulator_x; + prover_polynomials.msm_accumulator_y = key->msm_accumulator_y; + prover_polynomials.msm_pc = key->msm_pc; + prover_polynomials.msm_size_of_msm = key->msm_size_of_msm; + prover_polynomials.msm_count = key->msm_count; + prover_polynomials.msm_round = key->msm_round; + prover_polynomials.msm_q_add1 = key->msm_q_add1; + prover_polynomials.msm_q_add2 = key->msm_q_add2; + prover_polynomials.msm_q_add3 = key->msm_q_add3; + prover_polynomials.msm_q_add4 = key->msm_q_add4; + prover_polynomials.msm_x1 = key->msm_x1; + prover_polynomials.msm_y1 = key->msm_y1; + prover_polynomials.msm_x2 = key->msm_x2; + prover_polynomials.msm_y2 = key->msm_y2; + prover_polynomials.msm_x3 = key->msm_x3; + prover_polynomials.msm_y3 = key->msm_y3; + prover_polynomials.msm_x4 = key->msm_x4; + prover_polynomials.msm_y4 = key->msm_y4; + prover_polynomials.msm_collision_x1 = key->msm_collision_x1; + prover_polynomials.msm_collision_x2 = key->msm_collision_x2; + prover_polynomials.msm_collision_x3 = key->msm_collision_x3; + prover_polynomials.msm_collision_x4 = key->msm_collision_x4; + prover_polynomials.msm_lambda1 = key->msm_lambda1; + prover_polynomials.msm_lambda2 = key->msm_lambda2; + prover_polynomials.msm_lambda3 = key->msm_lambda3; + prover_polynomials.msm_lambda4 = key->msm_lambda4; + prover_polynomials.msm_slice1 = key->msm_slice1; + prover_polynomials.msm_slice2 = key->msm_slice2; + prover_polynomials.msm_slice3 = key->msm_slice3; + prover_polynomials.msm_slice4 = key->msm_slice4; + prover_polynomials.transcript_accumulator_empty = key->transcript_accumulator_empty; + prover_polynomials.transcript_q_reset_accumulator = key->transcript_q_reset_accumulator; + prover_polynomials.q_wnaf = key->q_wnaf; + prover_polynomials.lookup_read_counts_0 = key->lookup_read_counts_0; + prover_polynomials.lookup_read_counts_1 = key->lookup_read_counts_1; + prover_polynomials.q_transcript_mul_shift = key->q_transcript_mul.shifted(); + prover_polynomials.q_transcript_accumulate_shift = key->q_transcript_accumulate.shifted(); + prover_polynomials.transcript_msm_count_shift = key->transcript_msm_count.shifted(); + prover_polynomials.transcript_accumulator_x_shift = key->transcript_accumulator_x.shifted(); + prover_polynomials.transcript_accumulator_y_shift = key->transcript_accumulator_y.shifted(); + prover_polynomials.table_scalar_sum_shift = key->table_scalar_sum.shifted(); + prover_polynomials.table_dx_shift = key->table_dx.shifted(); + prover_polynomials.table_dy_shift = key->table_dy.shifted(); + prover_polynomials.table_tx_shift = key->table_tx.shifted(); + prover_polynomials.table_ty_shift = key->table_ty.shifted(); + prover_polynomials.q_msm_transition_shift = key->q_msm_transition.shifted(); + prover_polynomials.msm_q_add_shift = key->msm_q_add.shifted(); + prover_polynomials.msm_q_double_shift = key->msm_q_double.shifted(); + prover_polynomials.msm_q_skew_shift = key->msm_q_skew.shifted(); + prover_polynomials.msm_accumulator_x_shift = key->msm_accumulator_x.shifted(); + prover_polynomials.msm_accumulator_y_shift = key->msm_accumulator_y.shifted(); + prover_polynomials.msm_count_shift = key->msm_count.shifted(); + prover_polynomials.msm_round_shift = key->msm_round.shifted(); + prover_polynomials.msm_q_add1_shift = key->msm_q_add1.shifted(); + prover_polynomials.msm_pc_shift = key->msm_pc.shifted(); + prover_polynomials.table_pc_shift = key->table_pc.shifted(); + prover_polynomials.transcript_pc_shift = key->transcript_pc.shifted(); + prover_polynomials.table_round_shift = key->table_round.shifted(); + prover_polynomials.transcript_accumulator_empty_shift = key->transcript_accumulator_empty.shifted(); + prover_polynomials.q_wnaf_shift = key->q_wnaf.shifted(); + prover_polynomials.lagrange_first = key->lagrange_first; + prover_polynomials.lagrange_second = key->lagrange_second; + prover_polynomials.lagrange_last = key->lagrange_last; + + prover_polynomials.lookup_inverses = key->lookup_inverses; + key->z_perm = Polynomial(key->circuit_size); + prover_polynomials.z_perm = key->z_perm; +} + +/** + * @brief Commit to the first three wires only + * + */ +template void ECCVMProver_::compute_wire_commitments() +{ + auto wire_polys = key->get_wires(); + auto labels = commitment_labels.get_wires(); + for (size_t idx = 0; idx < wire_polys.size(); ++idx) { + queue.add_commitment(wire_polys[idx], labels[idx]); + } +} + +/** + * @brief Add circuit size, public input size, and public inputs to transcript + * + */ +template void ECCVMProver_::execute_preamble_round() +{ + const auto circuit_size = static_cast(key->circuit_size); + + transcript.send_to_verifier("circuit_size", circuit_size); +} + +/** + * @brief Compute commitments to the first three wires + * + */ +template void ECCVMProver_::execute_wire_commitments_round() +{ + auto wire_polys = key->get_wires(); + auto labels = commitment_labels.get_wires(); + for (size_t idx = 0; idx < wire_polys.size(); ++idx) { + queue.add_commitment(wire_polys[idx], labels[idx]); + } +} + +/** + * @brief Compute sorted witness-table accumulator + * + */ +template void ECCVMProver_::execute_log_derivative_commitments_round() +{ + // Compute and add eta to relation parameters + auto [eta, gamma] = transcript.get_challenges("beta", "gamma"); + // TODO(#583)(@zac-williamson): fix Transcript to be able to generate more than 2 challenges per round! oof. + auto eta_sqr = eta * eta; + relation_parameters.gamma = gamma; + relation_parameters.eta = eta; + relation_parameters.eta_sqr = eta_sqr; + relation_parameters.eta_cube = eta_sqr * eta; + relation_parameters.permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + relation_parameters.permutation_offset = relation_parameters.permutation_offset.invert(); + // Compute inverse polynomial for our logarithmic-derivative lookup method + lookup_library::compute_logderivative_inverse( + prover_polynomials, relation_parameters, key->circuit_size); + queue.add_commitment(key->lookup_inverses, commitment_labels.lookup_inverses); + prover_polynomials.lookup_inverses = key->lookup_inverses; +} + +/** + * @brief Compute permutation and lookup grand product polynomials and commitments + * + */ +template void ECCVMProver_::execute_grand_product_computation_round() +{ + // Compute permutation grand product and their commitments + permutation_library::compute_permutation_grand_products(key, prover_polynomials, relation_parameters); + + queue.add_commitment(key->z_perm, commitment_labels.z_perm); +} + +/** + * @brief Run Sumcheck resulting in u = (u_1,...,u_d) challenges and all evaluations at u being calculated. + * + */ +template void ECCVMProver_::execute_relation_check_rounds() +{ + using Sumcheck = sumcheck::Sumcheck>; + + auto sumcheck = Sumcheck(key->circuit_size, transcript); + + sumcheck_output = sumcheck.execute_prover(prover_polynomials, relation_parameters); +} + +/** + * - Get rho challenge + * - Compute d+1 Fold polynomials and their evaluations. + * + * */ +template void ECCVMProver_::execute_univariatization_round() +{ + const size_t NUM_POLYNOMIALS = Flavor::NUM_ALL_ENTITIES; + + // Generate batching challenge ρ and powers 1,ρ,…,ρᵐ⁻¹ + FF rho = transcript.get_challenge("rho"); + std::vector rhos = Gemini::powers_of_rho(rho, NUM_POLYNOMIALS); + + // Batch the unshifted polynomials and the to-be-shifted polynomials using ρ + Polynomial batched_poly_unshifted(key->circuit_size); // batched unshifted polynomials + size_t poly_idx = 0; // TODO(#391) zip + for (auto& unshifted_poly : prover_polynomials.get_unshifted()) { + batched_poly_unshifted.add_scaled(unshifted_poly, rhos[poly_idx]); + ++poly_idx; + } + + Polynomial batched_poly_to_be_shifted(key->circuit_size); // batched to-be-shifted polynomials + for (auto& to_be_shifted_poly : prover_polynomials.get_to_be_shifted()) { + batched_poly_to_be_shifted.add_scaled(to_be_shifted_poly, rhos[poly_idx]); + ++poly_idx; + }; + + // Compute d-1 polynomials Fold^(i), i = 1, ..., d-1. + fold_polynomials = Gemini::compute_fold_polynomials( + sumcheck_output.challenge_point, std::move(batched_poly_unshifted), std::move(batched_poly_to_be_shifted)); + + // Compute and add to trasnscript the commitments [Fold^(i)], i = 1, ..., d-1 + for (size_t l = 0; l < key->log_circuit_size - 1; ++l) { + queue.add_commitment(fold_polynomials[l + 2], "Gemini:FOLD_" + std::to_string(l + 1)); + } +} + +/** + * - Do Fiat-Shamir to get "r" challenge + * - Compute remaining two partially evaluated Fold polynomials Fold_{r}^(0) and Fold_{-r}^(0). + * - Compute and aggregate opening pairs (challenge, evaluation) for each of d Fold polynomials. + * - Add d-many Fold evaluations a_i, i = 0, ..., d-1 to the transcript, excluding eval of Fold_{r}^(0) + * */ +template void ECCVMProver_::execute_pcs_evaluation_round() +{ + const FF r_challenge = transcript.get_challenge("Gemini:r"); + gemini_output = Gemini::compute_fold_polynomial_evaluations( + sumcheck_output.challenge_point, std::move(fold_polynomials), r_challenge); + + for (size_t l = 0; l < key->log_circuit_size; ++l) { + std::string label = "Gemini:a_" + std::to_string(l); + const auto& evaluation = gemini_output.opening_pairs[l + 1].evaluation; + transcript.send_to_verifier(label, evaluation); + } +} + +/** + * - Do Fiat-Shamir to get "nu" challenge. + * - Compute commitment [Q]_1 + * */ +template void ECCVMProver_::execute_shplonk_batched_quotient_round() +{ + nu_challenge = transcript.get_challenge("Shplonk:nu"); + + batched_quotient_Q = + Shplonk::compute_batched_quotient(gemini_output.opening_pairs, gemini_output.witnesses, nu_challenge); + + // commit to Q(X) and add [Q] to the transcript + queue.add_commitment(batched_quotient_Q, "Shplonk:Q"); +} + +/** + * - Do Fiat-Shamir to get "z" challenge. + * - Compute polynomial Q(X) - Q_z(X) + * */ +template void ECCVMProver_::execute_shplonk_partial_evaluation_round() +{ + const FF z_challenge = transcript.get_challenge("Shplonk:z"); + + shplonk_output = Shplonk::compute_partially_evaluated_batched_quotient( + gemini_output.opening_pairs, gemini_output.witnesses, std::move(batched_quotient_Q), nu_challenge, z_challenge); +} +/** + * - Compute final PCS opening proof: + * - For KZG, this is the quotient commitment [W]_1 + * - For IPA, the vectors L and R + * */ +template void ECCVMProver_::execute_final_pcs_round() +{ + PCS::compute_opening_proof(pcs_commitment_key, shplonk_output.opening_pair, shplonk_output.witness, transcript); + // queue.add_commitment(quotient_W, "KZG:W"); +} + +template plonk::proof& ECCVMProver_::export_proof() +{ + proof.proof_data = transcript.proof_data; + return proof; +} + +template plonk::proof& ECCVMProver_::construct_proof() +{ + // Add circuit size public input size and public inputs to transcript. + execute_preamble_round(); + + // Compute first three wire commitments + execute_wire_commitments_round(); + queue.process_queue(); + + // Compute sorted list accumulator and commitment + execute_log_derivative_commitments_round(); + queue.process_queue(); + + // Fiat-Shamir: beta & gamma + // Compute grand product(s) and commitments. + execute_grand_product_computation_round(); + queue.process_queue(); + + // Fiat-Shamir: alpha + // Run sumcheck subprotocol. + execute_relation_check_rounds(); + + // Fiat-Shamir: rho + // Compute Fold polynomials and their commitments. + execute_univariatization_round(); + queue.process_queue(); + + // Fiat-Shamir: r + // Compute Fold evaluations + execute_pcs_evaluation_round(); + + // Fiat-Shamir: nu + // Compute Shplonk batched quotient commitment Q + execute_shplonk_batched_quotient_round(); + queue.process_queue(); + + // Fiat-Shamir: z + // Compute partial evaluation Q_z + execute_shplonk_partial_evaluation_round(); + + // Fiat-Shamir: z + // Compute PCS opening proof (either KZG quotient commitment or IPA opening proof) + execute_final_pcs_round(); + + return export_proof(); +} + +template class ECCVMProver_; +template class ECCVMProver_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/proof_system/eccvm_prover.hpp b/cpp/src/barretenberg/honk/proof_system/eccvm_prover.hpp new file mode 100644 index 0000000000..2756667791 --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/eccvm_prover.hpp @@ -0,0 +1,85 @@ +#pragma once +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/honk/pcs/gemini/gemini.hpp" +#include "barretenberg/honk/pcs/shplonk/shplonk_single.hpp" +#include "barretenberg/honk/proof_system/work_queue.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" +#include "barretenberg/honk/sumcheck/sumcheck_output.hpp" +#include "barretenberg/honk/transcript/transcript.hpp" +#include "barretenberg/plonk/proof_system/types/proof.hpp" + +namespace proof_system::honk { + +// We won't compile this class with honk::flavor::Standard, but we will like want to compile it (at least for testing) +// with a flavor that uses the curve Grumpkin, or a flavor that does/does not have zk, etc. +template class ECCVMProver_ { + + using FF = typename Flavor::FF; + using PCSParams = typename Flavor::PCSParams; + using PCS = typename Flavor::PCS; + using PCSCommitmentKey = typename Flavor::PCSParams::CommitmentKey; + using ProvingKey = typename Flavor::ProvingKey; + using Polynomial = typename Flavor::Polynomial; + using ProverPolynomials = typename Flavor::ProverPolynomials; + using CommitmentLabels = typename Flavor::CommitmentLabels; + + public: + explicit ECCVMProver_(std::shared_ptr input_key, std::shared_ptr commitment_key); + + void execute_preamble_round(); + void execute_wire_commitments_round(); + void execute_log_derivative_commitments_round(); + void execute_grand_product_computation_round(); + void execute_relation_check_rounds(); + void execute_univariatization_round(); + void execute_pcs_evaluation_round(); + void execute_shplonk_batched_quotient_round(); + void execute_shplonk_partial_evaluation_round(); + void execute_final_pcs_round(); + + void compute_wire_commitments(); + + plonk::proof& export_proof(); + plonk::proof& construct_proof(); + + ProverTranscript transcript; + + std::vector public_inputs; + + sumcheck::RelationParameters relation_parameters; + + std::shared_ptr key; + + // Container for spans of all polynomials required by the prover (i.e. all multivariates evaluated by Sumcheck). + ProverPolynomials prover_polynomials; + + CommitmentLabels commitment_labels; + + // Container for d + 1 Fold polynomials produced by Gemini + std::vector fold_polynomials; + + Polynomial batched_quotient_Q; // batched quotient poly computed by Shplonk + FF nu_challenge; // needed in both Shplonk rounds + + Polynomial quotient_W; + + work_queue queue; + + sumcheck::SumcheckOutput sumcheck_output; + pcs::gemini::ProverOutput gemini_output; + pcs::shplonk::ProverOutput shplonk_output; + std::shared_ptr pcs_commitment_key; + + using Gemini = pcs::gemini::MultilinearReductionScheme; + using Shplonk = pcs::shplonk::SingleBatchOpeningScheme; + + private: + plonk::proof proof; +}; + +extern template class ECCVMProver_; +extern template class ECCVMProver_; + +using ECCVMProver = ECCVMProver_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.cpp b/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.cpp new file mode 100644 index 0000000000..e9eb55dd69 --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.cpp @@ -0,0 +1,256 @@ +#include "./eccvm_verifier.hpp" +#include "barretenberg/honk/flavor/standard.hpp" +#include "barretenberg/honk/transcript/transcript.hpp" +#include "barretenberg/honk/utils/power_polynomial.hpp" +#include "barretenberg/numeric/bitop/get_msb.hpp" + +using namespace barretenberg; +using namespace proof_system::honk::sumcheck; + +namespace proof_system::honk { +template +ECCVMVerifier_::ECCVMVerifier_(std::shared_ptr verifier_key) + : key(verifier_key) +{} + +template +ECCVMVerifier_::ECCVMVerifier_(ECCVMVerifier_&& other) noexcept + : key(std::move(other.key)) + , pcs_verification_key(std::move(other.pcs_verification_key)) +{} + +template ECCVMVerifier_& ECCVMVerifier_::operator=(ECCVMVerifier_&& other) noexcept +{ + key = other.key; + pcs_verification_key = (std::move(other.pcs_verification_key)); + commitments.clear(); + pcs_fr_elements.clear(); + return *this; +} + +/** + * @brief This function verifies an ECCVM Honk proof for given program settings. + * + */ +template bool ECCVMVerifier_::verify_proof(const plonk::proof& proof) +{ + using FF = typename Flavor::FF; + using GroupElement = typename Flavor::GroupElement; + using Commitment = typename Flavor::Commitment; + using PCSParams = typename Flavor::PCSParams; + using PCS = typename Flavor::PCS; + using Gemini = pcs::gemini::MultilinearReductionScheme; + using Shplonk = pcs::shplonk::SingleBatchOpeningScheme; + using VerifierCommitments = typename Flavor::VerifierCommitments; + using CommitmentLabels = typename Flavor::CommitmentLabels; + + RelationParameters relation_parameters; + + transcript = VerifierTranscript{ proof.proof_data }; + + auto commitments = VerifierCommitments(key, transcript); + auto commitment_labels = CommitmentLabels(); + + // TODO(Adrian): Change the initialization of the transcript to take the VK hash? + const auto circuit_size = transcript.template receive_from_prover("circuit_size"); + + if (circuit_size != key->circuit_size) { + return false; + } + + // Get commitments to VM wires + commitments.q_transcript_add = + transcript.template receive_from_prover(commitment_labels.q_transcript_add); + commitments.q_transcript_mul = + transcript.template receive_from_prover(commitment_labels.q_transcript_mul); + commitments.q_transcript_eq = + transcript.template receive_from_prover(commitment_labels.q_transcript_eq); + commitments.q_transcript_accumulate = + transcript.template receive_from_prover(commitment_labels.q_transcript_accumulate); + commitments.q_transcript_msm_transition = + transcript.template receive_from_prover(commitment_labels.q_transcript_msm_transition); + commitments.transcript_pc = transcript.template receive_from_prover(commitment_labels.transcript_pc); + commitments.transcript_msm_count = + transcript.template receive_from_prover(commitment_labels.transcript_msm_count); + commitments.transcript_x = transcript.template receive_from_prover(commitment_labels.transcript_x); + commitments.transcript_y = transcript.template receive_from_prover(commitment_labels.transcript_y); + commitments.transcript_z1 = transcript.template receive_from_prover(commitment_labels.transcript_z1); + commitments.transcript_z2 = transcript.template receive_from_prover(commitment_labels.transcript_z2); + commitments.transcript_z1zero = + transcript.template receive_from_prover(commitment_labels.transcript_z1zero); + commitments.transcript_z2zero = + transcript.template receive_from_prover(commitment_labels.transcript_z2zero); + commitments.transcript_op = transcript.template receive_from_prover(commitment_labels.transcript_op); + commitments.transcript_accumulator_x = + transcript.template receive_from_prover(commitment_labels.transcript_accumulator_x); + commitments.transcript_accumulator_y = + transcript.template receive_from_prover(commitment_labels.transcript_accumulator_y); + commitments.transcript_msm_x = + transcript.template receive_from_prover(commitment_labels.transcript_msm_x); + commitments.transcript_msm_y = + transcript.template receive_from_prover(commitment_labels.transcript_msm_y); + commitments.table_pc = transcript.template receive_from_prover(commitment_labels.table_pc); + commitments.table_point_transition = + transcript.template receive_from_prover(commitment_labels.table_point_transition); + commitments.table_round = transcript.template receive_from_prover(commitment_labels.table_round); + commitments.table_scalar_sum = + transcript.template receive_from_prover(commitment_labels.table_scalar_sum); + commitments.table_s1 = transcript.template receive_from_prover(commitment_labels.table_s1); + commitments.table_s2 = transcript.template receive_from_prover(commitment_labels.table_s2); + commitments.table_s3 = transcript.template receive_from_prover(commitment_labels.table_s3); + commitments.table_s4 = transcript.template receive_from_prover(commitment_labels.table_s4); + commitments.table_s5 = transcript.template receive_from_prover(commitment_labels.table_s5); + commitments.table_s6 = transcript.template receive_from_prover(commitment_labels.table_s6); + commitments.table_s7 = transcript.template receive_from_prover(commitment_labels.table_s7); + commitments.table_s8 = transcript.template receive_from_prover(commitment_labels.table_s8); + commitments.table_skew = transcript.template receive_from_prover(commitment_labels.table_skew); + commitments.table_dx = transcript.template receive_from_prover(commitment_labels.table_dx); + commitments.table_dy = transcript.template receive_from_prover(commitment_labels.table_dy); + commitments.table_tx = transcript.template receive_from_prover(commitment_labels.table_tx); + commitments.table_ty = transcript.template receive_from_prover(commitment_labels.table_ty); + commitments.q_msm_transition = + transcript.template receive_from_prover(commitment_labels.q_msm_transition); + commitments.msm_q_add = transcript.template receive_from_prover(commitment_labels.msm_q_add); + commitments.msm_q_double = transcript.template receive_from_prover(commitment_labels.msm_q_double); + commitments.msm_q_skew = transcript.template receive_from_prover(commitment_labels.msm_q_skew); + commitments.msm_accumulator_x = + transcript.template receive_from_prover(commitment_labels.msm_accumulator_x); + commitments.msm_accumulator_y = + transcript.template receive_from_prover(commitment_labels.msm_accumulator_y); + commitments.msm_pc = transcript.template receive_from_prover(commitment_labels.msm_pc); + commitments.msm_size_of_msm = + transcript.template receive_from_prover(commitment_labels.msm_size_of_msm); + commitments.msm_count = transcript.template receive_from_prover(commitment_labels.msm_count); + commitments.msm_round = transcript.template receive_from_prover(commitment_labels.msm_round); + commitments.msm_q_add1 = transcript.template receive_from_prover(commitment_labels.msm_q_add1); + commitments.msm_q_add2 = transcript.template receive_from_prover(commitment_labels.msm_q_add2); + commitments.msm_q_add3 = transcript.template receive_from_prover(commitment_labels.msm_q_add3); + commitments.msm_q_add4 = transcript.template receive_from_prover(commitment_labels.msm_q_add4); + commitments.msm_x1 = transcript.template receive_from_prover(commitment_labels.msm_x1); + commitments.msm_y1 = transcript.template receive_from_prover(commitment_labels.msm_y1); + commitments.msm_x2 = transcript.template receive_from_prover(commitment_labels.msm_x2); + commitments.msm_y2 = transcript.template receive_from_prover(commitment_labels.msm_y2); + commitments.msm_x3 = transcript.template receive_from_prover(commitment_labels.msm_x3); + commitments.msm_y3 = transcript.template receive_from_prover(commitment_labels.msm_y3); + commitments.msm_x4 = transcript.template receive_from_prover(commitment_labels.msm_x4); + commitments.msm_y4 = transcript.template receive_from_prover(commitment_labels.msm_y4); + commitments.msm_collision_x1 = + transcript.template receive_from_prover(commitment_labels.msm_collision_x1); + commitments.msm_collision_x2 = + transcript.template receive_from_prover(commitment_labels.msm_collision_x2); + commitments.msm_collision_x3 = + transcript.template receive_from_prover(commitment_labels.msm_collision_x3); + commitments.msm_collision_x4 = + transcript.template receive_from_prover(commitment_labels.msm_collision_x4); + commitments.msm_lambda1 = transcript.template receive_from_prover(commitment_labels.msm_lambda1); + commitments.msm_lambda2 = transcript.template receive_from_prover(commitment_labels.msm_lambda2); + commitments.msm_lambda3 = transcript.template receive_from_prover(commitment_labels.msm_lambda3); + commitments.msm_lambda4 = transcript.template receive_from_prover(commitment_labels.msm_lambda4); + commitments.msm_slice1 = transcript.template receive_from_prover(commitment_labels.msm_slice1); + commitments.msm_slice2 = transcript.template receive_from_prover(commitment_labels.msm_slice2); + commitments.msm_slice3 = transcript.template receive_from_prover(commitment_labels.msm_slice3); + commitments.msm_slice4 = transcript.template receive_from_prover(commitment_labels.msm_slice4); + commitments.transcript_accumulator_empty = + transcript.template receive_from_prover(commitment_labels.transcript_accumulator_empty); + commitments.transcript_q_reset_accumulator = + transcript.template receive_from_prover(commitment_labels.transcript_q_reset_accumulator); + commitments.q_wnaf = transcript.template receive_from_prover(commitment_labels.q_wnaf); + commitments.lookup_read_counts_0 = + transcript.template receive_from_prover(commitment_labels.lookup_read_counts_0); + commitments.lookup_read_counts_1 = + transcript.template receive_from_prover(commitment_labels.lookup_read_counts_1); + + // Get challenge for sorted list batching and wire four memory records + auto [eta, gamma] = transcript.get_challenges("beta", "gamma"); + relation_parameters.gamma = gamma; + auto eta_sqr = eta * eta; + relation_parameters.eta = eta; + relation_parameters.eta_sqr = eta_sqr; + relation_parameters.eta_cube = eta_sqr * eta; + relation_parameters.permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + relation_parameters.permutation_offset = relation_parameters.permutation_offset.invert(); + + // Get commitment to permutation and lookup grand products + commitments.lookup_inverses = + transcript.template receive_from_prover(commitment_labels.lookup_inverses); + commitments.z_perm = transcript.template receive_from_prover(commitment_labels.z_perm); + + // Execute Sumcheck Verifier + auto sumcheck = Sumcheck>(circuit_size, transcript); + + std::optional sumcheck_output = sumcheck.execute_verifier(relation_parameters); + + // If Sumcheck does not return an output, sumcheck verification has failed + if (!sumcheck_output.has_value()) { + return false; + } + + auto [multivariate_challenge, purported_evaluations] = *sumcheck_output; + + // Execute Gemini/Shplonk verification: + + // Construct inputs for Gemini verifier: + // - Multivariate opening point u = (u_0, ..., u_{d-1}) + // - batched unshifted and to-be-shifted polynomial commitments + auto batched_commitment_unshifted = GroupElement::zero(); + auto batched_commitment_to_be_shifted = GroupElement::zero(); + + // Compute powers of batching challenge rho + FF rho = transcript.get_challenge("rho"); + std::vector rhos = Gemini::powers_of_rho(rho, Flavor::NUM_ALL_ENTITIES); + + // Compute batched multivariate evaluation + FF batched_evaluation = FF::zero(); + size_t evaluation_idx = 0; + for (auto& value : purported_evaluations.get_unshifted()) { + batched_evaluation += value * rhos[evaluation_idx]; + ++evaluation_idx; + } + for (auto& value : purported_evaluations.get_shifted()) { + batched_evaluation += value * rhos[evaluation_idx]; + ++evaluation_idx; + } + + // Construct batched commitment for NON-shifted polynomials + size_t commitment_idx = 0; + for (auto& commitment : commitments.get_unshifted()) { + // very lazy point at infinity check. not complete. fix. + if (commitment.y != 0) { + batched_commitment_unshifted += commitment * rhos[commitment_idx]; + } else { + std::cout << "point at infinity (unshifted)" << std::endl; + } + ++commitment_idx; + } + + // Construct batched commitment for to-be-shifted polynomials + for (auto& commitment : commitments.get_to_be_shifted()) { + // very lazy point at infinity check. not complete. fix. + if (commitment.y != 0) { + batched_commitment_to_be_shifted += commitment * rhos[commitment_idx]; + } else { + std::cout << "point at infinity (to be shifted)" << std::endl; + } + ++commitment_idx; + } + + // Produce a Gemini claim consisting of: + // - d+1 commitments [Fold_{r}^(0)], [Fold_{-r}^(0)], and [Fold^(l)], l = 1:d-1 + // - d+1 evaluations a_0_pos, and a_l, l = 0:d-1 + auto gemini_claim = Gemini::reduce_verify(multivariate_challenge, + batched_evaluation, + batched_commitment_unshifted, + batched_commitment_to_be_shifted, + transcript); + + // Produce a Shplonk claim: commitment [Q] - [Q_z], evaluation zero (at random challenge z) + auto shplonk_claim = Shplonk::reduce_verify(gemini_claim, transcript); + + // // Verify the Shplonk claim with KZG or IPA + return PCS::verify(pcs_verification_key, shplonk_claim, transcript); +} + +template class ECCVMVerifier_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.hpp b/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.hpp new file mode 100644 index 0000000000..5fc03225ce --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/eccvm_verifier.hpp @@ -0,0 +1,47 @@ +#pragma once +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include "barretenberg/plonk/proof_system/types/proof.hpp" + +namespace proof_system::honk { +template class ECCVMVerifier_ { + using FF = typename Flavor::FF; + using Commitment = typename Flavor::Commitment; + using VerificationKey = typename Flavor::VerificationKey; + using PCSVerificationKey = typename Flavor::PCSParams::VerificationKey; + + public: + explicit ECCVMVerifier_(std::shared_ptr verifier_key = nullptr); + ECCVMVerifier_(std::shared_ptr key, + std::map commitments, + std::map pcs_fr_elements, + std::shared_ptr pcs_verification_key, + VerifierTranscript transcript) + : key(std::move(key)) + , commitments(std::move(commitments)) + , pcs_fr_elements(std::move(pcs_fr_elements)) + , pcs_verification_key(std::move(pcs_verification_key)) + , transcript(std::move(transcript)) + {} + ECCVMVerifier_(ECCVMVerifier_&& other) noexcept; + ECCVMVerifier_(const ECCVMVerifier_& other) = delete; + ECCVMVerifier_& operator=(const ECCVMVerifier_& other) = delete; + ECCVMVerifier_& operator=(ECCVMVerifier_&& other) noexcept; + ~ECCVMVerifier_() = default; + + bool verify_proof(const plonk::proof& proof); + + std::shared_ptr key; + std::map commitments; + std::map pcs_fr_elements; + std::shared_ptr pcs_verification_key; + VerifierTranscript transcript; +}; + +extern template class ECCVMVerifier_; +extern template class ECCVMVerifier_; + +using ECCVMVerifier = ECCVMVerifier_; +using ECCVMVerifierGrumpkin = ECCVMVerifier_; + +} // namespace proof_system::honk diff --git a/cpp/src/barretenberg/honk/proof_system/lookup_library.hpp b/cpp/src/barretenberg/honk/proof_system/lookup_library.hpp new file mode 100644 index 0000000000..65f36427d7 --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/lookup_library.hpp @@ -0,0 +1,64 @@ +#pragma once +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include + +namespace proof_system::honk::lookup_library { + +/** + * @brief Compute the inverse polynomial I(X) required for logderivative lookups + * * + * @details + * Inverse may be defined in terms of its values on X_i = 0,1,...,n-1 as Z_perm[0] = 1 and for i = 1:n-1 + * 1 1 + * Inverse[i] = ∏ -------------------------- * ∏' -------------------------- + * relation::read_term(j) relation::write_term(j) + * + * where ∏ := ∏_{j=0:relation::NUM_READ_TERMS-1} and ∏' := ∏'_{j=0:relation::NUM_WRITE_TERMS-1} + * + * If row [i] does not contain a lookup read gate or a write gate, Inverse[i] = 0 + * N.B. by "write gate" we mean; do the lookup table polynomials contain nonzero values at this row? + * (in the ECCVM, the lookup table is not precomputed, so we have a concept of a "write gate", unlike when precomputed + * lookup tables are used) + * + * The specific algebraic relations that define read terms and write terms are defined in Flavor::LookupRelation + * + */ +template +void compute_logderivative_inverse(auto& polynomials, + sumcheck::RelationParameters& relation_parameters, + const size_t circuit_size) +{ + using FF = typename Flavor::FF; + using Accumulator = typename Relation::ValueAccumTypes; + constexpr size_t READ_TERMS = Relation::READ_TERMS; + constexpr size_t WRITE_TERMS = Relation::WRITE_TERMS; + auto& inverse_polynomial = polynomials.lookup_inverses; + // auto& inverse_polynomial = key->lookup_inverses; + // const size_t circuit_size = key->circuit_size; + + auto lookup_relation = Relation(); + for (size_t i = 0; i < circuit_size; ++i) { + bool has_inverse = + lookup_relation.template lookup_exists_at_row_index(polynomials, relation_parameters, i); + if (!has_inverse) { + continue; + } + FF denominator = 1; + barretenberg::constexpr_for<0, READ_TERMS, 1>([&] { + auto denominator_term = lookup_relation.template compute_read_term( + polynomials, relation_parameters, i); + denominator *= denominator_term; + }); + barretenberg::constexpr_for<0, WRITE_TERMS, 1>([&] { + auto denominator_term = lookup_relation.template compute_write_term( + polynomials, relation_parameters, i); + denominator *= denominator_term; + }); + inverse_polynomial[i] = denominator; + }; + + // todo might be inverting zero in field bleh bleh + FF::batch_invert(inverse_polynomial); +} + +} // namespace proof_system::honk::lookup_library \ No newline at end of file diff --git a/cpp/src/barretenberg/honk/proof_system/permutation_library.hpp b/cpp/src/barretenberg/honk/proof_system/permutation_library.hpp new file mode 100644 index 0000000000..2756219ae3 --- /dev/null +++ b/cpp/src/barretenberg/honk/proof_system/permutation_library.hpp @@ -0,0 +1,165 @@ +#pragma once +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp" +#include "barretenberg/polynomials/polynomial.hpp" +#include + +namespace proof_system::honk::permutation_library { + +/** + * @brief Compute a permutation grand product polynomial Z_perm(X) + * * + * @details + * Z_perm may be defined in terms of its values on X_i = 0,1,...,n-1 as Z_perm[0] = 1 and for i = 1:n-1 + * relation::numerator(j) + * Z_perm[i] = ∏ -------------------------------------------------------------------------------- + * relation::denominator(j) + * + * where ∏ := ∏_{j=0:i-1} + * + * The specific algebraic relation used by Z_perm is defined by Flavor::GrandProductRelations + * + * For example, in Flavor::Standard the relation describes: + * + * (w_1(j) + β⋅id_1(j) + γ) ⋅ (w_2(j) + β⋅id_2(j) + γ) ⋅ (w_3(j) + β⋅id_3(j) + γ) + * Z_perm[i] = ∏ -------------------------------------------------------------------------------- + * (w_1(j) + β⋅σ_1(j) + γ) ⋅ (w_2(j) + β⋅σ_2(j) + γ) ⋅ (w_3(j) + β⋅σ_3(j) + γ) + * where ∏ := ∏_{j=0:i-1} and id_i(X) = id(X) + n*(i-1) + * + * For Flavor::Ultra both the UltraPermutation and Lookup grand products are computed by this method. + * + * The grand product is constructed over the course of three steps. + * + * For expositional simplicity, write Z_perm[i] as + * + * A(j) + * Z_perm[i] = ∏ -------------------------- + * B(h) + * + * Step 1) Compute 2 length-n polynomials A, B + * Step 2) Compute 2 length-n polynomials numerator = ∏ A(j), nenominator = ∏ B(j) + * Step 3) Compute Z_perm[i + 1] = numerator[i] / denominator[i] (recall: Z_perm[0] = 1) + * + * Note: Step (3) utilizes Montgomery batch inversion to replace n-many inversions with + */ +template +void compute_permutation_grand_product(const size_t circuit_size, + auto& full_polynomials, + sumcheck::RelationParameters& relation_parameters) +{ + using FF = typename Flavor::FF; + using Polynomial = typename Flavor::Polynomial; + using ValueAccumTypes = typename PermutationRelation::ValueAccumTypes; + + // Allocate numerator/denominator polynomials that will serve as scratch space + // TODO(zac) we can re-use the permutation polynomial as the numerator polynomial. Reduces readability + Polynomial numerator = Polynomial{ circuit_size }; + Polynomial denominator = Polynomial{ circuit_size }; + + // Step (1) + // Populate `numerator` and `denominator` with the algebra described by PermutationRelation + const size_t num_threads = circuit_size >= get_num_cpus_pow2() ? get_num_cpus_pow2() : 1; + const size_t block_size = circuit_size / num_threads; + parallel_for(num_threads, [&](size_t thread_idx) { + const size_t start = thread_idx * block_size; + const size_t end = (thread_idx + 1) * block_size; + for (size_t i = start; i < end; ++i) { + + typename Flavor::ClaimedEvaluations evaluations; + for (size_t k = 0; k < Flavor::NUM_ALL_ENTITIES; ++k) { + evaluations[k] = full_polynomials[k].size() > i ? full_polynomials[k][i] : 0; + } + numerator[i] = PermutationRelation::template compute_permutation_numerator( + evaluations, relation_parameters, i); + denominator[i] = PermutationRelation::template compute_permutation_denominator( + evaluations, relation_parameters, i); + } + }); + + // Step (2) + // Compute the accumulating product of the numerator and denominator terms. + // This step is split into three parts for efficient multithreading: + // (i) compute ∏ A(j), ∏ B(j) subproducts for each thread + // (ii) compute scaling factor required to convert each subproduct into a single running product + // (ii) combine subproducts into a single running product + // + // For example, consider 4 threads and a size-8 numerator { a0, a1, a2, a3, a4, a5, a6, a7 } + // (i) Each thread computes 1 element of N = {{ a0, a0a1 }, { a2, a2a3 }, { a4, a4a5 }, { a6, a6a7 }} + // (ii) Take partial products P = { 1, a0a1, a2a3, a4a5 } + // (iii) Each thread j computes N[i][j]*P[j]= + // {{a0,a0a1},{a0a1a2,a0a1a2a3},{a0a1a2a3a4,a0a1a2a3a4a5},{a0a1a2a3a4a5a6,a0a1a2a3a4a5a6a7}} + std::vector partial_numerators(num_threads); + std::vector partial_denominators(num_threads); + + parallel_for(num_threads, [&](size_t thread_idx) { + const size_t start = thread_idx * block_size; + const size_t end = (thread_idx + 1) * block_size; + for (size_t i = start; i < end - 1; ++i) { + numerator[i + 1] *= numerator[i]; + denominator[i + 1] *= denominator[i]; + } + partial_numerators[thread_idx] = numerator[end - 1]; + partial_denominators[thread_idx] = denominator[end - 1]; + }); + + parallel_for(num_threads, [&](size_t thread_idx) { + const size_t start = thread_idx * block_size; + const size_t end = (thread_idx + 1) * block_size; + if (thread_idx > 0) { + FF numerator_scaling = 1; + FF denominator_scaling = 1; + + for (size_t j = 0; j < thread_idx; ++j) { + numerator_scaling *= partial_numerators[j]; + denominator_scaling *= partial_denominators[j]; + } + for (size_t i = start; i < end; ++i) { + numerator[i] *= numerator_scaling; + denominator[i] *= denominator_scaling; + } + } + + // Final step: invert denominator + FF::batch_invert(std::span{ &denominator[start], block_size }); + }); + + // Step (3) Compute z_perm[i] = numerator[i] / denominator[i] + auto& grand_product_polynomial = PermutationRelation::get_grand_product_polynomial(full_polynomials); + grand_product_polynomial[0] = 0; + parallel_for(num_threads, [&](size_t thread_idx) { + const size_t start = thread_idx * block_size; + const size_t end = (thread_idx == num_threads - 1) ? circuit_size - 1 : (thread_idx + 1) * block_size; + for (size_t i = start; i < end; ++i) { + grand_product_polynomial[i + 1] = numerator[i] * denominator[i]; + } + }); +} + +template +void compute_permutation_grand_products(std::shared_ptr& key, + typename Flavor::ProverPolynomials& full_polynomials, + sumcheck::RelationParameters& relation_parameters) +{ + using GrandProductRelations = typename Flavor::GrandProductRelations; + using FF = typename Flavor::FF; + + constexpr size_t NUM_RELATIONS = std::tuple_size{}; + barretenberg::constexpr_for<0, NUM_RELATIONS, 1>([&]() { + using PermutationRelation = typename std::tuple_element::type; + + // Assign the grand product polynomial to the relevant std::span member of `full_polynomials` (and its shift) + // For example, for UltraPermutationRelation, this will be `full_polynomials.z_perm` + // For example, for LookupRelation, this will be `full_polynomials.z_lookup` + std::span& full_polynomial = PermutationRelation::get_grand_product_polynomial(full_polynomials); + auto& key_polynomial = PermutationRelation::get_grand_product_polynomial(*key); + full_polynomial = key_polynomial; + + compute_permutation_grand_product( + key->circuit_size, full_polynomials, relation_parameters); + std::span& full_polynomial_shift = + PermutationRelation::get_shifted_grand_product_polynomial(full_polynomials); + full_polynomial_shift = key_polynomial.shifted(); + }); +} + +} // namespace proof_system::honk::permutation_library \ No newline at end of file diff --git a/cpp/src/barretenberg/honk/proof_system/prover_library.hpp b/cpp/src/barretenberg/honk/proof_system/prover_library.hpp index 9990400c74..39ead5849c 100644 --- a/cpp/src/barretenberg/honk/proof_system/prover_library.hpp +++ b/cpp/src/barretenberg/honk/proof_system/prover_library.hpp @@ -1,10 +1,18 @@ #pragma once +#include "barretenberg/common/constexpr_utils.hpp" #include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" #include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp" -#include "barretenberg/plonk/proof_system/types/program_settings.hpp" #include "barretenberg/plonk/proof_system/types/proof.hpp" -#include "barretenberg/polynomials/polynomial.hpp" +// TODO(@zac-williamson). We used to include `program_settings.hpp` in this file. Needed to remove due to circular +// dependency. `program_settings.hpp` included header files that added "using namespace proof_system" and "using +// namespace barretenberg" declarations. This effects downstream code that relies on these using declarations. This is a +// big code smell (should really not have using declarations in header files!), however fixing it requires changes in a +// LOT of files. This would clutter the eccvm feature PR. Adding these following "using namespace" declarations is a +// temp workaround. Once this work is merged in we should fix the root problem (no using declarations in header files) +using namespace proof_system; +using namespace barretenberg; namespace proof_system::honk::prover_library { template diff --git a/cpp/src/barretenberg/honk/sumcheck/polynomials/univariate.hpp b/cpp/src/barretenberg/honk/sumcheck/polynomials/univariate.hpp index 948382b845..5765f74cb5 100644 --- a/cpp/src/barretenberg/honk/sumcheck/polynomials/univariate.hpp +++ b/cpp/src/barretenberg/honk/sumcheck/polynomials/univariate.hpp @@ -96,6 +96,15 @@ template class Univariate { res -= other; return res; } + Univariate operator-() const + { + Univariate res(*this); + for (auto& eval : res.evaluations) { + eval = -eval; + } + return res; + } + Univariate operator*(const Univariate& other) const { Univariate res(*this); @@ -249,6 +258,15 @@ template class UnivariateView { return res; } + Univariate operator-() const + { + Univariate res(*this); + for (auto& eval : res.evaluations) { + eval = -eval; + } + return res; + } + Univariate operator*(const UnivariateView& other) const { Univariate res(*this); diff --git a/cpp/src/barretenberg/honk/sumcheck/relations/ecc_vm/ecc_vm_relation.test.cpp b/cpp/src/barretenberg/honk/sumcheck/relations/ecc_vm/ecc_vm_relation.test.cpp new file mode 100644 index 0000000000..b4888899d0 --- /dev/null +++ b/cpp/src/barretenberg/honk/sumcheck/relations/ecc_vm/ecc_vm_relation.test.cpp @@ -0,0 +1,360 @@ +#include "barretenberg/honk/composer/eccvm_composer.hpp" +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/honk/proof_system/lookup_library.hpp" +#include "barretenberg/honk/proof_system/permutation_library.hpp" +#include "barretenberg/honk/proof_system/prover_library.hpp" +#include "barretenberg/honk/sumcheck/sumcheck.hpp" +#include "barretenberg/numeric/random/engine.hpp" +#include "barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp" +#include + +/** + * We want to test if all three relations (namely, ArithmeticRelation, GrandProductComputationRelation, + * GrandProductInitializationRelation) provide correct contributions by manually computing their + * contributions with deterministic and random inputs. The relations are supposed to work with + * univariates (edges) of degree one (length 2) and spit out polynomials of corresponding degrees. We have + * MAX_RELATION_LENGTH = 5, meaning the output of a relation can atmost be a degree 5 polynomial. Hence, + * we use a method compute_mock_extended_edges() which starts with degree one input polynomial (two evaluation + points), + * extends them (using barycentric formula) to six evaluation points, and stores them to an array of polynomials. + */ + +using namespace proof_system::honk::sumcheck; +using Flavor = proof_system::honk::flavor::ECCVM; +using FF = typename Flavor::FF; +using ProverPolynomials = typename Flavor::ProverPolynomials; +using RawPolynomials = typename Flavor::RawPolynomials; + +static constexpr size_t NUM_POLYNOMIALS = Flavor::NUM_ALL_ENTITIES; + +namespace proof_system::honk_relation_tests_ecc_vm_full { + +namespace { +auto& engine = numeric::random::get_debug_engine(); +} + +ECCVMCircuitConstructor generate_trace(numeric::random::Engine* engine = nullptr) +{ + static bool init = false; + static grumpkin::g1::element a; + static grumpkin::g1::element b; + static grumpkin::g1::element c; + static grumpkin::fr x; + static grumpkin::fr y; + + ECCVMCircuitConstructor result; + if (!init) { + a = grumpkin::get_generator(0); + b = grumpkin::get_generator(1); + c = grumpkin::get_generator(2); + x = grumpkin::fr::random_element(engine); + y = grumpkin::fr::random_element(engine); + init = true; + } + + result.mul_accumulate(a, x); + + return result; +} + +TEST(SumcheckRelation, ECCVMLookupRelationAlgebra) +{ + const auto run_test = []() { + auto lookup_relation = ECCVMLookupRelation(); + + barretenberg::fr scaling_factor = barretenberg::fr::random_element(); + const FF gamma = FF::random_element(&engine); + const FF eta = FF::random_element(&engine); + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + honk::sumcheck::RelationParameters relation_params{ + .eta = eta, + .beta = 1, + .gamma = gamma, + .public_input_delta = 1, + .lookup_grand_product_delta = 1, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + + auto circuit_constructor = generate_trace(&engine); + auto rows = circuit_constructor.compute_full_polynomials(); + const size_t num_rows = rows[0].size(); + honk::lookup_library::compute_logderivative_inverse>( + rows, relation_params, num_rows); + honk::permutation_library::compute_permutation_grand_product>( + num_rows, rows, relation_params); + rows.z_perm_shift = Flavor::Polynomial(rows.z_perm.shifted()); + + // auto transcript_trace = transcript_trace.export_rows(); + + ECCVMLookupRelation::RelationValues result; + for (auto& r : result) { + r = 0; + } + for (size_t i = 0; i < num_rows; ++i) { + Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + lookup_relation.add_full_relation_value_contribution(result, row, relation_params, scaling_factor); + } + + for (auto r : result) { + EXPECT_EQ(r, 0); + } + }; + run_test(); +} + +TEST(SumcheckRelation, ECCVMFullRelationAlgebra) +{ + const auto run_test = []() { + // auto transcript_relation = ECCVMTranscriptRelation(); + // auto point_relation = ECCVMPointTableRelation(); + // auto wnaf_relation = ECCVMWnafRelation(); + // auto msm_relation = ECCVMMSMRelation(); + // auto set_relation = ECCVMSetRelation(); + auto lookup_relation = ECCVMLookupRelation(); + + barretenberg::fr scaling_factor = barretenberg::fr::random_element(); + const FF gamma = FF::random_element(&engine); + const FF eta = FF::random_element(&engine); + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + honk::sumcheck::RelationParameters relation_params{ + .eta = eta, + .beta = 1, + .gamma = gamma, + .public_input_delta = 1, + .lookup_grand_product_delta = 1, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + auto circuit_constructor = generate_trace(&engine); + auto rows = circuit_constructor.compute_full_polynomials(); + const size_t num_rows = rows[0].size(); + honk::lookup_library::compute_logderivative_inverse>( + rows, relation_params, num_rows); + honk::permutation_library::compute_permutation_grand_product>( + num_rows, rows, relation_params); + rows.z_perm_shift = Flavor::Polynomial(rows.z_perm.shifted()); + + // compute_permutation_polynomials(rows, relation_params); + // compute_lookup_inverse_polynomial(rows, relation_params); + + // auto transcript_trace = transcript_trace.export_rows(); + + ECCVMLookupRelation::RelationValues lookup_result; + for (auto& r : lookup_result) { + r = 0; + } + + const auto evaluate_relation = [&](const std::string& relation_name) { + auto relation = Relation(); + typename Relation::RelationValues result; + for (auto& r : result) { + r = 0; + } + constexpr size_t NUM_SUBRELATIONS = result.size(); + std::array relation_fail{}; + std::array relation_fails_at_row{}; + + for (size_t i = 0; i < num_rows; ++i) { + Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + relation.add_full_relation_value_contribution(result, row, relation_params, scaling_factor); + + for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) { + if (result[j] != 0) { + if (!relation_fail[j]) { + relation_fail[j] = true; + relation_fails_at_row[j] = i; + } + } + } + } + + for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) { + EXPECT_EQ(relation_fail[j], false); + if (relation_fail[j]) { + std::cerr << "relation " << relation_name << ", subrelation " << j + << " fails. First failure at row " << relation_fails_at_row[j] << std::endl; + } + } + }; + + evaluate_relation.template operator()>("ECCVMTranscriptRelation"); + evaluate_relation.template operator()>("ECCVMPointTableRelation"); + evaluate_relation.template operator()>("ECCVMWnafRelation"); + evaluate_relation.template operator()>("ECCVMMSMRelation"); + evaluate_relation.template operator()>("ECCVMSetRelation"); + + for (size_t i = 0; i < num_rows; ++i) { + Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + { + lookup_relation.add_full_relation_value_contribution( + lookup_result, row, relation_params, scaling_factor); + } + } + for (auto r : lookup_result) { + EXPECT_EQ(r, 0); + } + }; + run_test(); +} + +TEST(SumcheckRelation, ECCVMFullRelationProver) +{ + const auto run_test = []() { + const FF gamma = FF::random_element(&engine); + const FF eta = FF::random_element(&engine); + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + honk::sumcheck::RelationParameters relation_params{ + .eta = eta, + .beta = 1, + .gamma = gamma, + .public_input_delta = 1, + .lookup_grand_product_delta = 1, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + + auto circuit_constructor = generate_trace(&engine); + auto full_polynomials = circuit_constructor.compute_full_polynomials(); + const size_t num_rows = full_polynomials[0].size(); + honk::lookup_library::compute_logderivative_inverse>( + full_polynomials, relation_params, num_rows); + + honk::permutation_library::compute_permutation_grand_product>( + num_rows, full_polynomials, relation_params); + full_polynomials.z_perm_shift = Flavor::Polynomial(full_polynomials.z_perm.shifted()); + + // size_t pidx = 0; + // for (auto& p : full_polynomials) { + // size_t count = 0; + // for (auto& x : p) { + // std::cout << "poly[" << pidx << "][" << count << "] = " << x << std::endl; + // count++; + // } + // pidx++; + // } + // auto foo = full_polynomials.get_to_be_shifted(); + // size_t c = 0; + // for (auto& x : foo) { + // if (x[0] != 0) { + // std::cout << "shift at " << c << "not zero :/" << std::endl; + // } + // c += 1; + // } + const size_t multivariate_n = full_polynomials[0].size(); + const size_t multivariate_d = static_cast(numeric::get_msb64(multivariate_n)); + + EXPECT_EQ(1ULL << multivariate_d, multivariate_n); + + auto prover_transcript = honk::ProverTranscript::init_empty(); + + auto sumcheck_prover = Sumcheck>(multivariate_n, prover_transcript); + + auto prover_output = sumcheck_prover.execute_prover(full_polynomials, relation_params); + + auto verifier_transcript = honk::VerifierTranscript::init_empty(prover_transcript); + + auto sumcheck_verifier = Sumcheck>(multivariate_n, verifier_transcript); + + std::optional verifier_output = sumcheck_verifier.execute_verifier(relation_params); + + ASSERT_TRUE(verifier_output.has_value()); + ASSERT_EQ(prover_output, *verifier_output); + }; + run_test(); +} + +class ECCVMComposerTestsB : public ::testing::Test { + protected: + static void SetUpTestSuite() { barretenberg::srs::init_crs_factory("../srs_db/ignition"); } +}; +TEST_F(ECCVMComposerTestsB, BaseCase) +{ + auto circuit_constructor = generate_trace(&engine); + // auto composer = honk::ECCVMComposerHelper(); + // auto prover = composer.create_prover(circuit_constructor); + + // prover.construct_proof(); + // auto eta = prover.relation_parameters.eta; + // auto beta = prover.relation_parameters.beta; + // auto gamma = prover.relation_parameters.gamma; + // ECCVMBuilder trace2 = generate_trace(&engine); + + auto eta = FF::random_element(&engine); // prover.relation_parameters.eta; + auto beta = FF::random_element(&engine); // prover.relation_parameters.beta; + auto gamma = FF::random_element(&engine); // prover.relation_parameters.gamma; + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + + honk::sumcheck::RelationParameters relation_params{ + .eta = eta, + .beta = beta, + .gamma = gamma, + .public_input_delta = 0, + .lookup_grand_product_delta = 0, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + // std::cout << "gamma eta = " << gamma << " , " << eta << std::endl; + + // RawPolynomials full_polynomials = trace2.compute_full_polynomials(); + + // auto& full_polynomials = prover.prover_polynomials; + auto full_polynomials = circuit_constructor.compute_full_polynomials(); + // compute_logderivative_inverse(prover.proving_key, full_polynomials) + const size_t multivariate_n = full_polynomials[0].size(); + const size_t multivariate_d = static_cast(numeric::get_msb64(multivariate_n)); + + EXPECT_EQ(1ULL << multivariate_d, multivariate_n); + + honk::lookup_library::compute_logderivative_inverse>( + full_polynomials, relation_params, multivariate_n); + + honk::permutation_library::compute_permutation_grand_product>( + multivariate_n, full_polynomials, relation_params); + full_polynomials.z_perm_shift = Flavor::Polynomial(full_polynomials.z_perm.shifted()); + + auto prover_transcript = honk::ProverTranscript::init_empty(); + + auto sumcheck_prover = Sumcheck>(multivariate_n, prover_transcript); + + auto prover_output = sumcheck_prover.execute_prover(full_polynomials, relation_params); + + auto verifier_transcript = honk::VerifierTranscript::init_empty(prover_transcript); + + auto sumcheck_verifier = Sumcheck>(multivariate_n, verifier_transcript); + + std::optional verifier_output = sumcheck_verifier.execute_verifier(relation_params); + + ASSERT_TRUE(verifier_output.has_value()); + ASSERT_EQ(prover_output, *verifier_output); +} +} // namespace proof_system::honk_relation_tests_ecc_vm_full diff --git a/cpp/src/barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp b/cpp/src/barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp new file mode 100644 index 0000000000..53709d5e9b --- /dev/null +++ b/cpp/src/barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include "relation_types.hpp" + +#define ExtendedEdge(Flavor) Flavor::ExtendedEdges +#define EvaluationEdge(Flavor) Flavor::ClaimedEvaluations +#define EntityEdge(Flavor) Flavor::AllEntities + +#define ADD_EDGE_CONTRIBUTION(...) _ADD_EDGE_CONTRIBUTION(__VA_ARGS__) +#define _ADD_EDGE_CONTRIBUTION(Preface, Relation, Flavor, AccumulatorType, EdgeType) \ + Preface template void \ + Relation::add_edge_contribution_impl::AccumulatorType, \ + EdgeType(Flavor)>( \ + RelationWrapper::AccumulatorType::Accumulators&, \ + EdgeType(Flavor) const&, \ + RelationParameters const&, \ + Flavor::FF const&) const; + +#define PERMUTATION_METHOD(...) _PERMUTATION_METHOD(__VA_ARGS__) +#define _PERMUTATION_METHOD(Preface, MethodName, Relation, Flavor, AccumulatorType, EdgeType) \ + Preface template Relation::template Accumulator< \ + RelationWrapper::AccumulatorType> \ + Relation::MethodName::AccumulatorType, EdgeType(Flavor)>( \ + EdgeType(Flavor) const&, RelationParameters const&, size_t const); + +#define SUMCHECK_RELATION_CLASS(...) _SUMCHECK_RELATION_CLASS(__VA_ARGS__) +#define _SUMCHECK_RELATION_CLASS(Preface, Relation, Flavor) \ + ADD_EDGE_CONTRIBUTION(Preface, Relation, Flavor, UnivariateAccumTypes, ExtendedEdge) \ + ADD_EDGE_CONTRIBUTION(Preface, Relation, Flavor, ValueAccumTypes, EvaluationEdge) \ + ADD_EDGE_CONTRIBUTION(Preface, Relation, Flavor, ValueAccumTypes, EntityEdge) + +#define DECLARE_SUMCHECK_RELATION_CLASS(Relation, Flavor) SUMCHECK_RELATION_CLASS(extern, Relation, Flavor) +#define DEFINE_SUMCHECK_RELATION_CLASS(Relation, Flavor) SUMCHECK_RELATION_CLASS(, Relation, Flavor) + +#define SUMCHECK_PERMUTATION_CLASS(...) _SUMCHECK_PERMUTATION_CLASS(__VA_ARGS__) +#define _SUMCHECK_PERMUTATION_CLASS(Preface, Relation, Flavor) \ + PERMUTATION_METHOD(Preface, compute_permutation_numerator, Relation, Flavor, UnivariateAccumTypes, ExtendedEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_numerator, Relation, Flavor, ValueAccumTypes, EvaluationEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_numerator, Relation, Flavor, ValueAccumTypes, EntityEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_denominator, Relation, Flavor, UnivariateAccumTypes, ExtendedEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_denominator, Relation, Flavor, ValueAccumTypes, EvaluationEdge) \ + PERMUTATION_METHOD(Preface, compute_permutation_denominator, Relation, Flavor, ValueAccumTypes, EntityEdge) + +#define DECLARE_SUMCHECK_PERMUTATION_CLASS(Relation, Flavor) SUMCHECK_PERMUTATION_CLASS(extern, Relation, Flavor) +#define DEFINE_SUMCHECK_PERMUTATION_CLASS(Relation, Flavor) SUMCHECK_PERMUTATION_CLASS(, Relation, Flavor) diff --git a/cpp/src/barretenberg/honk/sumcheck/relations/relation_parameters.hpp b/cpp/src/barretenberg/honk/sumcheck/relations/relation_parameters.hpp index 863688560c..ccc74b3391 100644 --- a/cpp/src/barretenberg/honk/sumcheck/relations/relation_parameters.hpp +++ b/cpp/src/barretenberg/honk/sumcheck/relations/relation_parameters.hpp @@ -14,5 +14,8 @@ template struct RelationParameters { FF gamma = FF::zero(); // Permutation + Lookup FF public_input_delta = FF::zero(); // Permutation FF lookup_grand_product_delta = FF::zero(); // Lookup + FF eta_sqr = FF::zero(); + FF eta_cube = FF::zero(); + FF permutation_offset = FF::zero(); // TODO(@zac-williamson) explain what this is (to do w. set equality check) }; } // namespace proof_system::honk::sumcheck diff --git a/cpp/src/barretenberg/honk/sumcheck/relations/relation_types.hpp b/cpp/src/barretenberg/honk/sumcheck/relations/relation_types.hpp index e66cc8f980..38c0e1d237 100644 --- a/cpp/src/barretenberg/honk/sumcheck/relations/relation_types.hpp +++ b/cpp/src/barretenberg/honk/sumcheck/relations/relation_types.hpp @@ -1,9 +1,15 @@ #pragma once +#include "relation_parameters.hpp" #include +#include #include -#include "../polynomials/univariate.hpp" -#include "relation_parameters.hpp" +// forward-declare Polynomial so we can use in a concept +namespace barretenberg { +template class Polynomial; +} +template class Univariate; +template class UnivariateView; namespace proof_system::honk::sumcheck { template concept HasSubrelationLinearlyIndependentMember = requires(T) @@ -29,7 +35,7 @@ template concept HasSubrelationLinearlyIndependentMember = requires * @brief Getter method that will return `input[index]` iff `input` is a std::span container * * @tparam FF - * @tparam TypeMuncher + * @tparam AccumulatorTypes * @tparam T * @param input * @param index @@ -44,7 +50,25 @@ requires std::is_same, T>::value inline } /** - * @brief Getter method that will return `input[index]` iff `input` is not a std::span container + * @brief Getter method that will return `input[index]` iff `input` is a Polynomial container + * + * @tparam FF + * @tparam TypeMuncher + * @tparam T + * @param input + * @param index + * @return requires + */ +template +requires std::is_same, T>::value inline + typename std::tuple_element<0, typename AccumulatorTypes::AccumulatorViews>::type + get_view(const T& input, const size_t index) +{ + return input[index]; +} + +/** + * @brief Getter method that will return `input[index]` iff `input` is not a std::span or a Polynomial container * * @tparam FF * @tparam TypeMuncher @@ -104,30 +128,29 @@ template typename RelationBase> class Relation Relation::template add_edge_contribution_impl( accumulator, input, relation_parameters, scaling_factor); } - /** * @brief Check is subrelation is linearly independent - * Method always returns true if relation has no SUBRELATION_LINEARLY_INDEPENDENT std::array - * (i.e. default is to make linearly independent) + * Method is active if relation has SUBRELATION_LINEARLY_INDEPENDENT array defined * @tparam size_t */ template static constexpr bool is_subrelation_linearly_independent() requires( !HasSubrelationLinearlyIndependentMember) { - return true; + return std::get(Relation::SUBRELATION_LINEARLY_INDEPENDENT); } /** * @brief Check is subrelation is linearly independent - * Method is active if relation has SUBRELATION_LINEARLY_INDEPENDENT array defined + * Method always returns true if relation has no SUBRELATION_LINEARLY_INDEPENDENT std::array + * (i.e. default is to make linearly independent) * @tparam size_t */ template static constexpr bool is_subrelation_linearly_independent() requires( HasSubrelationLinearlyIndependentMember) { - return std::get(Relation::SUBRELATION_LINEARLY_INDEPENDENT); + return true; } }; diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_builder_types.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_builder_types.hpp new file mode 100644 index 0000000000..978d0bed55 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_builder_types.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" + +namespace proof_system_eccvm { + +static constexpr size_t NUM_SCALAR_BITS = 128; +static constexpr size_t WNAF_SLICE_BITS = 4; +static constexpr size_t NUM_WNAF_SLICES = (NUM_SCALAR_BITS + WNAF_SLICE_BITS - 1) / WNAF_SLICE_BITS; +static constexpr uint64_t WNAF_MASK = static_cast((1ULL << WNAF_SLICE_BITS) - 1ULL); +static constexpr size_t POINT_TABLE_SIZE = 1ULL << (WNAF_SLICE_BITS); +static constexpr size_t WNAF_SLICES_PER_ROW = 4; +static constexpr size_t ADDITIONS_PER_ROW = 4; + +template struct VMOperation { + bool add = false; + bool mul = false; + bool eq = false; + bool reset = false; + typename CycleGroup::affine_element base_point = typename CycleGroup::affine_element{ 0, 0 }; + uint256_t z1 = 0; + uint256_t z2 = 0; + typename CycleGroup::subgroup_field mul_scalar_full = 0; +}; +template struct ScalarMul { + uint32_t pc; + uint256_t scalar; + typename CycleGroup::affine_element base_point; + std::array wnaf_slices; + bool wnaf_skew; + std::array precomputed_table; +}; + +template using MSM = std::vector>; + +} // namespace proof_system_eccvm \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp new file mode 100644 index 0000000000..d55fb61305 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.hpp @@ -0,0 +1,489 @@ +#pragma once + +#include "./eccvm_builder_types.hpp" +#include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" + +#include "./msm_builder.hpp" +#include "./precomputed_tables_builder.hpp" +#include "./transcript_builder.hpp" +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/honk/proof_system/lookup_library.hpp" +#include "barretenberg/honk/proof_system/permutation_library.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" + +namespace proof_system { + +template class ECCVMCircuitConstructor { + public: + using CycleGroup = typename Flavor::CycleGroup; + using CycleScalar = typename CycleGroup::subgroup_field; + using FF = typename Flavor::FF; + using Element = typename CycleGroup::element; + using AffineElement = typename CycleGroup::affine_element; + static constexpr size_t NUM_SCALAR_BITS = proof_system_eccvm::NUM_SCALAR_BITS; + static constexpr size_t WNAF_SLICE_BITS = proof_system_eccvm::WNAF_SLICE_BITS; + static constexpr size_t NUM_WNAF_SLICES = proof_system_eccvm::NUM_WNAF_SLICES; + static constexpr uint64_t WNAF_MASK = proof_system_eccvm::WNAF_MASK; + static constexpr size_t POINT_TABLE_SIZE = proof_system_eccvm::POINT_TABLE_SIZE; + static constexpr size_t WNAF_SLICES_PER_ROW = proof_system_eccvm::WNAF_SLICES_PER_ROW; + static constexpr size_t ADDITIONS_PER_ROW = proof_system_eccvm::ADDITIONS_PER_ROW; + + static constexpr size_t NUM_POLYNOMIALS = Flavor::NUM_ALL_ENTITIES; + static constexpr size_t NUM_WIRES = Flavor::NUM_WIRES; + + using MSM = proof_system_eccvm::MSM; + using VMOperation = proof_system_eccvm::VMOperation; + std::vector vm_operations; + using ScalarMul = proof_system_eccvm::ScalarMul; + using RawPolynomials = typename Flavor::RawPolynomials; + using Polynomial = barretenberg::Polynomial; + uint32_t get_number_of_muls() + { + uint32_t num_muls = 0; + for (auto& op : vm_operations) { + if (op.mul) { + if (op.z1 != 0) { + num_muls++; + } + if (op.z2 != 0) { + num_muls++; + } + } + } + return num_muls; + } + + std::vector get_msms() + { + const uint32_t num_muls = get_number_of_muls(); + /** + * For input point [P], return { -15[P], -13[P], ..., -[P], [P], ..., 13[P], 15[P] } + */ + const auto compute_precomputed_table = [](const AffineElement& base_point) { + const auto d2 = Element(base_point).dbl(); + std::array table; + table[POINT_TABLE_SIZE / 2] = base_point; + for (size_t i = 1; i < POINT_TABLE_SIZE / 2; ++i) { + table[i + POINT_TABLE_SIZE / 2] = Element(table[i + POINT_TABLE_SIZE / 2 - 1]) + d2; + } + for (size_t i = 0; i < POINT_TABLE_SIZE / 2; ++i) { + table[i] = -table[POINT_TABLE_SIZE - 1 - i]; + } + return table; + }; + const auto compute_wnaf_slices = [](uint256_t scalar) { + std::array output; + int previous_slice = 0; + for (size_t i = 0; i < NUM_WNAF_SLICES; ++i) { + // slice the scalar into 4-bit chunks, starting with the least significant bits + uint64_t raw_slice = static_cast(scalar) & WNAF_MASK; + + bool is_even = ((raw_slice & 1ULL) == 0ULL); + + int wnaf_slice = static_cast(raw_slice); + + if (i == 0 && is_even) { + // if least significant slice is even, we add 1 to create an odd value && set 'skew' to true + wnaf_slice += 1; + } else if (is_even) { + // for other slices, if it's even, we add 1 to the slice value + // and subtract 16 from the previous slice to preserve the total scalar sum + static constexpr int borrow_constant = static_cast(1ULL << WNAF_SLICE_BITS); + previous_slice -= borrow_constant; + wnaf_slice += 1; + } + + if (i > 0) { + const size_t idx = i - 1; + output[NUM_WNAF_SLICES - idx - 1] = previous_slice; + } + previous_slice = wnaf_slice; + + // downshift raw_slice by 4 bits + scalar = scalar >> WNAF_SLICE_BITS; + } + + ASSERT(scalar == 0); + + output[0] = previous_slice; + + return output; + }; + std::vector msms; + std::vector active_msm; + + // We start pc at `num_muls` and decrement for each mul processed. + // This gives us two desired properties: + // 1: the value of pc at the 1st row = number of muls (easy to check) + // 2: the value of pc for the final mul = 1 + // The latter point is valuable as it means that we can add empty rows (where pc = 0) and still satisfy our + // sumcheck relations that involve pc (if we did the other way around, starting at 1 and ending at num_muls, + // we create a discontinuity in pc values between the last transcript row and the following empty row) + uint32_t pc = num_muls; + + const auto process_mul = [&active_msm, &pc, &compute_wnaf_slices, &compute_precomputed_table]( + const auto& scalar, const auto& base_point) { + if (scalar != 0) { + active_msm.push_back(ScalarMul{ + .pc = pc, + .scalar = scalar, + .base_point = base_point, + .wnaf_slices = compute_wnaf_slices(scalar), + .wnaf_skew = (scalar & 1) == 0, + .precomputed_table = compute_precomputed_table(base_point), + }); + pc--; + } + }; + + for (auto& op : vm_operations) { + if (op.mul) { + process_mul(op.z1, op.base_point); + process_mul(op.z2, AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y }); + + } else { + if (!active_msm.empty()) { + msms.push_back(active_msm); + active_msm = {}; + } + } + } + if (!active_msm.empty()) { + msms.push_back(active_msm); + } + + ASSERT(pc == 0); + return msms; + } + + static std::vector get_flattened_scalar_muls(const std::vector& msms) + { + std::vector result; + for (const auto& msm : msms) { + for (const auto& mul : msm) { + result.push_back(mul); + } + } + return result; + } + + void add_accumulate(const AffineElement& to_add) + { + vm_operations.emplace_back(VMOperation{ + .add = true, + .mul = false, + .eq = false, + .reset = false, + .base_point = to_add, + .z1 = 0, + .z2 = 0, + .mul_scalar_full = 0, + }); + } + + void mul_accumulate(const AffineElement& to_mul, const CycleScalar& scalar) + { + CycleScalar z1 = 0; + CycleScalar z2 = 0; + auto converted = scalar.from_montgomery_form(); + CycleScalar::split_into_endomorphism_scalars(converted, z1, z2); + z1 = z1.to_montgomery_form(); + z2 = z2.to_montgomery_form(); + vm_operations.emplace_back(VMOperation{ + .add = false, + .mul = true, + .eq = false, + .reset = false, + .base_point = to_mul, + .z1 = z1, + .z2 = z2, + .mul_scalar_full = scalar, + }); + } + void eq(const AffineElement& expected) + { + vm_operations.emplace_back(VMOperation{ + .add = false, + .mul = false, + .eq = true, + .reset = true, + .base_point = expected, + .z1 = 0, + .z2 = 0, + .mul_scalar_full = 0, + }); + } + + void empty_row() + { + vm_operations.emplace_back(VMOperation{ + .add = false, + .mul = false, + .eq = false, + .reset = false, + .base_point = CycleGroup::affine_point_at_infinity, + .z1 = 0, + .z2 = 0, + .mul_scalar_full = 0, + }); + } + + RawPolynomials compute_full_polynomials() + { + const auto msms = get_msms(); + const auto flattened_muls = get_flattened_scalar_muls(msms); + + std::array, 2> point_table_read_counts; + const auto transcript_state = + ECCVMTranscriptBuilder::compute_transcript_state(vm_operations, get_number_of_muls()); + const auto precompute_table_state = + ECCVMPrecomputedTablesBuilder::compute_precompute_state(flattened_muls); + const auto msm_state = + ECCVMMSMMBuilder::compute_msm_state(msms, point_table_read_counts, get_number_of_muls()); + + const size_t msm_size = msm_state.size(); + const size_t transcript_size = transcript_state.size(); + const size_t precompute_table_size = precompute_table_state.size(); + + const size_t num_rows = std::max(precompute_table_size, std::max(msm_size, transcript_size)); + + const size_t num_rows_log2 = static_cast(numeric::get_msb64(num_rows)); + size_t num_rows_pow2 = 1UL << (num_rows_log2 + (1UL << num_rows_log2 == num_rows ? 0 : 1)); + + RawPolynomials rows; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + rows[j] = Polynomial(num_rows_pow2); + } + + rows.lagrange_first[0] = 1; + rows.lagrange_second[1] = 1; + rows.lagrange_last[rows.lagrange_last.size() - 1] = 1; + + for (size_t i = 0; i < point_table_read_counts[0].size(); ++i) { + // TODO(@zac-williamson) explain off-by-one offset + // When computing the WNAF slice for a point at point counter value `pc` and a round index `round`, the row + // number that computes the slice can be derived. This row number is then mapped to the index of + // `lookup_read_counts`. We do this mapping in `ecc_msm_relation`. We are off-by-one because we add an empty + // row at the start of the WNAF columns that is not accounted for (index of lookup_read_counts maps to the + // row in our WNAF columns that computes a slice for a given value of pc and round) + rows.lookup_read_counts_0[i + 1] = point_table_read_counts[0][i]; + rows.lookup_read_counts_1[i + 1] = point_table_read_counts[1][i]; + } + for (size_t i = 0; i < transcript_state.size(); ++i) { + rows.transcript_accumulator_empty[i] = transcript_state[i].accumulator_empty; + rows.q_transcript_add[i] = transcript_state[i].q_add; + rows.q_transcript_mul[i] = transcript_state[i].q_mul; + rows.q_transcript_eq[i] = transcript_state[i].q_eq; + rows.transcript_q_reset_accumulator[i] = transcript_state[i].q_reset_accumulator; + rows.q_transcript_msm_transition[i] = transcript_state[i].q_msm_transition; + rows.transcript_pc[i] = transcript_state[i].pc; + rows.transcript_msm_count[i] = transcript_state[i].msm_count; + rows.transcript_x[i] = transcript_state[i].base_x; + rows.transcript_y[i] = transcript_state[i].base_y; + rows.transcript_z1[i] = transcript_state[i].z1; + rows.transcript_z2[i] = transcript_state[i].z2; + rows.transcript_z1zero[i] = transcript_state[i].z1_zero; + rows.transcript_z2zero[i] = transcript_state[i].z2_zero; + rows.transcript_op[i] = transcript_state[i].opcode; + rows.transcript_accumulator_x[i] = transcript_state[i].accumulator_x; + rows.transcript_accumulator_y[i] = transcript_state[i].accumulator_y; + rows.transcript_msm_x[i] = transcript_state[i].msm_output_x; + rows.transcript_msm_y[i] = transcript_state[i].msm_output_y; + } + + // TODO(@zac-williamson) if final opcode resets accumulator, all subsequent "is_accumulator_empty" row values + // must be 1. Ideally we find a way to tweak this so that empty rows that do nothing have column values that are + // all zero + if (transcript_state[transcript_state.size() - 1].accumulator_empty == 1) { + for (size_t i = transcript_state.size(); i < num_rows_pow2; ++i) { + rows.transcript_accumulator_empty[i] = 1; + } + } + for (size_t i = 0; i < precompute_table_state.size(); ++i) { + rows.q_wnaf[i] = (i != 0) ? 1 : 0; // todo document, derive etc etc // first row is empty! + rows.table_pc[i] = precompute_table_state[i].pc; + rows.table_point_transition[i] = static_cast(precompute_table_state[i].point_transition); + // rows.table_point_transition_shift = static_cast(table_state[i].point_transition); + rows.table_round[i] = precompute_table_state[i].round; + rows.table_scalar_sum[i] = precompute_table_state[i].scalar_sum; + + rows.table_s1[i] = precompute_table_state[i].s1; + rows.table_s2[i] = precompute_table_state[i].s2; + rows.table_s3[i] = precompute_table_state[i].s3; + rows.table_s4[i] = precompute_table_state[i].s4; + rows.table_s5[i] = precompute_table_state[i].s5; + rows.table_s6[i] = precompute_table_state[i].s6; + rows.table_s7[i] = precompute_table_state[i].s7; + rows.table_s8[i] = precompute_table_state[i].s8; + // todo explain why skew is 7 not 1 + rows.table_skew[i] = precompute_table_state[i].skew ? 7 : 0; + + rows.table_dx[i] = precompute_table_state[i].precompute_double.x; + rows.table_dy[i] = precompute_table_state[i].precompute_double.y; + rows.table_tx[i] = precompute_table_state[i].precompute_accumulator.x; + rows.table_ty[i] = precompute_table_state[i].precompute_accumulator.y; + } + + for (size_t i = 0; i < msm_state.size(); ++i) { + rows.q_msm_transition[i] = static_cast(msm_state[i].q_msm_transition); + rows.msm_q_add[i] = static_cast(msm_state[i].q_add); + rows.msm_q_double[i] = static_cast(msm_state[i].q_double); + rows.msm_q_skew[i] = static_cast(msm_state[i].q_skew); + rows.msm_accumulator_x[i] = msm_state[i].accumulator_x; + rows.msm_accumulator_y[i] = msm_state[i].accumulator_y; + rows.msm_pc[i] = msm_state[i].pc; + rows.msm_size_of_msm[i] = msm_state[i].msm_size; + rows.msm_count[i] = msm_state[i].msm_count; + rows.msm_round[i] = msm_state[i].msm_round; + rows.msm_q_add1[i] = static_cast(msm_state[i].add_state[0].add); + rows.msm_q_add2[i] = static_cast(msm_state[i].add_state[1].add); + rows.msm_q_add3[i] = static_cast(msm_state[i].add_state[2].add); + rows.msm_q_add4[i] = static_cast(msm_state[i].add_state[3].add); + rows.msm_x1[i] = msm_state[i].add_state[0].point.x; + rows.msm_y1[i] = msm_state[i].add_state[0].point.y; + rows.msm_x2[i] = msm_state[i].add_state[1].point.x; + rows.msm_y2[i] = msm_state[i].add_state[1].point.y; + rows.msm_x3[i] = msm_state[i].add_state[2].point.x; + rows.msm_y3[i] = msm_state[i].add_state[2].point.y; + rows.msm_x4[i] = msm_state[i].add_state[3].point.x; + rows.msm_y4[i] = msm_state[i].add_state[3].point.y; + rows.msm_collision_x1[i] = msm_state[i].add_state[0].collision_inverse; + rows.msm_collision_x2[i] = msm_state[i].add_state[1].collision_inverse; + rows.msm_collision_x3[i] = msm_state[i].add_state[2].collision_inverse; + rows.msm_collision_x4[i] = msm_state[i].add_state[3].collision_inverse; + rows.msm_lambda1[i] = msm_state[i].add_state[0].lambda; + rows.msm_lambda2[i] = msm_state[i].add_state[1].lambda; + rows.msm_lambda3[i] = msm_state[i].add_state[2].lambda; + rows.msm_lambda4[i] = msm_state[i].add_state[3].lambda; + rows.msm_slice1[i] = msm_state[i].add_state[0].slice; + rows.msm_slice2[i] = msm_state[i].add_state[1].slice; + rows.msm_slice3[i] = msm_state[i].add_state[2].slice; + rows.msm_slice4[i] = msm_state[i].add_state[3].slice; + } + + rows.q_transcript_mul_shift = typename Flavor::Polynomial(rows.q_transcript_mul.shifted()); + rows.q_transcript_accumulate_shift = typename Flavor::Polynomial(rows.q_transcript_accumulate.shifted()); + rows.transcript_msm_count_shift = typename Flavor::Polynomial(rows.transcript_msm_count.shifted()); + rows.transcript_accumulator_x_shift = typename Flavor::Polynomial(rows.transcript_accumulator_x.shifted()); + rows.transcript_accumulator_y_shift = typename Flavor::Polynomial(rows.transcript_accumulator_y.shifted()); + rows.table_scalar_sum_shift = typename Flavor::Polynomial(rows.table_scalar_sum.shifted()); + rows.table_dx_shift = typename Flavor::Polynomial(rows.table_dx.shifted()); + rows.table_dy_shift = typename Flavor::Polynomial(rows.table_dy.shifted()); + rows.table_tx_shift = typename Flavor::Polynomial(rows.table_tx.shifted()); + rows.table_ty_shift = typename Flavor::Polynomial(rows.table_ty.shifted()); + rows.q_msm_transition_shift = typename Flavor::Polynomial(rows.q_msm_transition.shifted()); + rows.msm_q_add_shift = typename Flavor::Polynomial(rows.msm_q_add.shifted()); + rows.msm_q_double_shift = typename Flavor::Polynomial(rows.msm_q_double.shifted()); + rows.msm_q_skew_shift = typename Flavor::Polynomial(rows.msm_q_skew.shifted()); + rows.msm_accumulator_x_shift = typename Flavor::Polynomial(rows.msm_accumulator_x.shifted()); + rows.msm_accumulator_y_shift = typename Flavor::Polynomial(rows.msm_accumulator_y.shifted()); + rows.msm_count_shift = typename Flavor::Polynomial(rows.msm_count.shifted()); + rows.msm_round_shift = typename Flavor::Polynomial(rows.msm_round.shifted()); + rows.msm_q_add1_shift = typename Flavor::Polynomial(rows.msm_q_add1.shifted()); + rows.msm_pc_shift = typename Flavor::Polynomial(rows.msm_pc.shifted()); + rows.table_pc_shift = typename Flavor::Polynomial(rows.table_pc.shifted()); + rows.transcript_pc_shift = typename Flavor::Polynomial(rows.transcript_pc.shifted()); + rows.table_round_shift = typename Flavor::Polynomial(rows.table_round.shifted()); + rows.transcript_accumulator_empty_shift = + typename Flavor::Polynomial(rows.transcript_accumulator_empty.shifted()); + rows.q_wnaf_shift = typename Flavor::Polynomial(rows.q_wnaf.shifted()); + return rows; + } + + bool check_circuit() + { + const FF gamma = FF::random_element(); + const FF eta = FF::random_element(); + const FF eta_sqr = eta.sqr(); + const FF eta_cube = eta_sqr * eta; + auto permutation_offset = + gamma * (gamma + eta_sqr) * (gamma + eta_sqr + eta_sqr) * (gamma + eta_sqr + eta_sqr + eta_sqr); + permutation_offset = permutation_offset.invert(); + proof_system::honk::sumcheck::RelationParameters params{ + .eta = eta, + .beta = 0, + .gamma = gamma, + .public_input_delta = 0, + .lookup_grand_product_delta = 0, + .eta_sqr = eta_sqr, + .eta_cube = eta_cube, + .permutation_offset = permutation_offset, + }; + + auto rows = compute_full_polynomials(); + const size_t num_rows = rows[0].size(); + proof_system::honk::lookup_library::compute_logderivative_inverse>( + rows, params, num_rows); + + honk::permutation_library::compute_permutation_grand_product>( + num_rows, rows, params); + + rows.z_perm_shift = typename Flavor::Polynomial(rows.z_perm.shifted()); + + const auto evaluate_relation = [&](const std::string& relation_name) { + auto relation = Relation(); + typename Relation::RelationValues result; + for (auto& r : result) { + r = 0; + } + constexpr size_t NUM_SUBRELATIONS = result.size(); + + for (size_t i = 0; i < num_rows; ++i) { + typename Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + relation.add_full_relation_value_contribution(result, row, params, 1); + + bool x = true; + for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) { + if (result[j] != 0) { + info("Relation ", relation_name, ", subrelation index ", j, " failed at row ", i); + x = false; + } + } + if (!x) { + return false; + } + } + return true; + }; + + bool result = true; + result = result && evaluate_relation.template operator()>( + "ECCVMTranscriptRelation"); + result = result && evaluate_relation.template operator()>( + "ECCVMPointTableRelation"); + result = + result && evaluate_relation.template operator()>("ECCVMWnafRelation"); + result = + result && evaluate_relation.template operator()>("ECCVMMSMRelation"); + result = + result && evaluate_relation.template operator()>("ECCVMSetRelation"); + + auto lookup_relation = honk::sumcheck::ECCVMLookupRelation(); + typename honk::sumcheck::ECCVMLookupRelation::RelationValues lookup_result; + for (auto& r : lookup_result) { + r = 0; + } + for (size_t i = 0; i < num_rows; ++i) { + typename Flavor::RowPolynomials row; + for (size_t j = 0; j < NUM_POLYNOMIALS; ++j) { + row[j] = rows[j][i]; + } + { + lookup_relation.add_full_relation_value_contribution(lookup_result, row, params, 1); + } + } + for (auto r : lookup_result) { + if (r != 0) { + info("Relation ECCVMLookupRelation failed."); + return false; + } + } + return result; + } +}; +} // namespace proof_system \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.test.cpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.test.cpp new file mode 100644 index 0000000000..2f44e45af8 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/eccvm_circuit_builder.test.cpp @@ -0,0 +1,222 @@ +#include "barretenberg/crypto/generators/generator_data.hpp" +#include "barretenberg/crypto/pedersen_commitment/pedersen.hpp" +#include "eccvm_circuit_builder.hpp" +#include + +using namespace barretenberg; + +namespace { +auto& engine = numeric::random::get_debug_engine(); +} + +namespace eccvm_circuit_builder_tests { + +TEST(ECCVMCircuitConstructor, BaseCase) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::g1::element b = grumpkin::get_generator(1); + grumpkin::g1::element c = grumpkin::get_generator(2); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + grumpkin::fr y = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x) + a + a + (b * y) + (b * x) + (b * x); + grumpkin::g1::element expected_2 = (a * x) + c + (b * x); + + circuit.add_accumulate(a); + circuit.mul_accumulate(a, x); + circuit.mul_accumulate(b, x); + circuit.mul_accumulate(b, y); + circuit.add_accumulate(a); + circuit.mul_accumulate(b, x); + circuit.eq(expected_1); + circuit.add_accumulate(c); + circuit.mul_accumulate(a, x); + circuit.mul_accumulate(b, x); + circuit.eq(expected_2); + circuit.mul_accumulate(a, x); + circuit.mul_accumulate(b, x); + circuit.mul_accumulate(c, x); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, Add) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + + circuit.add_accumulate(a); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, Mul) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + circuit.mul_accumulate(a, x); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, ShortMul) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + uint256_t small_x = 0; + // make sure scalar is less than 127 bits to fit in z1 + small_x.data[0] = engine.get_random_uint64(); + small_x.data[1] = engine.get_random_uint64() & 0xFFFFFFFFFFFFULL; + grumpkin::fr x = small_x; + + circuit.mul_accumulate(a, x); + circuit.eq(a * small_x); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EqFails) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + circuit.mul_accumulate(a, x); + circuit.eq(a); + bool result = circuit.check_circuit(); + EXPECT_EQ(result, false); +} + +TEST(ECCVMCircuitConstructor, EmptyRow) +{ + proof_system::ECCVMCircuitConstructor circuit; + + circuit.empty_row(); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EmptyRowBetweenOps) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x); + + circuit.mul_accumulate(a, x); + circuit.empty_row(); + circuit.eq(expected_1); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EndWithEq) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x); + + circuit.mul_accumulate(a, x); + circuit.eq(expected_1); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EndWithAdd) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + grumpkin::g1::element expected_1 = (a * x); + + circuit.mul_accumulate(a, x); + circuit.eq(expected_1); + circuit.add_accumulate(a); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EndWithMul) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + circuit.add_accumulate(a); + circuit.eq(a); + circuit.mul_accumulate(a, x); + + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, EndWithNoop) +{ + proof_system::ECCVMCircuitConstructor circuit; + + grumpkin::g1::element a = grumpkin::get_generator(0); + grumpkin::fr x = grumpkin::fr::random_element(&engine); + + circuit.add_accumulate(a); + circuit.eq(a); + circuit.mul_accumulate(a, x); + circuit.empty_row(); + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} + +TEST(ECCVMCircuitConstructor, MSM) +{ + const auto try_msms = [&](const size_t num_msms, auto& circuit) { + std::vector points; + std::vector scalars; + grumpkin::g1::element expected = grumpkin::g1::point_at_infinity; + for (size_t i = 0; i < num_msms; ++i) { + points.emplace_back(grumpkin::get_generator(i)); + scalars.emplace_back(grumpkin::fr::random_element(&engine)); + expected += (points[i] * scalars[i]); + circuit.mul_accumulate(points[i], scalars[i]); + } + circuit.eq(expected); + }; + + // single msms + for (size_t j = 1; j < 9; ++j) { + proof_system::ECCVMCircuitConstructor circuit; + try_msms(j, circuit); + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); + } + // chain msms + proof_system::ECCVMCircuitConstructor circuit; + for (size_t j = 1; j < 9; ++j) { + try_msms(j, circuit); + } + bool result = circuit.check_circuit(); + EXPECT_EQ(result, true); +} +} // namespace eccvm_circuit_builder_tests \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/msm_builder.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/msm_builder.hpp new file mode 100644 index 0000000000..c4af616d03 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/msm_builder.hpp @@ -0,0 +1,263 @@ +#pragma once + +#include + +#include "./eccvm_builder_types.hpp" + +namespace proof_system { + +template class ECCVMMSMMBuilder { + public: + using CycleGroup = typename Flavor::CycleGroup; + using FF = typename Flavor::FF; + using Element = typename CycleGroup::element; + using AffineElement = typename CycleGroup::affine_element; + + static constexpr size_t ADDITIONS_PER_ROW = proof_system_eccvm::ADDITIONS_PER_ROW; + static constexpr size_t NUM_SCALAR_BITS = proof_system_eccvm::NUM_SCALAR_BITS; + static constexpr size_t WNAF_SLICE_BITS = proof_system_eccvm::WNAF_SLICE_BITS; + + struct MSMState { + uint32_t pc = 0; + uint32_t msm_size = 0; + uint32_t msm_count = 0; + uint32_t msm_round = 0; + bool q_msm_transition = false; + bool q_add = false; + bool q_double = false; + bool q_skew = false; + + struct AddState { + bool add = false; + int slice = 0; + AffineElement point{ 0, 0 }; + FF lambda = 0; + FF collision_inverse = 0; + }; + std::array add_state{ AddState{ false, 0, { 0, 0 }, 0, 0 }, + AddState{ false, 0, { 0, 0 }, 0, 0 }, + AddState{ false, 0, { 0, 0 }, 0, 0 }, + AddState{ false, 0, { 0, 0 }, 0, 0 } }; + FF accumulator_x = 0; + FF accumulator_y = 0; + }; + + static std::vector compute_msm_state(const std::vector>& msms, + std::array, 2>& point_table_read_counts, + const uint32_t total_number_of_muls) + { + // when we define our point lookup table, we have 2 write columns and 4 read columns + // when we perform a read on a given row, we need to increment the read count on the respective write column by + // 1 we can define the following struture: 1st write column = positive 2nd write column = negative the row + // number is a function of pc and slice value row = pc_delta * rows_per_point_table + some function of the slice + // value pc_delta = total_number_of_muls - pc std::vector point_table_read_counts; + const size_t table_rows = static_cast(total_number_of_muls) * 8; + point_table_read_counts[0].reserve(table_rows); + point_table_read_counts[1].reserve(table_rows); + for (size_t i = 0; i < table_rows; ++i) { + point_table_read_counts[0].emplace_back(0); + point_table_read_counts[1].emplace_back(0); + } + const auto update_read_counts = [&](const size_t pc, const int slice) { + // When we compute our wnaf/point tables, we start with the point with the largest pc value. + // i.e. if we are reading a slice for point with a point counter value `pc`, + // its position in the wnaf/point table (relative to other points) will be `total_number_of_muls - pc` + const size_t pc_delta = total_number_of_muls - pc; + const size_t pc_offset = pc_delta * 8; + bool slice_negative = slice < 0; + const int slice_row = (slice + 15) / 2; + + const size_t column_index = slice_negative ? 1 : 0; + + if (slice_negative) { + point_table_read_counts[column_index][pc_offset + static_cast(slice_row)]++; + } else { + // 8 maps to 7 + // 15 maps to 0 + + // 15 - x + point_table_read_counts[column_index][pc_offset + 15 - static_cast(slice_row)]++; + } + // slice : row + // -15 : 0 + // -13 : 1 + // -11 : 2 + // -9 : 3 + // -7 : 4 + // -5 : 5 + // -3 : 6 + // -1 : 7 + // 1 : 8 + // 3 : 9 + // 5 : 10 + // 7 : 11 + // 9 : 12 + // 11 : 13 + // 13 : 14 + // 15 : 15 + }; + std::vector msm_state; + // start with empty row (shiftable polynomials must have 0 as first coefficient) + msm_state.emplace_back(MSMState{}); + uint32_t pc = total_number_of_muls; + AffineElement accumulator = CycleGroup::affine_point_at_infinity; + + for (const auto& msm : msms) { + const size_t msm_size = msm.size(); + + const size_t rows_per_round = (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); + static constexpr size_t num_rounds = NUM_SCALAR_BITS / WNAF_SLICE_BITS; + + const auto add_points = [](auto& P1, auto& P2, auto& lambda, auto& collision_inverse, bool predicate) { + lambda = predicate ? (P2.y - P1.y) / (P2.x - P1.x) : 0; + collision_inverse = predicate ? (P2.x - P1.x).invert() : 0; + auto x3 = predicate ? lambda * lambda - (P2.x + P1.x) : P1.x; + auto y3 = predicate ? lambda * (P1.x - x3) - P1.y : P1.y; + return AffineElement(x3, y3); + }; + for (size_t j = 0; j < num_rounds; ++j) { + for (size_t k = 0; k < rows_per_round; ++k) { + MSMState row; + const size_t points_per_row = + (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; + const size_t idx = k * ADDITIONS_PER_ROW; + row.q_msm_transition = (j == 0) && (k == 0); + + AffineElement acc(accumulator); + Element acc_expected = accumulator; + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.add = points_per_row > m; + int slice = add_state.add ? msm[idx + m].wnaf_slices[j] : 0; + add_state.slice = add_state.add ? (slice + 15) / 2 : 0; + add_state.point = add_state.add + ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + bool add_predicate = (m == 0 ? (j != 0 || k != 0) : add_state.add); + + auto& p1 = (m == 0) ? add_state.point : acc; + auto& p2 = (m == 0) ? acc : add_state.point; + + acc_expected = add_predicate ? (acc_expected + add_state.point) : Element(p1); + if (add_state.add) { + update_read_counts(pc - idx - m, slice); + } + acc = add_points(p1, p2, add_state.lambda, add_state.collision_inverse, add_predicate); + ASSERT(acc == AffineElement(acc_expected)); + } + row.q_add = true; + row.q_double = false; + row.q_skew = false; + row.msm_round = static_cast(j); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(idx); + row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x; + row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y; + row.pc = pc; + accumulator = acc; + msm_state.push_back(row); + } + if (j < num_rounds - 1) { + MSMState row; + row.q_msm_transition = false; + row.msm_round = static_cast(j + 1); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(0); + row.q_add = false; + row.q_double = true; + row.q_skew = false; + + auto dx = accumulator.x; + auto dy = accumulator.y; + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.add = false; + add_state.slice = 0; + add_state.point = { 0, 0 }; + add_state.collision_inverse = 0; + add_state.lambda = ((dx + dx + dx) * dx) / (dy + dy); + auto x3 = add_state.lambda.sqr() - dx - dx; + dy = add_state.lambda * (dx - x3) - dy; + dx = x3; + } + + row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x; + row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y; + accumulator = Element(accumulator).dbl().dbl().dbl().dbl(); + row.pc = pc; + msm_state.push_back(row); + } else { + for (size_t k = 0; k < rows_per_round; ++k) { + MSMState row; + + const size_t points_per_row = + (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; + const size_t idx = k * ADDITIONS_PER_ROW; + row.q_msm_transition = false; + + AffineElement acc(accumulator); + Element acc_expected = accumulator; + + for (size_t m = 0; m < 4; ++m) { + auto& add_state = row.add_state[m]; + add_state.add = points_per_row > m; + add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0; + + add_state.point = add_state.add + ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; + if (add_state.add) { + update_read_counts(pc - idx - m, msm[idx + m].wnaf_skew ? -1 : -15); + } + acc = add_points( + acc, add_state.point, add_state.lambda, add_state.collision_inverse, add_predicate); + acc_expected = add_predicate ? (acc_expected + add_state.point) : acc_expected; + ASSERT(acc == AffineElement(acc_expected)); + } + row.q_add = false; + row.q_double = false; + row.q_skew = true; + row.msm_round = static_cast(j + 1); + row.msm_size = static_cast(msm_size); + row.msm_count = static_cast(idx); + + row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x; + row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y; + + row.pc = pc; + accumulator = acc; + msm_state.emplace_back(row); + } + } + } + pc -= static_cast(msm_size); + // Validate our computed accumulator matches the real MSM result! + Element expected = CycleGroup::point_at_infinity; + for (size_t i = 0; i < msm.size(); ++i) { + expected += (Element(msm[i].base_point) * msm[i].scalar); + } + // Validate the accumulator is correct! + ASSERT(accumulator == AffineElement(expected)); + } + + MSMState final_row; + final_row.pc = pc; + final_row.q_msm_transition = true; + final_row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x; + final_row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y; + final_row.msm_size = 0; + final_row.msm_count = 0; + final_row.q_add = false; + final_row.q_double = false; + final_row.q_skew = false; + final_row.add_state = { typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } }; + + msm_state.emplace_back(final_row); + return msm_state; + } +}; +} // namespace proof_system \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/precomputed_tables_builder.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/precomputed_tables_builder.hpp new file mode 100644 index 0000000000..27c9cf48df --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/precomputed_tables_builder.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include "./eccvm_builder_types.hpp" + +namespace proof_system { + +template class ECCVMPrecomputedTablesBuilder { + public: + using CycleGroup = typename Flavor::CycleGroup; + using FF = typename Flavor::FF; + using Element = typename CycleGroup::element; + using AffineElement = typename CycleGroup::affine_element; + + static constexpr size_t NUM_WNAF_SLICES = proof_system_eccvm::NUM_WNAF_SLICES; + static constexpr size_t WNAF_SLICES_PER_ROW = proof_system_eccvm::WNAF_SLICES_PER_ROW; + static constexpr size_t WNAF_SLICE_BITS = proof_system_eccvm::WNAF_SLICE_BITS; + + struct PrecomputeState { + int s1 = 0; + int s2 = 0; + int s3 = 0; + int s4 = 0; + int s5 = 0; + int s6 = 0; + int s7 = 0; + int s8 = 0; + bool skew = false; + bool point_transition = false; + uint32_t pc = 0; + uint32_t round = 0; + uint256_t scalar_sum = 0; + AffineElement precompute_accumulator{ 0, 0 }; + AffineElement precompute_double{ 0, 0 }; + }; + + static std::vector compute_precompute_state( + const std::vector>& ecc_muls) + { + std::vector precompute_state; + + // start with empty row (shiftable polynomials must have 0 as first coefficient) + precompute_state.push_back(PrecomputeState{}); + static constexpr size_t num_rows_per_scalar = NUM_WNAF_SLICES / WNAF_SLICES_PER_ROW; + + // current impl doesn't work if not 4 + static_assert(WNAF_SLICES_PER_ROW == 4); + + for (const auto& entry : ecc_muls) { + const auto& slices = entry.wnaf_slices; + uint256_t scalar_sum = 0; + + const Element point = entry.base_point; + const Element d2 = point.dbl(); + + for (size_t i = 0; i < num_rows_per_scalar; ++i) { + PrecomputeState row; + const int slice0 = slices[i * WNAF_SLICES_PER_ROW]; + const int slice1 = slices[i * WNAF_SLICES_PER_ROW + 1]; + const int slice2 = slices[i * WNAF_SLICES_PER_ROW + 2]; + const int slice3 = slices[i * WNAF_SLICES_PER_ROW + 3]; + + const int slice0base2 = (slice0 + 15) / 2; + const int slice1base2 = (slice1 + 15) / 2; + const int slice2base2 = (slice2 + 15) / 2; + const int slice3base2 = (slice3 + 15) / 2; + + // convert into 2-bit chunks + row.s1 = slice0base2 >> 2; + row.s2 = slice0base2 & 3; + row.s3 = slice1base2 >> 2; + row.s4 = slice1base2 & 3; + row.s5 = slice2base2 >> 2; + row.s6 = slice2base2 & 3; + row.s7 = slice3base2 >> 2; + row.s8 = slice3base2 & 3; + bool last_row = (i == num_rows_per_scalar - 1); + + row.skew = last_row ? entry.wnaf_skew : false; + + row.scalar_sum = scalar_sum; + + // TODO(@zac-williamson). If 1st row do we apply constraint that requires slice0 to be positive? + // Need this if we want to rule out negative values (i.e. input has not yet been range + // constrained) + const int row_chunk = slice3 + slice2 * (1 << 4) + slice1 * (1 << 8) + slice0 * (1 << 12); + + bool chunk_negative = row_chunk < 0; + + scalar_sum = scalar_sum << (WNAF_SLICE_BITS * WNAF_SLICES_PER_ROW); + if (chunk_negative) { + scalar_sum -= static_cast(-row_chunk); + } else { + scalar_sum += static_cast(row_chunk); + } + row.round = static_cast(i); + row.point_transition = last_row; + row.pc = entry.pc; + + if (last_row) { + ASSERT(scalar_sum - entry.wnaf_skew == entry.scalar); + } + + row.precompute_double = d2; + // fill accumulator in reverse order i.e. first row = 15[P], then 13[P], ..., 1[P] + row.precompute_accumulator = entry.precomputed_table[proof_system_eccvm::POINT_TABLE_SIZE - 1 - i]; + precompute_state.emplace_back(row); + } + } + return precompute_state; + } +}; +} // namespace proof_system \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/transcript_builder.hpp b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/transcript_builder.hpp new file mode 100644 index 0000000000..74a2b6ab8d --- /dev/null +++ b/cpp/src/barretenberg/proof_system/circuit_builder/eccvm/transcript_builder.hpp @@ -0,0 +1,175 @@ +#pragma once + +#include "./eccvm_builder_types.hpp" + +namespace proof_system { + +template class ECCVMTranscriptBuilder { + public: + using CycleGroup = typename Flavor::CycleGroup; + using FF = typename Flavor::FF; + using Element = typename CycleGroup::element; + using AffineElement = typename CycleGroup::affine_element; + + struct TranscriptState { + bool accumulator_empty = false; + bool q_add = false; + bool q_mul = false; + bool q_eq = false; + bool q_reset_accumulator = false; + bool q_msm_transition = false; + uint32_t pc = 0; + uint32_t msm_count = 0; + FF base_x = 0; + FF base_y = 0; + uint256_t z1 = 0; + uint256_t z2 = 0; + bool z1_zero = false; + bool z2_zero = false; + uint32_t opcode = 0; + FF accumulator_x = 0; + FF accumulator_y = 0; + FF msm_output_x = 0; + FF msm_output_y = 0; + }; + struct VMState { + uint32_t pc = 0; + uint32_t count = 0; + AffineElement accumulator = CycleGroup::affine_point_at_infinity; + AffineElement msm_accumulator = CycleGroup::affine_point_at_infinity; + bool is_accumulator_empty = true; + }; + struct Opcode { + bool add; + bool mul; + bool eq; + bool reset; + [[nodiscard]] uint32_t value() const + { + auto res = static_cast(add); + res += res; + res += static_cast(mul); + res += res; + res += static_cast(eq); + res += res; + res += static_cast(reset); + return res; + } + }; + static std::vector compute_transcript_state( + const std::vector>& vm_operations, + const uint32_t total_number_of_muls) + { + std::vector transcript_state; + VMState state{ + .pc = total_number_of_muls, + .count = 0, + .accumulator = CycleGroup::affine_point_at_infinity, + .msm_accumulator = CycleGroup::affine_point_at_infinity, + .is_accumulator_empty = true, + }; + VMState updated_state; + + // add an empty row. 1st row all zeroes because of our shiftable polynomials + transcript_state.emplace_back(TranscriptState{}); + for (size_t i = 0; i < vm_operations.size(); ++i) { + TranscriptState row; + const proof_system_eccvm::VMOperation& entry = vm_operations[i]; + + const bool is_mul = entry.mul; + const bool z1_zero = (entry.mul) ? entry.z1 == 0 : true; + const bool z2_zero = (entry.mul) ? entry.z2 == 0 : true; + const uint32_t num_muls = is_mul ? (static_cast(!z1_zero) + static_cast(!z2_zero)) : 0; + + updated_state = state; + + if (entry.reset) { + updated_state.is_accumulator_empty = true; + updated_state.msm_accumulator = CycleGroup::affine_point_at_infinity; + } + updated_state.pc = state.pc - num_muls; + + bool last_row = i == (vm_operations.size() - 1); + // msm transition = current row is doing a lookup to validate output = msm output + // i.e. next row is not part of MSM and current row is part of MSM + // or next row is irrelevent and current row is a straight MUL + bool next_not_msm = last_row ? true : !vm_operations[i + 1].mul; + + bool msm_transition = entry.mul && next_not_msm; + + // we reset the count in updated state if we are not accumulating and not doing an msm + bool current_msm = entry.mul; + bool current_ongoing_msm = entry.mul && !next_not_msm; + updated_state.count = current_ongoing_msm ? state.count + num_muls : 0; + + if (current_msm) { + const auto P = grumpkin::g1::element(entry.base_point); + const auto R = grumpkin::g1::element(state.msm_accumulator); + updated_state.msm_accumulator = R + P * entry.mul_scalar_full; + } + + if (entry.mul && next_not_msm) { + if (state.is_accumulator_empty) { + updated_state.accumulator = updated_state.msm_accumulator; + } else { + const auto R = grumpkin::g1::element(state.accumulator); + updated_state.accumulator = R + updated_state.msm_accumulator; + } + updated_state.is_accumulator_empty = false; + } + + bool add_accumulate = entry.add; + if (add_accumulate) { + if (state.is_accumulator_empty) { + + updated_state.accumulator = entry.base_point; + } else { + updated_state.accumulator = grumpkin::g1::element(state.accumulator) + entry.base_point; + } + updated_state.is_accumulator_empty = false; + } + row.accumulator_empty = state.is_accumulator_empty; + row.q_add = entry.add; + row.q_mul = entry.mul; + row.q_eq = entry.eq; + row.q_reset_accumulator = entry.reset; + row.q_msm_transition = msm_transition; + row.pc = state.pc; + row.msm_count = state.count; + row.base_x = (entry.add || entry.mul || entry.eq) ? entry.base_point.x : 0; + row.base_y = (entry.add || entry.mul || entry.eq) ? entry.base_point.y : 0; + row.z1 = (entry.mul) ? entry.z1 : 0; + row.z2 = (entry.mul) ? entry.z2 : 0; + row.z1_zero = z1_zero; + row.z2_zero = z2_zero; + row.opcode = Opcode{ .add = entry.add, .mul = entry.mul, .eq = entry.eq, .reset = entry.reset }.value(); + row.accumulator_x = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.x; + row.accumulator_y = (state.accumulator.is_point_at_infinity()) ? 0 : state.accumulator.y; + row.msm_output_x = + msm_transition + ? (updated_state.msm_accumulator.is_point_at_infinity() ? 0 : updated_state.msm_accumulator.x) + : 0; + row.msm_output_y = + msm_transition + ? (updated_state.msm_accumulator.is_point_at_infinity() ? 0 : updated_state.msm_accumulator.y) + : 0; + + state = updated_state; + + if (entry.mul && next_not_msm) { + state.msm_accumulator = CycleGroup::affine_point_at_infinity; + } + transcript_state.emplace_back(row); + } + + TranscriptState final_row; + final_row.pc = updated_state.pc; + final_row.accumulator_x = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.x; + final_row.accumulator_y = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.y; + final_row.accumulator_empty = updated_state.is_accumulator_empty; + + transcript_state.push_back(final_row); + return transcript_state; + } +}; +} // namespace proof_system \ No newline at end of file diff --git a/cpp/src/barretenberg/proof_system/flavor/flavor.hpp b/cpp/src/barretenberg/proof_system/flavor/flavor.hpp index 99bbe7938c..3f4254c6d6 100644 --- a/cpp/src/barretenberg/proof_system/flavor/flavor.hpp +++ b/cpp/src/barretenberg/proof_system/flavor/flavor.hpp @@ -275,6 +275,8 @@ class Standard; class StandardGrumpkin; class Ultra; class UltraGrumpkin; +class ECCVM; +class ECCVMGrumpkin; class GoblinUltra; } // namespace proof_system::honk::flavor @@ -313,5 +315,7 @@ template concept StandardFlavor = IsAnyOf concept UltraFlavor = IsAnyOf; +template concept ECCVMFlavor = IsAnyOf; + // clang-format on } // namespace proof_system diff --git a/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.cpp b/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.cpp new file mode 100644 index 0000000000..33cd8eac38 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.cpp @@ -0,0 +1,89 @@ +#include "barretenberg/honk/flavor/ecc_vm.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_definitions_fwd.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" +#include "ecc_msm_relation.hpp" + +namespace proof_system::honk::sumcheck { + +/** + * @brief Expression for the StandardArithmetic gate. + * @details The relation is defined as C(extended_edges(X)...) = + * (q_m * w_r * w_l) + (q_l * w_l) + (q_r * w_r) + (q_o * w_o) + q_c + * + * @param evals transformed to `evals + C(extended_edges(X)...)*scaling_factor` + * @param extended_edges an std::array containing the fully extended Accumulator edges. + * @param parameters contains beta, gamma, and public_input_delta, .... + * @param scaling_factor optional term to scale the evaluation before adding to evals. + */ +template +template +void ECCVMLookupRelationBase::add_edge_contribution_impl(typename AccumulatorTypes::Accumulators& accumulator, + const auto& extended_edges, + const RelationParameters& relation_params, + const FF& /*unused*/) const +{ + using View = typename std::tuple_element<0, typename AccumulatorTypes::AccumulatorViews>::type; + using Accumulator = typename std::tuple_element<0, typename AccumulatorTypes::Accumulators>::type; + + auto lookup_inverses = View(extended_edges.lookup_inverses); + + constexpr size_t NUM_TOTAL_TERMS = READ_TERMS + WRITE_TERMS; + std::array lookup_terms; + std::array denominator_accumulator; + + // The lookup relation = \sum_j (1 / read_term[j]) - \sum_k (read_counts[k] / write_term[k]) + // To get the inverses (1 / read_term[i]), (1 / write_term[i]), we have a commitment to the product of all inverses + // i.e. lookup_inverse = \prod_j (1 / read_term[j]) * \prod_k (1 / write_term[k]) + // The purpose of this next section is to derive individual inverse terms using `lookup_inverses` + // i.e. (1 / read_term[i]) = lookup_inverse * \prod_{j /ne i} (read_term[j]) * \prod_k (write_term[k]) + // (1 / write_term[i]) = lookup_inverse * \prod_j (read_term[j]) * \prod_{k ne i} (write_term[k]) + barretenberg::constexpr_for<0, READ_TERMS, 1>([&]() { + lookup_terms[i] = compute_read_term(extended_edges, relation_params, 0); + }); + barretenberg::constexpr_for<0, WRITE_TERMS, 1>([&]() { + lookup_terms[i + READ_TERMS] = compute_write_term(extended_edges, relation_params, 0); + }); + + barretenberg::constexpr_for<0, NUM_TOTAL_TERMS, 1>( + [&]() { denominator_accumulator[i] = lookup_terms[i]; }); + + barretenberg::constexpr_for<0, NUM_TOTAL_TERMS - 1, 1>( + [&]() { denominator_accumulator[i + 1] *= denominator_accumulator[i]; }); + + Accumulator inverse_accumulator = Accumulator(lookup_inverses); // denominator_accumulator[NUM_TOTAL_TERMS - 1]; + + const auto row_has_write = View(extended_edges.q_wnaf); + const auto row_has_read = View(extended_edges.msm_q_add) + View(extended_edges.msm_q_skew); + const auto inverse_exists = row_has_write + row_has_read - (row_has_write * row_has_read); + + std::get<1>(accumulator) += denominator_accumulator[NUM_TOTAL_TERMS - 1] * lookup_inverses - inverse_exists; + + // After this algo, total degree of denominator_accumulator = NUM_TOTAL_TERMA + for (size_t i = 0; i < NUM_TOTAL_TERMS - 1; ++i) { + denominator_accumulator[NUM_TOTAL_TERMS - 1 - i] = + denominator_accumulator[NUM_TOTAL_TERMS - 2 - i] * inverse_accumulator; + inverse_accumulator = inverse_accumulator * lookup_terms[NUM_TOTAL_TERMS - 1 - i]; + } + denominator_accumulator[0] = inverse_accumulator; + + // each predicate is degree-1 + // degree of relation at this point = NUM_TOTAL_TERMS + 1 + barretenberg::constexpr_for<0, READ_TERMS, 1>([&]() { + std::get<0>(accumulator) += + compute_read_term_predicate(extended_edges, relation_params, 0) * + denominator_accumulator[i]; + }); + + // each predicate is degree-1, `lookup_read_counts` is degree-1 + // degree of relation = NUM_TOTAL_TERMS + 2 = 6 + 2 + barretenberg::constexpr_for<0, WRITE_TERMS, 1>([&]() { + const auto p = compute_write_term_predicate(extended_edges, relation_params, 0); + const auto lookup_read_count = View(extended_edges.template lookup_read_counts()); + std::get<0>(accumulator) -= p * (denominator_accumulator[i + READ_TERMS] * lookup_read_count); + }); +} +template class ECCVMLookupRelationBase; +DEFINE_SUMCHECK_RELATION_CLASS(ECCVMLookupRelationBase, flavor::ECCVM); +DEFINE_SUMCHECK_RELATION_CLASS(ECCVMLookupRelationBase, flavor::ECCVMGrumpkin); + +} // namespace proof_system::honk::sumcheck diff --git a/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp b/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp new file mode 100644 index 0000000000..c3dd7b85a2 --- /dev/null +++ b/cpp/src/barretenberg/proof_system/relations/ecc_vm/ecc_lookup_relation.hpp @@ -0,0 +1,259 @@ +#pragma once +#include +#include + +#include "barretenberg/common/constexpr_utils.hpp" +#include "barretenberg/honk/sumcheck/polynomials/univariate.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_parameters.hpp" +#include "barretenberg/honk/sumcheck/relations/relation_types.hpp" +#include "barretenberg/polynomials/polynomial.hpp" + +namespace proof_system::honk::sumcheck { + +template class ECCVMLookupRelationBase { + public: + static constexpr size_t READ_TERMS = 4; + static constexpr size_t WRITE_TERMS = 2; + // 1 + polynomial degree of this relation + static constexpr size_t RELATION_LENGTH = READ_TERMS + WRITE_TERMS + 3; // 9 + + static constexpr size_t LEN_1 = RELATION_LENGTH; // grand product construction sub-relation + static constexpr size_t LEN_2 = RELATION_LENGTH; // left-shiftable polynomial sub-relation + template