Skip to content

Commit

Permalink
added circuits and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yash1io committed Sep 8, 2024
1 parent ee52a07 commit cfb1acc
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 0 deletions.
103 changes: 103 additions & 0 deletions circuits/sha256Input.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
pragma circom 2.1.9;

include "circomlib/circuits/sha256/sha256.circom";
include "circomlib/circuits/mux1.circom";

// Copy over 1 block of sha256 input
// Sets bit to 1 at L_pos
template CopyOverBlock(ToCopyBits) {
// signals
signal input L_pos;
signal input in[ToCopyBits];
signal output out[ToCopyBits];

// copy over the block
component ie[ToCopyBits];
component mux[ToCopyBits];
for (var i = 0; i < ToCopyBits; i++) {
ie[i] = IsEqual();
ie[i].in[0] <== i;
ie[i].in[1] <== L_pos;

mux[i] = Mux1();
mux[i].c[0] <== in[i];
mux[i].c[1] <== 1;
mux[i].s <== ie[i].out;

out[i] <== mux[i].out;
}
}

// Prepare 1 sha256 input block
template Sha256InputBlock(BlockNumber) {
// constants
var BLOCK_LEN = 512;
var L_BITS = 64;

// variables
var PreLBlockLen = BLOCK_LEN - L_BITS;

// signals
signal input in[BLOCK_LEN];
signal input len;
signal input isLast;
signal output out[BLOCK_LEN];

// prepare CopyOverBlock
component cob = CopyOverBlock(BLOCK_LEN);
cob.L_pos <== len - (BlockNumber * BLOCK_LEN);

// prepare L
component n2b = Num2Bits(L_BITS);
n2b.in <== len;

// copy over the block up to pre-L length
for (var i = 0; i < BLOCK_LEN; i++) { cob.in[i] <== in[i]; }
for (var i = 0; i < PreLBlockLen; i++) { out[i] <== cob.out[i]; }

// copy over the L or the rest of the block
component mux[BLOCK_LEN - PreLBlockLen];
for (var i = PreLBlockLen; i < BLOCK_LEN; i++) {
var j = i - PreLBlockLen;
mux[j] = Mux1();
mux[j].c[1] <== n2b.out[BLOCK_LEN - 1 - i];
mux[j].c[0] <== cob.out[i];
mux[j].s <== isLast;
out[i] <== mux[j].out;
}
}

// Prepare sha256 input for Sha256_unsafe with tBlock as the current number of blocks
// and MaxBlockCount being the maximum number of blocks
// This template effectively implements https://datatracker.ietf.org/doc/html/rfc4634#section-4.1 as a circuit
template Sha256Input(MaxBlockCount) {

// constants
var BLOCK_LEN = 512;
var L_BITS = 64;

// variables
var PreLBlockLen = BLOCK_LEN - L_BITS;

// signals
signal input in[BLOCK_LEN * MaxBlockCount];
signal input len;
signal input tBlock;
signal output out[BLOCK_LEN * MaxBlockCount];

// copy over blocks
component inputBlock[MaxBlockCount];
component iz[MaxBlockCount];
for(var j = 0; j < MaxBlockCount; j++) {
var offset = j * BLOCK_LEN;

iz[j] = IsZero();
iz[j].in <== j - tBlock + 1;
inputBlock[j] = Sha256InputBlock(j);
inputBlock[j].len <== len;
inputBlock[j].isLast <== iz[j].out;
for (var i = 0; i < BLOCK_LEN; i++) { inputBlock[j].in[i] <== in[offset + i]; }
for (var i = 0; i < BLOCK_LEN; i++) { out[offset + i] <== inputBlock[j].out[i]; }
}

}
113 changes: 113 additions & 0 deletions circuits/sha256Unsafe.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
pragma circom 2.1.9;

