From 4cf12b1b0f77d1dc8c4c65abe829b81e20f07a9d Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Fri, 16 Aug 2024 15:37:40 -0700 Subject: [PATCH 1/4] Add Nondet Inv op to BigInt --- zirgen/Dialect/BigInt/IR/Eval.cpp | 31 +++++++++++ zirgen/Dialect/BigInt/IR/Ops.cpp | 34 ++++++++++++ zirgen/Dialect/BigInt/IR/Ops.td | 2 + zirgen/Dialect/BigInt/Transforms/BUILD.bazel | 1 + .../BigInt/Transforms/LowerModularInv.cpp | 55 +++++++++++++++++++ zirgen/Dialect/BigInt/Transforms/LowerZll.cpp | 4 +- zirgen/Dialect/BigInt/Transforms/Passes.h | 1 + zirgen/Dialect/BigInt/Transforms/Passes.td | 5 ++ zirgen/circuit/bigint/BUILD.bazel | 1 + zirgen/circuit/bigint/gen_bigint.cpp | 6 ++ zirgen/circuit/bigint/op_tests.cpp | 19 +++++++ zirgen/circuit/bigint/op_tests.h | 2 +- 12 files changed, 158 insertions(+), 3 deletions(-) create mode 100644 zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp diff --git a/zirgen/Dialect/BigInt/IR/Eval.cpp b/zirgen/Dialect/BigInt/IR/Eval.cpp index 66c836da..c92502c8 100644 --- a/zirgen/Dialect/BigInt/IR/Eval.cpp +++ b/zirgen/Dialect/BigInt/IR/Eval.cpp @@ -100,6 +100,31 @@ BytePoly nondetRem(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) { return fromAPInt(rem, coeffs); } +BytePoly nondetInvMod(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) { + // Uses the formula n^(p-2) * n = 1 (mod p) to invert `lhs` (mod `rhs`) + // (via the square and multiply technique) + auto lhsInt = toAPInt(lhs); + auto rhsInt = toAPInt(rhs); + size_t maxSize = rhsInt.getBitWidth(); + APInt inv(2 * maxSize, 1); // Initialize inverse to zero, twice the width of `prime` to allow multiplication + APInt sqr(lhsInt); // Will be repeatedly squared + APInt position(2 * maxSize, 1); // Bit at `idx` will be 1, other bits will be 0 + sqr = sqr.zext(2 * maxSize); + rhsInt = rhsInt.zext(2 * maxSize); + APInt exp = rhsInt - 2; + for (size_t idx = 0; idx < maxSize; idx++) { + if (exp.intersects(position)) { + // multiply in the current power of n (i.e., n^(2^idx)) + inv = (inv * sqr).urem(rhsInt); + } + position <<= 1; // increment the bit position to test in `exp` + sqr = (sqr * sqr).urem(rhsInt); // square `sqr` to increment to `n^(2^(idx+1))` + } + inv = inv.trunc(maxSize); // We don't need the extra space used as multiply buffer + LLVM_DEBUG({ dbgs() << "inv (mod " << rhsInt << "): " << inv << "\n"; }); + return fromAPInt(inv, coeffs); +} + void printEval(const std::string& message, BytePoly poly) { risc0::FpExt tot(0); risc0::FpExt mul(1); @@ -190,6 +215,12 @@ EvalOutput eval(func::FuncOp inFunc, ArrayRef witnessValues) { polys[op.getOut()] = poly; ret.privateWitness.push_back(poly); }) + .Case([&](auto op) { + uint32_t coeffs = op.getOut().getType().getCoeffs(); + auto poly = nondetInvMod(polys[op.getLhs()], polys[op.getRhs()], coeffs); + polys[op.getOut()] = poly; + ret.privateWitness.push_back(poly); + }) .Case([&](auto op) { auto poly = polys[op.getIn()]; if (toAPInt(poly) != 0) { diff --git a/zirgen/Dialect/BigInt/IR/Ops.cpp b/zirgen/Dialect/BigInt/IR/Ops.cpp index 62da43a4..54b788c8 100644 --- a/zirgen/Dialect/BigInt/IR/Ops.cpp +++ b/zirgen/Dialect/BigInt/IR/Ops.cpp @@ -135,6 +135,34 @@ LogicalResult NondetQuotOp::inferReturnTypes(MLIRContext* ctx, return success(); } +LogicalResult NondetInvModOp::inferReturnTypes(MLIRContext* ctx, + std::optional loc, + Adaptor adaptor, + SmallVectorImpl& out) { + auto rhsType = adaptor.getRhs().getType().cast(); + size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff); + out.push_back(BigIntType::get(ctx, + /*coeffs=*/coeffsWidth, + /*maxPos=*/(1 << kBitsPerCoeff) - 1, + /*maxNeg=*/0, + /*minBits=*/0)); + return success(); +} + +LogicalResult ModularInvOp::inferReturnTypes(MLIRContext* ctx, + std::optional loc, + Adaptor adaptor, + SmallVectorImpl& out) { + auto rhsType = adaptor.getRhs().getType().cast(); + size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff); + out.push_back(BigIntType::get(ctx, + /*coeffs=*/coeffsWidth, + /*maxPos=*/(1 << kBitsPerCoeff) - 1, + /*maxNeg=*/0, + /*minBits=*/0)); + return success(); +} + LogicalResult ReduceOp::inferReturnTypes(MLIRContext* ctx, std::optional loc, Adaptor adaptor, @@ -187,6 +215,12 @@ void NondetQuotOp::emitExpr(codegen::CodegenEmitter& cg) { {getLhs(), getRhs(), toConstantValue(cg, getContext(), getType().getCoeffs())}); } +void NondetInvModOp::emitExpr(codegen::CodegenEmitter& cg) { + cg.emitFuncCall(cg.getStringAttr("nondet_inv"), + /*contextArgs=*/{"ctx"}, + {getLhs(), getRhs(), toConstantValue(cg, getContext(), getType().getCoeffs())}); +} + void ConstOp::emitExpr(codegen::CodegenEmitter& cg) { auto bytePoly = fromAPInt(getValue(), getType().getCoeffs()); SmallVector macroArgs; diff --git a/zirgen/Dialect/BigInt/IR/Ops.td b/zirgen/Dialect/BigInt/IR/Ops.td index ef04110f..b803784b 100644 --- a/zirgen/Dialect/BigInt/IR/Ops.td +++ b/zirgen/Dialect/BigInt/IR/Ops.td @@ -38,6 +38,8 @@ def SubOp : BinaryOp<"sub", [Pure, ]> {} def MulOp : BinaryOp<"mul", [Pure, Commutative]> {} def NondetRemOp : BinaryOp<"nondet_rem", [DeclareOpInterfaceMethods]> {} def NondetQuotOp : BinaryOp<"nondet_quot", [DeclareOpInterfaceMethods]> {} +def NondetInvModOp : BinaryOp<"nondet_invmod", [DeclareOpInterfaceMethods]> {} +def ModularInvOp : BinaryOp<"inv", []> {} def ReduceOp : BinaryOp<"reduce", []> {} def EqualZeroOp : BigIntOp<"eqz", [DeclareOpInterfaceMethods]> { diff --git a/zirgen/Dialect/BigInt/Transforms/BUILD.bazel b/zirgen/Dialect/BigInt/Transforms/BUILD.bazel index 91b13c22..2c4752fa 100644 --- a/zirgen/Dialect/BigInt/Transforms/BUILD.bazel +++ b/zirgen/Dialect/BigInt/Transforms/BUILD.bazel @@ -32,6 +32,7 @@ gentbl_cc_library( cc_library( name = "Transforms", srcs = [ + "LowerModularInv.cpp", "LowerReduce.cpp", "LowerZll.cpp", "PassDetail.h", diff --git a/zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp b/zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp new file mode 100644 index 00000000..95698e64 --- /dev/null +++ b/zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2024 RISC Zero, Inc. +// +// All rights reserved. + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "zirgen/Dialect/BigInt/IR/BigInt.h" +#include "zirgen/Dialect/BigInt/Transforms/PassDetail.h" +#include "zirgen/Dialect/BigInt/Transforms/Passes.h" + +using namespace mlir; + +namespace zirgen::BigInt { + +namespace { + +struct ReplaceModularInv : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ModularInvOp op, PatternRewriter& rewriter) const override { + // Construct the constant 1 + mlir::Type oneType = rewriter.getIntegerType(1); // a `1` is bitwidth 1 + auto oneAttr = rewriter.getIntegerAttr(oneType, 1); // value 1 + auto one = rewriter.create(op.getLoc(), oneAttr); + + auto inv = rewriter.create(op.getLoc(), op.getLhs(), op.getRhs()); + auto remult = rewriter.create(op.getLoc(), op.getLhs(), inv); + auto reduced = rewriter.create(op.getLoc(), remult, op.getRhs()); + auto diff = rewriter.create(op.getLoc(), reduced, one); + rewriter.create(op.getLoc(), diff); + rewriter.replaceOp(op, inv); + return success(); + } +}; + +struct LowerModularInvPass : public LowerModularInvBase { + void runOnOperation() override { + auto ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + if (applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)).failed()) { + return signalPassFailure(); + } + } +}; + +} // End namespace + +std::unique_ptr> createLowerModularInvPass() { + return std::make_unique(); +} + +} // namespace zirgen::BigInt diff --git a/zirgen/Dialect/BigInt/Transforms/LowerZll.cpp b/zirgen/Dialect/BigInt/Transforms/LowerZll.cpp index 8b13b516..01d62b1b 100644 --- a/zirgen/Dialect/BigInt/Transforms/LowerZll.cpp +++ b/zirgen/Dialect/BigInt/Transforms/LowerZll.cpp @@ -57,7 +57,7 @@ void lower(func::FuncOp inFunc) { } }) .Case([&](auto op) { countConst += op.getOut().getType().getNormalWitnessSize(); }) - .Case( + .Case( [&](auto op) { countPrivate += op.getOut().getType().getNormalWitnessSize(); }) .Case( [&](auto op) { countPrivate += op.getIn().getType().getCarryWitnessSize(); }); @@ -158,7 +158,7 @@ void lower(func::FuncOp inFunc) { valMap[op.getOut()] = builder.create(loc, valMap[op.getLhs()], valMap[op.getRhs()]); }) - .Case([&](auto op) { + .Case([&](auto op) { valMap[op.getOut()] = extractPoly(cbPrivate.getEvaluations(), curPrivate, op.getOut().getType()); }) diff --git a/zirgen/Dialect/BigInt/Transforms/Passes.h b/zirgen/Dialect/BigInt/Transforms/Passes.h index f0c1ad40..8b6152b5 100644 --- a/zirgen/Dialect/BigInt/Transforms/Passes.h +++ b/zirgen/Dialect/BigInt/Transforms/Passes.h @@ -12,6 +12,7 @@ namespace zirgen::BigInt { // Pass constructors +std::unique_ptr> createLowerModularInvPass(); std::unique_ptr> createLowerReducePass(); std::unique_ptr> createLowerZllPass(); diff --git a/zirgen/Dialect/BigInt/Transforms/Passes.td b/zirgen/Dialect/BigInt/Transforms/Passes.td index 9cf7f452..edc46268 100644 --- a/zirgen/Dialect/BigInt/Transforms/Passes.td +++ b/zirgen/Dialect/BigInt/Transforms/Passes.td @@ -4,6 +4,11 @@ include "mlir/Pass/PassBase.td" include "mlir/Rewrite/PassUtil.td" +def LowerModularInv : Pass<"lower-modular-inv", "mlir::ModuleOp"> { + let summary = "Remove BigInt::ModularInvOp by lowering it to other ops"; + let constructor = "zirgen::BigInt::createLowerModularInvPass()"; +} + def LowerReduce : Pass<"lower-reduce", "mlir::ModuleOp"> { let summary = "Remove BigInt::ReduceOp by lowering it to other ops"; let constructor = "zirgen::BigInt::createLowerReducePass()"; diff --git a/zirgen/circuit/bigint/BUILD.bazel b/zirgen/circuit/bigint/BUILD.bazel index b757e742..3e0d88d6 100644 --- a/zirgen/circuit/bigint/BUILD.bazel +++ b/zirgen/circuit/bigint/BUILD.bazel @@ -45,6 +45,7 @@ ZKRS = [ "mul_test_128", "reduce_test_8", "reduce_test_128", + "nondet_inv_test_8", ] build_circuit( diff --git a/zirgen/circuit/bigint/gen_bigint.cpp b/zirgen/circuit/bigint/gen_bigint.cpp index 94d18e79..d3fa07c5 100644 --- a/zirgen/circuit/bigint/gen_bigint.cpp +++ b/zirgen/circuit/bigint/gen_bigint.cpp @@ -119,6 +119,12 @@ int main(int argc, char* argv[]) { BigInt::setIterationCount(funcOp, rsa.iters); } // TODO: More bitwidth coverage? + for (size_t numBits : {8}) { + module.addFunc<0>("nondet_inv_test_" + std::to_string(numBits), {}, [&]() { + auto& builder = Module::getCurModule()->getBuilder(); + zirgen::BigInt::makeNondetInvTest(builder, builder.getUnknownLoc(), numBits); + }); + } for (size_t numBits : {8}) { module.addFunc<0>("const_add_test_" + std::to_string(numBits), {}, [&]() { auto& builder = Module::getCurModule()->getBuilder(); diff --git a/zirgen/circuit/bigint/op_tests.cpp b/zirgen/circuit/bigint/op_tests.cpp index e7add91a..d5d3e206 100644 --- a/zirgen/circuit/bigint/op_tests.cpp +++ b/zirgen/circuit/bigint/op_tests.cpp @@ -128,4 +128,23 @@ void makeReduceTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) { builder.create(loc, diff); } +void makeNondetInvTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) { + auto inp = builder.create(loc, bits, 0, true); + auto prime = builder.create(loc, bits, 1, true, bits - 1); // TODO: Set to 131 if we need an actual number + auto expected = builder.create(loc, bits, 2, true); + + // Construct constants + mlir::Type oneType = builder.getIntegerType(1); // a `1` is bitwidth 1 + auto oneAttr = builder.getIntegerAttr(oneType, 1); // value 1 + auto one = builder.create(loc, oneAttr); + + auto inv = builder.create(loc, inp, prime); + auto prod = builder.create(loc, inp, inv); + auto reduced = builder.create(loc, prod, prime); + auto expect_zero = builder.create(loc, reduced, one); + builder.create(loc, expect_zero); + auto result_match = builder.create(loc, inv, expected); + builder.create(loc, result_match); +} + } // namespace zirgen::BigInt diff --git a/zirgen/circuit/bigint/op_tests.h b/zirgen/circuit/bigint/op_tests.h index ed33ae39..88ce8153 100644 --- a/zirgen/circuit/bigint/op_tests.h +++ b/zirgen/circuit/bigint/op_tests.h @@ -10,7 +10,6 @@ using namespace mlir; namespace zirgen::BigInt { -void makeIsOddTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); void makeConstAddTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); void makeConstAddAltTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); void makeConstMulTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); @@ -21,5 +20,6 @@ void makeConstTwoByteTest(mlir::OpBuilder builder, mlir::Location loc, size_t bi void makeSubTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); void makeMulTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); void makeReduceTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); +void makeNondetInvTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits); } // namespace zirgen::BigInt From 8fb180f667dbb6bb6983663cfe9a6d0230965bac Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Fri, 16 Aug 2024 15:43:56 -0700 Subject: [PATCH 2/4] Add cargo bootstrap alias --- .cargo/config.toml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..cbd10938 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[alias] +bootstrap = "run --release --manifest-path zirgen/bootstrap/Cargo.toml" \ No newline at end of file From 38399f58fb1f2fef71c1cff78add6c365a7f5cc1 Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Fri, 16 Aug 2024 15:43:56 -0700 Subject: [PATCH 3/4] Add cargo bootstrap alias --- .cargo/config.toml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..211df290 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[alias] +bootstrap = "run --release --manifest-path zirgen/bootstrap/Cargo.toml" From eb314b710024d703eb0c6031cf9d5346b31de77c Mon Sep 17 00:00:00 2001 From: Tim Zerrell Date: Fri, 16 Aug 2024 16:45:11 -0700 Subject: [PATCH 4/4] Format --- zirgen/Dialect/BigInt/IR/Eval.cpp | 13 +++++++------ zirgen/Dialect/BigInt/IR/Ops.cpp | 6 +++--- .../Dialect/BigInt/Transforms/LowerModularInv.cpp | 4 ++-- zirgen/circuit/bigint/op_tests.cpp | 6 +++--- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/zirgen/Dialect/BigInt/IR/Eval.cpp b/zirgen/Dialect/BigInt/IR/Eval.cpp index c92502c8..f16c616f 100644 --- a/zirgen/Dialect/BigInt/IR/Eval.cpp +++ b/zirgen/Dialect/BigInt/IR/Eval.cpp @@ -106,9 +106,10 @@ BytePoly nondetInvMod(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) { auto lhsInt = toAPInt(lhs); auto rhsInt = toAPInt(rhs); size_t maxSize = rhsInt.getBitWidth(); - APInt inv(2 * maxSize, 1); // Initialize inverse to zero, twice the width of `prime` to allow multiplication - APInt sqr(lhsInt); // Will be repeatedly squared - APInt position(2 * maxSize, 1); // Bit at `idx` will be 1, other bits will be 0 + APInt inv(2 * maxSize, + 1); // Initialize inverse to zero, twice the width of `prime` to allow multiplication + APInt sqr(lhsInt); // Will be repeatedly squared + APInt position(2 * maxSize, 1); // Bit at `idx` will be 1, other bits will be 0 sqr = sqr.zext(2 * maxSize); rhsInt = rhsInt.zext(2 * maxSize); APInt exp = rhsInt - 2; @@ -117,10 +118,10 @@ BytePoly nondetInvMod(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) { // multiply in the current power of n (i.e., n^(2^idx)) inv = (inv * sqr).urem(rhsInt); } - position <<= 1; // increment the bit position to test in `exp` - sqr = (sqr * sqr).urem(rhsInt); // square `sqr` to increment to `n^(2^(idx+1))` + position <<= 1; // increment the bit position to test in `exp` + sqr = (sqr * sqr).urem(rhsInt); // square `sqr` to increment to `n^(2^(idx+1))` } - inv = inv.trunc(maxSize); // We don't need the extra space used as multiply buffer + inv = inv.trunc(maxSize); // We don't need the extra space used as multiply buffer LLVM_DEBUG({ dbgs() << "inv (mod " << rhsInt << "): " << inv << "\n"; }); return fromAPInt(inv, coeffs); } diff --git a/zirgen/Dialect/BigInt/IR/Ops.cpp b/zirgen/Dialect/BigInt/IR/Ops.cpp index 54b788c8..b30accfb 100644 --- a/zirgen/Dialect/BigInt/IR/Ops.cpp +++ b/zirgen/Dialect/BigInt/IR/Ops.cpp @@ -150,9 +150,9 @@ LogicalResult NondetInvModOp::inferReturnTypes(MLIRContext* ctx, } LogicalResult ModularInvOp::inferReturnTypes(MLIRContext* ctx, - std::optional loc, - Adaptor adaptor, - SmallVectorImpl& out) { + std::optional loc, + Adaptor adaptor, + SmallVectorImpl& out) { auto rhsType = adaptor.getRhs().getType().cast(); size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff); out.push_back(BigIntType::get(ctx, diff --git a/zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp b/zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp index 95698e64..b4656885 100644 --- a/zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp +++ b/zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp @@ -21,8 +21,8 @@ struct ReplaceModularInv : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ModularInvOp op, PatternRewriter& rewriter) const override { // Construct the constant 1 - mlir::Type oneType = rewriter.getIntegerType(1); // a `1` is bitwidth 1 - auto oneAttr = rewriter.getIntegerAttr(oneType, 1); // value 1 + mlir::Type oneType = rewriter.getIntegerType(1); // a `1` is bitwidth 1 + auto oneAttr = rewriter.getIntegerAttr(oneType, 1); // value 1 auto one = rewriter.create(op.getLoc(), oneAttr); auto inv = rewriter.create(op.getLoc(), op.getLhs(), op.getRhs()); diff --git a/zirgen/circuit/bigint/op_tests.cpp b/zirgen/circuit/bigint/op_tests.cpp index d5d3e206..04191baf 100644 --- a/zirgen/circuit/bigint/op_tests.cpp +++ b/zirgen/circuit/bigint/op_tests.cpp @@ -130,12 +130,12 @@ void makeReduceTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) { void makeNondetInvTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) { auto inp = builder.create(loc, bits, 0, true); - auto prime = builder.create(loc, bits, 1, true, bits - 1); // TODO: Set to 131 if we need an actual number + auto prime = builder.create(loc, bits, 1, true, bits - 1); auto expected = builder.create(loc, bits, 2, true); // Construct constants - mlir::Type oneType = builder.getIntegerType(1); // a `1` is bitwidth 1 - auto oneAttr = builder.getIntegerAttr(oneType, 1); // value 1 + mlir::Type oneType = builder.getIntegerType(1); // a `1` is bitwidth 1 + auto oneAttr = builder.getIntegerAttr(oneType, 1); // value 1 auto one = builder.create(loc, oneAttr); auto inv = builder.create(loc, inp, prime);