diff --git a/controller/rest/restful.go b/controller/rest/restful.go index 8b56bd32..1c9b4b40 100644 --- a/controller/rest/restful.go +++ b/controller/rest/restful.go @@ -17,6 +17,7 @@ import ( "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/commonpool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/driver/common" "github.com/taosdata/taosadapter/v3/driver/common/parser" tErrors "github.com/taosdata/taosadapter/v3/driver/errors" @@ -31,7 +32,6 @@ import ( "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/taosdata/taosadapter/v3/tools/iptool" "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" - "github.com/taosdata/taosadapter/v3/tools/layout" "github.com/taosdata/taosadapter/v3/tools/pool" "github.com/taosdata/taosadapter/v3/tools/sqltype" ) @@ -112,19 +112,30 @@ type TDEngineRestfulRespDoc struct { // @Router /rest/sql [post] func (ctl *Restful) sql(c *gin.Context) { db := c.Param("db") - timeZone := c.Query("tz") - location := time.UTC var err error logger := c.MustGet(LoggerKey).(*logrus.Entry) reqID := c.MustGet(config.ReqIDKey).(int64) - if len(timeZone) != 0 { - location, err = time.LoadLocation(timeZone) + connTimezone, exists := c.GetQuery("conn_tz") + location := time.UTC + if exists { + location, err = time.LoadLocation(connTimezone) if err != nil { - logger.Errorf("load location:%s error:%s", timeZone, err) + logger.Errorf("load conn_tz location:%s fail, error:%s", connTimezone, err) BadRequestResponseWithMsg(c, logger, 0xffff, err.Error()) return } + } else { + timezone, exists := c.GetQuery("tz") + if exists { + location, err = time.LoadLocation(timezone) + if err != nil { + logger.Errorf("load tz location:%s fail, error:%s", timezone, err) + BadRequestResponseWithMsg(c, logger, 0xffff, err.Error()) + return + } + } } + var returnObj bool if returnObjStr := c.Query("row_with_meta"); len(returnObjStr) != 0 { if returnObj, err = strconv.ParseBool(returnObjStr); err != nil { @@ -134,21 +145,7 @@ func (ctl *Restful) sql(c *gin.Context) { } } - timeBuffer := make([]byte, 0, 30) - DoQuery(c, db, func(builder *jsonbuilder.Stream, ts int64, precision int) { - timeBuffer = timeBuffer[:0] - switch precision { - case common.PrecisionMilliSecond: // milli-second - timeBuffer = time.Unix(ts/1e3, (ts%1e3)*1e6).In(location).AppendFormat(timeBuffer, layout.LayoutMillSecond) - case common.PrecisionMicroSecond: // micro-second - timeBuffer = time.Unix(ts/1e6, (ts%1e6)*1e3).In(location).AppendFormat(timeBuffer, layout.LayoutMicroSecond) - case common.PrecisionNanoSecond: // nano-second - timeBuffer = time.Unix(0, ts).In(location).AppendFormat(timeBuffer, layout.LayoutNanoSecond) - default: - logger.Errorf("unknown precision:%d", precision) - } - builder.WriteString(string(timeBuffer)) - }, reqID, returnObj, logger) + DoQuery(c, db, location, reqID, returnObj, logger) } type TDEngineRestfulResp struct { @@ -159,7 +156,7 @@ type TDEngineRestfulResp struct { Rows int `json:"rows,omitempty"` } -func DoQuery(c *gin.Context, db string, timeFunc ctools.FormatTimeFunc, reqID int64, returnObj bool, logger *logrus.Entry) { +func DoQuery(c *gin.Context, db string, location *time.Location, reqID int64, returnObj bool, logger *logrus.Entry) { var s time.Time isDebug := log.IsDebug() b, err := c.GetRawData() @@ -218,14 +215,37 @@ func DoQuery(c *gin.Context, db string, timeFunc ctools.FormatTimeFunc, reqID in } logger.Debugf("put connection finish, cost:%s", log.GetLogDuration(isDebug, s)) }() - + // set connection options + success := trySetConnectionOptions(c, taosConnect.TaosConnection, logger, isDebug) + if !success { + monitor.RestRecordResult(sqlType, false) + return + } if len(db) > 0 { // Attempt to select the database does not return even if there is an error // To avoid error reporting in the `create database` statement logger.Tracef("select db %s", db) _ = async.GlobalAsync.TaosExecWithoutResult(taosConnect.TaosConnection, logger, isDebug, fmt.Sprintf("use `%s`", db), reqID) } - execute(c, logger, isDebug, taosConnect.TaosConnection, sql, timeFunc, reqID, sqlType, returnObj) + execute(c, logger, isDebug, taosConnect.TaosConnection, sql, reqID, sqlType, returnObj, location) +} + +func trySetConnectionOptions(c *gin.Context, conn unsafe.Pointer, logger *logrus.Entry, isDebug bool) bool { + keys := [3]string{"conn_tz", "app", "ip"} + options := [3]int{common.TSDB_OPTION_CONNECTION_TIMEZONE, common.TSDB_OPTION_CONNECTION_USER_APP, common.TSDB_OPTION_CONNECTION_USER_IP} + for i := 0; i < 3; i++ { + val := c.Query(keys[i]) + if val != "" { + code := syncinterface.TaosOptionsConnection(conn, options[i], &val, logger, isDebug) + if code != httperror.SUCCESS { + errStr := wrapper.TaosErrorStr(nil) + logger.Errorf("set connection options error, option:%d, val:%s, code:%d, message:%s", options[i], val, code, errStr) + TaosErrorResponse(c, logger, code, errStr) + return false + } + } + } + return true } var ( @@ -241,7 +261,7 @@ var ( Timing = []byte(`,"timing":`) ) -func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect unsafe.Pointer, sql string, timeFormat ctools.FormatTimeFunc, reqID int64, sqlType sqltype.SqlType, returnObj bool) { +func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect unsafe.Pointer, sql string, reqID int64, sqlType sqltype.SqlType, returnObj bool, location *time.Location) { _, calculateTiming := c.Get(RequireTiming) st := c.MustGet(StartTimeKey) flushTiming := int64(0) @@ -364,6 +384,7 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns fetched := false pHeaderList := make([]unsafe.Pointer, fieldsCount) pStartList := make([]unsafe.Pointer, fieldsCount) + timeBuffer := make([]byte, 0, 30) for { if config.Conf.RestfulRowLimit > -1 && total == config.Conf.RestfulRowLimit { break @@ -412,7 +433,7 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns if returnObj { builder.WriteObjectField(rowsHeader.ColNames[column]) } - ctools.JsonWriteRawBlock(builder, rowsHeader.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, timeFormat) + ctools.JsonWriteRawBlock(builder, rowsHeader.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, location, timeBuffer, logger) if column != fieldsCount-1 { builder.WriteMore() } diff --git a/controller/rest/restful_test.go b/controller/rest/restful_test.go index 39a54d50..f940cb87 100644 --- a/controller/rest/restful_test.go +++ b/controller/rest/restful_test.go @@ -21,6 +21,7 @@ import ( "github.com/taosdata/taosadapter/v3/db" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" + "github.com/taosdata/taosadapter/v3/tools/layout" ) var router *gin.Engine @@ -674,3 +675,98 @@ func TestInternalError(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) } + +func TestSetConnectionOptions(t *testing.T) { + config.Conf.RestfulRowLimit = -1 + w := httptest.NewRecorder() + body := strings.NewReader("create database if not exists rest_test_options") + url := "/rest/sql?app=rest_test_options&ip=192.168.100.1&conn_tz=Europe/Moscow&tz=Asia/Shanghai" + req, _ := http.NewRequest(http.MethodPost, url, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + router.ServeHTTP(w, req) + checkResp(t, w) + + defer func() { + body := strings.NewReader("drop database if exists rest_test_options") + req, _ := http.NewRequest(http.MethodPost, url, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + checkResp(t, w) + }() + + w = httptest.NewRecorder() + body = strings.NewReader("create table if not exists rest_test_options.t1(ts timestamp,v1 bool)") + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + checkResp(t, w) + + w = httptest.NewRecorder() + ts := "2024-12-04 12:34:56.789" + body = strings.NewReader(fmt.Sprintf(`insert into rest_test_options.t1 values ('%s',true)`, ts)) + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + checkResp(t, w) + + w = httptest.NewRecorder() + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + var result TDEngineRestfulRespDoc + err := json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.Equal(t, 0, result.Code) + assert.Equal(t, 1, len(result.Data)) + + location, err := time.LoadLocation("Europe/Moscow") + assert.NoError(t, err) + expectTime, err := time.ParseInLocation("2006-01-02 15:04:05.000", ts, location) + assert.NoError(t, err) + expectTimeStr := expectTime.Format(layout.LayoutMillSecond) + assert.Equal(t, expectTimeStr, result.Data[0][0]) + t.Log(expectTimeStr, result.Data[0][0]) + + // wrong timezone + wrongTZUrl := "/rest/sql?app=rest_test_options&ip=192.168.100.1&tz=xxx" + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) + req, _ = http.NewRequest(http.MethodPost, wrongTZUrl, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, 400, w.Code) + + // wrong conn_tz + wrongConnTZUrl := "/rest/sql?app=rest_test_options&ip=192.168.100.1&conn_tz=xxx" + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) + req, _ = http.NewRequest(http.MethodPost, wrongConnTZUrl, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, 400, w.Code) + // wrong ip + wrongIPUrl := "/rest/sql?app=rest_test_options&ip=xxx.xxx.xxx.xxx&conn_tz=Europe/Moscow&tz=Asia/Shanghai" + req, _ = http.NewRequest(http.MethodPost, wrongIPUrl, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + w = httptest.NewRecorder() + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + err = json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.NotEqual(t, 0, result.Code) +} + +func checkResp(t *testing.T, w *httptest.ResponseRecorder) { + assert.Equal(t, 200, w.Code) + var result TDEngineRestfulRespDoc + err := json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.Equal(t, 0, result.Code) +} diff --git a/controller/ws/query/ws_test.go b/controller/ws/query/ws_test.go index baa99ef6..eb7476c3 100644 --- a/controller/ws/query/ws_test.go +++ b/controller/ws/query/ws_test.go @@ -1608,10 +1608,10 @@ func TestDropUser(t *testing.T) { assert.NoError(t, err) }() defer doRestful("drop user test_ws_query_drop_user", "") - code, message := doRestful("create user test_ws_query_drop_user pass 'pass'", "") + code, message := doRestful("create user test_ws_query_drop_user pass 'pass_123'", "") assert.Equal(t, 0, code, message) // connect - connReq := &WSConnectReq{ReqID: 1, User: "test_ws_query_drop_user", Password: "pass"} + connReq := &WSConnectReq{ReqID: 1, User: "test_ws_query_drop_user", Password: "pass_123"} resp, err := doWebSocket(ws, WSConnect, &connReq) assert.NoError(t, err) var connResp WSConnectResp diff --git a/controller/ws/schemaless/schemaless_test.go b/controller/ws/schemaless/schemaless_test.go index 1e54c2ae..934dded7 100644 --- a/controller/ws/schemaless/schemaless_test.go +++ b/controller/ws/schemaless/schemaless_test.go @@ -262,10 +262,10 @@ func TestDropUser(t *testing.T) { assert.NoError(t, err) }() defer doRestful("drop user test_ws_sml_drop_user", "") - code, message := doRestful("create user test_ws_sml_drop_user pass 'pass'", "") + code, message := doRestful("create user test_ws_sml_drop_user pass 'pass_123'", "") assert.Equal(t, 0, code, message) // connect - connReq := &schemalessConnReq{ReqID: 1, User: "test_ws_sml_drop_user", Password: "pass"} + connReq := &schemalessConnReq{ReqID: 1, User: "test_ws_sml_drop_user", Password: "pass_123"} resp, err := doWebSocket(ws, SchemalessConn, &connReq) assert.NoError(t, err) var connResp schemalessConnResp diff --git a/controller/ws/stmt/stmt_test.go b/controller/ws/stmt/stmt_test.go index 445a2fb8..ee86e668 100644 --- a/controller/ws/stmt/stmt_test.go +++ b/controller/ws/stmt/stmt_test.go @@ -1101,10 +1101,10 @@ func TestDropUser(t *testing.T) { assert.NoError(t, err) }() defer doRestful("drop user test_ws_stmt_drop_user", "") - code, message := doRestful("create user test_ws_stmt_drop_user pass 'pass'", "") + code, message := doRestful("create user test_ws_stmt_drop_user pass 'pass_123'", "") assert.Equal(t, 0, code, message) // connect - connReq := &StmtConnectReq{ReqID: 1, User: "test_ws_stmt_drop_user", Password: "pass"} + connReq := &StmtConnectReq{ReqID: 1, User: "test_ws_stmt_drop_user", Password: "pass_123"} resp, err := doWebSocket(ws, STMTConnect, &connReq) assert.NoError(t, err) var connResp StmtConnectResp diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index cdbcb18e..1cef7054 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -4,6 +4,8 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" + "fmt" "net" "strconv" "sync" @@ -395,6 +397,9 @@ type TMQSubscribeReq struct { MsgConsumeExcluded string `json:"msg_consume_excluded"` SessionTimeoutMS string `json:"session_timeout_ms"` MaxPollIntervalMS string `json:"max_poll_interval_ms"` + TZ string `json:"tz"` + App string `json:"app"` + IP string `json:"ip"` } type TMQSubscribeResp struct { @@ -539,55 +544,77 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu logger.Tracef("tmq append topic:%s", topic) errCode = wrapper.TMQListAppend(topicList, topic) if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - logger.Errorf("tmq list append error, tpic:%s, code:%d, msg:%s", topic, errCode, errStr) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wsTMQErrorMsg(ctx, session, logger, int(errCode), errStr, action, req.ReqID, nil) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(int(errCode), wrapper.TMQErr2Str(errCode)), fmt.Sprintf("tmq list append error, tpic:%s", topic)) return } } errCode = t.wrapperSubscribe(logger, isDebug, cPointer, topicList) if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - logger.Errorf("tmq subscribe error:%d %s", errCode, errStr) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wsTMQErrorMsg(ctx, session, logger, int(errCode), errStr, action, req.ReqID, nil) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(int(errCode), wrapper.TMQErr2Str(errCode)), "tmq subscribe error") return } conn := wrapper.TMQGetConnect(cPointer) logger.Trace("get whitelist") whitelist, err := tool.GetWhitelist(conn) if err != nil { - logger.Errorf("get whitelist error:%s", err.Error()) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wstool.WSError(ctx, session, logger, err, action, req.ReqID) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, err, "get whitelist error") return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, t.ip) if !valid { - logger.Errorf("whitelist prohibits current IP access, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wstool.WSErrorMsg(ctx, session, logger, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) + errorExt := fmt.Sprintf("whitelist prohibits current IP access, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) + err = errors.New("whitelist prohibits current IP access") + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, err, errorExt) return } logger.Trace("register change whitelist") err = tool.RegisterChangeWhitelist(conn, t.whitelistChangeHandle) if err != nil { - logger.Errorf("register change whitelist error:%s", err) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wstool.WSError(ctx, session, logger, err, action, req.ReqID) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, err, "register change whitelist error") return } logger.Trace("register drop user") err = tool.RegisterDropUser(conn, t.dropUserHandle) if err != nil { - logger.Errorf("register drop user error:%s", err) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wstool.WSError(ctx, session, logger, err, action, req.ReqID) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, err, "register drop user error") + return + } + + // set connection ip + clientIP := t.ipStr + if req.IP != "" { + clientIP = req.IP + } + logger.Tracef("set connection ip, ip:%s", clientIP) + code := syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, &clientIP, logger, isDebug) + logger.Trace("set connection ip done") + if code != 0 { + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set connection ip error") return } + // set timezone + if req.TZ != "" { + logger.Tracef("set timezone, tz:%s", req.TZ) + code = syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_TIMEZONE, &req.TZ, logger, isDebug) + logger.Trace("set timezone done") + if code != 0 { + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set timezone error") + return + } + } + // set connection app + if req.App != "" { + logger.Tracef("set app, app:%s", req.App) + code = syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &req.App, logger, isDebug) + logger.Trace("set app done") + if code != 0 { + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set app error") + return + } + } + t.conn = conn t.consumer = cPointer logger.Trace("start to wait signal") @@ -599,12 +626,24 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu }) } -func (t *TMQ) closeConsumerWithErrLog(logger *logrus.Entry, isDebug bool, consumer unsafe.Pointer) { +func (t *TMQ) closeConsumerWithErrLog( + ctx context.Context, + consumer unsafe.Pointer, + session *melody.Session, + logger *logrus.Entry, + isDebug bool, + action string, + reqID uint64, + err error, + errorExt string, +) { + logger.Errorf("%s, err: %s", errorExt, err) errCode := t.wrapperCloseConsumer(logger, isDebug, consumer) if errCode != 0 { errMsg := wrapper.TMQErr2Str(errCode) logger.Errorf("tmq close consumer error, consumer:%p, code:%d, msg:%s", t.consumer, errCode, errMsg) } + wstool.WSError(ctx, session, logger, err, action, reqID) } type TMQCommitReq struct { @@ -1577,8 +1616,7 @@ func (t *TMQ) wrapperConsumerNew(logger *logrus.Entry, isDebug bool, tmqConfig u logger.Tracef("new consumer result %x", uintptr(result.Consumer)) if len(result.ErrStr) > 0 { err = taoserrors.NewError(-1, result.ErrStr) - } - if result.Consumer == nil { + } else if result.Consumer == nil { err = taoserrors.NewError(-1, "new consumer return nil") } t.asyncLocker.Unlock() diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index 69fabf70..ebaf5c4b 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -3156,9 +3156,10 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) } +// todo: not implemented //func TestDropUser(t *testing.T) { // defer doHttpSql("drop user test_tmq_drop_user") -// code, message := doHttpSql("create user test_tmq_drop_user pass 'pass'") +// code, message := doHttpSql("create user test_tmq_drop_user pass 'pass_123'") // assert.Equal(t, 0, code, message) // // dbName := "test_ws_tmq_drop_user" @@ -3186,7 +3187,7 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { // // subscribe // b, _ := json.Marshal(TMQSubscribeReq{ // User: "test_tmq_drop_user", -// Password: "pass", +// Password: "pass_123", // DB: dbName, // GroupID: "test", // Topics: []string{topic}, @@ -3206,3 +3207,124 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { // resp, err := doWebSocket(ws, wstool.ClientVersion, nil) // assert.Error(t, err, string(resp)) //} + +//type httpQueryResp struct { +// Code int `json:"code,omitempty"` +// Desc string `json:"desc,omitempty"` +// ColumnMeta [][]driver.Value `json:"column_meta,omitempty"` +// Data [][]driver.Value `json:"data,omitempty"` +// Rows int `json:"rows,omitempty"` +//} +// +//func restQuery(sql string, db string) *httpQueryResp { +// w := httptest.NewRecorder() +// body := strings.NewReader(sql) +// url := "/rest/sql" +// if db != "" { +// url = fmt.Sprintf("/rest/sql/%s", db) +// } +// req, _ := http.NewRequest(http.MethodPost, url, body) +// req.RemoteAddr = "127.0.0.1:33333" +// req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") +// router.ServeHTTP(w, req) +// if w.Code != http.StatusOK { +// return &httpQueryResp{ +// Code: w.Code, +// Desc: w.Body.String(), +// } +// } +// b, _ := io.ReadAll(w.Body) +// var res httpQueryResp +// _ = json.Unmarshal(b, &res) +// return &res +//} + +func TestConnectionOptions(t *testing.T) { + dbName := "test_ws_tmq_conn_options" + topic := "test_ws_tmq_conn_options_topic" + + before(t, dbName, topic) + + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() + + // subscribe + b, _ := json.Marshal(TMQSubscribeReq{ + User: "root", + Password: "taosdata", + DB: dbName, + GroupID: "test", + Topics: []string{topic}, + AutoCommit: "false", + OffsetReset: "earliest", + SessionTimeoutMS: "100000", + App: "tmq_test_conn_protocol", + IP: "192.168.55.55", + TZ: "Asia/Shanghai", + }) + msg, err := doWebSocket(ws, TMQSubscribe, b) + assert.NoError(t, err) + var subscribeResp TMQSubscribeResp + err = json.Unmarshal(msg, &subscribeResp) + assert.NoError(t, err) + assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) + + // todo: check connection options, C not implemented + //got := false + //for i := 0; i < 10; i++ { + // queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'tmq_test_conn_protocol' and user_ip = '192.168.55.55'", "") + // if queryResp.Code == 0 && len(queryResp.Data) > 0 { + // got = true + // break + // } + // time.Sleep(time.Second) + //} + //assert.True(t, got) +} + +func TestWrongPass(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + // subscribe + b, _ := json.Marshal(TMQSubscribeReq{ + User: "root", + Password: "wrong_pass", + GroupID: "test", + Topics: []string{"test"}, + AutoCommit: "false", + OffsetReset: "earliest", + SessionTimeoutMS: "100000", + App: "tmq_test_conn_protocol", + IP: "192.168.55.55", + TZ: "Asia/Shanghai", + }) + msg, err := doWebSocket(ws, TMQSubscribe, b) + assert.NoError(t, err) + var subscribeResp TMQSubscribeResp + err = json.Unmarshal(msg, &subscribeResp) + assert.NoError(t, err) + assert.NotEqual(t, 0, subscribeResp.Code, subscribeResp.Message) +} diff --git a/controller/ws/ws/const.go b/controller/ws/ws/const.go index a1951a7b..d3f2254a 100644 --- a/controller/ws/ws/const.go +++ b/controller/ws/ws/const.go @@ -42,6 +42,9 @@ const ( STMT2Exec = "stmt2_exec" STMT2Result = "stmt2_result" STMT2Close = "stmt2_close" + + // options + OptionsConnection = "options_connection" ) const ( diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 88e47673..5efc3f2b 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -625,6 +625,20 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { config.ReqIDKey: req.ReqID, }) h.getServerInfo(ctx, session, action, req, logger, log.IsDebug()) + case OptionsConnection: + action = OptionsConnection + var req optionsConnectionRequest + if err := json.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal options connection request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal options connection request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.optionsConnection(ctx, session, action, req, logger, log.IsDebug()) default: h.logger.Errorf("unknown action %s", action) reqID := getReqID(request.Args) diff --git a/controller/ws/ws/handler_test.go b/controller/ws/ws/handler_test.go index f279c34b..c91a318f 100644 --- a/controller/ws/ws/handler_test.go +++ b/controller/ws/ws/handler_test.go @@ -26,10 +26,10 @@ func TestDropUser(t *testing.T) { assert.NoError(t, err) }() defer doRestful("drop user test_ws_drop_user", "") - code, message := doRestful("create user test_ws_drop_user pass 'pass'", "") + code, message := doRestful("create user test_ws_drop_user pass 'pass_123'", "") assert.Equal(t, 0, code, message) // connect - connReq := connRequest{ReqID: 1, User: "test_ws_drop_user", Password: "pass"} + connReq := connRequest{ReqID: 1, User: "test_ws_drop_user", Password: "pass_123"} resp, err := doWebSocket(ws, Connect, &connReq) assert.NoError(t, err) var connResp commonResp @@ -270,6 +270,12 @@ func Test_WrongJsonProtocol(t *testing.T) { args: "wrong", errorPrefix: "unmarshal get server info request error", }, + { + name: "options connection with wrong args", + action: OptionsConnection, + args: "wrong", + errorPrefix: "unmarshal options connection request error", + }, { name: "unknown action", action: "unknown", diff --git a/controller/ws/ws/misc.go b/controller/ws/ws/misc.go index 573f210f..b0f6a4b8 100644 --- a/controller/ws/ws/misc.go +++ b/controller/ws/ws/misc.go @@ -7,6 +7,7 @@ import ( "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/syncinterface" errors2 "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/tools/melody" "github.com/taosdata/taosadapter/v3/version" ) @@ -90,3 +91,30 @@ func (h *messageHandler) version(ctx context.Context, session *melody.Session, a } wstool.WSWriteJson(session, logger, resp) } + +type optionsConnectionRequest struct { + ReqID uint64 `json:"req_id"` + Options []*option `json:"options"` +} +type option struct { + Option int `json:"option"` + Value *string `json:"value"` +} + +func (h *messageHandler) optionsConnection(ctx context.Context, session *melody.Session, action string, req optionsConnectionRequest, logger *logrus.Entry, isDebug bool) { + logger.Trace("options connection") + for i := 0; i < len(req.Options); i++ { + code := syncinterface.TaosOptionsConnection(h.conn, req.Options[i].Option, req.Options[i].Value, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosErrorStr(nil) + val := "" + if req.Options[i].Value != nil { + val = *req.Options[i].Value + } + logger.Errorf("options connection error, option:%d, value:%s, code:%d, err:%s", req.Options[i].Option, val, code, errStr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr) + return + } + } + commonSuccessResponse(ctx, session, logger, action, req.ReqID) +} diff --git a/controller/ws/ws/misc_test.go b/controller/ws/ws/misc_test.go index a2bafdf4..d605326b 100644 --- a/controller/ws/ws/misc_test.go +++ b/controller/ws/ws/misc_test.go @@ -6,9 +6,11 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" ) func TestGetCurrentDB(t *testing.T) { @@ -88,3 +90,95 @@ func TestGetServerInfo(t *testing.T) { assert.Equal(t, 0, serverInfoResp.Code, serverInfoResp.Message) t.Log(serverInfoResp.Info) } + +func TestOptionsConnection(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // set app name + app := "ws_test_options" + optionsConnectionReq := optionsConnectionRequest{ + ReqID: 2, + Options: []*option{ + {Option: common.TSDB_OPTION_CONNECTION_USER_APP, Value: &app}, + }, + } + resp, err = doWebSocket(ws, OptionsConnection, &optionsConnectionReq) + assert.NoError(t, err) + var optionsConnectionResp commonResp + err = json.Unmarshal(resp, &optionsConnectionResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), optionsConnectionResp.ReqID) + assert.Equal(t, 0, optionsConnectionResp.Code, optionsConnectionResp.Message) + + // get app name + got := false + for i := 0; i < 10; i++ { + queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'ws_test_options'", "") + if queryResp.Code == 0 && len(queryResp.Data) > 0 { + got = true + break + } + time.Sleep(time.Second) + } + assert.True(t, got) + // clear app name + optionsConnectionReq = optionsConnectionRequest{ + ReqID: 3, + Options: []*option{ + {Option: common.TSDB_OPTION_CONNECTION_USER_APP, Value: nil}, + }, + } + resp, err = doWebSocket(ws, OptionsConnection, &optionsConnectionReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &optionsConnectionResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), optionsConnectionResp.ReqID) + assert.Equal(t, 0, optionsConnectionResp.Code, optionsConnectionResp.Message) + + // wrong option with nil value + optionsConnectionReq = optionsConnectionRequest{ + ReqID: 4, + Options: []*option{ + {Option: -10000, Value: nil}, + }, + } + resp, err = doWebSocket(ws, OptionsConnection, &optionsConnectionReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &optionsConnectionResp) + assert.NoError(t, err) + assert.Equal(t, uint64(4), optionsConnectionResp.ReqID) + assert.NotEqual(t, 0, optionsConnectionResp.Code) + // wrong option with non-nil value + optionsConnectionReq = optionsConnectionRequest{ + ReqID: 5, + Options: []*option{ + {Option: -10000, Value: &app}, + }, + } + resp, err = doWebSocket(ws, OptionsConnection, &optionsConnectionReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &optionsConnectionResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), optionsConnectionResp.ReqID) + assert.NotEqual(t, 0, optionsConnectionResp.Code) +} diff --git a/controller/ws/ws/query.go b/controller/ws/ws/query.go index 22162c81..7c71fb94 100644 --- a/controller/ws/ws/query.go +++ b/controller/ws/ws/query.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "unsafe" "github.com/sirupsen/logrus" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" @@ -12,7 +13,7 @@ import ( "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" "github.com/taosdata/taosadapter/v3/driver/common" - errors2 "github.com/taosdata/taosadapter/v3/driver/errors" + taoserrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" @@ -27,6 +28,9 @@ type connRequest struct { Password string `json:"password"` DB string `json:"db"` Mode *int `json:"mode"` + TZ string `json:"tz"` + App string `json:"app"` + IP string `json:"ip"` } func (h *messageHandler) connect(ctx context.Context, session *melody.Session, action string, req connRequest, logger *logrus.Entry, isDebug bool) { @@ -45,10 +49,7 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug) if err != nil { - logger.Errorf("connect to TDengine error, err:%s", err) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "connect to TDengine error") return } logger.Trace("get whitelist") @@ -56,19 +57,14 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a whitelist, err := tool.GetWhitelist(conn) logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { - logger.Errorf("get whitelist error, err:%s", err) - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "get whitelist error") return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, h.ip) if !valid { - logger.Errorf("ip not in whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) - syncinterface.TaosClose(conn, logger, isDebug) - commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "whitelist prohibits current IP access") + err = errors.New("ip not in whitelist") + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "ip not in whitelist") return } s = log.GetLogNow(isDebug) @@ -76,11 +72,7 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a err = tool.RegisterChangeWhitelist(conn, h.whitelistChangeHandle) logger.Debugf("register whitelist change cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { - logger.Errorf("register whitelist change error, err:%s", err) - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "register whitelist change error") return } s = log.GetLogNow(isDebug) @@ -88,11 +80,7 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a err = tool.RegisterDropUser(conn, h.dropUserHandle) logger.Debugf("register drop user cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { - logger.Errorf("register drop user error, err:%s", err) - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "register drop user error") return } if req.Mode != nil { @@ -103,15 +91,44 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a code := wrapper.TaosSetConnMode(conn, common.TAOS_CONN_MODE_BI, 1) logger.Trace("set connection mode to BI done") if code != 0 { - logger.Errorf("set connection mode to BI error, err:%s", wrapper.TaosErrorStr(nil)) - syncinterface.TaosClose(conn, logger, isDebug) - commonErrorResponse(ctx, session, logger, action, req.ReqID, code, wrapper.TaosErrorStr(nil)) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set connection mode to BI error") return } default: - syncinterface.TaosClose(conn, logger, isDebug) - logger.Tracef("unexpected mode:%d", *req.Mode) - commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, fmt.Sprintf("unexpected mode:%d", *req.Mode)) + err = fmt.Errorf("unexpected mode:%d", *req.Mode) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, err.Error()) + return + } + } + // set connection ip + clientIP := h.ipStr + if req.IP != "" { + clientIP = req.IP + } + logger.Tracef("set connection ip, ip:%s", clientIP) + code := syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, &clientIP, logger, isDebug) + logger.Trace("set connection ip done") + if code != 0 { + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set connection ip error") + return + } + // set timezone + if req.TZ != "" { + logger.Tracef("set timezone, tz:%s", req.TZ) + code = syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_TIMEZONE, &req.TZ, logger, isDebug) + logger.Trace("set timezone done") + if code != 0 { + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set timezone error") + return + } + } + // set connection app + if req.App != "" { + logger.Tracef("set app, app:%s", req.App) + code = syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &req.App, logger, isDebug) + logger.Trace("set app done") + if code != 0 { + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set app error") return } } @@ -121,6 +138,22 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a commonSuccessResponse(ctx, session, logger, action, req.ReqID) } +func handleConnectError(ctx context.Context, conn unsafe.Pointer, session *melody.Session, logger *logrus.Entry, isDebug bool, action string, reqID uint64, err error, errorExt string) { + var code int + var errStr string + taosError, ok := err.(*taoserrors.TaosError) + if ok { + code = int(taosError.Code) + errStr = taosError.ErrStr + } else { + code = 0xffff + errStr = err.Error() + } + logger.Errorf("%s, code:%d, message:%s", errorExt, code, errStr) + syncinterface.TaosClose(conn, logger, isDebug) + commonErrorResponse(ctx, session, logger, action, reqID, code, errStr) +} + type queryRequest struct { ReqID uint64 `json:"req_id"` Sql string `json:"sql"` diff --git a/controller/ws/ws/query_test.go b/controller/ws/ws/query_test.go index 93c58bb4..a9a9009a 100644 --- a/controller/ws/ws/query_test.go +++ b/controller/ws/ws/query_test.go @@ -62,14 +62,6 @@ func TestWSConnect(t *testing.T) { assert.Equal(t, "duplicate connections", connResp.Message) } -type TestConnRequest struct { - ReqID uint64 `json:"req_id"` - User string `json:"user"` - Password string `json:"password"` - DB string `json:"db"` - Mode int `json:"mode"` -} - func TestMode(t *testing.T) { s := httptest.NewServer(router) defer s.Close() @@ -84,7 +76,7 @@ func TestMode(t *testing.T) { }() wrongMode := 999 - connReq := TestConnRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: wrongMode} + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: &wrongMode} resp, err := doWebSocket(ws, Connect, &connReq) assert.NoError(t, err) var connResp commonResp @@ -96,7 +88,7 @@ func TestMode(t *testing.T) { //bi biMode := 0 - connReq = TestConnRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: biMode} + connReq = connRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: &biMode} resp, err = doWebSocket(ws, Connect, &connReq) assert.NoError(t, err) err = json.Unmarshal(resp, &connResp) @@ -106,9 +98,38 @@ func TestMode(t *testing.T) { } -func TestWrongConnect(t *testing.T) { - // mock tool.GetWhitelist return error +func TestConnectionOptions(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", IP: "192.168.44.55", App: "ws_test_conn_protocol", TZ: "Asia/Shanghai"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + // check connection options + got := false + for i := 0; i < 10; i++ { + queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'ws_test_conn_protocol' and user_ip = '192.168.44.55'", "") + if queryResp.Code == 0 && len(queryResp.Data) > 0 { + got = true + break + } + time.Sleep(time.Second) + } + assert.True(t, got) } func TestWsQuery(t *testing.T) { diff --git a/db/commonpool/pool.go b/db/commonpool/pool.go index 9af5d545..775c91c1 100644 --- a/db/commonpool/pool.go +++ b/db/commonpool/pool.go @@ -14,6 +14,7 @@ import ( "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/driver/common" tErrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" @@ -196,6 +197,7 @@ func (cp *ConnectorPool) Get() (unsafe.Pointer, error) { func (cp *ConnectorPool) Put(c unsafe.Pointer) error { wrapper.TaosResetCurrentDB(c) + wrapper.TaosOptionsConnection(c, common.TSDB_OPTION_CONNECTION_CLEAR, nil) return cp.pool.Put(c) } @@ -308,6 +310,9 @@ func getConnectDirect(connectionPool *ConnectorPool, clientIP net.IP) (*Conn, er if err != nil { return nil, err } + ipStr := clientIP.String() + // ignore error, because we have checked the ip + wrapper.TaosOptionsConnection(c, common.TSDB_OPTION_CONNECTION_USER_IP, &ipStr) return &Conn{ TaosConnection: c, pool: connectionPool, diff --git a/db/commonpool/pool_test.go b/db/commonpool/pool_test.go index 838415c4..e5c0b719 100644 --- a/db/commonpool/pool_test.go +++ b/db/commonpool/pool_test.go @@ -111,28 +111,28 @@ func TestChangePassword(t *testing.T) { c, err := GetConnection("root", "taosdata", net.ParseIP("127.0.0.1")) assert.NoError(t, err) - result := wrapper.TaosQuery(c.TaosConnection, "drop user test") + result := wrapper.TaosQuery(c.TaosConnection, "drop user test_change_pass_pool") assert.NotNil(t, result) wrapper.TaosFreeResult(result) - result = wrapper.TaosQuery(c.TaosConnection, "create user test pass 'test'") + result = wrapper.TaosQuery(c.TaosConnection, "create user test_change_pass_pool pass 'test_123'") assert.NotNil(t, result) errNo := wrapper.TaosError(result) assert.Equal(t, 0, errNo) wrapper.TaosFreeResult(result) defer func() { - r := wrapper.TaosQuery(c.TaosConnection, "drop user test") + r := wrapper.TaosQuery(c.TaosConnection, "drop user test_change_pass_pool") wrapper.TaosFreeResult(r) }() - conn, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) + conn, err := GetConnection("test_change_pass_pool", "test_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) - result = wrapper.TaosQuery(c.TaosConnection, "alter user test pass 'test2'") + result = wrapper.TaosQuery(c.TaosConnection, "alter user test_change_pass_pool pass 'test2_123'") assert.NotNil(t, result) errNo = wrapper.TaosError(result) - assert.Equal(t, 0, errNo) + assert.Equal(t, 0, errNo, wrapper.TaosErrorStr(result)) wrapper.TaosFreeResult(result) result = wrapper.TaosQuery(conn.TaosConnection, "show databases") @@ -149,11 +149,11 @@ func TestChangePassword(t *testing.T) { err = conn.Put() assert.NoError(t, err) - conn2, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) + conn2, err := GetConnection("test_change_pass_pool", "test_123", net.ParseIP("127.0.0.1")) assert.Error(t, err) assert.Nil(t, conn2) - conn3, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn3, err := GetConnection("test_change_pass_pool", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn3) result2 := wrapper.TaosQuery(conn3.TaosConnection, "show databases") @@ -163,7 +163,7 @@ func TestChangePassword(t *testing.T) { err = conn3.Put() assert.NoError(t, err) - conn4, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn4, err := GetConnection("test_change_pass_pool", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn4) result3 := wrapper.TaosQuery(conn4.TaosConnection, "show databases") @@ -178,27 +178,27 @@ func TestChangePasswordConcurrent(t *testing.T) { c, err := GetConnection("root", "taosdata", net.ParseIP("127.0.0.1")) assert.NoError(t, err) - result := wrapper.TaosQuery(c.TaosConnection, "drop user test") + result := wrapper.TaosQuery(c.TaosConnection, "drop user test_change_pass_con") assert.NotNil(t, result) wrapper.TaosFreeResult(result) - result = wrapper.TaosQuery(c.TaosConnection, "create user test pass 'test'") + result = wrapper.TaosQuery(c.TaosConnection, "create user test_change_pass_con pass 'test_123'") assert.NotNil(t, result) errNo := wrapper.TaosError(result) assert.Equal(t, 0, errNo) wrapper.TaosFreeResult(result) defer func() { - r := wrapper.TaosQuery(c.TaosConnection, "drop user test") + r := wrapper.TaosQuery(c.TaosConnection, "drop user test_change_pass_con") wrapper.TaosFreeResult(r) }() - conn, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) + conn, err := GetConnection("test_change_pass_con", "test_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) - result = wrapper.TaosQuery(c.TaosConnection, "alter user test pass 'test2'") + result = wrapper.TaosQuery(c.TaosConnection, "alter user test_change_pass_con pass 'test2_123'") assert.NotNil(t, result) errNo = wrapper.TaosError(result) - assert.Equal(t, 0, errNo) + assert.Equal(t, 0, errNo, wrapper.TaosErrorStr(result)) wrapper.TaosFreeResult(result) result = wrapper.TaosQuery(conn.TaosConnection, "show databases") @@ -215,7 +215,7 @@ func TestChangePasswordConcurrent(t *testing.T) { for i := 0; i < 5; i++ { go func() { defer wg.Done() - conn2, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn2, err := GetConnection("test_change_pass_con", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn2) err = conn2.Put() @@ -223,11 +223,11 @@ func TestChangePasswordConcurrent(t *testing.T) { }() } wg.Wait() - conn2, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) + conn2, err := GetConnection("test_change_pass_con", "test_123", net.ParseIP("127.0.0.1")) assert.Error(t, err) assert.Nil(t, conn2) - conn3, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn3, err := GetConnection("test_change_pass_con", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn3) result2 := wrapper.TaosQuery(conn3.TaosConnection, "show databases") @@ -237,7 +237,7 @@ func TestChangePasswordConcurrent(t *testing.T) { err = conn3.Put() assert.NoError(t, err) - conn4, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn4, err := GetConnection("test_change_pass_con", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn4) result3 := wrapper.TaosQuery(conn4.TaosConnection, "show databases") diff --git a/db/syncinterface/wrapper.go b/db/syncinterface/wrapper.go index dd9c6222..ae189f17 100644 --- a/db/syncinterface/wrapper.go +++ b/db/syncinterface/wrapper.go @@ -421,3 +421,15 @@ func TaosStmt2GetStbFields(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug b thread.SyncLocker.Unlock() return code, count, fields } + +func TaosOptionsConnection(conn unsafe.Pointer, option int, value *string, logger *logrus.Entry, isDebug bool) int { + if value == nil { + logger.Tracef("call taos_options_connection, conn:%p, option:%d, value:", conn, option) + } else { + logger.Tracef("call taos_options_connection, conn:%p, option:%d, value:%s", conn, option, *value) + } + s := log.GetLogNow(isDebug) + code := wrapper.TaosOptionsConnection(conn, option, value) + logger.Debugf("taos_options_connection finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) + return code +} diff --git a/db/syncinterface/wrapper_test.go b/db/syncinterface/wrapper_test.go index 9917b7e8..ddd35808 100644 --- a/db/syncinterface/wrapper_test.go +++ b/db/syncinterface/wrapper_test.go @@ -641,3 +641,26 @@ func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { } return result, nil } + +func TestTaosOptionsConnection(t *testing.T) { + reqID := generator.GetReqID() + var logger = logger.WithField("test", "TestTaosOptionsConnection").WithField(config.ReqIDKey, reqID) + conn, err := TaosConnect("", "root", "taosdata", "", 0, logger, isDebug) + if !assert.NoError(t, err) { + return + } + defer TaosClose(conn, logger, isDebug) + app := "test_sync_interface" + code := TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosErrorStr(nil) + t.Error(t, taoserrors.NewError(code, errStr)) + return + } + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, nil, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosErrorStr(nil) + t.Error(t, taoserrors.NewError(code, errStr)) + return + } +} diff --git a/db/tool/notify_test.go b/db/tool/notify_test.go index c1df3fbc..f5edc97c 100644 --- a/db/tool/notify_test.go +++ b/db/tool/notify_test.go @@ -266,13 +266,13 @@ func TestRegisterChangePass(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) assert.NoError(t, err) defer wrapper.TaosClose(conn) - err = exec(conn, "create user test_notify pass 'notify'") + err = exec(conn, "create user test_notify pass 'notify_123'") assert.NoError(t, err) defer func() { // ignore error _ = exec(conn, "drop user test_notify") }() - conn2, err := wrapper.TaosConnect("", "test_notify", "notify", "", 0) + conn2, err := wrapper.TaosConnect("", "test_notify", "notify_123", "", 0) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() @@ -285,7 +285,7 @@ func TestRegisterChangePass(t *testing.T) { } ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second*5) defer cancel2() - err = exec(conn, "alter user test_notify pass 'test'") + err = exec(conn, "alter user test_notify pass 'test_123'") assert.NoError(t, err) select { case data := <-c: @@ -301,13 +301,13 @@ func TestRegisterDropUser(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) assert.NoError(t, err) defer wrapper.TaosClose(conn) - err = exec(conn, "create user test_drop_user pass 'notify'") + err = exec(conn, "create user test_drop_user pass 'notify_123'") assert.NoError(t, err) defer func() { // ignore error _ = exec(conn, "drop user test_drop_user") }() - conn2, err := wrapper.TaosConnect("", "test_drop_user", "notify", "", 0) + conn2, err := wrapper.TaosConnect("", "test_drop_user", "notify_123", "", 0) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() diff --git a/driver/common/const.go b/driver/common/const.go index 49efee48..ae97d949 100644 --- a/driver/common/const.go +++ b/driver/common/const.go @@ -24,6 +24,14 @@ const ( TSDB_OPTION_USE_ADAPTER ) +const ( + TSDB_OPTION_CONNECTION_CLEAR = iota - 1 + TSDB_OPTION_CONNECTION_CHARSET + TSDB_OPTION_CONNECTION_TIMEZONE + TSDB_OPTION_CONNECTION_USER_IP + TSDB_OPTION_CONNECTION_USER_APP +) + const ( TMQ_RES_INVALID = -1 TMQ_RES_DATA = 1 diff --git a/driver/wrapper/notify_test.go b/driver/wrapper/notify_test.go index 8db0cb37..de102270 100644 --- a/driver/wrapper/notify_test.go +++ b/driver/wrapper/notify_test.go @@ -25,10 +25,10 @@ func TestNotify(t *testing.T) { _ = exec(conn, "drop user t_notify") }() _ = exec(conn, "drop user t_notify") - err = exec(conn, "create user t_notify pass 'notify'") + err = exec(conn, "create user t_notify pass 'notify_123'") assert.NoError(t, err) - conn2, err := TaosConnect("", "t_notify", "notify", "", 0) + conn2, err := TaosConnect("", "t_notify", "notify_123", "", 0) if err != nil { t.Error(err) return @@ -58,7 +58,7 @@ func TestNotify(t *testing.T) { t.Error(errCode, errStr) } - err = exec(conn, "alter user t_notify pass 'test'") + err = exec(conn, "alter user t_notify pass 'test_123'") assert.NoError(t, err) timeout, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() diff --git a/driver/wrapper/taosc.go b/driver/wrapper/taosc.go index e4ccb12b..de3f62ea 100644 --- a/driver/wrapper/taosc.go +++ b/driver/wrapper/taosc.go @@ -27,6 +27,9 @@ void taos_query_a_with_req_id_wrapper(TAOS *taos,const char *sql, void *param, i void taos_fetch_raw_block_a_wrapper(TAOS_RES *res, void *param){ return taos_fetch_raw_block_a(res,FetchRawBlockCallback,param); }; +int taos_options_connection_wrapper(TAOS *taos, TSDB_OPTION_CONNECTION option, void *arg) { + return taos_options_connection(taos,option,arg); +}; */ import "C" import ( @@ -287,3 +290,13 @@ func TaosGetServerInfo(conn unsafe.Pointer) string { info := C.taos_get_server_info(conn) return C.GoString(info) } + +// TaosOptionsConnection int taos_options_connection(TAOS *taos, TSDB_OPTION_CONNECTION option, const void *arg, ...) +func TaosOptionsConnection(conn unsafe.Pointer, option int, value *string) int { + cValue := unsafe.Pointer(nil) + if value != nil { + cValue = unsafe.Pointer(C.CString(*value)) + defer C.free(cValue) + } + return int(C.taos_options_connection_wrapper(conn, (C.TSDB_OPTION_CONNECTION)(option), cValue)) +} diff --git a/driver/wrapper/taosc_test.go b/driver/wrapper/taosc_test.go index 7a2c268e..ec533323 100644 --- a/driver/wrapper/taosc_test.go +++ b/driver/wrapper/taosc_test.go @@ -605,3 +605,111 @@ func TestTaosGetServerInfo(t *testing.T) { info := TaosGetServerInfo(conn) assert.NotEmpty(t, info) } + +func TestTaosOptionsConnection(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + ip := "192.168.9.9" + app := "test_options_connection" + // set ip + code := TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, &ip) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + // set app + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + var values [][]driver.Value + for i := 0; i < 10; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_connection'") + assert.NoError(t, err) + if len(values) == 1 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 1, len(values)) + connID := values[0][0].(uint32) + + // clean app + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, nil) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 10; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_connection'") + assert.NoError(t, err) + if len(values) == 0 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 0, len(values)) + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9'") + assert.NoError(t, err) + assert.Equal(t, 1, len(values)) + assert.Equal(t, connID, values[0][0].(uint32)) + + // set app + app = "test_options_2" + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 20; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_2'") + assert.NoError(t, err) + if len(values) == 1 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 1, len(values)) + assert.Equal(t, connID, values[0][0].(uint32)) + + // clear ip + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, nil) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 10; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_2'") + assert.NoError(t, err) + if len(values) == 0 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 0, len(values)) + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_app = 'test_options_2'") + assert.NoError(t, err) + assert.Equal(t, 1, len(values)) + assert.Equal(t, connID, values[0][0].(uint32)) + + // clean all + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_CLEAR, nil) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 10; i++ { + values, err = query(conn, fmt.Sprintf("select user_app,user_ip from performance_schema.perf_connections where conn_id = %d", connID)) + assert.NoError(t, err) + if len(values) == 1 && values[0][0].(string) == "" && values[0][1].(string) == "" { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 1, len(values)) + assert.Equal(t, "", values[0][0].(string)) + assert.Equal(t, "", values[0][1].(string)) + + ip = "192.168.9.9" + app = "test_options_connection" + // set ip + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, &ip) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + // set app + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 10; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_connection'") + assert.NoError(t, err) + if len(values) == 1 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 1, len(values)) + assert.Equal(t, connID, values[0][0].(uint32)) +} diff --git a/plugin/influxdb/plugin.go b/plugin/influxdb/plugin.go index 3c1696ac..cfcd62ee 100644 --- a/plugin/influxdb/plugin.go +++ b/plugin/influxdb/plugin.go @@ -9,7 +9,10 @@ import ( "github.com/gin-gonic/gin" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/commonpool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/common" tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/plugin" @@ -188,6 +191,13 @@ func (p *Influxdb) write(c *gin.Context) { } }() conn := taosConn.TaosConnection + app := c.Query("app") + if app != "" { + errCode := syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app, logger, isDebug) + if errCode != 0 { + logger.Errorf("set app error, app:%s, code:%d, msg:%s", app, errCode, wrapper.TaosErrorStr(nil)) + } + } s = log.GetLogNow(isDebug) logger.Tracef("start insert influxdb, data:%s", data) err = inserter.InsertInfluxdb(conn, data, db, precision, ttl, reqID, tableNameKey, logger) diff --git a/plugin/influxdb/plugin_test.go b/plugin/influxdb/plugin_test.go index 552f2468..1340b71b 100644 --- a/plugin/influxdb/plugin_test.go +++ b/plugin/influxdb/plugin_test.go @@ -61,19 +61,19 @@ func TestInfluxdb(t *testing.T) { }() w := httptest.NewRecorder() reader := strings.NewReader(fmt.Sprintf("measurement,host=host1 field1=%di,field2=2.0,fieldKey=\"Launch 🚀\" %d", number, time.Now().UnixNano())) - req, _ := http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb", reader) + req, _ := http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb&app=test_influxdb", reader) req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) assert.Equal(t, 204, w.Code) w = httptest.NewRecorder() reader = strings.NewReader("measurement,host=host1 field1=a1") - req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb", reader) + req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb&app=test_influxdb", reader) req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) assert.Equal(t, 500, w.Code) w = httptest.NewRecorder() reader = strings.NewReader(fmt.Sprintf("measurement,host=host1 field1=%di,field2=2.0,fieldKey=\"Launch 🚀\" %d", number, time.Now().UnixNano())) - req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata", reader) + req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&app=test_influxdb", reader) req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) @@ -91,7 +91,7 @@ func TestInfluxdb(t *testing.T) { w = httptest.NewRecorder() reader = strings.NewReader(fmt.Sprintf("measurement_ttl,host=host1 field1=%di,field2=2.0,fieldKey=\"Launch 🚀\" %d", number, time.Now().UnixNano())) - req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb_ttl&ttl=1000", reader) + req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb_ttl&ttl=1000&app=test_influxdb", reader) req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) time.Sleep(time.Second) diff --git a/plugin/opentsdb/plugin.go b/plugin/opentsdb/plugin.go index f22cfc35..a5229fac 100644 --- a/plugin/opentsdb/plugin.go +++ b/plugin/opentsdb/plugin.go @@ -11,7 +11,10 @@ import ( "github.com/gin-gonic/gin" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/commonpool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/common" tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/plugin" @@ -152,6 +155,13 @@ func (p *Plugin) insertJson(c *gin.Context) { logger.WithError(putErr).Errorln("connect pool put error") } }() + app := c.Query("app") + if app != "" { + errCode := syncinterface.TaosOptionsConnection(taosConn.TaosConnection, common.TSDB_OPTION_CONNECTION_USER_APP, &app, logger, isDebug) + if errCode != 0 { + logger.Errorf("set app error, app:%s, code:%d, msg:%s", app, errCode, wrapper.TaosErrorStr(nil)) + } + } s = log.GetLogNow(isDebug) logger.Debugf("insert json payload, data:%s, db:%s, ttl:%d, table_name_key:%s", data, db, ttl, tableNameKey) err = inserter.InsertOpentsdbJson(taosConn.TaosConnection, data, db, ttl, reqID, tableNameKey, logger) @@ -271,6 +281,13 @@ func (p *Plugin) insertTelnet(c *gin.Context) { logger.WithError(putErr).Errorln("connect pool put error") } }() + app := c.Query("app") + if app != "" { + errCode := syncinterface.TaosOptionsConnection(taosConn.TaosConnection, common.TSDB_OPTION_CONNECTION_USER_APP, &app, logger, isDebug) + if errCode != 0 { + logger.Errorf("set app error, app:%s, code:%d, msg:%s", app, errCode, wrapper.TaosErrorStr(nil)) + } + } s = log.GetLogNow(isDebug) logger.Debugf("insert telnet payload, lines:%v, db:%s, ttl:%d, table_name_key: %s", lines, db, ttl, tableNameKey) err = inserter.InsertOpentsdbTelnetBatch(taosConn.TaosConnection, lines, db, ttl, reqID, tableNameKey, logger) diff --git a/plugin/opentsdb/plugin_test.go b/plugin/opentsdb/plugin_test.go index a2101fa0..4317e885 100644 --- a/plugin/opentsdb/plugin_test.go +++ b/plugin/opentsdb/plugin_test.go @@ -58,7 +58,7 @@ func TestOpentsdb(t *testing.T) { }() w := httptest.NewRecorder() reader := strings.NewReader(fmt.Sprintf("put metric %d %d host=web01 interface=eth0 ", time.Now().Unix(), number)) - req, _ := http.NewRequest("POST", "/put/telnet/test_plugin_opentsdb_http_telnet?ttl=1000", reader) + req, _ := http.NewRequest("POST", "/put/telnet/test_plugin_opentsdb_http_telnet?ttl=1000&app=test_telnet_http", reader) req.RemoteAddr = "127.0.0.1:33333" req.SetBasicAuth("root", "taosdata") router.ServeHTTP(w, req) @@ -73,7 +73,7 @@ func TestOpentsdb(t *testing.T) { "dc": "lga" } }`, time.Now().Unix(), number)) - req, _ = http.NewRequest("POST", "/put/json/test_plugin_opentsdb_http_json?ttl=1000", reader) + req, _ = http.NewRequest("POST", "/put/json/test_plugin_opentsdb_http_json?ttl=1000&app=test_json_http", reader) req.RemoteAddr = "127.0.0.1:33333" req.SetBasicAuth("root", "taosdata") router.ServeHTTP(w, req) diff --git a/tools/ctools/block.go b/tools/ctools/block.go index 72f0f08f..d7abdc76 100644 --- a/tools/ctools/block.go +++ b/tools/ctools/block.go @@ -3,16 +3,18 @@ package ctools import ( "math" "strconv" + "time" "unsafe" + "github.com/sirupsen/logrus" "github.com/taosdata/taosadapter/v3/driver/common" "github.com/taosdata/taosadapter/v3/driver/common/parser" "github.com/taosdata/taosadapter/v3/tools" + "github.com/taosdata/taosadapter/v3/tools/bytesutil" "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" + "github.com/taosdata/taosadapter/v3/tools/layout" ) -type FormatTimeFunc func(builder *jsonbuilder.Stream, ts int64, precision int) - func IsVarDataType(colType uint8) bool { return colType == common.TSDB_DATA_TYPE_BINARY || colType == common.TSDB_DATA_TYPE_NCHAR || colType == common.TSDB_DATA_TYPE_JSON || colType == common.TSDB_DATA_TYPE_VARBINARY || colType == common.TSDB_DATA_TYPE_GEOMETRY } @@ -87,9 +89,20 @@ func WriteRawJsonDouble(builder *jsonbuilder.Stream, pStart unsafe.Pointer, row builder.WriteFloat64(math.Float64frombits(*((*uint64)(tools.AddPointer(pStart, uintptr(row)*parser.Float64Size))))) } -func WriteRawJsonTime(builder *jsonbuilder.Stream, pStart unsafe.Pointer, row int, precision int, timeFormat FormatTimeFunc) { - value := *((*int64)(tools.AddPointer(pStart, uintptr(row)*parser.Int64Size))) - timeFormat(builder, value, precision) +func WriteRawJsonTime(builder *jsonbuilder.Stream, pStart unsafe.Pointer, row int, precision int, location *time.Location, timeBuffer []byte, logger *logrus.Entry) { + ts := *((*int64)(tools.AddPointer(pStart, uintptr(row)*parser.Int64Size))) + timeBuffer = timeBuffer[:0] + switch precision { + case common.PrecisionMilliSecond: // milli-second + timeBuffer = time.Unix(ts/1e3, (ts%1e3)*1e6).In(location).AppendFormat(timeBuffer, layout.LayoutMillSecond) + case common.PrecisionMicroSecond: // micro-second + timeBuffer = time.Unix(ts/1e6, (ts%1e6)*1e3).In(location).AppendFormat(timeBuffer, layout.LayoutMicroSecond) + case common.PrecisionNanoSecond: // nano-second + timeBuffer = time.Unix(0, ts).In(location).AppendFormat(timeBuffer, layout.LayoutNanoSecond) + default: + logger.Errorf("unknown precision:%d", precision) + } + builder.WriteString(bytesutil.ToUnsafeString(timeBuffer)) } func WriteRawJsonBinary(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Pointer, row int) { @@ -167,7 +180,7 @@ func WriteRawJsonJson(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Pointe } } -func JsonWriteRawBlock(builder *jsonbuilder.Stream, colType uint8, pHeader, pStart unsafe.Pointer, row int, precision int, timeFormat FormatTimeFunc) { +func JsonWriteRawBlock(builder *jsonbuilder.Stream, colType uint8, pHeader, pStart unsafe.Pointer, row int, precision int, location *time.Location, timeBuffer []byte, logger *logrus.Entry) { if IsVarDataType(colType) { switch colType { case uint8(common.TSDB_DATA_TYPE_BINARY): @@ -209,7 +222,7 @@ func JsonWriteRawBlock(builder *jsonbuilder.Stream, colType uint8, pHeader, pSta case uint8(common.TSDB_DATA_TYPE_DOUBLE): WriteRawJsonDouble(builder, pStart, row) case uint8(common.TSDB_DATA_TYPE_TIMESTAMP): - WriteRawJsonTime(builder, pStart, row, precision, timeFormat) + WriteRawJsonTime(builder, pStart, row, precision, location, timeBuffer, logger) } } } diff --git a/tools/ctools/block_test.go b/tools/ctools/block_test.go index 03390fba..cb4550ff 100644 --- a/tools/ctools/block_test.go +++ b/tools/ctools/block_test.go @@ -6,6 +6,7 @@ import ( "time" "unsafe" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/taosdata/taosadapter/v3/driver/common" "github.com/taosdata/taosadapter/v3/driver/common/parser" @@ -13,6 +14,21 @@ import ( "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" ) +type testHook struct { + hasError bool +} + +func (t *testHook) Levels() []logrus.Level { + return []logrus.Level{logrus.ErrorLevel} +} + +func (t *testHook) Fire(entry *logrus.Entry) error { + if entry.Level == logrus.ErrorLevel { + t.hasError = true + } + return nil +} + func TestJsonWriteRawBlock(t *testing.T) { raw := []byte{ 0x01, 0x00, 0x00, 0x00, @@ -134,13 +150,9 @@ func TestJsonWriteRawBlock(t *testing.T) { 0x07, 0x00, 0x7b, 0x22, 0x61, 0x22, 0x3a, 0x31, 0x7d, } - w := &strings.Builder{} - builder := jsonbuilder.BorrowStream(w) - defer jsonbuilder.ReturnStream(builder) fieldsCount := 17 fieldTypes := []uint8{9, 1, 2, 3, 4, 5, 11, 12, 13, 14, 6, 7, 8, 10, 16, 20, 15} blockSize := 2 - precision := 0 pHeaderList := make([]unsafe.Pointer, fieldsCount) pStartList := make([]unsafe.Pointer, fieldsCount) nullBitMapOffset := uintptr(BitmapLen(blockSize)) @@ -158,38 +170,91 @@ func TestJsonWriteRawBlock(t *testing.T) { } tmpPHeader = tools.AddPointer(pStartList[column], uintptr(colLength)) } - timeBuffer := make([]byte, 0, 30) - builder.WriteObjectStart() - for row := 0; row < blockSize; row++ { - builder.WriteArrayStart() - for column := 0; column < fieldsCount; column++ { - JsonWriteRawBlock(builder, fieldTypes[column], pHeaderList[column], pStartList[column], row, precision, func(builder *jsonbuilder.Stream, ts int64, precision int) { - timeBuffer = timeBuffer[:0] - switch precision { - case common.PrecisionMilliSecond: // milli-second - timeBuffer = time.Unix(0, ts*1e6).UTC().AppendFormat(timeBuffer, time.RFC3339Nano) - case common.PrecisionMicroSecond: // micro-second - timeBuffer = time.Unix(0, ts*1e3).UTC().AppendFormat(timeBuffer, time.RFC3339Nano) - case common.PrecisionNanoSecond: // nano-second - timeBuffer = time.Unix(0, ts).UTC().AppendFormat(timeBuffer, time.RFC3339Nano) - default: - panic("unknown precision") + cnLocation, _ := time.LoadLocation("Asia/Shanghai") + tests := []struct { + name string + precision int + timeZone *time.Location + expect string + expectError bool + }{ + { + name: "ms", + precision: common.PrecisionMilliSecond, + timeZone: time.UTC, + expect: `{["2022-08-10T07:02:40.500Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["2022-08-10T07:02:41.500Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "ns", + precision: common.PrecisionNanoSecond, + timeZone: time.UTC, + expect: `{["1970-01-01T00:27:40.114960500Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["1970-01-01T00:27:40.114961500Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "us", + precision: common.PrecisionMicroSecond, + timeZone: time.UTC, + expect: `{["1970-01-20T05:08:34.960500Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["1970-01-20T05:08:34.961500Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "ms_cn", + precision: common.PrecisionMilliSecond, + timeZone: cnLocation, + expect: `{["2022-08-10T15:02:40.500+08:00",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["2022-08-10T15:02:41.500+08:00",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "ns_cn", + precision: common.PrecisionNanoSecond, + timeZone: cnLocation, + expect: `{["1970-01-01T08:27:40.114960500+08:00",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["1970-01-01T08:27:40.114961500+08:00",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "us_cn", + precision: common.PrecisionMicroSecond, + timeZone: cnLocation, + expect: `{["1970-01-20T13:08:34.960500+08:00",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["1970-01-20T13:08:34.961500+08:00",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "invalid precision", + precision: -1, + timeZone: cnLocation, + expectError: true, + }, + } + for _, tt := range tests { + hook := &testHook{} + logger := logrus.New() + logger.AddHook(hook) + t.Run(tt.name, func(t *testing.T) { + w := &strings.Builder{} + builder := jsonbuilder.BorrowStream(w) + defer jsonbuilder.ReturnStream(builder) + builder.WriteObjectStart() + timeBuffer := make([]byte, 0, 35) + for row := 0; row < blockSize; row++ { + builder.WriteArrayStart() + for column := 0; column < fieldsCount; column++ { + JsonWriteRawBlock(builder, fieldTypes[column], pHeaderList[column], pStartList[column], row, tt.precision, tt.timeZone, timeBuffer, logger.WithField("test", "test")) + if column != fieldsCount-1 { + builder.WriteMore() + err := builder.Flush() + assert.NoError(t, err) + } + } + builder.WriteArrayEnd() + if row != blockSize-1 { + builder.WriteMore() } - builder.WriteString(string(timeBuffer)) - }) - if column != fieldsCount-1 { - builder.WriteMore() - err := builder.Flush() - assert.NoError(t, err) } - } - builder.WriteArrayEnd() - if row != blockSize-1 { - builder.WriteMore() - } + builder.WriteObjectEnd() + err := builder.Flush() + assert.NoError(t, err) + if tt.expectError { + assert.True(t, hook.hasError) + } else { + assert.False(t, hook.hasError) + assert.Equal(t, tt.expect, w.String()) + } + }) } - builder.WriteObjectEnd() - err := builder.Flush() - assert.NoError(t, err) - assert.Equal(t, `{["2022-08-10T07:02:40.5Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["2022-08-10T07:02:41.5Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, w.String()) }