From 07eeb21a0cf5cfad28ace4148e181e81ea686c7c Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 11 Dec 2024 15:12:22 +0800 Subject: [PATCH] fix: get client version without request id --- controller/ws/ws/handler.go | 17 ++++++++++++----- controller/ws/ws/ws_test.go | 13 +++++++++++++ driver/wrapper/stmt_test.go | 2 +- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 388820c..88e4767 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -210,11 +210,18 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case wstool.ClientVersion: action = wstool.ClientVersion var req versionRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - h.logger.Errorf("unmarshal version request error, request:%s, err:%s", request.Args, err) - reqID := getReqID(request.Args) - commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal version request error") - return + var reqID uint64 + if request.Args != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal version request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal version request error") + return + } + reqID = req.ReqID + } + req = versionRequest{ + ReqID: reqID, } logger := h.logger.WithFields(logrus.Fields{ actionKey: action, diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index bc9278b..6433d9e 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -151,4 +151,17 @@ func TestVersion(t *testing.T) { assert.Equal(t, 0, versionResp.Code, versionResp.Message) assert.Equal(t, version.TaosClientVersion, versionResp.Version) assert.Equal(t, wstool.ClientVersion, versionResp.Action) + + req := "{\"action\":\"version\"}" + err = ws.WriteMessage(websocket.TextMessage, []byte(req)) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + + assert.NoError(t, err) + err = json.Unmarshal(resp, &versionResp) + assert.NoError(t, err) + assert.Equal(t, 0, versionResp.Code, versionResp.Message) + assert.Equal(t, version.TaosClientVersion, versionResp.Version) + assert.Equal(t, wstool.ClientVersion, versionResp.Action) } diff --git a/driver/wrapper/stmt_test.go b/driver/wrapper/stmt_test.go index 9d63b7a..de58e22 100644 --- a/driver/wrapper/stmt_test.go +++ b/driver/wrapper/stmt_test.go @@ -898,7 +898,7 @@ func TestGetFieldsCommonTable(t *testing.T) { return } code, num, _ := TaosStmtGetTagFields(stmt) - assert.Equal(t, 0, code) + assert.NotEqual(t, 0, code) assert.Equal(t, 0, num) code, columnCount, columnsP := TaosStmtGetColFields(stmt) if code != 0 {