diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index 0c68968e90..51d685c2bf 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -38,50 +38,51 @@ Juno is a Go implementation of a Starknet full-node client created by Nethermind ` const ( - configF = "config" - logLevelF = "log-level" - httpF = "http" - httpHostF = "http-host" - httpPortF = "http-port" - wsF = "ws" - wsHostF = "ws-host" - wsPortF = "ws-port" - dbPathF = "db-path" - networkF = "network" - ethNodeF = "eth-node" - pprofF = "pprof" - pprofHostF = "pprof-host" - pprofPortF = "pprof-port" - colourF = "colour" - pendingPollIntervalF = "pending-poll-interval" - p2pF = "p2p" - p2pAddrF = "p2p-addr" - p2pPeersF = "p2p-peers" - p2pFeederNodeF = "p2p-feeder-node" - p2pPrivateKey = "p2p-private-key" - metricsF = "metrics" - metricsHostF = "metrics-host" - metricsPortF = "metrics-port" - grpcF = "grpc" - grpcHostF = "grpc-host" - grpcPortF = "grpc-port" - maxVMsF = "max-vms" - maxVMQueueF = "max-vm-queue" - remoteDBF = "remote-db" - rpcMaxBlockScanF = "rpc-max-block-scan" - dbCacheSizeF = "db-cache-size" - dbMaxHandlesF = "db-max-handles" - gwAPIKeyF = "gw-api-key" //nolint: gosec - gwTimeoutF = "gw-timeout" //nolint: gosec - cnNameF = "cn-name" - cnFeederURLF = "cn-feeder-url" - cnGatewayURLF = "cn-gateway-url" - cnL1ChainIDF = "cn-l1-chain-id" - cnL2ChainIDF = "cn-l2-chain-id" - cnCoreContractAddressF = "cn-core-contract-address" - cnUnverifiableRangeF = "cn-unverifiable-range" - callMaxStepsF = "rpc-call-max-steps" - corsEnableF = "rpc-cors-enable" + configF = "config" + logLevelF = "log-level" + httpF = "http" + httpHostF = "http-host" + httpPortF = "http-port" + wsF = "ws" + wsHostF = "ws-host" + wsPortF = "ws-port" + dbPathF = "db-path" + networkF = "network" + ethNodeF = "eth-node" + pprofF = "pprof" + pprofHostF = "pprof-host" + pprofPortF = "pprof-port" + colourF = "colour" + pendingPollIntervalF = "pending-poll-interval" + p2pF = "p2p" + p2pAddrF = "p2p-addr" + p2pPeersF = "p2p-peers" + p2pFeederNodeF = "p2p-feeder-node" + p2pPrivateKey = "p2p-private-key" + metricsF = "metrics" + metricsHostF = "metrics-host" + metricsPortF = "metrics-port" + grpcF = "grpc" + grpcHostF = "grpc-host" + grpcPortF = "grpc-port" + maxVMsF = "max-vms" + maxVMQueueF = "max-vm-queue" + remoteDBF = "remote-db" + rpcMaxBlockScanF = "rpc-max-block-scan" + dbCacheSizeF = "db-cache-size" + dbMaxHandlesF = "db-max-handles" + gwAPIKeyF = "gw-api-key" //nolint: gosec + gwTimeoutF = "gw-timeout" //nolint: gosec + cnNameF = "cn-name" + cnFeederURLF = "cn-feeder-url" + cnGatewayURLF = "cn-gateway-url" + cnL1ChainIDF = "cn-l1-chain-id" + cnL2ChainIDF = "cn-l2-chain-id" + cnCoreContractAddressF = "cn-core-contract-address" + cnUnverifiableRangeF = "cn-unverifiable-range" + callMaxStepsF = "rpc-call-max-steps" + corsEnableF = "rpc-cors-enable" + versionedConstantsFileF = "versioned-constants-file" defaultConfig = "" defaulHost = "localhost" @@ -117,6 +118,7 @@ const ( defaultCallMaxSteps = 4_000_000 defaultGwTimeout = 5 * time.Second defaultCorsEnable = false + defaultVersionedConstantsFile = "" configFlagUsage = "The YAML configuration file." logLevelFlagUsage = "Options: trace, debug, info, warn, error." @@ -165,7 +167,8 @@ const ( gwTimeoutUsage = "Timeout for requests made to the gateway" //nolint: gosec callMaxStepsUsage = "Maximum number of steps to be executed in starknet_call requests. " + "The upper limit is 4 million steps, and any higher value will still be capped at 4 million." - corsEnableUsage = "Enable CORS on RPC endpoints" + corsEnableUsage = "Enable CORS on RPC endpoints" + versionedConstantsFileUsage = "Use custom versioned constants from provided file" ) var Version string @@ -345,6 +348,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr junoCmd.Flags().Uint(callMaxStepsF, defaultCallMaxSteps, callMaxStepsUsage) junoCmd.Flags().Duration(gwTimeoutF, defaultGwTimeout, gwTimeoutUsage) junoCmd.Flags().Bool(corsEnableF, defaultCorsEnable, corsEnableUsage) + junoCmd.Flags().String(versionedConstantsFileF, defaultVersionedConstantsFile, versionedConstantsFileUsage) junoCmd.MarkFlagsMutuallyExclusive(p2pFeederNodeF, p2pPeersF) junoCmd.AddCommand(GenP2PKeyPair()) diff --git a/node/node.go b/node/node.go index 98b63750bb..1c92bae2ff 100644 --- a/node/node.go +++ b/node/node.go @@ -44,26 +44,27 @@ const ( // Config is the top-level juno configuration. type Config struct { - LogLevel utils.LogLevel `mapstructure:"log-level"` - HTTP bool `mapstructure:"http"` - HTTPHost string `mapstructure:"http-host"` - HTTPPort uint16 `mapstructure:"http-port"` - RPCCorsEnable bool `mapstructure:"rpc-cors-enable"` - Websocket bool `mapstructure:"ws"` - WebsocketHost string `mapstructure:"ws-host"` - WebsocketPort uint16 `mapstructure:"ws-port"` - GRPC bool `mapstructure:"grpc"` - GRPCHost string `mapstructure:"grpc-host"` - GRPCPort uint16 `mapstructure:"grpc-port"` - DatabasePath string `mapstructure:"db-path"` - Network utils.Network `mapstructure:"network"` - EthNode string `mapstructure:"eth-node"` - Pprof bool `mapstructure:"pprof"` - PprofHost string `mapstructure:"pprof-host"` - PprofPort uint16 `mapstructure:"pprof-port"` - Colour bool `mapstructure:"colour"` - PendingPollInterval time.Duration `mapstructure:"pending-poll-interval"` - RemoteDB string `mapstructure:"remote-db"` + LogLevel utils.LogLevel `mapstructure:"log-level"` + HTTP bool `mapstructure:"http"` + HTTPHost string `mapstructure:"http-host"` + HTTPPort uint16 `mapstructure:"http-port"` + RPCCorsEnable bool `mapstructure:"rpc-cors-enable"` + Websocket bool `mapstructure:"ws"` + WebsocketHost string `mapstructure:"ws-host"` + WebsocketPort uint16 `mapstructure:"ws-port"` + GRPC bool `mapstructure:"grpc"` + GRPCHost string `mapstructure:"grpc-host"` + GRPCPort uint16 `mapstructure:"grpc-port"` + DatabasePath string `mapstructure:"db-path"` + Network utils.Network `mapstructure:"network"` + EthNode string `mapstructure:"eth-node"` + Pprof bool `mapstructure:"pprof"` + PprofHost string `mapstructure:"pprof-host"` + PprofPort uint16 `mapstructure:"pprof-port"` + Colour bool `mapstructure:"colour"` + PendingPollInterval time.Duration `mapstructure:"pending-poll-interval"` + RemoteDB string `mapstructure:"remote-db"` + VersionedConstantsFile string `mapstructure:"versioned-constants-file"` Metrics bool `mapstructure:"metrics"` MetricsHost string `mapstructure:"metrics-host"` @@ -140,6 +141,13 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen } } + if cfg.VersionedConstantsFile != "" { + err = vm.SetVersionedConstants(cfg.VersionedConstantsFile) + if err != nil { + return nil, fmt.Errorf("failed to set versioned constants: %w", err) + } + } + client := feeder.NewClient(cfg.Network.FeederURL).WithUserAgent(ua).WithLogger(log). WithTimeout(cfg.GatewayTimeout).WithAPIKey(cfg.GatewayAPIKey) synchronizer := sync.New(chain, adaptfeeder.New(client), log, cfg.PendingPollInterval, dbIsRemote) diff --git a/vm/rust/src/lib.rs b/vm/rust/src/lib.rs index 54a9c9ab3a..2a7207c9f4 100644 --- a/vm/rust/src/lib.rs +++ b/vm/rust/src/lib.rs @@ -433,11 +433,14 @@ lazy_static! { }; } +#[allow(static_mut_refs)] fn get_versioned_constants(version: *const c_char) -> VersionedConstants { let version_str = unsafe { CStr::from_ptr(version) }.to_str().unwrap(); let version = StarknetVersion::from_str(&version_str).unwrap_or(StarknetVersion::from_str(&"0.0.0").unwrap()); - if version < StarknetVersion::from_str(&"0.13.1").unwrap() { + if let Some(constants) = unsafe{ &CUSTOM_VERSIONED_CONSTANTS } { + constants.clone() + } else if version < StarknetVersion::from_str(&"0.13.1").unwrap() { CONSTANTS.get(&"0.13.0".to_string()).unwrap().to_owned() } else if version < StarknetVersion::from_str(&"0.13.1.1").unwrap() { CONSTANTS.get(&"0.13.1".to_string()).unwrap().to_owned() @@ -478,4 +481,36 @@ impl FromStr for StarknetVersion { Ok(StarknetVersion(a, b, c, d)) } +} + +static mut CUSTOM_VERSIONED_CONSTANTS: Option = None; + +#[no_mangle] +pub extern "C" fn setVersionedConstants(json_bytes: *const c_char) -> *const c_char { + let json_str = unsafe { + match CStr::from_ptr(json_bytes).to_str() { + Ok(s) => s, + Err(_) => return CString::new("Failed to convert JSON bytes to string").unwrap().into_raw(), + } + }; + + match serde_json::from_str(json_str) { + Ok(parsed) => unsafe { + CUSTOM_VERSIONED_CONSTANTS = Some(parsed); + CString::new("").unwrap().into_raw() // No error, return an empty string + }, + Err(_) => CString::new("Failed to parse JSON").unwrap().into_raw(), + } +} + +#[no_mangle] +pub extern "C" fn freeString(s: *mut c_char) { + if !s.is_null() { + unsafe { + // Convert the raw C string pointer back to a CString. This operation + // takes ownership of the memory again and ensures it gets deallocated + // when drop function returns. + drop(CString::from_raw(s)); + } + } } \ No newline at end of file diff --git a/vm/rust/src/versioned_constants.rs b/vm/rust/src/versioned_constants.rs new file mode 100644 index 0000000000..1b1f197af8 --- /dev/null +++ b/vm/rust/src/versioned_constants.rs @@ -0,0 +1,5 @@ +use blockifier::versioned_constants::VersionedConstants; +use std::{ + ffi::{c_char, c_uchar, c_void, CStr} +} + diff --git a/vm/vm.go b/vm/vm.go index bc3c3f0645..d20eeae826 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -35,6 +35,9 @@ extern void cairoVMExecute(char* txns_json, char* classes_json, char* paid_fees_ BlockInfo* block_info_ptr, uintptr_t readerHandle, char* chain_id, unsigned char skip_charge_fee, unsigned char skip_validate, unsigned char err_on_revert); +extern char* setVersionedConstants(char* json); +extern void freeString(char* str); + #cgo vm_debug LDFLAGS: -L./rust/target/debug -ljuno_starknet_rs -ldl -lm #cgo !vm_debug LDFLAGS: -L./rust/target/release -ljuno_starknet_rs -ldl -lm */ @@ -44,6 +47,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "os" "runtime" "runtime/cgo" "unsafe" @@ -220,6 +225,8 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead handle := cgo.NewHandle(context) defer handle.Delete() + C.setVersionedConstants(C.CString("my_json")) + cCallInfo, callInfoPinner := makeCCallInfo(callInfo) cBlockInfo := makeCBlockInfo(blockInfo, useBlobData) chainID := C.CString(network.L2ChainID) @@ -351,3 +358,30 @@ func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []co return txnsJSON, classesJSON, nil } + +func SetVersionedConstants(filename string) error { + fd, err := os.Open(filename) + if err != nil { + return err + } + defer fd.Close() + + buff, err := io.ReadAll(fd) + if err != nil { + return err + } + + jsonStr := C.CString(string(buff)) + if errCStr := C.setVersionedConstants(jsonStr); errCStr != nil { + var errStr string = C.GoString(errCStr) + // empty string is not an error + if errStr != "" { + err = errors.New(errStr) + } + // here we rely on free call on Rust side, because on Go side we can have different allocator + C.freeString((*C.char)(unsafe.Pointer(errCStr))) + } + C.free(unsafe.Pointer(jsonStr)) + + return err +} diff --git a/vm/vm_test.go b/vm/vm_test.go index 4194e8a96b..dc65cb2179 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -2,6 +2,7 @@ package vm import ( "context" + "os" "reflect" "testing" @@ -224,3 +225,18 @@ func TestExecute(t *testing.T) { require.NoError(t, err) }) } + +func TestSetVersionedConstants(t *testing.T) { + t.Run("ok", func(t *testing.T) { + err := SetVersionedConstants("./rust/versioned_constants_13_1.json") + assert.NoError(t, err) + }) + t.Run("not valid json", func(t *testing.T) { + fd, err := os.CreateTemp("", "versioned_constants_test*") + require.NoError(t, err) + defer os.Remove(fd.Name()) + + err = SetVersionedConstants(fd.Name()) + assert.ErrorContains(t, err, "Failed to parse JSON") + }) +}