diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 388820c5..88e47673 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -210,11 +210,18 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case wstool.ClientVersion: action = wstool.ClientVersion var req versionRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - h.logger.Errorf("unmarshal version request error, request:%s, err:%s", request.Args, err) - reqID := getReqID(request.Args) - commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal version request error") - return + var reqID uint64 + if request.Args != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal version request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal version request error") + return + } + reqID = req.ReqID + } + req = versionRequest{ + ReqID: reqID, } logger := h.logger.WithFields(logrus.Fields{ actionKey: action, diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index bc9278ba..6433d9ef 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -151,4 +151,17 @@ func TestVersion(t *testing.T) { assert.Equal(t, 0, versionResp.Code, versionResp.Message) assert.Equal(t, version.TaosClientVersion, versionResp.Version) assert.Equal(t, wstool.ClientVersion, versionResp.Action) + + req := "{\"action\":\"version\"}" + err = ws.WriteMessage(websocket.TextMessage, []byte(req)) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + + assert.NoError(t, err) + err = json.Unmarshal(resp, &versionResp) + assert.NoError(t, err) + assert.Equal(t, 0, versionResp.Code, versionResp.Message) + assert.Equal(t, version.TaosClientVersion, versionResp.Version) + assert.Equal(t, wstool.ClientVersion, versionResp.Action) } diff --git a/driver/wrapper/stmt_test.go b/driver/wrapper/stmt_test.go index 9d63b7a9..de58e228 100644 --- a/driver/wrapper/stmt_test.go +++ b/driver/wrapper/stmt_test.go @@ -898,7 +898,7 @@ func TestGetFieldsCommonTable(t *testing.T) { return } code, num, _ := TaosStmtGetTagFields(stmt) - assert.Equal(t, 0, code) + assert.NotEqual(t, 0, code) assert.Equal(t, 0, num) code, columnCount, columnsP := TaosStmtGetColFields(stmt) if code != 0 {