include "circomlib/circuits/sha256/constants.circom";
include "circomlib/circuits/sha256/sha256compression.circom";
include "circomlib/circuits/comparators.circom";
/*
SHA256 Unsafe
Calculates the SHA256 hash of the input, using a signal to select the output round corresponding to the number of
non-empty input blocks. This implementation is referred to as "unsafe", as it relies upon the caller to ensure that
the input is padded correctly, and to ensure that the tBlock input corresponds to the actual terminating data block.
Crafted inputs could result in Length Extension Attacks.
Construction Parameters:
- nBlocks: Maximum number of 512-bit blocks for payload input
Inputs:
- in: An array of blocks exactly nBlocks in length, each block containing an array of exactly 512 bits.
Padding of the input according to RFC4634 Section 4.1 is left to the caller.
Blocks following tBlock must be supplied, and *should* contain all zeroes
- tBlock: An integer corresponding to the terminating block of the input, which contains the message padding
Outputs:
- out: An array of 256 bits corresponding to the SHA256 output as of the terminating block
*/
template Sha256_unsafe(nBlocks) {
signal input in[nBlocks][512];
signal input tBlock;

signal output out[256];

component ha0 = H(0);
component hb0 = H(1);
component hc0 = H(2);
component hd0 = H(3);
component he0 = H(4);
component hf0 = H(5);
component hg0 = H(6);
component hh0 = H(7);

component sha256compression[nBlocks];

for(var i=0; i < nBlocks; i++) {

sha256compression[i] = Sha256compression();

if (i==0) {
for(var k = 0; k < 32; k++) {
sha256compression[i].hin[0*32+k] <== ha0.out[k];
sha256compression[i].hin[1*32+k] <== hb0.out[k];
sha256compression[i].hin[2*32+k] <== hc0.out[k];
sha256compression[i].hin[3*32+k] <== hd0.out[k];
sha256compression[i].hin[4*32+k] <== he0.out[k];
sha256compression[i].hin[5*32+k] <== hf0.out[k];
sha256compression[i].hin[6*32+k] <== hg0.out[k];
sha256compression[i].hin[7*32+k] <== hh0.out[k];
}
} else {
for(var k = 0; k < 32; k++) {
sha256compression[i].hin[32*0+k] <== sha256compression[i-1].out[32*0+31-k];
sha256compression[i].hin[32*1+k] <== sha256compression[i-1].out[32*1+31-k];
sha256compression[i].hin[32*2+k] <== sha256compression[i-1].out[32*2+31-k];
sha256compression[i].hin[32*3+k] <== sha256compression[i-1].out[32*3+31-k];
sha256compression[i].hin[32*4+k] <== sha256compression[i-1].out[32*4+31-k];
sha256compression[i].hin[32*5+k] <== sha256compression[i-1].out[32*5+31-k];
sha256compression[i].hin[32*6+k] <== sha256compression[i-1].out[32*6+31-k];
sha256compression[i].hin[32*7+k] <== sha256compression[i-1].out[32*7+31-k];
}
}

for (var k = 0; k < 512; k++) {
sha256compression[i].inp[k] <== in[i][k];
}
}

// Collapse the hashing result at the terminating data block
// A modified Quin Selector allows us to select the block based on the tBlock signal
component calcTotal[256];
component eqs[256][nBlocks];

// For each bit of the output
for(var k = 0; k < 256; k++) {
calcTotal[k] = CalculateTotal(nBlocks);

// For each possible block
for (var i = 0; i < nBlocks; i++) {
// Determine if the given block index is equal to the terminating data block index
eqs[k][i] = IsEqual();
eqs[k][i].in[0] <== i;
eqs[k][i].in[1] <== tBlock - 1;

// eqs[k][i].out is 1 if the index matches. As such, at most one input to calcTotal is not 0.
// The bit corresponding to the terminating data block will be raised
calcTotal[k].nums[i] <== eqs[k][i].out * sha256compression[i].out[k];
}

out[k] <== calcTotal[k].sum;
}

}

// This circuit returns the sum of the inputs.
// n must be greater than 0.
template CalculateTotal(n) {
signal input nums[n];
signal output sum;

signal sums[n];
sums[0] <== nums[0];

for (var i=1; i < n; i++) {
sums[i] <== sums[i - 1] + nums[i];
}

sum <== sums[n - 1];
}
77 changes: 77 additions & 0 deletions circuits/sha256Var.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
pragma circom 2.1.9;

