Skip to content

Commit

Permalink
Implement casm class hash calculation (#1154)
Browse files Browse the repository at this point in the history
Co-authored-by: Kirill <[email protected]>
  • Loading branch information
Brivan-26 and kirugan authored Nov 27, 2023
1 parent a6b8453 commit 9bac881
Show file tree
Hide file tree
Showing 17 changed files with 298 additions and 117 deletions.
1 change: 0 additions & 1 deletion adapters/core2p2p/class.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ func AdaptClass(class core.Class, compiledHash *felt.Felt) *spec.Class {
case *core.Cairo1Class:
return &spec.Class{
CompiledHash: AdaptHash(compiledHash),
Definition: v.Compiled,
}
default:
panic(fmt.Errorf("unsupported cairo class %T (version=%d)", v, class.Version()))
Expand Down
123 changes: 123 additions & 0 deletions adapters/core2sn/core2sn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package core2sn

import (
"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/starknet"
"github.com/NethermindEth/juno/utils"
)

func AdaptCompiledClass(coreCompiledClass *core.CompiledClass) starknet.CompiledClass {
feederCompiledClass := new(starknet.CompiledClass)
feederCompiledClass.Bytecode = coreCompiledClass.Bytecode
feederCompiledClass.PythonicHints = coreCompiledClass.PythonicHints
feederCompiledClass.CompilerVersion = coreCompiledClass.CompilerVersion
feederCompiledClass.Hints = coreCompiledClass.Hints
feederCompiledClass.Prime = "0x" + coreCompiledClass.Prime.Text(felt.Base16)

feederCompiledClass.EntryPoints.External = make([]starknet.CompiledEntryPoint, len(coreCompiledClass.External))
for i, external := range coreCompiledClass.External {
feederCompiledClass.EntryPoints.External[i] = starknet.CompiledEntryPoint{
Selector: external.Selector,
Builtins: external.Builtins,
Offset: external.Offset,
}
}

feederCompiledClass.EntryPoints.L1Handler = make([]starknet.CompiledEntryPoint, len(coreCompiledClass.L1Handler))
for i, external := range coreCompiledClass.L1Handler {
feederCompiledClass.EntryPoints.L1Handler[i] = starknet.CompiledEntryPoint{
Selector: external.Selector,
Builtins: external.Builtins,
Offset: external.Offset,
}
}

feederCompiledClass.EntryPoints.Constructor = make([]starknet.CompiledEntryPoint, len(coreCompiledClass.Constructor))
for i, external := range coreCompiledClass.Constructor {
feederCompiledClass.EntryPoints.Constructor[i] = starknet.CompiledEntryPoint{
Selector: external.Selector,
Builtins: external.Builtins,
Offset: external.Offset,
}
}

return *feederCompiledClass
}

func AdaptSierraClass(class *core.Cairo1Class) *starknet.SierraDefinition {
constructors := make([]starknet.SierraEntryPoint, 0, len(class.EntryPoints.Constructor))
for _, entryPoint := range class.EntryPoints.Constructor {
constructors = append(constructors, starknet.SierraEntryPoint{
Selector: entryPoint.Selector,
Index: entryPoint.Index,
})
}

external := make([]starknet.SierraEntryPoint, 0, len(class.EntryPoints.External))
for _, entryPoint := range class.EntryPoints.External {
external = append(external, starknet.SierraEntryPoint{
Selector: entryPoint.Selector,
Index: entryPoint.Index,
})
}

handlers := make([]starknet.SierraEntryPoint, 0, len(class.EntryPoints.L1Handler))
for _, entryPoint := range class.EntryPoints.L1Handler {
handlers = append(handlers, starknet.SierraEntryPoint{
Selector: entryPoint.Selector,
Index: entryPoint.Index,
})
}

return &starknet.SierraDefinition{
Version: class.SemanticVersion,
Program: class.Program,
EntryPoints: starknet.SierraEntryPoints{
Constructor: constructors,
External: external,
L1Handler: handlers,
},
}
}

func AdaptCairo0Class(class *core.Cairo0Class) (*starknet.Cairo0Definition, error) {
decompressedProgram, err := utils.Gzip64Decode(class.Program)
if err != nil {
return nil, err
}

constructors := make([]starknet.EntryPoint, 0, len(class.Constructors))
for _, entryPoint := range class.Constructors {
constructors = append(constructors, starknet.EntryPoint{
Selector: entryPoint.Selector,
Offset: entryPoint.Offset,
})
}

external := make([]starknet.EntryPoint, 0, len(class.Externals))
for _, entryPoint := range class.Externals {
external = append(external, starknet.EntryPoint{
Selector: entryPoint.Selector,
Offset: entryPoint.Offset,
})
}

handlers := make([]starknet.EntryPoint, 0, len(class.L1Handlers))
for _, entryPoint := range class.L1Handlers {
handlers = append(handlers, starknet.EntryPoint{
Selector: entryPoint.Selector,
Offset: entryPoint.Offset,
})
}

return &starknet.Cairo0Definition{
Program: decompressedProgram,
Abi: class.Abi,
EntryPoints: starknet.EntryPoints{
Constructor: constructors,
External: external,
L1Handler: handlers,
},
}, nil
}
48 changes: 40 additions & 8 deletions adapters/sn2core/sn2core.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package sn2core

import (
"encoding/json"
"errors"
"fmt"
"math/big"

"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/crypto"
Expand Down Expand Up @@ -244,7 +244,7 @@ func AdaptDeployAccountTransaction(t *starknet.Transaction) *core.DeployAccountT
}
}

func AdaptCairo1Class(response *starknet.SierraDefinition, compiledClass json.RawMessage) (core.Class, error) {
func AdaptCairo1Class(response *starknet.SierraDefinition, compiledClass *starknet.CompiledClass) (*core.Cairo1Class, error) {
var err error

class := new(core.Cairo1Class)
Expand Down Expand Up @@ -272,28 +272,52 @@ func AdaptCairo1Class(response *starknet.SierraDefinition, compiledClass json.Ra
}

if compiledClass != nil {
if err = json.Unmarshal(compiledClass, &class.Compiled); err != nil {
adaptedCompiledClass, err := AdaptCompiledClass(compiledClass)
if err != nil {
return nil, err
}
class.Compiled = *adaptedCompiledClass
}

return class, nil
}

func AdaptCompiledClass(compiledClass *starknet.CompiledClass) (*core.CompiledClass, error) {
compiled := new(core.CompiledClass)
compiled.Bytecode = compiledClass.Bytecode
compiled.PythonicHints = compiledClass.PythonicHints
compiled.CompilerVersion = compiledClass.CompilerVersion
compiled.Hints = compiledClass.Hints

var ok bool
compiled.Prime, ok = new(big.Int).SetString(compiledClass.Prime, 0)
if !ok {
return nil, fmt.Errorf("couldn't convert prime value to big.Int: %d", compiled.Prime)
}

entryPoints := compiledClass.EntryPoints
compiled.External = utils.Map(entryPoints.External, adaptCompiledEntryPoint)
compiled.L1Handler = utils.Map(entryPoints.L1Handler, adaptCompiledEntryPoint)
compiled.Constructor = utils.Map(entryPoints.Constructor, adaptCompiledEntryPoint)

return compiled, nil
}

func AdaptCairo0Class(response *starknet.Cairo0Definition) (core.Class, error) {
class := new(core.Cairo0Class)
class.Abi = response.Abi

class.Externals = []core.EntryPoint{}
class.Externals = make([]core.EntryPoint, 0, len(response.EntryPoints.External))
for _, v := range response.EntryPoints.External {
class.Externals = append(class.Externals, core.EntryPoint{Selector: v.Selector, Offset: v.Offset})
}

class.L1Handlers = []core.EntryPoint{}
class.L1Handlers = make([]core.EntryPoint, 0, len(response.EntryPoints.L1Handler))
for _, v := range response.EntryPoints.L1Handler {
class.L1Handlers = append(class.L1Handlers, core.EntryPoint{Selector: v.Selector, Offset: v.Offset})
}

class.Constructors = []core.EntryPoint{}
class.Constructors = make([]core.EntryPoint, 0, len(response.EntryPoints.Constructor))
for _, v := range response.EntryPoints.Constructor {
class.Constructors = append(class.Constructors, core.EntryPoint{Selector: v.Selector, Offset: v.Offset})
}
Expand Down Expand Up @@ -335,7 +359,7 @@ func AdaptStateUpdate(response *starknet.StateUpdate) (*core.StateUpdate, error)
}
}

stateDiff.Nonces = make(map[felt.Felt]*felt.Felt)
stateDiff.Nonces = make(map[felt.Felt]*felt.Felt, len(response.StateDiff.Nonces))
for addrStr, nonce := range response.StateDiff.Nonces {
addr, err := new(felt.Felt).SetString(addrStr)
if err != nil {
Expand All @@ -344,7 +368,7 @@ func AdaptStateUpdate(response *starknet.StateUpdate) (*core.StateUpdate, error)
stateDiff.Nonces[*addr] = nonce
}

stateDiff.StorageDiffs = make(map[felt.Felt][]core.StorageDiff)
stateDiff.StorageDiffs = make(map[felt.Felt][]core.StorageDiff, len(response.StateDiff.StorageDiffs))
for addrStr, diffs := range response.StateDiff.StorageDiffs {
addr, err := new(felt.Felt).SetString(addrStr)
if err != nil {
Expand Down Expand Up @@ -374,3 +398,11 @@ func safeFeltToUint64(f *felt.Felt) uint64 {
}
return 0
}

func adaptCompiledEntryPoint(entryPoint starknet.CompiledEntryPoint) core.CompiledEntryPoint {
return core.CompiledEntryPoint{
Offset: entryPoint.Offset,
Selector: entryPoint.Selector,
Builtins: entryPoint.Builtins,
}
}
11 changes: 6 additions & 5 deletions adapters/sn2core/sn2core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,16 +481,17 @@ func TestClassV1(t *testing.T) {
compiled, err := client.CompiledClassDefinition(context.Background(), classHash)
require.NoError(t, err)

class, err := sn2core.AdaptCairo1Class(feederClass.V1, compiled)
v1Class, err := sn2core.AdaptCairo1Class(feederClass.V1, compiled)
require.NoError(t, err)

v1Class, ok := class.(*core.Cairo1Class)
require.True(t, ok)

assert.Equal(t, feederClass.V1.Abi, v1Class.Abi)
assert.Equal(t, feederClass.V1.Program, v1Class.Program)
assert.Equal(t, feederClass.V1.Version, v1Class.SemanticVersion)
assert.Equal(t, compiled, v1Class.Compiled)
assert.Equal(t, compiled.Prime, "0x"+v1Class.Compiled.Prime.Text(felt.Base16))
assert.Equal(t, compiled.Bytecode, v1Class.Compiled.Bytecode)
assert.Equal(t, compiled.Hints, v1Class.Compiled.Hints)
assert.Equal(t, compiled.CompilerVersion, v1Class.Compiled.CompilerVersion)
assert.Equal(t, len(compiled.EntryPoints.External), len(v1Class.Compiled.External))

assert.Equal(t, len(feederClass.V1.EntryPoints.External), len(v1Class.EntryPoints.External))
for i, v := range feederClass.V1.EntryPoints.External {
Expand Down
6 changes: 3 additions & 3 deletions clients/feeder/feeder.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (c *Client) ClassDefinition(ctx context.Context, classHash *felt.Felt) (*st
return class, nil
}

func (c *Client) CompiledClassDefinition(ctx context.Context, classHash *felt.Felt) (json.RawMessage, error) {
func (c *Client) CompiledClassDefinition(ctx context.Context, classHash *felt.Felt) (*starknet.CompiledClass, error) {
queryURL := c.buildQueryString("get_compiled_class_by_class_hash", map[string]string{
"classHash": classHash.String(),
})
Expand All @@ -332,8 +332,8 @@ func (c *Client) CompiledClassDefinition(ctx context.Context, classHash *felt.Fe
}
defer body.Close()

var class json.RawMessage
if err = json.NewDecoder(body).Decode(&class); err != nil {
class := new(starknet.CompiledClass)
if err = json.NewDecoder(body).Decode(class); err != nil {
return nil, err
}
return class, nil
Expand Down
9 changes: 7 additions & 2 deletions clients/feeder/feeder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package feeder_test

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strconv"
Expand Down Expand Up @@ -538,7 +537,13 @@ func TestCompiledClassDefinition(t *testing.T) {
classHash := utils.HexToFelt(t, "0x1cd2edfb485241c4403254d550de0a097fa76743cd30696f714a491a454bad5")
class, err := client.CompiledClassDefinition(context.Background(), classHash)
require.NoError(t, err)
require.True(t, json.Valid(class))
assert.Equal(t, "1.0.0", class.CompilerVersion)
assert.Equal(t, "0x800000000000011000000000000000000000000000000000000000000000001", class.Prime)
assert.Equal(t, 3900, len(class.Bytecode))
assert.Equal(t, 10, len(class.EntryPoints.External))
assert.Equal(t, 1, len(class.EntryPoints.External[9].Builtins))
assert.Equal(t, "range_check", class.EntryPoints.External[9].Builtins[0])
assert.Equal(t, "0x3604cea1cdb094a73a31144f14a3e5861613c008e1e879939ebc4827d10cd50", class.EntryPoints.External[9].Selector.String())
}

func TestTransactionStatusRevertError(t *testing.T) {
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

49 changes: 48 additions & 1 deletion core/class.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package core
import (
"encoding/json"
"fmt"
"math/big"

"github.com/NethermindEth/juno/core/crypto"
"github.com/NethermindEth/juno/core/felt"
Expand Down Expand Up @@ -56,7 +57,24 @@ type Cairo1Class struct {
Program []*felt.Felt
ProgramHash *felt.Felt
SemanticVersion string
Compiled json.RawMessage
Compiled CompiledClass
}

type CompiledClass struct {
Bytecode []*felt.Felt
PythonicHints json.RawMessage
CompilerVersion string
Hints json.RawMessage
Prime *big.Int
External []CompiledEntryPoint
L1Handler []CompiledEntryPoint
Constructor []CompiledEntryPoint
}

type CompiledEntryPoint struct {
Offset uint64
Builtins []string
Selector *felt.Felt
}

type SierraEntryPoint struct {
Expand All @@ -79,6 +97,18 @@ func (c *Cairo1Class) Hash() *felt.Felt {
)
}

var compiledClassV1Prefix = new(felt.Felt).SetBytes([]byte("COMPILED_CLASS_V1"))

func (c *CompiledClass) Hash() *felt.Felt {
return crypto.PoseidonArray(
compiledClassV1Prefix,
crypto.PoseidonArray(flattenCompiledEntryPoints(c.External)...),
crypto.PoseidonArray(flattenCompiledEntryPoints(c.L1Handler)...),
crypto.PoseidonArray(flattenCompiledEntryPoints(c.Constructor)...),
crypto.PoseidonArray(c.Bytecode...),
)
}

func flattenSierraEntryPoints(entryPoints []SierraEntryPoint) []*felt.Felt {
result := make([]*felt.Felt, len(entryPoints)*2)
for i, entryPoint := range entryPoints {
Expand All @@ -90,6 +120,23 @@ func flattenSierraEntryPoints(entryPoints []SierraEntryPoint) []*felt.Felt {
return result
}

func flattenCompiledEntryPoints(entryPoints []CompiledEntryPoint) []*felt.Felt {
result := make([]*felt.Felt, len(entryPoints)*3)
for i, entryPoint := range entryPoints {
// It is important that Selector is first, then Offset is second because the order
// influences the class hash.
result[3*i] = entryPoint.Selector
result[3*i+1] = new(felt.Felt).SetUint64(entryPoint.Offset)
builtins := make([]*felt.Felt, len(entryPoint.Builtins))
for idx, buil := range entryPoint.Builtins {
builtins[idx] = new(felt.Felt).SetBytes([]byte(buil))
}
result[3*i+2] = crypto.PoseidonArray(builtins...)
}

return result
}

func VerifyClassHashes(classes map[felt.Felt]Class) error {
for hash, class := range classes {
cairo1Class, ok := class.(*Cairo1Class)
Expand Down
Loading

0 comments on commit 9bac881

Please sign in to comment.