Skip to content

Commit

Permalink
Add --versioned-constants-file flag (#1920)
Browse files Browse the repository at this point in the history
  • Loading branch information
kirugan authored Jul 2, 2024
1 parent 70c9ccf commit 2b51b4f
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 66 deletions.
94 changes: 49 additions & 45 deletions cmd/juno/juno.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,50 +38,51 @@ Juno is a Go implementation of a Starknet full-node client created by Nethermind
`

const (
configF = "config"
logLevelF = "log-level"
httpF = "http"
httpHostF = "http-host"
httpPortF = "http-port"
wsF = "ws"
wsHostF = "ws-host"
wsPortF = "ws-port"
dbPathF = "db-path"
networkF = "network"
ethNodeF = "eth-node"
pprofF = "pprof"
pprofHostF = "pprof-host"
pprofPortF = "pprof-port"
colourF = "colour"
pendingPollIntervalF = "pending-poll-interval"
p2pF = "p2p"
p2pAddrF = "p2p-addr"
p2pPeersF = "p2p-peers"
p2pFeederNodeF = "p2p-feeder-node"
p2pPrivateKey = "p2p-private-key"
metricsF = "metrics"
metricsHostF = "metrics-host"
metricsPortF = "metrics-port"
grpcF = "grpc"
grpcHostF = "grpc-host"
grpcPortF = "grpc-port"
maxVMsF = "max-vms"
maxVMQueueF = "max-vm-queue"
remoteDBF = "remote-db"
rpcMaxBlockScanF = "rpc-max-block-scan"
dbCacheSizeF = "db-cache-size"
dbMaxHandlesF = "db-max-handles"
gwAPIKeyF = "gw-api-key" //nolint: gosec
gwTimeoutF = "gw-timeout" //nolint: gosec
cnNameF = "cn-name"
cnFeederURLF = "cn-feeder-url"
cnGatewayURLF = "cn-gateway-url"
cnL1ChainIDF = "cn-l1-chain-id"
cnL2ChainIDF = "cn-l2-chain-id"
cnCoreContractAddressF = "cn-core-contract-address"
cnUnverifiableRangeF = "cn-unverifiable-range"
callMaxStepsF = "rpc-call-max-steps"
corsEnableF = "rpc-cors-enable"
configF = "config"
logLevelF = "log-level"
httpF = "http"
httpHostF = "http-host"
httpPortF = "http-port"
wsF = "ws"
wsHostF = "ws-host"
wsPortF = "ws-port"
dbPathF = "db-path"
networkF = "network"
ethNodeF = "eth-node"
pprofF = "pprof"
pprofHostF = "pprof-host"
pprofPortF = "pprof-port"
colourF = "colour"
pendingPollIntervalF = "pending-poll-interval"
p2pF = "p2p"
p2pAddrF = "p2p-addr"
p2pPeersF = "p2p-peers"
p2pFeederNodeF = "p2p-feeder-node"
p2pPrivateKey = "p2p-private-key"
metricsF = "metrics"
metricsHostF = "metrics-host"
metricsPortF = "metrics-port"
grpcF = "grpc"
grpcHostF = "grpc-host"
grpcPortF = "grpc-port"
maxVMsF = "max-vms"
maxVMQueueF = "max-vm-queue"
remoteDBF = "remote-db"
rpcMaxBlockScanF = "rpc-max-block-scan"
dbCacheSizeF = "db-cache-size"
dbMaxHandlesF = "db-max-handles"
gwAPIKeyF = "gw-api-key" //nolint: gosec
gwTimeoutF = "gw-timeout" //nolint: gosec
cnNameF = "cn-name"
cnFeederURLF = "cn-feeder-url"
cnGatewayURLF = "cn-gateway-url"
cnL1ChainIDF = "cn-l1-chain-id"
cnL2ChainIDF = "cn-l2-chain-id"
cnCoreContractAddressF = "cn-core-contract-address"
cnUnverifiableRangeF = "cn-unverifiable-range"
callMaxStepsF = "rpc-call-max-steps"
corsEnableF = "rpc-cors-enable"
versionedConstantsFileF = "versioned-constants-file"

defaultConfig = ""
defaulHost = "localhost"
Expand Down Expand Up @@ -117,6 +118,7 @@ const (
defaultCallMaxSteps = 4_000_000
defaultGwTimeout = 5 * time.Second
defaultCorsEnable = false
defaultVersionedConstantsFile = ""

configFlagUsage = "The YAML configuration file."
logLevelFlagUsage = "Options: trace, debug, info, warn, error."
Expand Down Expand Up @@ -165,7 +167,8 @@ const (
gwTimeoutUsage = "Timeout for requests made to the gateway" //nolint: gosec
callMaxStepsUsage = "Maximum number of steps to be executed in starknet_call requests. " +
"The upper limit is 4 million steps, and any higher value will still be capped at 4 million."
corsEnableUsage = "Enable CORS on RPC endpoints"
corsEnableUsage = "Enable CORS on RPC endpoints"
versionedConstantsFileUsage = "Use custom versioned constants from provided file"
)

var Version string
Expand Down Expand Up @@ -345,6 +348,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr
junoCmd.Flags().Uint(callMaxStepsF, defaultCallMaxSteps, callMaxStepsUsage)
junoCmd.Flags().Duration(gwTimeoutF, defaultGwTimeout, gwTimeoutUsage)
junoCmd.Flags().Bool(corsEnableF, defaultCorsEnable, corsEnableUsage)
junoCmd.Flags().String(versionedConstantsFileF, defaultVersionedConstantsFile, versionedConstantsFileUsage)
junoCmd.MarkFlagsMutuallyExclusive(p2pFeederNodeF, p2pPeersF)
junoCmd.AddCommand(GenP2PKeyPair())

Expand Down
48 changes: 28 additions & 20 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,27 @@ const (

// Config is the top-level juno configuration.
type Config struct {
LogLevel utils.LogLevel `mapstructure:"log-level"`
HTTP bool `mapstructure:"http"`
HTTPHost string `mapstructure:"http-host"`
HTTPPort uint16 `mapstructure:"http-port"`
RPCCorsEnable bool `mapstructure:"rpc-cors-enable"`
Websocket bool `mapstructure:"ws"`
WebsocketHost string `mapstructure:"ws-host"`
WebsocketPort uint16 `mapstructure:"ws-port"`
GRPC bool `mapstructure:"grpc"`
GRPCHost string `mapstructure:"grpc-host"`
GRPCPort uint16 `mapstructure:"grpc-port"`
DatabasePath string `mapstructure:"db-path"`
Network utils.Network `mapstructure:"network"`
EthNode string `mapstructure:"eth-node"`
Pprof bool `mapstructure:"pprof"`
PprofHost string `mapstructure:"pprof-host"`
PprofPort uint16 `mapstructure:"pprof-port"`
Colour bool `mapstructure:"colour"`
PendingPollInterval time.Duration `mapstructure:"pending-poll-interval"`
RemoteDB string `mapstructure:"remote-db"`
LogLevel utils.LogLevel `mapstructure:"log-level"`
HTTP bool `mapstructure:"http"`
HTTPHost string `mapstructure:"http-host"`
HTTPPort uint16 `mapstructure:"http-port"`
RPCCorsEnable bool `mapstructure:"rpc-cors-enable"`
Websocket bool `mapstructure:"ws"`
WebsocketHost string `mapstructure:"ws-host"`
WebsocketPort uint16 `mapstructure:"ws-port"`
GRPC bool `mapstructure:"grpc"`
GRPCHost string `mapstructure:"grpc-host"`
GRPCPort uint16 `mapstructure:"grpc-port"`
DatabasePath string `mapstructure:"db-path"`
Network utils.Network `mapstructure:"network"`
EthNode string `mapstructure:"eth-node"`
Pprof bool `mapstructure:"pprof"`
PprofHost string `mapstructure:"pprof-host"`
PprofPort uint16 `mapstructure:"pprof-port"`
Colour bool `mapstructure:"colour"`
PendingPollInterval time.Duration `mapstructure:"pending-poll-interval"`
RemoteDB string `mapstructure:"remote-db"`
VersionedConstantsFile string `mapstructure:"versioned-constants-file"`

Metrics bool `mapstructure:"metrics"`
MetricsHost string `mapstructure:"metrics-host"`
Expand Down Expand Up @@ -140,6 +141,13 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen
}
}

if cfg.VersionedConstantsFile != "" {
err = vm.SetVersionedConstants(cfg.VersionedConstantsFile)
if err != nil {
return nil, fmt.Errorf("failed to set versioned constants: %w", err)
}
}

client := feeder.NewClient(cfg.Network.FeederURL).WithUserAgent(ua).WithLogger(log).
WithTimeout(cfg.GatewayTimeout).WithAPIKey(cfg.GatewayAPIKey)
synchronizer := sync.New(chain, adaptfeeder.New(client), log, cfg.PendingPollInterval, dbIsRemote)
Expand Down
37 changes: 36 additions & 1 deletion vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,14 @@ lazy_static! {
};
}

#[allow(static_mut_refs)]
fn get_versioned_constants(version: *const c_char) -> VersionedConstants {
let version_str = unsafe { CStr::from_ptr(version) }.to_str().unwrap();
let version = StarknetVersion::from_str(&version_str).unwrap_or(StarknetVersion::from_str(&"0.0.0").unwrap());

if version < StarknetVersion::from_str(&"0.13.1").unwrap() {
if let Some(constants) = unsafe{ &CUSTOM_VERSIONED_CONSTANTS } {
constants.clone()
} else if version < StarknetVersion::from_str(&"0.13.1").unwrap() {
CONSTANTS.get(&"0.13.0".to_string()).unwrap().to_owned()
} else if version < StarknetVersion::from_str(&"0.13.1.1").unwrap() {
CONSTANTS.get(&"0.13.1".to_string()).unwrap().to_owned()
Expand Down Expand Up @@ -478,4 +481,36 @@ impl FromStr for StarknetVersion {

Ok(StarknetVersion(a, b, c, d))
}
}

static mut CUSTOM_VERSIONED_CONSTANTS: Option<VersionedConstants> = None;

#[no_mangle]
pub extern "C" fn setVersionedConstants(json_bytes: *const c_char) -> *const c_char {
let json_str = unsafe {
match CStr::from_ptr(json_bytes).to_str() {
Ok(s) => s,
Err(_) => return CString::new("Failed to convert JSON bytes to string").unwrap().into_raw(),
}
};

match serde_json::from_str(json_str) {
Ok(parsed) => unsafe {
CUSTOM_VERSIONED_CONSTANTS = Some(parsed);
CString::new("").unwrap().into_raw() // No error, return an empty string
},
Err(_) => CString::new("Failed to parse JSON").unwrap().into_raw(),
}
}

#[no_mangle]
pub extern "C" fn freeString(s: *mut c_char) {
if !s.is_null() {
unsafe {
// Convert the raw C string pointer back to a CString. This operation
// takes ownership of the memory again and ensures it gets deallocated
// when drop function returns.
drop(CString::from_raw(s));
}
}
}
5 changes: 5 additions & 0 deletions vm/rust/src/versioned_constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use blockifier::versioned_constants::VersionedConstants;
use std::{
ffi::{c_char, c_uchar, c_void, CStr}
}

34 changes: 34 additions & 0 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ extern void cairoVMExecute(char* txns_json, char* classes_json, char* paid_fees_
BlockInfo* block_info_ptr, uintptr_t readerHandle, char* chain_id,
unsigned char skip_charge_fee, unsigned char skip_validate, unsigned char err_on_revert);
extern char* setVersionedConstants(char* json);
extern void freeString(char* str);
#cgo vm_debug LDFLAGS: -L./rust/target/debug -ljuno_starknet_rs -ldl -lm
#cgo !vm_debug LDFLAGS: -L./rust/target/release -ljuno_starknet_rs -ldl -lm
*/
Expand All @@ -44,6 +47,8 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"runtime"
"runtime/cgo"
"unsafe"
Expand Down Expand Up @@ -220,6 +225,8 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead
handle := cgo.NewHandle(context)
defer handle.Delete()

C.setVersionedConstants(C.CString("my_json"))

cCallInfo, callInfoPinner := makeCCallInfo(callInfo)
cBlockInfo := makeCBlockInfo(blockInfo, useBlobData)
chainID := C.CString(network.L2ChainID)
Expand Down Expand Up @@ -351,3 +358,30 @@ func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []co

return txnsJSON, classesJSON, nil
}

func SetVersionedConstants(filename string) error {
fd, err := os.Open(filename)
if err != nil {
return err
}
defer fd.Close()

buff, err := io.ReadAll(fd)
if err != nil {
return err
}

jsonStr := C.CString(string(buff))
if errCStr := C.setVersionedConstants(jsonStr); errCStr != nil {
var errStr string = C.GoString(errCStr)
// empty string is not an error
if errStr != "" {
err = errors.New(errStr)
}
// here we rely on free call on Rust side, because on Go side we can have different allocator
C.freeString((*C.char)(unsafe.Pointer(errCStr)))
}
C.free(unsafe.Pointer(jsonStr))

return err
}
16 changes: 16 additions & 0 deletions vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package vm

import (
"context"
"os"
"reflect"
"testing"

Expand Down Expand Up @@ -224,3 +225,18 @@ func TestExecute(t *testing.T) {
require.NoError(t, err)
})
}

func TestSetVersionedConstants(t *testing.T) {
t.Run("ok", func(t *testing.T) {
err := SetVersionedConstants("./rust/versioned_constants_13_1.json")
assert.NoError(t, err)
})
t.Run("not valid json", func(t *testing.T) {
fd, err := os.CreateTemp("", "versioned_constants_test*")
require.NoError(t, err)
defer os.Remove(fd.Name())

err = SetVersionedConstants(fd.Name())
assert.ErrorContains(t, err, "Failed to parse JSON")
})
}

0 comments on commit 2b51b4f

Please sign in to comment.