include "./Sha256Input.circom";
include "circomlib/circuits/mux1.circom";
include "circomlib/circuits/mux2.circom";
include "circomlib/circuits/mux3.circom";
include "circomlib/circuits/mux4.circom";
include "./sha256Unsafe.circom";

// Calculate power of x ^ y
function pow(x, y) {
if (y == 0) {
return 1;
} else {
return x * pow(x, y - 1);
}
}

// Caclulate sha256 of input of any length within (64 * (2 ^ BlockSpace)) characters
// Takes in array of bits and length of the string in bits
// If any of the bits after len are not 0, the result is undefined behavior
template Sha256Var(BlockSpace) {

// constants
var BLOCK_LEN = 512;
var SHA256_LEN = 256;

// variables
var MaxBlockCount = pow(2, BlockSpace);
var MaxLen = BLOCK_LEN * MaxBlockCount;
var LenMaxBits = 9 + BlockSpace; // can hold values from 2 ^ 10 to 2 ^ 13

// signals
signal input in[MaxLen];
signal input len;
signal output out[SHA256_LEN];


// calculate number of blocks needed (as bits)
signal len_plus_64;
len_plus_64 <== len + 64;
component n2b = Num2Bits(LenMaxBits);
n2b.in <== len_plus_64;
component shr = ShR(LenMaxBits, 9); // len_plus_64 >> 9
for (var i = 0; i < LenMaxBits; i++) {
shr.in[i] <== n2b.out[i];
}

// calculate number of blocks needed (as integer)
component b2n = Bits2Num(BlockSpace);
for (var k = 0; k < BlockSpace; k++) { b2n.in[k] <== shr.out[k]; }

// prepare input based on length and number of blocks
component input_blocks = Sha256Input(MaxBlockCount);
input_blocks.len <== len;
input_blocks.tBlock <== b2n.out + 1;
for (var j = 0; j < MaxBlockCount; j++) {
for (var i = 0; i < BLOCK_LEN; i++) {
input_blocks.in[j * BLOCK_LEN + i] <== in[j * BLOCK_LEN + i];
}
}

// put the selected input into sha256
component sha256_unsafe = Sha256_unsafe(MaxBlockCount);
sha256_unsafe.tBlock <== b2n.out + 1;
for (var j = 0; j < MaxBlockCount; j++) {
for (var i = 0; i < BLOCK_LEN; i++) {
sha256_unsafe.in[j][i] <== input_blocks.out[j * BLOCK_LEN + i];
}
}

// copy the output
for (var i = 0; i < SHA256_LEN; i++) {
out[i] <== sha256_unsafe.out[i];
}
}

50 changes: 50 additions & 0 deletions tests/sha256.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { WitnessTester } from "circomkit";
import { circomkit } from "./common";
import crypto from "crypto";

describe("Sha256Var", () => {
let circuit: WitnessTester<["in", "len"], ["out"]>;

describe("Sha256Var", () => {
//todo add more tests
it("Should generate input for 120-183 len (3 blocks)", async () => {
circuit = await circomkit.WitnessTester(`Sha256Var`, {
file: "sha256Var",
template: "Sha256Var",
params: [3],
});
for (let i = 120; i < 183; i++) {
const message = Array(183).fill("a").join("");
const len = message.length * 8;
const input = msgToBits(message, 8);
const msgHash = crypto.createHash("sha256").update(message).digest("hex");

await circuit.expectPass(
{
in: input,
len,
},
{ out: bufferToBitArray(Buffer.from(msgHash, "hex")) }
);
}
});
});
});

function msgToBits(msg: string, blocks: number) {
let inn = bufferToBitArray(Buffer.from(msg));
const overall_len = blocks * 512;
const add_bits = overall_len - inn.length;
inn = inn.concat(Array(add_bits).fill(0));
return inn;
}

function bufferToBitArray(b: Buffer) {
const res = [];
for (let i = 0; i < b.length; i++) {
for (let j = 0; j < 8; j++) {
res.push((b[i] >> (7 - j)) & 1);
}
}
return res;
}

0 comments on commit cfb1acc

Please sign in to comment.