Skip to content

Commit

Permalink
Merge pull request #850 from 0xPolygonHermez/feature/optFflonkMemory
Browse files Browse the repository at this point in the history
Adding Fflonk precomputed buffer into previously allocated memory
  • Loading branch information
RogerTaule authored Jun 21, 2024
2 parents b927edf + cd7ba95 commit 35f2e0c
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 277 deletions.
145 changes: 42 additions & 103 deletions src/prover/prover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "binfile_utils.hpp"
#include "zkey_utils.hpp"
#include "wtns_utils.hpp"
#include "groth16.hpp"
#include "sm/storage/storage_executor.hpp"
#include "timer.hpp"
#include "execFile.hpp"
Expand Down Expand Up @@ -103,7 +102,7 @@ Prover::Prover(Goldilocks &fr,
pthread_create(&cleanerPthread, NULL, cleanerThread, this);

bool reduceMemoryZkevm = REDUCE_ZKEVM_MEMORY ? true : false;

StarkInfo _starkInfo(config.zkevmStarkInfo, reduceMemoryZkevm);

// Allocate an area of memory, mapped to file, to store all the committed polynomials,
Expand Down Expand Up @@ -134,53 +133,14 @@ Prover::Prover(Goldilocks &fr,
alloc_pinned_mem(uint64_t(1<<24) * _starkInfo.mapSectionsN.section[eSection::cm1_n]);
warmup_gpu();
#endif
TimerStopAndLog(PROVER_INIT);
TimerStart(PROVER_INIT_FFLONK);

zkey = BinFileUtils::openExisting(config.finalStarkZkey, "zkey", 1);
protocolId = Zkey::getProtocolIdFromZkey(zkey.get());
if (Zkey::GROTH16_PROTOCOL_ID == protocolId)
{
zkeyHeader = ZKeyUtils::loadHeader(zkey.get());

if (mpz_cmp(zkeyHeader->rPrime, altBbn128r) != 0)
{
throw std::invalid_argument("zkey curve not supported");
}

groth16Prover = Groth16::makeProver<AltBn128::Engine>(
zkeyHeader->nVars,
zkeyHeader->nPublic,
zkeyHeader->domainSize,
zkeyHeader->nCoefs,
zkeyHeader->vk_alpha1,
zkeyHeader->vk_beta1,
zkeyHeader->vk_beta2,
zkeyHeader->vk_delta1,
zkeyHeader->vk_delta2,
zkey->getSectionData(4), // Coefs
zkey->getSectionData(5), // pointsA
zkey->getSectionData(6), // pointsB1
zkey->getSectionData(7), // pointsB2
zkey->getSectionData(8), // pointsC
zkey->getSectionData(9) // pointsH1
);
} else {
prover = new Fflonk::FflonkProver<AltBn128::Engine>(AltBn128::Engine::engine, pAddress, polsSize);
prover->setZkey(zkey.get());
}

BinFileUtils::BinFile *pZkey = zkey.release();
assert(zkey.get() == nullptr);
assert(zkey == nullptr);
delete pZkey;

TimerStopAndLog(PROVER_INIT_FFLONK);


json finalVerkeyJson;
file2json(config.finalVerkey, finalVerkeyJson);
domainSizeFflonk = 1 << uint64_t(finalVerkeyJson["power"]);
nPublicsFflonk = finalVerkeyJson["nPublic"];

TimerStopAndLog(PROVER_INIT);
TimerStart(PROVER_INIT_STARKINFO);
StarkInfo _starkInfoRecursiveF(config.recursivefStarkInfo);
pAddressStarksRecursiveF = (void *)malloc(_starkInfoRecursiveF.mapTotalN * sizeof(Goldilocks::Element));

string zkevmChelpers = USE_GENERIC_PARSER ? config.zkevmGenericCHelpers : config.zkevmCHelpers;
string c12aChelpers = USE_GENERIC_PARSER ? config.c12aGenericCHelpers : config.c12aCHelpers;
Expand All @@ -200,7 +160,7 @@ Prover::Prover(Goldilocks &fr,
starksRecursive2 = new Starks(config, {config.recursive2ConstPols, config.mapConstPolsFile, config.recursive2ConstantsTree, config.recursive2StarkInfo, recursive2Chelpers}, false, pAddress);
TimerStopAndLog(PROVER_INIT_STARK_RECURSIVE2);
TimerStart(PROVER_INIT_STARK_RECURSIVEF);
starksRecursiveF = new StarkRecursiveF(config, pAddressStarksRecursiveF);
starksRecursiveF = new StarkRecursiveF(config, pAddress);
TimerStopAndLog(PROVER_INIT_STARK_RECURSIVEF);
}
}
Expand All @@ -217,17 +177,6 @@ Prover::~Prover()

if (config.generateProof())
{
Groth16::Prover<AltBn128::Engine> *pGroth16 = groth16Prover.release();
ZKeyUtils::Header *pZkeyHeader = zkeyHeader.release();

assert(groth16Prover.get() == nullptr);
assert(groth16Prover == nullptr);
assert(zkeyHeader.get() == nullptr);
assert(zkeyHeader == nullptr);

delete pGroth16;
delete pZkeyHeader;

// Unmap committed polynomials address
if (config.zkevmCmPols.size() > 0)
{
Expand All @@ -237,13 +186,10 @@ Prover::~Prover()
{
free_zkevm(pAddress);
}
free(pAddressStarksRecursiveF);
#ifdef __USE_CUDA__
free_pinned_mem();
#endif

delete prover;

delete starkZkevm;
delete starksC12a;
delete starksRecursive1;
Expand Down Expand Up @@ -974,7 +920,7 @@ void Prover::genFinalProof(ProverRequest *pProverRequest)
publics[i] = Goldilocks::fromString(zkinFinal["publics"][i]);
}

CommitPolsStarks cmPolsRecursiveF((uint8_t *)pAddressStarksRecursiveF + starksRecursiveF->starkInfo.mapOffsets.section[cm1_n] * sizeof(Goldilocks::Element), (1 << starksRecursiveF->starkInfo.starkStruct.nBits), starksRecursiveF->starkInfo.nCm1);
CommitPolsStarks cmPolsRecursiveF((uint8_t *)pAddress + starksRecursiveF->starkInfo.mapOffsets.section[cm1_n] * sizeof(Goldilocks::Element), (1 << starksRecursiveF->starkInfo.starkStruct.nBits), starksRecursiveF->starkInfo.nCm1);
#if (PROVER_FORK_ID == 10)
CircomRecursiveFFork10::getCommitedPols(&cmPolsRecursiveF, config.recursivefVerifier, config.recursivefExec, zkinFinal, (1 << starksRecursiveF->starkInfo.starkStruct.nBits), starksRecursiveF->starkInfo.nCm1);
#else
Expand Down Expand Up @@ -1087,65 +1033,58 @@ void Prover::genFinalProof(ProverRequest *pProverRequest)
json2file(publicJson, pProverRequest->publicsOutputFile());
TimerStopAndLog(SAVE_PUBLICS_JSON);

if (Zkey::GROTH16_PROTOCOL_ID != protocolId)
{
TimerStart(RAPID_SNARK);
try
{
auto [jsonProof, publicSignalsJson] = prover->prove(pWitnessFinal);
// Save proof to file
if (config.saveProofToFile)
{
json2file(jsonProof, pProverRequest->filePrefix + "final_proof.proof.json");
}
TimerStopAndLog(RAPID_SNARK);
TimerStart(PROVER_INIT_FFLONK);

// Populate Proof with the correct data
PublicInputsExtended publicInputsExtended;
publicInputsExtended.publicInputs = pProverRequest->input.publicInputsExtended.publicInputs;
pProverRequest->proof.load(jsonProof, publicSignalsJson);
prover = new Fflonk::FflonkProver<AltBn128::Engine>(AltBn128::Engine::engine, pAddress, polsSize, true);

pProverRequest->result = ZKR_SUCCESS;
}
catch (std::exception &e)
{
zklog.error("Prover::genProof() got exception in rapid SNARK:" + string(e.what()));
exitProcess();
}
uint64_t lengthPrecomputedBuffer = prover->getLengthPrecomputedBuffer(domainSizeFflonk, nPublicsFflonk);
FrElement* binPointer = (FrElement *)pAddress + lengthPrecomputedBuffer;

zkey = BinFileUtils::openExisting(config.finalStarkZkey, "zkey", 1, binPointer, polsSize - lengthPrecomputedBuffer);
protocolId = Zkey::getProtocolIdFromZkey(zkey.get());
if(protocolId != Zkey::FFLONK_PROTOCOL_ID) {
zklog.error("Prover::genBatchProof() zkey protocolId has to be Fflonk");
exitProcess();
}
else
{
// Generate Groth16 via rapid SNARK
TimerStart(RAPID_SNARK);
json jsonProof;
try
{
auto proof = groth16Prover->prove(pWitnessFinal);
jsonProof = proof->toJson();
}
catch (std::exception &e)
{
zklog.error("Prover::genProof() got exception in rapid SNARK:" + string(e.what()));
exitProcess();
}
TimerStopAndLog(RAPID_SNARK);

prover->setZkey(zkey.get());

BinFileUtils::BinFile *pZkey = zkey.release();
assert(zkey.get() == nullptr);
assert(zkey == nullptr);
delete pZkey;

TimerStopAndLog(PROVER_INIT_FFLONK);

TimerStart(RAPID_SNARK);
try
{
auto [jsonProof, publicSignalsJson] = prover->prove(pWitnessFinal);
// Save proof to file
if (config.saveProofToFile)
{
json2file(jsonProof, pProverRequest->filePrefix + "final_proof.proof.json");
}
TimerStopAndLog(RAPID_SNARK);

// Populate Proof with the correct data
PublicInputsExtended publicInputsExtended;
publicInputsExtended.publicInputs = pProverRequest->input.publicInputsExtended.publicInputs;
pProverRequest->proof.load(jsonProof, publicJson);
pProverRequest->proof.load(jsonProof, publicSignalsJson);

pProverRequest->result = ZKR_SUCCESS;
}
catch (std::exception &e)
{
zklog.error("Prover::genProof() got exception in rapid SNARK:" + string(e.what()));
exitProcess();
}

/***********/
/* Cleanup */
/***********/
delete prover;

free(pWitnessFinal);

TimerStopAndLog(PROVER_FINAL_PROOF);
Expand Down
7 changes: 3 additions & 4 deletions src/prover/prover.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "rom.hpp"
#include "proof_fflonk.hpp"
#include "alt_bn128.hpp"
#include "groth16.hpp"
#include "binfile_utils.hpp"
#include "zkey_utils.hpp"
#include "prover_request.hpp"
Expand Down Expand Up @@ -37,9 +36,7 @@ class Prover
Starks *starksRecursive2;

Fflonk::FflonkProver<AltBn128::Engine> *prover;
std::unique_ptr<Groth16::Prover<AltBn128::Engine>> groth16Prover;
std::unique_ptr<BinFileUtils::BinFile> zkey;
std::unique_ptr<ZKeyUtils::Header> zkeyHeader;
mpz_t altBbn128r;

public:
Expand All @@ -54,10 +51,12 @@ class Prover
pthread_t cleanerPthread; // Garbage collector
pthread_mutex_t mutex; // Mutex to protect the requests queues
void *pAddress = NULL;
void *pAddressStarksRecursiveF = NULL;
int protocolId;
uint64_t polsSize;
uint64_t N;

uint64_t domainSizeFflonk;
uint64_t nPublicsFflonk;
public:
const Config &config;
sem_t pendingRequestSem; // Semaphore to wakeup prover thread when a new request is available
Expand Down
20 changes: 15 additions & 5 deletions src/rapidsnark/binfile_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace BinFileUtils
readingSection = NULL;
}

BinFile::BinFile(std::string fileName, std::string _type, uint32_t maxVersion)
BinFile::BinFile(std::string fileName, std::string _type, uint32_t maxVersion, void* reservedMemoryPtr, uint64_t reservedMemorySize)
{

int fd;
Expand All @@ -70,8 +70,16 @@ namespace BinFileUtils

size = sb.st_size;
close(fd);
addr = malloc(size);

if(NULL == reservedMemoryPtr) {
addr = malloc(size);
} else {
if(size > reservedMemorySize) {
throw std::runtime_error("There is not enough memory");
}
useReservedMemory = true;
addr = reservedMemoryPtr;
}

// Determine the number of chunks and the size of each chunk
size_t numChunks = 8; //omp_get_max_threads()/2;
Expand Down Expand Up @@ -130,7 +138,9 @@ namespace BinFileUtils

BinFile::~BinFile()
{
free(addr);
if(!useReservedMemory) {
free(addr);
}
}

void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos)
Expand Down Expand Up @@ -241,9 +251,9 @@ namespace BinFileUtils
return res;
}

std::unique_ptr<BinFile> openExisting(std::string filename, std::string type, uint32_t maxVersion)
std::unique_ptr<BinFile> openExisting(std::string filename, std::string type, uint32_t maxVersion, void* reservedMemoryPtr, uint64_t reservedMemorySize)
{
return std::unique_ptr<BinFile>(new BinFile(filename, type, maxVersion));
return std::unique_ptr<BinFile>(new BinFile(filename, type, maxVersion, reservedMemoryPtr, reservedMemorySize));
}

} // Namespace
6 changes: 4 additions & 2 deletions src/rapidsnark/binfile_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace BinFileUtils
class BinFile
{

bool useReservedMemory = false;

void *addr;
u_int64_t size;
u_int64_t pos;
Expand All @@ -33,7 +35,7 @@ namespace BinFileUtils

public:
BinFile(void *data, uint64_t size, std::string type, uint32_t maxVersion);
BinFile(std::string fileName, std::string type, uint32_t maxVersion);
BinFile(std::string fileName, std::string type, uint32_t maxVersion, void* reservedMemoryPtr = NULL, uint64_t reservedMemorySize = 0);

~BinFile();

Expand All @@ -55,7 +57,7 @@ namespace BinFileUtils
void *read(uint64_t l);
};

std::unique_ptr<BinFile> openExisting(std::string filename, std::string type, uint32_t maxVersion);
std::unique_ptr<BinFile> openExisting(std::string filename, std::string type, uint32_t maxVersion, void* reservedMemoryPtr = NULL, uint64_t reservedMemorySize = 0);
}

#endif // BINFILE_UTILS_H
Loading

0 comments on commit 35f2e0c

Please sign in to comment.