Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-232: Add Assert and Abort externs #69

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions zirgen/Conversions/Typing/BuiltinComponents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ function Div(lhs: Val, rhs: Val) {
}

extern Log(message: String, vals: Val...);
extern Abort();
extern Assert(x: Val, message: String);

)";

Expand Down
8 changes: 4 additions & 4 deletions zirgen/Dialect/BigInt/IR/test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ Digest hashPublic(llvm::ArrayRef<APInt> inputs) {

struct CheckedBytesExternHandler : public Zll::ExternHandler {
std::deque<uint8_t> coeffs;
std::vector<uint64_t> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const Zll::InterpVal*> arg,
size_t outCount) override {
std::optional<std::vector<uint64_t>> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const Zll::InterpVal*> arg,
size_t outCount) override {
if (name == "readCoefficients") {
assert(outCount == 16);
if (coeffs.size() < 16) {
Expand Down
14 changes: 7 additions & 7 deletions zirgen/Dialect/Zll/IR/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ template <typename T> void formatFieldElem(const InterpVal* interpVal, llvm::raw

} // namespace

std::vector<uint64_t> ExternHandler::doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const InterpVal*> args,
size_t outCount) {
std::optional<std::vector<uint64_t>> ExternHandler::doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const InterpVal*> args,
size_t outCount) {
if (name == "readCoefficients") {
// TODO: Migrate users of readCoefficients to use readInput, or
// move readCoefficients to a circuit-specific extern handler.
Expand All @@ -68,7 +68,7 @@ std::vector<uint64_t> ExternHandler::doExtern(llvm::StringRef name,
throw std::runtime_error("wrong number of arguments to configureInput");
size_t bytesPerElem = fpArgs[0];
inputBytesPerElem[extra] = bytesPerElem;
return {};
return std::vector<uint64_t>{};
}
if (name == "readInput") {
// Usage: readInput(/*extra=*/inputName)
Expand Down Expand Up @@ -179,7 +179,7 @@ std::vector<uint64_t> ExternHandler::doExtern(llvm::StringRef name,
throw std::runtime_error(("Unused arguments in format " + extra).str());
}
os << "\n";
return {};
return std::vector<uint64_t>{};
}
throw std::runtime_error(("Unknown extern: " + name).str());
}
Expand Down Expand Up @@ -507,7 +507,7 @@ FailureOr<SmallVector<Attribute>> Interpreter::runBlock(mlir::Block& block) {
evaluator = eval;
if (failed(evaluate(evaluator))) {
if (!gotErrorMsg && !getSilenceErrors())
eval->op->emitError() << "Unknown evaluation error occured";
eval->op->emitError() << "Evaluation error occured";
return failure();
}
}
Expand Down
8 changes: 4 additions & 4 deletions zirgen/Dialect/Zll/IR/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ std::vector<uint64_t> asFpArray(llvm::ArrayRef<const Zll::InterpVal*> array);
class ExternHandler {
public:
virtual ~ExternHandler() {}
virtual std::vector<uint64_t> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const InterpVal*> arg,
size_t outCount);
virtual std::optional<std::vector<uint64_t>> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const InterpVal*> arg,
size_t outCount);

// Add input data bytes available through the readInput extern.
void addInput(llvm::StringRef inputName, llvm::StringRef inputBytes);
Expand Down
29 changes: 18 additions & 11 deletions zirgen/Dialect/Zll/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,13 @@ LogicalResult ExternOp::evaluate(Interpreter& interp,
}
// TODO: We used to flatten extension field elements here... is that necessary?
size_t outCount = getNumResults();
std::vector<uint64_t> outFp = handler->doExtern(getName(), getExtra(), adaptor.getIn(), outCount);
assert(outFp.size() == outCount);
std::optional<std::vector<uint64_t>> outFp =
handler->doExtern(getName(), getExtra(), adaptor.getIn(), outCount);
if (!outFp)
return failure();
assert(outFp->size() == outCount);
for (size_t i = 0; i < getNumResults(); i++) {
outs[i]->setVal(outFp[i]);
outs[i]->setVal((*outFp)[i]);
}
return success();
}
Expand Down Expand Up @@ -620,17 +623,19 @@ LogicalResult HashCheckedBytesOp::evaluate(Interpreter& interp,
std::vector<uint32_t> accumCoeffs(16, 0);
size_t countAccumed = 0;
for (size_t i = 0; i < adaptor.getEvalsCount(); i++) {
std::vector<uint64_t> newCoeffs = handler->doExtern("readCoefficients", "", {}, 16);
std::optional<std::vector<uint64_t>> newCoeffs =
handler->doExtern("readCoefficients", "", {}, 16);
assert(newCoeffs && "readCoefficients shouldn't fail");
auto result = field.Zero();
auto currentPower = field.One();
for (size_t j = 0; j < 16; j++) {
if (newCoeffs[j] > 255) {
if ((*newCoeffs)[j] > 255) {
throw std::runtime_error("Coefficient fails range check");
}
result = field.Add(result, field.Mul(newCoeffs[j], currentPower));
result = field.Add(result, field.Mul((*newCoeffs)[j], currentPower));
currentPower = field.Mul(currentPower, evalPt);
accumCoeffs[j] *= 256;
accumCoeffs[j] += newCoeffs[j];
accumCoeffs[j] += (*newCoeffs)[j];
}
outs[1 + i]->setVal(result);
countAccumed++;
Expand Down Expand Up @@ -667,18 +672,20 @@ LogicalResult HashCheckedBytesPublicOp::evaluate(Interpreter& interp,
auto evalPt = adaptor.getEvalPt()->getVal();
std::vector<uint32_t> coeffs;
for (size_t i = 0; i < adaptor.getEvalsCount(); i++) {
std::vector<uint64_t> newCoeffs = handler->doExtern("readCoefficients", "", {}, 16);
std::optional<std::vector<uint64_t>> newCoeffs =
handler->doExtern("readCoefficients", "", {}, 16);
assert(newCoeffs && "readCoefficients shouldn't fail");
auto result = field.Zero();
auto currentPower = field.One();
for (size_t j = 0; j < 16; j++) {
if (newCoeffs[j] > 255) {
if ((*newCoeffs)[j] > 255) {
throw std::runtime_error("Coefficient fails range check");
}
result = field.Add(result, field.Mul(newCoeffs[j], currentPower));
result = field.Add(result, field.Mul((*newCoeffs)[j], currentPower));
currentPower = field.Mul(currentPower, evalPt);
}
outs[2 + i]->setVal(result);
coeffs.insert(coeffs.end(), newCoeffs.begin(), newCoeffs.end());
coeffs.insert(coeffs.end(), (*newCoeffs).begin(), (*newCoeffs).end());
}
auto hashVal1 = psuite->hash(coeffs.data(), coeffs.size());
outs[0]->setDigest(hashVal1);
Expand Down
28 changes: 20 additions & 8 deletions zirgen/Main/RunTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ struct TestExternHandler : public zirgen::Zll::ExternHandler {
results.push_back(rem >> 16);
}

std::vector<uint64_t> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const zirgen::Zll::InterpVal*> args,
size_t outCount) override {
std::optional<std::vector<uint64_t>> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const zirgen::Zll::InterpVal*> args,
size_t outCount) override {
auto& os = llvm::outs();
os << "[" << cycle << "] ";
llvm::printEscapedString(name, os);

// Including arguments for Log duplicates information in the output
if (name != "Log") {
if (name != "Log" && name != "Assert") {
os << "(";
if (!extra.empty()) {
printEscapedString(extra, os);
Expand Down Expand Up @@ -183,18 +183,30 @@ struct TestExternHandler : public zirgen::Zll::ExternHandler {
vals[i].setVal(varArgs[i].cast<mlir::PolynomialAttr>().asArrayRef());
valPtrs[i] = &vals[i];
}
results = zirgen::Zll::ExternHandler::doExtern("log", message, valPtrs, outCount);
results = *zirgen::Zll::ExternHandler::doExtern("log", message, valPtrs, outCount);
} else if (name == "Abort") {
os << ")\n";
os.flush();
return std::nullopt;
} else if (name == "Assert") {
auto condition = args[0]->getBaseFieldVal();
llvm::StringRef message = args[1]->getAttr<mlir::StringAttr>().getValue();
if (condition != 0) {
os << " failed: " << message << "\n";
os.flush();
return std::nullopt;
}
} else if (name == "configureInput" || name == "readInput") {
// Pass through to common implementation
results = zirgen::Zll::ExternHandler::doExtern(name, extra, args, outCount);
results = *zirgen::Zll::ExternHandler::doExtern(name, extra, args, outCount);
} else {
// By default, let random externs pass
// Fill with 0, 1, 2, ...
for (uint64_t i = 0; i != outCount; ++i) {
results.push_back(i);
}
}
if (name != "Log") {
if (name != "Log" && name != "Assert") {
interleaveComma(results, os);
os << ")\n";
}
Expand Down
13 changes: 13 additions & 0 deletions zirgen/circuit/keccak/src/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,19 @@ impl<'a> CpuExecContext<'a> {
tracing::trace!("Read returns {val:?}");
Ok(val)
}

pub fn abort(&self) -> Result<()> {
Err(anyhow!("circuit aborted proving"))
}

pub fn assert(&self, condition: Val, message: &str) -> Result<()> {
if condition == Val::ZERO {
Err(anyhow!(message.to_owned()))
} else {
Ok(())
}
}

pub fn log(&self, message: &str, x: impl AsRef<[Val]>) -> Result<()> {
risc0_zirgen_dsl::codegen::default_log(message, x.as_ref())
}
Expand Down
10 changes: 5 additions & 5 deletions zirgen/circuit/recursion/test/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ struct Runner::RecursionExternHandler : public WomExternHandler {
size_t offset;
std::deque<llvm::SmallVector<uint64_t, 4>> body;

std::vector<uint64_t> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const InterpVal*> args,
size_t outCount) override {
std::optional<std::vector<uint64_t>> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const InterpVal*> args,
size_t outCount) override {
auto fpArgs = asFpArray(args);
// TODO: this probably breaks log externs
if (name == "readIOPHeader") {
Expand Down Expand Up @@ -75,7 +75,7 @@ struct Runner::RecursionExternHandler : public WomExternHandler {
body.push_back(poly);
}
}
return {};
return std::vector<uint64_t>{};
}
if (name == "readIOPBody") {
auto front = body.front();
Expand Down
13 changes: 7 additions & 6 deletions zirgen/circuit/recursion/wom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,11 @@ WomExternHandler::WomExternHandler() {
state[0] = {0, 0, 0, 0};
}

std::vector<uint64_t> WomExternHandler::doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const Zll::InterpVal*> args,
size_t outCount) {
std::optional<std::vector<uint64_t>>
WomExternHandler::doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const Zll::InterpVal*> args,
size_t outCount) {
if (name == "womWrite") {
uint64_t addr = args[0]->getBaseFieldVal();
if (state.count(addr) != 0) {
Expand All @@ -136,7 +137,7 @@ std::vector<uint64_t> WomExternHandler::doExtern(llvm::StringRef name,
data[i] = args[1 + i]->getBaseFieldVal();
}
state[addr] = data;
return {};
return std::vector<uint64_t>{};
}
if (name == "womRead") {
uint32_t addr = args[0]->getBaseFieldVal();
Expand All @@ -145,7 +146,7 @@ std::vector<uint64_t> WomExternHandler::doExtern(llvm::StringRef name,
throw std::runtime_error("INVALID WOM READ");
}
auto data = state[addr];
return {data[0], data[1], data[2], data[3]};
return std::vector<uint64_t>{data[0], data[1], data[2], data[3]};
}
return PlonkExternHandler::doExtern(name, extra, args, outCount);
}
Expand Down
8 changes: 4 additions & 4 deletions zirgen/circuit/recursion/wom.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ using WomBody = Comp<WomBodyImpl>;
class WomExternHandler : public PlonkExternHandler {
public:
WomExternHandler();
std::vector<uint64_t> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const Zll::InterpVal*> args,
size_t outCount) override;
std::optional<std::vector<uint64_t>> doExtern(llvm::StringRef name,
llvm::StringRef extra,
llvm::ArrayRef<const Zll::InterpVal*> args,
size_t outCount) override;

std::map<size_t, std::array<uint64_t, kExtSize>> state;
};
Expand Down
Loading
Loading