From 2d3dcce1df9e1f0eab226659e7f3e5887b87e247 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Tue, 24 Dec 2024 16:15:40 -0300 Subject: [PATCH 01/35] Update go.mod to include gorilla/websocket and enhance RPC provider with Websocket support - Added gorilla/websocket v1.5.3 as a direct dependency in go.mod. - Introduced NewWebsocketProvider function in provider.go to create a Websocket RPC Provider instance, enhancing the existing HTTP provider functionality. --- go.mod | 2 +- rpc/provider.go | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 80b238b2..cdfaca71 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.1 require ( github.com/NethermindEth/juno v0.12.2 github.com/ethereum/go-ethereum v1.14.8 + github.com/gorilla/websocket v1.5.3 github.com/joho/godotenv v1.4.0 github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249 github.com/pkg/errors v0.9.1 @@ -23,7 +24,6 @@ require ( github.com/deckarep/golang-set/v2 v2.6.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/holiman/uint256 v1.3.1 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/rpc/provider.go b/rpc/provider.go index f29ab549..95f7e4a8 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -8,6 +8,7 @@ import ( "github.com/NethermindEth/juno/core/felt" ethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/gorilla/websocket" "golang.org/x/net/publicsuffix" ) @@ -22,7 +23,7 @@ type Provider struct { chainID string } -// NewProvider creates a new rpc Provider instance. +// NewProvider creates a new HTTP rpc Provider instance. func NewProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) { jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) if err != nil { @@ -40,6 +41,20 @@ func NewProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) return &Provider{c: c}, nil } +// NewWebsocketProvider creates a new Websocket rpc Provider instance. +func NewWebsocketProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) { + var dialer websocket.Dialer + // prepend the custom client to allow users to override + options = append([]ethrpc.ClientOption{ethrpc.WithWebsocketDialer(dialer)}, options...) + c, err := ethrpc.DialOptions(context.Background(), url, options...) + + if err != nil { + return nil, err + } + + return &Provider{c: c}, nil +} + //go:generate mockgen -destination=../mocks/mock_rpc_provider.go -package=mocks -source=provider.go api type RpcProvider interface { AddInvokeTransaction(ctx context.Context, invokeTxn BroadcastInvokeTxnType) (*AddInvokeTransactionResponse, error) From 418db12860d9eb16eea16fd75760a52f56995473 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Fri, 3 Jan 2025 17:41:15 -0300 Subject: [PATCH 02/35] Enhance NewWebsocketProvider with cookie support in provider.go --- examples/websocket/README.md | 10 ++++++++ examples/websocket/main.go | 44 ++++++++++++++++++++++++++++++++++++ rpc/provider.go | 7 +++++- 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 examples/websocket/README.md create mode 100644 examples/websocket/main.go diff --git a/examples/websocket/README.md b/examples/websocket/README.md new file mode 100644 index 00000000..c0aa9026 --- /dev/null +++ b/examples/websocket/README.md @@ -0,0 +1,10 @@ +This example calls two contract functions, with and without calldata. It uses an ERC20 token, but it can be any smart contract. + +Steps: +1. Rename the ".env.template" file located at the root of the "examples" folder to ".env" +1. Uncomment, and assign your Sepolia testnet endpoint to the `RPC_PROVIDER_URL` variable in the ".env" file +1. Uncomment, and assign your account address to the `ACCOUNT_ADDRESS` variable in the ".env" file +1. Make sure you are in the "simpleCall" directory +1. Execute `go run main.go` + +The calls outuputs will be returned at the end of the execution. \ No newline at end of file diff --git a/examples/websocket/main.go b/examples/websocket/main.go new file mode 100644 index 00000000..f10357f7 --- /dev/null +++ b/examples/websocket/main.go @@ -0,0 +1,44 @@ +package main + +import ( + "context" + "fmt" + + "github.com/NethermindEth/starknet.go/rpc" + + setup "github.com/NethermindEth/starknet.go/examples/internal" +) + +// main entry point of the program. +// +// It initializes the environment and establishes a connection with the client. +// It then makes two contract calls and prints the responses. +// +// Parameters: +// +// none +// +// Returns: +// +// none +func main() { + fmt.Println("Starting simpleCall example") + + // Load variables from '.env' file + rpcProviderUrl := setup.GetRpcProviderUrl() + + // Initialize connection to RPC provider + client, err := rpc.NewWebsocketProvider(rpcProviderUrl) + if err != nil { + panic(fmt.Sprintf("Error dialing the RPC provider: %s", err)) + } + + fmt.Println("Established connection with the client") + + chainID, err := client.ChainID(context.Background()) + if err != nil { + panic(fmt.Sprintf("Error getting chain ID: %s", err)) + } + fmt.Printf("Chain ID: %s\n", chainID) + +} diff --git a/rpc/provider.go b/rpc/provider.go index 95f7e4a8..45e0b666 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -43,7 +43,12 @@ func NewProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) // NewWebsocketProvider creates a new Websocket rpc Provider instance. func NewWebsocketProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) { - var dialer websocket.Dialer + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + return nil, err + } + dialer := websocket.Dialer{Jar: jar} + // prepend the custom client to allow users to override options = append([]ethrpc.ClientOption{ethrpc.WithWebsocketDialer(dialer)}, options...) c, err := ethrpc.DialOptions(context.Background(), url, options...) From 124776485fbcbfd855d43cd16443739319a9c19b Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Mon, 6 Jan 2025 08:51:00 -0300 Subject: [PATCH 03/35] Added new errors --- rpc/errors.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/rpc/errors.go b/rpc/errors.go index 5ea2f691..aad5eb5f 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -272,6 +272,22 @@ var ( Code: 63, Message: "An unexpected error occurred", } + ErrInvalidSubscriptionID = &RPCError{ + Code: 66, + Message: "Invalid subscription id", + } + ErrTooManyAddressesInFilter = &RPCError{ + Code: 67, + Message: "Too many addresses in filter sender_address filter", + } + ErrTooManyBlocksBack = &RPCError{ + Code: 68, + Message: "Cannot go back more than 1024 blocks", + } + ErrCallOnPending = &RPCError{ + Code: 69, + Message: "This method does not support being called on the pending block", + } ErrCompilationError = &RPCError{ Code: 100, Message: "Failed to compile the contract", From f7f342f43e819a6a897ce5d52449771c5958de91 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Tue, 7 Jan 2025 16:08:37 -0300 Subject: [PATCH 04/35] Refactor RPC client and enhance WebSocket support - Removed the dependency on github.com/ethereum/go-ethereum and replaced it with github.com/NethermindEth/starknet.go/client in go.mod and related files. - Introduced new WebSocket provider functionality in provider.go --- client/client.go | 721 ++++++++++++++++++++++++++++ client/client_opt.go | 144 ++++++ client/client_test.go | 951 +++++++++++++++++++++++++++++++++++++ client/context_headers.go | 56 +++ client/errors.go | 156 ++++++ client/handler.go | 612 ++++++++++++++++++++++++ client/http.go | 395 +++++++++++++++ client/http_test.go | 245 ++++++++++ client/inproc.go | 34 ++ client/json.go | 369 ++++++++++++++ client/log/format.go | 363 ++++++++++++++ client/log/handler.go | 199 ++++++++ client/log/logger.go | 216 +++++++++ client/log/root.go | 115 +++++ client/server.go | 271 +++++++++++ client/server_test.go | 194 ++++++++ client/service.go | 249 ++++++++++ client/subscription.go | 378 +++++++++++++++ client/testservice_test.go | 229 +++++++++ client/types.go | 44 ++ client/websocket.go | 376 +++++++++++++++ go.mod | 14 +- go.sum | 27 -- rpc/client.go | 11 +- rpc/provider.go | 23 +- rpc/websocket.go | 19 + 26 files changed, 6361 insertions(+), 50 deletions(-) create mode 100644 client/client.go create mode 100644 client/client_opt.go create mode 100644 client/client_test.go create mode 100644 client/context_headers.go create mode 100644 client/errors.go create mode 100644 client/handler.go create mode 100644 client/http.go create mode 100644 client/http_test.go create mode 100644 client/inproc.go create mode 100644 client/json.go create mode 100644 client/log/format.go create mode 100644 client/log/handler.go create mode 100644 client/log/logger.go create mode 100644 client/log/root.go create mode 100644 client/server.go create mode 100644 client/server_test.go create mode 100644 client/service.go create mode 100644 client/subscription.go create mode 100644 client/testservice_test.go create mode 100644 client/types.go create mode 100644 client/websocket.go create mode 100644 rpc/websocket.go diff --git a/client/client.go b/client/client.go new file mode 100644 index 00000000..3988f542 --- /dev/null +++ b/client/client.go @@ -0,0 +1,721 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "reflect" + "strconv" + "sync/atomic" + "time" + + "github.com/NethermindEth/starknet.go/client/log" +) + +var ( + ErrBadResult = errors.New("bad result in JSON-RPC response") + ErrClientQuit = errors.New("client is closed") + ErrNoResult = errors.New("JSON-RPC response has no result") + ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call") + ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") + errClientReconnected = errors.New("client reconnected") + errDead = errors.New("connection lost") +) + +// Timeouts +const ( + defaultDialTimeout = 10 * time.Second // used if context has no deadline + subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls + unsubscribeTimeout = 10 * time.Second // timeout for *_unsubscribe calls +) + +const ( + // Subscriptions are removed when the subscriber cannot keep up. + // + // This can be worked around by supplying a channel with sufficiently sized buffer, + // but this can be inconvenient and hard to explain in the docs. Another issue with + // buffered channels is that the buffer is static even though it might not be needed + // most of the time. + // + // The approach taken here is to maintain a per-subscription linked list buffer + // shrinks on demand. If the buffer reaches the size below, the subscription is + // dropped. + maxClientSubscriptionBuffer = 20000 +) + +// BatchElem is an element in a batch request. +type BatchElem struct { + Method string + Args []interface{} + // The result is unmarshaled into this field. Result must be set to a + // non-nil pointer value of the desired type, otherwise the response will be + // discarded. + Result interface{} + // Error is set if the server returns an error for this request, or if + // unmarshalling into Result fails. It is not set for I/O errors. + Error error +} + +// Client represents a connection to an RPC server. +type Client struct { + idgen func() ID // for subscriptions + isHTTP bool // connection type: http, ws or ipc + services *serviceRegistry + + idCounter atomic.Uint32 + + // This function, if non-nil, is called when the connection is lost. + reconnectFunc reconnectFunc + + // config fields + batchItemLimit int + batchResponseMaxSize int + + // writeConn is used for writing to the connection on the caller's goroutine. It should + // only be accessed outside of dispatch, with the write lock held. The write lock is + // taken by sending on reqInit and released by sending on reqSent. + writeConn jsonWriter + + // for dispatch + close chan struct{} + closing chan struct{} // closed when client is quitting + didClose chan struct{} // closed when client quits + reconnected chan ServerCodec // where write/reconnect sends the new connection + readOp chan readOp // read messages + readErr chan error // errors from read + reqInit chan *requestOp // register response IDs, takes write lock + reqSent chan error // signals write completion, releases write lock + reqTimeout chan *requestOp // removes response IDs when call timeout expires +} + +type reconnectFunc func(context.Context) (ServerCodec, error) + +type clientContextKey struct{} + +type clientConn struct { + codec ServerCodec + handler *handler +} + +func (c *Client) newClientConn(conn ServerCodec) *clientConn { + ctx := context.Background() + ctx = context.WithValue(ctx, clientContextKey{}, c) + ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo()) + handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize) + return &clientConn{conn, handler} +} + +func (cc *clientConn) close(err error, inflightReq *requestOp) { + cc.handler.close(err, inflightReq) + cc.codec.close() +} + +type readOp struct { + msgs []*jsonrpcMessage + batch bool +} + +// requestOp represents a pending request. This is used for both batch and non-batch +// requests. +type requestOp struct { + ids []json.RawMessage + err error + resp chan []*jsonrpcMessage // the response goes here + sub *ClientSubscription // set for Subscribe requests. + hadResponse bool // true when the request was responded to +} + +func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) { + select { + case <-ctx.Done(): + // Send the timeout to dispatch so it can remove the request IDs. + if !c.isHTTP { + select { + case c.reqTimeout <- op: + case <-c.closing: + } + } + return nil, ctx.Err() + case resp := <-op.resp: + return resp, op.err + } +} + +// Dial creates a new client for the given URL. +// +// The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a +// file name with no URL scheme, a local socket connection is established using UNIX +// domain sockets on supported platforms and named pipes on Windows. +// +// If you want to further configure the transport, use DialOptions instead of this +// function. +// +// For websocket connections, the origin is set to the local host name. +// +// The client reconnects automatically when the connection is lost. +func Dial(rawurl string) (*Client, error) { + return DialOptions(context.Background(), rawurl) +} + +// DialContext creates a new RPC client, just like Dial. +// +// The context is used to cancel or time out the initial connection establishment. It does +// not affect subsequent interactions with the client. +func DialContext(ctx context.Context, rawurl string) (*Client, error) { + return DialOptions(ctx, rawurl) +} + +// DialOptions creates a new RPC client for the given URL. You can supply any of the +// pre-defined client options to configure the underlying transport. +// +// The context is used to cancel or time out the initial connection establishment. It does +// not affect subsequent interactions with the client. +// +// The client reconnects automatically when the connection is lost. +func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*Client, error) { + u, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + + cfg := new(clientConfig) + for _, opt := range options { + opt.applyOption(cfg) + } + + var reconnect reconnectFunc + switch u.Scheme { + case "http", "https": + reconnect = newClientTransportHTTP(rawurl, cfg) + case "ws", "wss": + rc, err := newClientTransportWS(rawurl, cfg) + if err != nil { + return nil, err + } + reconnect = rc + default: + return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) + } + + return newClient(ctx, cfg, reconnect) +} + +// ClientFromContext retrieves the client from the context, if any. This can be used to perform +// 'reverse calls' in a handler method. +func ClientFromContext(ctx context.Context) (*Client, bool) { + client, ok := ctx.Value(clientContextKey{}).(*Client) + return client, ok +} + +func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) { + conn, err := connect(initctx) + if err != nil { + return nil, err + } + c := initClient(conn, new(serviceRegistry), cfg) + c.reconnectFunc = connect + return c, nil +} + +func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client { + _, isHTTP := conn.(*httpConn) + c := &Client{ + isHTTP: isHTTP, + services: services, + idgen: cfg.idgen, + batchItemLimit: cfg.batchItemLimit, + batchResponseMaxSize: cfg.batchResponseLimit, + writeConn: conn, + close: make(chan struct{}), + closing: make(chan struct{}), + didClose: make(chan struct{}), + reconnected: make(chan ServerCodec), + readOp: make(chan readOp), + readErr: make(chan error), + reqInit: make(chan *requestOp), + reqSent: make(chan error, 1), + reqTimeout: make(chan *requestOp), + } + + // Set defaults. + if c.idgen == nil { + c.idgen = randomIDGenerator() + } + + // Launch the main loop. + if !isHTTP { + go c.dispatch(conn) + } + return c +} + +// RegisterName creates a service for the given receiver type under the given name. When no +// methods on the given receiver match the criteria to be either a RPC method or a +// subscription an error is returned. Otherwise a new service is created and added to the +// service collection this client provides to the server. +func (c *Client) RegisterName(name string, receiver interface{}) error { + return c.services.registerName(name, receiver) +} + +func (c *Client) nextID() json.RawMessage { + id := c.idCounter.Add(1) + return strconv.AppendUint(nil, uint64(id), 10) +} + +// SupportedModules calls the rpc_modules method, retrieving the list of +// APIs that are available on the server. +func (c *Client) SupportedModules() (map[string]string, error) { + var result map[string]string + ctx, cancel := context.WithTimeout(context.Background(), subscribeTimeout) + defer cancel() + err := c.CallContext(ctx, &result, "rpc_modules") + return result, err +} + +// Close closes the client, aborting any in-flight requests. +func (c *Client) Close() { + if c.isHTTP { + return + } + select { + case c.close <- struct{}{}: + <-c.didClose + case <-c.didClose: + } +} + +// SetHeader adds a custom HTTP header to the client's requests. +// This method only works for clients using HTTP, it doesn't have +// any effect for clients using another transport. +func (c *Client) SetHeader(key, value string) { + if !c.isHTTP { + return + } + conn := c.writeConn.(*httpConn) + conn.mu.Lock() + conn.headers.Set(key, value) + conn.mu.Unlock() +} + +// Call performs a JSON-RPC call with the given arguments and unmarshals into +// result if no error occurred. +// +// The result must be a pointer so that package json can unmarshal into it. You +// can also pass nil, in which case the result is ignored. +func (c *Client) Call(result interface{}, method string, args ...interface{}) error { + ctx := context.Background() + return c.CallContext(ctx, result, method, args...) +} + +// CallContext performs a JSON-RPC call with the given arguments. If the context is +// canceled before the call has successfully returned, CallContext returns immediately. +// +// The result must be a pointer so that package json can unmarshal into it. You +// can also pass nil, in which case the result is ignored. +func (c *Client) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { + if result != nil && reflect.TypeOf(result).Kind() != reflect.Ptr { + return fmt.Errorf("call result parameter must be pointer or nil interface: %v", result) + } + msg, err := c.newMessage(method, args...) + if err != nil { + return err + } + op := &requestOp{ + ids: []json.RawMessage{msg.ID}, + resp: make(chan []*jsonrpcMessage, 1), + } + + if c.isHTTP { + err = c.sendHTTP(ctx, op, msg) + } else { + err = c.send(ctx, op, msg) + } + if err != nil { + return err + } + + // dispatch has accepted the request and will close the channel when it quits. + batchresp, err := op.wait(ctx, c) + if err != nil { + return err + } + resp := batchresp[0] + switch { + case resp.Error != nil: + return resp.Error + case len(resp.Result) == 0: + return ErrNoResult + default: + if result == nil { + return nil + } + return json.Unmarshal(resp.Result, result) + } +} + +// BatchCall sends all given requests as a single batch and waits for the server +// to return a response for all of them. +// +// In contrast to Call, BatchCall only returns I/O errors. Any error specific to +// a request is reported through the Error field of the corresponding BatchElem. +// +// Note that batch calls may not be executed atomically on the server side. +func (c *Client) BatchCall(b []BatchElem) error { + ctx := context.Background() + return c.BatchCallContext(ctx, b) +} + +// BatchCallContext sends all given requests as a single batch and waits for the server +// to return a response for all of them. The wait duration is bounded by the +// context's deadline. +// +// In contrast to CallContext, BatchCallContext only returns errors that have occurred +// while sending the request. Any error specific to a request is reported through the +// Error field of the corresponding BatchElem. +// +// Note that batch calls may not be executed atomically on the server side. +func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { + var ( + msgs = make([]*jsonrpcMessage, len(b)) + byID = make(map[string]int, len(b)) + ) + op := &requestOp{ + ids: make([]json.RawMessage, len(b)), + resp: make(chan []*jsonrpcMessage, 1), + } + for i, elem := range b { + msg, err := c.newMessage(elem.Method, elem.Args...) + if err != nil { + return err + } + msgs[i] = msg + op.ids[i] = msg.ID + byID[string(msg.ID)] = i + } + + var err error + if c.isHTTP { + err = c.sendBatchHTTP(ctx, op, msgs) + } else { + err = c.send(ctx, op, msgs) + } + if err != nil { + return err + } + + batchresp, err := op.wait(ctx, c) + if err != nil { + return err + } + + // Wait for all responses to come back. + for n := 0; n < len(batchresp); n++ { + resp := batchresp[n] + if resp == nil { + // Ignore null responses. These can happen for batches sent via HTTP. + continue + } + + // Find the element corresponding to this response. + index, ok := byID[string(resp.ID)] + if !ok { + continue + } + delete(byID, string(resp.ID)) + + // Assign result and error. + elem := &b[index] + switch { + case resp.Error != nil: + elem.Error = resp.Error + case resp.Result == nil: + elem.Error = ErrNoResult + default: + elem.Error = json.Unmarshal(resp.Result, elem.Result) + } + } + + // Check that all expected responses have been received. + for _, index := range byID { + elem := &b[index] + elem.Error = ErrMissingBatchResponse + } + + return err +} + +// Notify sends a notification, i.e. a method call that doesn't expect a response. +func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) error { + op := new(requestOp) + msg, err := c.newMessage(method, args...) + if err != nil { + return err + } + msg.ID = nil + + if c.isHTTP { + return c.sendHTTP(ctx, op, msg) + } + return c.send(ctx, op, msg) +} + +// EthSubscribe registers a subscription under the "eth" namespace. +func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + return c.Subscribe(ctx, "eth", channel, args...) +} + +// ShhSubscribe registers a subscription under the "shh" namespace. +// Deprecated: use Subscribe(ctx, "shh", ...). +func (c *Client) ShhSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + return c.Subscribe(ctx, "shh", channel, args...) +} + +// Subscribe calls the "_subscribe" method with the given arguments, +// registering a subscription. Server notifications for the subscription are +// sent to the given channel. The element type of the channel must match the +// expected type of content returned by the subscription. +// +// The context argument cancels the RPC request that sets up the subscription but has no +// effect on the subscription after Subscribe has returned. +// +// Slow subscribers will be dropped eventually. Client buffers up to 20000 notifications +// before considering the subscriber dead. The subscription Err channel will receive +// ErrSubscriptionQueueOverflow. Use a sufficiently large buffer on the channel or ensure +// that the channel usually has at least one reader to prevent this issue. +func (c *Client) Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + // Check type of channel first. + chanVal := reflect.ValueOf(channel) + if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { + panic(fmt.Sprintf("channel argument of Subscribe has type %T, need writable channel", channel)) + } + if chanVal.IsNil() { + panic("channel given to Subscribe must not be nil") + } + if c.isHTTP { + return nil, ErrNotificationsUnsupported + } + + msg, err := c.newMessage(namespace+subscribeMethodSuffix, args...) + if err != nil { + return nil, err + } + op := &requestOp{ + ids: []json.RawMessage{msg.ID}, + resp: make(chan []*jsonrpcMessage, 1), + sub: newClientSubscription(c, namespace, chanVal), + } + + // Send the subscription request. + // The arrival and validity of the response is signaled on sub.quit. + if err := c.send(ctx, op, msg); err != nil { + return nil, err + } + if _, err := op.wait(ctx, c); err != nil { + return nil, err + } + return op.sub, nil +} + +// SupportsSubscriptions reports whether subscriptions are supported by the client +// transport. When this returns false, Subscribe and related methods will return +// ErrNotificationsUnsupported. +func (c *Client) SupportsSubscriptions() bool { + return !c.isHTTP +} + +func (c *Client) newMessage(method string, paramsIn ...interface{}) (*jsonrpcMessage, error) { + msg := &jsonrpcMessage{Version: vsn, ID: c.nextID(), Method: method} + if paramsIn != nil { // prevent sending "params":null + var err error + if msg.Params, err = json.Marshal(paramsIn); err != nil { + return nil, err + } + } + return msg, nil +} + +// send registers op with the dispatch loop, then sends msg on the connection. +// if sending fails, op is deregistered. +func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error { + select { + case c.reqInit <- op: + err := c.write(ctx, msg, false) + c.reqSent <- err + return err + case <-ctx.Done(): + // This can happen if the client is overloaded or unable to keep up with + // subscription notifications. + return ctx.Err() + case <-c.closing: + return ErrClientQuit + } +} + +func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { + if c.writeConn == nil { + // The previous write failed. Try to establish a new connection. + if err := c.reconnect(ctx); err != nil { + return err + } + } + err := c.writeConn.writeJSON(ctx, msg, false) + if err != nil { + c.writeConn = nil + if !retry { + return c.write(ctx, msg, true) + } + } + return err +} + +func (c *Client) reconnect(ctx context.Context) error { + if c.reconnectFunc == nil { + return errDead + } + + if _, ok := ctx.Deadline(); !ok { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, defaultDialTimeout) + defer cancel() + } + newconn, err := c.reconnectFunc(ctx) + if err != nil { + log.Trace("RPC client reconnect failed", "err", err) + return err + } + select { + case c.reconnected <- newconn: + c.writeConn = newconn + return nil + case <-c.didClose: + newconn.close() + return ErrClientQuit + } +} + +// dispatch is the main loop of the client. +// It sends read messages to waiting calls to Call and BatchCall +// and subscription notifications to registered subscriptions. +func (c *Client) dispatch(codec ServerCodec) { + var ( + lastOp *requestOp // tracks last send operation + reqInitLock = c.reqInit // nil while the send lock is held + conn = c.newClientConn(codec) + reading = true + ) + defer func() { + close(c.closing) + if reading { + conn.close(ErrClientQuit, nil) + c.drainRead() + } + close(c.didClose) + }() + + // Spawn the initial read loop. + go c.read(codec) + + for { + select { + case <-c.close: + return + + // Read path: + case op := <-c.readOp: + if op.batch { + conn.handler.handleBatch(op.msgs) + } else { + conn.handler.handleMsg(op.msgs[0]) + } + + case err := <-c.readErr: + conn.handler.log.Debug("RPC connection read error", "err", err) + conn.close(err, lastOp) + reading = false + + // Reconnect: + case newcodec := <-c.reconnected: + log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr()) + if reading { + // Wait for the previous read loop to exit. This is a rare case which + // happens if this loop isn't notified in time after the connection breaks. + // In those cases the caller will notice first and reconnect. Closing the + // handler terminates all waiting requests (closing op.resp) except for + // lastOp, which will be transferred to the new handler. + conn.close(errClientReconnected, lastOp) + c.drainRead() + } + go c.read(newcodec) + reading = true + conn = c.newClientConn(newcodec) + // Re-register the in-flight request on the new handler + // because that's where it will be sent. + conn.handler.addRequestOp(lastOp) + + // Send path: + case op := <-reqInitLock: + // Stop listening for further requests until the current one has been sent. + reqInitLock = nil + lastOp = op + conn.handler.addRequestOp(op) + + case err := <-c.reqSent: + if err != nil { + // Remove response handlers for the last send. When the read loop + // goes down, it will signal all other current operations. + conn.handler.removeRequestOp(lastOp) + } + // Let the next request in. + reqInitLock = c.reqInit + lastOp = nil + + case op := <-c.reqTimeout: + conn.handler.removeRequestOp(op) + } + } +} + +// drainRead drops read messages until an error occurs. +func (c *Client) drainRead() { + for { + select { + case <-c.readOp: + case <-c.readErr: + return + } + } +} + +// read decodes RPC messages from a codec, feeding them into dispatch. +func (c *Client) read(codec ServerCodec) { + for { + msgs, batch, err := codec.readBatch() + if _, ok := err.(*json.SyntaxError); ok { + msg := errorMessage(&parseError{err.Error()}) + codec.writeJSON(context.Background(), msg, true) + } + if err != nil { + c.readErr <- err + return + } + c.readOp <- readOp{msgs, batch} + } +} diff --git a/client/client_opt.go b/client/client_opt.go new file mode 100644 index 00000000..5a7c66ec --- /dev/null +++ b/client/client_opt.go @@ -0,0 +1,144 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "net/http" + + "github.com/gorilla/websocket" +) + +// ClientOption is a configuration option for the RPC client. +type ClientOption interface { + applyOption(*clientConfig) +} + +type clientConfig struct { + // HTTP settings + httpClient *http.Client + httpHeaders http.Header + httpAuth HTTPAuth + + // WebSocket options + wsDialer *websocket.Dialer + wsMessageSizeLimit *int64 // wsMessageSizeLimit nil = default, 0 = no limit + + // RPC handler options + idgen func() ID + batchItemLimit int + batchResponseLimit int +} + +func (cfg *clientConfig) initHeaders() { + if cfg.httpHeaders == nil { + cfg.httpHeaders = make(http.Header) + } +} + +func (cfg *clientConfig) setHeader(key, value string) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) +} + +type optionFunc func(*clientConfig) + +func (fn optionFunc) applyOption(opt *clientConfig) { + fn(opt) +} + +// WithWebsocketDialer configures the websocket.Dialer used by the RPC client. +func WithWebsocketDialer(dialer websocket.Dialer) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.wsDialer = &dialer + }) +} + +// WithWebsocketMessageSizeLimit configures the websocket message size limit used by the RPC +// client. Passing a limit of 0 means no limit. +func WithWebsocketMessageSizeLimit(messageSizeLimit int64) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.wsMessageSizeLimit = &messageSizeLimit + }) +} + +// WithHeader configures HTTP headers set by the RPC client. Headers set using this option +// will be used for both HTTP and WebSocket connections. +func WithHeader(key, value string) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + cfg.httpHeaders.Set(key, value) + }) +} + +// WithHeaders configures HTTP headers set by the RPC client. Headers set using this +// option will be used for both HTTP and WebSocket connections. +func WithHeaders(headers http.Header) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.initHeaders() + for k, vs := range headers { + cfg.httpHeaders[k] = vs + } + }) +} + +// WithHTTPClient configures the http.Client used by the RPC client. +func WithHTTPClient(c *http.Client) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.httpClient = c + }) +} + +// WithHTTPAuth configures HTTP request authentication. The given provider will be called +// whenever a request is made. Note that only one authentication provider can be active at +// any time. +func WithHTTPAuth(a HTTPAuth) ClientOption { + if a == nil { + panic("nil auth") + } + return optionFunc(func(cfg *clientConfig) { + cfg.httpAuth = a + }) +} + +// A HTTPAuth function is called by the client whenever a HTTP request is sent. +// The function must be safe for concurrent use. +// +// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add +// auth information to the request. +type HTTPAuth func(h http.Header) error + +// WithBatchItemLimit changes the maximum number of items allowed in batch requests. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchItemLimit(limit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchItemLimit = limit + }) +} + +// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be +// generated for batch requests. When this limit is reached, further calls in the batch +// will not be processed. +// +// Note: this option applies when processing incoming batch requests. It does not affect +// batch requests sent by the client. +func WithBatchResponseSizeLimit(sizeLimit int) ClientOption { + return optionFunc(func(cfg *clientConfig) { + cfg.batchResponseLimit = sizeLimit + }) +} diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 00000000..45a722bf --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,951 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "reflect" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/NethermindEth/starknet.go/client/log" + "github.com/davecgh/go-spew/spew" +) + +func TestClientRequest(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + var resp echoResult + if err := client.Call(&resp, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(resp, echoResult{"hello", 10, &echoArgs{"world"}}) { + t.Errorf("incorrect result %#v", resp) + } +} + +func TestClientResponseType(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + if err := client.Call(nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { + t.Errorf("Passing nil as result should be fine, but got an error: %v", err) + } + var resultVar echoResult + // Note: passing the var, not a ref + err := client.Call(resultVar, "test_echo", "hello", 10, &echoArgs{"world"}) + if err == nil { + t.Error("Passing a var as result should be an error") + } +} + +// This test checks calling a method that returns 'null'. +func TestClientNullResponse(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + + client := DialInProc(server) + defer client.Close() + + var result json.RawMessage + if err := client.Call(&result, "test_null"); err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("Expected non-nil result") + } + if !reflect.DeepEqual(result, json.RawMessage("null")) { + t.Errorf("Expected null, got %s", result) + } +} + +// This test checks that server-returned errors with code and data come out of Client.Call. +func TestClientErrorData(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + var resp interface{} + err := client.Call(&resp, "test_returnError") + if err == nil { + t.Fatal("expected error") + } + + // Check code. + // The method handler returns an error value which implements the rpc.Error + // interface, i.e. it has a custom error code. The server returns this error code. + expectedCode := testError{}.ErrorCode() + if e, ok := err.(Error); !ok { + t.Fatalf("client did not return rpc.Error, got %#v", e) + } else if e.ErrorCode() != expectedCode { + t.Fatalf("wrong error code %d, want %d", e.ErrorCode(), expectedCode) + } + + // Check data. + if e, ok := err.(DataError); !ok { + t.Fatalf("client did not return rpc.DataError, got %#v", e) + } else if e.ErrorData() != (testError{}.ErrorData()) { + t.Fatalf("wrong error data %#v, want %#v", e.ErrorData(), testError{}.ErrorData()) + } +} + +func TestClientBatchRequest(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + batch := []BatchElem{ + { + Method: "test_echo", + Args: []interface{}{"hello", 10, &echoArgs{"world"}}, + Result: new(echoResult), + }, + { + Method: "test_echo", + Args: []interface{}{"hello2", 11, &echoArgs{"world"}}, + Result: new(echoResult), + }, + { + Method: "no_such_method", + Args: []interface{}{1, 2, 3}, + Result: new(int), + }, + } + if err := client.BatchCall(batch); err != nil { + t.Fatal(err) + } + wantResult := []BatchElem{ + { + Method: "test_echo", + Args: []interface{}{"hello", 10, &echoArgs{"world"}}, + Result: &echoResult{"hello", 10, &echoArgs{"world"}}, + }, + { + Method: "test_echo", + Args: []interface{}{"hello2", 11, &echoArgs{"world"}}, + Result: &echoResult{"hello2", 11, &echoArgs{"world"}}, + }, + { + Method: "no_such_method", + Args: []interface{}{1, 2, 3}, + Result: new(int), + Error: &jsonError{Code: -32601, Message: "the method no_such_method does not exist/is not available"}, + }, + } + if !reflect.DeepEqual(batch, wantResult) { + t.Errorf("batch results mismatch:\ngot %swant %s", spew.Sdump(batch), spew.Sdump(wantResult)) + } +} + +// This checks that, for HTTP connections, the length of batch responses is validated to +// match the request exactly. +func TestClientBatchRequest_len(t *testing.T) { + t.Parallel() + + b, err := json.Marshal([]jsonrpcMessage{ + {Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)}, + {Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)}, + }) + if err != nil { + t.Fatal("failed to encode jsonrpc message:", err) + } + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, err := rw.Write(b) + if err != nil { + t.Error("failed to write response:", err) + } + })) + t.Cleanup(s.Close) + + t.Run("too-few", func(t *testing.T) { + t.Parallel() + + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + + batch := []BatchElem{ + {Method: "foo", Result: new(string)}, + {Method: "bar", Result: new(string)}, + {Method: "baz", Result: new(string)}, + } + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + + if err := client.BatchCallContext(ctx, batch); err != nil { + t.Fatal("error:", err) + } + for i, elem := range batch[:2] { + if elem.Error != nil { + t.Errorf("expected no error for batch element %d, got %q", i, elem.Error) + } + } + for i, elem := range batch[2:] { + if elem.Error != ErrMissingBatchResponse { + t.Errorf("wrong error %q for batch element %d", elem.Error, i+2) + } + } + }) + + t.Run("too-many", func(t *testing.T) { + t.Parallel() + + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + + batch := []BatchElem{ + {Method: "foo", Result: new(string)}, + } + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + + if err := client.BatchCallContext(ctx, batch); err != nil { + t.Fatal("error:", err) + } + for i, elem := range batch[:1] { + if elem.Error != nil { + t.Errorf("expected no error for batch element %d, got %q", i, elem.Error) + } + } + for i, elem := range batch[1:] { + if elem.Error != ErrMissingBatchResponse { + t.Errorf("wrong error %q for batch element %d", elem.Error, i+2) + } + } + }) +} + +// This checks that the client can handle the case where the server doesn't +// respond to all requests in a batch. +func TestClientBatchRequestLimit(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(2, 100000) + client := DialInProc(server) + defer client.Close() + + batch := []BatchElem{ + {Method: "foo"}, + {Method: "bar"}, + {Method: "baz"}, + } + err := client.BatchCall(batch) + if err != nil { + t.Fatal("unexpected error:", err) + } + + // Check that the first response indicates an error with batch size. + var err0 Error + if !errors.As(batch[0].Error, &err0) { + t.Log("error zero:", batch[0].Error) + t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error) + } else { + if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge { + t.Fatalf("wrong error on batch elem zero: %v", err0) + } + } + + // Check that remaining response batch elements are reported as absent. + for i, elem := range batch[1:] { + if elem.Error != ErrMissingBatchResponse { + t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error) + } + } +} + +func TestClientNotify(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { + t.Fatal(err) + } +} + +// func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) } +func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) } +func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) } + +// This test checks that requests made through CallContext can be canceled by canceling +// the context. +func testClientCancel(transport string, t *testing.T) { + // These tests take a lot of time, run them all at once. + // You probably want to run with -parallel 1 or comment out + // the call to t.Parallel if you enable the logging. + t.Parallel() + + server := newTestServer() + defer server.Stop() + + // What we want to achieve is that the context gets canceled + // at various stages of request processing. The interesting cases + // are: + // - cancel during dial + // - cancel while performing a HTTP request + // - cancel while waiting for a response + // + // To trigger those, the times are chosen such that connections + // are killed within the deadline for every other call (maxKillTimeout + // is 2x maxCancelTimeout). + // + // Once a connection is dead, there is a fair chance it won't connect + // successfully because the accept is delayed by 1s. + maxContextCancelTimeout := 300 * time.Millisecond + fl := &flakeyListener{ + maxAcceptDelay: 1 * time.Second, + maxKillTimeout: 600 * time.Millisecond, + } + + var client *Client + switch transport { + case "ws", "http": + c, hs := httpTestClient(server, transport, fl) + defer hs.Close() + client = c + default: + panic("unknown transport: " + transport) + } + defer client.Close() + + // The actual test starts here. + var ( + wg sync.WaitGroup + nreqs = 10 + ncallers = 10 + ) + caller := func(index int) { + defer wg.Done() + for i := 0; i < nreqs; i++ { + var ( + ctx context.Context + cancel func() + timeout = time.Duration(rand.Int63n(int64(maxContextCancelTimeout))) + ) + if index < ncallers/2 { + // For half of the callers, create a context without deadline + // and cancel it later. + ctx, cancel = context.WithCancel(context.Background()) + time.AfterFunc(timeout, cancel) + } else { + // For the other half, create a context with a deadline instead. This is + // different because the context deadline is used to set the socket write + // deadline. + ctx, cancel = context.WithTimeout(context.Background(), timeout) + } + + // Now perform a call with the context. + // The key thing here is that no call will ever complete successfully. + err := client.CallContext(ctx, nil, "test_block") + switch { + case err == nil: + _, hasDeadline := ctx.Deadline() + t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) + // default: + // t.Logf("got expected error with %v wait time: %v", timeout, err) + } + cancel() + } + } + wg.Add(ncallers) + for i := 0; i < ncallers; i++ { + go caller(i) + } + wg.Wait() +} + +func TestClientSubscribeInvalidArg(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + check := func(shouldPanic bool, arg interface{}) { + defer func() { + err := recover() + if shouldPanic && err == nil { + t.Errorf("EthSubscribe should've panicked for %#v", arg) + } + if !shouldPanic && err != nil { + t.Errorf("EthSubscribe shouldn't have panicked for %#v", arg) + buf := make([]byte, 1024*1024) + buf = buf[:runtime.Stack(buf, false)] + t.Error(err) + t.Error(string(buf)) + } + }() + client.EthSubscribe(context.Background(), arg, "foo_bar") + } + check(true, nil) + check(true, 1) + check(true, (chan int)(nil)) + check(true, make(<-chan int)) + check(false, make(chan int)) + check(false, make(chan<- int)) +} + +func TestClientSubscribe(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + nc := make(chan int) + count := 10 + sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", count, 0) + if err != nil { + t.Fatal("can't subscribe:", err) + } + for i := 0; i < count; i++ { + if val := <-nc; val != i { + t.Fatalf("value mismatch: got %d, want %d", val, i) + } + } + + sub.Unsubscribe() + select { + case v := <-nc: + t.Fatal("received value after unsubscribe:", v) + case err := <-sub.Err(): + if err != nil { + t.Fatalf("Err returned a non-nil error after explicit unsubscribe: %q", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("subscription not closed within 1s after unsubscribe") + } +} + +// In this test, the connection drops while Subscribe is waiting for a response. +func TestClientSubscribeClose(t *testing.T) { + t.Parallel() + + server := newTestServer() + service := ¬ificationTestService{ + gotHangSubscriptionReq: make(chan struct{}), + unblockHangSubscription: make(chan struct{}), + } + if err := server.RegisterName("nftest2", service); err != nil { + t.Fatal(err) + } + + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + var ( + nc = make(chan int) + errc = make(chan error, 1) + sub *ClientSubscription + err error + ) + go func() { + sub, err = client.Subscribe(context.Background(), "nftest2", nc, "hangSubscription", 999) + errc <- err + }() + + <-service.gotHangSubscriptionReq + client.Close() + service.unblockHangSubscription <- struct{}{} + + select { + case err := <-errc: + if err == nil { + t.Errorf("Subscribe returned nil error after Close") + } + if sub != nil { + t.Error("Subscribe returned non-nil subscription after Close") + } + case <-time.After(1 * time.Second): + t.Fatalf("Subscribe did not return within 1s after Close") + } +} + +// This test reproduces https://github.com/ethereum/go-ethereum/issues/17837 where the +// client hangs during shutdown when Unsubscribe races with Client.Close. +func TestClientCloseUnsubscribeRace(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + + for i := 0; i < 20; i++ { + client := DialInProc(server) + nc := make(chan int) + sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1) + if err != nil { + t.Fatal(err) + } + go client.Close() + go sub.Unsubscribe() + select { + case <-sub.Err(): + case <-time.After(5 * time.Second): + t.Fatal("subscription not closed within timeout") + } + } +} + +// unsubscribeBlocker will wait for the quit channel to process an unsubscribe +// request. +type unsubscribeBlocker struct { + ServerCodec + quit chan struct{} +} + +func (b *unsubscribeBlocker) readBatch() ([]*jsonrpcMessage, bool, error) { + msgs, batch, err := b.ServerCodec.readBatch() + for _, msg := range msgs { + if msg.isUnsubscribe() { + <-b.quit + } + } + return msgs, batch, err +} + +// TestUnsubscribeTimeout verifies that calling the client's Unsubscribe +// function will eventually timeout and not block forever in case the serve does +// not respond. +// It reproducers the issue https://github.com/ethereum/go-ethereum/issues/30156 +func TestUnsubscribeTimeout(t *testing.T) { + t.Parallel() + + srv := NewServer() + srv.RegisterName("nftest", new(notificationTestService)) + + // Setup middleware to block on unsubscribe. + p1, p2 := net.Pipe() + blocker := &unsubscribeBlocker{ServerCodec: NewCodec(p1), quit: make(chan struct{})} + defer close(blocker.quit) + + // Serve the middleware. + go srv.ServeCodec(blocker, OptionMethodInvocation|OptionSubscriptions) + defer srv.Stop() + + // Create the client on the other end of the pipe. + cfg := new(clientConfig) + client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) { + return NewCodec(p2), nil + }) + defer client.Close() + + // Start subscription. + sub, err := client.Subscribe(context.Background(), "nftest", make(chan int), "someSubscription", 1, 1) + if err != nil { + t.Fatalf("failed to subscribe: %v", err) + } + + // Now on a separate thread, attempt to unsubscribe. Since the middleware + // won't return, the function will only return if it times out on the request. + done := make(chan struct{}) + go func() { + sub.Unsubscribe() + done <- struct{}{} + }() + + // Wait for the timeout. If the expected time for the timeout elapses, the + // test is considered failed. + select { + case <-done: + case <-time.After(unsubscribeTimeout + 3*time.Second): + t.Fatalf("Unsubscribe did not return within %s", unsubscribeTimeout) + } +} + +// unsubscribeRecorder collects the subscription IDs of *_unsubscribe calls. +type unsubscribeRecorder struct { + ServerCodec + unsubscribes map[string]bool +} + +func (r *unsubscribeRecorder) readBatch() ([]*jsonrpcMessage, bool, error) { + if r.unsubscribes == nil { + r.unsubscribes = make(map[string]bool) + } + + msgs, batch, err := r.ServerCodec.readBatch() + for _, msg := range msgs { + if msg.isUnsubscribe() { + var params []string + if err := json.Unmarshal(msg.Params, ¶ms); err != nil { + panic("unsubscribe decode error: " + err.Error()) + } + r.unsubscribes[params[0]] = true + } + } + return msgs, batch, err +} + +// This checks that Client calls the _unsubscribe method on the server when Unsubscribe is +// called on a subscription. +func TestClientSubscriptionUnsubscribeServer(t *testing.T) { + t.Parallel() + + // Create the server. + srv := NewServer() + srv.RegisterName("nftest", new(notificationTestService)) + p1, p2 := net.Pipe() + recorder := &unsubscribeRecorder{ServerCodec: NewCodec(p1)} + go srv.ServeCodec(recorder, OptionMethodInvocation|OptionSubscriptions) + defer srv.Stop() + + // Create the client on the other end of the pipe. + cfg := new(clientConfig) + client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) { + return NewCodec(p2), nil + }) + defer client.Close() + + // Create the subscription. + ch := make(chan int) + sub, err := client.Subscribe(context.Background(), "nftest", ch, "someSubscription", 1, 1) + if err != nil { + t.Fatal(err) + } + + // Unsubscribe and check that unsubscribe was called. + sub.Unsubscribe() + if !recorder.unsubscribes[sub.subid] { + t.Fatal("client did not call unsubscribe method") + } + if _, open := <-sub.Err(); open { + t.Fatal("subscription error channel not closed after unsubscribe") + } +} + +// This checks that the subscribed channel can be closed after Unsubscribe. +// It is the reproducer for https://github.com/ethereum/go-ethereum/issues/22322 +func TestClientSubscriptionChannelClose(t *testing.T) { + t.Parallel() + + var ( + srv = NewServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + srv.RegisterName("nftest", new(notificationTestService)) + client, _ := Dial(wsURL) + defer client.Close() + + for i := 0; i < 100; i++ { + ch := make(chan int, 100) + sub, err := client.Subscribe(context.Background(), "nftest", ch, "someSubscription", 100, 1) + if err != nil { + t.Fatal(err) + } + sub.Unsubscribe() + close(ch) + } +} + +// This test checks that Client doesn't lock up when a single subscriber +// doesn't read subscription events. +func TestClientNotificationStorm(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + + doTest := func(count int, wantError bool) { + client := DialInProc(server) + defer client.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Subscribe on the server. It will start sending many notifications + // very quickly. + nc := make(chan int) + sub, err := client.Subscribe(ctx, "nftest", nc, "someSubscription", count, 0) + if err != nil { + t.Fatal("can't subscribe:", err) + } + defer sub.Unsubscribe() + + // Process each notification, try to run a call in between each of them. + for i := 0; i < count; i++ { + select { + case val := <-nc: + if val != i { + t.Fatalf("(%d/%d) unexpected value %d", i, count, val) + } + case err := <-sub.Err(): + if wantError && err != ErrSubscriptionQueueOverflow { + t.Fatalf("(%d/%d) got error %q, want %q", i, count, err, ErrSubscriptionQueueOverflow) + } else if !wantError { + t.Fatalf("(%d/%d) got unexpected error %q", i, count, err) + } + return + } + var r int + err := client.CallContext(ctx, &r, "nftest_echo", i) + if err != nil { + if !wantError { + t.Fatalf("(%d/%d) call error: %v", i, count, err) + } + return + } + } + if wantError { + t.Fatalf("didn't get expected error") + } + } + + doTest(8000, false) + doTest(24000, true) +} + +func TestClientSetHeader(t *testing.T) { + t.Parallel() + + var gotHeader bool + srv := newTestServer() + httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("test") == "ok" { + gotHeader = true + } + srv.ServeHTTP(w, r) + })) + defer httpsrv.Close() + defer srv.Stop() + + client, err := Dial(httpsrv.URL) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + client.SetHeader("test", "ok") + if _, err := client.SupportedModules(); err != nil { + t.Fatal(err) + } + if !gotHeader { + t.Fatal("client did not set custom header") + } + + // Check that Content-Type can be replaced. + client.SetHeader("content-type", "application/x-garbage") + _, err = client.SupportedModules() + if err == nil { + t.Fatal("no error for invalid content-type header") + } else if !strings.Contains(err.Error(), "Unsupported Media Type") { + t.Fatalf("error is not related to content-type: %q", err) + } +} + +func TestClientHTTP(t *testing.T) { + t.Parallel() + + server := newTestServer() + defer server.Stop() + + client, hs := httpTestClient(server, "http", nil) + defer hs.Close() + defer client.Close() + + // Launch concurrent requests. + var ( + results = make([]echoResult, 100) + errc = make(chan error, len(results)) + wantResult = echoResult{"a", 1, new(echoArgs)} + ) + for i := range results { + go func() { + errc <- client.Call(&results[i], "test_echo", wantResult.String, wantResult.Int, wantResult.Args) + }() + } + + // Wait for all of them to complete. + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + for i := range results { + select { + case err := <-errc: + if err != nil { + t.Fatal(err) + } + case <-timeout.C: + t.Fatalf("timeout (got %d/%d) results)", i+1, len(results)) + } + } + + // Check results. + for i := range results { + if !reflect.DeepEqual(results[i], wantResult) { + t.Errorf("result %d mismatch: got %#v, want %#v", i, results[i], wantResult) + } + } +} + +func TestClientReconnect(t *testing.T) { + t.Parallel() + + startServer := func(addr string) (*Server, net.Listener) { + srv := newTestServer() + l, err := net.Listen("tcp", addr) + if err != nil { + t.Fatal("can't listen:", err) + } + go http.Serve(l, srv.WebsocketHandler([]string{"*"})) + return srv, l + } + + ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) + defer cancel() + + // Start a server and corresponding client. + s1, l1 := startServer("127.0.0.1:0") + client, err := DialContext(ctx, "ws://"+l1.Addr().String()) + if err != nil { + t.Fatal("can't dial", err) + } + defer client.Close() + + // Perform a call. This should work because the server is up. + var resp echoResult + if err := client.CallContext(ctx, &resp, "test_echo", "", 1, nil); err != nil { + t.Fatal(err) + } + + // Shut down the server and allow for some cool down time so we can listen on the same + // address again. + l1.Close() + s1.Stop() + time.Sleep(2 * time.Second) + + // Try calling again. It shouldn't work. + if err := client.CallContext(ctx, &resp, "test_echo", "", 2, nil); err == nil { + t.Error("successful call while the server is down") + t.Logf("resp: %#v", resp) + } + + // Start it up again and call again. The connection should be reestablished. + // We spawn multiple calls here to check whether this hangs somehow. + s2, l2 := startServer(l1.Addr().String()) + defer l2.Close() + defer s2.Stop() + + start := make(chan struct{}) + errors := make(chan error, 20) + for i := 0; i < cap(errors); i++ { + go func() { + <-start + var resp echoResult + errors <- client.CallContext(ctx, &resp, "test_echo", "", 3, nil) + }() + } + close(start) + errcount := 0 + for i := 0; i < cap(errors); i++ { + if err = <-errors; err != nil { + errcount++ + } + } + t.Logf("%d errors, last error: %v", errcount, err) + if errcount > 1 { + t.Errorf("expected one error after disconnect, got %d", errcount) + } +} + +func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) { + // Create the HTTP server. + var hs *httptest.Server + switch transport { + case "ws": + hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"})) + case "http": + hs = httptest.NewUnstartedServer(srv) + default: + panic("unknown HTTP transport: " + transport) + } + // Wrap the listener if required. + if fl != nil { + fl.Listener = hs.Listener + hs.Listener = fl + } + // Connect the client. + hs.Start() + client, err := Dial(transport + "://" + hs.Listener.Addr().String()) + if err != nil { + panic(err) + } + return client, hs +} + +// flakeyListener kills accepted connections after a random timeout. +type flakeyListener struct { + net.Listener + maxKillTimeout time.Duration + maxAcceptDelay time.Duration +} + +func (l *flakeyListener) Accept() (net.Conn, error) { + delay := time.Duration(rand.Int63n(int64(l.maxAcceptDelay))) + time.Sleep(delay) + + c, err := l.Listener.Accept() + if err == nil { + timeout := time.Duration(rand.Int63n(int64(l.maxKillTimeout))) + time.AfterFunc(timeout, func() { + log.Debug(fmt.Sprintf("killing conn %v after %v", c.LocalAddr(), timeout)) + c.Close() + }) + } + return c, err +} diff --git a/client/context_headers.go b/client/context_headers.go new file mode 100644 index 00000000..c246d501 --- /dev/null +++ b/client/context_headers.go @@ -0,0 +1,56 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "net/http" +) + +type mdHeaderKey struct{} + +// NewContextWithHeaders wraps the given context, adding HTTP headers. These headers will +// be applied by Client when making a request using the returned context. +func NewContextWithHeaders(ctx context.Context, h http.Header) context.Context { + if len(h) == 0 { + // This check ensures the header map set in context will never be nil. + return ctx + } + + var ctxh http.Header + prev, ok := ctx.Value(mdHeaderKey{}).(http.Header) + if ok { + ctxh = setHeaders(prev.Clone(), h) + } else { + ctxh = h.Clone() + } + return context.WithValue(ctx, mdHeaderKey{}, ctxh) +} + +// headersFromContext is used to extract http.Header from context. +func headersFromContext(ctx context.Context) http.Header { + source, _ := ctx.Value(mdHeaderKey{}).(http.Header) + return source +} + +// setHeaders sets all headers from src in dst. +func setHeaders(dst http.Header, src http.Header) http.Header { + for key, values := range src { + dst[http.CanonicalHeaderKey(key)] = values + } + return dst +} diff --git a/client/errors.go b/client/errors.go new file mode 100644 index 00000000..fdd8fb2a --- /dev/null +++ b/client/errors.go @@ -0,0 +1,156 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import "fmt" + +// HTTPError is returned by client operations when the HTTP status code of the +// response is not a 2xx status. +type HTTPError struct { + StatusCode int + Status string + Body []byte +} + +func (err HTTPError) Error() string { + if len(err.Body) == 0 { + return err.Status + } + return fmt.Sprintf("%v: %s", err.Status, err.Body) +} + +// Error wraps RPC errors, which contain an error code in addition to the message. +type Error interface { + Error() string // returns the message + ErrorCode() int // returns the code +} + +// A DataError contains some data in addition to the error message. +type DataError interface { + Error() string // returns the message + ErrorData() interface{} // returns the error data +} + +// Error types defined below are the built-in JSON-RPC errors. + +var ( + _ Error = new(methodNotFoundError) + _ Error = new(subscriptionNotFoundError) + _ Error = new(parseError) + _ Error = new(invalidRequestError) + _ Error = new(invalidMessageError) + _ Error = new(invalidParamsError) + _ Error = new(internalServerError) +) + +const ( + errcodeDefault = -32000 + errcodeTimeout = -32002 + errcodeResponseTooLarge = -32003 + errcodePanic = -32603 + errcodeMarshalError = -32603 + + legacyErrcodeNotificationsUnsupported = -32001 +) + +const ( + errMsgTimeout = "request timed out" + errMsgResponseTooLarge = "response too large" + errMsgBatchTooLarge = "batch too large" +) + +type methodNotFoundError struct{ method string } + +func (e *methodNotFoundError) ErrorCode() int { return -32601 } + +func (e *methodNotFoundError) Error() string { + return fmt.Sprintf("the method %s does not exist/is not available", e.method) +} + +type notificationsUnsupportedError struct{} + +func (e notificationsUnsupportedError) Error() string { + return "notifications not supported" +} + +func (e notificationsUnsupportedError) ErrorCode() int { return -32601 } + +// Is checks for equivalence to another error. Here we define that all errors with code +// -32601 (method not found) are equivalent to notificationsUnsupportedError. This is +// done to enable the following pattern: +// +// sub, err := client.Subscribe(...) +// if errors.Is(err, rpc.ErrNotificationsUnsupported) { +// // server doesn't support subscriptions +// } +func (e notificationsUnsupportedError) Is(other error) bool { + if other == (notificationsUnsupportedError{}) { + return true + } + rpcErr, ok := other.(Error) + if ok { + code := rpcErr.ErrorCode() + return code == -32601 || code == legacyErrcodeNotificationsUnsupported + } + return false +} + +type subscriptionNotFoundError struct{ namespace, subscription string } + +func (e *subscriptionNotFoundError) ErrorCode() int { return -32601 } + +func (e *subscriptionNotFoundError) Error() string { + return fmt.Sprintf("no %q subscription in %s namespace", e.subscription, e.namespace) +} + +// Invalid JSON was received by the server. +type parseError struct{ message string } + +func (e *parseError) ErrorCode() int { return -32700 } + +func (e *parseError) Error() string { return e.message } + +// received message isn't a valid request +type invalidRequestError struct{ message string } + +func (e *invalidRequestError) ErrorCode() int { return -32600 } + +func (e *invalidRequestError) Error() string { return e.message } + +// received message is invalid +type invalidMessageError struct{ message string } + +func (e *invalidMessageError) ErrorCode() int { return -32700 } + +func (e *invalidMessageError) Error() string { return e.message } + +// unable to decode supplied params, or an invalid number of parameters +type invalidParamsError struct{ message string } + +func (e *invalidParamsError) ErrorCode() int { return -32602 } + +func (e *invalidParamsError) Error() string { return e.message } + +// internalServerError is used for server errors during request processing. +type internalServerError struct { + code int + message string +} + +func (e *internalServerError) ErrorCode() int { return e.code } + +func (e *internalServerError) Error() string { return e.message } diff --git a/client/handler.go b/client/handler.go new file mode 100644 index 00000000..abf05008 --- /dev/null +++ b/client/handler.go @@ -0,0 +1,612 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/NethermindEth/starknet.go/client/log" +) + +// handler handles JSON-RPC messages. There is one handler per connection. Note that +// handler is not safe for concurrent use. Message handling never blocks indefinitely +// because RPCs are processed on background goroutines launched by handler. +// +// The entry points for incoming messages are: +// +// h.handleMsg(message) +// h.handleBatch(message) +// +// Outgoing calls use the requestOp struct. Register the request before sending it +// on the connection: +// +// op := &requestOp{ids: ...} +// h.addRequestOp(op) +// +// Now send the request, then wait for the reply to be delivered through handleMsg: +// +// if err := op.wait(...); err != nil { +// h.removeRequestOp(op) // timeout, etc. +// } +type handler struct { + reg *serviceRegistry + unsubscribeCb *callback + idgen func() ID // subscription ID generator + respWait map[string]*requestOp // active client requests + clientSubs map[string]*ClientSubscription // active client subscriptions + callWG sync.WaitGroup // pending call goroutines + rootCtx context.Context // canceled by close() + cancelRoot func() // cancel function for rootCtx + conn jsonWriter // where responses will be sent + log log.Logger + allowSubscribe bool + batchRequestLimit int + batchResponseMaxSize int + + subLock sync.Mutex + serverSubs map[ID]*Subscription +} + +type callProc struct { + ctx context.Context + notifiers []*Notifier +} + +func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler { + rootCtx, cancelRoot := context.WithCancel(connCtx) + h := &handler{ + reg: reg, + idgen: idgen, + conn: conn, + respWait: make(map[string]*requestOp), + clientSubs: make(map[string]*ClientSubscription), + rootCtx: rootCtx, + cancelRoot: cancelRoot, + allowSubscribe: true, + serverSubs: make(map[ID]*Subscription), + log: log.Root(), + batchRequestLimit: batchRequestLimit, + batchResponseMaxSize: batchResponseMaxSize, + } + if conn.remoteAddr() != "" { + h.log = h.log.New("conn", conn.remoteAddr()) + } + h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe)) + return h +} + +// batchCallBuffer manages in progress call messages and their responses during a batch +// call. Calls need to be synchronized between the processing and timeout-triggering +// goroutines. +type batchCallBuffer struct { + mutex sync.Mutex + calls []*jsonrpcMessage + resp []*jsonrpcMessage + wrote bool +} + +// nextCall returns the next unprocessed message. +func (b *batchCallBuffer) nextCall() *jsonrpcMessage { + b.mutex.Lock() + defer b.mutex.Unlock() + + if len(b.calls) == 0 { + return nil + } + // The popping happens in `pushAnswer`. The in progress call is kept + // so we can return an error for it in case of timeout. + msg := b.calls[0] + return msg +} + +// pushResponse adds the response to last call returned by nextCall. +func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) { + b.mutex.Lock() + defer b.mutex.Unlock() + + if answer != nil { + b.resp = append(b.resp, answer) + } + b.calls = b.calls[1:] +} + +// write sends the responses. +func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.doWrite(ctx, conn, false) +} + +// respondWithError sends the responses added so far. For the remaining unanswered call +// messages, it responds with the given error. +func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) { + b.mutex.Lock() + defer b.mutex.Unlock() + + for _, msg := range b.calls { + if !msg.isNotification() { + b.resp = append(b.resp, msg.errorResponse(err)) + } + } + b.doWrite(ctx, conn, true) +} + +// doWrite actually writes the response. +// This assumes b.mutex is held. +func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) { + if b.wrote { + return + } + b.wrote = true // can only write once + if len(b.resp) > 0 { + conn.writeJSON(ctx, b.resp, isErrorResponse) + } +} + +// handleBatch executes all messages in a batch and returns the responses. +func (h *handler) handleBatch(msgs []*jsonrpcMessage) { + // Emit error response for empty batches: + if len(msgs) == 0 { + h.startCallProc(func(cp *callProc) { + resp := errorMessage(&invalidRequestError{"empty batch"}) + h.conn.writeJSON(cp.ctx, resp, true) + }) + return + } + // Apply limit on total number of requests. + if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit { + h.startCallProc(func(cp *callProc) { + h.respondWithBatchTooLarge(cp, msgs) + }) + return + } + + // Handle non-call messages first. + // Here we need to find the requestOp that sent the request batch. + calls := make([]*jsonrpcMessage, 0, len(msgs)) + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + calls = append(calls, msg) + }) + if len(calls) == 0 { + return + } + + // Process calls on a goroutine because they may block indefinitely: + h.startCallProc(func(cp *callProc) { + var ( + timer *time.Timer + cancel context.CancelFunc + callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))} + ) + + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // currently-running method might not return immediately on timeout, we must wait + // for the timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + err := &internalServerError{errcodeTimeout, errMsgTimeout} + callBuffer.respondWithError(cp.ctx, h.conn, err) + }) + } + + responseBytes := 0 + for { + // No need to handle rest of calls if timed out. + if cp.ctx.Err() != nil { + break + } + msg := callBuffer.nextCall() + if msg == nil { + break + } + resp := h.handleCallMsg(cp, msg) + callBuffer.pushResponse(resp) + if resp != nil && h.batchResponseMaxSize != 0 { + responseBytes += len(resp.Result) + if responseBytes > h.batchResponseMaxSize { + err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge} + callBuffer.respondWithError(cp.ctx, h.conn, err) + break + } + } + } + if timer != nil { + timer.Stop() + } + + h.addSubscriptions(cp.notifiers) + callBuffer.write(cp.ctx, h.conn) + for _, n := range cp.notifiers { + n.activate() + } + }) +} + +func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) { + resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge}) + // Find the first call and add its "id" field to the error. + // This is the best we can do, given that the protocol doesn't have a way + // of reporting an error for the entire batch. + for _, msg := range batch { + if msg.isCall() { + resp.ID = msg.ID + break + } + } + h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true) +} + +// handleMsg handles a single non-batch message. +func (h *handler) handleMsg(msg *jsonrpcMessage) { + msgs := []*jsonrpcMessage{msg} + h.handleResponses(msgs, func(msg *jsonrpcMessage) { + h.startCallProc(func(cp *callProc) { + h.handleNonBatchCall(cp, msg) + }) + }) +} + +func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) { + var ( + responded sync.Once + timer *time.Timer + cancel context.CancelFunc + ) + cp.ctx, cancel = context.WithCancel(cp.ctx) + defer cancel() + + // Cancel the request context after timeout and send an error response. Since the + // running method might not return immediately on timeout, we must wait for the + // timeout concurrently with processing the request. + if timeout, ok := ContextRequestTimeout(cp.ctx); ok { + timer = time.AfterFunc(timeout, func() { + cancel() + responded.Do(func() { + resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) + h.conn.writeJSON(cp.ctx, resp, true) + }) + }) + } + + answer := h.handleCallMsg(cp, msg) + if timer != nil { + timer.Stop() + } + h.addSubscriptions(cp.notifiers) + if answer != nil { + responded.Do(func() { + h.conn.writeJSON(cp.ctx, answer, false) + }) + } + for _, n := range cp.notifiers { + n.activate() + } +} + +// close cancels all requests except for inflightReq and waits for +// call goroutines to shut down. +func (h *handler) close(err error, inflightReq *requestOp) { + h.cancelAllRequests(err, inflightReq) + h.callWG.Wait() + h.cancelRoot() + h.cancelServerSubscriptions(err) +} + +// addRequestOp registers a request operation. +func (h *handler) addRequestOp(op *requestOp) { + for _, id := range op.ids { + h.respWait[string(id)] = op + } +} + +// removeRequestOp stops waiting for the given request IDs. +func (h *handler) removeRequestOp(op *requestOp) { + for _, id := range op.ids { + delete(h.respWait, string(id)) + } +} + +// cancelAllRequests unblocks and removes pending requests and active subscriptions. +func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) { + didClose := make(map[*requestOp]bool) + if inflightReq != nil { + didClose[inflightReq] = true + } + + for id, op := range h.respWait { + // Remove the op so that later calls will not close op.resp again. + delete(h.respWait, id) + + if !didClose[op] { + op.err = err + close(op.resp) + didClose[op] = true + } + } + for id, sub := range h.clientSubs { + delete(h.clientSubs, id) + sub.close(err) + } +} + +func (h *handler) addSubscriptions(nn []*Notifier) { + h.subLock.Lock() + defer h.subLock.Unlock() + + for _, n := range nn { + if sub := n.takeSubscription(); sub != nil { + h.serverSubs[sub.ID] = sub + } + } +} + +// cancelServerSubscriptions removes all subscriptions and closes their error channels. +func (h *handler) cancelServerSubscriptions(err error) { + h.subLock.Lock() + defer h.subLock.Unlock() + + for id, s := range h.serverSubs { + s.err <- err + close(s.err) + delete(h.serverSubs, id) + } +} + +// startCallProc runs fn in a new goroutine and starts tracking it in the h.calls wait group. +func (h *handler) startCallProc(fn func(*callProc)) { + h.callWG.Add(1) + go func() { + ctx, cancel := context.WithCancel(h.rootCtx) + defer h.callWG.Done() + defer cancel() + fn(&callProc{ctx: ctx}) + }() +} + +// handleResponses processes method call responses. +func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) { + var resolvedops []*requestOp + handleResp := func(msg *jsonrpcMessage) { + op := h.respWait[string(msg.ID)] + if op == nil { + h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) + return + } + resolvedops = append(resolvedops, op) + delete(h.respWait, string(msg.ID)) + + // For subscription responses, start the subscription if the server + // indicates success. EthSubscribe gets unblocked in either case through + // the op.resp channel. + if op.sub != nil { + if msg.Error != nil { + op.err = msg.Error + } else { + op.err = json.Unmarshal(msg.Result, &op.sub.subid) + if op.err == nil { + go op.sub.run() + h.clientSubs[op.sub.subid] = op.sub + } + } + } + + if !op.hadResponse { + op.hadResponse = true + op.resp <- batch + } + } + + for _, msg := range batch { + start := time.Now() + switch { + case msg.isResponse(): + handleResp(msg) + h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) + + case msg.isNotification(): + if strings.HasSuffix(msg.Method, notificationMethodSuffix) { + h.handleSubscriptionResult(msg) + continue + } + handleCall(msg) + + default: + handleCall(msg) + } + } + + for _, op := range resolvedops { + h.removeRequestOp(op) + } +} + +// handleSubscriptionResult processes subscription notifications. +func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { + var result subscriptionResult + if err := json.Unmarshal(msg.Params, &result); err != nil { + h.log.Debug("Dropping invalid subscription message") + return + } + if h.clientSubs[result.ID] != nil { + h.clientSubs[result.ID].deliver(result.Result) + } +} + +// handleCallMsg executes a call message and returns the answer. +func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + start := time.Now() + switch { + case msg.isNotification(): + h.handleCall(ctx, msg) + h.log.Debug("Served "+msg.Method, "duration", time.Since(start)) + return nil + + case msg.isCall(): + resp := h.handleCall(ctx, msg) + var logctx []any + logctx = append(logctx, "reqid", idForLog{msg.ID}, "duration", time.Since(start)) + if resp.Error != nil { + logctx = append(logctx, "err", resp.Error.Message) + if resp.Error.Data != nil { + logctx = append(logctx, "errdata", formatErrorData(resp.Error.Data)) + } + h.log.Warn("Served "+msg.Method, logctx...) + } else { + h.log.Debug("Served "+msg.Method, logctx...) + } + return resp + + case msg.hasValidID(): + return msg.errorResponse(&invalidRequestError{"invalid request"}) + + default: + return errorMessage(&invalidRequestError{"invalid request"}) + } +} + +// handleCall processes method calls. +func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + if msg.isSubscribe() { + return h.handleSubscribe(cp, msg) + } + var callb *callback + if msg.isUnsubscribe() { + callb = h.unsubscribeCb + } else { + callb = h.reg.callback(msg.Method) + } + if callb == nil { + return msg.errorResponse(&methodNotFoundError{method: msg.Method}) + } + + args, err := parsePositionalArguments(msg.Params, callb.argTypes) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + answer := h.runMethod(cp.ctx, msg, callb, args) + + return answer +} + +// handleSubscribe processes *_subscribe method calls. +func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + if !h.allowSubscribe { + return msg.errorResponse(ErrNotificationsUnsupported) + } + + // Subscription method name is first argument. + name, err := parseSubscriptionName(msg.Params) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + namespace := msg.namespace() + callb := h.reg.subscription(namespace, name) + if callb == nil { + return msg.errorResponse(&subscriptionNotFoundError{namespace, name}) + } + + // Parse subscription name arg too, but remove it before calling the callback. + argTypes := append([]reflect.Type{stringType}, callb.argTypes...) + args, err := parsePositionalArguments(msg.Params, argTypes) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + args = args[1:] + + // Install notifier in context so the subscription handler can find it. + n := &Notifier{h: h, namespace: namespace} + cp.notifiers = append(cp.notifiers, n) + ctx := context.WithValue(cp.ctx, notifierKey{}, n) + + return h.runMethod(ctx, msg, callb, args) +} + +// runMethod runs the Go callback for an RPC method. +func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *callback, args []reflect.Value) *jsonrpcMessage { + result, err := callb.call(ctx, msg.Method, args) + if err != nil { + return msg.errorResponse(err) + } + return msg.response(result) +} + +// unsubscribe is the callback function for all *_unsubscribe calls. +func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) { + h.subLock.Lock() + defer h.subLock.Unlock() + + s := h.serverSubs[id] + if s == nil { + return false, ErrSubscriptionNotFound + } + close(s.err) + delete(h.serverSubs, id) + return true, nil +} + +type idForLog struct{ json.RawMessage } + +func (id idForLog) String() string { + if s, err := strconv.Unquote(string(id.RawMessage)); err == nil { + return s + } + return string(id.RawMessage) +} + +var errTruncatedOutput = errors.New("truncated output") + +type limitedBuffer struct { + output []byte + limit int +} + +func (buf *limitedBuffer) Write(data []byte) (int, error) { + avail := max(buf.limit, len(buf.output)) + if len(data) < avail { + buf.output = append(buf.output, data...) + return len(data), nil + } + buf.output = append(buf.output, data[:avail]...) + return avail, errTruncatedOutput +} + +func formatErrorData(v any) string { + buf := limitedBuffer{limit: 1024} + err := json.NewEncoder(&buf).Encode(v) + switch { + case err == nil: + return string(bytes.TrimRight(buf.output, "\n")) + case errors.Is(err, errTruncatedOutput): + return fmt.Sprintf("%s... (truncated)", buf.output) + default: + return fmt.Sprintf("bad error data (err=%v)", err) + } +} diff --git a/client/http.go b/client/http.go new file mode 100644 index 00000000..68d0453b --- /dev/null +++ b/client/http.go @@ -0,0 +1,395 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "mime" + "net/http" + "net/url" + "strconv" + "sync" + "time" +) + +const ( + defaultBodyLimit = 5 * 1024 * 1024 + contentType = "application/json" +) + +// https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13 +var acceptedContentTypes = []string{contentType, "application/json-rpc", "application/jsonrequest"} + +type httpConn struct { + client *http.Client + url string + closeOnce sync.Once + closeCh chan interface{} + mu sync.Mutex // protects headers + headers http.Header + auth HTTPAuth +} + +// httpConn implements ServerCodec, but it is treated specially by Client +// and some methods don't work. The panic() stubs here exist to ensure +// this special treatment is correct. + +func (hc *httpConn) writeJSON(context.Context, interface{}, bool) error { + panic("writeJSON called on httpConn") +} + +func (hc *httpConn) peerInfo() PeerInfo { + panic("peerInfo called on httpConn") +} + +func (hc *httpConn) remoteAddr() string { + return hc.url +} + +func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) { + <-hc.closeCh + return nil, false, io.EOF +} + +func (hc *httpConn) close() { + hc.closeOnce.Do(func() { close(hc.closeCh) }) +} + +func (hc *httpConn) closed() <-chan interface{} { + return hc.closeCh +} + +// HTTPTimeouts represents the configuration params for the HTTP RPC server. +type HTTPTimeouts struct { + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration + + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. The connection's read deadline is reset + // after reading the headers and the Handler can decide what + // is considered too slow for the body. If ReadHeaderTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + ReadHeaderTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not + // let Handlers make decisions on a per-request basis. + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, ReadHeaderTimeout is used. + IdleTimeout time.Duration +} + +// DefaultHTTPTimeouts represents the default timeout values used if further +// configuration is not provided. +var DefaultHTTPTimeouts = HTTPTimeouts{ + ReadTimeout: 30 * time.Second, + ReadHeaderTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, +} + +// DialHTTP creates a new RPC client that connects to an RPC server over HTTP. +func DialHTTP(endpoint string) (*Client, error) { + return DialHTTPWithClient(endpoint, new(http.Client)) +} + +// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP +// using the provided HTTP Client. +// +// Deprecated: use DialOptions and the WithHTTPClient option. +func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { + // Sanity check URL so we don't end up with a client that will fail every request. + _, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + + var cfg clientConfig + cfg.httpClient = client + fn := newClientTransportHTTP(endpoint, &cfg) + return newClient(context.Background(), &cfg, fn) +} + +func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { + headers := make(http.Header, 2+len(cfg.httpHeaders)) + headers.Set("accept", contentType) + headers.Set("content-type", contentType) + for key, values := range cfg.httpHeaders { + headers[key] = values + } + + client := cfg.httpClient + if client == nil { + client = new(http.Client) + } + + hc := &httpConn{ + client: client, + headers: headers, + url: endpoint, + auth: cfg.httpAuth, + closeCh: make(chan interface{}), + } + + return func(ctx context.Context) (ServerCodec, error) { + return hc, nil + } +} + +func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error { + hc := c.writeConn.(*httpConn) + respBody, err := hc.doRequest(ctx, msg) + if err != nil { + return err + } + defer respBody.Close() + + var resp jsonrpcMessage + batch := [1]*jsonrpcMessage{&resp} + if err := json.NewDecoder(respBody).Decode(&resp); err != nil { + return err + } + op.resp <- batch[:] + return nil +} + +func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonrpcMessage) error { + hc := c.writeConn.(*httpConn) + respBody, err := hc.doRequest(ctx, msgs) + if err != nil { + return err + } + defer respBody.Close() + + var respmsgs []*jsonrpcMessage + if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { + return err + } + op.resp <- respmsgs + return nil +} + +func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadCloser, error) { + body, err := json.Marshal(msg) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, hc.url, io.NopCloser(bytes.NewReader(body))) + if err != nil { + return nil, err + } + req.ContentLength = int64(len(body)) + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(body)), nil } + + // set headers + hc.mu.Lock() + req.Header = hc.headers.Clone() + hc.mu.Unlock() + setHeaders(req.Header, headersFromContext(ctx)) + + if hc.auth != nil { + if err := hc.auth(req.Header); err != nil { + return nil, err + } + } + + // do request + resp, err := hc.client.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var buf bytes.Buffer + var body []byte + if _, err := buf.ReadFrom(resp.Body); err == nil { + body = buf.Bytes() + } + resp.Body.Close() + return nil, HTTPError{ + Status: resp.Status, + StatusCode: resp.StatusCode, + Body: body, + } + } + return resp.Body, nil +} + +// httpServerConn turns a HTTP connection into a Conn. +type httpServerConn struct { + io.Reader + io.Writer + r *http.Request +} + +func (s *Server) newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { + body := io.LimitReader(r.Body, int64(s.httpBodyLimit)) + conn := &httpServerConn{Reader: body, Writer: w, r: r} + + encoder := func(v any, isErrorResponse bool) error { + if !isErrorResponse { + return json.NewEncoder(conn).Encode(v) + } + + // It's an error response and requires special treatment. + // + // In case of a timeout error, the response must be written before the HTTP + // server's write timeout occurs. So we need to flush the response. The + // Content-Length header also needs to be set to ensure the client knows + // when it has the full response. + encdata, err := json.Marshal(v) + if err != nil { + return err + } + w.Header().Set("content-length", strconv.Itoa(len(encdata))) + + // If this request is wrapped in a handler that might remove Content-Length (such + // as the automatic gzip we do in package node), we need to ensure the HTTP server + // doesn't perform chunked encoding. In case WriteTimeout is reached, the chunked + // encoding might not be finished correctly, and some clients do not like it when + // the final chunk is missing. + w.Header().Set("transfer-encoding", "identity") + + _, err = w.Write(encdata) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return err + } + + dec := json.NewDecoder(conn) + dec.UseNumber() + + return NewFuncCodec(conn, encoder, dec.Decode) +} + +// Close does nothing and always returns nil. +func (t *httpServerConn) Close() error { return nil } + +// RemoteAddr returns the peer address of the underlying connection. +func (t *httpServerConn) RemoteAddr() string { + return t.r.RemoteAddr +} + +// SetWriteDeadline does nothing and always returns nil. +func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil } + +// ServeHTTP serves JSON-RPC requests over HTTP. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Permit dumb empty requests for remote health-checks (AWS) + if r.Method == http.MethodGet && r.ContentLength == 0 && r.URL.RawQuery == "" { + w.WriteHeader(http.StatusOK) + return + } + if code, err := s.validateRequest(r); err != nil { + http.Error(w, err.Error(), code) + return + } + + // Create request-scoped context. + connInfo := PeerInfo{Transport: "http", RemoteAddr: r.RemoteAddr} + connInfo.HTTP.Version = r.Proto + connInfo.HTTP.Host = r.Host + connInfo.HTTP.Origin = r.Header.Get("Origin") + connInfo.HTTP.UserAgent = r.Header.Get("User-Agent") + ctx := r.Context() + ctx = context.WithValue(ctx, peerInfoContextKey{}, connInfo) + + // All checks passed, create a codec that reads directly from the request body + // until EOF, writes the response to w, and orders the server to process a + // single request. + w.Header().Set("content-type", contentType) + codec := s.newHTTPServerConn(r, w) + defer codec.close() + s.serveSingleRequest(ctx, codec) +} + +// validateRequest returns a non-zero response code and error message if the +// request is invalid. +func (s *Server) validateRequest(r *http.Request) (int, error) { + if r.Method == http.MethodPut || r.Method == http.MethodDelete { + return http.StatusMethodNotAllowed, errors.New("method not allowed") + } + if r.ContentLength > int64(s.httpBodyLimit) { + err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, s.httpBodyLimit) + return http.StatusRequestEntityTooLarge, err + } + // Allow OPTIONS (regardless of content-type) + if r.Method == http.MethodOptions { + return 0, nil + } + // Check content-type + if mt, _, err := mime.ParseMediaType(r.Header.Get("content-type")); err == nil { + for _, accepted := range acceptedContentTypes { + if accepted == mt { + return 0, nil + } + } + } + // Invalid content-type + err := fmt.Errorf("invalid content type, only %s is supported", contentType) + return http.StatusUnsupportedMediaType, err +} + +// ContextRequestTimeout returns the request timeout derived from the given context. +func ContextRequestTimeout(ctx context.Context) (time.Duration, bool) { + timeout := time.Duration(math.MaxInt64) + hasTimeout := false + setTimeout := func(d time.Duration) { + if d < timeout { + timeout = d + hasTimeout = true + } + } + + if deadline, ok := ctx.Deadline(); ok { + setTimeout(time.Until(deadline)) + } + + // If the context is an HTTP request context, use the server's WriteTimeout. + httpSrv, ok := ctx.Value(http.ServerContextKey).(*http.Server) + if ok && httpSrv.WriteTimeout > 0 { + wt := httpSrv.WriteTimeout + // When a write timeout is configured, we need to send the response message before + // the HTTP server cuts connection. So our internal timeout must be earlier than + // the server's true timeout. + // + // Note: Timeouts are sanitized to be a minimum of 1 second. + // Also see issue: https://github.com/golang/go/issues/47229 + wt -= 100 * time.Millisecond + setTimeout(wt) + } + + return timeout, hasTimeout +} diff --git a/client/http_test.go b/client/http_test.go new file mode 100644 index 00000000..4b422ff0 --- /dev/null +++ b/client/http_test.go @@ -0,0 +1,245 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func confirmStatusCode(t *testing.T, got, want int) { + t.Helper() + if got == want { + return + } + if gotName := http.StatusText(got); len(gotName) > 0 { + if wantName := http.StatusText(want); len(wantName) > 0 { + t.Fatalf("response status code: got %d (%s), want %d (%s)", got, gotName, want, wantName) + } + } + t.Fatalf("response status code: got %d, want %d", got, want) +} + +func confirmRequestValidationCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { + t.Helper() + + s := NewServer() + request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body)) + if len(contentType) > 0 { + request.Header.Set("Content-Type", contentType) + } + code, err := s.validateRequest(request) + if code == 0 { + if err != nil { + t.Errorf("validation: got error %v, expected nil", err) + } + } else if err == nil { + t.Errorf("validation: code %d: got nil, expected error", code) + } + confirmStatusCode(t, code, expectedStatusCode) +} + +func TestHTTPErrorResponseWithDelete(t *testing.T) { + confirmRequestValidationCode(t, http.MethodDelete, contentType, "", http.StatusMethodNotAllowed) +} + +func TestHTTPErrorResponseWithPut(t *testing.T) { + confirmRequestValidationCode(t, http.MethodPut, contentType, "", http.StatusMethodNotAllowed) +} + +func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) { + body := make([]rune, defaultBodyLimit+1) + confirmRequestValidationCode(t, + http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge) +} + +func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) { + confirmRequestValidationCode(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType) +} + +func TestHTTPErrorResponseWithValidRequest(t *testing.T) { + confirmRequestValidationCode(t, http.MethodPost, contentType, "", 0) +} + +func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { + t.Helper() + s := Server{} + ts := httptest.NewServer(&s) + defer ts.Close() + + request, err := http.NewRequest(method, ts.URL, strings.NewReader(body)) + if err != nil { + t.Fatalf("failed to create a valid HTTP request: %v", err) + } + if len(contentType) > 0 { + request.Header.Set("Content-Type", contentType) + } + resp, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + confirmStatusCode(t, resp.StatusCode, expectedStatusCode) +} + +func TestHTTPResponseWithEmptyGet(t *testing.T) { + confirmHTTPRequestYieldsStatusCode(t, http.MethodGet, "", "", http.StatusOK) +} + +// This checks that maxRequestContentLength is not applied to the response of a request. +func TestHTTPRespBodyUnlimited(t *testing.T) { + const respLength = defaultBodyLimit * 3 + + s := NewServer() + defer s.Stop() + s.RegisterName("test", largeRespService{respLength}) + ts := httptest.NewServer(s) + defer ts.Close() + + c, err := DialHTTP(ts.URL) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + var r string + if err := c.Call(&r, "test_largeResp"); err != nil { + t.Fatal(err) + } + if len(r) != respLength { + t.Fatalf("response has wrong length %d, want %d", len(r), respLength) + } +} + +// Tests that an HTTP error results in an HTTPError instance +// being returned with the expected attributes. +func TestHTTPErrorResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "error has occurred!", http.StatusTeapot) + })) + defer ts.Close() + + c, err := DialHTTP(ts.URL) + if err != nil { + t.Fatal(err) + } + + var r string + err = c.Call(&r, "test_method") + if err == nil { + t.Fatal("error was expected") + } + + httpErr, ok := err.(HTTPError) + if !ok { + t.Fatalf("unexpected error type %T", err) + } + + if httpErr.StatusCode != http.StatusTeapot { + t.Error("unexpected status code", httpErr.StatusCode) + } + if httpErr.Status != "418 I'm a teapot" { + t.Error("unexpected status text", httpErr.Status) + } + if body := string(httpErr.Body); body != "error has occurred!\n" { + t.Error("unexpected body", body) + } + + if errMsg := httpErr.Error(); errMsg != "418 I'm a teapot: error has occurred!\n" { + t.Error("unexpected error message", errMsg) + } +} + +func TestHTTPPeerInfo(t *testing.T) { + s := newTestServer() + defer s.Stop() + ts := httptest.NewServer(s) + defer ts.Close() + + c, err := Dial(ts.URL) + if err != nil { + t.Fatal(err) + } + c.SetHeader("user-agent", "ua-testing") + c.SetHeader("origin", "origin.example.com") + + // Request peer information. + var info PeerInfo + if err := c.Call(&info, "test_peerInfo"); err != nil { + t.Fatal(err) + } + + if info.RemoteAddr == "" { + t.Error("RemoteAddr not set") + } + if info.Transport != "http" { + t.Errorf("wrong Transport %q", info.Transport) + } + if info.HTTP.Version != "HTTP/1.1" { + t.Errorf("wrong HTTP.Version %q", info.HTTP.Version) + } + if info.HTTP.UserAgent != "ua-testing" { + t.Errorf("wrong HTTP.UserAgent %q", info.HTTP.UserAgent) + } + if info.HTTP.Origin != "origin.example.com" { + t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent) + } +} + +func TestNewContextWithHeaders(t *testing.T) { + expectedHeaders := 0 + server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + for i := 0; i < expectedHeaders; i++ { + key, want := fmt.Sprintf("key-%d", i), fmt.Sprintf("val-%d", i) + if have := request.Header.Get(key); have != want { + t.Errorf("wrong request headers for %s, want: %s, have: %s", key, want, have) + } + } + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write([]byte(`{}`)) + })) + defer server.Close() + + client, err := Dial(server.URL) + if err != nil { + t.Fatalf("failed to dial: %s", err) + } + defer client.Close() + + newHdr := func(k, v string) http.Header { + header := http.Header{} + header.Set(k, v) + return header + } + ctx1 := NewContextWithHeaders(context.Background(), newHdr("key-0", "val-0")) + ctx2 := NewContextWithHeaders(ctx1, newHdr("key-1", "val-1")) + ctx3 := NewContextWithHeaders(ctx2, newHdr("key-2", "val-2")) + + expectedHeaders = 3 + if err := client.CallContext(ctx3, nil, "test"); err != ErrNoResult { + t.Error("call failed", err) + } + + expectedHeaders = 2 + if err := client.CallContext(ctx2, nil, "test"); err != ErrNoResult { + t.Error("call failed:", err) + } +} diff --git a/client/inproc.go b/client/inproc.go new file mode 100644 index 00000000..b0beee85 --- /dev/null +++ b/client/inproc.go @@ -0,0 +1,34 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "net" +) + +// DialInProc attaches an in-process connection to the given RPC server. +func DialInProc(handler *Server) *Client { + initctx := context.Background() + cfg := new(clientConfig) + c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) { + p1, p2 := net.Pipe() + go handler.ServeCodec(NewCodec(p1), 0) + return NewCodec(p2), nil + }) + return c +} diff --git a/client/json.go b/client/json.go new file mode 100644 index 00000000..e7908e68 --- /dev/null +++ b/client/json.go @@ -0,0 +1,369 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "reflect" + "strings" + "sync" + "time" +) + +const ( + vsn = "2.0" + serviceMethodSeparator = "_" + subscribeMethodSuffix = "_subscribe" + unsubscribeMethodSuffix = "_unsubscribe" + notificationMethodSuffix = "_subscription" + + defaultWriteTimeout = 10 * time.Second // used if context has no deadline +) + +var null = json.RawMessage("null") + +type subscriptionResult struct { + ID string `json:"subscription"` + Result json.RawMessage `json:"result,omitempty"` +} + +type subscriptionResultEnc struct { + ID string `json:"subscription"` + Result any `json:"result"` +} + +type jsonrpcSubscriptionNotification struct { + Version string `json:"jsonrpc"` + Method string `json:"method"` + Params subscriptionResultEnc `json:"params"` +} + +// A value of this type can a JSON-RPC request, notification, successful response or +// error response. Which one it is depends on the fields. +type jsonrpcMessage struct { + Version string `json:"jsonrpc,omitempty"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Error *jsonError `json:"error,omitempty"` + Result json.RawMessage `json:"result,omitempty"` +} + +func (msg *jsonrpcMessage) isNotification() bool { + return msg.hasValidVersion() && msg.ID == nil && msg.Method != "" +} + +func (msg *jsonrpcMessage) isCall() bool { + return msg.hasValidVersion() && msg.hasValidID() && msg.Method != "" +} + +func (msg *jsonrpcMessage) isResponse() bool { + return msg.hasValidVersion() && msg.hasValidID() && msg.Method == "" && msg.Params == nil && (msg.Result != nil || msg.Error != nil) +} + +func (msg *jsonrpcMessage) hasValidID() bool { + return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '[' +} + +func (msg *jsonrpcMessage) hasValidVersion() bool { + return msg.Version == vsn +} + +func (msg *jsonrpcMessage) isSubscribe() bool { + return strings.HasSuffix(msg.Method, subscribeMethodSuffix) +} + +func (msg *jsonrpcMessage) isUnsubscribe() bool { + return strings.HasSuffix(msg.Method, unsubscribeMethodSuffix) +} + +func (msg *jsonrpcMessage) namespace() string { + before, _, _ := strings.Cut(msg.Method, serviceMethodSeparator) + return before +} + +func (msg *jsonrpcMessage) String() string { + b, _ := json.Marshal(msg) + return string(b) +} + +func (msg *jsonrpcMessage) errorResponse(err error) *jsonrpcMessage { + resp := errorMessage(err) + resp.ID = msg.ID + return resp +} + +func (msg *jsonrpcMessage) response(result interface{}) *jsonrpcMessage { + enc, err := json.Marshal(result) + if err != nil { + return msg.errorResponse(&internalServerError{errcodeMarshalError, err.Error()}) + } + return &jsonrpcMessage{Version: vsn, ID: msg.ID, Result: enc} +} + +func errorMessage(err error) *jsonrpcMessage { + msg := &jsonrpcMessage{Version: vsn, ID: null, Error: &jsonError{ + Code: errcodeDefault, + Message: err.Error(), + }} + ec, ok := err.(Error) + if ok { + msg.Error.Code = ec.ErrorCode() + } + de, ok := err.(DataError) + if ok { + msg.Error.Data = de.ErrorData() + } + return msg +} + +type jsonError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +func (err *jsonError) Error() string { + if err.Message == "" { + return fmt.Sprintf("json-rpc error %d", err.Code) + } + return err.Message +} + +func (err *jsonError) ErrorCode() int { + return err.Code +} + +func (err *jsonError) ErrorData() interface{} { + return err.Data +} + +// Conn is a subset of the methods of net.Conn which are sufficient for ServerCodec. +type Conn interface { + io.ReadWriteCloser + SetWriteDeadline(time.Time) error +} + +type deadlineCloser interface { + io.Closer + SetWriteDeadline(time.Time) error +} + +// ConnRemoteAddr wraps the RemoteAddr operation, which returns a description +// of the peer address of a connection. If a Conn also implements ConnRemoteAddr, this +// description is used in log messages. +type ConnRemoteAddr interface { + RemoteAddr() string +} + +// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It also has +// support for parsing arguments and serializing (result) objects. +type jsonCodec struct { + remote string + closer sync.Once // close closed channel once + closeCh chan interface{} // closed on Close + decode decodeFunc // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode encodeFunc // encoder to allow multiple transports + conn deadlineCloser +} + +type encodeFunc = func(v interface{}, isErrorResponse bool) error + +type decodeFunc = func(v interface{}) error + +// NewFuncCodec creates a codec which uses the given functions to read and write. If conn +// implements ConnRemoteAddr, log messages will use it to include the remote address of +// the connection. +func NewFuncCodec(conn deadlineCloser, encode encodeFunc, decode decodeFunc) ServerCodec { + codec := &jsonCodec{ + closeCh: make(chan interface{}), + encode: encode, + decode: decode, + conn: conn, + } + if ra, ok := conn.(ConnRemoteAddr); ok { + codec.remote = ra.RemoteAddr() + } + return codec +} + +// NewCodec creates a codec on the given connection. If conn implements ConnRemoteAddr, log +// messages will use it to include the remote address of the connection. +func NewCodec(conn Conn) ServerCodec { + enc := json.NewEncoder(conn) + dec := json.NewDecoder(conn) + dec.UseNumber() + + encode := func(v interface{}, isErrorResponse bool) error { + return enc.Encode(v) + } + return NewFuncCodec(conn, encode, dec.Decode) +} + +func (c *jsonCodec) peerInfo() PeerInfo { + // This returns "ipc" because all other built-in transports have a separate codec type. + return PeerInfo{Transport: "ipc", RemoteAddr: c.remote} +} + +func (c *jsonCodec) remoteAddr() string { + return c.remote +} + +func (c *jsonCodec) readBatch() (messages []*jsonrpcMessage, batch bool, err error) { + // Decode the next JSON object in the input stream. + // This verifies basic syntax, etc. + var rawmsg json.RawMessage + if err := c.decode(&rawmsg); err != nil { + return nil, false, err + } + messages, batch = parseMessage(rawmsg) + for i, msg := range messages { + if msg == nil { + // Message is JSON 'null'. Replace with zero value so it + // will be treated like any other invalid message. + messages[i] = new(jsonrpcMessage) + } + } + return messages, batch, nil +} + +func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorResponse bool) error { + c.encMu.Lock() + defer c.encMu.Unlock() + + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(defaultWriteTimeout) + } + c.conn.SetWriteDeadline(deadline) + return c.encode(v, isErrorResponse) +} + +func (c *jsonCodec) close() { + c.closer.Do(func() { + close(c.closeCh) + c.conn.Close() + }) +} + +// closed returns a channel which will be closed when Close is called +func (c *jsonCodec) closed() <-chan interface{} { + return c.closeCh +} + +// parseMessage parses raw bytes as a (batch of) JSON-RPC message(s). There are no error +// checks in this function because the raw message has already been syntax-checked when it +// is called. Any non-JSON-RPC messages in the input return the zero value of +// jsonrpcMessage. +func parseMessage(raw json.RawMessage) ([]*jsonrpcMessage, bool) { + if !isBatch(raw) { + msgs := []*jsonrpcMessage{{}} + json.Unmarshal(raw, &msgs[0]) + return msgs, false + } + dec := json.NewDecoder(bytes.NewReader(raw)) + dec.Token() // skip '[' + var msgs []*jsonrpcMessage + for dec.More() { + msgs = append(msgs, new(jsonrpcMessage)) + dec.Decode(&msgs[len(msgs)-1]) + } + return msgs, true +} + +// isBatch returns true when the first non-whitespace characters is '[' +func isBatch(raw json.RawMessage) bool { + for _, c := range raw { + // skip insignificant whitespace (http://www.ietf.org/rfc/rfc4627.txt) + if c == 0x20 || c == 0x09 || c == 0x0a || c == 0x0d { + continue + } + return c == '[' + } + return false +} + +// parsePositionalArguments tries to parse the given args to an array of values with the +// given types. It returns the parsed values or an error when the args could not be +// parsed. Missing optional arguments are returned as reflect.Zero values. +func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([]reflect.Value, error) { + dec := json.NewDecoder(bytes.NewReader(rawArgs)) + var args []reflect.Value + tok, err := dec.Token() + switch { + case err == io.EOF || tok == nil && err == nil: + // "params" is optional and may be empty. Also allow "params":null even though it's + // not in the spec because our own client used to send it. + case err != nil: + return nil, err + case tok == json.Delim('['): + // Read argument array. + if args, err = parseArgumentArray(dec, types); err != nil { + return nil, err + } + default: + return nil, errors.New("non-array args") + } + // Set any missing args to nil. + for i := len(args); i < len(types); i++ { + if types[i].Kind() != reflect.Ptr { + return nil, fmt.Errorf("missing value for required argument %d", i) + } + args = append(args, reflect.Zero(types[i])) + } + return args, nil +} + +func parseArgumentArray(dec *json.Decoder, types []reflect.Type) ([]reflect.Value, error) { + args := make([]reflect.Value, 0, len(types)) + for i := 0; dec.More(); i++ { + if i >= len(types) { + return args, fmt.Errorf("too many arguments, want at most %d", len(types)) + } + argval := reflect.New(types[i]) + if err := dec.Decode(argval.Interface()); err != nil { + return args, fmt.Errorf("invalid argument %d: %v", i, err) + } + if argval.IsNil() && types[i].Kind() != reflect.Ptr { + return args, fmt.Errorf("missing value for required argument %d", i) + } + args = append(args, argval.Elem()) + } + // Read end of args array. + _, err := dec.Token() + return args, err +} + +// parseSubscriptionName extracts the subscription name from an encoded argument array. +func parseSubscriptionName(rawArgs json.RawMessage) (string, error) { + dec := json.NewDecoder(bytes.NewReader(rawArgs)) + if tok, _ := dec.Token(); tok != json.Delim('[') { + return "", errors.New("non-array args") + } + v, _ := dec.Token() + method, ok := v.(string) + if !ok { + return "", errors.New("expected subscription name as first argument") + } + return method, nil +} diff --git a/client/log/format.go b/client/log/format.go new file mode 100644 index 00000000..54c071b9 --- /dev/null +++ b/client/log/format.go @@ -0,0 +1,363 @@ +package log + +import ( + "bytes" + "fmt" + "log/slog" + "math/big" + "reflect" + "strconv" + "time" + "unicode/utf8" + + "github.com/holiman/uint256" +) + +const ( + timeFormat = "2006-01-02T15:04:05-0700" + floatFormat = 'f' + termMsgJust = 40 + termCtxMaxPadding = 40 +) + +// 40 spaces +var spaces = []byte(" ") + +// TerminalStringer is an analogous interface to the stdlib stringer, allowing +// own types to have custom shortened serialization formats when printed to the +// screen. +type TerminalStringer interface { + TerminalString() string +} + +func (h *TerminalHandler) format(buf []byte, r slog.Record, usecolor bool) []byte { + msg := escapeMessage(r.Message) + var color = "" + if usecolor { + switch r.Level { + case LevelCrit: + color = "\x1b[35m" + case slog.LevelError: + color = "\x1b[31m" + case slog.LevelWarn: + color = "\x1b[33m" + case slog.LevelInfo: + color = "\x1b[32m" + case slog.LevelDebug: + color = "\x1b[36m" + case LevelTrace: + color = "\x1b[34m" + } + } + if buf == nil { + buf = make([]byte, 0, 30+termMsgJust) + } + b := bytes.NewBuffer(buf) + + if color != "" { // Start color + b.WriteString(color) + b.WriteString(LevelAlignedString(r.Level)) + b.WriteString("\x1b[0m") + } else { + b.WriteString(LevelAlignedString(r.Level)) + } + b.WriteString("[") + writeTimeTermFormat(b, r.Time) + b.WriteString("] ") + b.WriteString(msg) + + // try to justify the log output for short messages + //length := utf8.RuneCountInString(msg) + length := len(msg) + if (r.NumAttrs()+len(h.attrs)) > 0 && length < termMsgJust { + b.Write(spaces[:termMsgJust-length]) + } + // print the attributes + h.formatAttributes(b, r, color) + + return b.Bytes() +} + +func (h *TerminalHandler) formatAttributes(buf *bytes.Buffer, r slog.Record, color string) { + writeAttr := func(attr slog.Attr, first, last bool) { + buf.WriteByte(' ') + + if color != "" { + buf.WriteString(color) + buf.Write(appendEscapeString(buf.AvailableBuffer(), attr.Key)) + buf.WriteString("\x1b[0m=") + } else { + buf.Write(appendEscapeString(buf.AvailableBuffer(), attr.Key)) + buf.WriteByte('=') + } + val := FormatSlogValue(attr.Value, buf.AvailableBuffer()) + + padding := h.fieldPadding[attr.Key] + + length := utf8.RuneCount(val) + if padding < length && length <= termCtxMaxPadding { + padding = length + h.fieldPadding[attr.Key] = padding + } + buf.Write(val) + if !last && padding > length { + buf.Write(spaces[:padding-length]) + } + } + var n = 0 + var nAttrs = len(h.attrs) + r.NumAttrs() + for _, attr := range h.attrs { + writeAttr(attr, n == 0, n == nAttrs-1) + n++ + } + r.Attrs(func(attr slog.Attr) bool { + writeAttr(attr, n == 0, n == nAttrs-1) + n++ + return true + }) + buf.WriteByte('\n') +} + +// FormatSlogValue formats a slog.Value for serialization to terminal. +func FormatSlogValue(v slog.Value, tmp []byte) (result []byte) { + var value any + defer func() { + if err := recover(); err != nil { + if v := reflect.ValueOf(value); v.Kind() == reflect.Ptr && v.IsNil() { + result = []byte("") + } else { + panic(err) + } + } + }() + + switch v.Kind() { + case slog.KindString: + return appendEscapeString(tmp, v.String()) + case slog.KindInt64: // All int-types (int8, int16 etc) wind up here + return appendInt64(tmp, v.Int64()) + case slog.KindUint64: // All uint-types (uint8, uint16 etc) wind up here + return appendUint64(tmp, v.Uint64(), false) + case slog.KindFloat64: + return strconv.AppendFloat(tmp, v.Float64(), floatFormat, 3, 64) + case slog.KindBool: + return strconv.AppendBool(tmp, v.Bool()) + case slog.KindDuration: + value = v.Duration() + case slog.KindTime: + // Performance optimization: No need for escaping since the provided + // timeFormat doesn't have any escape characters, and escaping is + // expensive. + return v.Time().AppendFormat(tmp, timeFormat) + default: + value = v.Any() + } + if value == nil { + return []byte("") + } + switch v := value.(type) { + case *big.Int: // Need to be before fmt.Stringer-clause + return appendBigInt(tmp, v) + case *uint256.Int: // Need to be before fmt.Stringer-clause + return appendU256(tmp, v) + case error: + return appendEscapeString(tmp, v.Error()) + case TerminalStringer: + return appendEscapeString(tmp, v.TerminalString()) + case fmt.Stringer: + return appendEscapeString(tmp, v.String()) + } + + // We can use the 'tmp' as a scratch-buffer, to first format the + // value, and in a second step do escaping. + internal := fmt.Appendf(tmp, "%+v", value) + return appendEscapeString(tmp, string(internal)) +} + +// appendInt64 formats n with thousand separators and writes into buffer dst. +func appendInt64(dst []byte, n int64) []byte { + if n < 0 { + return appendUint64(dst, uint64(-n), true) + } + return appendUint64(dst, uint64(n), false) +} + +// appendUint64 formats n with thousand separators and writes into buffer dst. +func appendUint64(dst []byte, n uint64, neg bool) []byte { + // Small numbers are fine as is + if n < 100000 { + if neg { + return strconv.AppendInt(dst, -int64(n), 10) + } else { + return strconv.AppendInt(dst, int64(n), 10) + } + } + // Large numbers should be split + const maxLength = 26 + + var ( + out = make([]byte, maxLength) + i = maxLength - 1 + comma = 0 + ) + for ; n > 0; i-- { + if comma == 3 { + comma = 0 + out[i] = ',' + } else { + comma++ + out[i] = '0' + byte(n%10) + n /= 10 + } + } + if neg { + out[i] = '-' + i-- + } + return append(dst, out[i+1:]...) +} + +// FormatLogfmtUint64 formats n with thousand separators. +func FormatLogfmtUint64(n uint64) string { + return string(appendUint64(nil, n, false)) +} + +// appendBigInt formats n with thousand separators and writes to dst. +func appendBigInt(dst []byte, n *big.Int) []byte { + if n.IsUint64() { + return appendUint64(dst, n.Uint64(), false) + } + if n.IsInt64() { + return appendInt64(dst, n.Int64()) + } + + var ( + text = n.String() + buf = make([]byte, len(text)+len(text)/3) + comma = 0 + i = len(buf) - 1 + ) + for j := len(text) - 1; j >= 0; j, i = j-1, i-1 { + c := text[j] + + switch { + case c == '-': + buf[i] = c + case comma == 3: + buf[i] = ',' + i-- + comma = 0 + fallthrough + default: + buf[i] = c + comma++ + } + } + return append(dst, buf[i+1:]...) +} + +// appendU256 formats n with thousand separators. +func appendU256(dst []byte, n *uint256.Int) []byte { + if n.IsUint64() { + return appendUint64(dst, n.Uint64(), false) + } + res := []byte(n.PrettyDec(',')) + return append(dst, res...) +} + +// appendEscapeString writes the string s to the given writer, with +// escaping/quoting if needed. +func appendEscapeString(dst []byte, s string) []byte { + needsQuoting := false + needsEscaping := false + for _, r := range s { + // If it contains spaces or equal-sign, we need to quote it. + if r == ' ' || r == '=' { + needsQuoting = true + continue + } + // We need to escape it, if it contains + // - character " (0x22) and lower (except space) + // - characters above ~ (0x7E), plus equal-sign + if r <= '"' || r > '~' { + needsEscaping = true + break + } + } + if needsEscaping { + return strconv.AppendQuote(dst, s) + } + // No escaping needed, but we might have to place within quote-marks, in case + // it contained a space + if needsQuoting { + dst = append(dst, '"') + dst = append(dst, []byte(s)...) + return append(dst, '"') + } + return append(dst, []byte(s)...) +} + +// escapeMessage checks if the provided string needs escaping/quoting, similarly +// to escapeString. The difference is that this method is more lenient: it allows +// for spaces and linebreaks to occur without needing quoting. +func escapeMessage(s string) string { + needsQuoting := false + for _, r := range s { + // Allow CR/LF/TAB. This is to make multi-line messages work. + if r == '\r' || r == '\n' || r == '\t' { + continue + } + // We quote everything below (0x20) and above~ (0x7E), + // plus equal-sign + if r < ' ' || r > '~' || r == '=' { + needsQuoting = true + break + } + } + if !needsQuoting { + return s + } + return strconv.Quote(s) +} + +// writeTimeTermFormat writes on the format "01-02|15:04:05.000" +func writeTimeTermFormat(buf *bytes.Buffer, t time.Time) { + _, month, day := t.Date() + writePosIntWidth(buf, int(month), 2) + buf.WriteByte('-') + writePosIntWidth(buf, day, 2) + buf.WriteByte('|') + hour, min, sec := t.Clock() + writePosIntWidth(buf, hour, 2) + buf.WriteByte(':') + writePosIntWidth(buf, min, 2) + buf.WriteByte(':') + writePosIntWidth(buf, sec, 2) + ns := t.Nanosecond() + buf.WriteByte('.') + writePosIntWidth(buf, ns/1e6, 3) +} + +// writePosIntWidth writes non-negative integer i to the buffer, padded on the left +// by zeroes to the given width. Use a width of 0 to omit padding. +// Adapted from pkg.go.dev/log/slog/internal/buffer +func writePosIntWidth(b *bytes.Buffer, i, width int) { + // Cheap integer to fixed-width decimal ASCII. + // Copied from log/log.go. + if i < 0 { + panic("negative int") + } + // Assemble decimal in reverse order. + var bb [20]byte + bp := len(bb) - 1 + for i >= 10 || width > 1 { + width-- + q := i / 10 + bb[bp] = byte('0' + i - q*10) + bp-- + i = q + } + // i < 10 + bb[bp] = byte('0' + i) + b.Write(bb[bp:]) +} diff --git a/client/log/handler.go b/client/log/handler.go new file mode 100644 index 00000000..56eff667 --- /dev/null +++ b/client/log/handler.go @@ -0,0 +1,199 @@ +package log + +import ( + "context" + "fmt" + "io" + "log/slog" + "math/big" + "reflect" + "sync" + "time" + + "github.com/holiman/uint256" +) + +type discardHandler struct{} + +// DiscardHandler returns a no-op handler +func DiscardHandler() slog.Handler { + return &discardHandler{} +} + +func (h *discardHandler) Handle(_ context.Context, r slog.Record) error { + return nil +} + +func (h *discardHandler) Enabled(_ context.Context, level slog.Level) bool { + return false +} + +func (h *discardHandler) WithGroup(name string) slog.Handler { + panic("not implemented") +} + +func (h *discardHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &discardHandler{} +} + +type TerminalHandler struct { + mu sync.Mutex + wr io.Writer + lvl slog.Level + useColor bool + attrs []slog.Attr + // fieldPadding is a map with maximum field value lengths seen until now + // to allow padding log contexts in a bit smarter way. + fieldPadding map[string]int + + buf []byte +} + +// NewTerminalHandler returns a handler which formats log records at all levels optimized for human readability on +// a terminal with color-coded level output and terser human friendly timestamp. +// This format should only be used for interactive programs or while developing. +// +// [LEVEL] [TIME] MESSAGE key=value key=value ... +// +// Example: +// +// [DBUG] [May 16 20:58:45] remove route ns=haproxy addr=127.0.0.1:50002 +func NewTerminalHandler(wr io.Writer, useColor bool) *TerminalHandler { + return NewTerminalHandlerWithLevel(wr, levelMaxVerbosity, useColor) +} + +// NewTerminalHandlerWithLevel returns the same handler as NewTerminalHandler but only outputs +// records which are less than or equal to the specified verbosity level. +func NewTerminalHandlerWithLevel(wr io.Writer, lvl slog.Level, useColor bool) *TerminalHandler { + return &TerminalHandler{ + wr: wr, + lvl: lvl, + useColor: useColor, + fieldPadding: make(map[string]int), + } +} + +func (h *TerminalHandler) Handle(_ context.Context, r slog.Record) error { + h.mu.Lock() + defer h.mu.Unlock() + buf := h.format(h.buf, r, h.useColor) + h.wr.Write(buf) + h.buf = buf[:0] + return nil +} + +func (h *TerminalHandler) Enabled(_ context.Context, level slog.Level) bool { + return level >= h.lvl +} + +func (h *TerminalHandler) WithGroup(name string) slog.Handler { + panic("not implemented") +} + +func (h *TerminalHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &TerminalHandler{ + wr: h.wr, + lvl: h.lvl, + useColor: h.useColor, + attrs: append(h.attrs, attrs...), + fieldPadding: make(map[string]int), + } +} + +// ResetFieldPadding zeroes the field-padding for all attribute pairs. +func (h *TerminalHandler) ResetFieldPadding() { + h.mu.Lock() + h.fieldPadding = make(map[string]int) + h.mu.Unlock() +} + +type leveler struct{ minLevel slog.Level } + +func (l *leveler) Level() slog.Level { + return l.minLevel +} + +// JSONHandler returns a handler which prints records in JSON format. +func JSONHandler(wr io.Writer) slog.Handler { + return JSONHandlerWithLevel(wr, levelMaxVerbosity) +} + +// JSONHandlerWithLevel returns a handler which prints records in JSON format that are less than or equal to +// the specified verbosity level. +func JSONHandlerWithLevel(wr io.Writer, level slog.Level) slog.Handler { + return slog.NewJSONHandler(wr, &slog.HandlerOptions{ + ReplaceAttr: builtinReplaceJSON, + Level: &leveler{level}, + }) +} + +// LogfmtHandler returns a handler which prints records in logfmt format, an easy machine-parseable but human-readable +// format for key/value pairs. +// +// For more details see: http://godoc.org/github.com/kr/logfmt +func LogfmtHandler(wr io.Writer) slog.Handler { + return slog.NewTextHandler(wr, &slog.HandlerOptions{ + ReplaceAttr: builtinReplaceLogfmt, + }) +} + +// LogfmtHandlerWithLevel returns the same handler as LogfmtHandler but it only outputs +// records which are less than or equal to the specified verbosity level. +func LogfmtHandlerWithLevel(wr io.Writer, level slog.Level) slog.Handler { + return slog.NewTextHandler(wr, &slog.HandlerOptions{ + ReplaceAttr: builtinReplaceLogfmt, + Level: &leveler{level}, + }) +} + +func builtinReplaceLogfmt(_ []string, attr slog.Attr) slog.Attr { + return builtinReplace(nil, attr, true) +} + +func builtinReplaceJSON(_ []string, attr slog.Attr) slog.Attr { + return builtinReplace(nil, attr, false) +} + +func builtinReplace(_ []string, attr slog.Attr, logfmt bool) slog.Attr { + switch attr.Key { + case slog.TimeKey: + if attr.Value.Kind() == slog.KindTime { + if logfmt { + return slog.String("t", attr.Value.Time().Format(timeFormat)) + } else { + return slog.Attr{Key: "t", Value: attr.Value} + } + } + case slog.LevelKey: + if l, ok := attr.Value.Any().(slog.Level); ok { + attr = slog.Any("lvl", LevelString(l)) + return attr + } + } + + switch v := attr.Value.Any().(type) { + case time.Time: + if logfmt { + attr = slog.String(attr.Key, v.Format(timeFormat)) + } + case *big.Int: + if v == nil { + attr.Value = slog.StringValue("") + } else { + attr.Value = slog.StringValue(v.String()) + } + case *uint256.Int: + if v == nil { + attr.Value = slog.StringValue("") + } else { + attr.Value = slog.StringValue(v.Dec()) + } + case fmt.Stringer: + if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) { + attr.Value = slog.StringValue("") + } else { + attr.Value = slog.StringValue(v.String()) + } + } + return attr +} diff --git a/client/log/logger.go b/client/log/logger.go new file mode 100644 index 00000000..016856c8 --- /dev/null +++ b/client/log/logger.go @@ -0,0 +1,216 @@ +package log + +import ( + "context" + "log/slog" + "math" + "os" + "runtime" + "time" +) + +const errorKey = "LOG_ERROR" + +const ( + legacyLevelCrit = iota + legacyLevelError + legacyLevelWarn + legacyLevelInfo + legacyLevelDebug + legacyLevelTrace +) + +const ( + levelMaxVerbosity slog.Level = math.MinInt + LevelTrace slog.Level = -8 + LevelDebug = slog.LevelDebug + LevelInfo = slog.LevelInfo + LevelWarn = slog.LevelWarn + LevelError = slog.LevelError + LevelCrit slog.Level = 12 + + // for backward-compatibility + LvlTrace = LevelTrace + LvlInfo = LevelInfo + LvlDebug = LevelDebug +) + +// FromLegacyLevel converts from old Geth verbosity level constants +// to levels defined by slog +func FromLegacyLevel(lvl int) slog.Level { + switch lvl { + case legacyLevelCrit: + return LevelCrit + case legacyLevelError: + return slog.LevelError + case legacyLevelWarn: + return slog.LevelWarn + case legacyLevelInfo: + return slog.LevelInfo + case legacyLevelDebug: + return slog.LevelDebug + case legacyLevelTrace: + return LevelTrace + default: + break + } + + // TODO: should we allow use of custom levels or force them to match existing max/min if they fall outside the range as I am doing here? + if lvl > legacyLevelTrace { + return LevelTrace + } + return LevelCrit +} + +// LevelAlignedString returns a 5-character string containing the name of a Lvl. +func LevelAlignedString(l slog.Level) string { + switch l { + case LevelTrace: + return "TRACE" + case slog.LevelDebug: + return "DEBUG" + case slog.LevelInfo: + return "INFO " + case slog.LevelWarn: + return "WARN " + case slog.LevelError: + return "ERROR" + case LevelCrit: + return "CRIT " + default: + return "unknown level" + } +} + +// LevelString returns a string containing the name of a Lvl. +func LevelString(l slog.Level) string { + switch l { + case LevelTrace: + return "trace" + case slog.LevelDebug: + return "debug" + case slog.LevelInfo: + return "info" + case slog.LevelWarn: + return "warn" + case slog.LevelError: + return "error" + case LevelCrit: + return "crit" + default: + return "unknown" + } +} + +// A Logger writes key/value pairs to a Handler +type Logger interface { + // With returns a new Logger that has this logger's attributes plus the given attributes + With(ctx ...interface{}) Logger + + // New returns a new Logger that has this logger's attributes plus the given attributes. Identical to 'With'. + New(ctx ...interface{}) Logger + + // Log logs a message at the specified level with context key/value pairs + Log(level slog.Level, msg string, ctx ...interface{}) + + // Trace log a message at the trace level with context key/value pairs + Trace(msg string, ctx ...interface{}) + + // Debug logs a message at the debug level with context key/value pairs + Debug(msg string, ctx ...interface{}) + + // Info logs a message at the info level with context key/value pairs + Info(msg string, ctx ...interface{}) + + // Warn logs a message at the warn level with context key/value pairs + Warn(msg string, ctx ...interface{}) + + // Error logs a message at the error level with context key/value pairs + Error(msg string, ctx ...interface{}) + + // Crit logs a message at the crit level with context key/value pairs, and exits + Crit(msg string, ctx ...interface{}) + + // Write logs a message at the specified level + Write(level slog.Level, msg string, attrs ...any) + + // Enabled reports whether l emits log records at the given context and level. + Enabled(ctx context.Context, level slog.Level) bool + + // Handler returns the underlying handler of the inner logger. + Handler() slog.Handler +} + +type logger struct { + inner *slog.Logger +} + +// NewLogger returns a logger with the specified handler set +func NewLogger(h slog.Handler) Logger { + return &logger{ + slog.New(h), + } +} + +func (l *logger) Handler() slog.Handler { + return l.inner.Handler() +} + +// Write logs a message at the specified level. +func (l *logger) Write(level slog.Level, msg string, attrs ...any) { + if !l.inner.Enabled(context.Background(), level) { + return + } + + var pcs [1]uintptr + runtime.Callers(3, pcs[:]) + + if len(attrs)%2 != 0 { + attrs = append(attrs, nil, errorKey, "Normalized odd number of arguments by adding nil") + } + r := slog.NewRecord(time.Now(), level, msg, pcs[0]) + r.Add(attrs...) + l.inner.Handler().Handle(context.Background(), r) +} + +func (l *logger) Log(level slog.Level, msg string, attrs ...any) { + l.Write(level, msg, attrs...) +} + +func (l *logger) With(ctx ...interface{}) Logger { + return &logger{l.inner.With(ctx...)} +} + +func (l *logger) New(ctx ...interface{}) Logger { + return l.With(ctx...) +} + +// Enabled reports whether l emits log records at the given context and level. +func (l *logger) Enabled(ctx context.Context, level slog.Level) bool { + return l.inner.Enabled(ctx, level) +} + +func (l *logger) Trace(msg string, ctx ...interface{}) { + l.Write(LevelTrace, msg, ctx...) +} + +func (l *logger) Debug(msg string, ctx ...interface{}) { + l.Write(slog.LevelDebug, msg, ctx...) +} + +func (l *logger) Info(msg string, ctx ...interface{}) { + l.Write(slog.LevelInfo, msg, ctx...) +} + +func (l *logger) Warn(msg string, ctx ...any) { + l.Write(slog.LevelWarn, msg, ctx...) +} + +func (l *logger) Error(msg string, ctx ...interface{}) { + l.Write(slog.LevelError, msg, ctx...) +} + +func (l *logger) Crit(msg string, ctx ...interface{}) { + l.Write(LevelCrit, msg, ctx...) + os.Exit(1) +} diff --git a/client/log/root.go b/client/log/root.go new file mode 100644 index 00000000..91209c46 --- /dev/null +++ b/client/log/root.go @@ -0,0 +1,115 @@ +package log + +import ( + "log/slog" + "os" + "sync/atomic" +) + +var root atomic.Value + +func init() { + root.Store(&logger{slog.New(DiscardHandler())}) +} + +// SetDefault sets the default global logger +func SetDefault(l Logger) { + root.Store(l) + if lg, ok := l.(*logger); ok { + slog.SetDefault(lg.inner) + } +} + +// Root returns the root logger +func Root() Logger { + return root.Load().(Logger) +} + +// The following functions bypass the exported logger methods (logger.Debug, +// etc.) to keep the call depth the same for all paths to logger.Write so +// runtime.Caller(2) always refers to the call site in client code. + +// Trace is a convenient alias for Root().Trace +// +// Log a message at the trace level with context key/value pairs +// +// # Usage +// +// log.Trace("msg") +// log.Trace("msg", "key1", val1) +// log.Trace("msg", "key1", val1, "key2", val2) +func Trace(msg string, ctx ...interface{}) { + Root().Write(LevelTrace, msg, ctx...) +} + +// Debug is a convenient alias for Root().Debug +// +// Log a message at the debug level with context key/value pairs +// +// # Usage Examples +// +// log.Debug("msg") +// log.Debug("msg", "key1", val1) +// log.Debug("msg", "key1", val1, "key2", val2) +func Debug(msg string, ctx ...interface{}) { + Root().Write(slog.LevelDebug, msg, ctx...) +} + +// Info is a convenient alias for Root().Info +// +// Log a message at the info level with context key/value pairs +// +// # Usage Examples +// +// log.Info("msg") +// log.Info("msg", "key1", val1) +// log.Info("msg", "key1", val1, "key2", val2) +func Info(msg string, ctx ...interface{}) { + Root().Write(slog.LevelInfo, msg, ctx...) +} + +// Warn is a convenient alias for Root().Warn +// +// Log a message at the warn level with context key/value pairs +// +// # Usage Examples +// +// log.Warn("msg") +// log.Warn("msg", "key1", val1) +// log.Warn("msg", "key1", val1, "key2", val2) +func Warn(msg string, ctx ...interface{}) { + Root().Write(slog.LevelWarn, msg, ctx...) +} + +// Error is a convenient alias for Root().Error +// +// Log a message at the error level with context key/value pairs +// +// # Usage Examples +// +// log.Error("msg") +// log.Error("msg", "key1", val1) +// log.Error("msg", "key1", val1, "key2", val2) +func Error(msg string, ctx ...interface{}) { + Root().Write(slog.LevelError, msg, ctx...) +} + +// Crit is a convenient alias for Root().Crit +// +// Log a message at the crit level with context key/value pairs, and then exit. +// +// # Usage Examples +// +// log.Crit("msg") +// log.Crit("msg", "key1", val1) +// log.Crit("msg", "key1", val1, "key2", val2) +func Crit(msg string, ctx ...interface{}) { + Root().Write(LevelCrit, msg, ctx...) + os.Exit(1) +} + +// New returns a new logger with the given context. +// New is a convenient alias for Root().New +func New(ctx ...interface{}) Logger { + return Root().With(ctx...) +} diff --git a/client/server.go b/client/server.go new file mode 100644 index 00000000..80b00966 --- /dev/null +++ b/client/server.go @@ -0,0 +1,271 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + + "github.com/NethermindEth/starknet.go/client/log" +) + +const MetadataApi = "rpc" +const EngineApi = "engine" + +// CodecOption specifies which type of messages a codec supports. +// +// Deprecated: this option is no longer honored by Server. +type CodecOption int + +const ( + // OptionMethodInvocation is an indication that the codec supports RPC method calls + OptionMethodInvocation CodecOption = 1 << iota + + // OptionSubscriptions is an indication that the codec supports RPC notifications + OptionSubscriptions = 1 << iota // support pub sub +) + +// Server is an RPC server. +type Server struct { + services serviceRegistry + idgen func() ID + + mutex sync.Mutex + codecs map[ServerCodec]struct{} + run atomic.Bool + batchItemLimit int + batchResponseLimit int + httpBodyLimit int +} + +// NewServer creates a new server instance with no registered handlers. +func NewServer() *Server { + server := &Server{ + idgen: randomIDGenerator(), + codecs: make(map[ServerCodec]struct{}), + httpBodyLimit: defaultBodyLimit, + } + server.run.Store(true) + // Register the default service providing meta information about the RPC service such + // as the services and methods it offers. + rpcService := &RPCService{server} + server.RegisterName(MetadataApi, rpcService) + return server +} + +// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit' +// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of +// response bytes across all requests in a batch. +// +// This method should be called before processing any requests via ServeCodec, ServeHTTP, +// ServeListener etc. +func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) { + s.batchItemLimit = itemLimit + s.batchResponseLimit = maxResponseSize +} + +// SetHTTPBodyLimit sets the size limit for HTTP requests. +// +// This method should be called before processing any requests via ServeHTTP. +func (s *Server) SetHTTPBodyLimit(limit int) { + s.httpBodyLimit = limit +} + +// RegisterName creates a service for the given receiver type under the given name. When no +// methods on the given receiver match the criteria to be either an RPC method or a +// subscription an error is returned. Otherwise a new service is created and added to the +// service collection this server provides to clients. +func (s *Server) RegisterName(name string, receiver interface{}) error { + return s.services.registerName(name, receiver) +} + +// ServeListener accepts connections on l, serving JSON-RPC on them. +func (s *Server) ServeListener(l net.Listener) error { + for { + conn, err := l.Accept() + if isTemporaryError(err) { + log.Warn("RPC accept error", "err", err) + continue + } else if err != nil { + return err + } + log.Trace("Accepted RPC connection", "conn", conn.RemoteAddr()) + go s.ServeCodec(NewCodec(conn), 0) + } +} + +func isTemporaryError(err error) bool { + tempErr, ok := err.(interface { + Temporary() bool + }) + return ok && tempErr.Temporary() || false +} + +// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes +// the response back using the given codec. It will block until the codec is closed or the +// server is stopped. In either case the codec is closed. +// +// Note that codec options are no longer supported. +func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { + defer codec.close() + + if !s.trackCodec(codec) { + return + } + defer s.untrackCodec(codec) + + cfg := &clientConfig{ + idgen: s.idgen, + batchItemLimit: s.batchItemLimit, + batchResponseLimit: s.batchResponseLimit, + } + c := initClient(codec, &s.services, cfg) + <-codec.closed() + c.Close() +} + +func (s *Server) trackCodec(codec ServerCodec) bool { + s.mutex.Lock() + defer s.mutex.Unlock() + + if !s.run.Load() { + return false // Don't serve if server is stopped. + } + s.codecs[codec] = struct{}{} + return true +} + +func (s *Server) untrackCodec(codec ServerCodec) { + s.mutex.Lock() + defer s.mutex.Unlock() + + delete(s.codecs, codec) +} + +// serveSingleRequest reads and processes a single RPC request from the given codec. This +// is used to serve HTTP connections. Subscriptions and reverse calls are not allowed in +// this mode. +func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { + // Don't serve if server is stopped. + if !s.run.Load() { + return + } + + h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit) + h.allowSubscribe = false + defer h.close(io.EOF, nil) + + reqs, batch, err := codec.readBatch() + if err != nil { + if msg := messageForReadError(err); msg != "" { + resp := errorMessage(&invalidMessageError{msg}) + codec.writeJSON(ctx, resp, true) + } + return + } + if batch { + h.handleBatch(reqs) + } else { + h.handleMsg(reqs[0]) + } +} + +func messageForReadError(err error) string { + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + return "read timeout" + } else { + return "read error" + } + } else if err != io.EOF { + return "parse error" + } + return "" +} + +// Stop stops reading new requests, waits for stopPendingRequestTimeout to allow pending +// requests to finish, then closes all codecs which will cancel pending requests and +// subscriptions. +func (s *Server) Stop() { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.run.CompareAndSwap(true, false) { + log.Debug("RPC server shutting down") + for codec := range s.codecs { + codec.close() + } + } +} + +// RPCService gives meta information about the server. +// e.g. gives information about the loaded modules. +type RPCService struct { + server *Server +} + +// Modules returns the list of RPC services with their version number +func (s *RPCService) Modules() map[string]string { + s.server.services.mu.Lock() + defer s.server.services.mu.Unlock() + + modules := make(map[string]string) + for name := range s.server.services.services { + modules[name] = "1.0" + } + return modules +} + +// PeerInfo contains information about the remote end of the network connection. +// +// This is available within RPC method handlers through the context. Call +// PeerInfoFromContext to get information about the client connection related to +// the current method call. +type PeerInfo struct { + // Transport is name of the protocol used by the client. + // This can be "http", "ws" or "ipc". + Transport string + + // Address of client. This will usually contain the IP address and port. + RemoteAddr string + + // Additional information for HTTP and WebSocket connections. + HTTP struct { + // Protocol version, i.e. "HTTP/1.1". This is not set for WebSocket. + Version string + // Header values sent by the client. + UserAgent string + Origin string + Host string + } +} + +type peerInfoContextKey struct{} + +// PeerInfoFromContext returns information about the client's network connection. +// Use this with the context passed to RPC method handler functions. +// +// The zero value is returned if no connection info is present in ctx. +func PeerInfoFromContext(ctx context.Context) PeerInfo { + info, _ := ctx.Value(peerInfoContextKey{}).(PeerInfo) + return info +} diff --git a/client/server_test.go b/client/server_test.go new file mode 100644 index 00000000..d9ff9eb9 --- /dev/null +++ b/client/server_test.go @@ -0,0 +1,194 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "bufio" + "bytes" + "io" + "net" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestServerRegisterName(t *testing.T) { + server := NewServer() + service := new(testService) + + svcName := "test" + if err := server.RegisterName(svcName, service); err != nil { + t.Fatalf("%v", err) + } + + if len(server.services.services) != 2 { + t.Fatalf("Expected 2 service entries, got %d", len(server.services.services)) + } + + svc, ok := server.services.services[svcName] + if !ok { + t.Fatalf("Expected service %s to be registered", svcName) + } + + wantCallbacks := 14 + if len(svc.callbacks) != wantCallbacks { + t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks)) + } +} + +func TestServer(t *testing.T) { + files, err := os.ReadDir("testdata") + if err != nil { + t.Fatal("where'd my testdata go?") + } + for _, f := range files { + if f.IsDir() || strings.HasPrefix(f.Name(), ".") { + continue + } + path := filepath.Join("testdata", f.Name()) + name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name())) + t.Run(name, func(t *testing.T) { + runTestScript(t, path) + }) + } +} + +func runTestScript(t *testing.T, file string) { + server := newTestServer() + server.SetBatchLimits(4, 100000) + content, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + go server.ServeCodec(NewCodec(serverConn), 0) + readbuf := bufio.NewReader(clientConn) + for _, line := range strings.Split(string(content), "\n") { + line = strings.TrimSpace(line) + switch { + case len(line) == 0 || strings.HasPrefix(line, "//"): + // skip comments, blank lines + continue + case strings.HasPrefix(line, "--> "): + t.Log(line) + // write to connection + clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if _, err := io.WriteString(clientConn, line[4:]+"\n"); err != nil { + t.Fatalf("write error: %v", err) + } + case strings.HasPrefix(line, "<-- "): + t.Log(line) + want := line[4:] + // read line from connection and compare text + clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + sent, err := readbuf.ReadString('\n') + if err != nil { + t.Fatalf("read error: %v", err) + } + sent = strings.TrimRight(sent, "\r\n") + if sent != want { + t.Errorf("wrong line from server\ngot: %s\nwant: %s", sent, want) + } + default: + panic("invalid line in test script: " + line) + } + } +} + +// This test checks that responses are delivered for very short-lived connections that +// only carry a single request. +func TestServerShortLivedConn(t *testing.T) { + server := newTestServer() + defer server.Stop() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("can't listen:", err) + } + defer listener.Close() + go server.ServeListener(listener) + + var ( + request = `{"jsonrpc":"2.0","id":1,"method":"rpc_modules"}` + "\n" + wantResp = `{"jsonrpc":"2.0","id":1,"result":{"nftest":"1.0","rpc":"1.0","test":"1.0"}}` + "\n" + deadline = time.Now().Add(10 * time.Second) + ) + for i := 0; i < 20; i++ { + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatal("can't dial:", err) + } + + conn.SetDeadline(deadline) + // Write the request, then half-close the connection so the server stops reading. + conn.Write([]byte(request)) + conn.(*net.TCPConn).CloseWrite() + // Now try to get the response. + buf := make([]byte, 2000) + n, err := conn.Read(buf) + conn.Close() + + if err != nil { + t.Fatal("read error:", err) + } + if !bytes.Equal(buf[:n], []byte(wantResp)) { + t.Fatalf("wrong response: %s", buf[:n]) + } + } +} + +func TestServerBatchResponseSizeLimit(t *testing.T) { + server := newTestServer() + defer server.Stop() + server.SetBatchLimits(100, 60) + var ( + batch []BatchElem + client = DialInProc(server) + ) + for i := 0; i < 5; i++ { + batch = append(batch, BatchElem{ + Method: "test_echo", + Args: []any{"x", 1}, + Result: new(echoResult), + }) + } + if err := client.BatchCall(batch); err != nil { + t.Fatal("error sending batch:", err) + } + for i := range batch { + // We expect the first two queries to be ok, but after that the size limit takes effect. + if i < 2 { + if batch[i].Error != nil { + t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error) + } + continue + } + // After two, we expect an error. + re, ok := batch[i].Error.(Error) + if !ok { + t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error) + } + wantedCode := errcodeResponseTooLarge + if re.ErrorCode() != wantedCode { + t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode) + } + } +} diff --git a/client/service.go b/client/service.go new file mode 100644 index 00000000..65814ee3 --- /dev/null +++ b/client/service.go @@ -0,0 +1,249 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "fmt" + "reflect" + "runtime" + "strings" + "sync" + "unicode" + + "github.com/NethermindEth/starknet.go/client/log" +) + +var ( + contextType = reflect.TypeOf((*context.Context)(nil)).Elem() + errorType = reflect.TypeOf((*error)(nil)).Elem() + subscriptionType = reflect.TypeOf(Subscription{}) + stringType = reflect.TypeOf("") +) + +type serviceRegistry struct { + mu sync.Mutex + services map[string]service +} + +// service represents a registered object. +type service struct { + name string // name for service + callbacks map[string]*callback // registered handlers + subscriptions map[string]*callback // available subscriptions/notifications +} + +// callback is a method callback which was registered in the server +type callback struct { + fn reflect.Value // the function + rcvr reflect.Value // receiver object of method, set if fn is method + argTypes []reflect.Type // input argument types + hasCtx bool // method's first argument is a context (not included in argTypes) + errPos int // err return idx, of -1 when method cannot return error + isSubscribe bool // true if this is a subscription callback +} + +func (r *serviceRegistry) registerName(name string, rcvr interface{}) error { + rcvrVal := reflect.ValueOf(rcvr) + if name == "" { + return fmt.Errorf("no service name for type %s", rcvrVal.Type().String()) + } + callbacks := suitableCallbacks(rcvrVal) + if len(callbacks) == 0 { + return fmt.Errorf("service %T doesn't have any suitable methods/subscriptions to expose", rcvr) + } + + r.mu.Lock() + defer r.mu.Unlock() + if r.services == nil { + r.services = make(map[string]service) + } + svc, ok := r.services[name] + if !ok { + svc = service{ + name: name, + callbacks: make(map[string]*callback), + subscriptions: make(map[string]*callback), + } + r.services[name] = svc + } + for name, cb := range callbacks { + if cb.isSubscribe { + svc.subscriptions[name] = cb + } else { + svc.callbacks[name] = cb + } + } + return nil +} + +// callback returns the callback corresponding to the given RPC method name. +func (r *serviceRegistry) callback(method string) *callback { + before, after, found := strings.Cut(method, serviceMethodSeparator) + if !found { + return nil + } + r.mu.Lock() + defer r.mu.Unlock() + return r.services[before].callbacks[after] +} + +// subscription returns a subscription callback in the given service. +func (r *serviceRegistry) subscription(service, name string) *callback { + r.mu.Lock() + defer r.mu.Unlock() + return r.services[service].subscriptions[name] +} + +// suitableCallbacks iterates over the methods of the given type. It determines if a method +// satisfies the criteria for an RPC callback or a subscription callback and adds it to the +// collection of callbacks. See server documentation for a summary of these criteria. +func suitableCallbacks(receiver reflect.Value) map[string]*callback { + typ := receiver.Type() + callbacks := make(map[string]*callback) + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + if method.PkgPath != "" { + continue // method not exported + } + cb := newCallback(receiver, method.Func) + if cb == nil { + continue // function invalid + } + name := formatName(method.Name) + callbacks[name] = cb + } + return callbacks +} + +// newCallback turns fn (a function) into a callback object. It returns nil if the function +// is unsuitable as an RPC callback. +func newCallback(receiver, fn reflect.Value) *callback { + fntype := fn.Type() + c := &callback{fn: fn, rcvr: receiver, errPos: -1, isSubscribe: isPubSub(fntype)} + // Determine parameter types. They must all be exported or builtin types. + c.makeArgTypes() + + // Verify return types. The function must return at most one error + // and/or one other non-error value. + outs := make([]reflect.Type, fntype.NumOut()) + for i := 0; i < fntype.NumOut(); i++ { + outs[i] = fntype.Out(i) + } + if len(outs) > 2 { + return nil + } + // If an error is returned, it must be the last returned value. + switch { + case len(outs) == 1 && isErrorType(outs[0]): + c.errPos = 0 + case len(outs) == 2: + if isErrorType(outs[0]) || !isErrorType(outs[1]) { + return nil + } + c.errPos = 1 + } + return c +} + +// makeArgTypes composes the argTypes list. +func (c *callback) makeArgTypes() { + fntype := c.fn.Type() + // Skip receiver and context.Context parameter (if present). + firstArg := 0 + if c.rcvr.IsValid() { + firstArg++ + } + if fntype.NumIn() > firstArg && fntype.In(firstArg) == contextType { + c.hasCtx = true + firstArg++ + } + // Add all remaining parameters. + c.argTypes = make([]reflect.Type, fntype.NumIn()-firstArg) + for i := firstArg; i < fntype.NumIn(); i++ { + c.argTypes[i-firstArg] = fntype.In(i) + } +} + +// call invokes the callback. +func (c *callback) call(ctx context.Context, method string, args []reflect.Value) (res interface{}, errRes error) { + // Create the argument slice. + fullargs := make([]reflect.Value, 0, 2+len(args)) + if c.rcvr.IsValid() { + fullargs = append(fullargs, c.rcvr) + } + if c.hasCtx { + fullargs = append(fullargs, reflect.ValueOf(ctx)) + } + fullargs = append(fullargs, args...) + + // Catch panic while running the callback. + defer func() { + if err := recover(); err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + log.Error("RPC method " + method + " crashed: " + fmt.Sprintf("%v\n%s", err, buf)) + errRes = &internalServerError{errcodePanic, "method handler crashed"} + } + }() + // Run the callback. + results := c.fn.Call(fullargs) + if len(results) == 0 { + return nil, nil + } + if c.errPos >= 0 && !results[c.errPos].IsNil() { + // Method has returned non-nil error value. + err := results[c.errPos].Interface().(error) + return reflect.Value{}, err + } + return results[0].Interface(), nil +} + +// Does t satisfy the error interface? +func isErrorType(t reflect.Type) bool { + return t.Implements(errorType) +} + +// Is t Subscription or *Subscription? +func isSubscriptionType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t == subscriptionType +} + +// isPubSub tests whether the given method's first argument is a context.Context and +// returns the pair (Subscription, error). +func isPubSub(methodType reflect.Type) bool { + // numIn(0) is the receiver type + if methodType.NumIn() < 2 || methodType.NumOut() != 2 { + return false + } + return methodType.In(1) == contextType && + isSubscriptionType(methodType.Out(0)) && + isErrorType(methodType.Out(1)) +} + +// formatName converts to first character of name to lowercase. +func formatName(name string) string { + ret := []rune(name) + if len(ret) > 0 { + ret[0] = unicode.ToLower(ret[0]) + } + return string(ret) +} diff --git a/client/subscription.go b/client/subscription.go new file mode 100644 index 00000000..db378be6 --- /dev/null +++ b/client/subscription.go @@ -0,0 +1,378 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "container/list" + "context" + crand "crypto/rand" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "math/rand" + "reflect" + "strings" + "sync" + "time" +) + +var ( + // ErrNotificationsUnsupported is returned by the client when the connection doesn't + // support notifications. You can use this error value to check for subscription + // support like this: + // + // sub, err := client.EthSubscribe(ctx, channel, "newHeads", true) + // if errors.Is(err, rpc.ErrNotificationsUnsupported) { + // // Server does not support subscriptions, fall back to polling. + // } + // + ErrNotificationsUnsupported = notificationsUnsupportedError{} + + // ErrSubscriptionNotFound is returned when the notification for the given id is not found + ErrSubscriptionNotFound = errors.New("subscription not found") +) + +var globalGen = randomIDGenerator() + +// ID defines a pseudo random number that is used to identify RPC subscriptions. +type ID string + +// NewID returns a new, random ID. +func NewID() ID { + return globalGen() +} + +// randomIDGenerator returns a function generates a random IDs. +func randomIDGenerator() func() ID { + var buf = make([]byte, 8) + var seed int64 + if _, err := crand.Read(buf); err == nil { + seed = int64(binary.BigEndian.Uint64(buf)) + } else { + seed = int64(time.Now().Nanosecond()) + } + + var ( + mu sync.Mutex + rng = rand.New(rand.NewSource(seed)) + ) + return func() ID { + mu.Lock() + defer mu.Unlock() + id := make([]byte, 16) + rng.Read(id) + return encodeID(id) + } +} + +func encodeID(b []byte) ID { + id := hex.EncodeToString(b) + id = strings.TrimLeft(id, "0") + if id == "" { + id = "0" // ID's are RPC quantities, no leading zero's and 0 is 0x0. + } + return ID("0x" + id) +} + +type notifierKey struct{} + +// NotifierFromContext returns the Notifier value stored in ctx, if any. +func NotifierFromContext(ctx context.Context) (*Notifier, bool) { + n, ok := ctx.Value(notifierKey{}).(*Notifier) + return n, ok +} + +// Notifier is tied to an RPC connection that supports subscriptions. +// Server callbacks use the notifier to send notifications. +type Notifier struct { + h *handler + namespace string + + mu sync.Mutex + sub *Subscription + buffer []any + callReturned bool + activated bool +} + +// CreateSubscription returns a new subscription that is coupled to the +// RPC connection. By default subscriptions are inactive and notifications +// are dropped until the subscription is marked as active. This is done +// by the RPC server after the subscription ID is send to the client. +func (n *Notifier) CreateSubscription() *Subscription { + n.mu.Lock() + defer n.mu.Unlock() + + if n.sub != nil { + panic("can't create multiple subscriptions with Notifier") + } else if n.callReturned { + panic("can't create subscription after subscribe call has returned") + } + n.sub = &Subscription{ID: n.h.idgen(), namespace: n.namespace, err: make(chan error, 1)} + return n.sub +} + +// Notify sends a notification to the client with the given data as payload. +// If an error occurs the RPC connection is closed and the error is returned. +func (n *Notifier) Notify(id ID, data any) error { + n.mu.Lock() + defer n.mu.Unlock() + + if n.sub == nil { + panic("can't Notify before subscription is created") + } else if n.sub.ID != id { + panic("Notify with wrong ID") + } + if n.activated { + return n.send(n.sub, data) + } + n.buffer = append(n.buffer, data) + return nil +} + +// takeSubscription returns the subscription (if one has been created). No subscription can +// be created after this call. +func (n *Notifier) takeSubscription() *Subscription { + n.mu.Lock() + defer n.mu.Unlock() + n.callReturned = true + return n.sub +} + +// activate is called after the subscription ID was sent to client. Notifications are +// buffered before activation. This prevents notifications being sent to the client before +// the subscription ID is sent to the client. +func (n *Notifier) activate() error { + n.mu.Lock() + defer n.mu.Unlock() + + for _, data := range n.buffer { + if err := n.send(n.sub, data); err != nil { + return err + } + } + n.activated = true + return nil +} + +func (n *Notifier) send(sub *Subscription, data any) error { + msg := jsonrpcSubscriptionNotification{ + Version: vsn, + Method: n.namespace + notificationMethodSuffix, + Params: subscriptionResultEnc{ + ID: string(sub.ID), + Result: data, + }, + } + return n.h.conn.writeJSON(context.Background(), &msg, false) +} + +// A Subscription is created by a notifier and tied to that notifier. The client can use +// this subscription to wait for an unsubscribe request for the client, see Err(). +type Subscription struct { + ID ID + namespace string + err chan error // closed on unsubscribe +} + +// Err returns a channel that is closed when the client send an unsubscribe request. +func (s *Subscription) Err() <-chan error { + return s.err +} + +// MarshalJSON marshals a subscription as its ID. +func (s *Subscription) MarshalJSON() ([]byte, error) { + return json.Marshal(s.ID) +} + +// ClientSubscription is a subscription established through the Client's Subscribe or +// EthSubscribe methods. +type ClientSubscription struct { + client *Client + etype reflect.Type + channel reflect.Value + namespace string + subid string + + // The in channel receives notification values from client dispatcher. + in chan json.RawMessage + + // The error channel receives the error from the forwarding loop. + // It is closed by Unsubscribe. + err chan error + errOnce sync.Once + + // Closing of the subscription is requested by sending on 'quit'. This is handled by + // the forwarding loop, which closes 'forwardDone' when it has stopped sending to + // sub.channel. Finally, 'unsubDone' is closed after unsubscribing on the server side. + quit chan error + forwardDone chan struct{} + unsubDone chan struct{} +} + +// This is the sentinel value sent on sub.quit when Unsubscribe is called. +var errUnsubscribed = errors.New("unsubscribed") + +func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription { + sub := &ClientSubscription{ + client: c, + namespace: namespace, + etype: channel.Type().Elem(), + channel: channel, + in: make(chan json.RawMessage), + quit: make(chan error), + forwardDone: make(chan struct{}), + unsubDone: make(chan struct{}), + err: make(chan error, 1), + } + return sub +} + +// Err returns the subscription error channel. The intended use of Err is to schedule +// resubscription when the client connection is closed unexpectedly. +// +// The error channel receives a value when the subscription has ended due to an error. The +// received error is nil if Close has been called on the underlying client and no other +// error has occurred. +// +// The error channel is closed when Unsubscribe is called on the subscription. +func (sub *ClientSubscription) Err() <-chan error { + return sub.err +} + +// Unsubscribe unsubscribes the notification and closes the error channel. +// It can safely be called more than once. +func (sub *ClientSubscription) Unsubscribe() { + sub.errOnce.Do(func() { + select { + case sub.quit <- errUnsubscribed: + <-sub.unsubDone + case <-sub.unsubDone: + } + close(sub.err) + }) +} + +// deliver is called by the client's message dispatcher to send a notification value. +func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) { + select { + case sub.in <- result: + return true + case <-sub.forwardDone: + return false + } +} + +// close is called by the client's message dispatcher when the connection is closed. +func (sub *ClientSubscription) close(err error) { + select { + case sub.quit <- err: + case <-sub.forwardDone: + } +} + +// run is the forwarding loop of the subscription. It runs in its own goroutine and +// is launched by the client's handler after the subscription has been created. +func (sub *ClientSubscription) run() { + defer close(sub.unsubDone) + + unsubscribe, err := sub.forward() + + // The client's dispatch loop won't be able to execute the unsubscribe call if it is + // blocked in sub.deliver() or sub.close(). Closing forwardDone unblocks them. + close(sub.forwardDone) + + // Call the unsubscribe method on the server. + if unsubscribe { + sub.requestUnsubscribe() + } + + // Send the error. + if err != nil { + if err == ErrClientQuit { + // ErrClientQuit gets here when Client.Close is called. This is reported as a + // nil error because it's not an error, but we can't close sub.err here. + err = nil + } + sub.err <- err + } +} + +// forward is the forwarding loop. It takes in RPC notifications and sends them +// on the subscription channel. +func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { + cases := []reflect.SelectCase{ + {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)}, + {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)}, + {Dir: reflect.SelectSend, Chan: sub.channel}, + } + buffer := list.New() + + for { + var chosen int + var recv reflect.Value + if buffer.Len() == 0 { + // Idle, omit send case. + chosen, recv, _ = reflect.Select(cases[:2]) + } else { + // Non-empty buffer, send the first queued item. + cases[2].Send = reflect.ValueOf(buffer.Front().Value) + chosen, recv, _ = reflect.Select(cases) + } + + switch chosen { + case 0: // <-sub.quit + if !recv.IsNil() { + err = recv.Interface().(error) + } + if err == errUnsubscribed { + // Exiting because Unsubscribe was called, unsubscribe on server. + return true, nil + } + return false, err + + case 1: // <-sub.in + val, err := sub.unmarshal(recv.Interface().(json.RawMessage)) + if err != nil { + return true, err + } + if buffer.Len() == maxClientSubscriptionBuffer { + return true, ErrSubscriptionQueueOverflow + } + buffer.PushBack(val) + + case 2: // sub.channel<- + cases[2].Send = reflect.Value{} // Don't hold onto the value. + buffer.Remove(buffer.Front()) + } + } +} + +func (sub *ClientSubscription) unmarshal(result json.RawMessage) (interface{}, error) { + val := reflect.New(sub.etype) + err := json.Unmarshal(result, val.Interface()) + return val.Elem().Interface(), err +} + +func (sub *ClientSubscription) requestUnsubscribe() error { + var result interface{} + ctx, cancel := context.WithTimeout(context.Background(), unsubscribeTimeout) + defer cancel() + err := sub.client.CallContext(ctx, &result, sub.namespace+unsubscribeMethodSuffix, sub.subid) + return err +} diff --git a/client/testservice_test.go b/client/testservice_test.go new file mode 100644 index 00000000..542aa7a6 --- /dev/null +++ b/client/testservice_test.go @@ -0,0 +1,229 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "encoding/binary" + "errors" + "strings" + "sync" + "time" +) + +func newTestServer() *Server { + server := NewServer() + server.idgen = sequentialIDGenerator() + if err := server.RegisterName("test", new(testService)); err != nil { + panic(err) + } + if err := server.RegisterName("nftest", new(notificationTestService)); err != nil { + panic(err) + } + return server +} + +func sequentialIDGenerator() func() ID { + var ( + mu sync.Mutex + counter uint64 + ) + return func() ID { + mu.Lock() + defer mu.Unlock() + counter++ + id := make([]byte, 8) + binary.BigEndian.PutUint64(id, counter) + return encodeID(id) + } +} + +type testService struct{} + +type echoArgs struct { + S string +} + +type echoResult struct { + String string + Int int + Args *echoArgs +} + +type testError struct{} + +func (testError) Error() string { return "testError" } +func (testError) ErrorCode() int { return 444 } +func (testError) ErrorData() interface{} { return "testError data" } + +type MarshalErrObj struct{} + +func (o *MarshalErrObj) MarshalText() ([]byte, error) { + return nil, errors.New("marshal error") +} + +func (s *testService) NoArgsRets() {} + +func (s *testService) Null() any { + return nil +} + +func (s *testService) Echo(str string, i int, args *echoArgs) echoResult { + return echoResult{str, i, args} +} + +func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *echoArgs) echoResult { + return echoResult{str, i, args} +} + +func (s *testService) Repeat(msg string, i int) string { + return strings.Repeat(msg, i) +} + +func (s *testService) PeerInfo(ctx context.Context) PeerInfo { + return PeerInfoFromContext(ctx) +} + +func (s *testService) Sleep(ctx context.Context, duration time.Duration) { + time.Sleep(duration) +} + +func (s *testService) Block(ctx context.Context) error { + <-ctx.Done() + return errors.New("context canceled in testservice_block") +} + +func (s *testService) Rets() (string, error) { + return "", nil +} + +//lint:ignore ST1008 returns error first on purpose. +func (s *testService) InvalidRets1() (error, string) { + return nil, "" +} + +func (s *testService) InvalidRets2() (string, string) { + return "", "" +} + +func (s *testService) InvalidRets3() (string, string, error) { + return "", "", nil +} + +func (s *testService) ReturnError() error { + return testError{} +} + +func (s *testService) MarshalError() *MarshalErrObj { + return &MarshalErrObj{} +} + +func (s *testService) Panic() string { + panic("service panic") +} + +func (s *testService) CallMeBack(ctx context.Context, method string, args []interface{}) (interface{}, error) { + c, ok := ClientFromContext(ctx) + if !ok { + return nil, errors.New("no client") + } + var result interface{} + err := c.Call(&result, method, args...) + return result, err +} + +func (s *testService) CallMeBackLater(ctx context.Context, method string, args []interface{}) error { + c, ok := ClientFromContext(ctx) + if !ok { + return errors.New("no client") + } + go func() { + <-ctx.Done() + var result interface{} + c.Call(&result, method, args...) + }() + return nil +} + +func (s *testService) Subscription(ctx context.Context) (*Subscription, error) { + return nil, nil +} + +type notificationTestService struct { + unsubscribed chan string + gotHangSubscriptionReq chan struct{} + unblockHangSubscription chan struct{} +} + +func (s *notificationTestService) Echo(i int) int { + return i +} + +func (s *notificationTestService) Unsubscribe(subid string) { + if s.unsubscribed != nil { + s.unsubscribed <- subid + } +} + +func (s *notificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) { + notifier, supported := NotifierFromContext(ctx) + if !supported { + return nil, ErrNotificationsUnsupported + } + + // By explicitly creating an subscription we make sure that the subscription id is send + // back to the client before the first subscription.Notify is called. Otherwise the + // events might be send before the response for the *_subscribe method. + subscription := notifier.CreateSubscription() + go func() { + for i := 0; i < n; i++ { + if err := notifier.Notify(subscription.ID, val+i); err != nil { + return + } + } + <-subscription.Err() + if s.unsubscribed != nil { + s.unsubscribed <- string(subscription.ID) + } + }() + return subscription, nil +} + +// HangSubscription blocks on s.unblockHangSubscription before sending anything. +func (s *notificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) { + notifier, supported := NotifierFromContext(ctx) + if !supported { + return nil, ErrNotificationsUnsupported + } + s.gotHangSubscriptionReq <- struct{}{} + <-s.unblockHangSubscription + subscription := notifier.CreateSubscription() + + go func() { + notifier.Notify(subscription.ID, val) + }() + return subscription, nil +} + +// largeRespService generates arbitrary-size JSON responses. +type largeRespService struct { + length int +} + +func (x largeRespService) LargeResp() string { + return strings.Repeat("x", x.length) +} diff --git a/client/types.go b/client/types.go new file mode 100644 index 00000000..5d835593 --- /dev/null +++ b/client/types.go @@ -0,0 +1,44 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" +) + +// ServerCodec implements reading, parsing and writing RPC messages for the server side of +// an RPC session. Implementations must be go-routine safe since the codec can be called in +// multiple go-routines concurrently. +type ServerCodec interface { + peerInfo() PeerInfo + readBatch() (msgs []*jsonrpcMessage, isBatch bool, err error) + close() + + jsonWriter +} + +// jsonWriter can write JSON messages to its underlying connection. +// Implementations must be safe for concurrent use. +type jsonWriter interface { + // writeJSON writes a message to the connection. + writeJSON(ctx context.Context, msg interface{}, isError bool) error + + // Closed returns a channel which is closed when the connection is closed. + closed() <-chan interface{} + // RemoteAddr returns the peer address of the connection. + remoteAddr() string +} diff --git a/client/websocket.go b/client/websocket.go new file mode 100644 index 00000000..ef2e36cc --- /dev/null +++ b/client/websocket.go @@ -0,0 +1,376 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/NethermindEth/starknet.go/client/log" + mapset "github.com/deckarep/golang-set/v2" + "github.com/gorilla/websocket" +) + +const ( + wsReadBuffer = 1024 + wsWriteBuffer = 1024 + wsPingInterval = 30 * time.Second + wsPingWriteTimeout = 5 * time.Second + wsPongTimeout = 30 * time.Second + wsDefaultReadLimit = 32 * 1024 * 1024 +) + +var wsBufferPool = new(sync.Pool) + +// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. +// +// allowedOrigins should be a comma-separated list of allowed origin URLs. +// To allow connections with any origin, pass "*". +func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { + var upgrader = websocket.Upgrader{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + CheckOrigin: wsHandshakeValidator(allowedOrigins), + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Debug("WebSocket upgrade failed", "err", err) + return + } + codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit) + s.ServeCodec(codec, 0) + }) +} + +// wsHandshakeValidator returns a handler that verifies the origin during the +// websocket upgrade process. When a '*' is specified as an allowed origins all +// connections are accepted. +func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { + origins := mapset.NewSet[string]() + allowAllOrigins := false + + for _, origin := range allowedOrigins { + if origin == "*" { + allowAllOrigins = true + } + if origin != "" { + origins.Add(origin) + } + } + // allow localhost if no allowedOrigins are specified. + if len(origins.ToSlice()) == 0 { + origins.Add("http://localhost") + if hostname, err := os.Hostname(); err == nil { + origins.Add("http://" + hostname) + } + } + log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) + + f := func(req *http.Request) bool { + // Skip origin verification if no Origin header is present. The origin check + // is supposed to protect against browser based attacks. Browsers always set + // Origin. Non-browser software can put anything in origin and checking it doesn't + // provide additional security. + if _, ok := req.Header["Origin"]; !ok { + return true + } + // Verify origin against allow list. + origin := strings.ToLower(req.Header.Get("Origin")) + if allowAllOrigins || originIsAllowed(origins, origin) { + return true + } + log.Warn("Rejected WebSocket connection", "origin", origin) + return false + } + + return f +} + +type wsHandshakeError struct { + err error + status string +} + +func (e wsHandshakeError) Error() string { + s := e.err.Error() + if e.status != "" { + s += " (HTTP status " + e.status + ")" + } + return s +} + +func (e wsHandshakeError) Unwrap() error { + return e.err +} + +func originIsAllowed(allowedOrigins mapset.Set[string], browserOrigin string) bool { + it := allowedOrigins.Iterator() + for origin := range it.C { + if ruleAllowsOrigin(origin, browserOrigin) { + return true + } + } + return false +} + +func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { + var ( + allowedScheme, allowedHostname, allowedPort string + browserScheme, browserHostname, browserPort string + err error + ) + allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) + if err != nil { + log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) + return false + } + browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) + if err != nil { + log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) + return false + } + if allowedScheme != "" && allowedScheme != browserScheme { + return false + } + if allowedHostname != "" && allowedHostname != browserHostname { + return false + } + if allowedPort != "" && allowedPort != browserPort { + return false + } + return true +} + +func parseOriginURL(origin string) (string, string, string, error) { + parsedURL, err := url.Parse(strings.ToLower(origin)) + if err != nil { + return "", "", "", err + } + var scheme, hostname, port string + if strings.Contains(origin, "://") { + scheme = parsedURL.Scheme + hostname = parsedURL.Hostname() + port = parsedURL.Port() + } else { + scheme = "" + hostname = parsedURL.Scheme + port = parsedURL.Opaque + if hostname == "" { + hostname = origin + } + } + return scheme, hostname, port, nil +} + +// DialWebsocketWithDialer creates a new RPC client using WebSocket. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +// +// Deprecated: use DialOptions and the WithWebsocketDialer option. +func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { + cfg := new(clientConfig) + cfg.wsDialer = &dialer + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) + if err != nil { + return nil, err + } + return newClient(ctx, cfg, connect) +} + +// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server +// that is listening on the given endpoint. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { + cfg := new(clientConfig) + if origin != "" { + cfg.setHeader("origin", origin) + } + connect, err := newClientTransportWS(endpoint, cfg) + if err != nil { + return nil, err + } + return newClient(ctx, cfg, connect) +} + +func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { + dialer := cfg.wsDialer + if dialer == nil { + dialer = &websocket.Dialer{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + Proxy: http.ProxyFromEnvironment, + } + } + + dialURL, header, err := wsClientHeaders(endpoint, "") + if err != nil { + return nil, err + } + for key, values := range cfg.httpHeaders { + header[key] = values + } + + connect := func(ctx context.Context) (ServerCodec, error) { + header := header.Clone() + if cfg.httpAuth != nil { + if err := cfg.httpAuth(header); err != nil { + return nil, err + } + } + conn, resp, err := dialer.DialContext(ctx, dialURL, header) + if err != nil { + hErr := wsHandshakeError{err: err} + if resp != nil { + hErr.status = resp.Status + } + return nil, hErr + } + messageSizeLimit := int64(wsDefaultReadLimit) + if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 { + messageSizeLimit = *cfg.wsMessageSizeLimit + } + return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil + } + return connect, nil +} + +func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { + endpointURL, err := url.Parse(endpoint) + if err != nil { + return endpoint, nil, err + } + header := make(http.Header) + if origin != "" { + header.Add("origin", origin) + } + if endpointURL.User != nil { + b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) + header.Add("authorization", "Basic "+b64auth) + endpointURL.User = nil + } + return endpointURL.String(), header, nil +} + +type websocketCodec struct { + *jsonCodec + conn *websocket.Conn + info PeerInfo + + wg sync.WaitGroup + pingReset chan struct{} + pongReceived chan struct{} +} + +func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec { + conn.SetReadLimit(readLimit) + encode := func(v interface{}, isErrorResponse bool) error { + return conn.WriteJSON(v) + } + wc := &websocketCodec{ + jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), + conn: conn, + pingReset: make(chan struct{}, 1), + pongReceived: make(chan struct{}), + info: PeerInfo{ + Transport: "ws", + RemoteAddr: conn.RemoteAddr().String(), + }, + } + // Fill in connection details. + wc.info.HTTP.Host = host + wc.info.HTTP.Origin = req.Get("Origin") + wc.info.HTTP.UserAgent = req.Get("User-Agent") + // Start pinger. + conn.SetPongHandler(func(appData string) error { + select { + case wc.pongReceived <- struct{}{}: + case <-wc.closed(): + } + return nil + }) + wc.wg.Add(1) + go wc.pingLoop() + return wc +} + +func (wc *websocketCodec) close() { + wc.jsonCodec.close() + wc.wg.Wait() +} + +func (wc *websocketCodec) peerInfo() PeerInfo { + return wc.info +} + +func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { + err := wc.jsonCodec.writeJSON(ctx, v, isError) + if err == nil { + // Notify pingLoop to delay the next idle ping. + select { + case wc.pingReset <- struct{}{}: + default: + } + } + return err +} + +// pingLoop sends periodic ping frames when the connection is idle. +func (wc *websocketCodec) pingLoop() { + var pingTimer = time.NewTimer(wsPingInterval) + defer wc.wg.Done() + defer pingTimer.Stop() + + for { + select { + case <-wc.closed(): + return + + case <-wc.pingReset: + if !pingTimer.Stop() { + <-pingTimer.C + } + pingTimer.Reset(wsPingInterval) + + case <-pingTimer.C: + wc.jsonCodec.encMu.Lock() + wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) + wc.conn.WriteMessage(websocket.PingMessage, nil) + wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) + wc.jsonCodec.encMu.Unlock() + pingTimer.Reset(wsPingInterval) + + case <-wc.pongReceived: + wc.conn.SetReadDeadline(time.Time{}) + } + } +} diff --git a/go.mod b/go.mod index cdfaca71..ffb56e0c 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.23.1 require ( github.com/NethermindEth/juno v0.12.2 - github.com/ethereum/go-ethereum v1.14.8 github.com/gorilla/websocket v1.5.3 github.com/joho/godotenv v1.4.0 github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249 @@ -16,23 +15,16 @@ require ( ) require ( - github.com/Microsoft/go-winio v0.6.2 // indirect github.com/bits-and-blooms/bitset v1.14.2 // indirect github.com/consensys/bavard v0.1.13 // indirect github.com/consensys/gnark-crypto v0.13.0 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/deckarep/golang-set/v2 v2.6.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc + github.com/deckarep/golang-set/v2 v2.6.0 github.com/fxamacker/cbor/v2 v2.7.0 // indirect - github.com/go-ole/go-ole v1.3.0 // indirect - github.com/holiman/uint256 v1.3.1 // indirect + github.com/holiman/uint256 v1.3.1 github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/shirou/gopsutil v3.21.11+incompatible // indirect - github.com/tklauser/go-sysconf v0.3.13 // indirect - github.com/tklauser/numcpus v0.7.0 // indirect github.com/x448/float16 v0.8.4 // indirect - github.com/yusufpapurcu/wmi v1.2.4 // indirect - golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) diff --git a/go.sum b/go.sum index ab3c7895..cb47ffd6 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,11 @@ github.com/DataDog/zstd v1.5.6-0.20230824185856-869dae002e5e h1:ZIWapoIRN1VqT8GR8jAwb1Ie9GyehWjVcGh32Y2MznE= github.com/DataDog/zstd v1.5.6-0.20230824185856-869dae002e5e/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= -github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= -github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/NethermindEth/juno v0.12.2 h1:GnQUgqMCA93EO7KROlXI5AdXQ9IBvcDP2PEYFr3fQIY= github.com/NethermindEth/juno v0.12.2/go.mod h1:PlxXUUGgzFVRIiSDrLo/jrEtAPUIHpFAylChtoH3wK4= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bits-and-blooms/bitset v1.14.2 h1:YXVoyPndbdvcEVcseEovVfp0qjJp7S+i5+xgp/Nfbdc= github.com/bits-and-blooms/bitset v1.14.2/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= -github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurTXGPFfiQ= -github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cockroachdb/errors v1.11.3 h1:5bA+k2Y6r+oz/6Z/RFlNeVCesGARKuC6YymtcDrbC/I= @@ -28,25 +24,16 @@ github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/Yj github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/consensys/gnark-crypto v0.13.0 h1:VPULb/v6bbYELAPTDFINEVaMTTybV5GLxDdcjnS+4oc= github.com/consensys/gnark-crypto v0.13.0/go.mod h1:wKqwsieaKPThcFkHe0d0zMsbHEUWFmZcG7KBCse210o= -github.com/crate-crypto/go-kzg-4844 v1.1.0 h1:EN/u9k2TF6OWSHrCCDBBU6GLNMq88OspHHlMnHfoyU4= -github.com/crate-crypto/go-kzg-4844 v1.1.0/go.mod h1:JolLjpSff1tCCJKaJx4psrlEdlXuJEC996PL3tTAFks= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/deckarep/golang-set/v2 v2.6.0 h1:XfcQbWM1LlMB8BsJ8N9vW5ehnnPVIw0je80NsVHagjM= github.com/deckarep/golang-set/v2 v2.6.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= -github.com/ethereum/c-kzg-4844 v1.0.0 h1:0X1LBXxaEtYD9xsyj9B9ctQEZIpnvVDeoBx8aHEwTNA= -github.com/ethereum/c-kzg-4844 v1.0.0/go.mod h1:VewdlzQmpT5QSrVhbBuGoCdFJkpaJlO1aQputP83wc0= github.com/ethereum/go-ethereum v1.14.8 h1:NgOWvXS+lauK+zFukEvi85UmmsS/OkV0N23UZ1VTIig= github.com/ethereum/go-ethereum v1.14.8/go.mod h1:TJhyuDq0JDppAkFXgqjwpdlQApywnu/m10kFPxh8vvs= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= -github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= -github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= -github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= @@ -87,22 +74,12 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= -github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/supranational/blst v0.3.11 h1:LyU6FolezeWAhvQk0k6O/d49jqgO52MSDDfYgbeoEm4= -github.com/supranational/blst v0.3.11/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= -github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4= -github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0= -github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4= -github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= -github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -115,10 +92,6 @@ golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDT golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= diff --git a/rpc/client.go b/rpc/client.go index 4e3e0ec0..ed317dc9 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" - ethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/NethermindEth/starknet.go/client" ) type callCloser interface { @@ -12,6 +12,11 @@ type callCloser interface { Close() } +type wsConn interface { + callCloser + Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*client.ClientSubscription, error) +} + // do is a function that performs a remote procedure call (RPC) using the provided callCloser. // // Parameters: @@ -44,6 +49,6 @@ func do(ctx context.Context, call callCloser, method string, data interface{}, a // Returns: // - *ethrpc.Client: a new ethrpc.Client // - error: an error if any occurred -func NewClient(url string) (*ethrpc.Client, error) { - return ethrpc.DialContext(context.Background(), url) +func NewClient(url string) (*client.Client, error) { + return client.DialContext(context.Background(), url) } diff --git a/rpc/provider.go b/rpc/provider.go index 45e0b666..9e2a7603 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -7,7 +7,7 @@ import ( "net/http/cookiejar" "github.com/NethermindEth/juno/core/felt" - ethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/NethermindEth/starknet.go/client" "github.com/gorilla/websocket" "golang.org/x/net/publicsuffix" ) @@ -23,16 +23,21 @@ type Provider struct { chainID string } +// WsProvider provides the provider for websocket starknet.go/rpc implementation. +type WsProvider struct { + c wsConn +} + // NewProvider creates a new HTTP rpc Provider instance. -func NewProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) { +func NewProvider(url string, options ...client.ClientOption) (*Provider, error) { jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) if err != nil { return nil, err } - client := &http.Client{Jar: jar} + httpClient := &http.Client{Jar: jar} // prepend the custom client to allow users to override - options = append([]ethrpc.ClientOption{ethrpc.WithHTTPClient(client)}, options...) - c, err := ethrpc.DialOptions(context.Background(), url, options...) + options = append([]client.ClientOption{client.WithHTTPClient(httpClient)}, options...) + c, err := client.DialOptions(context.Background(), url, options...) if err != nil { return nil, err @@ -42,7 +47,7 @@ func NewProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) } // NewWebsocketProvider creates a new Websocket rpc Provider instance. -func NewWebsocketProvider(url string, options ...ethrpc.ClientOption) (*Provider, error) { +func NewWebsocketProvider(url string, options ...client.ClientOption) (*WsProvider, error) { jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) if err != nil { return nil, err @@ -50,14 +55,14 @@ func NewWebsocketProvider(url string, options ...ethrpc.ClientOption) (*Provider dialer := websocket.Dialer{Jar: jar} // prepend the custom client to allow users to override - options = append([]ethrpc.ClientOption{ethrpc.WithWebsocketDialer(dialer)}, options...) - c, err := ethrpc.DialOptions(context.Background(), url, options...) + options = append([]client.ClientOption{client.WithWebsocketDialer(dialer)}, options...) + c, err := client.DialOptions(context.Background(), url, options...) if err != nil { return nil, err } - return &Provider{c: c}, nil + return &WsProvider{c: c}, nil } //go:generate mockgen -destination=../mocks/mock_rpc_provider.go -package=mocks -source=provider.go api diff --git a/rpc/websocket.go b/rpc/websocket.go new file mode 100644 index 00000000..b650a200 --- /dev/null +++ b/rpc/websocket.go @@ -0,0 +1,19 @@ +package rpc + +import "context" + +// New block headers subscription. +// Creates a WebSocket stream which will fire events for new block headers +// +// Parameters: +// - ctx: The context.Context object for controlling the function call +// - blockID: The ID of the block to retrieve the transactions from +// Returns: +// - subscriptionId: The subscription ID +// - error: An error, if any +func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, blockID BlockID) (subscriptionId int, err error) { + if err = do(ctx, provider.c, "starknet_subscribeNewHeads", &subscriptionId, blockID); err != nil { + return 0, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) + } + return subscriptionId, nil +} From 6d54e5a93aae37fcdc33b861f68c96d6e0645b66 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 8 Jan 2025 01:32:31 -0300 Subject: [PATCH 05/35] adds remaining test data and improves tests, and adapt subscribe method to work on starknet --- client/client.go | 13 ++++--------- client/client_test.go | 14 +++++++------- client/testdata/internal-error.js | 7 +++++++ client/testdata/invalid-badid.js | 7 +++++++ client/testdata/invalid-badversion.js | 19 +++++++++++++++++++ client/testdata/invalid-batch-toolarge.js | 13 +++++++++++++ client/testdata/invalid-batch.js | 17 +++++++++++++++++ client/testdata/invalid-idonly.js | 7 +++++++ client/testdata/invalid-nonobj.js | 7 +++++++ client/testdata/invalid-syntax.json | 5 +++++ client/testdata/reqresp-batch.js | 8 ++++++++ client/testdata/reqresp-echo.js | 16 ++++++++++++++++ client/testdata/reqresp-namedparam.js | 5 +++++ client/testdata/reqresp-noargsrets.js | 4 ++++ client/testdata/reqresp-nomethod.js | 4 ++++ client/testdata/reqresp-noparam.js | 4 ++++ client/testdata/reqresp-paramsnull.js | 4 ++++ client/testdata/revcall.js | 6 ++++++ client/testdata/revcall2.js | 7 +++++++ client/testdata/subscription.js | 12 ++++++++++++ rpc/client.go | 2 +- 21 files changed, 164 insertions(+), 17 deletions(-) create mode 100644 client/testdata/internal-error.js create mode 100644 client/testdata/invalid-badid.js create mode 100644 client/testdata/invalid-badversion.js create mode 100644 client/testdata/invalid-batch-toolarge.js create mode 100644 client/testdata/invalid-batch.js create mode 100644 client/testdata/invalid-idonly.js create mode 100644 client/testdata/invalid-nonobj.js create mode 100644 client/testdata/invalid-syntax.json create mode 100644 client/testdata/reqresp-batch.js create mode 100644 client/testdata/reqresp-echo.js create mode 100644 client/testdata/reqresp-namedparam.js create mode 100644 client/testdata/reqresp-noargsrets.js create mode 100644 client/testdata/reqresp-nomethod.js create mode 100644 client/testdata/reqresp-noparam.js create mode 100644 client/testdata/reqresp-paramsnull.js create mode 100644 client/testdata/revcall.js create mode 100644 client/testdata/revcall2.js create mode 100644 client/testdata/subscription.js diff --git a/client/client.go b/client/client.go index 3988f542..b4018b85 100644 --- a/client/client.go +++ b/client/client.go @@ -478,14 +478,9 @@ func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) } // EthSubscribe registers a subscription under the "eth" namespace. +// Note: this was kept for compatibility with the ethereum client tests func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { - return c.Subscribe(ctx, "eth", channel, args...) -} - -// ShhSubscribe registers a subscription under the "shh" namespace. -// Deprecated: use Subscribe(ctx, "shh", ...). -func (c *Client) ShhSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { - return c.Subscribe(ctx, "shh", channel, args...) + return c.Subscribe(ctx, "eth", subscribeMethodSuffix, channel, args...) } // Subscribe calls the "_subscribe" method with the given arguments, @@ -500,7 +495,7 @@ func (c *Client) ShhSubscribe(ctx context.Context, channel interface{}, args ... // before considering the subscriber dead. The subscription Err channel will receive // ErrSubscriptionQueueOverflow. Use a sufficiently large buffer on the channel or ensure // that the channel usually has at least one reader to prevent this issue. -func (c *Client) Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { +func (c *Client) Subscribe(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { // Check type of channel first. chanVal := reflect.ValueOf(channel) if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { @@ -513,7 +508,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf return nil, ErrNotificationsUnsupported } - msg, err := c.newMessage(namespace+subscribeMethodSuffix, args...) + msg, err := c.newMessage(namespace+methodSuffix, args...) if err != nil { return nil, err } diff --git a/client/client_test.go b/client/client_test.go index 45a722bf..3a580517 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -447,7 +447,7 @@ func TestClientSubscribe(t *testing.T) { nc := make(chan int) count := 10 - sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", count, 0) + sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, nc, "someSubscription", count, 0) if err != nil { t.Fatal("can't subscribe:", err) } @@ -494,7 +494,7 @@ func TestClientSubscribeClose(t *testing.T) { err error ) go func() { - sub, err = client.Subscribe(context.Background(), "nftest2", nc, "hangSubscription", 999) + sub, err = client.Subscribe(context.Background(), "nftest2", subscribeMethodSuffix, nc, "hangSubscription", 999) errc <- err }() @@ -526,7 +526,7 @@ func TestClientCloseUnsubscribeRace(t *testing.T) { for i := 0; i < 20; i++ { client := DialInProc(server) nc := make(chan int) - sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1) + sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, nc, "someSubscription", 3, 1) if err != nil { t.Fatal(err) } @@ -584,7 +584,7 @@ func TestUnsubscribeTimeout(t *testing.T) { defer client.Close() // Start subscription. - sub, err := client.Subscribe(context.Background(), "nftest", make(chan int), "someSubscription", 1, 1) + sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, make(chan int), "someSubscription", 1, 1) if err != nil { t.Fatalf("failed to subscribe: %v", err) } @@ -652,7 +652,7 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) { // Create the subscription. ch := make(chan int) - sub, err := client.Subscribe(context.Background(), "nftest", ch, "someSubscription", 1, 1) + sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, ch, "someSubscription", 1, 1) if err != nil { t.Fatal(err) } @@ -686,7 +686,7 @@ func TestClientSubscriptionChannelClose(t *testing.T) { for i := 0; i < 100; i++ { ch := make(chan int, 100) - sub, err := client.Subscribe(context.Background(), "nftest", ch, "someSubscription", 100, 1) + sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, ch, "someSubscription", 100, 1) if err != nil { t.Fatal(err) } @@ -712,7 +712,7 @@ func TestClientNotificationStorm(t *testing.T) { // Subscribe on the server. It will start sending many notifications // very quickly. nc := make(chan int) - sub, err := client.Subscribe(ctx, "nftest", nc, "someSubscription", count, 0) + sub, err := client.Subscribe(ctx, "nftest", subscribeMethodSuffix, nc, "someSubscription", count, 0) if err != nil { t.Fatal("can't subscribe:", err) } diff --git a/client/testdata/internal-error.js b/client/testdata/internal-error.js new file mode 100644 index 00000000..c10ac475 --- /dev/null +++ b/client/testdata/internal-error.js @@ -0,0 +1,7 @@ +// These tests trigger various 'internal error' conditions. + +--> {"jsonrpc":"2.0","id":1,"method":"test_marshalError","params": []} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32603,"message":"json: error calling MarshalText for type *client.MarshalErrObj: marshal error"}} + +--> {"jsonrpc":"2.0","id":2,"method":"test_panic","params": []} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32603,"message":"method handler crashed"}} diff --git a/client/testdata/invalid-badid.js b/client/testdata/invalid-badid.js new file mode 100644 index 00000000..2202b8cc --- /dev/null +++ b/client/testdata/invalid-badid.js @@ -0,0 +1,7 @@ +// This test checks processing of messages with invalid ID. + +--> {"id":[],"method":"test_foo"} +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} + +--> {"id":{},"method":"test_foo"} +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} diff --git a/client/testdata/invalid-badversion.js b/client/testdata/invalid-badversion.js new file mode 100644 index 00000000..75b5291d --- /dev/null +++ b/client/testdata/invalid-badversion.js @@ -0,0 +1,19 @@ +// This test checks processing of messages with invalid Version. + +--> {"jsonrpc":"2.0","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"result":{"String":"x","Int":3,"Args":null}} + +--> {"jsonrpc":"2.1","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":"go-ethereum","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":1,"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":2.0,"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} diff --git a/client/testdata/invalid-batch-toolarge.js b/client/testdata/invalid-batch-toolarge.js new file mode 100644 index 00000000..218fea58 --- /dev/null +++ b/client/testdata/invalid-batch-toolarge.js @@ -0,0 +1,13 @@ +// This file checks the behavior of the batch item limit code. +// In tests, the batch item limit is set to 4. So to trigger the error, +// all batches in this file have 5 elements. + +// For batches that do not contain any calls, a response message with "id" == null +// is returned. + +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}] + +// For batches with at least one call, the call's "id" is used. +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] +<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}] diff --git a/client/testdata/invalid-batch.js b/client/testdata/invalid-batch.js new file mode 100644 index 00000000..768dbc83 --- /dev/null +++ b/client/testdata/invalid-batch.js @@ -0,0 +1,17 @@ +// This test checks the behavior of batches with invalid elements. +// Empty batches are not allowed. Batches may contain junk. + +--> [] +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"empty batch"}} + +--> [1] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +--> [1,2,3] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +--> [null] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +--> [{"jsonrpc":"2.0","id":1,"method":"test_echo","params":["foo",1]},55,{"jsonrpc":"2.0","id":2,"method":"unknown_method"},{"foo":"bar"}] +<-- [{"jsonrpc":"2.0","id":1,"result":{"String":"foo","Int":1,"Args":null}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"the method unknown_method does not exist/is not available"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] diff --git a/client/testdata/invalid-idonly.js b/client/testdata/invalid-idonly.js new file mode 100644 index 00000000..79997bee --- /dev/null +++ b/client/testdata/invalid-idonly.js @@ -0,0 +1,7 @@ +// This test checks processing of messages that contain just the ID and nothing else. + +--> {"id":1} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":"2.0","id":1} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} diff --git a/client/testdata/invalid-nonobj.js b/client/testdata/invalid-nonobj.js new file mode 100644 index 00000000..ffdd4a5b --- /dev/null +++ b/client/testdata/invalid-nonobj.js @@ -0,0 +1,7 @@ +// This test checks behavior for invalid requests. + +--> 1 +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} + +--> null +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} diff --git a/client/testdata/invalid-syntax.json b/client/testdata/invalid-syntax.json new file mode 100644 index 00000000..b1942996 --- /dev/null +++ b/client/testdata/invalid-syntax.json @@ -0,0 +1,5 @@ +// This test checks that an error is written for invalid JSON requests. + +--> 'f +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"invalid character '\\'' looking for beginning of value"}} + diff --git a/client/testdata/reqresp-batch.js b/client/testdata/reqresp-batch.js new file mode 100644 index 00000000..977af766 --- /dev/null +++ b/client/testdata/reqresp-batch.js @@ -0,0 +1,8 @@ +// There is no response for all-notification batches. + +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] + +// This test checks regular batch calls. + +--> [{"jsonrpc":"2.0","id":2,"method":"test_echo","params":[]}, {"jsonrpc":"2.0","id": 3,"method":"test_echo","params":["x",3]}] +<-- [{"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 0"}},{"jsonrpc":"2.0","id":3,"result":{"String":"x","Int":3,"Args":null}}] diff --git a/client/testdata/reqresp-echo.js b/client/testdata/reqresp-echo.js new file mode 100644 index 00000000..7a9e9032 --- /dev/null +++ b/client/testdata/reqresp-echo.js @@ -0,0 +1,16 @@ +// This test calls the test_echo method. + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": []} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 0"}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x"]} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 1"}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x", 3]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":null}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x", 3, {"S": "foo"}]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":{"S":"foo"}}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echoWithCtx", "params": ["x", 3, {"S": "foo"}]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":{"S":"foo"}}} diff --git a/client/testdata/reqresp-namedparam.js b/client/testdata/reqresp-namedparam.js new file mode 100644 index 00000000..9a9372b0 --- /dev/null +++ b/client/testdata/reqresp-namedparam.js @@ -0,0 +1,5 @@ +// This test checks that an error response is sent for calls +// with named parameters. + +--> {"jsonrpc":"2.0","method":"test_echo","params":{"int":23},"id":3} +<-- {"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"non-array args"}} diff --git a/client/testdata/reqresp-noargsrets.js b/client/testdata/reqresp-noargsrets.js new file mode 100644 index 00000000..e61cc708 --- /dev/null +++ b/client/testdata/reqresp-noargsrets.js @@ -0,0 +1,4 @@ +// This test calls the test_noArgsRets method. + +--> {"jsonrpc": "2.0", "id": "foo", "method": "test_noArgsRets", "params": []} +<-- {"jsonrpc":"2.0","id":"foo","result":null} diff --git a/client/testdata/reqresp-nomethod.js b/client/testdata/reqresp-nomethod.js new file mode 100644 index 00000000..58ea6f30 --- /dev/null +++ b/client/testdata/reqresp-nomethod.js @@ -0,0 +1,4 @@ +// This test calls a method that doesn't exist. + +--> {"jsonrpc": "2.0", "id": 2, "method": "invalid_method", "params": [2, 3]} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"the method invalid_method does not exist/is not available"}} diff --git a/client/testdata/reqresp-noparam.js b/client/testdata/reqresp-noparam.js new file mode 100644 index 00000000..2edf486d --- /dev/null +++ b/client/testdata/reqresp-noparam.js @@ -0,0 +1,4 @@ +// This test checks that calls with no parameters work. + +--> {"jsonrpc":"2.0","method":"test_noArgsRets","id":3} +<-- {"jsonrpc":"2.0","id":3,"result":null} diff --git a/client/testdata/reqresp-paramsnull.js b/client/testdata/reqresp-paramsnull.js new file mode 100644 index 00000000..8a01bae1 --- /dev/null +++ b/client/testdata/reqresp-paramsnull.js @@ -0,0 +1,4 @@ +// This test checks that calls with "params":null work. + +--> {"jsonrpc":"2.0","method":"test_noArgsRets","params":null,"id":3} +<-- {"jsonrpc":"2.0","id":3,"result":null} diff --git a/client/testdata/revcall.js b/client/testdata/revcall.js new file mode 100644 index 00000000..695d9858 --- /dev/null +++ b/client/testdata/revcall.js @@ -0,0 +1,6 @@ +// This test checks reverse calls. + +--> {"jsonrpc":"2.0","id":2,"method":"test_callMeBack","params":["foo",[1]]} +<-- {"jsonrpc":"2.0","id":1,"method":"foo","params":[1]} +--> {"jsonrpc":"2.0","id":1,"result":"my result"} +<-- {"jsonrpc":"2.0","id":2,"result":"my result"} diff --git a/client/testdata/revcall2.js b/client/testdata/revcall2.js new file mode 100644 index 00000000..acab4655 --- /dev/null +++ b/client/testdata/revcall2.js @@ -0,0 +1,7 @@ +// This test checks reverse calls. + +--> {"jsonrpc":"2.0","id":2,"method":"test_callMeBackLater","params":["foo",[1]]} +<-- {"jsonrpc":"2.0","id":2,"result":null} +<-- {"jsonrpc":"2.0","id":1,"method":"foo","params":[1]} +--> {"jsonrpc":"2.0","id":1,"result":"my result"} + diff --git a/client/testdata/subscription.js b/client/testdata/subscription.js new file mode 100644 index 00000000..9f100730 --- /dev/null +++ b/client/testdata/subscription.js @@ -0,0 +1,12 @@ +// This test checks basic subscription support. + +--> {"jsonrpc":"2.0","id":1,"method":"nftest_subscribe","params":["someSubscription",5,1]} +<-- {"jsonrpc":"2.0","id":1,"result":"0x1"} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":1}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":2}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":3}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":4}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":5}} + +--> {"jsonrpc":"2.0","id":2,"method":"nftest_echo","params":[11]} +<-- {"jsonrpc":"2.0","id":2,"result":11} diff --git a/rpc/client.go b/rpc/client.go index ed317dc9..868fa80f 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -14,7 +14,7 @@ type callCloser interface { type wsConn interface { callCloser - Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*client.ClientSubscription, error) + Subscribe(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args ...interface{}) (*client.ClientSubscription, error) } // do is a function that performs a remote procedure call (RPC) using the provided callCloser. From f188b94a11dbebc2d7e87be37ff64ead51e7b0c2 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 8 Jan 2025 01:45:12 -0300 Subject: [PATCH 06/35] Enhance Starknet WebSocket support and subscription handling - Updated the handler to correctly process subscription IDs for Starknet, accommodating the new structure returned by the Starknet API. - Modified the subscriptionResult struct to include a Starknet-specific subscription ID field. - Adjusted the WebSocket provider to support new block header subscriptions, improving the overall subscription mechanism. - Updated example usage to reflect changes in subscription handling and error management. These changes improve compatibility with Starknet's API and enhance the robustness of the WebSocket client. --- client/handler.go | 23 +++++++++++++++++++---- client/json.go | 15 ++++++++++----- examples/websocket/main.go | 22 +++++++++++++--------- rpc/websocket.go | 16 +++++++++++----- 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/client/handler.go b/client/handler.go index abf05008..ef47d4f6 100644 --- a/client/handler.go +++ b/client/handler.go @@ -410,7 +410,16 @@ func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*json if msg.Error != nil { op.err = msg.Error } else { - op.err = json.Unmarshal(msg.Result, &op.sub.subid) + // starknet returns a object with a subid field instead of a string + if op.sub.namespace == "starknet" { + var subid struct { + SubID uint64 `json:"subscription_id"` + } + op.err = json.Unmarshal(msg.Result, &subid) + op.sub.subid = strconv.FormatUint(subid.SubID, 10) + } else { + op.err = json.Unmarshal(msg.Result, &op.sub.subid) + } if op.err == nil { go op.sub.run() h.clientSubs[op.sub.subid] = op.sub @@ -432,7 +441,7 @@ func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*json h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) case msg.isNotification(): - if strings.HasSuffix(msg.Method, notificationMethodSuffix) { + if strings.HasPrefix(msg.Method, starknetNotificationMethodPrefix) || strings.HasSuffix(msg.Method, notificationMethodSuffix) { h.handleSubscriptionResult(msg) continue } @@ -455,8 +464,14 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { h.log.Debug("Dropping invalid subscription message") return } - if h.clientSubs[result.ID] != nil { - h.clientSubs[result.ID].deliver(result.Result) + + id := strconv.FormatUint(result.StarknetID, 10) + if id == "0" { + id = result.ID + } + + if h.clientSubs[id] != nil { + h.clientSubs[id].deliver(result.Result) } } diff --git a/client/json.go b/client/json.go index e7908e68..56911e51 100644 --- a/client/json.go +++ b/client/json.go @@ -36,19 +36,24 @@ const ( unsubscribeMethodSuffix = "_unsubscribe" notificationMethodSuffix = "_subscription" + starknetSubscribeMethodPrefix = "starknet_subscribe" + starknetNotificationMethodPrefix = "starknet_subscription" + defaultWriteTimeout = 10 * time.Second // used if context has no deadline ) var null = json.RawMessage("null") type subscriptionResult struct { - ID string `json:"subscription"` - Result json.RawMessage `json:"result,omitempty"` + StarknetID uint64 `json:"subscription_id"` + ID string `json:"subscription,omitempty"` // ethereum field, kept for testing compatibility + Result json.RawMessage `json:"result,omitempty"` } type subscriptionResultEnc struct { - ID string `json:"subscription"` - Result any `json:"result"` + StarknetID uint64 `json:"subscription_id"` + ID string `json:"subscription,omitempty"` // ethereum field, kept for testing compatibility + Result any `json:"result"` } type jsonrpcSubscriptionNotification struct { @@ -89,7 +94,7 @@ func (msg *jsonrpcMessage) hasValidVersion() bool { } func (msg *jsonrpcMessage) isSubscribe() bool { - return strings.HasSuffix(msg.Method, subscribeMethodSuffix) + return strings.HasPrefix(msg.Method, starknetSubscribeMethodPrefix) || strings.HasSuffix(msg.Method, subscribeMethodSuffix) } func (msg *jsonrpcMessage) isUnsubscribe() bool { diff --git a/examples/websocket/main.go b/examples/websocket/main.go index f10357f7..d76ec0d2 100644 --- a/examples/websocket/main.go +++ b/examples/websocket/main.go @@ -5,8 +5,6 @@ import ( "fmt" "github.com/NethermindEth/starknet.go/rpc" - - setup "github.com/NethermindEth/starknet.go/examples/internal" ) // main entry point of the program. @@ -24,21 +22,27 @@ import ( func main() { fmt.Println("Starting simpleCall example") - // Load variables from '.env' file - rpcProviderUrl := setup.GetRpcProviderUrl() - // Initialize connection to RPC provider - client, err := rpc.NewWebsocketProvider(rpcProviderUrl) + client, err := rpc.NewWebsocketProvider("ws://localhost:6061") //local juno node for testing if err != nil { panic(fmt.Sprintf("Error dialing the RPC provider: %s", err)) } fmt.Println("Established connection with the client") - chainID, err := client.ChainID(context.Background()) + ch := make(chan *rpc.BlockHeader) + sub, err := client.SubscribeNewHeads(context.Background(), ch) if err != nil { - panic(fmt.Sprintf("Error getting chain ID: %s", err)) + rpcErr := err.(*rpc.RPCError) + panic(fmt.Sprintf("Error subscribing: %s", rpcErr.Error())) } - fmt.Printf("Chain ID: %s\n", chainID) + for { + select { + case resp := <-ch: + fmt.Printf("New block: %d \n", resp.BlockNumber) + case err := <-sub.Err(): + panic(fmt.Sprintf("Error subscribing to new heads: %s", err)) + } + } } diff --git a/rpc/websocket.go b/rpc/websocket.go index b650a200..38681427 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -1,6 +1,10 @@ package rpc -import "context" +import ( + "context" + + "github.com/NethermindEth/starknet.go/client" +) // New block headers subscription. // Creates a WebSocket stream which will fire events for new block headers @@ -11,9 +15,11 @@ import "context" // Returns: // - subscriptionId: The subscription ID // - error: An error, if any -func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, blockID BlockID) (subscriptionId int, err error) { - if err = do(ctx, provider.c, "starknet_subscribeNewHeads", &subscriptionId, blockID); err != nil { - return 0, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) + +func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, ch chan<- *BlockHeader) (*client.ClientSubscription, error) { + sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeNewHeads", ch) + if err != nil { + return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } - return subscriptionId, nil + return sub, nil } From 71d5a4ecb794428f4b56b14cb1192930ba48f8d2 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Fri, 10 Jan 2025 14:49:34 -0300 Subject: [PATCH 07/35] improves newHeads method and WS close message --- client/subscription.go | 19 ++++++++++++++++++- client/websocket.go | 6 ++++++ rpc/provider.go | 5 +++++ rpc/websocket.go | 12 +++++++++--- 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/client/subscription.go b/client/subscription.go index db378be6..35014098 100644 --- a/client/subscription.go +++ b/client/subscription.go @@ -26,6 +26,7 @@ import ( "errors" "math/rand" "reflect" + "strconv" "strings" "sync" "time" @@ -268,6 +269,11 @@ func (sub *ClientSubscription) Unsubscribe() { }) } +// ID returns the subscription ID. +func (sub *ClientSubscription) ID() string { + return sub.subid +} + // deliver is called by the client's message dispatcher to send a notification value. func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) { select { @@ -373,6 +379,17 @@ func (sub *ClientSubscription) requestUnsubscribe() error { var result interface{} ctx, cancel := context.WithTimeout(context.Background(), unsubscribeTimeout) defer cancel() - err := sub.client.CallContext(ctx, &result, sub.namespace+unsubscribeMethodSuffix, sub.subid) + + var err error + if sub.namespace == "starknet" { + var subId uint64 + subId, err = strconv.ParseUint(sub.subid, 10, 64) + if err != nil { + return err + } + err = sub.client.CallContext(ctx, &result, sub.namespace+unsubscribeMethodSuffix, subId) + } else { + err = sub.client.CallContext(ctx, &result, sub.namespace+unsubscribeMethodSuffix, sub.subid) + } return err } diff --git a/client/websocket.go b/client/websocket.go index ef2e36cc..e0c95f85 100644 --- a/client/websocket.go +++ b/client/websocket.go @@ -324,6 +324,12 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readL } func (wc *websocketCodec) close() { + err := wc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "client is closed"), time.Time{}) + if err != nil { + // Handle error but ensure we still try to close the connection + log.Warn("Error sending close message: ", err) + } + wc.jsonCodec.close() wc.wg.Wait() } diff --git a/rpc/provider.go b/rpc/provider.go index 9e2a7603..77a249a1 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -28,6 +28,11 @@ type WsProvider struct { c wsConn } +// Close closes the client, aborting any in-flight requests. +func (p *WsProvider) Close() { + p.c.Close() +} + // NewProvider creates a new HTTP rpc Provider instance. func NewProvider(url string, options ...client.ClientOption) (*Provider, error) { jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) diff --git a/rpc/websocket.go b/rpc/websocket.go index 38681427..b6a9a11b 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -11,13 +11,19 @@ import ( // // Parameters: // - ctx: The context.Context object for controlling the function call -// - blockID: The ID of the block to retrieve the transactions from +// - headers: The channel to send the new block headers to +// - blockID (optional): The block to get notifications from, default is latest, limited to 1024 blocks back // Returns: // - subscriptionId: The subscription ID // - error: An error, if any +func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID ...BlockID) (*client.ClientSubscription, error) { + // Convert blockID to []any + params := make([]any, len(blockID)) + for i, v := range blockID { + params[i] = v + } -func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, ch chan<- *BlockHeader) (*client.ClientSubscription, error) { - sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeNewHeads", ch) + sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeNewHeads", headers, params...) if err != nil { return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } From 05d7e90372a78cab581983c756126c63144e0371 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Fri, 10 Jan 2025 16:26:34 -0300 Subject: [PATCH 08/35] Enhance Starknet compatibility and testing framework - Updated the main CI workflow to include testing for the client with mocks. - Modified subscription handling in `subscription.js` to accommodate the new `subscription_id` structure from Starknet. - Refactored `provider_test.go` to include WebSocket provider support and improved test configurations. - Introduced a new test file `websocket_test.go` to validate WebSocket subscriptions for new block headers, ensuring robust error handling and compatibility with the testnet environment. These changes improve the overall robustness of the RPC client and enhance compatibility with Starknet's API. --- .github/workflows/main_ci_check.yml | 4 + client/testdata/subscription.js | 11 +-- rpc/provider_test.go | 36 ++++++--- rpc/websocket_test.go | 113 ++++++++++++++++++++++++++++ 4 files changed, 150 insertions(+), 14 deletions(-) create mode 100644 rpc/websocket_test.go diff --git a/.github/workflows/main_ci_check.yml b/.github/workflows/main_ci_check.yml index 98cf6453..5ec204b6 100644 --- a/.github/workflows/main_ci_check.yml +++ b/.github/workflows/main_ci_check.yml @@ -71,3 +71,7 @@ jobs: cd ../simpleCall && go build cd ../simpleInvoke && go build cd ../deployContractUDC && go build + + # Test client on mock + - name: Test client with mocks + run: cd client && go test -v diff --git a/client/testdata/subscription.js b/client/testdata/subscription.js index 9f100730..ebbd0eab 100644 --- a/client/testdata/subscription.js +++ b/client/testdata/subscription.js @@ -2,11 +2,12 @@ --> {"jsonrpc":"2.0","id":1,"method":"nftest_subscribe","params":["someSubscription",5,1]} <-- {"jsonrpc":"2.0","id":1,"result":"0x1"} -<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":1}} -<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":2}} -<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":3}} -<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":4}} -<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":5}} +// changed from {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":1}} to accomodate the new subscription_id from starknet +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":1}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":2}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":3}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":4}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription_id":0,"subscription":"0x1","result":5}} --> {"jsonrpc":"2.0","id":2,"method":"nftest_echo","params":[11]} <-- {"jsonrpc":"2.0","id":2,"result":11} diff --git a/rpc/provider_test.go b/rpc/provider_test.go index 2a81ea0d..70c95be5 100644 --- a/rpc/provider_test.go +++ b/rpc/provider_test.go @@ -21,8 +21,10 @@ const ( // testConfiguration is a type that is used to configure tests type testConfiguration struct { - provider *Provider - base string + provider *Provider + wsProvider *WsProvider + base string + wsBase string } var ( @@ -31,15 +33,16 @@ var ( // testConfigurations are predefined test configurations testConfigurations = map[string]testConfiguration{ - // Requires a Mainnet Starknet JSON-RPC compliant node (e.g. pathfinder) - // (ref: https://github.com/eqlabs/pathfinder) + // Requires a Mainnet Starknet JSON-RPC compliant node (e.g. Juno) + // (ref: https://github.com/NethermindEth/juno) "mainnet": { base: "https://free-rpc.nethermind.io/mainnet-juno", }, - // Requires a Testnet Starknet JSON-RPC compliant node (e.g. pathfinder) - // (ref: https://github.com/eqlabs/pathfinder) + // Requires a Testnet Starknet JSON-RPC compliant node (e.g. Juno) + // (ref: https://github.com/NethermindEth/juno) "testnet": { base: "https://free-rpc.nethermind.io/sepolia-juno", + // wsBase: "ws://localhost:6061", }, // Requires a Devnet configuration running locally // (ref: https://github.com/0xSpaceShard/starknet-devnet-rs) @@ -96,15 +99,30 @@ func beforeEach(t *testing.T) *testConfiguration { if base != "" { testConfig.base = base } - c, err := NewProvider(testConfig.base) + + client, err := NewProvider(testConfig.base) if err != nil { t.Fatal("connect should succeed, instead:", err) } - - testConfig.provider = c + testConfig.provider = client t.Cleanup(func() { testConfig.provider.c.Close() }) + + wsBase := os.Getenv("WS_PROVIDER_URL") + if wsBase != "" { + testConfig.wsBase = wsBase + + wsClient, err := NewWebsocketProvider(testConfig.wsBase) + if err != nil { + t.Fatal("connect should succeed, instead:", err) + } + testConfig.wsProvider = wsClient + t.Cleanup(func() { + testConfig.wsProvider.c.Close() + }) + } + return &testConfig } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go new file mode 100644 index 00000000..8ca1dea1 --- /dev/null +++ b/rpc/websocket_test.go @@ -0,0 +1,113 @@ +package rpc + +import ( + "context" + "fmt" + "testing" + + "github.com/NethermindEth/starknet.go/client" + "github.com/stretchr/testify/require" +) + +func TestSubscribeNewHeads(t *testing.T) { + if testEnv != "testnet" { + t.Skip("Skipping test as it requires a testnet environment") + } + + testConfig := beforeEach(t) + require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") + + type testSetType struct { + headers chan *BlockHeader + blockID []BlockID + counter int + isErrorExpected bool + } + + provider := testConfig.provider + blockNumber, err := provider.BlockNumber(context.Background()) + require.NoError(t, err) + + latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated + + testSet := map[string][]testSetType{ + "testnet": { + { // normal + headers: make(chan *BlockHeader), + isErrorExpected: false, + }, + { // with tag latest + headers: make(chan *BlockHeader), + blockID: []BlockID{WithBlockTag("latest")}, + isErrorExpected: false, + }, + { // with tag pending + headers: make(chan *BlockHeader), + blockID: []BlockID{WithBlockTag("pending")}, + isErrorExpected: true, + }, + { // with block number within the range of 1024 blocks + headers: make(chan *BlockHeader), + blockID: []BlockID{WithBlockNumber(blockNumber - 100)}, + counter: 100, + isErrorExpected: false, + }, + { // invalid, with block number out of the range of 1024 blocks + headers: make(chan *BlockHeader), + blockID: []BlockID{WithBlockNumber(blockNumber - 1025)}, + isErrorExpected: true, + }, + { // invalid, more than one blockID parameter + headers: make(chan *BlockHeader), + blockID: []BlockID{WithBlockTag("latest"), WithBlockTag("latest")}, + isErrorExpected: true, + }, + }, + }[testEnv] + + for index, test := range testSet { + t.Run(fmt.Sprintf("test %d", index+1), func(t *testing.T) { + + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + var sub *client.ClientSubscription + if len(test.blockID) == 0 { + sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers) + } else { + sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, test.blockID...) + } + + if test.isErrorExpected { + require.Error(t, err) + return + } else { + require.NoError(t, err) + } + + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-test.headers: + require.IsType(t, &BlockHeader{}, resp) + + if test.counter != 0 { + if test.counter == 1 { + require.Contains(t, latestBlockNumbers, resp.BlockNumber+1) + return + } else { + test.counter-- + } + } else { + return + } + case err := <-sub.Err(): + require.NoError(t, err) + } + } + }) + } +} From c73bfced6d60dd018a5e60d3aef9e6eef8f1f5e9 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Mon, 13 Jan 2025 19:34:26 -0300 Subject: [PATCH 09/35] Fixes linter and test errors --- account/account_test.go | 16 ++++------------ client/client.go | 2 +- client/client_test.go | 12 +++++++----- client/handler.go | 14 +++++++------- client/http_test.go | 2 +- client/json.go | 8 ++++---- client/log/handler.go | 2 +- client/log/logger.go | 2 +- client/server.go | 4 ++-- client/server_test.go | 14 ++++++++------ client/subscription.go | 2 +- client/testdata/invalid-syntax.json | 5 ----- client/testservice_test.go | 4 ++-- client/websocket.go | 8 ++++---- 14 files changed, 43 insertions(+), 52 deletions(-) delete mode 100644 client/testdata/invalid-syntax.json diff --git a/account/account_test.go b/account/account_test.go index 627b8974..5ef237a9 100644 --- a/account/account_test.go +++ b/account/account_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "flag" - "fmt" "math/big" "os" "testing" @@ -18,7 +17,6 @@ import ( "github.com/NethermindEth/starknet.go/mocks" "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/utils" - "github.com/joho/godotenv" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -43,18 +41,12 @@ var ( // // none func TestMain(m *testing.M) { - flag.StringVar(&testEnv, "env", "mock", "set the test environment") + flag.StringVar(&testEnv, "env", "devnet", "set the test environment") flag.Parse() if testEnv == "mock" { return } - base = os.Getenv("INTEGRATION_BASE") - if base == "" { - if err := godotenv.Load(fmt.Sprintf(".env.%s", testEnv)); err != nil { - panic(fmt.Sprintf("Failed to load .env.%s, err: %s", testEnv, err)) - } - base = os.Getenv("INTEGRATION_BASE") - } + base = "http://localhost:5050" os.Exit(m.Run()) } @@ -1076,7 +1068,7 @@ func TestWaitForTransactionReceipt(t *testing.T) { Timeout: 3, // Should poll 3 times Hash: new(felt.Felt).SetUint64(100), ExpectedReceipt: rpc.TransactionReceipt{}, - ExpectedErr: rpc.Err(rpc.InternalError, &rpc.RPCData{Message: "Post \"http://localhost:5050\": context deadline exceeded"}), + ExpectedErr: rpc.Err(rpc.InternalError, &rpc.RPCData{Message: "context deadline exceeded"}), }, }, }[testEnv] @@ -1090,7 +1082,7 @@ func TestWaitForTransactionReceipt(t *testing.T) { rpcErr, ok := err.(*rpc.RPCError) require.True(t, ok) require.Equal(t, test.ExpectedErr.Code, rpcErr.Code) - require.Equal(t, test.ExpectedErr.Data.Message, rpcErr.Data.Message) + require.Contains(t, rpcErr.Data.Message, test.ExpectedErr.Data.Message) // sometimes the error message starts with "Post \"http://localhost:5050\":..." } else { require.Equal(t, test.ExpectedReceipt.ExecutionStatus, (*resp).ExecutionStatus) } diff --git a/client/client.go b/client/client.go index b4018b85..4c04fe8a 100644 --- a/client/client.go +++ b/client/client.go @@ -705,7 +705,7 @@ func (c *Client) read(codec ServerCodec) { msgs, batch, err := codec.readBatch() if _, ok := err.(*json.SyntaxError); ok { msg := errorMessage(&parseError{err.Error()}) - codec.writeJSON(context.Background(), msg, true) + _ = codec.writeJSON(context.Background(), msg, true) } if err != nil { c.readErr <- err diff --git a/client/client_test.go b/client/client_test.go index 3a580517..031b69f3 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -427,7 +427,7 @@ func TestClientSubscribeInvalidArg(t *testing.T) { t.Error(string(buf)) } }() - client.EthSubscribe(context.Background(), arg, "foo_bar") + _, _ = client.EthSubscribe(context.Background(), arg, "foo_bar") } check(true, nil) check(true, 1) @@ -565,7 +565,7 @@ func TestUnsubscribeTimeout(t *testing.T) { t.Parallel() srv := NewServer() - srv.RegisterName("nftest", new(notificationTestService)) + _ = srv.RegisterName("nftest", new(notificationTestService)) // Setup middleware to block on unsubscribe. p1, p2 := net.Pipe() @@ -637,7 +637,7 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) { // Create the server. srv := NewServer() - srv.RegisterName("nftest", new(notificationTestService)) + _ = srv.RegisterName("nftest", new(notificationTestService)) p1, p2 := net.Pipe() recorder := &unsubscribeRecorder{ServerCodec: NewCodec(p1)} go srv.ServeCodec(recorder, OptionMethodInvocation|OptionSubscriptions) @@ -680,7 +680,7 @@ func TestClientSubscriptionChannelClose(t *testing.T) { defer srv.Stop() defer httpsrv.Close() - srv.RegisterName("nftest", new(notificationTestService)) + _ = srv.RegisterName("nftest", new(notificationTestService)) client, _ := Dial(wsURL) defer client.Close() @@ -842,7 +842,9 @@ func TestClientReconnect(t *testing.T) { if err != nil { t.Fatal("can't listen:", err) } - go http.Serve(l, srv.WebsocketHandler([]string{"*"})) + go func() { + _ = http.Serve(l, srv.WebsocketHandler([]string{"*"})) + }() return srv, l } diff --git a/client/handler.go b/client/handler.go index ef47d4f6..130a0686 100644 --- a/client/handler.go +++ b/client/handler.go @@ -163,7 +163,7 @@ func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorR } b.wrote = true // can only write once if len(b.resp) > 0 { - conn.writeJSON(ctx, b.resp, isErrorResponse) + _ = conn.writeJSON(ctx, b.resp, isErrorResponse) } } @@ -173,7 +173,7 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { if len(msgs) == 0 { h.startCallProc(func(cp *callProc) { resp := errorMessage(&invalidRequestError{"empty batch"}) - h.conn.writeJSON(cp.ctx, resp, true) + _ = h.conn.writeJSON(cp.ctx, resp, true) }) return } @@ -245,7 +245,7 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { h.addSubscriptions(cp.notifiers) callBuffer.write(cp.ctx, h.conn) for _, n := range cp.notifiers { - n.activate() + _ = n.activate() } }) } @@ -261,7 +261,7 @@ func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage break } } - h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true) + _ = h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true) } // handleMsg handles a single non-batch message. @@ -291,7 +291,7 @@ func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) { cancel() responded.Do(func() { resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout}) - h.conn.writeJSON(cp.ctx, resp, true) + _ = h.conn.writeJSON(cp.ctx, resp, true) }) }) } @@ -303,11 +303,11 @@ func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) { h.addSubscriptions(cp.notifiers) if answer != nil { responded.Do(func() { - h.conn.writeJSON(cp.ctx, answer, false) + _ = h.conn.writeJSON(cp.ctx, answer, false) }) } for _, n := range cp.notifiers { - n.activate() + _ = n.activate() } } diff --git a/client/http_test.go b/client/http_test.go index 4b422ff0..88de7da1 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -110,7 +110,7 @@ func TestHTTPRespBodyUnlimited(t *testing.T) { s := NewServer() defer s.Stop() - s.RegisterName("test", largeRespService{respLength}) + _ = s.RegisterName("test", largeRespService{respLength}) ts := httptest.NewServer(s) defer ts.Close() diff --git a/client/json.go b/client/json.go index 56911e51..d8cdcbb4 100644 --- a/client/json.go +++ b/client/json.go @@ -260,7 +260,7 @@ func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}, isErrorRespons if !ok { deadline = time.Now().Add(defaultWriteTimeout) } - c.conn.SetWriteDeadline(deadline) + _ = c.conn.SetWriteDeadline(deadline) return c.encode(v, isErrorResponse) } @@ -283,15 +283,15 @@ func (c *jsonCodec) closed() <-chan interface{} { func parseMessage(raw json.RawMessage) ([]*jsonrpcMessage, bool) { if !isBatch(raw) { msgs := []*jsonrpcMessage{{}} - json.Unmarshal(raw, &msgs[0]) + _ = json.Unmarshal(raw, &msgs[0]) return msgs, false } dec := json.NewDecoder(bytes.NewReader(raw)) - dec.Token() // skip '[' + _, _ = dec.Token() // skip '[' var msgs []*jsonrpcMessage for dec.More() { msgs = append(msgs, new(jsonrpcMessage)) - dec.Decode(&msgs[len(msgs)-1]) + _ = dec.Decode(&msgs[len(msgs)-1]) } return msgs, true } diff --git a/client/log/handler.go b/client/log/handler.go index 56eff667..b5b4bf98 100644 --- a/client/log/handler.go +++ b/client/log/handler.go @@ -77,7 +77,7 @@ func (h *TerminalHandler) Handle(_ context.Context, r slog.Record) error { h.mu.Lock() defer h.mu.Unlock() buf := h.format(h.buf, r, h.useColor) - h.wr.Write(buf) + _, _ = h.wr.Write(buf) h.buf = buf[:0] return nil } diff --git a/client/log/logger.go b/client/log/logger.go index 016856c8..187802e1 100644 --- a/client/log/logger.go +++ b/client/log/logger.go @@ -170,7 +170,7 @@ func (l *logger) Write(level slog.Level, msg string, attrs ...any) { } r := slog.NewRecord(time.Now(), level, msg, pcs[0]) r.Add(attrs...) - l.inner.Handler().Handle(context.Background(), r) + _ = l.inner.Handler().Handle(context.Background(), r) } func (l *logger) Log(level slog.Level, msg string, attrs ...any) { diff --git a/client/server.go b/client/server.go index 80b00966..4d6acc01 100644 --- a/client/server.go +++ b/client/server.go @@ -67,7 +67,7 @@ func NewServer() *Server { // Register the default service providing meta information about the RPC service such // as the services and methods it offers. rpcService := &RPCService{server} - server.RegisterName(MetadataApi, rpcService) + _ = server.RegisterName(MetadataApi, rpcService) return server } @@ -177,7 +177,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { if err != nil { if msg := messageForReadError(err); msg != "" { resp := errorMessage(&invalidMessageError{msg}) - codec.writeJSON(ctx, resp, true) + _ = codec.writeJSON(ctx, resp, true) } return } diff --git a/client/server_test.go b/client/server_test.go index d9ff9eb9..36eff4dc 100644 --- a/client/server_test.go +++ b/client/server_test.go @@ -90,7 +90,7 @@ func runTestScript(t *testing.T, file string) { case strings.HasPrefix(line, "--> "): t.Log(line) // write to connection - clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _ = clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)) if _, err := io.WriteString(clientConn, line[4:]+"\n"); err != nil { t.Fatalf("write error: %v", err) } @@ -98,7 +98,7 @@ func runTestScript(t *testing.T, file string) { t.Log(line) want := line[4:] // read line from connection and compare text - clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + _ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) sent, err := readbuf.ReadString('\n') if err != nil { t.Fatalf("read error: %v", err) @@ -124,7 +124,9 @@ func TestServerShortLivedConn(t *testing.T) { t.Fatal("can't listen:", err) } defer listener.Close() - go server.ServeListener(listener) + go func() { + _ = server.ServeListener(listener) + }() var ( request = `{"jsonrpc":"2.0","id":1,"method":"rpc_modules"}` + "\n" @@ -137,10 +139,10 @@ func TestServerShortLivedConn(t *testing.T) { t.Fatal("can't dial:", err) } - conn.SetDeadline(deadline) + _ = conn.SetDeadline(deadline) // Write the request, then half-close the connection so the server stops reading. - conn.Write([]byte(request)) - conn.(*net.TCPConn).CloseWrite() + _, _ = conn.Write([]byte(request)) + _ = conn.(*net.TCPConn).CloseWrite() // Now try to get the response. buf := make([]byte, 2000) n, err := conn.Read(buf) diff --git a/client/subscription.go b/client/subscription.go index 35014098..25bdb69a 100644 --- a/client/subscription.go +++ b/client/subscription.go @@ -305,7 +305,7 @@ func (sub *ClientSubscription) run() { // Call the unsubscribe method on the server. if unsubscribe { - sub.requestUnsubscribe() + _ = sub.requestUnsubscribe() } // Send the error. diff --git a/client/testdata/invalid-syntax.json b/client/testdata/invalid-syntax.json deleted file mode 100644 index b1942996..00000000 --- a/client/testdata/invalid-syntax.json +++ /dev/null @@ -1,5 +0,0 @@ -// This test checks that an error is written for invalid JSON requests. - ---> 'f -<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"invalid character '\\'' looking for beginning of value"}} - diff --git a/client/testservice_test.go b/client/testservice_test.go index 542aa7a6..34c7d0a9 100644 --- a/client/testservice_test.go +++ b/client/testservice_test.go @@ -154,7 +154,7 @@ func (s *testService) CallMeBackLater(ctx context.Context, method string, args [ go func() { <-ctx.Done() var result interface{} - c.Call(&result, method, args...) + _ = c.Call(&result, method, args...) }() return nil } @@ -214,7 +214,7 @@ func (s *notificationTestService) HangSubscription(ctx context.Context, val int) subscription := notifier.CreateSubscription() go func() { - notifier.Notify(subscription.ID, val) + _ = notifier.Notify(subscription.ID, val) }() return subscription, nil } diff --git a/client/websocket.go b/client/websocket.go index e0c95f85..7bab6117 100644 --- a/client/websocket.go +++ b/client/websocket.go @@ -369,14 +369,14 @@ func (wc *websocketCodec) pingLoop() { case <-pingTimer.C: wc.jsonCodec.encMu.Lock() - wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) - wc.conn.WriteMessage(websocket.PingMessage, nil) - wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) + _ = wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) + _ = wc.conn.WriteMessage(websocket.PingMessage, nil) + _ = wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) wc.jsonCodec.encMu.Unlock() pingTimer.Reset(wsPingInterval) case <-wc.pongReceived: - wc.conn.SetReadDeadline(time.Time{}) + _ = wc.conn.SetReadDeadline(time.Time{}) } } } From 76d01c8a6605689e3e4860be37bd22a36002a39a Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Tue, 14 Jan 2025 15:42:50 -0300 Subject: [PATCH 10/35] draft implementation of subscribeEvents --- client/subscription.go | 3 +-- rpc/types_event.go | 6 ++++++ rpc/websocket.go | 38 +++++++++++++++++++++++++++++++++++++- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/client/subscription.go b/client/subscription.go index 25bdb69a..3390f39d 100644 --- a/client/subscription.go +++ b/client/subscription.go @@ -201,8 +201,7 @@ func (s *Subscription) MarshalJSON() ([]byte, error) { return json.Marshal(s.ID) } -// ClientSubscription is a subscription established through the Client's Subscribe or -// EthSubscribe methods. +// ClientSubscription is a subscription established through the Client's Subscribe type ClientSubscription struct { client *Client etype reflect.Type diff --git a/rpc/types_event.go b/rpc/types_event.go index 0157edd2..b33f0072 100644 --- a/rpc/types_event.go +++ b/rpc/types_event.go @@ -45,3 +45,9 @@ type EventsInput struct { EventFilter ResultPageRequest } + +type EventSubscriptionInput struct { + FromAddress *felt.Felt `json:"from_address,omitempty"` // Optional. Filter events by from_address which emitted the event + Keys [][]*felt.Felt `json:"keys,omitempty"` // Optional. Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value + BlockID BlockID `json:"block_id,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back +} diff --git a/rpc/websocket.go b/rpc/websocket.go index b6a9a11b..23806adf 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -14,7 +14,7 @@ import ( // - headers: The channel to send the new block headers to // - blockID (optional): The block to get notifications from, default is latest, limited to 1024 blocks back // Returns: -// - subscriptionId: The subscription ID +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID ...BlockID) (*client.ClientSubscription, error) { // Convert blockID to []any @@ -29,3 +29,39 @@ func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan< } return sub, nil } + +// Events subscription. +// Creates a WebSocket stream which will fire events for new Starknet events with applied filters +// +// Parameters: +// - ctx: The context.Context object for controlling the function call +// - events: The channel to send the new events to +// - input: The input struct containing the optional filters +// Returns: +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any +func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, input EventSubscriptionInput) (*client.ClientSubscription, error) { + // Convert struct fields to []any, only including non-empty fields + var params []any + + switch { + case input.BlockID.Number != nil: + params = append(params, input.BlockID.Number) + case input.BlockID.Hash != nil: + params = append(params, input.BlockID.Hash) + case input.BlockID.Tag != "": + params = append(params, input.BlockID.Tag) + } + if input.FromAddress != nil { + params = append(params, input.FromAddress) + } + if len(input.Keys) > 0 { + params = append(params, input.Keys) + } + + sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, params...) + if err != nil { + return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) + } + return sub, nil +} From f678b3ffd913e16709b5efb22ca37aa34a202167 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Tue, 14 Jan 2025 15:52:47 -0300 Subject: [PATCH 11/35] client adaptation and functional version of subscribeEvents Note: at the moment it's not possible to omit optional parameters in json-rpc calls using array as structure type with Juno, and the current client implementation only supports sending parameters as arrays. Therefore, I changed the Subscribe function and now we are able to send optional parameters as object too. That way Juno doesn't return an error --- client/client.go | 18 +++++++++++------- client/client_test.go | 14 +++++++------- rpc/client.go | 3 ++- rpc/types_event.go | 3 ++- rpc/websocket.go | 22 ++-------------------- 5 files changed, 24 insertions(+), 36 deletions(-) diff --git a/client/client.go b/client/client.go index 4c04fe8a..8a0e26de 100644 --- a/client/client.go +++ b/client/client.go @@ -334,7 +334,7 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str if result != nil && reflect.TypeOf(result).Kind() != reflect.Ptr { return fmt.Errorf("call result parameter must be pointer or nil interface: %v", result) } - msg, err := c.newMessage(method, args...) + msg, err := c.newMessage(method, args) if err != nil { return err } @@ -402,7 +402,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { resp: make(chan []*jsonrpcMessage, 1), } for i, elem := range b { - msg, err := c.newMessage(elem.Method, elem.Args...) + msg, err := c.newMessage(elem.Method, elem.Args) if err != nil { return err } @@ -465,7 +465,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { // Notify sends a notification, i.e. a method call that doesn't expect a response. func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) error { op := new(requestOp) - msg, err := c.newMessage(method, args...) + msg, err := c.newMessage(method, args) if err != nil { return err } @@ -480,7 +480,11 @@ func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) // EthSubscribe registers a subscription under the "eth" namespace. // Note: this was kept for compatibility with the ethereum client tests func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { - return c.Subscribe(ctx, "eth", subscribeMethodSuffix, channel, args...) + return c.SubscribeWithSliceArgs(ctx, "eth", subscribeMethodSuffix, channel, args) +} + +func (c *Client) SubscribeWithSliceArgs(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + return c.Subscribe(ctx, namespace, methodSuffix, channel, args) } // Subscribe calls the "_subscribe" method with the given arguments, @@ -495,7 +499,7 @@ func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ... // before considering the subscriber dead. The subscription Err channel will receive // ErrSubscriptionQueueOverflow. Use a sufficiently large buffer on the channel or ensure // that the channel usually has at least one reader to prevent this issue. -func (c *Client) Subscribe(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { +func (c *Client) Subscribe(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args interface{}) (*ClientSubscription, error) { // Check type of channel first. chanVal := reflect.ValueOf(channel) if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { @@ -508,7 +512,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, methodSuffix s return nil, ErrNotificationsUnsupported } - msg, err := c.newMessage(namespace+methodSuffix, args...) + msg, err := c.newMessage(namespace+methodSuffix, args) if err != nil { return nil, err } @@ -536,7 +540,7 @@ func (c *Client) SupportsSubscriptions() bool { return !c.isHTTP } -func (c *Client) newMessage(method string, paramsIn ...interface{}) (*jsonrpcMessage, error) { +func (c *Client) newMessage(method string, paramsIn interface{}) (*jsonrpcMessage, error) { msg := &jsonrpcMessage{Version: vsn, ID: c.nextID(), Method: method} if paramsIn != nil { // prevent sending "params":null var err error diff --git a/client/client_test.go b/client/client_test.go index 031b69f3..d19b3824 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -447,7 +447,7 @@ func TestClientSubscribe(t *testing.T) { nc := make(chan int) count := 10 - sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, nc, "someSubscription", count, 0) + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, nc, "someSubscription", count, 0) if err != nil { t.Fatal("can't subscribe:", err) } @@ -494,7 +494,7 @@ func TestClientSubscribeClose(t *testing.T) { err error ) go func() { - sub, err = client.Subscribe(context.Background(), "nftest2", subscribeMethodSuffix, nc, "hangSubscription", 999) + sub, err = client.SubscribeWithSliceArgs(context.Background(), "nftest2", subscribeMethodSuffix, nc, "hangSubscription", 999) errc <- err }() @@ -526,7 +526,7 @@ func TestClientCloseUnsubscribeRace(t *testing.T) { for i := 0; i < 20; i++ { client := DialInProc(server) nc := make(chan int) - sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, nc, "someSubscription", 3, 1) + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, nc, "someSubscription", 3, 1) if err != nil { t.Fatal(err) } @@ -584,7 +584,7 @@ func TestUnsubscribeTimeout(t *testing.T) { defer client.Close() // Start subscription. - sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, make(chan int), "someSubscription", 1, 1) + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, make(chan int), "someSubscription", 1, 1) if err != nil { t.Fatalf("failed to subscribe: %v", err) } @@ -652,7 +652,7 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) { // Create the subscription. ch := make(chan int) - sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, ch, "someSubscription", 1, 1) + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, ch, "someSubscription", 1, 1) if err != nil { t.Fatal(err) } @@ -686,7 +686,7 @@ func TestClientSubscriptionChannelClose(t *testing.T) { for i := 0; i < 100; i++ { ch := make(chan int, 100) - sub, err := client.Subscribe(context.Background(), "nftest", subscribeMethodSuffix, ch, "someSubscription", 100, 1) + sub, err := client.SubscribeWithSliceArgs(context.Background(), "nftest", subscribeMethodSuffix, ch, "someSubscription", 100, 1) if err != nil { t.Fatal(err) } @@ -712,7 +712,7 @@ func TestClientNotificationStorm(t *testing.T) { // Subscribe on the server. It will start sending many notifications // very quickly. nc := make(chan int) - sub, err := client.Subscribe(ctx, "nftest", subscribeMethodSuffix, nc, "someSubscription", count, 0) + sub, err := client.SubscribeWithSliceArgs(ctx, "nftest", subscribeMethodSuffix, nc, "someSubscription", count, 0) if err != nil { t.Fatal("can't subscribe:", err) } diff --git a/rpc/client.go b/rpc/client.go index 868fa80f..677f59a2 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -14,7 +14,8 @@ type callCloser interface { type wsConn interface { callCloser - Subscribe(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args ...interface{}) (*client.ClientSubscription, error) + Subscribe(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args interface{}) (*client.ClientSubscription, error) + SubscribeWithSliceArgs(ctx context.Context, namespace string, methodSuffix string, channel interface{}, args ...interface{}) (*client.ClientSubscription, error) } // do is a function that performs a remote procedure call (RPC) using the provided callCloser. diff --git a/rpc/types_event.go b/rpc/types_event.go index b33f0072..49a8624e 100644 --- a/rpc/types_event.go +++ b/rpc/types_event.go @@ -49,5 +49,6 @@ type EventsInput struct { type EventSubscriptionInput struct { FromAddress *felt.Felt `json:"from_address,omitempty"` // Optional. Filter events by from_address which emitted the event Keys [][]*felt.Felt `json:"keys,omitempty"` // Optional. Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value - BlockID BlockID `json:"block_id,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back + BlockID BlockID `json:"block,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back + // TODO: change 'block' to 'block_id' as soon as Juno fixes the issue with the 'block' field } diff --git a/rpc/websocket.go b/rpc/websocket.go index 23806adf..3508af16 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -23,7 +23,7 @@ func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan< params[i] = v } - sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeNewHeads", headers, params...) + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeNewHeads", headers, params...) if err != nil { return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } @@ -41,25 +41,7 @@ func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan< // - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, input EventSubscriptionInput) (*client.ClientSubscription, error) { - // Convert struct fields to []any, only including non-empty fields - var params []any - - switch { - case input.BlockID.Number != nil: - params = append(params, input.BlockID.Number) - case input.BlockID.Hash != nil: - params = append(params, input.BlockID.Hash) - case input.BlockID.Tag != "": - params = append(params, input.BlockID.Tag) - } - if input.FromAddress != nil { - params = append(params, input.FromAddress) - } - if len(input.Keys) > 0 { - params = append(params, input.Keys) - } - - sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, params...) + sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, input) if err != nil { return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } From 245dbacae6e6efe02717c83a9450271122ff2c7a Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Tue, 14 Jan 2025 23:48:13 -0300 Subject: [PATCH 12/35] Implement SubscribeEvents functionality in WebSocket provider with enhanced error handling and testing - Added the SubscribeEvents method to the WebSocket provider, allowing for event subscriptions with optional parameters. - Introduced a new test case in websocket_test.go to validate event subscriptions under various conditions, including handling of empty arguments and specific address filtering. - Updated the HexToFeltNoErr utility function to convert hexadecimal strings to felt objects without error handling, suitable for internal use. These changes improve the robustness of the WebSocket client and enhance compatibility with Starknet's event subscription model. --- rpc/websocket.go | 22 ++++++++++++- rpc/websocket_test.go | 76 +++++++++++++++++++++++++++++++++++++++++++ utils/Felt.go | 14 ++++++++ 3 files changed, 111 insertions(+), 1 deletion(-) diff --git a/rpc/websocket.go b/rpc/websocket.go index 3508af16..1302374e 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -3,6 +3,7 @@ package rpc import ( "context" + "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/client" ) @@ -41,7 +42,26 @@ func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan< // - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, input EventSubscriptionInput) (*client.ClientSubscription, error) { - sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, input) + var sub *client.ClientSubscription + var err error + + var emptyBlockID BlockID + if input.BlockID == emptyBlockID { + // BlockID has a custom MarshalJSON that doesn't allow zero values. + // Create a temporary struct without BlockID field to properly handle the optional parameter. + tempInput := struct { + FromAddress *felt.Felt `json:"from_address,omitempty"` + Keys [][]*felt.Felt `json:"keys,omitempty"` + }{ + FromAddress: input.FromAddress, + Keys: input.Keys, + } + + sub, err = provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, tempInput) + } else { + sub, err = provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, input) + } + if err != nil { return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 8ca1dea1..b87439ad 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -4,8 +4,11 @@ import ( "context" "fmt" "testing" + "time" + "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/client" + "github.com/NethermindEth/starknet.go/utils" "github.com/stretchr/testify/require" ) @@ -111,3 +114,76 @@ func TestSubscribeNewHeads(t *testing.T) { }) } } + +func TestSubscribeEvents(t *testing.T) { + if testEnv != "testnet" { + t.Skip("Skipping test as it requires a testnet environment") + } + + testConfig := beforeEach(t) + require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") + + provider := testConfig.provider + blockNumber, err := provider.BlockNumber(context.Background()) + require.NoError(t, err) + + latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated + fromAddress := utils.HexToFeltNoErr("0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7") // sepolia StarkGate: ETH Token + key := utils.HexToFeltNoErr("0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9") + + t.Run("normal call, with empty args", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{}) + require.NoError(t, err) + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Contains(t, latestBlockNumbers, resp.BlockNumber) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("normal call, with all arguments, within the range of 1024 blocks", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + BlockID: WithBlockNumber(blockNumber - 100), + FromAddress: fromAddress, + Keys: [][]*felt.Felt{{key}}, + }) + require.NoError(t, err) + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Less(t, resp.BlockNumber, blockNumber) + require.Equal(t, fromAddress, resp.FromAddress) + require.Equal(t, key, resp.Keys[0]) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) +} diff --git a/utils/Felt.go b/utils/Felt.go index 37343c10..3cf1a3f8 100644 --- a/utils/Felt.go +++ b/utils/Felt.go @@ -31,6 +31,20 @@ func HexToFelt(hex string) (*felt.Felt, error) { return new(felt.Felt).SetString(hex) } +// HexToFelt converts a hexadecimal string to a *felt.Felt object, ignoring errors. +// +// Note: only use this function if you are sure that the input is a valid felt input. +// Not recommended for production use. Always handle errors correctly. +// +// Parameters: +// - hex: the input hexadecimal string to be converted. +// Returns: +// - *felt.Felt: a *felt.Felt object +func HexToFeltNoErr(hex string) *felt.Felt { + felt, _ := new(felt.Felt).SetString(hex) + return felt +} + // HexArrToFelt converts an array of hexadecimal strings to an array of felt objects. // // The function iterates over each element in the hexArr array and calls the HexToFelt function to convert each hexadecimal value to a felt object. From 094d21fff0737fb25f70ad3256eb07e3e03018c0 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 15 Jan 2025 11:46:55 -0300 Subject: [PATCH 13/35] finish subscribeEvents tests --- rpc/websocket_test.go | 159 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index b87439ad..041e7941 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -3,6 +3,7 @@ package rpc import ( "context" "fmt" + "slices" "testing" "time" @@ -156,6 +157,113 @@ func TestSubscribeEvents(t *testing.T) { } }) + t.Run("normal call, fromAddress only", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + FromAddress: fromAddress, + }) + require.NoError(t, err) + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Contains(t, latestBlockNumbers, resp.BlockNumber) + + // Subscription with only fromAddress should return events from the specified address from the latest block onwards. + require.Equal(t, fromAddress, resp.FromAddress) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(20 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("normal call, keys only", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + Keys: [][]*felt.Felt{{key}}, + }) + require.NoError(t, err) + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Contains(t, latestBlockNumbers, resp.BlockNumber) + + // Subscription with only keys should return events with the specified keys from the latest block onwards. + require.Equal(t, key, resp.Keys[0]) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(20 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + + t.Run("normal call, blockID only", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + BlockID: WithBlockNumber(blockNumber - 100), + }) + require.NoError(t, err) + require.NotNil(t, sub) + defer sub.Unsubscribe() + + differentFromAddressFound := false + differentKeyFound := false + + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + require.Less(t, resp.BlockNumber, blockNumber) + + // Subscription with only blockID should return events from all addresses and keys from the specified block onwards. + // Verify by checking for events with different addresses and keys than the test values. + if !differentFromAddressFound { + if resp.FromAddress != fromAddress { + differentFromAddressFound = true + } + } + + if !differentKeyFound { + if !slices.Contains(resp.Keys, key) { + differentKeyFound = true + } + } + + if differentFromAddressFound && differentKeyFound { + return + } + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) + t.Run("normal call, with all arguments, within the range of 1024 blocks", func(t *testing.T) { wsProvider, err := NewWebsocketProvider(testConfig.wsBase) require.NoError(t, err) @@ -186,4 +294,55 @@ func TestSubscribeEvents(t *testing.T) { } } }) + + t.Run("error calls", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + type testSetType struct { + input EventSubscriptionInput + expectedError error + } + + keys := make([][]*felt.Felt, 1025) + for i := 0; i < 1025; i++ { + keys[i] = []*felt.Felt{utils.HexToFeltNoErr("0x1")} + } + + testSet := []testSetType{ + { + input: EventSubscriptionInput{ + Keys: keys, + }, + expectedError: ErrTooManyKeysInFilter, + }, + { + input: EventSubscriptionInput{ + BlockID: WithBlockNumber(blockNumber - 1025), + }, + expectedError: ErrTooManyBlocksBack, + }, + { + input: EventSubscriptionInput{ + BlockID: WithBlockNumber(blockNumber + 2), + }, + expectedError: ErrBlockNotFound, + }, + { + input: EventSubscriptionInput{ + BlockID: WithBlockTag("pending"), + }, + expectedError: ErrCallOnPending, + }, + } + + for _, test := range testSet { + events := make(chan *EmittedEvent) + defer close(events) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, test.input) + require.Nil(t, sub) + require.EqualError(t, err, test.expectedError.Error()) + } + }) } From d01082f7c4a78f623df30b424a528e7019419b1d Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 16 Jan 2025 12:19:58 -0300 Subject: [PATCH 14/35] add starknet-spec updates --- rpc/call.go | 2 +- rpc/call_test.go | 1 + rpc/errors.go | 4 ++++ rpc/types_contract.go | 5 +++-- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/rpc/call.go b/rpc/call.go index 4d29cf1d..0bc8f3e3 100644 --- a/rpc/call.go +++ b/rpc/call.go @@ -18,7 +18,7 @@ import ( func (provider *Provider) Call(ctx context.Context, request FunctionCall, blockID BlockID) ([]*felt.Felt, error) { var result []*felt.Felt if err := do(ctx, provider.c, "starknet_call", &result, request, blockID); err != nil { - return nil, tryUnwrapToRPCErr(err, ErrContractNotFound, ErrContractError, ErrBlockNotFound) + return nil, tryUnwrapToRPCErr(err, ErrContractNotFound, ErrEntrypointNotFound, ErrContractError, ErrBlockNotFound) } return result, nil } diff --git a/rpc/call_test.go b/rpc/call_test.go index 94f601b2..154c0fe2 100644 --- a/rpc/call_test.go +++ b/rpc/call_test.go @@ -66,6 +66,7 @@ func TestCall(t *testing.T) { BlockID: WithBlockTag("latest"), ExpectedPatternResult: utils.TestHexToFelt(t, "0x506f736974696f6e"), }, + // TODO: create a case for the ErrEntrypointNotFound error when Juno implement it { FunctionCall: FunctionCall{ ContractAddress: utils.TestHexToFelt(t, "0x025633c6142D9CA4126e3fD1D522Faa6e9f745144aba728c0B3FEE38170DF9e7"), diff --git a/rpc/errors.go b/rpc/errors.go index aad5eb5f..6fefa92e 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -164,6 +164,10 @@ var ( Code: 20, Message: "Contract not found", } + ErrEntrypointNotFound = &RPCError{ + Code: 21, + Message: "Requested entrypoint does not exist in the contract", + } ErrBlockNotFound = &RPCError{ Code: 24, Message: "Block not found", diff --git a/rpc/types_contract.go b/rpc/types_contract.go index 09e8c180..79d2b90c 100644 --- a/rpc/types_contract.go +++ b/rpc/types_contract.go @@ -109,8 +109,9 @@ type ContractsProof struct { // The nonce and class hash for each requested contract address, in the order in which // they appear in the request. These values are needed to construct the associated leaf node type ContractLeavesData struct { - Nonce *felt.Felt `json:"nonce"` - ClassHash *felt.Felt `json:"class_hash"` + Nonce *felt.Felt `json:"nonce"` + ClassHash *felt.Felt `json:"class_hash"` + StorageRoot *felt.Felt `json:"storage_root"` } type GlobalRoots struct { From f386d32066d615fed256797da679967603793d92 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 16 Jan 2025 15:41:42 -0300 Subject: [PATCH 15/35] Finish subscribeTransactionStatus --- rpc/types_transaction_receipt.go | 5 ++++ rpc/websocket.go | 24 +++++++++++++++ rpc/websocket_test.go | 50 ++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/rpc/types_transaction_receipt.go b/rpc/types_transaction_receipt.go index 7a6e9e33..30dd01eb 100644 --- a/rpc/types_transaction_receipt.go +++ b/rpc/types_transaction_receipt.go @@ -155,6 +155,11 @@ type TxnStatusResp struct { FailureReason string `json:"failure_reason,omitempty"` } +type NewTxnStatusResp struct { + TransactionHash *felt.Felt `json:"transaction_hash"` + Status TxnStatusResp `json:"status"` +} + type TransactionReceiptWithBlockInfo struct { TransactionReceipt BlockHash *felt.Felt `json:"block_hash,omitempty"` diff --git a/rpc/websocket.go b/rpc/websocket.go index 1302374e..8f6303e0 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -67,3 +67,27 @@ func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- * } return sub, nil } + +// Transaction Status subscription. +// Creates a WebSocket stream which at first fires an event with the current known transaction status, +// followed by events for every transaction status update +// +// Parameters: +// - ctx: The context.Context object for controlling the function call +// - newStatus: The channel to send the new transaction status to +// - transactionHash: The transaction hash to fetch status updates for +// Returns: +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any +func (provider *WsProvider) SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) { + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash, WithBlockTag("latest")) + if err != nil { + return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound) + } + // TODO: wait for Juno to implement this. This is the correct implementation by the spec + // sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash) + // if err != nil { + // return nil, tryUnwrapToRPCErr(err) + // } + return sub, nil +} diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 041e7941..858b1b6a 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -346,3 +346,53 @@ func TestSubscribeEvents(t *testing.T) { } }) } + +func TestSubscribeTransactionStatus(t *testing.T) { + if testEnv != "testnet" { + t.Skip("Skipping test as it requires a testnet environment") + } + + testConfig := beforeEach(t) + require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") + + provider := testConfig.provider + blockInterface, err := provider.BlockWithTxHashes(context.Background(), WithBlockTag("latest")) + require.NoError(t, err) + block := blockInterface.(*BlockTxHashes) + + txHash := new(felt.Felt) + for _, tx := range block.Transactions { + status, err := provider.GetTransactionStatus(context.Background(), tx) + require.NoError(t, err) + if status.FinalityStatus == TxnStatus_Accepted_On_L2 { + txHash = tx + break + } + } + + t.Run("normal call", func(t *testing.T) { + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *NewTxnStatusResp) + sub, err := wsProvider.SubscribeTransactionStatus(context.Background(), events, txHash) + require.NoError(t, err) + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-events: + require.IsType(t, &NewTxnStatusResp{}, resp) + require.Equal(t, txHash, resp.TransactionHash) + require.Equal(t, TxnStatus_Accepted_On_L2, resp.Status.FinalityStatus) + return + case err := <-sub.Err(): + require.NoError(t, err) + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for events") + } + } + }) +} From 4d61a710505d04fbde34ad46d5adf69febf51895 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Tue, 21 Jan 2025 22:33:30 -0300 Subject: [PATCH 16/35] Implement SubscribePendingTransactions --- rpc/types_transaction.go | 36 ++++++++++++++++++++++++++++++++++++ rpc/websocket.go | 23 +++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/rpc/types_transaction.go b/rpc/types_transaction.go index 3235f9ae..7eb948fc 100644 --- a/rpc/types_transaction.go +++ b/rpc/types_transaction.go @@ -401,3 +401,39 @@ func (v *TransactionVersion) BigInt() (*big.Int, error) { return big.NewInt(-1), errors.New(fmt.Sprint("TransactionVersion %i not supported", *v)) } } + +// SubPendingTxnsInput is the optional input of the starknet_subscribePendingTransactions subscription. +type SubPendingTxnsInput struct { + // Get all transaction details, and not only the hash. If not provided, only hash is returned. Default is false + TransactionDetails bool `json:"transaction_details,omitempty"` + // Filter transactions to only receive notification from address list + SenderAddress *felt.Felt `json:"sender_address,omitempty"` +} + +// SubPendingTxns is the response of the starknet_subscribePendingTransactions subscription. +type SubPendingTxns struct { + // The hashes of the pending transactions. Only present if transactionDetails is false. + TransactionHashes []*felt.Felt + // The full transaction details. Only present if transactionDetails is true. + Transactions []*BlockTransaction +} + +// UnmarshalJSON unmarshals the JSON data into a SubPendingTxns object. +// +// Parameters: +// - data: The JSON data to be unmarshalled +// Returns: +// - error: An error if the unmarshalling process fails +func (s *SubPendingTxns) UnmarshalJSON(data []byte) error { + var txnsHashes []*felt.Felt + if err := json.Unmarshal(data, &txnsHashes); err == nil { + s.TransactionHashes = txnsHashes + return nil + } + var txns []*BlockTransaction + if err := json.Unmarshal(data, &txns); err == nil { + s.Transactions = txns + return nil + } + return errors.New("failed to unmarshal SubPendingTxns") +} diff --git a/rpc/websocket.go b/rpc/websocket.go index 8f6303e0..467e5a70 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -91,3 +91,26 @@ func (provider *WsProvider) SubscribeTransactionStatus(ctx context.Context, newS // } return sub, nil } + +// New Pending Transactions subscription +// Creates a WebSocket stream which will fire events when a new pending transaction is added. +// While there is no mempool, this notifies of transactions in the pending block. +// +// Parameters: +// - ctx: The context.Context object for controlling the function call +// - pendingTxns: The channel to send the new pending transactions to +// - options: The optional input struct containing the optional filters. Set to nil if no filters are needed. +// Returns: +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any +func (provider *WsProvider) SubscribePendingTransactions(ctx context.Context, pendingTxns chan<- *SubPendingTxns, options *SubPendingTxnsInput) (*client.ClientSubscription, error) { + if options == nil { + options = &SubPendingTxnsInput{} + } + + sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribePendingTransactions", pendingTxns, options) + if err != nil { + return nil, tryUnwrapToRPCErr(err, ErrTooManyAddressesInFilter) + } + return sub, nil +} From 67c1eb737acb5b22245ede811370975c56c070bd Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Tue, 21 Jan 2025 23:29:22 -0300 Subject: [PATCH 17/35] revert mistake in the account_test.go file --- account/account_test.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/account/account_test.go b/account/account_test.go index 5ef237a9..f9841017 100644 --- a/account/account_test.go +++ b/account/account_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "flag" + "fmt" "math/big" "os" "testing" @@ -17,6 +18,7 @@ import ( "github.com/NethermindEth/starknet.go/mocks" "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/utils" + "github.com/joho/godotenv" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -41,12 +43,18 @@ var ( // // none func TestMain(m *testing.M) { - flag.StringVar(&testEnv, "env", "devnet", "set the test environment") + flag.StringVar(&testEnv, "env", "mock", "set the test environment") flag.Parse() if testEnv == "mock" { return } - base = "http://localhost:5050" + base = os.Getenv("INTEGRATION_BASE") + if base == "" { + if err := godotenv.Load(fmt.Sprintf(".env.%s", testEnv)); err != nil { + panic(fmt.Sprintf("Failed to load .env.%s, err: %s", testEnv, err)) + } + base = os.Getenv("INTEGRATION_BASE") + } os.Exit(m.Run()) } From 6c22c90430c54f67326b82475b657fc490e2038c Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 22 Jan 2025 10:41:42 -0300 Subject: [PATCH 18/35] improve SubscribeNewHeads param --- rpc/websocket.go | 12 +++++------- rpc/websocket_test.go | 23 ++++++++++------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/rpc/websocket.go b/rpc/websocket.go index 467e5a70..a3b129eb 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -13,18 +13,16 @@ import ( // Parameters: // - ctx: The context.Context object for controlling the function call // - headers: The channel to send the new block headers to -// - blockID (optional): The block to get notifications from, default is latest, limited to 1024 blocks back +// - blockID (optional): The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used // Returns: // - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any -func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID ...BlockID) (*client.ClientSubscription, error) { - // Convert blockID to []any - params := make([]any, len(blockID)) - for i, v := range blockID { - params[i] = v +func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID *BlockID) (*client.ClientSubscription, error) { + if blockID == nil { + blockID = &BlockID{Tag: "latest"} } - sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeNewHeads", headers, params...) + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeNewHeads", headers, blockID) if err != nil { return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 858b1b6a..dd198d45 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -23,7 +23,7 @@ func TestSubscribeNewHeads(t *testing.T) { type testSetType struct { headers chan *BlockHeader - blockID []BlockID + blockID *BlockID counter int isErrorExpected bool } @@ -34,6 +34,8 @@ func TestSubscribeNewHeads(t *testing.T) { latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated + blockIdEx1 := WithBlockNumber(blockNumber - 100) + blockIdEx2 := WithBlockNumber(blockNumber - 1025) testSet := map[string][]testSetType{ "testnet": { { // normal @@ -42,28 +44,23 @@ func TestSubscribeNewHeads(t *testing.T) { }, { // with tag latest headers: make(chan *BlockHeader), - blockID: []BlockID{WithBlockTag("latest")}, + blockID: &BlockID{Tag: "latest"}, isErrorExpected: false, }, { // with tag pending headers: make(chan *BlockHeader), - blockID: []BlockID{WithBlockTag("pending")}, + blockID: &BlockID{Tag: "pending"}, isErrorExpected: true, }, { // with block number within the range of 1024 blocks headers: make(chan *BlockHeader), - blockID: []BlockID{WithBlockNumber(blockNumber - 100)}, + blockID: &blockIdEx1, counter: 100, isErrorExpected: false, }, { // invalid, with block number out of the range of 1024 blocks headers: make(chan *BlockHeader), - blockID: []BlockID{WithBlockNumber(blockNumber - 1025)}, - isErrorExpected: true, - }, - { // invalid, more than one blockID parameter - headers: make(chan *BlockHeader), - blockID: []BlockID{WithBlockTag("latest"), WithBlockTag("latest")}, + blockID: &blockIdEx2, isErrorExpected: true, }, }, @@ -77,10 +74,10 @@ func TestSubscribeNewHeads(t *testing.T) { defer wsProvider.Close() var sub *client.ClientSubscription - if len(test.blockID) == 0 { - sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers) + if test.blockID == nil { + sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, nil) } else { - sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, test.blockID...) + sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, test.blockID) } if test.isErrorExpected { From c5c147f839cf341ec22ec4772b1ffb836b0a1eb2 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 22 Jan 2025 13:32:37 -0300 Subject: [PATCH 19/35] improve SubscribeEvents method --- examples/websocket/main.go | 2 +- rpc/types_event.go | 3 +-- rpc/websocket.go | 30 ++++++++++++------------------ rpc/websocket_test.go | 13 +++++++------ 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/examples/websocket/main.go b/examples/websocket/main.go index d76ec0d2..1cfc9681 100644 --- a/examples/websocket/main.go +++ b/examples/websocket/main.go @@ -31,7 +31,7 @@ func main() { fmt.Println("Established connection with the client") ch := make(chan *rpc.BlockHeader) - sub, err := client.SubscribeNewHeads(context.Background(), ch) + sub, err := client.SubscribeNewHeads(context.Background(), ch, nil) if err != nil { rpcErr := err.(*rpc.RPCError) panic(fmt.Sprintf("Error subscribing: %s", rpcErr.Error())) diff --git a/rpc/types_event.go b/rpc/types_event.go index 49a8624e..b33f0072 100644 --- a/rpc/types_event.go +++ b/rpc/types_event.go @@ -49,6 +49,5 @@ type EventsInput struct { type EventSubscriptionInput struct { FromAddress *felt.Felt `json:"from_address,omitempty"` // Optional. Filter events by from_address which emitted the event Keys [][]*felt.Felt `json:"keys,omitempty"` // Optional. Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value - BlockID BlockID `json:"block,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back - // TODO: change 'block' to 'block_id' as soon as Juno fixes the issue with the 'block' field + BlockID BlockID `json:"block_id,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back } diff --git a/rpc/websocket.go b/rpc/websocket.go index a3b129eb..34fe6c4a 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -35,31 +35,25 @@ func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan< // Parameters: // - ctx: The context.Context object for controlling the function call // - events: The channel to send the new events to -// - input: The input struct containing the optional filters +// - options: The optional input struct containing the optional filters. Set to nil if no filters are needed. +// - fromAddress: Filter events by from_address which emitted the event +// - keys: Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value +// - blockID: The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used +// // Returns: // - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any -func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, input EventSubscriptionInput) (*client.ClientSubscription, error) { - var sub *client.ClientSubscription - var err error +func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, options *EventSubscriptionInput) (*client.ClientSubscription, error) { + if options == nil { + options = &EventSubscriptionInput{} + } var emptyBlockID BlockID - if input.BlockID == emptyBlockID { - // BlockID has a custom MarshalJSON that doesn't allow zero values. - // Create a temporary struct without BlockID field to properly handle the optional parameter. - tempInput := struct { - FromAddress *felt.Felt `json:"from_address,omitempty"` - Keys [][]*felt.Felt `json:"keys,omitempty"` - }{ - FromAddress: input.FromAddress, - Keys: input.Keys, - } - - sub, err = provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, tempInput) - } else { - sub, err = provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, input) + if options.BlockID == emptyBlockID { + options.BlockID = WithBlockTag("latest") } + sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, options) if err != nil { return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index dd198d45..bbe45394 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -135,7 +135,7 @@ func TestSubscribeEvents(t *testing.T) { defer wsProvider.Close() events := make(chan *EmittedEvent) - sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{}) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{}) require.NoError(t, err) require.NotNil(t, sub) defer sub.Unsubscribe() @@ -160,7 +160,7 @@ func TestSubscribeEvents(t *testing.T) { defer wsProvider.Close() events := make(chan *EmittedEvent) - sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ FromAddress: fromAddress, }) require.NoError(t, err) @@ -190,7 +190,7 @@ func TestSubscribeEvents(t *testing.T) { defer wsProvider.Close() events := make(chan *EmittedEvent) - sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ Keys: [][]*felt.Felt{{key}}, }) require.NoError(t, err) @@ -220,7 +220,7 @@ func TestSubscribeEvents(t *testing.T) { defer wsProvider.Close() events := make(chan *EmittedEvent) - sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ BlockID: WithBlockNumber(blockNumber - 100), }) require.NoError(t, err) @@ -267,7 +267,7 @@ func TestSubscribeEvents(t *testing.T) { defer wsProvider.Close() events := make(chan *EmittedEvent) - sub, err := wsProvider.SubscribeEvents(context.Background(), events, EventSubscriptionInput{ + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ BlockID: WithBlockNumber(blockNumber - 100), FromAddress: fromAddress, Keys: [][]*felt.Felt{{key}}, @@ -335,9 +335,10 @@ func TestSubscribeEvents(t *testing.T) { } for _, test := range testSet { + t.Logf("test: %+v", test.expectedError.Error()) events := make(chan *EmittedEvent) defer close(events) - sub, err := wsProvider.SubscribeEvents(context.Background(), events, test.input) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &test.input) require.Nil(t, sub) require.EqualError(t, err, test.expectedError.Error()) } From 8f9901852c167878000f694ca4540b616b0d2377 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 22 Jan 2025 14:17:13 -0300 Subject: [PATCH 20/35] Add SubscribePendingTransactions tests --- rpc/types_transaction.go | 2 +- rpc/websocket_test.go | 95 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/rpc/types_transaction.go b/rpc/types_transaction.go index 7eb948fc..a2e3c616 100644 --- a/rpc/types_transaction.go +++ b/rpc/types_transaction.go @@ -407,7 +407,7 @@ type SubPendingTxnsInput struct { // Get all transaction details, and not only the hash. If not provided, only hash is returned. Default is false TransactionDetails bool `json:"transaction_details,omitempty"` // Filter transactions to only receive notification from address list - SenderAddress *felt.Felt `json:"sender_address,omitempty"` + SenderAddress []*felt.Felt `json:"sender_address,omitempty"` } // SubPendingTxns is the response of the starknet_subscribePendingTransactions subscription. diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index bbe45394..12e6419e 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -68,6 +68,7 @@ func TestSubscribeNewHeads(t *testing.T) { for index, test := range testSet { t.Run(fmt.Sprintf("test %d", index+1), func(t *testing.T) { + t.Parallel() wsProvider, err := NewWebsocketProvider(testConfig.wsBase) require.NoError(t, err) @@ -130,6 +131,8 @@ func TestSubscribeEvents(t *testing.T) { key := utils.HexToFeltNoErr("0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9") t.Run("normal call, with empty args", func(t *testing.T) { + t.Parallel() + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) require.NoError(t, err) defer wsProvider.Close() @@ -155,6 +158,8 @@ func TestSubscribeEvents(t *testing.T) { }) t.Run("normal call, fromAddress only", func(t *testing.T) { + t.Parallel() + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) require.NoError(t, err) defer wsProvider.Close() @@ -185,6 +190,8 @@ func TestSubscribeEvents(t *testing.T) { }) t.Run("normal call, keys only", func(t *testing.T) { + t.Parallel() + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) require.NoError(t, err) defer wsProvider.Close() @@ -215,6 +222,8 @@ func TestSubscribeEvents(t *testing.T) { }) t.Run("normal call, blockID only", func(t *testing.T) { + t.Parallel() + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) require.NoError(t, err) defer wsProvider.Close() @@ -262,6 +271,8 @@ func TestSubscribeEvents(t *testing.T) { }) t.Run("normal call, with all arguments, within the range of 1024 blocks", func(t *testing.T) { + t.Parallel() + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) require.NoError(t, err) defer wsProvider.Close() @@ -293,6 +304,8 @@ func TestSubscribeEvents(t *testing.T) { }) t.Run("error calls", func(t *testing.T) { + t.Parallel() + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) require.NoError(t, err) defer wsProvider.Close() @@ -394,3 +407,85 @@ func TestSubscribeTransactionStatus(t *testing.T) { } }) } + +func TestSubscribePendingTransactions(t *testing.T) { + if testEnv != "testnet" { + t.Skip("Skipping test as it requires a testnet environment") + } + + testConfig := beforeEach(t) + require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") + + type testSetType struct { + pendingTxns chan *SubPendingTxns + options *SubPendingTxnsInput + expectedError error + } + + addresses := make([]*felt.Felt, 1025) + for i := 0; i < 1025; i++ { + addresses[i] = utils.HexToFeltNoErr("0x1") + } + + testSet := map[string][]testSetType{ + "testnet": { + { // nil input + pendingTxns: make(chan *SubPendingTxns), + options: nil, + }, + { // empty input + pendingTxns: make(chan *SubPendingTxns), + options: &SubPendingTxnsInput{}, + }, + { // with transanctionDetails true + pendingTxns: make(chan *SubPendingTxns), + options: &SubPendingTxnsInput{TransactionDetails: true}, + }, + { // error: too many addresses + pendingTxns: make(chan *SubPendingTxns), + options: &SubPendingTxnsInput{SenderAddress: addresses}, + expectedError: ErrTooManyAddressesInFilter, + }, + }, + }[testEnv] + + for index, test := range testSet { + t.Run(fmt.Sprintf("test %d", index+1), func(t *testing.T) { + t.Parallel() + + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + sub, err := wsProvider.SubscribePendingTransactions(context.Background(), test.pendingTxns, test.options) + + if test.expectedError != nil { + require.Error(t, err) + return + } else { + require.NoError(t, err) + } + + require.NotNil(t, sub) + defer sub.Unsubscribe() + + for { + select { + case resp := <-test.pendingTxns: + require.IsType(t, &SubPendingTxns{}, resp) + + if test.options == nil || !test.options.TransactionDetails { + require.NotEmpty(t, resp.TransactionHashes) + require.Empty(t, resp.Transactions) + } else { + require.Empty(t, resp.TransactionHashes) + require.NotEmpty(t, resp.Transactions) + } + return + case err := <-sub.Err(): + require.NoError(t, err) + } + } + }) + } +} From 087ed79c8b1632f4c406ede867cea14243379ada Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 23 Jan 2025 13:40:46 -0300 Subject: [PATCH 21/35] Add support for Reorg notifications --- client/client.go | 2 +- client/json.go | 11 ++++++ client/subscription.go | 77 +++++++++++++++++++++++++++++++++++------- rpc/websocket_test.go | 9 +++++ 4 files changed, 85 insertions(+), 14 deletions(-) diff --git a/client/client.go b/client/client.go index 8a0e26de..e861d469 100644 --- a/client/client.go +++ b/client/client.go @@ -519,7 +519,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, methodSuffix s op := &requestOp{ ids: []json.RawMessage{msg.ID}, resp: make(chan []*jsonrpcMessage, 1), - sub: newClientSubscription(c, namespace, chanVal), + sub: newClientSubscription(c, namespace, chanVal, methodSuffix), } // Send the subscription request. diff --git a/client/json.go b/client/json.go index d8cdcbb4..e3dee5f3 100644 --- a/client/json.go +++ b/client/json.go @@ -27,6 +27,8 @@ import ( "strings" "sync" "time" + + "github.com/NethermindEth/juno/core/felt" ) const ( @@ -56,6 +58,15 @@ type subscriptionResultEnc struct { Result any `json:"result"` } +// Struct representing a reorganization of the chain +// Can be received from subscribing to newHeads, Events, TransactionStatus +type ReorgEvent struct { + StartBlockHash *felt.Felt `json:"starting_block_hash"` + StartBlockNum uint64 `json:"starting_block_number"` + EndBlockHash *felt.Felt `json:"ending_block_hash"` + EndBlockNum uint64 `json:"ending_block_number"` +} + type jsonrpcSubscriptionNotification struct { Version string `json:"jsonrpc"` Method string `json:"method"` diff --git a/client/subscription.go b/client/subscription.go index 3390f39d..9b2db86a 100644 --- a/client/subscription.go +++ b/client/subscription.go @@ -17,6 +17,7 @@ package client import ( + "bytes" "container/list" "context" crand "crypto/rand" @@ -203,11 +204,13 @@ func (s *Subscription) MarshalJSON() ([]byte, error) { // ClientSubscription is a subscription established through the Client's Subscribe type ClientSubscription struct { - client *Client - etype reflect.Type - channel reflect.Value - namespace string - subid string + client *Client + etype reflect.Type + channel reflect.Value + reorgEtype reflect.Type + reorgChannel chan *ReorgEvent + namespace string + subid string // The in channel receives notification values from client dispatcher. in chan json.RawMessage @@ -228,7 +231,7 @@ type ClientSubscription struct { // This is the sentinel value sent on sub.quit when Unsubscribe is called. var errUnsubscribed = errors.New("unsubscribed") -func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription { +func newClientSubscription(c *Client, namespace string, channel reflect.Value, method string) *ClientSubscription { sub := &ClientSubscription{ client: c, namespace: namespace, @@ -240,6 +243,13 @@ func newClientSubscription(c *Client, namespace string, channel reflect.Value) * unsubDone: make(chan struct{}), err: make(chan error, 1), } + + // A reorg event can be received from subscribing to newHeads, Events, TransactionStatus + if strings.HasSuffix(method, "NewHeads") || strings.HasSuffix(method, "Events") || strings.HasSuffix(method, "TransactionStatus") { + sub.reorgChannel = make(chan *ReorgEvent) + sub.reorgEtype = reflect.TypeOf(&ReorgEvent{}) + } + return sub } @@ -255,6 +265,12 @@ func (sub *ClientSubscription) Err() <-chan error { return sub.err } +// Reorg returns a channel that notifies the subscriber of a reorganization of the chain. +// A reorg event could be received only from subscribing to NewHeads, Events, and TransactionStatus +func (sub *ClientSubscription) Reorg() <-chan *ReorgEvent { + return sub.reorgChannel +} + // Unsubscribe unsubscribes the notification and closes the error channel. // It can safely be called more than once. func (sub *ClientSubscription) Unsubscribe() { @@ -295,6 +311,9 @@ func (sub *ClientSubscription) close(err error) { // is launched by the client's handler after the subscription has been created. func (sub *ClientSubscription) run() { defer close(sub.unsubDone) + if sub.reorgChannel != nil { + defer close(sub.reorgChannel) + } unsubscribe, err := sub.forward() @@ -326,7 +345,15 @@ func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)}, {Dir: reflect.SelectSend, Chan: sub.channel}, } + + // a separate case for reorg events as it'll come in the same subscription + var reorgCases []reflect.SelectCase + if sub.reorgChannel != nil { + reorgCases = append(cases[:2:2], reflect.SelectCase{Dir: reflect.SelectSend, Chan: reflect.ValueOf(sub.reorgChannel)}) + } + buffer := list.New() + isReorg := false for { var chosen int @@ -336,8 +363,13 @@ func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { chosen, recv, _ = reflect.Select(cases[:2]) } else { // Non-empty buffer, send the first queued item. - cases[2].Send = reflect.ValueOf(buffer.Front().Value) - chosen, recv, _ = reflect.Select(cases) + if isReorg { + reorgCases[2].Send = reflect.ValueOf(buffer.Front().Value) + chosen, recv, _ = reflect.Select(reorgCases) + } else { + cases[2].Send = reflect.ValueOf(buffer.Front().Value) + chosen, recv, _ = reflect.Select(cases) + } } switch chosen { @@ -352,7 +384,7 @@ func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { return false, err case 1: // <-sub.in - val, err := sub.unmarshal(recv.Interface().(json.RawMessage)) + val, err := sub.unmarshal(recv.Interface().(json.RawMessage), &isReorg) if err != nil { return true, err } @@ -361,16 +393,35 @@ func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { } buffer.PushBack(val) - case 2: // sub.channel<- - cases[2].Send = reflect.Value{} // Don't hold onto the value. + case 2: // sub.channel<- OR sub.reorgChannel<- + if isReorg { + reorgCases[2].Send = reflect.Value{} // Don't hold onto the value. + } else { + cases[2].Send = reflect.Value{} // Don't hold onto the value. + } buffer.Remove(buffer.Front()) } } } -func (sub *ClientSubscription) unmarshal(result json.RawMessage) (interface{}, error) { +func (sub *ClientSubscription) unmarshal(result json.RawMessage, isReorg *bool) (interface{}, error) { val := reflect.New(sub.etype) - err := json.Unmarshal(result, val.Interface()) + dec := json.NewDecoder(bytes.NewReader(result)) + dec.DisallowUnknownFields() + err := dec.Decode(val.Interface()) + + // If there's an error when unmarshalling to the main channel type, maybe it's a reorg event + if err != nil && sub.reorgEtype != nil { + val = reflect.New(sub.reorgEtype) + err2 := json.Unmarshal(result, val.Interface()) + if err2 != nil { + err = errors.Join(err, err2) + } else { + *isReorg = true + return val.Elem().Interface(), nil + } + } + *isReorg = false return val.Elem().Interface(), err } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 12e6419e..d48abadc 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -14,6 +14,8 @@ import ( ) func TestSubscribeNewHeads(t *testing.T) { + t.Parallel() + if testEnv != "testnet" { t.Skip("Skipping test as it requires a testnet environment") } @@ -115,6 +117,8 @@ func TestSubscribeNewHeads(t *testing.T) { } func TestSubscribeEvents(t *testing.T) { + t.Parallel() + if testEnv != "testnet" { t.Skip("Skipping test as it requires a testnet environment") } @@ -359,6 +363,7 @@ func TestSubscribeEvents(t *testing.T) { } func TestSubscribeTransactionStatus(t *testing.T) { + t.Parallel() if testEnv != "testnet" { t.Skip("Skipping test as it requires a testnet environment") } @@ -409,6 +414,7 @@ func TestSubscribeTransactionStatus(t *testing.T) { } func TestSubscribePendingTransactions(t *testing.T) { + t.Parallel() if testEnv != "testnet" { t.Skip("Skipping test as it requires a testnet environment") } @@ -489,3 +495,6 @@ func TestSubscribePendingTransactions(t *testing.T) { }) } } + +// TODO: Add mock for testing reorg events. +// A simple test was made to make sure the reorg events are received; it'll be added in the PR 651 comments From e3bc2d135b2a428514a2a01234def16d9bb08e60 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Fri, 24 Jan 2025 11:58:24 -0300 Subject: [PATCH 22/35] add TestUnsubscribe --- rpc/websocket_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index d48abadc..13eabeba 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -496,5 +496,48 @@ func TestSubscribePendingTransactions(t *testing.T) { } } +func TestUnsubscribe(t *testing.T) { + t.Parallel() + + if testEnv != "testnet" { + t.Skip("Skipping test as it requires a testnet environment") + } + + testConfig := beforeEach(t) + require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") + + wsProvider, err := NewWebsocketProvider(testConfig.wsBase) + require.NoError(t, err) + defer wsProvider.Close() + + events := make(chan *EmittedEvent) + sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{}) + require.NoError(t, err) + require.NotNil(t, sub) + + go func(t *testing.T) { + timer := time.NewTimer(3 * time.Second) + <-timer.C + sub.Unsubscribe() + }(t) + +loop: + for { + select { + case resp := <-events: + require.IsType(t, &EmittedEvent{}, resp) + case err := <-sub.Err(): + // when unsubscribing, the error channel should return nil + require.Nil(t, err) + break loop + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for unsubscription") + } + } + + // Unsubscribe again to make sure nothing happens + sub.Unsubscribe() +} + // TODO: Add mock for testing reorg events. // A simple test was made to make sure the reorg events are received; it'll be added in the PR 651 comments From 8b908c5b7c194d44bb72a625070b4a93df5142ef Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 29 Jan 2025 22:53:19 -0300 Subject: [PATCH 23/35] Create WebsocketProvider interface --- rpc/provider.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rpc/provider.go b/rpc/provider.go index 77a249a1..0cd63ab6 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -106,4 +106,12 @@ type RpcProvider interface { TraceTransaction(ctx context.Context, transactionHash *felt.Felt) (TxnTrace, error) } +type WebsocketProvider interface { + SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, options *EventSubscriptionInput) (*client.ClientSubscription, error) + SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID *BlockID) (*client.ClientSubscription, error) + SubscribePendingTransactions(ctx context.Context, pendingTxns chan<- *SubPendingTxns, options *SubPendingTxnsInput) (*client.ClientSubscription, error) + SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) +} + var _ RpcProvider = &Provider{} +var _ WebsocketProvider = &WsProvider{} From a0dd75aeb0a58e6efee7084eb33bf90344032ec2 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 29 Jan 2025 22:53:53 -0300 Subject: [PATCH 24/35] Remove some TODO comments --- client/log/logger.go | 1 - typedData/revision.go | 1 - typedData/typedData.go | 1 - utils/keccak.go | 1 - 4 files changed, 4 deletions(-) diff --git a/client/log/logger.go b/client/log/logger.go index 187802e1..d8014c2a 100644 --- a/client/log/logger.go +++ b/client/log/logger.go @@ -55,7 +55,6 @@ func FromLegacyLevel(lvl int) slog.Level { break } - // TODO: should we allow use of custom levels or force them to match existing max/min if they fall outside the range as I am doing here? if lvl > legacyLevelTrace { return LevelTrace } diff --git a/typedData/revision.go b/typedData/revision.go index 4e2780f1..33b6423b 100644 --- a/typedData/revision.go +++ b/typedData/revision.go @@ -73,7 +73,6 @@ func init() { } type revision struct { - //TODO: create a enum version uint8 domain string hashMethod func(felts ...*felt.Felt) *felt.Felt diff --git a/typedData/typedData.go b/typedData/typedData.go index e56cf372..64945fa8 100644 --- a/typedData/typedData.go +++ b/typedData/typedData.go @@ -203,7 +203,6 @@ func shortGetStructHash( // - hash: A pointer to a felt.Felt representing the calculated hash. // - err: an error if any occurred during the hash calculation. func (td *TypedData) GetTypeHash(typeName string) (*felt.Felt, error) { - //TODO: create/update methods descriptions typeDef, ok := td.Types[typeName] if !ok { if typeDef, ok = td.Revision.Types().Preset[typeName]; !ok { diff --git a/utils/keccak.go b/utils/keccak.go index 79553b62..9989b9ab 100644 --- a/utils/keccak.go +++ b/utils/keccak.go @@ -149,7 +149,6 @@ func BigToHex(in *big.Int) string { // - funcName: the name of the function // Returns: // - *big.Int: the selector -// TODO: this is used by the signer. Should it return a felt? func GetSelectorFromName(funcName string) *big.Int { kec := Keccak256([]byte(funcName)) From ecfe19faa3b208a1d24a613ff57a84677004033a Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 29 Jan 2025 22:57:53 -0300 Subject: [PATCH 25/35] Reorg ws methods alphabetically --- rpc/websocket.go | 70 ++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/rpc/websocket.go b/rpc/websocket.go index 34fe6c4a..3b539d79 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -7,28 +7,6 @@ import ( "github.com/NethermindEth/starknet.go/client" ) -// New block headers subscription. -// Creates a WebSocket stream which will fire events for new block headers -// -// Parameters: -// - ctx: The context.Context object for controlling the function call -// - headers: The channel to send the new block headers to -// - blockID (optional): The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used -// Returns: -// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors -// - error: An error, if any -func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID *BlockID) (*client.ClientSubscription, error) { - if blockID == nil { - blockID = &BlockID{Tag: "latest"} - } - - sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeNewHeads", headers, blockID) - if err != nil { - return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) - } - return sub, nil -} - // Events subscription. // Creates a WebSocket stream which will fire events for new Starknet events with applied filters // @@ -60,27 +38,25 @@ func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- * return sub, nil } -// Transaction Status subscription. -// Creates a WebSocket stream which at first fires an event with the current known transaction status, -// followed by events for every transaction status update +// New block headers subscription. +// Creates a WebSocket stream which will fire events for new block headers // // Parameters: // - ctx: The context.Context object for controlling the function call -// - newStatus: The channel to send the new transaction status to -// - transactionHash: The transaction hash to fetch status updates for +// - headers: The channel to send the new block headers to +// - blockID (optional): The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used // Returns: // - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any -func (provider *WsProvider) SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) { - sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash, WithBlockTag("latest")) +func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID *BlockID) (*client.ClientSubscription, error) { + if blockID == nil { + blockID = &BlockID{Tag: "latest"} + } + + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeNewHeads", headers, blockID) if err != nil { - return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound) + return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) } - // TODO: wait for Juno to implement this. This is the correct implementation by the spec - // sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash) - // if err != nil { - // return nil, tryUnwrapToRPCErr(err) - // } return sub, nil } @@ -106,3 +82,27 @@ func (provider *WsProvider) SubscribePendingTransactions(ctx context.Context, pe } return sub, nil } + +// Transaction Status subscription. +// Creates a WebSocket stream which at first fires an event with the current known transaction status, +// followed by events for every transaction status update +// +// Parameters: +// - ctx: The context.Context object for controlling the function call +// - newStatus: The channel to send the new transaction status to +// - transactionHash: The transaction hash to fetch status updates for +// Returns: +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any +func (provider *WsProvider) SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) { + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash, WithBlockTag("latest")) + if err != nil { + return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound) + } + // TODO: wait for Juno to implement this. This is the correct implementation by the spec + // sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash) + // if err != nil { + // return nil, tryUnwrapToRPCErr(err) + // } + return sub, nil +} From 838bd82d29130493df8e7b20d0d41ee110700c55 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Wed, 29 Jan 2025 23:42:25 -0300 Subject: [PATCH 26/35] Improve ws methods description --- rpc/types_transaction.go | 4 ++-- rpc/websocket.go | 41 +++++++++++++++++++++++----------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/rpc/types_transaction.go b/rpc/types_transaction.go index a2e3c616..02c1a387 100644 --- a/rpc/types_transaction.go +++ b/rpc/types_transaction.go @@ -404,9 +404,9 @@ func (v *TransactionVersion) BigInt() (*big.Int, error) { // SubPendingTxnsInput is the optional input of the starknet_subscribePendingTransactions subscription. type SubPendingTxnsInput struct { - // Get all transaction details, and not only the hash. If not provided, only hash is returned. Default is false + // Optional: Get all transaction details, and not only the hash. If not provided, only hash is returned. Default is false TransactionDetails bool `json:"transaction_details,omitempty"` - // Filter transactions to only receive notification from address list + // Optional: Filter transactions to only receive notification from address list SenderAddress []*felt.Felt `json:"sender_address,omitempty"` } diff --git a/rpc/websocket.go b/rpc/websocket.go index 3b539d79..6d1190e2 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -11,16 +11,20 @@ import ( // Creates a WebSocket stream which will fire events for new Starknet events with applied filters // // Parameters: +// // - ctx: The context.Context object for controlling the function call +// // - events: The channel to send the new events to +// // - options: The optional input struct containing the optional filters. Set to nil if no filters are needed. +// // - fromAddress: Filter events by from_address which emitted the event // - keys: Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value // - blockID: The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used // // Returns: -// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors -// - error: An error, if any +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, options *EventSubscriptionInput) (*client.ClientSubscription, error) { if options == nil { options = &EventSubscriptionInput{} @@ -42,12 +46,13 @@ func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- * // Creates a WebSocket stream which will fire events for new block headers // // Parameters: -// - ctx: The context.Context object for controlling the function call -// - headers: The channel to send the new block headers to -// - blockID (optional): The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used +// - ctx: The context.Context object for controlling the function call +// - headers: The channel to send the new block headers to +// - blockID (optional): The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used +// // Returns: -// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors -// - error: An error, if any +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID *BlockID) (*client.ClientSubscription, error) { if blockID == nil { blockID = &BlockID{Tag: "latest"} @@ -65,12 +70,13 @@ func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan< // While there is no mempool, this notifies of transactions in the pending block. // // Parameters: -// - ctx: The context.Context object for controlling the function call -// - pendingTxns: The channel to send the new pending transactions to -// - options: The optional input struct containing the optional filters. Set to nil if no filters are needed. +// - ctx: The context.Context object for controlling the function call +// - pendingTxns: The channel to send the new pending transactions to +// - options: The optional input struct containing the optional filters. Set to nil if no filters are needed. +// // Returns: -// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors -// - error: An error, if any +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any func (provider *WsProvider) SubscribePendingTransactions(ctx context.Context, pendingTxns chan<- *SubPendingTxns, options *SubPendingTxnsInput) (*client.ClientSubscription, error) { if options == nil { options = &SubPendingTxnsInput{} @@ -88,12 +94,13 @@ func (provider *WsProvider) SubscribePendingTransactions(ctx context.Context, pe // followed by events for every transaction status update // // Parameters: -// - ctx: The context.Context object for controlling the function call -// - newStatus: The channel to send the new transaction status to -// - transactionHash: The transaction hash to fetch status updates for +// - ctx: The context.Context object for controlling the function call +// - newStatus: The channel to send the new transaction status to +// - transactionHash: The transaction hash to fetch status updates for +// // Returns: -// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors -// - error: An error, if any +// - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors +// - error: An error, if any func (provider *WsProvider) SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) { sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash, WithBlockTag("latest")) if err != nil { From 35023f8da4dd63acdf5ab71bd58798829c5407cf Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 30 Jan 2025 01:06:17 -0300 Subject: [PATCH 27/35] Add WebSocket example with detailed usage and README update --- client/subscription.go | 2 +- examples/README.md | 3 ++ examples/internal/setup.go | 5 ++ examples/websocket/README.md | 7 ++- examples/websocket/main.go | 100 +++++++++++++++++++++++++++-------- 5 files changed, 89 insertions(+), 28 deletions(-) diff --git a/client/subscription.go b/client/subscription.go index 9b2db86a..dc5c86b4 100644 --- a/client/subscription.go +++ b/client/subscription.go @@ -271,7 +271,7 @@ func (sub *ClientSubscription) Reorg() <-chan *ReorgEvent { return sub.reorgChannel } -// Unsubscribe unsubscribes the notification and closes the error channel. +// Unsubscribe unsubscribes the notification by calling the 'starknet_unsubscribe' method and closes the error channel. // It can safely be called more than once. func (sub *ClientSubscription) Unsubscribe() { sub.errOnce.Do(func() { diff --git a/examples/README.md b/examples/README.md index 3b897110..6a1482ae 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,6 +3,7 @@ To successfully execute these examples you'll need to configure some environment 1. Rename the ".env.template" file located at the root of this folder to ".env" 1. Uncomment, and assign your Sepolia testnet endpoint to the `RPC_PROVIDER_URL` variable in the ".env" file +1. Uncomment, and assign your Sepolia websocket testnet endpoint to the `WS_PROVIDER_URL` variable in the ".env" file 1. Uncomment, and assign your account address to the `ACCOUNT_ADDRESS` variable in the ".env" file (make sure to have a few ETH in it) 1. Uncomment, and assign your starknet public key to the `PUBLIC_KEY` variable in the ".env" file 1. Uncomment, and assign your private key to the `PRIVATE_KEY` variable in the ".env" file @@ -40,3 +41,5 @@ To run an example: R: See [simpleCall](./simpleCall/main.go). 1. How to sign and verify a typed data? R: See [typedData](./typedData/main.go). +1. How to use WebSocket methods? How to subscribe, unsubscribe, handle errors, and read values from them? + R: See [websocket](./websocket/main.go). diff --git a/examples/internal/setup.go b/examples/internal/setup.go index 912fee98..930814cd 100644 --- a/examples/internal/setup.go +++ b/examples/internal/setup.go @@ -40,6 +40,11 @@ func GetRpcProviderUrl() string { return getEnv("RPC_PROVIDER_URL") } +// Validates whether the WS_PROVIDER_URL variable has been set in the '.env' file and returns it; panics otherwise. +func GetWsProviderUrl() string { + return getEnv("WS_PROVIDER_URL") +} + // Validates whether the PRIVATE_KEY variable has been set in the '.env' file and returns it; panics otherwise. func GetPrivateKey() string { return getEnv("PRIVATE_KEY") diff --git a/examples/websocket/README.md b/examples/websocket/README.md index c0aa9026..b7ab8304 100644 --- a/examples/websocket/README.md +++ b/examples/websocket/README.md @@ -1,10 +1,9 @@ -This example calls two contract functions, with and without calldata. It uses an ERC20 token, but it can be any smart contract. +This example demonstrates how to subscribe to new block headers using WebSocket. It can be adapted to subscribe to other methods as well. Steps: 1. Rename the ".env.template" file located at the root of the "examples" folder to ".env" -1. Uncomment, and assign your Sepolia testnet endpoint to the `RPC_PROVIDER_URL` variable in the ".env" file -1. Uncomment, and assign your account address to the `ACCOUNT_ADDRESS` variable in the ".env" file -1. Make sure you are in the "simpleCall" directory +1. Uncomment, and assign your Sepolia WebSocket testnet endpoint to the `WS_PROVIDER_URL` variable in the ".env" file +1. Make sure you are in the "websocket" directory 1. Execute `go run main.go` The calls outuputs will be returned at the end of the execution. \ No newline at end of file diff --git a/examples/websocket/main.go b/examples/websocket/main.go index 1cfc9681..0abbc96c 100644 --- a/examples/websocket/main.go +++ b/examples/websocket/main.go @@ -3,46 +3,100 @@ package main import ( "context" "fmt" + "time" "github.com/NethermindEth/starknet.go/rpc" + + setup "github.com/NethermindEth/starknet.go/examples/internal" ) -// main entry point of the program. -// -// It initializes the environment and establishes a connection with the client. -// It then makes two contract calls and prints the responses. -// -// Parameters: -// -// none -// -// Returns: -// -// none func main() { - fmt.Println("Starting simpleCall example") + fmt.Println("Starting websocket example") + + // Load variables from '.env' file + wsProviderUrl := setup.GetWsProviderUrl() - // Initialize connection to RPC provider - client, err := rpc.NewWebsocketProvider("ws://localhost:6061") //local juno node for testing + // Initialize connection to WS provider + wsClient, err := rpc.NewWebsocketProvider(wsProviderUrl) if err != nil { - panic(fmt.Sprintf("Error dialing the RPC provider: %s", err)) + panic(fmt.Sprintf("Error dialing the WS provider: %s", err)) } + defer wsClient.Close() // Close the WS client when the program finishes fmt.Println("Established connection with the client") - ch := make(chan *rpc.BlockHeader) - sub, err := client.SubscribeNewHeads(context.Background(), ch, nil) + // Let's now call the SubscribeNewHeads method. To do this, we need to create a channel to receive the new heads. + // + // Note: We'll need to do this for each of the methods we want to subscribe to, always creating a channel to receive the values from + // the node. Check each method's description for the type required for the channel. + newHeadsChan := make(chan *rpc.BlockHeader) + + // We then call the desired websocket method, passing in the channel and the parameters if needed. + // For example, to subscribe to new block headers, we call the SubscribeNewHeads method, passing in the channel and the blockID. + // As the description says it's optional, we pass nil for the blockID value. That way, the latest block will be used by default. + sub, err := wsClient.SubscribeNewHeads(context.Background(), newHeadsChan, nil) + if err != nil { + setup.PanicRPC(err) + } + fmt.Println() + fmt.Println("Successfully subscribed to the node. Subscription ID:", sub.ID()) + + var latestBlockNumber uint64 + + // Now we'll create the loop to continuously read the new heads from the channel. + // This will make the program wait indefinitely for new heads or errors if not interrupted. +loop1: + for { + select { + case newHead := <-newHeadsChan: + // This case will be triggered when a new block header is received. + fmt.Println("New block header received:", newHead.BlockNumber) + latestBlockNumber = newHead.BlockNumber + break loop1 // Let's exit the loop after receiving the first block header + case err := <-sub.Err(): + // This case will be triggered when an error occurs. + setup.PanicRPC(err) + } + } + + // We can also use the subscription returned by the WS methods to unsubscribe from the stream when we're done + sub.Unsubscribe() + + fmt.Printf("Unsubscribed from the subscription %s successfully\n", sub.ID()) + + olderBlockId := rpc.WithBlockNumber(latestBlockNumber - 10) + + // We'll now subscribe to the node again, but this time we'll pass in an older block number as the blockID. + // This way, the node will send us block headers from that block number onwards. + sub, err = wsClient.SubscribeNewHeads(context.Background(), newHeadsChan, &olderBlockId) if err != nil { - rpcErr := err.(*rpc.RPCError) - panic(fmt.Sprintf("Error subscribing: %s", rpcErr.Error())) + setup.PanicRPC(err) } + fmt.Println() + fmt.Println("Successfully subscribed to the node. Subscription ID:", sub.ID()) + go func() { + time.Sleep(20 * time.Second) + // Unsubscribe from the subscription after 20 seconds + sub.Unsubscribe() + }() + +loop2: for { select { - case resp := <-ch: - fmt.Printf("New block: %d \n", resp.BlockNumber) + case newHead := <-newHeadsChan: + fmt.Println("New block header received:", newHead.BlockNumber) case err := <-sub.Err(): - panic(fmt.Sprintf("Error subscribing to new heads: %s", err)) + if err == nil { // when sub.Unsubscribe() is called a nil error is returned, so let's just break the loop if that's the case + fmt.Printf("Unsubscribed from the subscription %s successfully\n", sub.ID()) + break loop2 + } + setup.PanicRPC(err) } } + + // This example can be used to understand how to use all the methods that return a subscription. + // It's just a matter of creating a channel to receive the values from the node and calling the + // desired method, passing in the channel and the parameters if needed. Remember to check the method's + // description for the type required for the channel and whether there are any other parameters needed. } From 2457999238d32af283008de91bfa843830a350eb Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 30 Jan 2025 02:16:21 -0300 Subject: [PATCH 28/35] Update README with WebSocket example documentation --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index a2551938..e84c84d8 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ operations on the wallets. The package has excellent documentation for a smooth - [invoke transaction example](./examples/simpleInvoke) to add a new invoke transaction on testnet. - [deploy contract UDC example](./examples/deployContractUDC) to deploy an ERC20 token using [UDC (Universal Deployer Contract)](https://docs.starknet.io/architecture-and-concepts/accounts/universal-deployer/) on testnet. - [typed data example](./examples/typedData) to sign and verify a typed data. +- [websocket example](./examples/websocket) to learn how to subscribe to WebSocket methods. ### Run Examples @@ -105,6 +106,15 @@ go run main.go > Check [here](examples/typedData/README.md) for more details. +***starknet websocket*** + +```sh +cd examples/websocket +go run main.go +``` + +> Check [here](examples/websocket/README.md) for more details. +
Check [here](https://github.com/NethermindEth/starknet.go/tree/main/examples) for some FAQ answered by these examples. From 9bae8fbd2e244d6ec2e350b54c23340075114064 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 30 Jan 2025 02:19:31 -0300 Subject: [PATCH 29/35] Update CI workflow to build WebSocket and typedData projects --- .github/workflows/main_ci_check.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/main_ci_check.yml b/.github/workflows/main_ci_check.yml index 5ec204b6..c8de8718 100644 --- a/.github/workflows/main_ci_check.yml +++ b/.github/workflows/main_ci_check.yml @@ -71,7 +71,12 @@ jobs: cd ../simpleCall && go build cd ../simpleInvoke && go build cd ../deployContractUDC && go build + cd ../typedData && go build + cd ../websocket && go build # Test client on mock - name: Test client with mocks run: cd client && go test -v + + +#TODO: handle websocket tests in the CI \ No newline at end of file From 956df892f6ae4bedd47fbf4d20cb61309a8a04ca Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 30 Jan 2025 09:18:25 -0300 Subject: [PATCH 30/35] Remove HexToFeltNoErr --- rpc/websocket_test.go | 10 +++++----- utils/Felt.go | 14 -------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 13eabeba..d45f14be 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -130,9 +130,9 @@ func TestSubscribeEvents(t *testing.T) { blockNumber, err := provider.BlockNumber(context.Background()) require.NoError(t, err) - latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated - fromAddress := utils.HexToFeltNoErr("0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7") // sepolia StarkGate: ETH Token - key := utils.HexToFeltNoErr("0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9") + latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated + fromAddress := utils.TestHexToFelt(t, "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7") // sepolia StarkGate: ETH Token + key := utils.TestHexToFelt(t, "0x99cd8bde557814842a3121e8ddfd433a539b8c9f14bf31ebf108d12e6196e9") t.Run("normal call, with empty args", func(t *testing.T) { t.Parallel() @@ -321,7 +321,7 @@ func TestSubscribeEvents(t *testing.T) { keys := make([][]*felt.Felt, 1025) for i := 0; i < 1025; i++ { - keys[i] = []*felt.Felt{utils.HexToFeltNoErr("0x1")} + keys[i] = []*felt.Felt{utils.TestHexToFelt(t, "0x1")} } testSet := []testSetType{ @@ -430,7 +430,7 @@ func TestSubscribePendingTransactions(t *testing.T) { addresses := make([]*felt.Felt, 1025) for i := 0; i < 1025; i++ { - addresses[i] = utils.HexToFeltNoErr("0x1") + addresses[i] = utils.TestHexToFelt(t, "0x1") } testSet := map[string][]testSetType{ diff --git a/utils/Felt.go b/utils/Felt.go index 3cf1a3f8..37343c10 100644 --- a/utils/Felt.go +++ b/utils/Felt.go @@ -31,20 +31,6 @@ func HexToFelt(hex string) (*felt.Felt, error) { return new(felt.Felt).SetString(hex) } -// HexToFelt converts a hexadecimal string to a *felt.Felt object, ignoring errors. -// -// Note: only use this function if you are sure that the input is a valid felt input. -// Not recommended for production use. Always handle errors correctly. -// -// Parameters: -// - hex: the input hexadecimal string to be converted. -// Returns: -// - *felt.Felt: a *felt.Felt object -func HexToFeltNoErr(hex string) *felt.Felt { - felt, _ := new(felt.Felt).SetString(hex) - return felt -} - // HexArrToFelt converts an array of hexadecimal strings to an array of felt objects. // // The function iterates over each element in the hexArr array and calls the HexToFelt function to convert each hexadecimal value to a felt object. From 4beb3ffaa3c2ba8280f7979ac17bc239a7551d4a Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 30 Jan 2025 09:56:04 -0300 Subject: [PATCH 31/35] Refactor WebSocket subscription types with SubscriptionBlockID and remove unused error --- examples/websocket/main.go | 4 +--- rpc/errors.go | 4 ---- rpc/provider.go | 2 +- rpc/types_block.go | 34 ++++++++++++++++++++++++++++++++++ rpc/types_event.go | 6 +++--- rpc/websocket.go | 19 +++++-------------- rpc/websocket_test.go | 35 +++++++++-------------------------- 7 files changed, 53 insertions(+), 51 deletions(-) diff --git a/examples/websocket/main.go b/examples/websocket/main.go index 0abbc96c..cd39d719 100644 --- a/examples/websocket/main.go +++ b/examples/websocket/main.go @@ -64,11 +64,9 @@ loop1: fmt.Printf("Unsubscribed from the subscription %s successfully\n", sub.ID()) - olderBlockId := rpc.WithBlockNumber(latestBlockNumber - 10) - // We'll now subscribe to the node again, but this time we'll pass in an older block number as the blockID. // This way, the node will send us block headers from that block number onwards. - sub, err = wsClient.SubscribeNewHeads(context.Background(), newHeadsChan, &olderBlockId) + sub, err = wsClient.SubscribeNewHeads(context.Background(), newHeadsChan, &rpc.SubscriptionBlockID{Number: latestBlockNumber - 10}) if err != nil { setup.PanicRPC(err) } diff --git a/rpc/errors.go b/rpc/errors.go index 6fefa92e..cf1803df 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -288,10 +288,6 @@ var ( Code: 68, Message: "Cannot go back more than 1024 blocks", } - ErrCallOnPending = &RPCError{ - Code: 69, - Message: "This method does not support being called on the pending block", - } ErrCompilationError = &RPCError{ Code: 100, Message: "Failed to compile the contract", diff --git a/rpc/provider.go b/rpc/provider.go index 0cd63ab6..daa912e0 100644 --- a/rpc/provider.go +++ b/rpc/provider.go @@ -108,7 +108,7 @@ type RpcProvider interface { type WebsocketProvider interface { SubscribeEvents(ctx context.Context, events chan<- *EmittedEvent, options *EventSubscriptionInput) (*client.ClientSubscription, error) - SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID *BlockID) (*client.ClientSubscription, error) + SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, subBlockID *SubscriptionBlockID) (*client.ClientSubscription, error) SubscribePendingTransactions(ctx context.Context, pendingTxns chan<- *SubPendingTxns, options *SubPendingTxnsInput) (*client.ClientSubscription, error) SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) } diff --git a/rpc/types_block.go b/rpc/types_block.go index c8acb86c..4a915d94 100644 --- a/rpc/types_block.go +++ b/rpc/types_block.go @@ -27,6 +27,13 @@ type BlockID struct { Tag string `json:"block_tag,omitempty"` } +// Block hash, number or tag, same as BLOCK_ID, but without 'pending' +type SubscriptionBlockID struct { + Number uint64 `json:"block_number,omitempty"` + Hash *felt.Felt `json:"block_hash,omitempty"` + Tag string `json:"block_tag,omitempty"` +} + // MarshalJSON marshals the BlockID to JSON format. // // It returns a byte slice and an error. The byte slice contains the JSON representation of the BlockID, @@ -57,7 +64,34 @@ func (b BlockID) MarshalJSON() ([]byte, error) { } return nil, ErrInvalidBlockID +} + +// MarshalJSON marshals the SubscriptionBlockID to JSON format. +// +// It returns a byte slice and an error. The byte slice contains the JSON representation of the SubscriptionBlockID, +// while the error indicates any error that occurred during the marshaling process. +// +// Parameters: +// +// none +// +// Returns: +// - []byte: the JSON representation of the SubscriptionBlockID +// - error: any error that occurred during the marshaling process +func (b SubscriptionBlockID) MarshalJSON() ([]byte, error) { + if b.Number != 0 { + return []byte(fmt.Sprintf(`{"block_number":%d}`, b.Number)), nil + } + if b.Hash != nil && b.Hash.BigInt(big.NewInt(0)).BitLen() != 0 { + return []byte(fmt.Sprintf(`{"block_hash":"%s"}`, b.Hash.String())), nil + } + + if b.Tag == "latest" || b.Tag == "" { + return []byte(strconv.Quote("latest")), nil + } + + return nil, ErrInvalidBlockID } type BlockStatus string diff --git a/rpc/types_event.go b/rpc/types_event.go index b33f0072..df39d6f6 100644 --- a/rpc/types_event.go +++ b/rpc/types_event.go @@ -47,7 +47,7 @@ type EventsInput struct { } type EventSubscriptionInput struct { - FromAddress *felt.Felt `json:"from_address,omitempty"` // Optional. Filter events by from_address which emitted the event - Keys [][]*felt.Felt `json:"keys,omitempty"` // Optional. Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value - BlockID BlockID `json:"block_id,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back + FromAddress *felt.Felt `json:"from_address,omitempty"` // Optional. Filter events by from_address which emitted the event + Keys [][]*felt.Felt `json:"keys,omitempty"` // Optional. Per key (by position), designate the possible values to be matched for events to be returned. Empty array designates 'any' value + BlockID SubscriptionBlockID `json:"block_id,omitempty"` // Optional. The block to get notifications from, default is latest, limited to 1024 blocks back } diff --git a/rpc/websocket.go b/rpc/websocket.go index 6d1190e2..65f6c58d 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -30,14 +30,9 @@ func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- * options = &EventSubscriptionInput{} } - var emptyBlockID BlockID - if options.BlockID == emptyBlockID { - options.BlockID = WithBlockTag("latest") - } - sub, err := provider.c.Subscribe(ctx, "starknet", "_subscribeEvents", events, options) if err != nil { - return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) + return nil, tryUnwrapToRPCErr(err, ErrTooManyKeysInFilter, ErrTooManyBlocksBack, ErrBlockNotFound) } return sub, nil } @@ -48,19 +43,15 @@ func (provider *WsProvider) SubscribeEvents(ctx context.Context, events chan<- * // Parameters: // - ctx: The context.Context object for controlling the function call // - headers: The channel to send the new block headers to -// - blockID (optional): The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used +// - subBlockID (optional): The block to get notifications from, limited to 1024 blocks back. If set to nil, the latest block will be used // // Returns: // - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any -func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, blockID *BlockID) (*client.ClientSubscription, error) { - if blockID == nil { - blockID = &BlockID{Tag: "latest"} - } - - sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeNewHeads", headers, blockID) +func (provider *WsProvider) SubscribeNewHeads(ctx context.Context, headers chan<- *BlockHeader, subBlockID *SubscriptionBlockID) (*client.ClientSubscription, error) { + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeNewHeads", headers, subBlockID) if err != nil { - return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound, ErrCallOnPending) + return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound) } return sub, nil } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index d45f14be..5fa310a3 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -25,7 +25,7 @@ func TestSubscribeNewHeads(t *testing.T) { type testSetType struct { headers chan *BlockHeader - blockID *BlockID + subBlockID *SubscriptionBlockID counter int isErrorExpected bool } @@ -36,8 +36,6 @@ func TestSubscribeNewHeads(t *testing.T) { latestBlockNumbers := []uint64{blockNumber, blockNumber + 1} // for the case the latest block number is updated - blockIdEx1 := WithBlockNumber(blockNumber - 100) - blockIdEx2 := WithBlockNumber(blockNumber - 1025) testSet := map[string][]testSetType{ "testnet": { { // normal @@ -46,23 +44,18 @@ func TestSubscribeNewHeads(t *testing.T) { }, { // with tag latest headers: make(chan *BlockHeader), - blockID: &BlockID{Tag: "latest"}, + subBlockID: &SubscriptionBlockID{Tag: "latest"}, isErrorExpected: false, }, - { // with tag pending - headers: make(chan *BlockHeader), - blockID: &BlockID{Tag: "pending"}, - isErrorExpected: true, - }, { // with block number within the range of 1024 blocks headers: make(chan *BlockHeader), - blockID: &blockIdEx1, + subBlockID: &SubscriptionBlockID{Number: blockNumber - 100}, counter: 100, isErrorExpected: false, }, { // invalid, with block number out of the range of 1024 blocks headers: make(chan *BlockHeader), - blockID: &blockIdEx2, + subBlockID: &SubscriptionBlockID{Number: blockNumber - 1025}, isErrorExpected: true, }, }, @@ -77,11 +70,7 @@ func TestSubscribeNewHeads(t *testing.T) { defer wsProvider.Close() var sub *client.ClientSubscription - if test.blockID == nil { - sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, nil) - } else { - sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, test.blockID) - } + sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, test.subBlockID) if test.isErrorExpected { require.Error(t, err) @@ -234,7 +223,7 @@ func TestSubscribeEvents(t *testing.T) { events := make(chan *EmittedEvent) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ - BlockID: WithBlockNumber(blockNumber - 100), + BlockID: SubscriptionBlockID{Number: blockNumber - 100}, }) require.NoError(t, err) require.NotNil(t, sub) @@ -283,7 +272,7 @@ func TestSubscribeEvents(t *testing.T) { events := make(chan *EmittedEvent) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ - BlockID: WithBlockNumber(blockNumber - 100), + BlockID: SubscriptionBlockID{Number: blockNumber - 100}, FromAddress: fromAddress, Keys: [][]*felt.Felt{{key}}, }) @@ -333,22 +322,16 @@ func TestSubscribeEvents(t *testing.T) { }, { input: EventSubscriptionInput{ - BlockID: WithBlockNumber(blockNumber - 1025), + BlockID: SubscriptionBlockID{Number: blockNumber - 1025}, }, expectedError: ErrTooManyBlocksBack, }, { input: EventSubscriptionInput{ - BlockID: WithBlockNumber(blockNumber + 2), + BlockID: SubscriptionBlockID{Number: blockNumber + 2}, }, expectedError: ErrBlockNotFound, }, - { - input: EventSubscriptionInput{ - BlockID: WithBlockTag("pending"), - }, - expectedError: ErrCallOnPending, - }, } for _, test := range testSet { From 3dc8e5fee08f9184755b52094a1528bc4bbcf35e Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 30 Jan 2025 11:02:18 -0300 Subject: [PATCH 32/35] Fix PR review comments --- examples/.env.template | 1 + examples/websocket/README.md | 2 +- rpc/websocket_test.go | 130 ++++++++++++++++++----------------- 3 files changed, 68 insertions(+), 65 deletions(-) diff --git a/examples/.env.template b/examples/.env.template index 6bdf5e2b..e477140c 100644 --- a/examples/.env.template +++ b/examples/.env.template @@ -1,5 +1,6 @@ # ----- use this variable to set the RPC provider URL #RPC_PROVIDER_URL=http_insert_end_point +#WS_PROVIDER_URL=ws_insert_end_point # ----- Use these variables to set up your account #ACCOUNT_ADDRESS=0xyour_account_address diff --git a/examples/websocket/README.md b/examples/websocket/README.md index b7ab8304..c70db85f 100644 --- a/examples/websocket/README.md +++ b/examples/websocket/README.md @@ -6,4 +6,4 @@ Steps: 1. Make sure you are in the "websocket" directory 1. Execute `go run main.go` -The calls outuputs will be returned at the end of the execution. \ No newline at end of file +The call outputs will be returned at the end of the execution. \ No newline at end of file diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 5fa310a3..cf412efe 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -21,13 +21,13 @@ func TestSubscribeNewHeads(t *testing.T) { } testConfig := beforeEach(t) - require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") type testSetType struct { headers chan *BlockHeader subBlockID *SubscriptionBlockID counter int isErrorExpected bool + description string } provider := testConfig.provider @@ -38,49 +38,51 @@ func TestSubscribeNewHeads(t *testing.T) { testSet := map[string][]testSetType{ "testnet": { - { // normal + { headers: make(chan *BlockHeader), isErrorExpected: false, + description: "normal call", }, - { // with tag latest + { headers: make(chan *BlockHeader), subBlockID: &SubscriptionBlockID{Tag: "latest"}, isErrorExpected: false, + description: "with tag latest", }, - { // with block number within the range of 1024 blocks + { headers: make(chan *BlockHeader), subBlockID: &SubscriptionBlockID{Number: blockNumber - 100}, counter: 100, isErrorExpected: false, + description: "with block number within the range of 1024 blocks", }, - { // invalid, with block number out of the range of 1024 blocks + { headers: make(chan *BlockHeader), subBlockID: &SubscriptionBlockID{Number: blockNumber - 1025}, isErrorExpected: true, + description: "invalid, with block number out of the range of 1024 blocks", }, }, }[testEnv] - for index, test := range testSet { - t.Run(fmt.Sprintf("test %d", index+1), func(t *testing.T) { + for _, test := range testSet { + t.Run(fmt.Sprintf("test: %s", test.description), func(t *testing.T) { t.Parallel() - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider var sub *client.ClientSubscription sub, err = wsProvider.SubscribeNewHeads(context.Background(), test.headers, test.subBlockID) + if sub != nil { + defer sub.Unsubscribe() + } if test.isErrorExpected { require.Error(t, err) return - } else { - require.NoError(t, err) } - + require.NoError(t, err) require.NotNil(t, sub) - defer sub.Unsubscribe() for { select { @@ -113,7 +115,6 @@ func TestSubscribeEvents(t *testing.T) { } testConfig := beforeEach(t) - require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") provider := testConfig.provider blockNumber, err := provider.BlockNumber(context.Background()) @@ -126,15 +127,15 @@ func TestSubscribeEvents(t *testing.T) { t.Run("normal call, with empty args", func(t *testing.T) { t.Parallel() - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider events := make(chan *EmittedEvent) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{}) + if sub != nil { + defer sub.Unsubscribe() + } require.NoError(t, err) require.NotNil(t, sub) - defer sub.Unsubscribe() for { select { @@ -153,17 +154,17 @@ func TestSubscribeEvents(t *testing.T) { t.Run("normal call, fromAddress only", func(t *testing.T) { t.Parallel() - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider events := make(chan *EmittedEvent) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ FromAddress: fromAddress, }) + if sub != nil { + defer sub.Unsubscribe() + } require.NoError(t, err) require.NotNil(t, sub) - defer sub.Unsubscribe() for { select { @@ -185,17 +186,17 @@ func TestSubscribeEvents(t *testing.T) { t.Run("normal call, keys only", func(t *testing.T) { t.Parallel() - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider events := make(chan *EmittedEvent) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ Keys: [][]*felt.Felt{{key}}, }) + if sub != nil { + defer sub.Unsubscribe() + } require.NoError(t, err) require.NotNil(t, sub) - defer sub.Unsubscribe() for { select { @@ -217,17 +218,17 @@ func TestSubscribeEvents(t *testing.T) { t.Run("normal call, blockID only", func(t *testing.T) { t.Parallel() - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider events := make(chan *EmittedEvent) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ BlockID: SubscriptionBlockID{Number: blockNumber - 100}, }) + if sub != nil { + defer sub.Unsubscribe() + } require.NoError(t, err) require.NotNil(t, sub) - defer sub.Unsubscribe() differentFromAddressFound := false differentKeyFound := false @@ -266,9 +267,7 @@ func TestSubscribeEvents(t *testing.T) { t.Run("normal call, with all arguments, within the range of 1024 blocks", func(t *testing.T) { t.Parallel() - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider events := make(chan *EmittedEvent) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{ @@ -276,15 +275,19 @@ func TestSubscribeEvents(t *testing.T) { FromAddress: fromAddress, Keys: [][]*felt.Felt{{key}}, }) + if sub != nil { + defer sub.Unsubscribe() + } require.NoError(t, err) require.NotNil(t, sub) - defer sub.Unsubscribe() for { select { case resp := <-events: require.IsType(t, &EmittedEvent{}, resp) require.Less(t, resp.BlockNumber, blockNumber) + // 'fromAddress' is the address of the sepolia StarkGate: ETH Token, which is very likely to have events, + // so we can use it to verify the events are returned correctly. require.Equal(t, fromAddress, resp.FromAddress) require.Equal(t, key, resp.Keys[0]) return @@ -299,9 +302,7 @@ func TestSubscribeEvents(t *testing.T) { t.Run("error calls", func(t *testing.T) { t.Parallel() - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider type testSetType struct { input EventSubscriptionInput @@ -339,6 +340,9 @@ func TestSubscribeEvents(t *testing.T) { events := make(chan *EmittedEvent) defer close(events) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &test.input) + if sub != nil { + defer sub.Unsubscribe() + } require.Nil(t, sub) require.EqualError(t, err, test.expectedError.Error()) } @@ -352,7 +356,6 @@ func TestSubscribeTransactionStatus(t *testing.T) { } testConfig := beforeEach(t) - require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") provider := testConfig.provider blockInterface, err := provider.BlockWithTxHashes(context.Background(), WithBlockTag("latest")) @@ -370,15 +373,15 @@ func TestSubscribeTransactionStatus(t *testing.T) { } t.Run("normal call", func(t *testing.T) { - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider events := make(chan *NewTxnStatusResp) sub, err := wsProvider.SubscribeTransactionStatus(context.Background(), events, txHash) + if sub != nil { + defer sub.Unsubscribe() + } require.NoError(t, err) require.NotNil(t, sub) - defer sub.Unsubscribe() for { select { @@ -403,12 +406,12 @@ func TestSubscribePendingTransactions(t *testing.T) { } testConfig := beforeEach(t) - require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") type testSetType struct { pendingTxns chan *SubPendingTxns options *SubPendingTxnsInput expectedError error + description string } addresses := make([]*felt.Felt, 1025) @@ -418,45 +421,47 @@ func TestSubscribePendingTransactions(t *testing.T) { testSet := map[string][]testSetType{ "testnet": { - { // nil input + { pendingTxns: make(chan *SubPendingTxns), options: nil, + description: "nil input", }, - { // empty input + { pendingTxns: make(chan *SubPendingTxns), options: &SubPendingTxnsInput{}, + description: "empty input", }, - { // with transanctionDetails true + { pendingTxns: make(chan *SubPendingTxns), options: &SubPendingTxnsInput{TransactionDetails: true}, + description: "with transanctionDetails true", }, - { // error: too many addresses + { pendingTxns: make(chan *SubPendingTxns), options: &SubPendingTxnsInput{SenderAddress: addresses}, expectedError: ErrTooManyAddressesInFilter, + description: "error: too many addresses", }, }, }[testEnv] - for index, test := range testSet { - t.Run(fmt.Sprintf("test %d", index+1), func(t *testing.T) { + for _, test := range testSet { + t.Run(fmt.Sprintf("test: %s", test.description), func(t *testing.T) { t.Parallel() - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider sub, err := wsProvider.SubscribePendingTransactions(context.Background(), test.pendingTxns, test.options) + if sub != nil { + defer sub.Unsubscribe() + } if test.expectedError != nil { require.Error(t, err) return - } else { - require.NoError(t, err) } - + require.NoError(t, err) require.NotNil(t, sub) - defer sub.Unsubscribe() for { select { @@ -487,14 +492,14 @@ func TestUnsubscribe(t *testing.T) { } testConfig := beforeEach(t) - require.NotNil(t, testConfig.wsBase, "wsProvider base is not set") - wsProvider, err := NewWebsocketProvider(testConfig.wsBase) - require.NoError(t, err) - defer wsProvider.Close() + wsProvider := testConfig.wsProvider events := make(chan *EmittedEvent) sub, err := wsProvider.SubscribeEvents(context.Background(), events, &EventSubscriptionInput{}) + if sub != nil { + defer sub.Unsubscribe() + } require.NoError(t, err) require.NotNil(t, sub) @@ -517,9 +522,6 @@ loop: t.Fatal("timeout waiting for unsubscription") } } - - // Unsubscribe again to make sure nothing happens - sub.Unsubscribe() } // TODO: Add mock for testing reorg events. From 4cf22d84a49660e7106d5e1a9f7c0b57c3c6f97a Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Thu, 30 Jan 2025 11:24:49 -0300 Subject: [PATCH 33/35] Improve TestSubscribeEvents testcase description and fix PR comment --- rpc/websocket_test.go | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index cf412efe..c1a5427f 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -3,7 +3,6 @@ package rpc import ( "context" "fmt" - "slices" "testing" "time" @@ -230,30 +229,22 @@ func TestSubscribeEvents(t *testing.T) { require.NoError(t, err) require.NotNil(t, sub) - differentFromAddressFound := false - differentKeyFound := false + uniqueAddresses := make(map[string]bool) + uniqueKeys := make(map[string]bool) for { select { case resp := <-events: require.IsType(t, &EmittedEvent{}, resp) require.Less(t, resp.BlockNumber, blockNumber) - // Subscription with only blockID should return events from all addresses and keys from the specified block onwards. - // Verify by checking for events with different addresses and keys than the test values. - if !differentFromAddressFound { - if resp.FromAddress != fromAddress { - differentFromAddressFound = true - } - } + // As none filters are applied, the events should be from all addresses and keys. - if !differentKeyFound { - if !slices.Contains(resp.Keys, key) { - differentKeyFound = true - } - } + uniqueAddresses[resp.FromAddress.String()] = true + uniqueKeys[resp.Keys[0].String()] = true - if differentFromAddressFound && differentKeyFound { + // check if there are at least 3 different addresses and keys in the received events + if len(uniqueAddresses) >= 3 && len(uniqueKeys) >= 3 { return } case err := <-sub.Err(): @@ -329,7 +320,7 @@ func TestSubscribeEvents(t *testing.T) { }, { input: EventSubscriptionInput{ - BlockID: SubscriptionBlockID{Number: blockNumber + 2}, + BlockID: SubscriptionBlockID{Number: blockNumber + 100}, }, expectedError: ErrBlockNotFound, }, From 03ba2e373dc5547fea24365b9a92d25bc0738b0d Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Fri, 31 Jan 2025 11:42:11 -0300 Subject: [PATCH 34/35] Fix SubscribeTransactionStatus with new Juno update --- rpc/websocket.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/rpc/websocket.go b/rpc/websocket.go index 65f6c58d..a1475ad2 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -93,14 +93,9 @@ func (provider *WsProvider) SubscribePendingTransactions(ctx context.Context, pe // - clientSubscription: The client subscription object, used to unsubscribe from the stream and to get errors // - error: An error, if any func (provider *WsProvider) SubscribeTransactionStatus(ctx context.Context, newStatus chan<- *NewTxnStatusResp, transactionHash *felt.Felt) (*client.ClientSubscription, error) { - sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash, WithBlockTag("latest")) + sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash) if err != nil { - return nil, tryUnwrapToRPCErr(err, ErrTooManyBlocksBack, ErrBlockNotFound) + return nil, tryUnwrapToRPCErr(err) } - // TODO: wait for Juno to implement this. This is the correct implementation by the spec - // sub, err := provider.c.SubscribeWithSliceArgs(ctx, "starknet", "_subscribeTransactionStatus", newStatus, transactionHash) - // if err != nil { - // return nil, tryUnwrapToRPCErr(err) - // } return sub, nil } From cb50e2315fb354b7aa29e9d485b06194f52ff550 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Fri, 31 Jan 2025 16:12:16 -0300 Subject: [PATCH 35/35] New 'is_reverted' field from spec update --- rpc/types_trace.go | 3 +++ rpc/websocket_test.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/rpc/types_trace.go b/rpc/types_trace.go index 6c833730..9651ffa2 100644 --- a/rpc/types_trace.go +++ b/rpc/types_trace.go @@ -123,6 +123,9 @@ type FnInvocation struct { // Resources consumed by the internal call ExecutionResources InnerCallExecutionResources `json:"execution_resources"` + + // True if this inner call panicked + IsReverted bool `json:"is_reverted"` } // the resources consumed by an inner call (does not account for state diffs since data is squashed across the transaction) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index c1a5427f..295713c2 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -40,7 +40,7 @@ func TestSubscribeNewHeads(t *testing.T) { { headers: make(chan *BlockHeader), isErrorExpected: false, - description: "normal call", + description: "normal call, without subBlockID", }, { headers: make(chan *BlockHeader),