diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 9f8ff5a..388820c 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -208,7 +208,19 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { // no need connection actions switch request.Action { case wstool.ClientVersion: - wstool.WSWriteVersion(session, h.logger) + action = wstool.ClientVersion + var req versionRequest + if err := json.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal version request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal version request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.version(ctx, session, action, req, logger, log.IsDebug()) return case Connect: action = Connect diff --git a/controller/ws/ws/handler_test.go b/controller/ws/ws/handler_test.go index dbd232f..f279c34 100644 --- a/controller/ws/ws/handler_test.go +++ b/controller/ws/ws/handler_test.go @@ -102,6 +102,12 @@ func Test_WrongJsonProtocol(t *testing.T) { args: nil, errorPrefix: "request no action", }, + { + name: "version with wrong args", + action: wstool.ClientVersion, + args: "wrong", + errorPrefix: "unmarshal version request error", + }, { name: "connect with wrong args", action: Connect, diff --git a/controller/ws/ws/misc.go b/controller/ws/ws/misc.go index da4e634..573f210 100644 --- a/controller/ws/ws/misc.go +++ b/controller/ws/ws/misc.go @@ -8,6 +8,7 @@ import ( "github.com/taosdata/taosadapter/v3/db/syncinterface" errors2 "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/tools/melody" + "github.com/taosdata/taosadapter/v3/version" ) type getCurrentDBRequest struct { @@ -65,3 +66,27 @@ func (h *messageHandler) getServerInfo(ctx context.Context, session *melody.Sess } wstool.WSWriteJson(session, logger, resp) } + +type versionRequest struct { + ReqID uint64 `json:"req_id"` +} + +type versionResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + Version string `json:"version"` +} + +func (h *messageHandler) version(ctx context.Context, session *melody.Session, action string, req versionRequest, logger *logrus.Entry, isDebug bool) { + logger.Trace("get version") + resp := &versionResp{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + Version: version.TaosClientVersion, + } + wstool.WSWriteJson(session, logger, resp) +} diff --git a/controller/ws/ws/stmt2.go b/controller/ws/ws/stmt2.go index e2ff2d5..f7ccb5b 100644 --- a/controller/ws/ws/stmt2.go +++ b/controller/ws/ws/stmt2.go @@ -200,7 +200,7 @@ type stmt2UseResultResponse struct { ReqID uint64 `json:"req_id"` Timing int64 `json:"timing"` StmtID uint64 `json:"stmt_id"` - ResultID uint64 `json:"result_id"` + ID uint64 `json:"id"` FieldsCount int `json:"fields_count"` FieldsNames []string `json:"fields_names"` FieldsTypes jsontype.JsonUint8 `json:"fields_types"` @@ -227,7 +227,7 @@ func (h *messageHandler) stmt2UseResult(ctx context.Context, session *melody.Ses ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), StmtID: req.StmtID, - ResultID: idx, + ID: idx, FieldsCount: fieldsCount, FieldsNames: rowsHeader.ColNames, FieldsTypes: rowsHeader.ColTypes, diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go index 55f0fc1..8486ec4 100644 --- a/controller/ws/ws/stmt2_test.go +++ b/controller/ws/ws/stmt2_test.go @@ -510,7 +510,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.Equal(t, 0, useResultResp.Code, useResultResp.Message) // fetch - fetchReq := fetchRequest{ReqID: 8, ID: useResultResp.ResultID} + fetchReq := fetchRequest{ReqID: 8, ID: useResultResp.ID} resp, err = doWebSocket(ws, WSFetch, &fetchReq) assert.NoError(t, err) var fetchResp fetchResponse @@ -521,7 +521,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.Equal(t, 1, fetchResp.Rows) // fetch block - fetchBlockReq := fetchBlockRequest{ReqID: 9, ID: useResultResp.ResultID} + fetchBlockReq := fetchBlockRequest{ReqID: 9, ID: useResultResp.ID} fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) assert.NoError(t, err) _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], useResultResp.FieldsTypes, fetchResp.Rows, useResultResp.Precision) @@ -531,7 +531,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.Equal(t, float32(0.31), blockResult[0][3]) // free result - freeResultReq, _ := json.Marshal(freeResultRequest{ReqID: 10, ID: useResultResp.ResultID}) + freeResultReq, _ := json.Marshal(freeResultRequest{ReqID: 10, ID: useResultResp.ID}) a, _ := json.Marshal(Request{Action: WSFreeResult, Args: freeResultReq}) err = ws.WriteMessage(websocket.TextMessage, a) assert.NoError(t, err) diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index cf67435..bc9278b 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -131,11 +131,6 @@ func doWebSocketWithoutResp(ws *websocket.Conn, action string, arg interface{}) return nil } -type versionResponse struct { - commonResp - Version string -} - func TestVersion(t *testing.T) { s := httptest.NewServer(router) defer s.Close() @@ -150,7 +145,7 @@ func TestVersion(t *testing.T) { }() resp, err := doWebSocket(ws, wstool.ClientVersion, nil) assert.NoError(t, err) - var versionResp versionResponse + var versionResp versionResp err = json.Unmarshal(resp, &versionResp) assert.NoError(t, err) assert.Equal(t, 0, versionResp.Code, versionResp.Message)