From 2bb3e29968536856398e9ae266c20d87eed598a8 Mon Sep 17 00:00:00 2001 From: huskar-t <1172915550@qq.com> Date: Wed, 23 Oct 2024 00:20:46 +0800 Subject: [PATCH 01/48] fix: add max number of concurrent calls allowed for the C method configuration item --- config/config.go | 24 ++++++ config/config_test.go | 2 + config/config_windows_test.go | 2 + controller/rest/configcontroller.go | 8 +- controller/rest/restful.go | 6 +- controller/ws/query/ws.go | 12 +-- controller/ws/tmq/tmq.go | 6 +- controller/ws/ws/query_result.go | 2 +- db/async/row.go | 36 +++++--- db/init.go | 3 +- db/syncinterface/wrapper.go | 124 ++++++++++++++-------------- db/syncinterface/wrapper_test.go | 2 + db/tool/notify.go | 4 +- example/config/taosadapter.toml | 8 +- thread/locker.go | 21 +---- thread/locker_test.go | 9 -- 16 files changed, 149 insertions(+), 120 deletions(-) diff --git a/config/config.go b/config/config.go index 60228101..0d01effc 100644 --- a/config/config.go +++ b/config/config.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/pflag" "github.com/spf13/viper" + "github.com/taosdata/taosadapter/v3/thread" "github.com/taosdata/taosadapter/v3/version" "go.uber.org/automaxprocs/maxprocs" ) @@ -15,6 +16,8 @@ type Config struct { InstanceID uint8 Cors CorsConfig TaosConfigDir string + MaxSyncMethodLimit int + MaxAsyncMethodLimit int Debug bool Port int LogLevel string @@ -81,6 +84,8 @@ func Init() { HttpCodeServerError: viper.GetBool("httpCodeServerError"), SMLAutoCreateDB: viper.GetBool("smlAutoCreateDB"), InstanceID: uint8(viper.GetInt("instanceId")), + MaxSyncMethodLimit: viper.GetInt("maxSyncMethodLimit"), + MaxAsyncMethodLimit: viper.GetInt("maxAsyncMethodLimit"), } Conf.Log.setValue() Conf.Cors.setValue() @@ -99,6 +104,17 @@ func Init() { if !viper.IsSet("logLevel") { viper.Set("logLevel", "") } + maxAsyncMethodLimit := Conf.MaxAsyncMethodLimit + if maxAsyncMethodLimit == 0 { + maxAsyncMethodLimit = runtime.NumCPU() + } + thread.AsyncLocker = thread.NewLocker(maxAsyncMethodLimit) + + maxSyncMethodLimit := Conf.MaxSyncMethodLimit + if maxSyncMethodLimit == 0 { + maxSyncMethodLimit = runtime.NumCPU() + } + thread.SyncLocker = thread.NewLocker(maxSyncMethodLimit) } // arg > file > env @@ -135,6 +151,14 @@ func init() { _ = viper.BindEnv("instanceId", "TAOS_ADAPTER_INSTANCE_ID") pflag.Int("instanceId", 32, `instance ID. Env "TAOS_ADAPTER_INSTANCE_ID"`) + viper.SetDefault("maxSyncMethodLimit", 0) + _ = viper.BindEnv("maxSyncMethodLimit", "TAOS_ADAPTER_MAX_SYNC_METHOD_LIMIT") + pflag.Int("maxSyncMethodLimit", 0, `The maximum number of concurrent calls allowed for the C synchronized method. 0 means use CPU core count. Env "TAOS_ADAPTER_MAX_SYNC_METHOD_LIMIT"`) + + viper.SetDefault("maxAsyncMethodLimit", 0) + _ = viper.BindEnv("maxAsyncMethodLimit", "TAOS_ADAPTER_MAX_ASYNC_METHOD_LIMIT") + pflag.Int("maxAsyncMethodLimit", 0, `The maximum number of concurrent calls allowed for the C asynchronous method. 0 means use CPU core count. Env "TAOS_ADAPTER_MAX_ASYNC_METHOD_LIMIT"`) + initLog() initCors() initPool() diff --git a/config/config_test.go b/config/config_test.go index 6a59d883..9be30dee 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -36,6 +36,8 @@ func TestInit(t *testing.T) { AllowWebSockets: false, }, TaosConfigDir: "", + MaxSyncMethodLimit: 0, + MaxAsyncMethodLimit: 0, Debug: true, Port: 6041, LogLevel: "info", diff --git a/config/config_windows_test.go b/config/config_windows_test.go index d78a075b..42ef60ee 100644 --- a/config/config_windows_test.go +++ b/config/config_windows_test.go @@ -36,6 +36,8 @@ func TestInit(t *testing.T) { AllowWebSockets: false, }, TaosConfigDir: "", + MaxSyncMethodLimit: 0, + MaxAsyncMethodLimit: 0, Debug: true, Port: 6041, LogLevel: "info", diff --git a/controller/rest/configcontroller.go b/controller/rest/configcontroller.go index 9f600b50..4e63e7c8 100644 --- a/controller/rest/configcontroller.go +++ b/controller/rest/configcontroller.go @@ -2,6 +2,7 @@ package rest import ( "encoding/json" + "github.com/taosdata/driver-go/v3/wrapper" "net/http" "sync/atomic" @@ -10,7 +11,6 @@ import ( taoserrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/db/commonpool" - "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/iptool" @@ -50,13 +50,15 @@ func (ctl *ConfigController) changeConfig(c *gin.Context) { defer unlock() user := c.MustGet(UserKey).(string) password := c.MustGet(PasswordKey).(string) - conn, err := syncinterface.TaosConnect("", user, password, "", 0, logger, log.IsDebug()) - //conn, err := wrapper.TaosConnect("", user, password, "", 0) + conn, err := wrapper.TaosConnect("", user, password, "", 0) if err != nil { taosErr := err.(*taoserrors.TaosError) ErrorResponse(c, logger, http.StatusUnauthorized, int(taosErr.Code), taosErr.ErrStr) return } + defer func() { + wrapper.TaosClose(conn) + }() whitelist, err := tool.GetWhitelist(conn) if err != nil { logger.Errorf("get whitelist failed, err: %s", err) diff --git a/controller/rest/restful.go b/controller/rest/restful.go index 52fe6269..09328fa2 100644 --- a/controller/rest/restful.go +++ b/controller/rest/restful.go @@ -21,7 +21,6 @@ import ( "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/commonpool" - "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" @@ -223,7 +222,8 @@ func DoQuery(c *gin.Context, db string, timeFunc ctools.FormatTimeFunc, reqID in if len(db) > 0 { // Attempt to select the database does not return even if there is an error // To avoid error reporting in the `create database` statement - syncinterface.TaosSelectDB(taosConnect.TaosConnection, db, logger, isDebug) + logger.Tracef("select db %s", db) + _ = async.GlobalAsync.TaosExecWithoutResult(taosConnect.TaosConnection, logger, isDebug, fmt.Sprintf("use `%s`", db), reqID) } execute(c, logger, isDebug, taosConnect.TaosConnection, sql, timeFunc, reqID, sqlType, returnObj) } @@ -250,7 +250,7 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns result := async.GlobalAsync.TaosQuery(taosConnect, logger, isDebug, sql, handler, reqID) defer func() { if result != nil && result.Res != nil { - syncinterface.FreeResult(result.Res, logger, isDebug) + async.FreeResultAsync(result.Res, logger, isDebug) } }() res := result.Res diff --git a/controller/ws/query/ws.go b/controller/ws/query/ws.go index c3b53eea..c3cacbad 100644 --- a/controller/ws/query/ws.go +++ b/controller/ws/query/ws.go @@ -615,13 +615,13 @@ func (t *Taos) writeRaw(ctx context.Context, session *melody.Session, reqID, mes meta := wrapper.BuildRawMeta(length, metaType, data) s = log.GetLogNow(isDebug) logger.Trace("get thread lock for write raw meta") - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) logger.Trace("write raw meta") errCode := wrapper.TMQWriteRaw(t.conn, meta) logger.Debugf("write_raw_meta cost:%s", log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() if errCode != 0 { errStr := wrapper.TMQErr2Str(errCode) logger.Errorf("write raw meta error, code: %d, message: %s", errCode, errStr) @@ -658,12 +658,12 @@ func (t *Taos) writeRawBlock(ctx context.Context, session *melody.Session, reqID } logger.Trace("get thread lock for write raw block") s = log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) logger.Trace("write raw block") errCode := wrapper.TaosWriteRawBlockWithReqID(t.conn, numOfRows, rawBlock, tableName, int64(reqID)) - thread.Unlock() + thread.SyncLocker.Unlock() logger.Debugf("write raw cost:%s", log.GetLogDuration(isDebug, s)) if errCode != 0 { errStr := wrapper.TMQErr2Str(int32(errCode)) @@ -702,12 +702,12 @@ func (t *Taos) writeRawBlockWithFields(ctx context.Context, session *melody.Sess } logger.Trace("get thread lock for write raw block with fields") s = log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) logger.Trace("write raw block with fields") errCode := wrapper.TaosWriteRawBlockWithFieldsWithReqID(t.conn, numOfRows, rawBlock, tableName, fields, numFields, int64(reqID)) - thread.Unlock() + thread.SyncLocker.Unlock() logger.Debugf("write raw block with fields cost:%s", log.GetLogDuration(isDebug, s)) if errCode != 0 { errStr := wrapper.TMQErr2Str(int32(errCode)) diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index 477113bd..8a42fbea 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -1467,16 +1467,16 @@ func (t *TMQ) listTopics(ctx context.Context, session *melody.Session, req *TMQL isDebug := log.IsDebug() s := log.GetLogNow(isDebug) logger.Trace("subscription get thread lock") - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("subscription get thread lock cost:%s", log.GetLogDuration(isDebug, s)) if t.isClosed() { - thread.Unlock() + thread.SyncLocker.Unlock() logger.Trace("server closed") return } s = log.GetLogNow(isDebug) code, topicsPointer := wrapper.TMQSubscription(t.consumer) - thread.Unlock() + thread.SyncLocker.Unlock() logger.Debugf("subscription cost:%s", log.GetLogDuration(isDebug, s)) defer wrapper.TMQListDestroy(topicsPointer) if code != 0 { diff --git a/controller/ws/ws/query_result.go b/controller/ws/ws/query_result.go index 0f019b8f..27028489 100644 --- a/controller/ws/ws/query_result.go +++ b/controller/ws/ws/query_result.go @@ -44,7 +44,7 @@ func (r *QueryResult) free(logger *logrus.Entry) { return } logger.Tracef("free result:%d", r.index) - syncinterface.FreeResult(r.TaosResult, logger, log.IsDebug()) + async.FreeResultAsync(r.TaosResult, logger, log.IsDebug()) r.TaosResult = nil } diff --git a/db/async/row.go b/db/async/row.go index 0ffdddf8..d15d2f66 100644 --- a/db/async/row.go +++ b/db/async/row.go @@ -11,7 +11,6 @@ import ( tErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" - "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/thread" @@ -36,7 +35,7 @@ func (a *Async) TaosExec(taosConnect unsafe.Pointer, logger *logrus.Entry, isDeb var s time.Time defer func() { if result != nil && result.Res != nil { - syncinterface.FreeResult(result.Res, logger, isDebug) + FreeResultAsync(result.Res, logger, isDebug) } }() res := result.Res @@ -79,12 +78,12 @@ func (a *Async) TaosExec(taosConnect unsafe.Pointer, logger *logrus.Entry, isDeb var row unsafe.Pointer logger.Tracef("get thread lock for fetch row, row:%d", i) s = log.GetLogNow(isDebug) - thread.Lock() + thread.AsyncLocker.Lock() logger.Debugf("get thread lock for fetch row cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) row = wrapper.TaosFetchRow(res) logger.Debugf("taos_fetch_row finish, row:%p, cost:%s", row, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.AsyncLocker.Unlock() lengths := wrapper.FetchLengths(res, len(rowsHeader.ColNames)) logger.Tracef("fetch lengths:%d", lengths) values := make([]driver.Value, len(rowsHeader.ColNames)) @@ -114,12 +113,12 @@ func (a *Async) TaosQuery(taosConnect unsafe.Pointer, logger *logrus.Entry, isDe logger = logger.WithField(config.ReqIDKey, reqID) } s := log.GetLogNow(isDebug) - thread.Lock() + thread.AsyncLocker.Lock() logger.Debugf("get thread lock for taos_query_a cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) wrapper.TaosQueryAWithReqID(taosConnect, sql, handler.Handler, reqID) logger.Debugf("taos_query_a finish, cost:%s", log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.AsyncLocker.Unlock() logger.Trace("wait for query result") s = log.GetLogNow(isDebug) r := <-handler.Caller.QueryResult @@ -130,12 +129,12 @@ func (a *Async) TaosQuery(taosConnect unsafe.Pointer, logger *logrus.Entry, isDe func (a *Async) TaosFetchRowsA(res unsafe.Pointer, logger *logrus.Entry, isDebug bool, handler *Handler) *Result { logger.Tracef("call taos_fetch_rows_a, res:%p", res) s := log.GetLogNow(isDebug) - thread.Lock() + thread.AsyncLocker.Lock() logger.Debugf("get thread lock for fetch_rows_a cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) wrapper.TaosFetchRowsA(res, handler.Handler) logger.Debugf("taos_fetch_rows_a finish, cost:%s", log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.AsyncLocker.Unlock() logger.Trace("wait for fetch rows result") s = log.GetLogNow(isDebug) r := <-handler.Caller.FetchResult @@ -146,13 +145,13 @@ func (a *Async) TaosFetchRowsA(res unsafe.Pointer, logger *logrus.Entry, isDebug func (a *Async) TaosFetchRawBlockA(res unsafe.Pointer, logger *logrus.Entry, isDebug bool, handler *Handler) *Result { logger.Tracef("call taos_fetch_raw_block_a, res:%p", res) s := log.GetLogNow(isDebug) - thread.Lock() + thread.AsyncLocker.Lock() logger.Debugf("get thread lock for fetch_raw_block_a cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) logger.Trace("start fetch_raw_block_a") wrapper.TaosFetchRawBlockA(res, handler.Handler) logger.Debugf("taos_fetch_raw_block_a finish, cost:%s", log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.AsyncLocker.Unlock() logger.Trace("wait for fetch raw block result") s = log.GetLogNow(isDebug) r := <-handler.Caller.FetchResult @@ -173,7 +172,7 @@ func (a *Async) TaosExecWithoutResult(taosConnect unsafe.Pointer, logger *logrus result := a.TaosQuery(taosConnect, logger, isDebug, sql, handler, reqID) defer func() { if result != nil && result.Res != nil { - syncinterface.FreeResult(result.Res, logger, isDebug) + FreeResultAsync(result.Res, logger, isDebug) } }() res := result.Res @@ -185,3 +184,18 @@ func (a *Async) TaosExecWithoutResult(taosConnect unsafe.Pointer, logger *logrus } return nil } + +func FreeResultAsync(res unsafe.Pointer, logger *logrus.Entry, isDebug bool) { + logger.Tracef("call taos_free_result async, res:%p", res) + if res == nil { + logger.Trace("result is nil") + return + } + s := log.GetLogNow(isDebug) + thread.AsyncLocker.Lock() + logger.Debugf("get thread lock for free result cost:%s", log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + wrapper.TaosFreeResult(res) + logger.Debugf("taos_free_result finish, cost:%s", log.GetLogDuration(isDebug, s)) + thread.AsyncLocker.Unlock() +} diff --git a/db/init.go b/db/init.go index 5de6e564..b8600039 100644 --- a/db/init.go +++ b/db/init.go @@ -30,6 +30,7 @@ func PrepareConnection() { err := errors.NewError(code, errStr) logger.WithError(err).Panic("set option TSDB_OPTION_USE_ADAPTER error") } + + async.GlobalAsync = async.NewAsync(async.NewHandlerPool(10000)) }) - async.GlobalAsync = async.NewAsync(async.NewHandlerPool(10000)) } diff --git a/db/syncinterface/wrapper.go b/db/syncinterface/wrapper.go index d59b5015..96e57c65 100644 --- a/db/syncinterface/wrapper.go +++ b/db/syncinterface/wrapper.go @@ -20,12 +20,12 @@ func FreeResult(res unsafe.Pointer, logger *logrus.Entry, isDebug bool) { return } s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for free result cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) wrapper.TaosFreeResult(res) logger.Debugf("taos_free_result finish, cost:%s", log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() } func TaosClose(conn unsafe.Pointer, logger *logrus.Entry, isDebug bool) { @@ -35,203 +35,203 @@ func TaosClose(conn unsafe.Pointer, logger *logrus.Entry, isDebug bool) { return } s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_close cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) wrapper.TaosClose(conn) logger.Debugf("taos_close finish, cost:%s", log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() } func TaosSelectDB(conn unsafe.Pointer, db string, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_select_db, conn:%p, db:%s", conn, db) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_select_db cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosSelectDB(conn, db) logger.Debugf("taos_select_db finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosConnect(host, user, pass, db string, port int, logger *logrus.Entry, isDebug bool) (unsafe.Pointer, error) { logger.Tracef("call taos_connect, host:%s, user:%s, pass:%s, db:%s, port:%d", host, user, pass, db, port) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_connect cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) conn, err := wrapper.TaosConnect(host, user, pass, db, port) logger.Debugf("taos_connect finish, conn:%p, err:%v, cost:%s", conn, err, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return conn, err } func TaosGetTablesVgID(conn unsafe.Pointer, db string, tables []string, logger *logrus.Entry, isDebug bool) ([]int, int) { logger.Tracef("call taos_get_tables_vgId, conn:%p, db:%s, tables:%s", conn, db, strings.Join(tables, ", ")) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_get_tables_vgId cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) vgIDs, code := wrapper.TaosGetTablesVgID(conn, db, tables) logger.Debugf("taos_get_tables_vgId finish, vgid:%v, code:%d, cost:%s", vgIDs, code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return vgIDs, code } func TaosStmtInitWithReqID(conn unsafe.Pointer, reqID int64, logger *logrus.Entry, isDebug bool) unsafe.Pointer { logger.Tracef("call taos_stmt_init_with_reqid, conn:%p, QID:0x%x", conn, reqID) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_init_with_reqid cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) stmtInit := wrapper.TaosStmtInitWithReqID(conn, reqID) logger.Debugf("taos_stmt_init_with_reqid result:%p, cost:%s", stmtInit, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return stmtInit } func TaosStmtPrepare(stmt unsafe.Pointer, sql string, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt_init_with_reqid, stmt:%p, sql:%s", stmt, log.GetLogSql(sql)) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_prepare cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmtPrepare(stmt, sql) logger.Debugf("taos_stmt_prepare code:%d cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmtIsInsert(stmt unsafe.Pointer, logger *logrus.Entry, isDebug bool) (bool, int) { logger.Tracef("call taos_stmt_is_insert, stmt:%p", stmt) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_is_insert cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) isInsert, code := wrapper.TaosStmtIsInsert(stmt) logger.Debugf("taos_stmt_is_insert isInsert finish, insert:%t, code:%d, cost:%s", isInsert, code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return isInsert, code } func TaosStmtSetTBName(stmt unsafe.Pointer, tbname string, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt_set_tbname, stmt:%p, tbname:%s", stmt, tbname) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_set_tbname cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmtSetTBName(stmt, tbname) logger.Debugf("taos_stmt_set_tbname finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmtGetTagFields(stmt unsafe.Pointer, logger *logrus.Entry, isDebug bool) (int, int, unsafe.Pointer) { logger.Tracef("call taos_stmt_get_tag_fields, stmt:%p", stmt) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_get_tag_fields cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code, num, fields := wrapper.TaosStmtGetTagFields(stmt) logger.Debugf("taos_stmt_get_tag_fields finish, code:%d, num:%d, fields:%p, cost:%s", code, num, fields, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code, num, fields } func TaosStmtSetTags(stmt unsafe.Pointer, tags []driver.Value, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt_set_tags, stmt:%p, tags:%v", stmt, tags) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_set_tags cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmtSetTags(stmt, tags) logger.Debugf("taos_stmt_set_tags finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmtGetColFields(stmt unsafe.Pointer, logger *logrus.Entry, isDebug bool) (int, int, unsafe.Pointer) { logger.Tracef("call taos_stmt_get_col_fields, stmt:%p", stmt) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_get_col_fields cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code, num, fields := wrapper.TaosStmtGetColFields(stmt) logger.Debugf("taos_stmt_get_col_fields code:%d, num:%d, fields:%p, cost:%s", code, num, fields, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code, num, fields } func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bindType []*types.ColumnType, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt_bind_param_batch, stmt:%p, multiBind:%v, bindType:%v", stmt, multiBind, bindType) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_bind_param_batch cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmtBindParamBatch(stmt, multiBind, bindType) logger.Debugf("taos_stmt_bind_param_batch code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmtAddBatch(stmt unsafe.Pointer, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt_add_batch, stmt:%p", stmt) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_add_batch cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmtAddBatch(stmt) logger.Debugf("taos_stmt_add_batch code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmtExecute(stmt unsafe.Pointer, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt_execute, stmt:%p", stmt) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_execute cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmtExecute(stmt) logger.Debugf("taos_stmt_execute code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmtClose(stmt unsafe.Pointer, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt_close, stmt:%p", stmt) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_close cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmtClose(stmt) logger.Debugf("taos_stmt_close code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TMQWriteRaw(conn unsafe.Pointer, raw unsafe.Pointer, logger *logrus.Entry, isDebug bool) int32 { logger.Tracef("call tmq_write_raw, conn:%p, raw:%p", conn, raw) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for tmq_write_raw cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TMQWriteRaw(conn, raw) logger.Debugf("tmq_write_raw finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosWriteRawBlockWithReqID(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, tableName string, reqID int64, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_write_raw_block_with_reqid, conn:%p, numOfRows:%d, pData:%p, tableName:%s, reqID:%d", conn, numOfRows, pData, tableName, reqID) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_write_raw_block_with_reqid cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosWriteRawBlockWithReqID(conn, numOfRows, pData, tableName, reqID) logger.Debugf("taos_write_raw_block_with_reqid finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } @@ -257,155 +257,155 @@ func TaosWriteRawBlockWithFieldsWithReqID( reqID, ) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_write_raw_block_with_fields_with_reqid cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosWriteRawBlockWithFieldsWithReqID(conn, numOfRows, pData, tableName, fields, numFields, reqID) logger.Debugf("taos_write_raw_block_with_fields_with_reqid finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosGetCurrentDB(conn unsafe.Pointer, logger *logrus.Entry, isDebug bool) (string, error) { logger.Tracef("call taos_get_current_db, conn:%p", conn) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_get_current_db cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) db, err := wrapper.TaosGetCurrentDB(conn) logger.Debugf("taos_get_current_db finish, db:%s, err:%v, cost:%s", db, err, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return db, err } func TaosGetServerInfo(conn unsafe.Pointer, logger *logrus.Entry, isDebug bool) string { logger.Tracef("call taos_get_server_info, conn:%p", conn) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_get_server_info cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) info := wrapper.TaosGetServerInfo(conn) logger.Debugf("taos_get_server_info finish, info:%s, cost:%s", info, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return info } func TaosStmtNumParams(stmt unsafe.Pointer, logger *logrus.Entry, isDebug bool) (int, int) { logger.Tracef("call taos_stmt_num_params, stmt:%p", stmt) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_num_params cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) num, errCode := wrapper.TaosStmtNumParams(stmt) logger.Debugf("taos_stmt_num_params finish, num:%d, code:%d, cost:%s", num, errCode, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return num, errCode } func TaosStmtGetParam(stmt unsafe.Pointer, index int, logger *logrus.Entry, isDebug bool) (int, int, error) { logger.Tracef("call taos_stmt_get_param, stmt:%p, index:%d", stmt, index) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_get_param cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) dataType, dataLength, err := wrapper.TaosStmtGetParam(stmt, index) logger.Debugf("taos_stmt_get_param finish, type:%d, len:%d, err:%v, cost:%s", dataType, dataLength, err, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return dataType, dataLength, err } func TaosSchemalessInsertRawTTLWithReqIDTBNameKey(conn unsafe.Pointer, lines string, protocol int, precision string, ttl int, reqID int64, tbNameKey string, logger *logrus.Entry, isDebug bool) (int32, unsafe.Pointer) { logger.Tracef("call taos_schemaless_insert_raw_ttl_with_reqid_tbname_key, conn:%p, lines:%s, protocol:%d, precision:%s, ttl:%d, reqID:%d, tbnameKey:%s", conn, lines, protocol, precision, ttl, reqID, tbNameKey) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_schemaless_insert_raw_ttl_with_reqid_tbname_key cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) rows, result := wrapper.TaosSchemalessInsertRawTTLWithReqIDTBNameKey(conn, lines, protocol, precision, ttl, reqID, tbNameKey) logger.Debugf("taos_schemaless_insert_raw_ttl_with_reqid_tbname_key finish, rows:%d, result:%p, cost:%s", rows, result, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return rows, result } func TaosStmt2Init(taosConnect unsafe.Pointer, reqID int64, singleStbInsert bool, singleTableBindOnce bool, handle cgo.Handle, logger *logrus.Entry, isDebug bool) unsafe.Pointer { logger.Tracef("call taos_stmt2_init, taosConnect:%p, reqID:%d, singleStbInsert:%t, singleTableBindOnce:%t, handle:%p", taosConnect, reqID, singleStbInsert, singleTableBindOnce, handle.Pointer()) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt2_init cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) stmt2 := wrapper.TaosStmt2Init(taosConnect, reqID, singleStbInsert, singleTableBindOnce, handle) logger.Debugf("taos_stmt2_init finish, stmt2:%p, cost:%s", stmt2, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return stmt2 } func TaosStmt2Prepare(stmt2 unsafe.Pointer, sql string, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt2_prepare, stmt2:%p, sql:%s", stmt2, log.GetLogSql(sql)) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt2_prepare cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmt2Prepare(stmt2, sql) logger.Debugf("taos_stmt2_prepare finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmt2IsInsert(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) (bool, int) { logger.Tracef("call taos_stmt2_is_insert, stmt2:%p", stmt2) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt2_is_insert cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) isInsert, code := wrapper.TaosStmt2IsInsert(stmt2) logger.Debugf("taos_stmt2_is_insert finish, isInsert:%t, code:%d, cost:%s", isInsert, code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return isInsert, code } func TaosStmt2GetFields(stmt2 unsafe.Pointer, fieldType int, logger *logrus.Entry, isDebug bool) (code, count int, fields unsafe.Pointer) { logger.Tracef("call taos_stmt2_get_fields, stmt2:%p, fieldType:%d", stmt2, fieldType) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt2_get_fields cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code, count, fields = wrapper.TaosStmt2GetFields(stmt2, fieldType) logger.Debugf("taos_stmt2_get_fields finish, code:%d, count:%d, fields:%p, cost:%s", code, count, fields, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code, count, fields } func TaosStmt2Exec(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt2_exec, stmt2:%p", stmt2) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt2_exec cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmt2Exec(stmt2) logger.Debugf("taos_stmt2_exec finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmt2Close(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) int { logger.Tracef("call taos_stmt2_close, stmt2:%p", stmt2) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt2_close cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) code := wrapper.TaosStmt2Close(stmt2) logger.Debugf("taos_stmt2_close finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return code } func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32, logger *logrus.Entry, isDebug bool) error { logger.Tracef("call taos_stmt_bind_binary, stmt2:%p, colIdx:%d, data:%v", stmt2, colIdx, data) s := log.GetLogNow(isDebug) - thread.Lock() + thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt_bind_binary cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) err := wrapper.TaosStmt2BindBinary(stmt2, data, colIdx) logger.Debugf("taos_stmt_bind_binary finish, err:%v, cost:%s", err, log.GetLogDuration(isDebug, s)) - thread.Unlock() + thread.SyncLocker.Unlock() return err } diff --git a/db/syncinterface/wrapper_test.go b/db/syncinterface/wrapper_test.go index 27ac018e..52e808b0 100644 --- a/db/syncinterface/wrapper_test.go +++ b/db/syncinterface/wrapper_test.go @@ -15,6 +15,7 @@ import ( "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" + "github.com/taosdata/taosadapter/v3/db" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/generator" ) @@ -26,6 +27,7 @@ const isDebug = true func TestMain(m *testing.M) { config.Init() log.SetLevel("trace") + db.PrepareConnection() m.Run() } func TestTaosConnect(t *testing.T) { diff --git a/db/tool/notify.go b/db/tool/notify.go index 9184f394..cf5dd8f5 100644 --- a/db/tool/notify.go +++ b/db/tool/notify.go @@ -45,9 +45,9 @@ func putWhiteListHandle(handle cgo.Handle) { func GetWhitelist(conn unsafe.Pointer) ([]*net.IPNet, error) { c, handler := getWhiteListHandle() defer putWhiteListHandle(handler) - thread.Lock() + thread.SyncLocker.Lock() wrapper.TaosFetchWhitelistA(conn, handler) - thread.Unlock() + thread.SyncLocker.Unlock() data := <-c if data.ErrCode != 0 { err := errors.NewError(int(data.ErrCode), wrapper.TaosErrorStr(nil)) diff --git a/example/config/taosadapter.toml b/example/config/taosadapter.toml index e0a7e111..b44e6bc1 100644 --- a/example/config/taosadapter.toml +++ b/example/config/taosadapter.toml @@ -16,12 +16,18 @@ smlAutoCreateDB = false # Instance ID of the taosAdapter. instanceId = 32 +# The maximum number of concurrent calls allowed for the C synchronized method.0 means use CPU core count. +# maxSyncMethodLimit = 0 + +# The maximum number of concurrent calls allowed for the C asynchronous method. 0 means use CPU core count. +#maxAsyncMethodLimit = 0 + [cors] # If set to true, allows cross-origin requests from any origin (CORS). allowAllOrigins = true [pool] -# The maximum number of connections to the server. If set to 0, no limit is imposed. +# The maximum number of connections to the server. If set to 0, use cpu count. # maxConnect = 0 # The maximum number of idle connections to the server. Should match maxConnect. diff --git a/thread/locker.go b/thread/locker.go index d87afc0b..9c475464 100644 --- a/thread/locker.go +++ b/thread/locker.go @@ -1,13 +1,12 @@ package thread -import ( - "runtime" -) - type Locker struct { c chan struct{} } +var SyncLocker *Locker +var AsyncLocker *Locker + func NewLocker(count int) *Locker { return &Locker{c: make(chan struct{}, count)} } @@ -19,17 +18,3 @@ func (l *Locker) Lock() { func (l *Locker) Unlock() { <-l.c } - -var c chan struct{} - -func Lock() { - c <- struct{}{} -} - -func Unlock() { - <-c -} - -func init() { - c = make(chan struct{}, runtime.NumCPU()) -} diff --git a/thread/locker_test.go b/thread/locker_test.go index f49be5c3..3d7c74b4 100644 --- a/thread/locker_test.go +++ b/thread/locker_test.go @@ -30,12 +30,3 @@ func TestNewLocker(t *testing.T) { }) } } - -// @author: xftan -// @date: 2021/12/14 15:16 -// @description: test DefaultLocker -func TestDefaultLocker(t *testing.T) { - Lock() - t.Log("success") - defer Unlock() -} From c4130da2f789389aae33d4cee4a3b0f0ff8a527a Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 23 Oct 2024 15:06:59 +0800 Subject: [PATCH 02/48] enh: rename config item --- config/config.go | 16 ++++++++-------- example/config/taosadapter.toml | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/config/config.go b/config/config.go index 0d01effc..0e54324d 100644 --- a/config/config.go +++ b/config/config.go @@ -84,8 +84,8 @@ func Init() { HttpCodeServerError: viper.GetBool("httpCodeServerError"), SMLAutoCreateDB: viper.GetBool("smlAutoCreateDB"), InstanceID: uint8(viper.GetInt("instanceId")), - MaxSyncMethodLimit: viper.GetInt("maxSyncMethodLimit"), - MaxAsyncMethodLimit: viper.GetInt("maxAsyncMethodLimit"), + MaxSyncMethodLimit: viper.GetInt("maxSyncConcurrentLimit"), + MaxAsyncMethodLimit: viper.GetInt("maxAsyncConcurrentLimit"), } Conf.Log.setValue() Conf.Cors.setValue() @@ -151,13 +151,13 @@ func init() { _ = viper.BindEnv("instanceId", "TAOS_ADAPTER_INSTANCE_ID") pflag.Int("instanceId", 32, `instance ID. Env "TAOS_ADAPTER_INSTANCE_ID"`) - viper.SetDefault("maxSyncMethodLimit", 0) - _ = viper.BindEnv("maxSyncMethodLimit", "TAOS_ADAPTER_MAX_SYNC_METHOD_LIMIT") - pflag.Int("maxSyncMethodLimit", 0, `The maximum number of concurrent calls allowed for the C synchronized method. 0 means use CPU core count. Env "TAOS_ADAPTER_MAX_SYNC_METHOD_LIMIT"`) + viper.SetDefault("maxSyncConcurrentLimit", 0) + _ = viper.BindEnv("maxSyncConcurrentLimit", "TAOS_ADAPTER_MAX_SYNC_CONCURRENT_LIMIT") + pflag.Int("maxSyncConcurrentLimit", 0, `The maximum number of concurrent calls allowed for the C synchronized method. 0 means use CPU core count. Env "TAOS_ADAPTER_MAX_SYNC_CONCURRENT_LIMIT"`) - viper.SetDefault("maxAsyncMethodLimit", 0) - _ = viper.BindEnv("maxAsyncMethodLimit", "TAOS_ADAPTER_MAX_ASYNC_METHOD_LIMIT") - pflag.Int("maxAsyncMethodLimit", 0, `The maximum number of concurrent calls allowed for the C asynchronous method. 0 means use CPU core count. Env "TAOS_ADAPTER_MAX_ASYNC_METHOD_LIMIT"`) + viper.SetDefault("maxAsyncConcurrentLimit", 0) + _ = viper.BindEnv("maxAsyncConcurrentLimit", "TAOS_ADAPTER_MAX_ASYNC_CONCURRENT_LIMIT") + pflag.Int("maxAsyncConcurrentLimit", 0, `The maximum number of concurrent calls allowed for the C asynchronous method. 0 means use CPU core count. Env "TAOS_ADAPTER_MAX_ASYNC_CONCURRENT_LIMIT"`) initLog() initCors() diff --git a/example/config/taosadapter.toml b/example/config/taosadapter.toml index b44e6bc1..4e60e272 100644 --- a/example/config/taosadapter.toml +++ b/example/config/taosadapter.toml @@ -17,10 +17,10 @@ smlAutoCreateDB = false instanceId = 32 # The maximum number of concurrent calls allowed for the C synchronized method.0 means use CPU core count. -# maxSyncMethodLimit = 0 +#maxSyncConcurrentLimit = 0 # The maximum number of concurrent calls allowed for the C asynchronous method. 0 means use CPU core count. -#maxAsyncMethodLimit = 0 +#maxAsyncConcurrentLimit = 0 [cors] # If set to true, allows cross-origin requests from any origin (CORS). From eb88166e65142bc2f6992c57bb5c70c3cf6d32ba Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 8 Nov 2024 14:55:16 +0800 Subject: [PATCH 03/48] test: add unit test --- controller/ws/ws/handler.go | 39 ++++++++++++++----------------------- controller/ws/ws/ws_test.go | 32 ++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 6e89597a..9020793c 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -98,13 +98,7 @@ func (h *messageHandler) waitSignal(logger *logrus.Entry) { return } logger.Info("user dropped, close connection") - s := log.GetLogNow(isDebug) - h.session.Close() - h.Unlock() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - h.Close() - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + h.signalExit(logger, isDebug) return case <-h.whitelistChangeChan: logger.Info("get whitelist change signal") @@ -116,32 +110,17 @@ func (h *messageHandler) waitSignal(logger *logrus.Entry) { return } logger.Trace("get whitelist") - s := log.GetLogNow(isDebug) whitelist, err := tool.GetWhitelist(h.conn) if err != nil { logger.Errorf("get whitelist error, close connection, err:%s", err) - s = log.GetLogNow(isDebug) - h.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - h.Unlock() - s = log.GetLogNow(isDebug) - h.Close() - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + h.signalExit(logger, isDebug) return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, h.ip) if !valid { logger.Errorf("ip not in whitelist! close connection, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) - logger.Trace("close session") - s = log.GetLogNow(isDebug) - h.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - h.Unlock() - logger.Trace("close handler") - s = log.GetLogNow(isDebug) - h.Close() - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + h.signalExit(logger, isDebug) return } h.Unlock() @@ -151,6 +130,18 @@ func (h *messageHandler) waitSignal(logger *logrus.Entry) { } } +func (h *messageHandler) signalExit(logger *logrus.Entry, isDebug bool) { + logger.Trace("close session") + s := log.GetLogNow(isDebug) + h.session.Close() + logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) + h.Unlock() + logger.Trace("close handler") + s = log.GetLogNow(isDebug) + h.Close() + logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) +} + func (h *messageHandler) lock(logger *logrus.Entry, isDebug bool) { logger.Trace("get handler lock") s := log.GetLogNow(isDebug) diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index 0863d399..f7244f43 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -3316,3 +3316,35 @@ func TestWSTMQWriteRaw(t *testing.T) { } } } + +func TestDropUser(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err := ws.Close() + assert.NoError(t, err) + }() + defer doRestful("drop user test_ws_drop_user", "") + code, message := doRestful("create user test_ws_drop_user pass 'pass'", "") + assert.Equal(t, 0, code, message) + // connect + connReq := ConnRequest{ReqID: 1, User: "test_ws_drop_user", Password: "pass"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp BaseResponse + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + // drop user + code, message = doRestful("drop user test_ws_drop_user", "") + assert.Equal(t, 0, code, message) + time.Sleep(time.Second * 3) + resp, err = doWebSocket(ws, wstool.ClientVersion, nil) + assert.Error(t, err, resp) +} From c292a68010bd8f1025644da132def58c5d69ecdf Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 8 Nov 2024 16:04:51 +0800 Subject: [PATCH 04/48] test: add unit test --- log/logger.go | 7 +- log/logger_test.go | 223 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+), 5 deletions(-) diff --git a/log/logger.go b/log/logger.go index bde31567..0f169dd4 100644 --- a/log/logger.go +++ b/log/logger.go @@ -252,14 +252,11 @@ func (t *TaosLogFormatter) Format(entry *logrus.Entry) ([]byte, error) { keys = append(keys, k) } for _, k := range keys { - v := entry.Data[k] - if k == config.ReqIDKey && v == nil { - continue - } + value := entry.Data[k] b.WriteString(", ") b.WriteString(k) b.WriteByte(':') - fmt.Fprintf(b, "%v", v) + fmt.Fprintf(b, "%v", value) } b.WriteByte('\n') diff --git a/log/logger_test.go b/log/logger_test.go index 44191728..086521d0 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -1,9 +1,14 @@ package log import ( + "bytes" + "context" + "fmt" "strings" "testing" + "time" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/taosdata/taosadapter/v3/config" ) @@ -21,6 +26,10 @@ func TestConfigLog(t *testing.T) { ConfigLog() logger := GetLogger("TST") logger.Info("test config log") + time.Sleep(time.Second * 6) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + Close(ctx) } func TestIsDebug(t *testing.T) { @@ -49,3 +58,217 @@ func TestGetLogSql(t *testing.T) { sql = GetLogSql(str) assert.Equal(t, sql, str[:MaxLogSqlLength]) } +func TestTaosLogFormatter_Format(t1 *testing.T) { + type args struct { + entry *logrus.Entry + } + tests := []struct { + name string + args args + want []byte + wantErr assert.ErrorAssertionFunc + }{ + { + name: "common_panic", + args: args{ + entry: &logrus.Entry{ + Time: time.Unix(1657084598, 0), + Message: "select 1\n", + Data: map[string]interface{}{ + config.ModelKey: "test", + config.SessionIDKey: 1, + config.ReqIDKey: 1, + "ext": "111", + }, + Level: logrus.PanicLevel, + }, + }, + want: []byte(fmt.Sprintf("%s %s test PANIC SID:0x1, QID:0x1 select 1, ext:111\n", time.Unix(1657084598, 0).Format("01/02 15:04:05.000000"), ServerID)), + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + if err != nil { + t.Errorf("%s,%v", err.Error(), i) + return false + } + return true + }, + }, + { + name: "common_fatal", + args: args{ + entry: &logrus.Entry{ + Time: time.Unix(1657084598, 0), + Message: "select 1\n", + Data: map[string]interface{}{ + config.ModelKey: "test", + config.SessionIDKey: 1, + config.ReqIDKey: 1, + "ext": "111", + }, + Level: logrus.FatalLevel, + }, + }, + want: []byte(fmt.Sprintf("%s %s test FATAL SID:0x1, QID:0x1 select 1, ext:111\n", time.Unix(1657084598, 0).Format("01/02 15:04:05.000000"), ServerID)), + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + if err != nil { + t.Errorf("%s,%v", err.Error(), i) + return false + } + return true + }, + }, + { + name: "common_error", + args: args{ + entry: &logrus.Entry{ + Time: time.Unix(1657084598, 0), + Message: "select 1\n", + Data: map[string]interface{}{ + config.ModelKey: "test", + config.SessionIDKey: 1, + config.ReqIDKey: 1, + "ext": "111", + }, + Level: logrus.ErrorLevel, + }, + }, + want: []byte(fmt.Sprintf("%s %s test ERROR SID:0x1, QID:0x1 select 1, ext:111\n", time.Unix(1657084598, 0).Format("01/02 15:04:05.000000"), ServerID)), + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + if err != nil { + t.Errorf("%s,%v", err.Error(), i) + return false + } + return true + }, + }, + { + name: "common_warn", + args: args{ + entry: &logrus.Entry{ + Time: time.Unix(1657084598, 0), + Message: "select 1\n", + Data: map[string]interface{}{ + config.ModelKey: "test", + config.SessionIDKey: 1, + config.ReqIDKey: 1, + "ext": "111", + }, + Level: logrus.WarnLevel, + }, + }, + want: []byte(fmt.Sprintf("%s %s test WARN SID:0x1, QID:0x1 select 1, ext:111\n", time.Unix(1657084598, 0).Format("01/02 15:04:05.000000"), ServerID)), + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + if err != nil { + t.Errorf("%s,%v", err.Error(), i) + return false + } + return true + }, + }, + { + name: "common_info", + args: args{ + entry: &logrus.Entry{ + Time: time.Unix(1657084598, 0), + Message: "select 1\n", + Data: map[string]interface{}{ + config.ModelKey: "test", + config.SessionIDKey: 1, + config.ReqIDKey: 1, + "ext": "111", + }, + Level: logrus.InfoLevel, + }, + }, + want: []byte(fmt.Sprintf("%s %s test INFO SID:0x1, QID:0x1 select 1, ext:111\n", time.Unix(1657084598, 0).Format("01/02 15:04:05.000000"), ServerID)), + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + if err != nil { + t.Errorf("%s,%v", err.Error(), i) + return false + } + return true + }, + }, + { + name: "common_debug", + args: args{ + entry: &logrus.Entry{ + Time: time.Unix(1657084598, 0), + Message: "select 1\n", + Data: map[string]interface{}{ + config.ModelKey: "test", + config.SessionIDKey: 1, + config.ReqIDKey: 1, + "ext": "111", + }, + Level: logrus.DebugLevel, + }, + }, + want: []byte(fmt.Sprintf("%s %s test DEBUG SID:0x1, QID:0x1 select 1, ext:111\n", time.Unix(1657084598, 0).Format("01/02 15:04:05.000000"), ServerID)), + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + if err != nil { + t.Errorf("%s,%v", err.Error(), i) + return false + } + return true + }, + }, + { + name: "common_debug", + args: args{ + entry: &logrus.Entry{ + Time: time.Unix(1657084598, 0), + Message: "select 1\n", + Data: map[string]interface{}{ + config.ModelKey: "test", + config.SessionIDKey: 1, + config.ReqIDKey: 1, + "ext": "111", + }, + Level: logrus.TraceLevel, + }, + }, + want: []byte(fmt.Sprintf("%s %s test TRACE SID:0x1, QID:0x1 select 1, ext:111\n", time.Unix(1657084598, 0).Format("01/02 15:04:05.000000"), ServerID)), + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + if err != nil { + t.Errorf("%s,%v", err.Error(), i) + return false + } + return true + }, + }, + { + name: "common", + args: args{ + entry: &logrus.Entry{ + Time: time.Unix(1657084598, 0), + Message: "select 1", + Data: map[string]interface{}{ + config.SessionIDKey: 1, + config.ReqIDKey: nil, + "ext": "111", + }, + Level: logrus.InfoLevel, + Buffer: &bytes.Buffer{}, + }, + }, + want: []byte(fmt.Sprintf("%s %s CLI INFO SID:0x1, select 1, ext:111\n", time.Unix(1657084598, 0).Format("01/02 15:04:05.000000"), ServerID)), + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + if err != nil { + t.Errorf("%s,%v", err.Error(), i) + return false + } + return true + }, + }, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + t := &TaosLogFormatter{} + got, err := t.Format(tt.args.entry) + if !tt.wantErr(t1, err, fmt.Sprintf("Format(%v)", tt.args.entry)) { + return + } + assert.Equalf(t1, tt.want, got, "Format(%v)", tt.args.entry) + }) + } +} From 9272bb2e11a16183f4d440beabb886fd2cfcfc2f Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 13 Nov 2024 14:38:39 +0800 Subject: [PATCH 05/48] ci: add golangci-lint --- .github/workflows/linux.yml | 53 ++++- .golangci.yaml | 38 ++++ benchmark/remotewrite/main.go | 9 +- config/config.go | 5 +- controller/controller_test.go | 2 +- controller/ping/controller.go | 5 +- controller/rest/configcontroller.go | 2 +- controller/rest/restful.go | 141 +++++++------- controller/rest/restful_test.go | 2 +- controller/ws/query/ws.go | 123 ++++++------ controller/ws/query/ws_test.go | 34 ++-- controller/ws/schemaless/schemaless.go | 40 ++-- controller/ws/schemaless/schemaless_test.go | 5 +- controller/ws/stmt/convert.go | 5 +- controller/ws/stmt/stmt.go | 124 ++++++------ controller/ws/stmt/stmt_test.go | 17 +- controller/ws/tmq/const.go | 2 - controller/ws/tmq/tmq.go | 51 +++-- controller/ws/tmq/tmq_test.go | 202 +++++++++++++------- controller/ws/ws/handler.go | 130 ++++++------- controller/ws/ws/ws_test.go | 58 +++--- controller/ws/wstool/const.go | 4 +- controller/ws/wstool/error.go | 16 +- controller/ws/wstool/error_test.go | 11 +- controller/ws/wstool/resp.go | 10 +- controller/ws/wstool/resp_test.go | 7 +- db/async/handlerpool.go | 6 +- db/async/handlerpool_test.go | 2 +- db/async/row.go | 55 +++--- db/async/row_test.go | 4 +- db/asynctmq/tmq.go | 3 + db/asynctmq/tmq_windows.go | 1 + db/asynctmq/tmqhandle/handler.go | 6 +- db/asynctmq/tmqhandle/handler_test.go | 12 +- db/commonpool/pool.go | 47 ++--- db/commonpool/pool_test.go | 27 ++- db/init_test.go | 2 +- db/syncinterface/wrapper_test.go | 2 +- db/tool/notify_test.go | 11 +- httperror/errors.go | 5 +- log/logger.go | 18 +- log/logger_test.go | 2 +- log/web_test.go | 2 +- monitor/keeper.go | 4 +- monitor/keeper_test.go | 8 +- plugin/collectd/plugin_test.go | 21 +- plugin/influxdb/plugin_test.go | 21 +- plugin/interface_test.go | 4 +- plugin/nodeexporter/plugin.go | 11 +- plugin/nodeexporter/plugin_test.go | 26 ++- plugin/opentsdb/plugin.go | 7 +- plugin/opentsdb/plugin_test.go | 31 ++- plugin/opentsdbtelnet/plugin.go | 25 ++- plugin/opentsdbtelnet/plugin_test.go | 21 +- plugin/prometheus/plugin.go | 2 +- plugin/prometheus/plugin_test.go | 21 +- plugin/prometheus/process.go | 12 +- plugin/statsd/plugin_test.go | 16 +- plugin/statsd/statsd.go | 40 ++-- plugin/statsd/statsd_test.go | 4 +- schemaless/capi/influxdb_test.go | 2 - schemaless/proto/influx.go | 7 - system/controller.go | 17 +- system/main.go | 2 +- system/main_test.go | 2 +- system/plugin.go | 14 +- thread/locker_test.go | 4 +- tools/bytesutil/bytesutil.go | 13 +- tools/connectpool/pool.go | 27 +-- tools/connectpool/pool_test.go | 29 ++- tools/ctools/block.go | 16 +- tools/ctools/block_test.go | 9 +- tools/generator/reqid_test.go | 4 +- tools/joinerror/join.go | 41 ++++ tools/joinerror/join_test.go | 68 +++++++ tools/jsonbuilder/builder_test.go | 3 +- tools/jsonbuilder/stream.go | 4 +- tools/jsonbuilder/stream_rune.go | 45 +---- tools/jsonbuilder/stream_test.go | 23 ++- tools/jsontype/uint8.go | 2 + tools/layout/time.go | 8 +- tools/monitor/collect_test.go | 2 +- tools/monitor/util_test.go | 10 +- tools/web/middleware_test.go | 2 +- version/version.go | 2 + 85 files changed, 1173 insertions(+), 758 deletions(-) create mode 100644 .golangci.yaml delete mode 100644 schemaless/proto/influx.go create mode 100644 tools/joinerror/join.go create mode 100644 tools/joinerror/join_test.go diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index b23fa431..462d3944 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -249,7 +249,6 @@ jobs: run: sudo go test -v --count=1 -coverprofile=coverage.txt -covermode=atomic ./... - name: Upload coverage to Codecov - if: ${{ matrix.go }} == '1.20' uses: codecov/codecov-action@v4-beta with: files: ./coverage.txt @@ -260,4 +259,54 @@ jobs: if: always() && (steps.test.outcome == 'failure' || steps.test.outcome == 'cancelled') with: name: ${{ runner.os }}-${{ matrix.go }}-log - path: /var/log/taos/ \ No newline at end of file + path: /var/log/taos/ + + golangci: + name: lint + runs-on: ubuntu-latest + needs: build + steps: + - name: get cache server by pr + if: github.event_name == 'pull_request' + id: get-cache-server-pr + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ github.base_ref }}- + + - name: get cache server by push + if: github.event_name == 'push' + id: get-cache-server-push + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ github.ref_name }}- + + - name: get cache server manually + if: github.event_name == 'workflow_dispatch' + id: get-cache-server-manually + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ inputs.tbBranch }}- + + + - name: install + run: | + tar -zxvf server.tar.gz + cd release && sudo sh install.sh + + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: stable + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.61.0 \ No newline at end of file diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 00000000..3fabc1c1 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,38 @@ +run: + timeout: 5m + modules-download-mode: readonly + +linters: + enable: + - goimports + - revive + - govet + - staticcheck + - gosimple + +issues: + exclude-use-default: false + max-issues-per-linter: 0 + max-same-issues: 0 + +linters-settings: + revive: + rules: + - name: context-keys-type + - name: time-naming + - name: var-declaration + - name: unexported-return + - name: errorf + - name: blank-imports + - name: context-as-argument + - name: dot-imports + - name: error-return + - name: error-strings + - name: error-naming + - name: var-naming + arguments: + - [ "ID","IP","JSON","URL","HTTP","SQL","CPU","URI" ] + - [ ] + - name: range + - name: receiver-naming + - name: indent-error-flow \ No newline at end of file diff --git a/benchmark/remotewrite/main.go b/benchmark/remotewrite/main.go index 21a42e1d..94125161 100644 --- a/benchmark/remotewrite/main.go +++ b/benchmark/remotewrite/main.go @@ -56,10 +56,10 @@ func main() { } if resp.StatusCode != 202 { d, _ := io.ReadAll(resp.Body) - resp.Body.Close() + _ = resp.Body.Close() panic(string(d)) } - resp.Body.Close() + _ = resp.Body.Close() } }() } @@ -103,11 +103,6 @@ func generateData(id string, loop int) [][]byte { } const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -const ( - letterIdxBits = 6 // 6 bits to represent a letter index - letterIdxMask = 1< file > env func init() { - maxprocs.Set() + // get the number of CPU cores, and set GOMAXPROCS to match the number of CPU cores + _, _ = maxprocs.Set() viper.SetDefault("debug", true) _ = viper.BindEnv("debug", "TAOS_ADAPTER_DEBUG") pflag.Bool("debug", true, `enable debug mode. Env "TAOS_ADAPTER_DEBUG"`) diff --git a/controller/controller_test.go b/controller/controller_test.go index f6e366b1..46d5b2e7 100644 --- a/controller/controller_test.go +++ b/controller/controller_test.go @@ -8,7 +8,7 @@ import ( type MockController struct{} -func (mc MockController) Init(r gin.IRouter) {} +func (mc MockController) Init(_ gin.IRouter) {} func TestAddAndGetControllers(t *testing.T) { mockController := MockController{} diff --git a/controller/ping/controller.go b/controller/ping/controller.go index 8cb79260..1b48fcab 100644 --- a/controller/ping/controller.go +++ b/controller/ping/controller.go @@ -18,10 +18,9 @@ func (c Controller) Init(r gin.IRouter) { if monitor.QueryPaused() { c.Status(http.StatusServiceUnavailable) return - } else { - c.Status(http.StatusOK) - return } + c.Status(http.StatusOK) + return } if monitor.AllPaused() { c.Status(http.StatusServiceUnavailable) diff --git a/controller/rest/configcontroller.go b/controller/rest/configcontroller.go index 4e63e7c8..e5490213 100644 --- a/controller/rest/configcontroller.go +++ b/controller/rest/configcontroller.go @@ -2,13 +2,13 @@ package rest import ( "encoding/json" - "github.com/taosdata/driver-go/v3/wrapper" "net/http" "sync/atomic" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" taoserrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/db/commonpool" "github.com/taosdata/taosadapter/v3/db/tool" diff --git a/controller/rest/restful.go b/controller/rest/restful.go index 09328fa2..d7667e52 100644 --- a/controller/rest/restful.go +++ b/controller/rest/restful.go @@ -315,12 +315,11 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns w.Flush() } return - } else { - if monitor.QueryPaused() { - logger.Errorf("query memory exceeds threshold, QID:0x%x", reqID) - c.AbortWithStatusJSON(http.StatusServiceUnavailable, "query memory exceeds threshold") - return - } + } + if monitor.QueryPaused() { + logger.Errorf("query memory exceeds threshold, QID:0x%x", reqID) + c.AbortWithStatusJSON(http.StatusServiceUnavailable, "query memory exceeds threshold") + return } fieldsCount := wrapper.TaosNumFields(res) logger.Tracef("get fieldsCount:%d", fieldsCount) @@ -357,7 +356,7 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns if err != nil { return } - tmpFlushTiming += tmpFlushTiming + flushTiming += tmpFlushTiming total := 0 builder.WritePure(Query3) precision := wrapper.TaosResultPrecision(res) @@ -373,76 +372,73 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns if result.N == 0 { logger.Trace("fetch finished") break + } + if result.N < 0 { + logger.Tracef("fetch error, result.N:%d", result.N) + break + } + res = result.Res + if fetched { + builder.WriteMore() } else { - if result.N < 0 { - logger.Tracef("fetch error, result.N:%d", result.N) - break + fetched = true + } + logger.Tracef("get fetch result, rows:%d", result.N) + block := wrapper.TaosGetRawBlock(res) + logger.Trace("start parse block") + blockSize := result.N + nullBitMapOffset := uintptr(ctools.BitmapLen(blockSize)) + lengthOffset := parser.RawBlockGetColumnLengthOffset(fieldsCount) + tmpPHeader := tools.AddPointer(block, parser.RawBlockGetColDataOffset(fieldsCount)) + for column := 0; column < fieldsCount; column++ { + colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*parser.Int32Size))) + if ctools.IsVarDataType(rowsHeader.ColTypes[column]) { + pHeaderList[column] = tmpPHeader + pStartList[column] = tools.AddPointer(tmpPHeader, uintptr(4*blockSize)) + } else { + pHeaderList[column] = tmpPHeader + pStartList[column] = tools.AddPointer(tmpPHeader, nullBitMapOffset) } - res = result.Res - if fetched { - builder.WriteMore() + tmpPHeader = tools.AddPointer(pStartList[column], uintptr(colLength)) + } + + for row := 0; row < result.N; row++ { + if returnObj { + builder.WriteObjectStart() } else { - fetched = true + builder.WriteArrayStart() } - logger.Tracef("get fetch result, rows:%d", result.N) - block := wrapper.TaosGetRawBlock(res) - logger.Trace("start parse block") - blockSize := result.N - nullBitMapOffset := uintptr(ctools.BitmapLen(blockSize)) - lengthOffset := parser.RawBlockGetColumnLengthOffset(fieldsCount) - tmpPHeader := tools.AddPointer(block, parser.RawBlockGetColDataOffset(fieldsCount)) - tmpPStart := tmpPHeader for column := 0; column < fieldsCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*parser.Int32Size))) - if ctools.IsVarDataType(rowsHeader.ColTypes[column]) { - pHeaderList[column] = tmpPHeader - tmpPStart = tools.AddPointer(tmpPHeader, uintptr(4*blockSize)) - pStartList[column] = tmpPStart - } else { - pHeaderList[column] = tmpPHeader - tmpPStart = tools.AddPointer(tmpPHeader, nullBitMapOffset) - pStartList[column] = tmpPStart - } - tmpPHeader = tools.AddPointer(tmpPStart, uintptr(colLength)) - } - - for row := 0; row < result.N; row++ { - if returnObj { - builder.WriteObjectStart() - } else { - builder.WriteArrayStart() - } - for column := 0; column < fieldsCount; column++ { - if returnObj { - builder.WriteObjectField(rowsHeader.ColNames[column]) - } - ctools.JsonWriteRawBlock(builder, rowsHeader.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, timeFormat) - if column != fieldsCount-1 { - builder.WriteMore() - } - } - // try flushing after parsing a row of data - tmpFlushTiming, err = tryFlush(w, builder, calculateTiming) - if err != nil { - return - } - flushTiming += tmpFlushTiming if returnObj { - builder.WriteObjectEnd() - } else { - builder.WriteArrayEnd() + builder.WriteObjectField(rowsHeader.ColNames[column]) } - total += 1 - if config.Conf.RestfulRowLimit > -1 && total == config.Conf.RestfulRowLimit { - logger.Tracef("row limit %d reached", config.Conf.RestfulRowLimit) - break - } - if row != result.N-1 { + ctools.JsonWriteRawBlock(builder, rowsHeader.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, timeFormat) + if column != fieldsCount-1 { builder.WriteMore() } } - logger.Trace("parse block finished") + // try flushing after parsing a row of data + tmpFlushTiming, err = tryFlush(w, builder, calculateTiming) + if err != nil { + return + } + flushTiming += tmpFlushTiming + if returnObj { + builder.WriteObjectEnd() + } else { + builder.WriteArrayEnd() + } + total += 1 + if config.Conf.RestfulRowLimit > -1 && total == config.Conf.RestfulRowLimit { + logger.Tracef("row limit %d reached", config.Conf.RestfulRowLimit) + break + } + if row != result.N-1 { + builder.WriteMore() + } } + logger.Trace("parse block finished") + } builder.WritePure(Query4) builder.WriteInt(total) @@ -468,7 +464,7 @@ func tryFlush(w gin.ResponseWriter, builder *jsonbuilder.Stream, calculateTiming if calculateTiming { s = time.Now() w.Flush() - return time.Now().Sub(s).Nanoseconds(), nil + return time.Since(s).Nanoseconds(), nil } w.Flush() } @@ -554,7 +550,7 @@ func (ctl *Restful) upload(c *gin.Context) { }() s = log.GetLogNow(isDebug) logger.Tracef("exec sql: %s", sql) - result, err := async.GlobalAsync.TaosExec(taosConnect.TaosConnection, logger, isDebug, sql, func(ts int64, precision int) driver.Value { + result, err := async.GlobalAsync.TaosExec(taosConnect.TaosConnection, logger, isDebug, sql, func(ts int64, _ int) driver.Value { return ts }, reqID) logger.Debugf("describe table cost:%s", log.GetLogDuration(isDebug, s)) @@ -666,7 +662,7 @@ func (ctl *Restful) upload(c *gin.Context) { buffer.WriteString(tableName) buffer.WriteString(" values") } - colBuffer.WriteTo(buffer) + _, _ = colBuffer.WriteTo(buffer) } } if buffer.Len() > prefixLength { @@ -728,7 +724,12 @@ func (ctl *Restful) des(c *gin.Context) { UnAuthResponse(c, logger, httperror.TSDB_CODE_RPC_AUTH_FAILURE) return } - conn.Put() + err = conn.Put() + if err != nil { + logger.Errorf("put connection error, err:%s", err) + InternalErrorResponse(c, logger, httperror.HTTP_GEN_TAOSD_TOKEN_ERR, "put connection error") + return + } token, err := EncodeDes(user, password) if err != nil { logger.Errorf("encode token error, err:%s", err) diff --git a/controller/rest/restful_test.go b/controller/rest/restful_test.go index 3aa7ad32..39a54d50 100644 --- a/controller/rest/restful_test.go +++ b/controller/rest/restful_test.go @@ -342,7 +342,7 @@ func TestWrongEmptySql(t *testing.T) { type ErrorReader struct{} -func (e *ErrorReader) Read(p []byte) (n int, err error) { +func (e *ErrorReader) Read(_ []byte) (n int, err error) { return 0, errors.New("forced read error") } diff --git a/controller/ws/query/ws.go b/controller/ws/query/ws.go index c3cacbad..dd3b1f63 100644 --- a/controller/ws/query/ws.go +++ b/controller/ws/query/ws.go @@ -72,7 +72,7 @@ func NewQueryController() *QueryController { } switch action.Action { case wstool.ClientVersion: - session.Write(wstool.VersionResp) + _ = session.Write(wstool.VersionResp) case WSConnect: var wsConnect WSConnectReq err = json.Unmarshal(action.Args, &wsConnect) @@ -112,7 +112,7 @@ func NewQueryController() *QueryController { logger.WithError(err).WithField(config.ReqIDKey, fetchJson.ReqID).Errorln("unmarshal fetch_json args") return } - t.freeResult(session, &fetchJson) + t.freeResult(&fetchJson) default: logger.WithError(err).Errorln("unknown action :" + action.Action) return @@ -268,7 +268,7 @@ func (t *Taos) waitSignal(logger *logrus.Entry) { logger.WithField("clientIP", t.ipStr).Info("user dropped! close connection!") logger.Trace("close session") s := log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() logger.Trace("close handler") @@ -279,7 +279,6 @@ func (t *Taos) waitSignal(logger *logrus.Entry) { case <-t.whitelistChangeChan: logger.Info("get whitelist change signal") isDebug := log.IsDebug() - s := log.GetLogNow(isDebug) t.lock(logger, isDebug) if t.closed { logger.Trace("server closed") @@ -287,13 +286,13 @@ func (t *Taos) waitSignal(logger *logrus.Entry) { return } logger.Trace("get whitelist") - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) whitelist, err := tool.GetWhitelist(t.conn) logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { logger.WithField("clientIP", t.ipStr).WithError(err).Errorln("get whitelist error! close connection!") s = log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() logger.Trace("close handler") @@ -308,7 +307,7 @@ func (t *Taos) waitSignal(logger *logrus.Entry) { logger.WithField("clientIP", t.ipStr).Errorln("ip not in whitelist! close connection!") logger.Trace("close session") s = log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() logger.Trace("close handler") @@ -340,6 +339,7 @@ type Result struct { func (r *Result) FreeResult(logger *logrus.Entry) { r.Lock() + defer r.Unlock() r.FieldsCount = 0 r.Header = nil r.Lengths = nil @@ -351,9 +351,8 @@ func (r *Result) FreeResult(logger *logrus.Entry) { } if r.TaosResult != nil { syncinterface.FreeResult(r.TaosResult, logger, log.IsDebug()) + r.TaosResult = nil } - r.logger = nil - r.Unlock() } func (t *Taos) addResult(result *Result) { @@ -395,8 +394,8 @@ func (t *Taos) getResult(index uint64) *list.Element { func (t *Taos) removeResult(item *list.Element) { t.resultLocker.Lock() + defer t.resultLocker.Unlock() t.Results.Remove(item) - t.resultLocker.Unlock() } type WSConnectReq struct { @@ -419,7 +418,6 @@ func (t *Taos) connect(ctx context.Context, session *melody.Session, req *WSConn logrus.Fields{"action": WSConnect, config.ReqIDKey: req.ReqID}, ) isDebug := log.IsDebug() - s := log.GetLogNow(isDebug) t.lock(logger, isDebug) defer t.Unlock() if t.closed { @@ -434,17 +432,17 @@ func (t *Taos) connect(ctx context.Context, session *melody.Session, req *WSConn conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug) if err != nil { logger.WithError(err).Errorln("connect to TDengine error") - wstool.WSError(ctx, session, err, WSConnect, req.ReqID) + wstool.WSError(ctx, session, logger, err, WSConnect, req.ReqID) return } logger.Trace("get whitelist") - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) whitelist, err := tool.GetWhitelist(conn) logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { logger.WithError(err).Errorln("get whitelist error") syncinterface.TaosClose(conn, logger, isDebug) - wstool.WSError(ctx, session, err, WSConnect, req.ReqID) + wstool.WSError(ctx, session, logger, err, WSConnect, req.ReqID) return } logger.Tracef("check whitelist, ip: %s, whitelist: %s", t.ipStr, tool.IpNetSliceToString(whitelist)) @@ -452,7 +450,7 @@ func (t *Taos) connect(ctx context.Context, session *melody.Session, req *WSConn if !valid { logger.Errorf("ip not in whitelist, ip: %s, whitelist: %s", t.ipStr, tool.IpNetSliceToString(whitelist)) syncinterface.TaosClose(conn, logger, isDebug) - wstool.WSErrorMsg(ctx, session, 0xffff, "whitelist prohibits current IP access", WSConnect, req.ReqID) + wstool.WSErrorMsg(ctx, session, logger, 0xffff, "whitelist prohibits current IP access", WSConnect, req.ReqID) return } s = log.GetLogNow(isDebug) @@ -462,7 +460,7 @@ func (t *Taos) connect(ctx context.Context, session *melody.Session, req *WSConn if err != nil { logger.WithError(err).Errorln("register whitelist change error") syncinterface.TaosClose(conn, logger, isDebug) - wstool.WSError(ctx, session, err, WSConnect, req.ReqID) + wstool.WSError(ctx, session, logger, err, WSConnect, req.ReqID) return } s = log.GetLogNow(isDebug) @@ -472,7 +470,7 @@ func (t *Taos) connect(ctx context.Context, session *melody.Session, req *WSConn if err != nil { logger.WithError(err).Errorln("register drop user error") syncinterface.TaosClose(conn, logger, isDebug) - wstool.WSError(ctx, session, err, WSConnect, req.ReqID) + wstool.WSError(ctx, session, logger, err, WSConnect, req.ReqID) return } t.conn = conn @@ -532,7 +530,6 @@ func (t *Taos) query(ctx context.Context, session *melody.Session, req *WSQueryR monitor.WSRecordResult(sqlType, false) errStr := wrapper.TaosErrorStr(result.Res) logger.Errorf("query error, code: %d, message: %s", code, errStr) - s = log.GetLogNow(isDebug) logger.Trace("get thread lock for free result") syncinterface.FreeResult(result.Res, logger, isDebug) wsErrorMsg(ctx, session, code, errStr, WSQuery, req.ReqID) @@ -551,39 +548,38 @@ func (t *Taos) query(ctx context.Context, session *melody.Session, req *WSQueryR logger.Debugf("affected_rows %d cost: %s", affectRows, log.GetLogDuration(isDebug, s)) queryResult.IsUpdate = true queryResult.AffectedRows = affectRows - s = log.GetLogNow(isDebug) logger.Trace("get thread lock for free result") syncinterface.FreeResult(result.Res, logger, isDebug) queryResult.Timing = wstool.GetDuration(ctx) wstool.WSWriteJson(session, logger, queryResult) return - } else { - s = log.GetLogNow(isDebug) - fieldsCount := wrapper.TaosNumFields(result.Res) - logger.Debugf("num_fields %d cost: %s", fieldsCount, log.GetLogDuration(isDebug, s)) - queryResult.FieldsCount = fieldsCount - s = log.GetLogNow(isDebug) - rowsHeader, _ := wrapper.ReadColumn(result.Res, fieldsCount) - logger.Debugf("read column cost:%s", log.GetLogDuration(isDebug, s)) - queryResult.FieldsNames = rowsHeader.ColNames - queryResult.FieldsLengths = rowsHeader.ColLength - queryResult.FieldsTypes = rowsHeader.ColTypes - s = log.GetLogNow(isDebug) - precision := wrapper.TaosResultPrecision(result.Res) - logger.Debugf("result_precision %d cost: %s ", precision, log.GetLogDuration(isDebug, s)) - queryResult.Precision = precision - result := &Result{ - TaosResult: result.Res, - FieldsCount: fieldsCount, - Header: rowsHeader, - precision: precision, - } - logger.Trace("add result to list") - t.addResult(result) - queryResult.ID = result.index - queryResult.Timing = wstool.GetDuration(ctx) - wstool.WSWriteJson(session, logger, queryResult) } + // query + s = log.GetLogNow(isDebug) + fieldsCount := wrapper.TaosNumFields(result.Res) + logger.Debugf("num_fields %d cost: %s", fieldsCount, log.GetLogDuration(isDebug, s)) + queryResult.FieldsCount = fieldsCount + s = log.GetLogNow(isDebug) + rowsHeader, _ := wrapper.ReadColumn(result.Res, fieldsCount) + logger.Debugf("read column cost:%s", log.GetLogDuration(isDebug, s)) + queryResult.FieldsNames = rowsHeader.ColNames + queryResult.FieldsLengths = rowsHeader.ColLength + queryResult.FieldsTypes = rowsHeader.ColTypes + s = log.GetLogNow(isDebug) + precision := wrapper.TaosResultPrecision(result.Res) + logger.Debugf("result_precision %d cost: %s ", precision, log.GetLogDuration(isDebug, s)) + queryResult.Precision = precision + resultItem := &Result{ + TaosResult: result.Res, + FieldsCount: fieldsCount, + Header: rowsHeader, + precision: precision, + } + logger.Trace("add result to list") + t.addResult(resultItem) + queryResult.ID = resultItem.index + queryResult.Timing = wstool.GetDuration(ctx) + wstool.WSWriteJson(session, logger, queryResult) } type WSWriteMetaResp struct { @@ -600,7 +596,6 @@ func (t *Taos) writeRaw(ctx context.Context, session *melody.Session, reqID, mes logrus.Fields{"action": WSWriteRaw, config.ReqIDKey: reqID}, ) isDebug := log.IsDebug() - s := log.GetLogNow(isDebug) t.lock(logger, isDebug) defer t.Unlock() if t.closed { @@ -613,7 +608,7 @@ func (t *Taos) writeRaw(ctx context.Context, session *melody.Session, reqID, mes return } meta := wrapper.BuildRawMeta(length, metaType, data) - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) logger.Trace("get thread lock for write raw meta") thread.SyncLocker.Lock() logger.Debugf("get thread lock cost:%s", log.GetLogDuration(isDebug, s)) @@ -645,7 +640,6 @@ func (t *Taos) writeRawBlock(ctx context.Context, session *melody.Session, reqID logrus.Fields{"action": WSWriteRawBlock, config.ReqIDKey: reqID}, ) isDebug := log.IsDebug() - s := log.GetLogNow(isDebug) t.lock(logger, isDebug) defer t.Unlock() if t.closed { @@ -657,7 +651,7 @@ func (t *Taos) writeRawBlock(ctx context.Context, session *melody.Session, reqID return } logger.Trace("get thread lock for write raw block") - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) thread.SyncLocker.Lock() logger.Debugf("get thread lock cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) @@ -688,7 +682,6 @@ func (t *Taos) writeRawBlockWithFields(ctx context.Context, session *melody.Sess logrus.Fields{"action": WSWriteRawBlockWithFields, config.ReqIDKey: reqID}, ) isDebug := log.IsDebug() - s := log.GetLogNow(isDebug) t.lock(logger, isDebug) defer t.Unlock() if t.closed { @@ -701,7 +694,7 @@ func (t *Taos) writeRawBlockWithFields(ctx context.Context, session *melody.Sess return } logger.Trace("get thread lock for write raw block with fields") - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) thread.SyncLocker.Lock() logger.Debugf("get thread lock cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) @@ -753,6 +746,13 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR return } resultS := resultItem.Value.(*Result) + resultS.Lock() + if resultS.TaosResult == nil { + resultS.Unlock() + logger.Errorf("result is nil") + wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetch, req.ReqID) + return + } s := log.GetLogNow(isDebug) handler := async.GlobalAsync.HandlerPool.Get() logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) @@ -763,6 +763,7 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR logger.Debugf("fetch_raw_block_a cost:%s", log.GetLogDuration(isDebug, s)) if result.N == 0 { logger.Trace("fetch raw block completed") + resultS.Unlock() t.FreeResult(resultItem, logger) wstool.WSWriteJson(session, logger, &WSFetchResp{ Action: WSFetch, @@ -776,6 +777,7 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR if result.N < 0 { errStr := wrapper.TaosErrorStr(result.Res) logger.Errorf("fetch raw block error, code: %d, message: %s", result.N, errStr) + resultS.Unlock() t.FreeResult(resultItem, logger) wsErrorMsg(ctx, session, result.N&0xffff, errStr, WSFetch, req.ReqID) return @@ -789,7 +791,7 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR logger.Debugf("get_raw_block cost:%s", log.GetLogDuration(isDebug, s)) resultS.Block = block resultS.Size = result.N - + resultS.Unlock() wstool.WSWriteJson(session, logger, &WSFetchResp{ Action: WSFetch, ReqID: req.ReqID, @@ -822,11 +824,17 @@ func (t *Taos) fetchBlock(ctx context.Context, session *melody.Session, req *WSF return } resultS := resultItem.Value.(*Result) + resultS.Lock() + if resultS.TaosResult == nil { + resultS.Unlock() + wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetchBlock, req.ReqID) + return + } if resultS.Block == nil { + resultS.Unlock() wsErrorMsg(ctx, session, 0xffff, "block is nil", WSFetchBlock, req.ReqID) return } - resultS.Lock() blockLength := int(parser.RawBlockGetLength(resultS.Block)) if resultS.buffer == nil { resultS.buffer = new(bytes.Buffer) @@ -850,7 +858,7 @@ type WSFreeResultReq struct { ID uint64 `json:"id"` } -func (t *Taos) freeResult(session *melody.Session, req *WSFreeResultReq) { +func (t *Taos) freeResult(req *WSFreeResultReq) { logger := t.logger.WithFields( logrus.Fields{"action": WSFreeResult, config.ReqIDKey: req.ReqID}, ) @@ -910,7 +918,6 @@ func (t *Taos) freeAllResult() { func (t *Taos) Close() { isDebug := log.IsDebug() - s := log.GetLogNow(isDebug) t.lock(t.logger, isDebug) defer t.Unlock() if t.closed { @@ -933,7 +940,7 @@ func (t *Taos) Close() { t.logger.Trace("all task finished") } t.logger.Trace("free all result") - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) t.freeAllResult() t.logger.Debugf("free all result cost:%s", log.GetLogDuration(isDebug, s)) if t.conn != nil { @@ -981,7 +988,7 @@ func wsErrorMsg(ctx context.Context, session *melody.Session, code int, message Timing: wstool.GetDuration(ctx), }) wstool.GetLogger(session).Tracef("write error message: %s", b) - session.Write(b) + _ = session.Write(b) } type WSTMQErrorResp struct { @@ -1003,7 +1010,7 @@ func wsTMQErrorMsg(ctx context.Context, session *melody.Session, code int, messa MessageID: messageID, }) wstool.GetLogger(session).Tracef("write error message: %s", b) - session.Write(b) + _ = session.Write(b) } func init() { diff --git a/controller/ws/query/ws_test.go b/controller/ws/query/ws_test.go index 223c0628..96319303 100644 --- a/controller/ws/query/ws_test.go +++ b/controller/ws/query/ws_test.go @@ -84,7 +84,10 @@ func TestWebsocket(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterConnect = 1 AfterQuery = 2 @@ -102,7 +105,7 @@ func TestWebsocket(t *testing.T) { //var jsonResult [][]interface{} var resultID uint64 var blockResult [][]driver.Value - testMessageHandler := func(messageType int, message []byte) error { + testMessageHandler := func(_ int, message []byte) error { //json switch status { case AfterConnect: @@ -392,7 +395,7 @@ func TestWriteBlock(t *testing.T) { var queryResult *WSQueryResult var rows int finish := make(chan struct{}) - testMessageHandler := func(messageType int, message []byte) error { + testMessageHandler := func(_ int, message []byte) error { //json switch status { case AfterConnect: @@ -596,14 +599,15 @@ func TestWriteBlock(t *testing.T) { return } <-finish - ws.Close() + err = ws.Close() + assert.NoError(t, err) ws, _, err = websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/ws", nil) if err != nil { t.Error(err) return } var blockResult [][]driver.Value - testMessageHandler2 := func(messageType int, message []byte) error { + testMessageHandler2 := func(_ int, message []byte) error { switch status { case AfterConnect: var d WSConnectResp @@ -776,7 +780,8 @@ func TestWriteBlock(t *testing.T) { return } <-finish - ws.Close() + err = ws.Close() + assert.NoError(t, err) assert.Equal(t, 3, len(blockResult)) assert.Equal(t, true, blockResult[0][1]) assert.Equal(t, int8(2), blockResult[0][2]) @@ -884,7 +889,7 @@ func TestWriteBlockWithFields(t *testing.T) { var queryResult *WSQueryResult var rows int finish := make(chan struct{}) - testMessageHandler := func(messageType int, message []byte) error { + testMessageHandler := func(_ int, message []byte) error { //json switch status { case AfterConnect: @@ -1116,14 +1121,15 @@ func TestWriteBlockWithFields(t *testing.T) { return } <-finish - ws.Close() + err = ws.Close() + assert.NoError(t, err) ws, _, err = websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/ws", nil) if err != nil { t.Error(err) return } var blockResult [][]driver.Value - testMessageHandler2 := func(messageType int, message []byte) error { + testMessageHandler2 := func(_ int, message []byte) error { switch status { case AfterConnect: var d WSConnectResp @@ -1296,7 +1302,8 @@ func TestWriteBlockWithFields(t *testing.T) { return } <-finish - ws.Close() + err = ws.Close() + assert.NoError(t, err) assert.Equal(t, 3, len(blockResult)) assert.Equal(t, now, blockResult[0][0].(time.Time).UnixNano()/1e6) assert.Equal(t, true, blockResult[0][1]) @@ -1359,7 +1366,10 @@ func TestQueryAllType(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterConnect = 1 AfterQuery = 2 @@ -1377,7 +1387,7 @@ func TestQueryAllType(t *testing.T) { //var jsonResult [][]interface{} var resultID uint64 var blockResult [][]driver.Value - testMessageHandler := func(messageType int, message []byte) error { + testMessageHandler := func(_ int, message []byte) error { //json switch status { case AfterConnect: diff --git a/controller/ws/schemaless/schemaless.go b/controller/ws/schemaless/schemaless.go index 3072e559..54953f0d 100644 --- a/controller/ws/schemaless/schemaless.go +++ b/controller/ws/schemaless/schemaless.go @@ -57,17 +57,17 @@ func NewSchemalessController() *SchemalessController { err := json.Unmarshal(bytes, &action) if err != nil { logger.Errorf("unmarshal ws request error, err:%s", err) - wstool.WSError(ctx, session, err, action.Action, 0) + wstool.WSError(ctx, session, logger, err, action.Action, 0) return } switch action.Action { case wstool.ClientVersion: - session.Write(wstool.VersionResp) + _ = session.Write(wstool.VersionResp) case SchemalessConn: var req schemalessConnReq if err = json.Unmarshal(action.Args, &req); err != nil { logger.Errorf("unmarshal connect args, err:%s, args:%s", err, action.Args) - wstool.WSError(ctx, session, err, SchemalessConn, req.ReqID) + wstool.WSError(ctx, session, logger, err, SchemalessConn, req.ReqID) return } t.connect(ctx, session, req) @@ -75,7 +75,7 @@ func NewSchemalessController() *SchemalessController { var req schemalessWriteReq if err = json.Unmarshal(action.Args, &req); err != nil { logger.Errorf("unmarshal schemaless insert args, err:%s, args:%s", err, action.Args) - wstool.WSError(ctx, session, err, SchemalessWrite, req.ReqID) + wstool.WSError(ctx, session, logger, err, SchemalessWrite, req.ReqID) return } t.insert(ctx, session, req) @@ -178,7 +178,7 @@ func (t *TaosSchemaless) waitSignal(logger *logrus.Entry) { } logger.Info("user dropped! close connection!") s := log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -202,7 +202,7 @@ func (t *TaosSchemaless) waitSignal(logger *logrus.Entry) { logger.Errorf("get whitelist error, close connection, err:%s", err) wstool.GetLogger(t.session).WithField("ip", t.ipStr).WithError(err).Errorln("get whitelist error! close connection!") s = log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -215,7 +215,7 @@ func (t *TaosSchemaless) waitSignal(logger *logrus.Entry) { if !valid { logger.Errorf("ip not in whitelist, close connection, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) s = log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -293,13 +293,13 @@ func (t *TaosSchemaless) connect(ctx context.Context, session *melody.Session, r } if t.conn != nil { logger.Errorf("duplicate connections") - wsSchemalessErrorMsg(ctx, session, 0xffff, "duplicate connections", action, req.ReqID) + wsSchemalessErrorMsg(ctx, session, logger, 0xffff, "duplicate connections", action, req.ReqID) return } conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug) if err != nil { logger.Errorf("connect error, err:%s", err) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } s := log.GetLogNow(isDebug) @@ -308,7 +308,7 @@ func (t *TaosSchemaless) connect(ctx context.Context, session *melody.Session, r if err != nil { logger.Errorf("get whitelist error, close connection, err:%s", err) syncinterface.TaosClose(conn, t.logger, isDebug) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) @@ -316,7 +316,7 @@ func (t *TaosSchemaless) connect(ctx context.Context, session *melody.Session, r if !valid { logger.Errorf("ip not in whitelist, close connection, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) syncinterface.TaosClose(conn, t.logger, isDebug) - wstool.WSErrorMsg(ctx, session, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) + wstool.WSErrorMsg(ctx, session, logger, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) return } logger.Trace("register change whitelist") @@ -324,7 +324,7 @@ func (t *TaosSchemaless) connect(ctx context.Context, session *melody.Session, r if err != nil { logger.Errorf("register change whitelist error:%s", err) syncinterface.TaosClose(conn, t.logger, isDebug) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } logger.Trace("register drop user") @@ -332,7 +332,7 @@ func (t *TaosSchemaless) connect(ctx context.Context, session *melody.Session, r if err != nil { logger.Errorf("register drop user error:%s", err) syncinterface.TaosClose(conn, t.logger, isDebug) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } t.conn = conn @@ -371,12 +371,12 @@ func (t *TaosSchemaless) insert(ctx context.Context, session *melody.Session, re isDebug := log.IsDebug() if req.Protocol == 0 { logger.Errorf("args error, protocol is 0") - wsSchemalessErrorMsg(ctx, session, 0xffff, "args error", action, req.ReqID) + wsSchemalessErrorMsg(ctx, session, logger, 0xffff, "args error", action, req.ReqID) return } if t.conn == nil { logger.Errorf("server not connected") - wsSchemalessErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID) + wsSchemalessErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID) return } var result unsafe.Pointer @@ -405,7 +405,7 @@ func (t *TaosSchemaless) insert(ctx context.Context, session *melody.Session, re } if err != nil { logger.Errorf("insert error, err:%s", err) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } affectedRows = wrapper.TaosAffectedRows(result) @@ -428,15 +428,15 @@ type WSSchemalessErrorResp struct { Timing int64 `json:"timing"` } -func wsSchemalessErrorMsg(ctx context.Context, session *melody.Session, code int, message string, action string, reqID uint64) { - b, _ := json.Marshal(&WSSchemalessErrorResp{ +func wsSchemalessErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, action string, reqID uint64) { + data := &WSSchemalessErrorResp{ Code: code & 0xffff, Message: message, Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), - }) - session.Write(b) + } + wstool.WSWriteJson(session, logger, data) } func init() { diff --git a/controller/ws/schemaless/schemaless_test.go b/controller/ws/schemaless/schemaless_test.go index d0661025..a2c87d1a 100644 --- a/controller/ws/schemaless/schemaless_test.go +++ b/controller/ws/schemaless/schemaless_test.go @@ -143,7 +143,10 @@ func TestRestful_InitSchemaless(t *testing.T) { if err != nil { t.Fatal("connect error", err) } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() j, _ := json.Marshal(map[string]interface{}{ "action": "conn", diff --git a/controller/ws/stmt/convert.go b/controller/ws/stmt/convert.go index d1c7e9f1..17537dab 100644 --- a/controller/ws/stmt/convert.go +++ b/controller/ws/stmt/convert.go @@ -217,7 +217,7 @@ func BlockConvert(block unsafe.Pointer, blockSize int, fields []*stmtCommon.Stmt nullBitMapOffset := uintptr(parser.BitmapLen(blockSize)) lengthOffset := parser.RawBlockGetColumnLengthOffset(colCount) pHeader := tools.AddPointer(block, parser.RawBlockGetColDataOffset(colCount)) - pStart := pHeader + var pStart unsafe.Pointer length := 0 for column := 0; column < colCount; column++ { r[column] = make([]driver.Value, blockSize) @@ -285,9 +285,8 @@ func ItemIsNull(pHeader unsafe.Pointer, row int) bool { func rawConvertBool(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { if (*((*byte)(tools.AddPointer(pStart, uintptr(row)*1)))) != 0 { return types.TaosBool(true) - } else { - return types.TaosBool(false) } + return types.TaosBool(false) } func rawConvertTinyint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { diff --git a/controller/ws/stmt/stmt.go b/controller/ws/stmt/stmt.go index a80e49ad..9eb5b918 100644 --- a/controller/ws/stmt/stmt.go +++ b/controller/ws/stmt/stmt.go @@ -71,7 +71,7 @@ func NewSTMTController() *STMTController { } switch action.Action { case wstool.ClientVersion: - session.Write(wstool.VersionResp) + _ = session.Write(wstool.VersionResp) case STMTConnect: var req StmtConnectReq err = json.Unmarshal(action.Args, &req) @@ -304,7 +304,7 @@ func (t *TaosStmt) waitSignal(logger *logrus.Entry) { } logger.Info("user dropped! close connection!") s := log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -328,7 +328,7 @@ func (t *TaosStmt) waitSignal(logger *logrus.Entry) { logger.Errorf("get whitelist error, close connection, err:%s", err) wstool.GetLogger(t.session).WithField("ip", t.ipStr).WithError(err).Errorln("get whitelist error! close connection!") s = log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -341,7 +341,7 @@ func (t *TaosStmt) waitSignal(logger *logrus.Entry) { if !valid { logger.Errorf("ip not in whitelist, close connection, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) s = log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -442,13 +442,13 @@ func (t *TaosStmt) connect(ctx context.Context, session *melody.Session, req *St } if t.conn != nil { logger.Errorf("duplicate connections") - wsStmtErrorMsg(ctx, session, 0xffff, "duplicate connections", action, req.ReqID, nil) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "duplicate connections", action, req.ReqID, nil) return } conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug) if err != nil { logger.Errorf("connect error, err:%s", err) - wsStmtError(ctx, session, err, action, req.ReqID, nil) + wsStmtError(ctx, session, logger, err, action, req.ReqID, nil) return } s := log.GetLogNow(isDebug) @@ -457,7 +457,7 @@ func (t *TaosStmt) connect(ctx context.Context, session *melody.Session, req *St if err != nil { logger.Errorf("get whitelist error, close connection, err:%s", err) syncinterface.TaosClose(conn, logger, isDebug) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) @@ -465,7 +465,7 @@ func (t *TaosStmt) connect(ctx context.Context, session *melody.Session, req *St if !valid { logger.Errorf("ip not in whitelist, close connection, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) syncinterface.TaosClose(conn, logger, isDebug) - wstool.WSErrorMsg(ctx, session, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) + wstool.WSErrorMsg(ctx, session, logger, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) return } logger.Trace("register change whitelist") @@ -473,7 +473,7 @@ func (t *TaosStmt) connect(ctx context.Context, session *melody.Session, req *St if err != nil { logger.Errorf("register change whitelist error, err:%s", err) syncinterface.TaosClose(conn, logger, isDebug) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } logger.Trace("register drop user") @@ -481,7 +481,7 @@ func (t *TaosStmt) connect(ctx context.Context, session *melody.Session, req *St if err != nil { logger.Errorf("register drop user error, err:%s", err) syncinterface.TaosClose(conn, logger, isDebug) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } t.conn = conn @@ -511,7 +511,7 @@ func (t *TaosStmt) init(ctx context.Context, session *melody.Session, req *StmtI logger.Tracef("stmt init request:%+v", req) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, nil) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, nil) return } isDebug := log.IsDebug() @@ -519,7 +519,7 @@ func (t *TaosStmt) init(ctx context.Context, session *melody.Session, req *StmtI if stmt == nil { errStr := wrapper.TaosStmtErrStr(stmt) logger.Errorf("stmt init error, err:%s", errStr) - wsStmtErrorMsg(ctx, session, 0xffff, errStr, action, req.ReqID, nil) + wsStmtErrorMsg(ctx, session, logger, 0xffff, errStr, action, req.ReqID, nil) return } stmtItem := &StmtItem{ @@ -552,14 +552,14 @@ func (t *TaosStmt) prepare(ctx context.Context, session *melody.Session, req *St if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -568,7 +568,7 @@ func (t *TaosStmt) prepare(ctx context.Context, session *melody.Session, req *St if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt prepare error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } logger.Tracef("stmt prepare success, stmt_id:%d", req.StmtID) @@ -601,13 +601,13 @@ func (t *TaosStmt) setTableName(ctx context.Context, session *melody.Session, re logger.Tracef("stmt set table name, stmt_id:%d, name:%s", req.StmtID, req.Name) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -616,7 +616,7 @@ func (t *TaosStmt) setTableName(ctx context.Context, session *melody.Session, re if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt set table name error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } resp := &StmtSetTableNameResp{ @@ -650,13 +650,13 @@ func (t *TaosStmt) setTags(ctx context.Context, session *melody.Session, req *St logger.Tracef("stmt set tags, stmt_id:%d, tags:%+v", req.StmtID, req.Tags) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -665,7 +665,7 @@ func (t *TaosStmt) setTags(ctx context.Context, session *melody.Session, req *St if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt get tag fields error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } defer func() { @@ -693,14 +693,14 @@ func (t *TaosStmt) setTags(ctx context.Context, session *melody.Session, req *St logger.Debugf("stmt parse tag json cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { logger.Errorf("stmt parse tag json error, err:%s", err) - wsStmtErrorMsg(ctx, session, 0xffff, fmt.Sprintf("stmt parse tag json:%s", err.Error()), action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, fmt.Sprintf("stmt parse tag json:%s", err.Error()), action, req.ReqID, &req.StmtID) return } code = syncinterface.TaosStmtSetTags(stmt.stmt, data, logger, isDebug) if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt set tags error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } resp.Timing = wstool.GetDuration(ctx) @@ -729,13 +729,13 @@ func (t *TaosStmt) getTagFields(ctx context.Context, session *melody.Session, re logger.Tracef("stmt get tag fields, stmt_id:%d", req.StmtID) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -744,7 +744,7 @@ func (t *TaosStmt) getTagFields(ctx context.Context, session *melody.Session, re if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt get tag fields error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } defer func() { @@ -788,13 +788,13 @@ func (t *TaosStmt) getColFields(ctx context.Context, session *melody.Session, re logger.Tracef("stmt get tag fields, stmt_id:%d", req.StmtID) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -803,7 +803,7 @@ func (t *TaosStmt) getColFields(ctx context.Context, session *melody.Session, re if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt get col fields error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } defer func() { @@ -847,13 +847,13 @@ func (t *TaosStmt) bind(ctx context.Context, session *melody.Session, req *StmtB if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -862,7 +862,7 @@ func (t *TaosStmt) bind(ctx context.Context, session *melody.Session, req *StmtB if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt get col fields error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } defer func() { @@ -888,7 +888,7 @@ func (t *TaosStmt) bind(ctx context.Context, session *melody.Session, req *StmtB fieldTypes[i], err = fields[i].GetType() if err != nil { logger.Errorf("stmt get column type error, err:%s", err) - wsStmtErrorMsg(ctx, session, 0xffff, fmt.Sprintf("stmt get column type error, err:%s", err.Error()), action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, fmt.Sprintf("stmt get column type error, err:%s", err.Error()), action, req.ReqID, &req.StmtID) return } } @@ -897,14 +897,14 @@ func (t *TaosStmt) bind(ctx context.Context, session *melody.Session, req *StmtB logger.Debugf("stmt parse column json cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { logger.Errorf("stmt parse column json error, err:%s", err) - wsStmtErrorMsg(ctx, session, 0xffff, fmt.Sprintf("stmt parse column json:%s", err.Error()), action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, fmt.Sprintf("stmt parse column json:%s", err.Error()), action, req.ReqID, &req.StmtID) return } code = syncinterface.TaosStmtBindParamBatch(stmt.stmt, data, fieldTypes, logger, isDebug) if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt bind error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } resp.Timing = wstool.GetDuration(ctx) @@ -931,13 +931,13 @@ func (t *TaosStmt) addBatch(ctx context.Context, session *melody.Session, req *S logger.Tracef("stmt add batch, stmt_id:%d", req.StmtID) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -946,7 +946,7 @@ func (t *TaosStmt) addBatch(ctx context.Context, session *melody.Session, req *S if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt add batch error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } resp := &StmtAddBatchResp{ @@ -979,13 +979,13 @@ func (t *TaosStmt) exec(ctx context.Context, session *melody.Session, req *StmtE logger.Tracef("stmt exec, stmt_id:%d", req.StmtID) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -994,7 +994,7 @@ func (t *TaosStmt) exec(ctx context.Context, session *melody.Session, req *StmtE if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt exec error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, req.ReqID, &req.StmtID) return } s := log.GetLogNow(isDebug) @@ -1021,13 +1021,13 @@ func (t *TaosStmt) close(ctx context.Context, session *melody.Session, req *Stmt logger.Tracef("stmt close, stmt_id:%d", req.StmtID) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, req.ReqID, &req.StmtID) return } stmtItem := t.getStmtItem(req.StmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, req.ReqID, &req.StmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -1041,18 +1041,18 @@ func (t *TaosStmt) setTagsBlock(ctx context.Context, session *melody.Session, re logger.Tracef("stmt set tags with block, stmt_id:%d", stmtID) if rows != 1 { logger.Errorf("rows not equal 1") - wsStmtErrorMsg(ctx, session, 0xffff, "rows not equal 1", action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "rows not equal 1", action, reqID, &stmtID) return } if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, reqID, &stmtID) return } stmtItem := t.getStmtItem(stmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", stmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, reqID, &stmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -1061,7 +1061,7 @@ func (t *TaosStmt) setTagsBlock(ctx context.Context, session *melody.Session, re if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt get tag fields error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, reqID, &stmtID) return } defer func() { @@ -1078,7 +1078,7 @@ func (t *TaosStmt) setTagsBlock(ctx context.Context, session *melody.Session, re } if columns != tagNums { logger.Errorf("stmt tags count not match, columns:%d, tagNums:%d", columns, tagNums) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt tags count not match", action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt tags count not match", action, reqID, &stmtID) return } s := log.GetLogNow(isDebug) @@ -1095,7 +1095,7 @@ func (t *TaosStmt) setTagsBlock(ctx context.Context, session *melody.Session, re if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt set tags error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, reqID, &stmtID) return } resp.Timing = wstool.GetDuration(ctx) @@ -1108,13 +1108,13 @@ func (t *TaosStmt) bindBlock(ctx context.Context, session *melody.Session, reqID logger.Tracef("stmt bind with block, stmt_id:%d", stmtID) if t.conn == nil { logger.Error("server not connected") - wsStmtErrorMsg(ctx, session, 0xffff, "server not connected", action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "server not connected", action, reqID, &stmtID) return } stmtItem := t.getStmtItem(stmtID) if stmtItem == nil { logger.Errorf("stmt is nil, stmt_id:%d", stmtID) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt is nil", action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt is nil", action, reqID, &stmtID) return } stmt := stmtItem.Value.(*StmtItem) @@ -1123,7 +1123,7 @@ func (t *TaosStmt) bindBlock(ctx context.Context, session *melody.Session, reqID if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt get col fields error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, reqID, &stmtID) return } defer func() { @@ -1148,13 +1148,13 @@ func (t *TaosStmt) bindBlock(ctx context.Context, session *melody.Session, reqID fieldTypes[i], err = fields[i].GetType() if err != nil { logger.Errorf("stmt get column type error, err:%s", err) - wsStmtErrorMsg(ctx, session, 0xffff, fmt.Sprintf("stmt get column type error, err:%s", err.Error()), action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, fmt.Sprintf("stmt get column type error, err:%s", err.Error()), action, reqID, &stmtID) return } } if columns != colNums { logger.Errorf("stmt column count not match, columns:%d, colNums:%d", columns, colNums) - wsStmtErrorMsg(ctx, session, 0xffff, "stmt column count not match", action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, "stmt column count not match", action, reqID, &stmtID) return } s = log.GetLogNow(isDebug) @@ -1164,7 +1164,7 @@ func (t *TaosStmt) bindBlock(ctx context.Context, session *melody.Session, reqID if code != httperror.SUCCESS { errStr := wrapper.TaosStmtErrStr(stmt.stmt) logger.Errorf("stmt bind error, code:%d, msg:%s", code, errStr) - wsStmtErrorMsg(ctx, session, code, errStr, action, reqID, &stmtID) + wsStmtErrorMsg(ctx, session, logger, code, errStr, action, reqID, &stmtID) return } resp.Timing = wstool.GetDuration(ctx) @@ -1227,23 +1227,23 @@ type WSStmtErrorResp struct { StmtID *uint64 `json:"stmt_id,omitempty"` } -func wsStmtErrorMsg(ctx context.Context, session *melody.Session, code int, message string, action string, reqID uint64, stmtID *uint64) { - b, _ := json.Marshal(&WSStmtErrorResp{ +func wsStmtErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, action string, reqID uint64, stmtID *uint64) { + data := &WSStmtErrorResp{ Code: code & 0xffff, Message: message, Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), StmtID: stmtID, - }) - session.Write(b) + } + wstool.WSWriteJson(session, logger, data) } -func wsStmtError(ctx context.Context, session *melody.Session, err error, action string, reqID uint64, stmtID *uint64) { +func wsStmtError(ctx context.Context, session *melody.Session, logger *logrus.Entry, err error, action string, reqID uint64, stmtID *uint64) { e, is := err.(*tErrors.TaosError) if is { - wsStmtErrorMsg(ctx, session, int(e.Code)&0xffff, e.ErrStr, action, reqID, stmtID) + wsStmtErrorMsg(ctx, session, logger, int(e.Code)&0xffff, e.ErrStr, action, reqID, stmtID) } else { - wsStmtErrorMsg(ctx, session, 0xffff, err.Error(), action, reqID, stmtID) + wsStmtErrorMsg(ctx, session, logger, 0xffff, err.Error(), action, reqID, stmtID) } } diff --git a/controller/ws/stmt/stmt_test.go b/controller/ws/stmt/stmt_test.go index 81b0a66e..4101f333 100644 --- a/controller/ws/stmt/stmt_test.go +++ b/controller/ws/stmt/stmt_test.go @@ -91,7 +91,10 @@ func TestSTMT(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterConnect = iota + 1 AfterInit @@ -366,6 +369,7 @@ func TestSTMT(t *testing.T) { nil, }, }) + assert.NoError(t, err) b, _ := json.Marshal(&StmtBindReq{ ReqID: 5, StmtID: stmtID, @@ -512,7 +516,8 @@ func TestSTMT(t *testing.T) { return } <-finish - ws.Close() + err = ws.Close() + assert.NoError(t, err) time.Sleep(time.Second) w = httptest.NewRecorder() body = strings.NewReader("select * from st") @@ -721,7 +726,10 @@ func TestBlock(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterConnect = iota + 1 AfterInit @@ -1027,7 +1035,8 @@ func TestBlock(t *testing.T) { return } <-finish - ws.Close() + err = ws.Close() + assert.NoError(t, err) w = httptest.NewRecorder() body = strings.NewReader("select * from stb") req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_stmt", body) diff --git a/controller/ws/tmq/const.go b/controller/ws/tmq/const.go index 3ce4fffc..5915f7ba 100644 --- a/controller/ws/tmq/const.go +++ b/controller/ws/tmq/const.go @@ -24,5 +24,3 @@ const ( ) const OffsetInvalid = -2147467247 - -const LoggerKey = "logger" diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index 8a42fbea..657a0aaa 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -72,7 +72,7 @@ func NewTMQController() *TMQController { } switch action.Action { case wstool.ClientVersion: - session.Write(wstool.VersionResp) + _ = session.Write(wstool.VersionResp) case TMQSubscribe: var req TMQSubscribeReq err = json.Unmarshal(action.Args, &req) @@ -326,7 +326,7 @@ func (t *TMQ) waitSignal(logger *logrus.Entry) { } logger.Info("user dropped! close connection!") s := log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -336,7 +336,6 @@ func (t *TMQ) waitSignal(logger *logrus.Entry) { case <-t.whitelistChangeChan: logger.Info("get whitelist change signal") isDebug := log.IsDebug() - s := log.GetLogNow(isDebug) t.lock(logger, isDebug) if t.isClosed() { logger.Trace("server closed") @@ -344,13 +343,13 @@ func (t *TMQ) waitSignal(logger *logrus.Entry) { return } logger.Trace("get whitelist") - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) whitelist, err := tool.GetWhitelist(t.conn) logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { logger.Errorf("get whitelist error, close connection, err:%s", err) s = log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -363,7 +362,7 @@ func (t *TMQ) waitSignal(logger *logrus.Entry) { if !valid { logger.Errorf("ip not in whitelist, close connection, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) s = log.GetLogNow(isDebug) - t.session.Close() + _ = t.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) t.Unlock() s = log.GetLogNow(isDebug) @@ -416,7 +415,6 @@ type TMQSubscribeResp struct { func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSubscribeReq) { action := TMQSubscribe logger := t.logger.WithField("action", action).WithField(config.ReqIDKey, req.ReqID) - ctx = context.WithValue(ctx, LoggerKey, logger) isDebug := log.IsDebug() logger.Tracef("subscribe request:%+v", req) // lock for consumer and unsubscribed @@ -456,11 +454,11 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu Timing: wstool.GetDuration(ctx), }) return - } else { - logger.Errorf("tmq should have unsubscribed first") - wsTMQErrorMsg(ctx, session, logger, 0xffff, "tmq should have unsubscribed first", action, req.ReqID, nil) - return } + logger.Errorf("tmq should have unsubscribed first") + wsTMQErrorMsg(ctx, session, logger, 0xffff, "tmq should have unsubscribed first", action, req.ReqID, nil) + return + } tmqConfig := wrapper.TMQConfNew() defer func() { @@ -546,7 +544,7 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu if err != nil { logger.Errorf("get whitelist error:%s", err.Error()) t.wrapperCloseConsumer(logger, isDebug, cPointer) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) @@ -554,7 +552,7 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu if !valid { logger.Errorf("whitelist prohibits current IP access, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) t.wrapperCloseConsumer(logger, isDebug, cPointer) - wstool.WSErrorMsg(ctx, session, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) + wstool.WSErrorMsg(ctx, session, logger, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) return } logger.Trace("register change whitelist") @@ -562,7 +560,7 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu if err != nil { logger.Errorf("register change whitelist error:%s", err) t.wrapperCloseConsumer(logger, isDebug, cPointer) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } logger.Trace("register drop user") @@ -570,7 +568,7 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu if err != nil { logger.Errorf("register drop user error:%s", err) t.wrapperCloseConsumer(logger, isDebug, cPointer) - wstool.WSError(ctx, session, err, action, req.ReqID) + wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } t.conn = conn @@ -890,7 +888,7 @@ func (t *TMQ) fetchBlock(ctx context.Context, session *melody.Session, req *TMQF wsTMQErrorMsg(ctx, session, logger, 0xffff, "message type is not data", action, req.ReqID, &req.MessageID) return } - if message.buffer == nil || len(message.buffer) == 0 { + if len(message.buffer) == 0 { logger.Errorf("no fetch data") wsTMQErrorMsg(ctx, session, logger, 0xffff, "no fetch data", action, req.ReqID, &req.MessageID) return @@ -898,7 +896,7 @@ func (t *TMQ) fetchBlock(ctx context.Context, session *melody.Session, req *TMQF s = log.GetLogNow(isDebug) binary.LittleEndian.PutUint64(message.buffer, uint64(wstool.GetDuration(ctx))) logger.Debugf("handle data cost:%s", log.GetLogDuration(isDebug, s)) - session.WriteBinary(message.buffer) + wstool.WSWriteBinary(session, message.buffer, logger) } type TMQFetchRawReq struct { @@ -971,7 +969,7 @@ func (t *TMQ) fetchRawBlockNew(ctx context.Context, session *melody.Session, req logger.Tracef("fetch raw request:%+v", req) if t.consumer == nil { logger.Trace("tmq not init") - tmqFetchRawBlockErrorMsg(ctx, session, 0xffff, "tmq not init", req.ReqID, req.MessageID) + tmqFetchRawBlockErrorMsg(ctx, session, logger, 0xffff, "tmq not init", req.ReqID, req.MessageID) return } isDebug := log.IsDebug() @@ -981,16 +979,15 @@ func (t *TMQ) fetchRawBlockNew(ctx context.Context, session *melody.Session, req logger.Debugf("get message lock cost:%s", log.GetLogDuration(isDebug, s)) if t.tmpMessage.CPointer == nil { logger.Error("message has been freed") - tmqFetchRawBlockErrorMsg(ctx, session, 0xffff, "message has been freed", req.ReqID, req.MessageID) + tmqFetchRawBlockErrorMsg(ctx, session, logger, 0xffff, "message has been freed", req.ReqID, req.MessageID) return } message := t.tmpMessage if message.Index != req.MessageID { logger.Errorf("message ID are not equal, req:%d, message:%d", req.MessageID, message.Index) - tmqFetchRawBlockErrorMsg(ctx, session, 0xffff, "message ID is not equal", req.ReqID, req.MessageID) + tmqFetchRawBlockErrorMsg(ctx, session, logger, 0xffff, "message ID is not equal", req.ReqID, req.MessageID) return } - s = log.GetLogNow(isDebug) rawData := asynctmq.TaosaInitTMQRaw() defer asynctmq.TaosaFreeTMQRaw(rawData) errCode, closed := t.wrapperGetRaw(logger, isDebug, message.CPointer, rawData) @@ -1001,7 +998,7 @@ func (t *TMQ) fetchRawBlockNew(ctx context.Context, session *melody.Session, req if errCode != 0 { errStr := wrapper.TMQErr2Str(errCode) logger.Errorf("tmq get raw error, code:%d, msg:%s", errCode, errStr) - tmqFetchRawBlockErrorMsg(ctx, session, int(errCode), errStr, req.ReqID, req.MessageID) + tmqFetchRawBlockErrorMsg(ctx, session, logger, int(errCode), errStr, req.ReqID, req.MessageID) return } s = log.GetLogNow(isDebug) @@ -1319,11 +1316,7 @@ func wsTMQErrorMsg(ctx context.Context, session *melody.Session, logger *logrus. MessageID: messageID, }) logger.Tracef("write json:%s", b) - session.Write(b) -} - -func canGetMeta(messageType int32) bool { - return messageType == common.TMQ_RES_TABLE_META || messageType == common.TMQ_RES_METADATA + _ = session.Write(b) } func canGetData(messageType int32) bool { @@ -1814,7 +1807,7 @@ func (t *TMQ) wrapperCommitOffset(logger *logrus.Entry, isDebug bool, topic stri TMQRawBlock []byte //RawBlockLength 56 + MessageLen + RawBlockLength */ -func tmqFetchRawBlockErrorMsg(ctx context.Context, session *melody.Session, code int, message string, reqID uint64, messageID uint64) { +func tmqFetchRawBlockErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, reqID uint64, messageID uint64) { bufLength := 8 + 8 + 2 + 8 + 8 + 4 + 4 + len(message) + 8 buf := make([]byte, bufLength) binary.LittleEndian.PutUint64(buf, 0xffffffffffffffff) @@ -1826,7 +1819,7 @@ func tmqFetchRawBlockErrorMsg(ctx context.Context, session *melody.Session, code binary.LittleEndian.PutUint32(buf[38:], uint32(len(message))) copy(buf[42:], message) binary.LittleEndian.PutUint64(buf[42+len(message):], messageID) - session.WriteBinary(buf) + wstool.WSWriteBinary(session, buf, logger) } func wsFetchRawBlockMessage(ctx context.Context, buf []byte, reqID uint64, resultID uint64, MetaType uint16, blockLength uint32, rawBlock unsafe.Pointer) []byte { diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index 83aa7eed..733b66a3 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -137,7 +137,10 @@ func TestTMQ(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterTMQSubscribe = iota + 1 AfterTMQPoll @@ -434,7 +437,8 @@ func TestTMQ(t *testing.T) { return } <-finish - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + assert.NoError(t, err) time.Sleep(time.Second * 5) w = httptest.NewRecorder() @@ -482,7 +486,10 @@ func TestMeta(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterTMQSubscribe = iota + 1 AfterTMQPoll @@ -864,7 +871,8 @@ func TestMeta(t *testing.T) { } <-finish - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + assert.NoError(t, err) time.Sleep(time.Second * 5) w = httptest.NewRecorder() body = strings.NewReader("describe stb") @@ -950,7 +958,10 @@ func writeRaw(t *testing.T, rawData []byte) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterConnect = 1 AfterWriteRaw = 2 @@ -960,7 +971,7 @@ func writeRaw(t *testing.T, rawData []byte) { //total := 0 finish := make(chan struct{}) //var jsonResult [][]interface{} - testMessageHandler := func(messageType int, message []byte) error { + testMessageHandler := func(_ int, message []byte) error { //json switch status { case AfterConnect: @@ -1037,7 +1048,8 @@ func writeRaw(t *testing.T, rawData []byte) { return } <-finish - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + assert.NoError(t, err) } func TestTMQAutoCommit(t *testing.T) { @@ -1115,7 +1127,10 @@ func TestTMQAutoCommit(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterTMQSubscribe = iota + 1 AfterTMQPoll @@ -1402,7 +1417,8 @@ func TestTMQAutoCommit(t *testing.T) { return } <-finish - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + assert.NoError(t, err) assert.Equal(t, true, expectError) time.Sleep(time.Second * 5) w = httptest.NewRecorder() @@ -1514,7 +1530,10 @@ func TestTMQUnsubscribeAndSubscribe(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() const ( AfterTMQSubscribe = iota + 1 AfterTMQPoll @@ -1986,7 +2005,8 @@ func TestTMQUnsubscribeAndSubscribe(t *testing.T) { return } <-finish - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + assert.NoError(t, err) time.Sleep(time.Second * 5) w = httptest.NewRecorder() body = strings.NewReader("drop topic if exists test_tmq_ws_unsubscribe_topic") @@ -2082,7 +2102,10 @@ func TestTMQSeek(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() //sub { @@ -2198,28 +2221,28 @@ func TestTMQSeek(t *testing.T) { assert.Equal(t, 0, tmqFetchResp.Code) if tmqFetchResp.Completed { break - } else { - req := TMQFetchBlockReq{ - ReqID: 1, - MessageID: tmqFetchResp.MessageID, - } - b, _ := json.Marshal(req) - action, _ := json.Marshal(&wstool.WSAction{ - Action: TMQFetchBlock, - Args: b, - }) - err = ws.WriteMessage( - websocket.TextMessage, - action, - ) - assert.NoError(t, err) - mt, message, err := ws.ReadMessage() - assert.NoError(t, err) - assert.Equal(t, websocket.BinaryMessage, mt) - _, _, value := parseblock.ParseTmqBlock(message[8:], tmqFetchResp.FieldsTypes, tmqFetchResp.Rows, tmqFetchResp.Precision) - t.Log(value) - rowCount += 1 } + fetchBlockReq := TMQFetchBlockReq{ + ReqID: 1, + MessageID: tmqFetchResp.MessageID, + } + b, _ = json.Marshal(fetchBlockReq) + action, _ = json.Marshal(&wstool.WSAction{ + Action: TMQFetchBlock, + Args: b, + }) + err = ws.WriteMessage( + websocket.TextMessage, + action, + ) + assert.NoError(t, err) + mt, message, err = ws.ReadMessage() + assert.NoError(t, err) + assert.Equal(t, websocket.BinaryMessage, mt) + _, _, value := parseblock.ParseTmqBlock(message[8:], tmqFetchResp.FieldsTypes, tmqFetchResp.Rows, tmqFetchResp.Precision) + t.Log(value) + rowCount += 1 + } { req := TMQCommitReq{ @@ -2404,28 +2427,27 @@ func TestTMQSeek(t *testing.T) { assert.Equal(t, 0, tmqFetchResp.Code) if tmqFetchResp.Completed { break - } else { - req := TMQFetchBlockReq{ - ReqID: 1, - MessageID: tmqFetchResp.MessageID, - } - b, _ := json.Marshal(req) - action, _ := json.Marshal(&wstool.WSAction{ - Action: TMQFetchBlock, - Args: b, - }) - err = ws.WriteMessage( - websocket.TextMessage, - action, - ) - assert.NoError(t, err) - mt, message, err := ws.ReadMessage() - assert.NoError(t, err) - assert.Equal(t, websocket.BinaryMessage, mt) - _, _, value := parseblock.ParseTmqBlock(message[8:], tmqFetchResp.FieldsTypes, tmqFetchResp.Rows, tmqFetchResp.Precision) - t.Log(value) - rowCount += 1 } + fetchBlockReq := TMQFetchBlockReq{ + ReqID: 1, + MessageID: tmqFetchResp.MessageID, + } + b, _ = json.Marshal(fetchBlockReq) + action, _ = json.Marshal(&wstool.WSAction{ + Action: TMQFetchBlock, + Args: b, + }) + err = ws.WriteMessage( + websocket.TextMessage, + action, + ) + assert.NoError(t, err) + mt, message, err = ws.ReadMessage() + assert.NoError(t, err) + assert.Equal(t, websocket.BinaryMessage, mt) + _, _, value := parseblock.ParseTmqBlock(message[8:], tmqFetchResp.FieldsTypes, tmqFetchResp.Rows, tmqFetchResp.Precision) + t.Log(value) + rowCount += 1 } { req := TMQCommitReq{ @@ -2492,7 +2514,8 @@ func TestTMQSeek(t *testing.T) { assert.NoError(t, err) _, _, err = ws.ReadMessage() assert.NoError(t, err) - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + assert.NoError(t, err) time.Sleep(time.Second * 3) w = httptest.NewRecorder() body = strings.NewReader("drop topic if exists " + topic) @@ -2559,13 +2582,17 @@ func before(t *testing.T, dbName string, topic string) { assert.Equal(t, 0, code, message) } -func after(ws *websocket.Conn, dbName string, topic string) { +func after(ws *websocket.Conn, dbName string, topic string) error { b, _ := json.Marshal(TMQUnsubscribeReq{ReqID: 0}) _, _ = doWebSocket(ws, TMQUnsubscribe, b) - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + err := ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + return err + } time.Sleep(time.Second * 5) doHttpSql(fmt.Sprintf("drop topic if exists %s", topic)) doHttpSql(fmt.Sprintf("drop database if exists %s", dbName)) + return nil } func TestTMQ_Position_And_Committed(t *testing.T) { @@ -2581,8 +2608,15 @@ func TestTMQ_Position_And_Committed(t *testing.T) { t.Error(err) return } - defer ws.Close() - defer after(ws, dbName, topic) + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() // subscribe b, _ := json.Marshal(TMQSubscribeReq{ @@ -2662,9 +2696,15 @@ func TestTMQ_ListTopics(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() - defer after(ws, dbName, topic) + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() // subscribe b, _ := json.Marshal(TMQSubscribeReq{ @@ -2706,9 +2746,15 @@ func TestTMQ_CommitOffset(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() - defer after(ws, dbName, topic) + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() // subscribe b, _ := json.Marshal(TMQSubscribeReq{ @@ -2785,10 +2831,16 @@ func TestTMQ_PollAfterClose(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() before(t, dbName, topic) - defer after(ws, dbName, topic) + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() // subscribe b, _ := json.Marshal(TMQSubscribeReq{ @@ -2946,9 +2998,15 @@ func TestTMQ_FetchRawNew(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() - defer after(ws, dbName, topic) + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() // subscribe b, _ := json.Marshal(TMQSubscribeReq{ @@ -3069,9 +3127,15 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() - defer after(ws, dbName, topic) + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() // subscribe b, _ := json.Marshal(TMQSubscribeReq{ diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 9020793c..2970d39f 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -14,7 +14,6 @@ import ( "unsafe" "github.com/huskar-t/melody" - jsoniter "github.com/json-iterator/go" "github.com/sirupsen/logrus" "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/driver-go/v3/common/parser" @@ -133,7 +132,7 @@ func (h *messageHandler) waitSignal(logger *logrus.Entry) { func (h *messageHandler) signalExit(logger *logrus.Entry, isDebug bool) { logger.Trace("close session") s := log.GetLogNow(isDebug) - h.session.Close() + _ = h.session.Close() logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) h.Unlock() logger.Trace("close handler") @@ -168,8 +167,6 @@ type Request struct { Args json.RawMessage `json:"args"` } -var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary - func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) h.logger.Debugf("get ws message data:%s", data) @@ -288,7 +285,7 @@ type RequestID struct { ReqID uint64 `json:"req_id"` } -type dealFunc func(context.Context, Request, *logrus.Entry, bool, time.Time) Response +type dealFunc func(context.Context, Request, *logrus.Entry, bool) Response type dealBinaryRequest struct { action messageType @@ -297,7 +294,7 @@ type dealBinaryRequest struct { p0 unsafe.Pointer message []byte } -type dealBinaryFunc func(context.Context, dealBinaryRequest, *logrus.Entry, bool, time.Time) Response +type dealBinaryFunc func(context.Context, dealBinaryRequest, *logrus.Entry, bool) Response func (h *messageHandler) deal(ctx context.Context, session *melody.Session, request Request, f dealFunc) { h.wait.Add(1) @@ -324,9 +321,7 @@ func (h *messageHandler) deal(ctx context.Context, session *melody.Session, requ return } - s := log.GetLogNow(isDebug) - - resp := f(ctx, request, logger, isDebug, s) + resp := f(ctx, request, logger, isDebug) h.writeResponse(ctx, session, resp, request.Action, reqID, logger) }() } @@ -344,8 +339,6 @@ func (h *messageHandler) dealBinary(ctx context.Context, session *melody.Session return } - s := log.GetLogNow(isDebug) - req := dealBinaryRequest{ action: action, reqID: reqID, @@ -353,7 +346,7 @@ func (h *messageHandler) dealBinary(ctx context.Context, session *melody.Session p0: p0, message: message, } - resp := f(ctx, req, logger, isDebug, s) + resp := f(ctx, req, logger, isDebug) h.writeResponse(ctx, session, resp, action.String(), reqID, logger) }() } @@ -417,15 +410,15 @@ func (h *messageHandler) stop() { }) } -func (h *messageHandler) handleDefault(_ context.Context, request Request, _ *logrus.Entry, _ bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleDefault(_ context.Context, request Request, _ *logrus.Entry, _ bool) (resp Response) { return wsCommonErrorMsg(0xffff, fmt.Sprintf("unknown action %s", request.Action)) } -func (h *messageHandler) handleDefaultBinary(_ context.Context, req dealBinaryRequest, _ *logrus.Entry, _ bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleDefaultBinary(_ context.Context, req dealBinaryRequest, _ *logrus.Entry, _ bool) (resp Response) { return wsCommonErrorMsg(0xffff, fmt.Sprintf("unknown action %v", req.action)) } -func (h *messageHandler) handleVersion(_ context.Context, _ Request, _ *logrus.Entry, _ bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleVersion(_ context.Context, _ Request, _ *logrus.Entry, _ bool) (resp Response) { return &VersionResponse{Version: version.TaosClientVersion} } @@ -437,7 +430,7 @@ type ConnRequest struct { Mode *int `json:"mode"` } -func (h *messageHandler) handleConnect(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleConnect(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req ConnRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal connect request:%s, error, err:%s", string(request.Args), err) @@ -464,7 +457,7 @@ func (h *messageHandler) handleConnect(_ context.Context, request Request, logge return wsCommonErrorMsg(int(taosErr.Code), taosErr.ErrStr) } logger.Trace("get whitelist") - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) whitelist, err := tool.GetWhitelist(conn) logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { @@ -544,7 +537,7 @@ type QueryResponse struct { Precision int `json:"precision"` } -func (h *messageHandler) handleQuery(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleQuery(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req QueryRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal ws query request %s error, err:%s", request.Args, err) @@ -552,6 +545,7 @@ func (h *messageHandler) handleQuery(_ context.Context, request Request, logger } sqlType := monitor.WSRecordRequest(req.Sql) logger.Debugf("get query request, sql:%s", req.Sql) + s := log.GetLogNow(isDebug) handler := async.GlobalAsync.HandlerPool.Get() defer async.GlobalAsync.HandlerPool.Put(handler) logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) @@ -613,7 +607,7 @@ type FetchResponse struct { Rows int `json:"rows"` } -func (h *messageHandler) handleFetch(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleFetch(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req FetchRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal ws fetch request %s error, err:%s", request.Args, err) @@ -632,7 +626,7 @@ func (h *messageHandler) handleFetch(_ context.Context, request Request, logger logger.Errorf("result has been freed") return wsCommonErrorMsg(0xffff, "result has been freed") } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) handler := async.GlobalAsync.HandlerPool.Get() defer async.GlobalAsync.HandlerPool.Put(handler) logger.Debugf("get handler, cost:%s", log.GetLogDuration(isDebug, s)) @@ -669,7 +663,7 @@ type FetchBlockRequest struct { ID uint64 `json:"id"` } -func (h *messageHandler) handleFetchBlock(ctx context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleFetchBlock(ctx context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req FetchBlockRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal ws fetch block request, req:%s, error, err:%s", request.Args, err) @@ -696,6 +690,7 @@ func (h *messageHandler) handleFetchBlock(ctx context.Context, request Request, if blockLength <= 0 { return wsCommonErrorMsg(0xffff, "block length illegal") } + s := log.GetLogNow(isDebug) if cap(item.buf) < blockLength+16 { item.buf = make([]byte, 0, blockLength+16) } @@ -714,7 +709,7 @@ type FreeResultRequest struct { ID uint64 `json:"id"` } -func (h *messageHandler) handleFreeResult(_ context.Context, request Request, logger *logrus.Entry, _ bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleFreeResult(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { var req FreeResultRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal ws fetch request %s error, err:%s", request.Args, err) @@ -742,7 +737,7 @@ type SchemalessWriteResponse struct { TotalRows int32 `json:"total_rows"` } -func (h *messageHandler) handleSchemalessWrite(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleSchemalessWrite(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req SchemalessWriteRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal schemaless write request %s error, err:%s", request.Args, err) @@ -775,7 +770,7 @@ type StmtInitResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) handleStmtInit(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmtInit(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { stmtInit := syncinterface.TaosStmtInitWithReqID(h.conn, int64(request.ReqID), logger, isDebug) if stmtInit == nil { errStr := wrapper.TaosStmtErrStr(stmtInit) @@ -800,7 +795,7 @@ type StmtPrepareResponse struct { IsInsert bool `json:"is_insert"` } -func (h *messageHandler) handleStmtPrepare(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleStmtPrepare(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtPrepareRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt prepare request %s error, err:%s", request.Args, err) @@ -812,7 +807,7 @@ func (h *messageHandler) handleStmtPrepare(_ context.Context, request Request, l logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) logger.Trace("get stmt lock") stmtItem.Lock() logger.Debugf("get stmt lock cost:%s", log.GetLogDuration(isDebug, s)) @@ -850,7 +845,7 @@ type StmtSetTableNameResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) handleStmtSetTableName(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmtSetTableName(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtSetTableNameRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt set table name request %s error, err:%s", request.Args, err) @@ -889,7 +884,7 @@ type StmtSetTagsResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) handleStmtSetTags(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleStmtSetTags(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtSetTagsRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt set tags request %s error, err:%s", request.Args, err) @@ -921,7 +916,7 @@ func (h *messageHandler) handleStmtSetTags(_ context.Context, request Request, l logger.Trace("no tags") return &StmtSetTagsResponse{StmtID: req.StmtID} } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(tagNums, tagFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) tags := make([][]driver.Value, tagNums) @@ -955,7 +950,7 @@ type StmtBindResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) handleStmtBind(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleStmtBind(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtBindRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt bind tag request %s error, err:%s", request.Args, err) @@ -986,7 +981,7 @@ func (h *messageHandler) handleStmtBind(_ context.Context, request Request, logg logger.Trace("no columns") return &StmtBindResponse{StmtID: req.StmtID} } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(colNums, colFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) fieldTypes := make([]*types.ColumnType, colNums) @@ -1015,7 +1010,7 @@ func (h *messageHandler) handleStmtBind(_ context.Context, request Request, logg return &StmtBindResponse{StmtID: req.StmtID} } -func (h *messageHandler) handleBindMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleBindMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { block := tools.AddPointer(req.p0, uintptr(24)) columns := parser.RawBlockGetNumOfCols(block) rows := parser.RawBlockGetNumOfRows(block) @@ -1047,7 +1042,7 @@ func (h *messageHandler) handleBindMessage(_ context.Context, req dealBinaryRequ logger.Trace("no columns") return &StmtBindResponse{StmtID: req.id} } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(colNums, colFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) fieldTypes = make([]*types.ColumnType, colNums) @@ -1165,7 +1160,7 @@ type StmtAddBatchResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) handleStmtAddBatch(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmtAddBatch(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtAddBatchRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt add batch request %s error, err:%s", request.Args, err) @@ -1204,7 +1199,7 @@ type StmtExecResponse struct { Affected int `json:"affected"` } -func (h *messageHandler) handleStmtExec(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleStmtExec(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtExecRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt exec request %s error, err:%s", request.Args, err) @@ -1228,7 +1223,7 @@ func (h *messageHandler) handleStmtExec(_ context.Context, request Request, logg logger.Errorf("stmt execute error, err:%s", errStr) return wsStmtErrorMsg(code, errStr, req.StmtID) } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) affected := wrapper.TaosStmtAffectedRowsOnce(stmtItem.stmt) logger.Debugf("stmt_affected_rows_once, affected:%d, cost:%s", affected, log.GetLogDuration(isDebug, s)) return &StmtExecResponse{StmtID: req.StmtID, Affected: affected} @@ -1244,7 +1239,7 @@ type StmtCloseResponse struct { StmtID uint64 `json:"stmt_id,omitempty"` } -func (h *messageHandler) handleStmtClose(_ context.Context, request Request, logger *logrus.Entry, _ bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmtClose(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { var req StmtCloseRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt close request %s error, err:%s", request.Args, err) @@ -1273,7 +1268,7 @@ type StmtGetColFieldsResponse struct { Fields []*stmtCommon.StmtField `json:"fields"` } -func (h *messageHandler) handleStmtGetColFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleStmtGetColFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtGetColFieldsRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt get col request %s error, err:%s", request.Args, err) @@ -1303,7 +1298,7 @@ func (h *messageHandler) handleStmtGetColFields(_ context.Context, request Reque if colNums == 0 { return &StmtGetColFieldsResponse{StmtID: req.StmtID} } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(colNums, colFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) return &StmtGetColFieldsResponse{StmtID: req.StmtID, Fields: fields} @@ -1320,7 +1315,7 @@ type StmtGetTagFieldsResponse struct { Fields []*stmtCommon.StmtField `json:"fields,omitempty"` } -func (h *messageHandler) handleStmtGetTagFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleStmtGetTagFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtGetTagFieldsRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt get tags request %s error, err:%s", request.Args, err) @@ -1350,7 +1345,7 @@ func (h *messageHandler) handleStmtGetTagFields(_ context.Context, request Reque if tagNums == 0 { return &StmtGetTagFieldsResponse{StmtID: req.StmtID} } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(tagNums, tagFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) return &StmtGetTagFieldsResponse{StmtID: req.StmtID, Fields: fields} @@ -1372,7 +1367,7 @@ type StmtUseResultResponse struct { Precision int `json:"precision"` } -func (h *messageHandler) handleStmtUseResult(_ context.Context, request Request, logger *logrus.Entry, _ bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmtUseResult(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { var req StmtUseResultRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt use result request %s error, err:%s", request.Args, err) @@ -1416,7 +1411,7 @@ func (h *messageHandler) handleStmtUseResult(_ context.Context, request Request, } } -func (h *messageHandler) handleSetTagsMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleSetTagsMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { block := tools.AddPointer(req.p0, uintptr(24)) columns := parser.RawBlockGetNumOfCols(block) rows := parser.RawBlockGetNumOfRows(block) @@ -1453,7 +1448,7 @@ func (h *messageHandler) handleSetTagsMessage(_ context.Context, req dealBinaryR logger.Tracef("stmt tags count not match %d != %d", columns, tagNums) return wsStmtErrorMsg(0xffff, "stmt tags count not match", req.id) } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(tagNums, tagFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) @@ -1473,12 +1468,13 @@ func (h *messageHandler) handleSetTagsMessage(_ context.Context, req dealBinaryR return &StmtSetTagsResponse{StmtID: req.id} } -func (h *messageHandler) handleTMQRawMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleTMQRawMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { length := *(*uint32)(tools.AddPointer(req.p0, uintptr(24))) metaType := *(*uint16)(tools.AddPointer(req.p0, uintptr(28))) data := tools.AddPointer(req.p0, uintptr(30)) logger.Tracef("get write raw message, length:%d, metaType:%d", length, metaType) logger.Trace("get global lock for raw message") + s := log.GetLogNow(isDebug) h.Lock() logger.Debugf("get global lock cost:%s", log.GetLogDuration(isDebug, s)) defer h.Unlock() @@ -1498,7 +1494,7 @@ func (h *messageHandler) handleTMQRawMessage(_ context.Context, req dealBinaryRe return &BaseResponse{} } -func (h *messageHandler) handleRawBlockMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleRawBlockMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { numOfRows := *(*int32)(tools.AddPointer(req.p0, uintptr(24))) tableNameLength := *(*uint16)(tools.AddPointer(req.p0, uintptr(28))) tableName := make([]byte, tableNameLength) @@ -1507,7 +1503,7 @@ func (h *messageHandler) handleRawBlockMessage(_ context.Context, req dealBinary } rawBlock := tools.AddPointer(req.p0, uintptr(30+tableNameLength)) logger.Tracef("raw block message, table:%s, rows:%d", tableName, numOfRows) - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) h.Lock() logger.Debugf("get global lock cost:%s", log.GetLogDuration(isDebug, s)) defer h.Unlock() @@ -1525,7 +1521,7 @@ func (h *messageHandler) handleRawBlockMessage(_ context.Context, req dealBinary return &BaseResponse{} } -func (h *messageHandler) handleRawBlockMessageWithFields(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleRawBlockMessageWithFields(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { numOfRows := *(*int32)(tools.AddPointer(req.p0, uintptr(24))) tableNameLength := int(*(*uint16)(tools.AddPointer(req.p0, uintptr(28)))) tableName := make([]byte, tableNameLength) @@ -1537,7 +1533,6 @@ func (h *messageHandler) handleRawBlockMessageWithFields(_ context.Context, req numOfColumn := int(parser.RawBlockGetNumOfCols(rawBlock)) fieldsBlock := tools.AddPointer(req.p0, uintptr(30+tableNameLength+blockLength)) logger.Tracef("raw block message with fields, table:%s, rows:%d", tableName, numOfRows) - s = log.GetLogNow(isDebug) h.Lock() defer h.Unlock() if h.closed { @@ -1554,7 +1549,7 @@ func (h *messageHandler) handleRawBlockMessageWithFields(_ context.Context, req return &BaseResponse{} } -func (h *messageHandler) handleBinaryQuery(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool, s time.Time) Response { +func (h *messageHandler) handleBinaryQuery(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) Response { message := req.message if len(message) < 31 { return wsCommonErrorMsg(0xffff, "message length is too short") @@ -1574,7 +1569,7 @@ func (h *messageHandler) handleBinaryQuery(_ context.Context, req dealBinaryRequ } logger.Debugf("binary query, sql:%s", log.GetLogSql(bytesutil.ToUnsafeString(sql))) sqlType := monitor.WSRecordRequest(bytesutil.ToUnsafeString(sql)) - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) handler := async.GlobalAsync.HandlerPool.Get() defer async.GlobalAsync.HandlerPool.Put(handler) logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) @@ -1602,7 +1597,6 @@ func (h *messageHandler) handleBinaryQuery(_ context.Context, req dealBinaryRequ s = log.GetLogNow(isDebug) fieldsCount := wrapper.TaosNumFields(result.Res) logger.Debugf("num_fields cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) rowsHeader, _ := wrapper.ReadColumn(result.Res, fieldsCount) s = log.GetLogNow(isDebug) logger.Debugf("read column cost:%s", log.GetLogDuration(isDebug, s)) @@ -1622,7 +1616,7 @@ func (h *messageHandler) handleBinaryQuery(_ context.Context, req dealBinaryRequ } } -func (h *messageHandler) handleFetchRawBlock(ctx context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool, s time.Time) Response { +func (h *messageHandler) handleFetchRawBlock(ctx context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) Response { message := req.message if len(message) < 26 { return wsFetchRawBlockErrorMsg(0xffff, "message length is too short", req.reqID, req.id, uint64(wstool.GetDuration(ctx))) @@ -1643,7 +1637,7 @@ func (h *messageHandler) handleFetchRawBlock(ctx context.Context, req dealBinary logger.Errorf("result has been freed, result_id:%d", req.id) return wsFetchRawBlockErrorMsg(0xffff, "result has been freed", req.reqID, req.id, uint64(wstool.GetDuration(ctx))) } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) handler := async.GlobalAsync.HandlerPool.Get() defer async.GlobalAsync.HandlerPool.Put(handler) logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) @@ -1686,7 +1680,7 @@ type Stmt2BindResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) handleStmt2Bind(ctx context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool, s time.Time) Response { +func (h *messageHandler) handleStmt2Bind(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) Response { message := req.message if len(message) < 30 { return wsStmtErrorMsg(0xffff, "message length is too short", req.id) @@ -1726,7 +1720,7 @@ type GetCurrentDBResponse struct { DB string `json:"db"` } -func (h *messageHandler) handleGetCurrentDB(_ context.Context, _ Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleGetCurrentDB(_ context.Context, _ Request, logger *logrus.Entry, isDebug bool) (resp Response) { db, err := syncinterface.TaosGetCurrentDB(h.conn, logger, isDebug) if err != nil { var taosErr *errors2.TaosError @@ -1742,7 +1736,7 @@ type GetServerInfoResponse struct { Info string `json:"info"` } -func (h *messageHandler) handleGetServerInfo(_ context.Context, _ Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleGetServerInfo(_ context.Context, _ Request, logger *logrus.Entry, isDebug bool) (resp Response) { serverInfo := syncinterface.TaosGetServerInfo(h.conn, logger, isDebug) return &GetServerInfoResponse{Info: serverInfo} } @@ -1757,7 +1751,7 @@ type NumFieldsResponse struct { NumFields int `json:"num_fields"` } -func (h *messageHandler) handleNumFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleNumFields(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { var req NumFieldsRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt num params request %s error, err:%s", request.Args, err) @@ -1790,7 +1784,7 @@ type StmtNumParamsResponse struct { NumParams int `json:"num_params"` } -func (h *messageHandler) handleStmtNumParams(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmtNumParams(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtNumParamsRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt num params request %s error, err:%s", request.Args, err) @@ -1831,7 +1825,7 @@ type StmtGetParamResponse struct { Length int `json:"length"` } -func (h *messageHandler) handleStmtGetParam(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmtGetParam(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req StmtGetParamRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt get param request %s error, err:%s", request.Args, err) @@ -1872,7 +1866,7 @@ type Stmt2InitResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) handleStmt2Init(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmt2Init(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req Stmt2InitRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt2 init request %s error, err:%s", request.Args, err) @@ -1912,7 +1906,7 @@ type Stmt2PrepareResponse struct { FieldsCount int `json:"fields_count"` } -func (h *messageHandler) handleStmt2Prepare(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) Response { +func (h *messageHandler) handleStmt2Prepare(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) Response { var req Stmt2PrepareRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt2 prepare request %s error, err:%s", request.Args, err) @@ -1924,7 +1918,7 @@ func (h *messageHandler) handleStmt2Prepare(_ context.Context, request Request, logger.Errorf("stmt2 is nil, stmt_id:%d", req.StmtID) return wsStmtErrorMsg(0xffff, "stmt2 is nil", req.StmtID) } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) logger.Trace("get stmt2 lock") stmtItem.Lock() logger.Debugf("get stmt2 lock cost:%s", log.GetLogDuration(isDebug, s)) @@ -2026,7 +2020,7 @@ type Stmt2GetFieldsResponse struct { TagFields []*stmtCommon.StmtField `json:"tag_fields"` } -func (h *messageHandler) handleStmt2GetFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleStmt2GetFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req Stmt2GetFieldsRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt2 get fields request %s error, err:%s", request.Args, err) @@ -2087,7 +2081,7 @@ type Stmt2ExecResponse struct { Affected int `json:"affected"` } -func (h *messageHandler) handleStmt2Exec(_ context.Context, request Request, logger *logrus.Entry, isDebug bool, s time.Time) (resp Response) { +func (h *messageHandler) handleStmt2Exec(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { var req Stmt2ExecRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt2 exec request %s error, err:%s", request.Args, err) @@ -2111,7 +2105,7 @@ func (h *messageHandler) handleStmt2Exec(_ context.Context, request Request, log logger.Errorf("stmt2 execute error, err:%s", errStr) return wsStmtErrorMsg(code, errStr, req.StmtID) } - s = log.GetLogNow(isDebug) + s := log.GetLogNow(isDebug) logger.Tracef("stmt2 execute wait callback, stmt_id:%d", req.StmtID) result := <-stmtItem.caller.ExecResult logger.Debugf("stmt2 execute wait callback finish, affected:%d, res:%p, n:%d, cost:%s", result.Affected, result.Res, result.N, log.GetLogDuration(isDebug, s)) @@ -2129,7 +2123,7 @@ type Stmt2CloseResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) handleStmt2Close(_ context.Context, request Request, logger *logrus.Entry, _ bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmt2Close(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { var req StmtCloseRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt close request %s error, err:%s", request.Args, err) @@ -2162,7 +2156,7 @@ type Stmt2UseResultResponse struct { Precision int `json:"precision"` } -func (h *messageHandler) handleStmt2UseResult(_ context.Context, request Request, logger *logrus.Entry, _ bool, _ time.Time) (resp Response) { +func (h *messageHandler) handleStmt2UseResult(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { var req Stmt2UseResultRequest if err := json.Unmarshal(request.Args, &req); err != nil { logger.Errorf("unmarshal stmt2 use result request %s error, err:%s", request.Args, err) diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index f7244f43..66251775 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -144,7 +144,10 @@ func TestVersion(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() resp, err := doWebSocket(ws, wstool.ClientVersion, nil) assert.NoError(t, err) var versionResp VersionResponse @@ -182,7 +185,7 @@ func TestWsQuery(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -275,7 +278,7 @@ func TestWsQuery(t *testing.T) { assert.NoError(t, err) resultID, blockResult = parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) checkBlockResult(t, blockResult) - + assert.Equal(t, queryResp.ID, resultID) // fetch fetchReq = FetchRequest{ReqID: 9, ID: queryResp.ID} resp, err = doWebSocket(ws, WSFetch, &fetchReq) @@ -547,7 +550,8 @@ func TestWsQuery(t *testing.T) { fetchBlockResp, err = doWebSocket(ws, WSFetchBlock, &fetchBlockReq) assert.NoError(t, err) resultID, blockResult = parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) - + assert.Equal(t, queryResp.ID, resultID) + checkBlockResult(t, blockResult) // fetch fetchReq = FetchRequest{ReqID: 13, ID: queryResp.ID} resp, err = doWebSocket(ws, WSFetch, &fetchReq) @@ -635,7 +639,7 @@ func TestWsBinaryQuery(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -1050,7 +1054,6 @@ func TestWsBinaryQuery(t *testing.T) { assert.Equal(t, false, fetchRawBlockResp.Finished) blockResult = ReadBlockSimple(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]), queryResp.Precision) checkBlockResult(t, blockResult) - rawBlock = fetchRawBlockResp.RawBlock buffer.Reset() wstool.WriteUint64(&buffer, 13) // req id @@ -1093,6 +1096,7 @@ func TestWsBinaryQuery(t *testing.T) { err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) assert.NoError(t, err) _, resp, err = ws.ReadMessage() + assert.NoError(t, err) err = json.Unmarshal(resp, &queryResp) assert.NoError(t, err) assert.Equal(t, 65535, queryResp.Code, queryResp.Message) @@ -1109,6 +1113,7 @@ func TestWsBinaryQuery(t *testing.T) { err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) assert.NoError(t, err) _, resp, err = ws.ReadMessage() + assert.NoError(t, err) err = json.Unmarshal(resp, &queryResp) assert.NoError(t, err) assert.Equal(t, 65535, queryResp.Code, queryResp.Message) @@ -1125,6 +1130,7 @@ func TestWsBinaryQuery(t *testing.T) { err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) assert.NoError(t, err) _, resp, err = ws.ReadMessage() + assert.NoError(t, err) err = json.Unmarshal(resp, &queryResp) assert.NoError(t, err) assert.NotEqual(t, 0, queryResp.Code, queryResp.Message) @@ -1141,6 +1147,7 @@ func TestWsBinaryQuery(t *testing.T) { err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) assert.NoError(t, err) _, resp, err = ws.ReadMessage() + assert.NoError(t, err) err = json.Unmarshal(resp, &queryResp) assert.NoError(t, err) assert.Equal(t, 0, queryResp.Code, queryResp.Message) @@ -1158,6 +1165,7 @@ func TestWsBinaryQuery(t *testing.T) { err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) assert.NoError(t, err) _, resp, err = ws.ReadMessage() + assert.NoError(t, err) err = json.Unmarshal(resp, &queryResp) assert.NoError(t, err) assert.Equal(t, 0, queryResp.Code, queryResp.Message) @@ -1316,7 +1324,7 @@ func TestWsSchemaless(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -1481,7 +1489,7 @@ func TestWsStmt(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -1994,7 +2002,7 @@ func StmtQuery(t *testing.T, db string, prepareDataSql []string) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2138,7 +2146,7 @@ func TestStmtNumParams(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2207,7 +2215,7 @@ func TestStmtGetParams(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2276,7 +2284,7 @@ func TestGetCurrentDB(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2311,7 +2319,7 @@ func TestGetServerInfo(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2359,7 +2367,7 @@ func TestNumFields(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2416,7 +2424,7 @@ func TestWsStmt2(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2651,7 +2659,7 @@ func TestStmt2Prepare(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2777,7 +2785,7 @@ func TestStmt2GetFields(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -2990,7 +2998,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -3047,6 +3055,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { }, } b, err := stmtCommon.MarshalStmt2Binary(params, false, nil, nil) + assert.NoError(t, err) block.Write(b) err = ws.WriteMessage(websocket.BinaryMessage, block.Bytes()) @@ -3112,6 +3121,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.NoError(t, err) var closeResp Stmt2CloseResponse err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) assert.Equal(t, 0, closeResp.Code, closeResp.Message) } @@ -3124,7 +3134,7 @@ func TestWSConnect(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -3176,7 +3186,7 @@ func TestMode(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -3203,10 +3213,6 @@ func TestMode(t *testing.T) { } -func TestStmtBinary(t *testing.T) { - -} - func TestWSTMQWriteRaw(t *testing.T) { s := httptest.NewServer(router) defer s.Close() @@ -3216,7 +3222,7 @@ func TestWSTMQWriteRaw(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() @@ -3326,7 +3332,7 @@ func TestDropUser(t *testing.T) { return } defer func() { - err := ws.Close() + err = ws.Close() assert.NoError(t, err) }() defer doRestful("drop user test_ws_drop_user", "") diff --git a/controller/ws/wstool/const.go b/controller/ws/wstool/const.go index ef7e019d..80e22cf6 100644 --- a/controller/ws/wstool/const.go +++ b/controller/ws/wstool/const.go @@ -1,4 +1,6 @@ package wstool -const StartTimeKey = 1 +type ContextTypeInt int + +const StartTimeKey ContextTypeInt = 1 const ClientVersion = "version" diff --git a/controller/ws/wstool/error.go b/controller/ws/wstool/error.go index e1819a24..009c085c 100644 --- a/controller/ws/wstool/error.go +++ b/controller/ws/wstool/error.go @@ -2,9 +2,9 @@ package wstool import ( "context" - "encoding/json" "github.com/huskar-t/melody" + "github.com/sirupsen/logrus" tErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -16,21 +16,21 @@ type WSErrorResp struct { Timing int64 `json:"timing"` } -func WSErrorMsg(ctx context.Context, session *melody.Session, code int, message string, action string, reqID uint64) { - b, _ := json.Marshal(&WSErrorResp{ +func WSErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, action string, reqID uint64) { + data := &WSErrorResp{ Code: code & 0xffff, Message: message, Action: action, ReqID: reqID, Timing: GetDuration(ctx), - }) - session.Write(b) + } + WSWriteJson(session, logger, data) } -func WSError(ctx context.Context, session *melody.Session, err error, action string, reqID uint64) { +func WSError(ctx context.Context, session *melody.Session, logger *logrus.Entry, err error, action string, reqID uint64) { e, is := err.(*tErrors.TaosError) if is { - WSErrorMsg(ctx, session, int(e.Code)&0xffff, e.ErrStr, action, reqID) + WSErrorMsg(ctx, session, logger, int(e.Code)&0xffff, e.ErrStr, action, reqID) } else { - WSErrorMsg(ctx, session, 0xffff, err.Error(), action, reqID) + WSErrorMsg(ctx, session, logger, 0xffff, err.Error(), action, reqID) } } diff --git a/controller/ws/wstool/error_test.go b/controller/ws/wstool/error_test.go index 035e3115..75d5a7d3 100644 --- a/controller/ws/wstool/error_test.go +++ b/controller/ws/wstool/error_test.go @@ -12,6 +12,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/huskar-t/melody" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" tErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -26,15 +27,16 @@ func TestWSError(t *testing.T) { ErrStr: "test error", } commonErr := errors.New("test common error") + logger := logrus.New().WithField("test", "TestWSError") m.HandleMessage(func(session *melody.Session, data []byte) { if m.IsClosed() { return } switch data[0] { case '1': - WSError(ctx, session, taosErr, "test action", reqID) + WSError(ctx, session, logger, taosErr, "test action", reqID) case '2': - WSError(ctx, session, commonErr, "test common error action", reqID) + WSError(ctx, session, logger, commonErr, "test common error action", reqID) } }) @@ -50,7 +52,10 @@ func TestWSError(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() err = ws.WriteMessage(websocket.TextMessage, []byte{'1'}) assert.NoError(t, err) wt, resp, err := ws.ReadMessage() diff --git a/controller/ws/wstool/resp.go b/controller/ws/wstool/resp.go index ccb9fa08..0b544a85 100644 --- a/controller/ws/wstool/resp.go +++ b/controller/ws/wstool/resp.go @@ -18,14 +18,18 @@ type TDEngineRestfulResp struct { } func WSWriteJson(session *melody.Session, logger *logrus.Entry, data interface{}) { - b, _ := json.Marshal(data) + b, err := json.Marshal(data) + if err != nil { + logger.Errorf("marshal json failed:%s, data:%#v", err, data) + return + } logger.Tracef("write json:%s", b) - session.Write(b) + _ = session.Write(b) } func WSWriteBinary(session *melody.Session, data []byte, logger *logrus.Entry) { logger.Tracef("write binary:%+v", data) - session.WriteBinary(data) + _ = session.WriteBinary(data) } type WSVersionResp struct { diff --git a/controller/ws/wstool/resp_test.go b/controller/ws/wstool/resp_test.go index e883d18e..e47d737d 100644 --- a/controller/ws/wstool/resp_test.go +++ b/controller/ws/wstool/resp_test.go @@ -22,7 +22,7 @@ func TestWSWriteJson(t *testing.T) { Action: "version", Version: "1.0.0", } - m.HandleMessage(func(session *melody.Session, msg []byte) { + m.HandleMessage(func(session *melody.Session, _ []byte) { if m.IsClosed() { return } @@ -42,7 +42,10 @@ func TestWSWriteJson(t *testing.T) { t.Error(err) return } - defer ws.Close() + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() err = ws.WriteMessage(websocket.TextMessage, []byte{'1'}) assert.NoError(t, err) wt, resp, err := ws.ReadMessage() diff --git a/db/async/handlerpool.go b/db/async/handlerpool.go index a32cd527..79b37877 100644 --- a/db/async/handlerpool.go +++ b/db/async/handlerpool.go @@ -97,9 +97,7 @@ func (c *HandlerPool) Put(handler *Handler) { } c.mu.Unlock() return - } else { - c.handlers <- handler - c.mu.Unlock() - return } + c.handlers <- handler + c.mu.Unlock() } diff --git a/db/async/handlerpool_test.go b/db/async/handlerpool_test.go index 95f04e38..3f6202d2 100644 --- a/db/async/handlerpool_test.go +++ b/db/async/handlerpool_test.go @@ -32,7 +32,7 @@ func TestNewHandlerPool(t *testing.T) { }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name, func(_ *testing.T) { got := NewHandlerPool(tt.args.count) l := make([]*Handler, tt.args.count) for i := 0; i < tt.args.count; i++ { diff --git a/db/async/row.go b/db/async/row.go index d15d2f66..c012c2fa 100644 --- a/db/async/row.go +++ b/db/async/row.go @@ -17,7 +17,7 @@ import ( "github.com/taosdata/taosadapter/v3/tools/generator" ) -var FetchRowError = errors.New("fetch row error") +var ErrFetchRowError = errors.New("fetch row error") var GlobalAsync *Async type Async struct { @@ -72,35 +72,34 @@ func (a *Async) TaosExec(taosConnect unsafe.Pointer, logger *logrus.Entry, isDeb if result.N == 0 { logger.Trace("fetch finished") return execResult, nil - } else { - res = result.Res - for i := 0; i < result.N; i++ { - var row unsafe.Pointer - logger.Tracef("get thread lock for fetch row, row:%d", i) - s = log.GetLogNow(isDebug) - thread.AsyncLocker.Lock() - logger.Debugf("get thread lock for fetch row cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - row = wrapper.TaosFetchRow(res) - logger.Debugf("taos_fetch_row finish, row:%p, cost:%s", row, log.GetLogDuration(isDebug, s)) - thread.AsyncLocker.Unlock() - lengths := wrapper.FetchLengths(res, len(rowsHeader.ColNames)) - logger.Tracef("fetch lengths:%d", lengths) - values := make([]driver.Value, len(rowsHeader.ColNames)) - for j := range rowsHeader.ColTypes { - if row == nil { - logger.Error("fetch row error, row is nil") - return nil, FetchRowError - } - v := wrapper.FetchRow(row, j, rowsHeader.ColTypes[j], lengths[j], precision, timeFormat) - if vv, is := v.([]byte); is { - v = json.RawMessage(vv) - } - values[j] = v + } + res = result.Res + for i := 0; i < result.N; i++ { + var row unsafe.Pointer + logger.Tracef("get thread lock for fetch row, row:%d", i) + s = log.GetLogNow(isDebug) + thread.AsyncLocker.Lock() + logger.Debugf("get thread lock for fetch row cost:%s", log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + row = wrapper.TaosFetchRow(res) + logger.Debugf("taos_fetch_row finish, row:%p, cost:%s", row, log.GetLogDuration(isDebug, s)) + thread.AsyncLocker.Unlock() + lengths := wrapper.FetchLengths(res, len(rowsHeader.ColNames)) + logger.Tracef("fetch lengths:%d", lengths) + values := make([]driver.Value, len(rowsHeader.ColNames)) + for j := range rowsHeader.ColTypes { + if row == nil { + logger.Error("fetch row error, row is nil") + return nil, ErrFetchRowError + } + v := wrapper.FetchRow(row, j, rowsHeader.ColTypes[j], lengths[j], precision, timeFormat) + if vv, is := v.([]byte); is { + v = json.RawMessage(vv) } - logger.Tracef("get data, %v", values) - execResult.Data = append(execResult.Data, values) + values[j] = v } + logger.Tracef("get data, %v", values) + execResult.Data = append(execResult.Data, values) } } } diff --git a/db/async/row_test.go b/db/async/row_test.go index eee21be5..65789004 100644 --- a/db/async/row_test.go +++ b/db/async/row_test.go @@ -13,7 +13,7 @@ import ( func TestMain(m *testing.M) { config.Init() - log.SetLevel("trace") + _ = log.SetLevel("trace") m.Run() } @@ -48,7 +48,7 @@ func TestAsync_TaosExec(t *testing.T) { args: args{ taosConnect: conn, sql: "select 1", - timeFormat: func(ts int64, precision int) driver.Value { + timeFormat: func(ts int64, _ int) driver.Value { return ts }, }, diff --git a/db/asynctmq/tmq.go b/db/asynctmq/tmq.go index 6cbe9266..4582c849 100644 --- a/db/asynctmq/tmq.go +++ b/db/asynctmq/tmq.go @@ -1,6 +1,7 @@ //go:build !windows // +build !windows +// Package asynctmq is a cgo wrapper for TDengine tmq API package asynctmq /* @@ -696,10 +697,12 @@ func TaosaTMQConsumerCloseA(tmqThread unsafe.Pointer, tmq unsafe.Pointer, caller C.taosa_tmq_consumer_close_a_wrapper((*C.tmq_thread)(tmqThread), (*C.tmq_t)(tmq), C.uintptr_t(caller)) } +// TaosaTMQGetRawA malloc tmq_raw_data func TaosaInitTMQRaw() unsafe.Pointer { return unsafe.Pointer(C.malloc(C.sizeof_struct_tmq_raw_data)) } +// TaosaFreeTMQRaw free tmq_raw_data func TaosaFreeTMQRaw(raw unsafe.Pointer) { C.free(raw) } diff --git a/db/asynctmq/tmq_windows.go b/db/asynctmq/tmq_windows.go index ade79589..71c0b65c 100644 --- a/db/asynctmq/tmq_windows.go +++ b/db/asynctmq/tmq_windows.go @@ -1,6 +1,7 @@ //go:build windows // +build windows +// Package asynctmq is a cgo wrapper for TDengine tmq API package asynctmq /* diff --git a/db/asynctmq/tmqhandle/handler.go b/db/asynctmq/tmqhandle/handler.go index 58b369e5..b915a723 100644 --- a/db/asynctmq/tmqhandle/handler.go +++ b/db/asynctmq/tmqhandle/handler.go @@ -206,11 +206,9 @@ func (c *TMQHandlerPool) Put(handler *TMQHandler) { } c.mu.Unlock() return - } else { - c.handlers <- handler - c.mu.Unlock() - return } + c.handlers <- handler + c.mu.Unlock() } var GlobalTMQHandlerPoll = NewHandlerPool(10000) diff --git a/db/asynctmq/tmqhandle/handler_test.go b/db/asynctmq/tmqhandle/handler_test.go index db08bde5..746fa6fc 100644 --- a/db/asynctmq/tmqhandle/handler_test.go +++ b/db/asynctmq/tmqhandle/handler_test.go @@ -70,7 +70,8 @@ func TestTMQCaller(t *testing.T) { caller := NewTMQCaller() // Test PollCall - res := unsafe.Pointer(uintptr(0x12345)) + a := 1 + res := unsafe.Pointer(&a) caller.PollCall(res) if <-caller.PollResult != res { t.Error("PollCall failed") @@ -88,10 +89,11 @@ func TestTMQCaller(t *testing.T) { } // Test FetchRawBlockCall + b := 2 frbr := &FetchRawBlockResult{ Code: 456, BlockSize: 789, - Block: unsafe.Pointer(uintptr(0x67890)), + Block: unsafe.Pointer(&b), } caller.FetchRawBlockCall(frbr.Code, frbr.BlockSize, frbr.Block) result := <-caller.FetchRawBlockResult @@ -100,7 +102,8 @@ func TestTMQCaller(t *testing.T) { } // Test NewConsumerCall - consumer := unsafe.Pointer(uintptr(0xabcd)) + c := 3 + consumer := unsafe.Pointer(&c) errStr := "some error" caller.NewConsumerCall(consumer, errStr) result2 := <-caller.NewConsumerResult @@ -137,7 +140,8 @@ func TestTMQCaller(t *testing.T) { } // Test GetJsonMetaCall - meta := unsafe.Pointer(uintptr(0xeffe)) + d := 4 + meta := unsafe.Pointer(&d) caller.GetJsonMetaCall(meta) if <-caller.GetJsonMetaResult != meta { t.Error("GetJsonMetaCall failed") diff --git a/db/commonpool/pool.go b/db/commonpool/pool.go index 8125ea3a..499f24af 100644 --- a/db/commonpool/pool.go +++ b/db/commonpool/pool.go @@ -87,7 +87,7 @@ func NewConnectorPool(user, password string) (*ConnectorPool, error) { cp.ctx, cp.cancel = context.WithCancel(context.Background()) err = tool.RegisterChangePass(v, cp.changePassHandle) if err != nil { - p.Put(v) + _ = p.Put(v) p.Release() cp.putHandle() return nil, err @@ -95,7 +95,7 @@ func NewConnectorPool(user, password string) (*ConnectorPool, error) { // notify drop err = tool.RegisterDropUser(v, cp.dropUserHandle) if err != nil { - p.Put(v) + _ = p.Put(v) p.Release() cp.putHandle() return nil, err @@ -103,7 +103,7 @@ func NewConnectorPool(user, password string) (*ConnectorPool, error) { // whitelist ipNets, err := tool.GetWhitelist(v) if err != nil { - p.Put(v) + _ = p.Put(v) p.Release() cp.putHandle() return nil, err @@ -112,12 +112,12 @@ func NewConnectorPool(user, password string) (*ConnectorPool, error) { // register whitelist modify callback err = tool.RegisterChangeWhitelist(v, cp.whitelistChangeHandle) if err != nil { - p.Put(v) + _ = p.Put(v) p.Release() cp.putHandle() return nil, err } - p.Put(v) + _ = p.Put(v) go func() { defer func() { cp.logger.Warn("connector pool exit") @@ -174,11 +174,10 @@ func (cp *ConnectorPool) factory() (unsafe.Pointer, error) { return conn, err } -func (cp *ConnectorPool) close(v unsafe.Pointer) error { +func (cp *ConnectorPool) close(v unsafe.Pointer) { if v != nil { syncinterface.TaosClose(v, cp.logger, log.IsDebug()) } - return nil } var AuthFailureError = tErrors.NewError(httperror.TSDB_CODE_MND_AUTH_FAILURE, "Authentication failure") @@ -250,7 +249,7 @@ func (c *Conn) Put() error { } var singleGroup singleflight.Group -var ErrWhitelistForbidden error = errors.New("whitelist prohibits current IP access") +var ErrWhitelistForbidden = errors.New("whitelist prohibits current IP access") func GetConnection(user, password string, clientIp net.IP) (*Conn, error) { cp, err := getConnectionPool(user, password) @@ -266,16 +265,7 @@ func getConnectionPool(user, password string) (*ConnectorPool, error) { connectionPool := p.(*ConnectorPool) if connectionPool.verifyPassword(password) { return connectionPool, nil - } else { - cp, err, _ := singleGroup.Do(fmt.Sprintf("%s:%s", user, password), func() (interface{}, error) { - return getConnectorPoolSafe(user, password) - }) - if err != nil { - return nil, err - } - return cp.(*ConnectorPool), nil } - } else { cp, err, _ := singleGroup.Do(fmt.Sprintf("%s:%s", user, password), func() (interface{}, error) { return getConnectorPoolSafe(user, password) }) @@ -284,6 +274,13 @@ func getConnectionPool(user, password string) (*ConnectorPool, error) { } return cp.(*ConnectorPool), nil } + cp, err, _ := singleGroup.Do(fmt.Sprintf("%s:%s", user, password), func() (interface{}, error) { + return getConnectorPoolSafe(user, password) + }) + if err != nil { + return nil, err + } + return cp.(*ConnectorPool), nil } func VerifyClientIP(user, password string, clientIP net.IP) (authed bool, valid bool, connectionPoolExits bool) { @@ -325,21 +322,19 @@ func getConnectorPoolSafe(user, password string) (*ConnectorPool, error) { connectionPool := p.(*ConnectorPool) if connectionPool.verifyPassword(password) { return connectionPool, nil - } else { - newPool, err := NewConnectorPool(user, password) - if err != nil { - return nil, err - } - connectionPool.Release() - connectionMap.Store(user, newPool) - return newPool, nil } - } else { newPool, err := NewConnectorPool(user, password) if err != nil { return nil, err } + connectionPool.Release() connectionMap.Store(user, newPool) return newPool, nil } + newPool, err := NewConnectorPool(user, password) + if err != nil { + return nil, err + } + connectionMap.Store(user, newPool) + return newPool, nil } diff --git a/db/commonpool/pool_test.go b/db/commonpool/pool_test.go index 870cd892..4d7b4e03 100644 --- a/db/commonpool/pool_test.go +++ b/db/commonpool/pool_test.go @@ -24,7 +24,11 @@ func BenchmarkGetConnection(b *testing.B) { b.Error(err) return } - conn.Put() + err = conn.Put() + if err != nil { + b.Error(err) + return + } } } @@ -142,7 +146,8 @@ func TestChangePassword(t *testing.T) { wc, err := conn.pool.Get() assert.Equal(t, AuthFailureError, err) assert.Equal(t, unsafe.Pointer(nil), wc) - conn.Put() + err = conn.Put() + assert.NoError(t, err) conn2, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) assert.Error(t, err) @@ -155,7 +160,8 @@ func TestChangePassword(t *testing.T) { errNo = wrapper.TaosError(result2) wrapper.TaosFreeResult(result2) assert.Equal(t, 0, errNo) - conn3.Put() + err = conn3.Put() + assert.NoError(t, err) conn4, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) assert.NoError(t, err) @@ -164,7 +170,8 @@ func TestChangePassword(t *testing.T) { errNo = wrapper.TaosError(result3) wrapper.TaosFreeResult(result3) assert.Equal(t, 0, errNo) - conn3.Put() + err = conn3.Put() + assert.NoError(t, err) } func TestChangePasswordConcurrent(t *testing.T) { @@ -201,7 +208,8 @@ func TestChangePasswordConcurrent(t *testing.T) { wc1, err := conn.pool.Get() assert.Equal(t, AuthFailureError, err) assert.Equal(t, unsafe.Pointer(nil), wc1) - conn.Put() + err = conn.Put() + assert.NoError(t, err) wg := sync.WaitGroup{} wg.Add(5) for i := 0; i < 5; i++ { @@ -210,7 +218,8 @@ func TestChangePasswordConcurrent(t *testing.T) { conn2, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn2) - conn2.Put() + err = conn2.Put() + assert.NoError(t, err) }() } wg.Wait() @@ -225,7 +234,8 @@ func TestChangePasswordConcurrent(t *testing.T) { errNo = wrapper.TaosError(result2) wrapper.TaosFreeResult(result2) assert.Equal(t, 0, errNo) - conn3.Put() + err = conn3.Put() + assert.NoError(t, err) conn4, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) assert.NoError(t, err) @@ -234,5 +244,6 @@ func TestChangePasswordConcurrent(t *testing.T) { errNo = wrapper.TaosError(result3) wrapper.TaosFreeResult(result3) assert.Equal(t, 0, errNo) - conn3.Put() + err = conn3.Put() + assert.NoError(t, err) } diff --git a/db/init_test.go b/db/init_test.go index be4d676b..7daa74e8 100644 --- a/db/init_test.go +++ b/db/init_test.go @@ -10,7 +10,7 @@ import ( // @author: xftan // @date: 2021/12/14 15:06 // @description: test database init -func TestPrepareConnection(t *testing.T) { +func TestPrepareConnection(_ *testing.T) { viper.Set("taosConfigDir", "/etc/taos/") config.Init() PrepareConnection() diff --git a/db/syncinterface/wrapper_test.go b/db/syncinterface/wrapper_test.go index 52e808b0..c37a3b5c 100644 --- a/db/syncinterface/wrapper_test.go +++ b/db/syncinterface/wrapper_test.go @@ -26,7 +26,7 @@ const isDebug = true func TestMain(m *testing.M) { config.Init() - log.SetLevel("trace") + _ = log.SetLevel("trace") db.PrepareConnection() m.Run() } diff --git a/db/tool/notify_test.go b/db/tool/notify_test.go index 0879fbd0..6c9a8a17 100644 --- a/db/tool/notify_test.go +++ b/db/tool/notify_test.go @@ -228,6 +228,7 @@ func TestGetWhitelist(t *testing.T) { assert.Equal(t, []*net.IPNet{ipNet}, ipNets) ipNets, err = GetWhitelist(nil) assert.Error(t, err) + assert.Nil(t, ipNets) } func TestCheckWhitelist(t *testing.T) { @@ -267,7 +268,10 @@ func TestRegisterChangePass(t *testing.T) { defer wrapper.TaosClose(conn) err = exec(conn, "create user test_notify pass 'notify'") assert.NoError(t, err) - defer exec(conn, "drop user test_notify") + defer func() { + // ignore error + _ = exec(conn, "drop user test_notify") + }() conn2, err := wrapper.TaosConnect("", "test_notify", "notify", "", 0) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) @@ -299,7 +303,10 @@ func TestRegisterDropUser(t *testing.T) { defer wrapper.TaosClose(conn) err = exec(conn, "create user test_drop_user pass 'notify'") assert.NoError(t, err) - defer exec(conn, "drop user test_drop_user") + defer func() { + // ignore error + _ = exec(conn, "drop user test_drop_user") + }() conn2, err := wrapper.TaosConnect("", "test_drop_user", "notify", "", 0) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) diff --git a/httperror/errors.go b/httperror/errors.go index 09ae13fc..736a9b5d 100644 --- a/httperror/errors.go +++ b/httperror/errors.go @@ -1,4 +1,5 @@ package httperror +// Code generated from TDengine. DO NOT EDIT. const ( SUCCESS = 0x0 @@ -127,12 +128,12 @@ const ( TSDB_CODE_PAR_INVALID_FILL_TIME_RANGE = 0x263B ) -// 502 - +// RPC_NETWORK_UNAVAIL return 502 status code const ( RPC_NETWORK_UNAVAIL = 0x000B ) +// ErrorMsgMap is the map of error code and error message. var ErrorMsgMap = map[int]string{ TSDB_CODE_RPC_AUTH_FAILURE: "Authentication failure", HTTP_SERVER_OFFLINE: "http server is not onlin", diff --git a/log/logger.go b/log/logger.go index 0f169dd4..8b5f16e0 100644 --- a/log/logger.go +++ b/log/logger.go @@ -60,14 +60,18 @@ func NewFileHook(formatter logrus.Formatter, writer io.WriteCloser) *FileHook { //can be optimized by tryLock fh.Lock() if fh.buf.Len() > 0 { - fh.flush() + // flush log ignore error, because it have been printed to stderr + _ = fh.flush() } fh.Unlock() case <-exit: fh.Lock() - fh.flush() + _ = fh.flush() fh.Unlock() - writer.Close() + err := writer.Close() + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "close log file error:", err) + } ticker.Stop() close(finish) return @@ -106,7 +110,7 @@ func (f *FileHook) flush() error { _, err := f.writer.Write(f.buf.Bytes()) f.buf.Reset() if err != nil { - fmt.Fprintln(os.Stderr, "write log error:", err) + _, _ = fmt.Fprintln(os.Stderr, "write log error:", err) } return err } @@ -134,9 +138,9 @@ func ConfigLog() { if err != nil { panic(err) } - fmt.Fprintln(writer, "==================================================") - fmt.Fprintln(writer, " new log file") - fmt.Fprintln(writer, "==================================================") + _, _ = fmt.Fprintln(writer, "==================================================") + _, _ = fmt.Fprintln(writer, " new log file") + _, _ = fmt.Fprintln(writer, "==================================================") hook := NewFileHook(globalLogFormatter, writer) logger.AddHook(hook) if config.Conf.Log.EnableRecordHttpSql { diff --git a/log/logger_test.go b/log/logger_test.go index 086521d0..41c49350 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -21,7 +21,7 @@ func TestMain(m *testing.M) { // @author: xftan // @date: 2021/12/14 15:07 // @description: test config log -func TestConfigLog(t *testing.T) { +func TestConfigLog(_ *testing.T) { config.Conf.Log.EnableRecordHttpSql = true ConfigLog() logger := GetLogger("TST") diff --git a/log/web_test.go b/log/web_test.go index c9bea983..103fd89c 100644 --- a/log/web_test.go +++ b/log/web_test.go @@ -21,7 +21,7 @@ func TestGinLog(t *testing.T) { router.POST("/rest/sql", func(c *gin.Context) { c.Status(200) }) - router.POST("/panic", func(c *gin.Context) { + router.POST("/panic", func(_ *gin.Context) { panic("test") }) w := httptest.NewRecorder() diff --git a/monitor/keeper.go b/monitor/keeper.go index f71ea80b..82ed708a 100644 --- a/monitor/keeper.go +++ b/monitor/keeper.go @@ -268,7 +268,7 @@ func StartUpload() { go func() { nextUploadTime := getNextUploadTime() logger.Debugf("start upload keeper when %s", nextUploadTime.Format("2006-01-02 15:04:05.000000000")) - startTimer := time.NewTimer(nextUploadTime.Sub(time.Now())) + startTimer := time.NewTimer(time.Until(nextUploadTime)) <-startTimer.C startTimer.Stop() go func() { @@ -352,7 +352,7 @@ func doRequest(client *http.Client, data []byte, reqID int64) error { if err != nil { return err } - resp.Body.Close() + _ = resp.Body.Close() logger.Debugf("upload_id:0x%x, upload to keeper success", reqID) if resp.StatusCode != http.StatusOK { logger.Errorf("upload_id:0x%x, upload keeper error, code: %d", reqID, resp.StatusCode) diff --git a/monitor/keeper_test.go b/monitor/keeper_test.go index 976e26a3..e8b82695 100644 --- a/monitor/keeper_test.go +++ b/monitor/keeper_test.go @@ -2,7 +2,7 @@ package monitor_test import ( "context" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -224,7 +224,11 @@ func TestUpload(t *testing.T) { done := make(chan struct{}) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost && r.URL.Path == "/upload" { - b, _ := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + return + } t.Log(string(b)) times += 1 if times == 1 { diff --git a/plugin/collectd/plugin_test.go b/plugin/collectd/plugin_test.go index 41e5f1cd..ec64395d 100644 --- a/plugin/collectd/plugin_test.go +++ b/plugin/collectd/plugin_test.go @@ -30,11 +30,15 @@ func TestCollectd(t *testing.T) { } afC, err := af.NewConnector(conn) assert.NoError(t, err) - defer afC.Close() + defer func() { + err = afC.Close() + assert.NoError(t, err) + }() _, err = afC.Exec("drop database if exists collectd") assert.NoError(t, err) _, err = afC.Exec("create database if not exists collectd") assert.NoError(t, err) + //nolint:staticcheck rand.Seed(time.Now().UnixNano()) p := &Plugin{} config.Init() @@ -75,7 +79,10 @@ func TestCollectd(t *testing.T) { t.Error(err) return } - defer c.Close() + defer func() { + err = c.Close() + assert.NoError(t, err) + }() _, err = c.Write(bytes) if err != nil { t.Error(err) @@ -97,7 +104,10 @@ func TestCollectd(t *testing.T) { t.Error(err) return } - defer r.Close() + defer func() { + err = r.Close() + assert.NoError(t, err) + }() values := make([]driver.Value, 1) err = r.Next(values) assert.NoError(t, err) @@ -111,7 +121,10 @@ func TestCollectd(t *testing.T) { t.Error(err) return } - defer r.Close() + defer func() { + err = r.Close() + assert.NoError(t, err) + }() values = make([]driver.Value, 1) err = r.Next(values) assert.NoError(t, err) diff --git a/plugin/influxdb/plugin_test.go b/plugin/influxdb/plugin_test.go index 91f4b7ee..6b0ce01a 100644 --- a/plugin/influxdb/plugin_test.go +++ b/plugin/influxdb/plugin_test.go @@ -24,6 +24,7 @@ import ( // @date: 2021/12/14 15:07 // @description: test influxdb plugin func TestInfluxdb(t *testing.T) { + //nolint:staticcheck rand.Seed(time.Now().UnixNano()) viper.Set("smlAutoCreateDB", true) defer viper.Set("smlAutoCreateDB", false) @@ -39,7 +40,10 @@ func TestInfluxdb(t *testing.T) { } afC, err := af.NewConnector(conn) assert.NoError(t, err) - defer afC.Close() + defer func() { + err = afC.Close() + assert.NoError(t, err) + }() _, err = afC.Exec("create database if not exists test_plugin_influxdb") assert.NoError(t, err) defer func() { @@ -56,7 +60,10 @@ func TestInfluxdb(t *testing.T) { err = p.Start() assert.NoError(t, err) number := rand.Int31() - defer p.Stop() + defer func() { + err = p.Stop() + assert.NoError(t, err) + }() w := httptest.NewRecorder() reader := strings.NewReader(fmt.Sprintf("measurement,host=host1 field1=%di,field2=2.0,fieldKey=\"Launch 🚀\" %d", number, time.Now().UnixNano())) req, _ := http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb", reader) @@ -81,7 +88,10 @@ func TestInfluxdb(t *testing.T) { t.Error(err) return } - defer r.Close() + defer func() { + err = r.Close() + assert.NoError(t, err) + }() fieldCount := len(r.Columns()) values := make([]driver.Value, fieldCount) err = r.Next(values) @@ -112,7 +122,10 @@ func TestInfluxdb(t *testing.T) { t.Error(err) return } - defer r.Close() + defer func() { + err = r.Close() + assert.NoError(t, err) + }() values = make([]driver.Value, 1) err = r.Next(values) assert.NoError(t, err) diff --git a/plugin/interface_test.go b/plugin/interface_test.go index 578b42e9..1e31c82e 100644 --- a/plugin/interface_test.go +++ b/plugin/interface_test.go @@ -9,7 +9,7 @@ import ( type fakePlugin struct { } -func (f *fakePlugin) Init(r gin.IRouter) error { +func (f *fakePlugin) Init(_ gin.IRouter) error { return nil } @@ -32,7 +32,7 @@ func (f *fakePlugin) Version() string { // @author: xftan // @date: 2021/12/14 15:09 // @description: test plugin register -func TestRegister(t *testing.T) { +func TestRegister(_ *testing.T) { Register(&fakePlugin{}) r := gin.Default() Init(r) diff --git a/plugin/nodeexporter/plugin.go b/plugin/nodeexporter/plugin.go index 11b4f353..61e39cf0 100644 --- a/plugin/nodeexporter/plugin.go +++ b/plugin/nodeexporter/plugin.go @@ -200,7 +200,12 @@ func (p *NodeExporter) Gather() { logger.WithError(err).Errorln("commonpool.GetConnection error") return } - defer conn.Put() + defer func() { + err = conn.Put() + if err != nil { + logger.WithError(err).Errorln("conn.Put error") + } + }() for _, req := range p.request { err := p.requestSingle(conn.TaosConnection, req) if err != nil { @@ -214,7 +219,9 @@ func (p *NodeExporter) requestSingle(conn unsafe.Pointer, req *Req) error { if err != nil { return err } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() if resp.StatusCode != http.StatusOK { return fmt.Errorf("%s returned HTTP status %s", req.req.URL, resp.Status) diff --git a/plugin/nodeexporter/plugin_test.go b/plugin/nodeexporter/plugin_test.go index 6d0ee9c9..71bb0a60 100644 --- a/plugin/nodeexporter/plugin_test.go +++ b/plugin/nodeexporter/plugin_test.go @@ -39,7 +39,10 @@ func TestNodeExporter_Gather(t *testing.T) { config.Init() db.PrepareConnection() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(s)) + _, err := w.Write([]byte(s)) + if err != nil { + return + } })) defer ts.Close() api := ts.URL @@ -49,7 +52,12 @@ func TestNodeExporter_Gather(t *testing.T) { viper.Set("node_exporter.ttl", 1000) conn, err := af.Open("", "", "", "", 0) assert.NoError(t, err) - defer conn.Close() + defer func() { + err = conn.Close() + if err != nil { + t.Error(err) + } + }() _, err = conn.Exec("create database if not exists node_exporter precision 'ns'") assert.NoError(t, err) err = conn.SelectDB("node_exporter") @@ -62,7 +70,12 @@ func TestNodeExporter_Gather(t *testing.T) { time.Sleep(time.Second * 2) rows, err := conn.Query("select last(`value`) as `value` from node_exporter.test_metric;") assert.NoError(t, err) - defer rows.Close() + defer func() { + err = rows.Close() + if err != nil { + t.Error(err) + } + }() assert.Equal(t, 1, len(rows.Columns())) d := make([]driver.Value, 1) err = rows.Next(d) @@ -77,7 +90,12 @@ func TestNodeExporter_Gather(t *testing.T) { t.Error(err) return } - defer rows.Close() + defer func() { + err = rows.Close() + if err != nil { + t.Error(err) + } + }() values := make([]driver.Value, 1) err = rows.Next(values) assert.NoError(t, err) diff --git a/plugin/opentsdb/plugin.go b/plugin/opentsdb/plugin.go index 4b42c869..0ceb0477 100644 --- a/plugin/opentsdb/plugin.go +++ b/plugin/opentsdb/plugin.go @@ -230,11 +230,10 @@ func (p *Plugin) insertTelnet(c *gin.Context) { if err != nil { if err == io.EOF { break - } else { - logger.Errorf("read line error, err:%s", err) - p.errorResponse(c, http.StatusBadRequest, err) - return } + logger.Errorf("read line error, err:%s", err) + p.errorResponse(c, http.StatusBadRequest, err) + return } tmp.Write(l) if !hasNext { diff --git a/plugin/opentsdb/plugin_test.go b/plugin/opentsdb/plugin_test.go index 404c3716..10a4d4c5 100644 --- a/plugin/opentsdb/plugin_test.go +++ b/plugin/opentsdb/plugin_test.go @@ -24,6 +24,7 @@ import ( // @date: 2021/12/14 15:08 // @description: test opentsdb test func TestOpentsdb(t *testing.T) { + //nolint:staticcheck rand.Seed(time.Now().UnixNano()) viper.Set("smlAutoCreateDB", true) defer viper.Set("smlAutoCreateDB", false) @@ -40,7 +41,10 @@ func TestOpentsdb(t *testing.T) { } afC, err := af.NewConnector(conn) assert.NoError(t, err) - defer afC.Close() + defer func() { + err = afC.Close() + assert.NoError(t, err) + }() _, err = afC.Exec("create database if not exists test_plugin_opentsdb_http_json") assert.NoError(t, err) err = p.Init(router) @@ -48,7 +52,10 @@ func TestOpentsdb(t *testing.T) { err = p.Start() assert.NoError(t, err) number := rand.Int31() - defer p.Stop() + defer func() { + err = p.Stop() + assert.NoError(t, err) + }() w := httptest.NewRecorder() reader := strings.NewReader(fmt.Sprintf("put metric %d %d host=web01 interface=eth0 ", time.Now().Unix(), number)) req, _ := http.NewRequest("POST", "/put/telnet/test_plugin_opentsdb_http_telnet?ttl=1000", reader) @@ -96,7 +103,10 @@ func TestOpentsdb(t *testing.T) { t.Error(err) return } - defer r.Close() + defer func() { + err = r.Close() + assert.NoError(t, err) + }() values := make([]driver.Value, 1) err = r.Next(values) assert.NoError(t, err) @@ -109,7 +119,10 @@ func TestOpentsdb(t *testing.T) { t.Error(err) return } - defer r2.Close() + defer func() { + err = r2.Close() + assert.NoError(t, err) + }() values = make([]driver.Value, 1) err = r2.Next(values) assert.NoError(t, err) @@ -123,7 +136,10 @@ func TestOpentsdb(t *testing.T) { t.Error(err) return } - defer rows.Close() + defer func() { + err = rows.Close() + assert.NoError(t, err) + }() values = make([]driver.Value, 1) err = rows.Next(values) assert.NoError(t, err) @@ -137,7 +153,10 @@ func TestOpentsdb(t *testing.T) { t.Error(err) return } - defer rows.Close() + defer func() { + err = rows.Close() + assert.NoError(t, err) + }() values = make([]driver.Value, 1) err = rows.Next(values) assert.NoError(t, err) diff --git a/plugin/opentsdbtelnet/plugin.go b/plugin/opentsdbtelnet/plugin.go index 72de7074..f5c12f90 100644 --- a/plugin/opentsdbtelnet/plugin.go +++ b/plugin/opentsdbtelnet/plugin.go @@ -19,6 +19,7 @@ import ( "github.com/taosdata/taosadapter/v3/plugin" "github.com/taosdata/taosadapter/v3/schemaless/inserter" "github.com/taosdata/taosadapter/v3/tools/generator" + "github.com/taosdata/taosadapter/v3/tools/joinerror" ) var logger = log.GetLogger("PLG").WithField("mod", "telnet") @@ -138,7 +139,7 @@ type Connection struct { func (c *Connection) handle() { defer func() { c.l.wg.Done() - c.conn.Close() + _ = c.conn.Close() c.l.accept <- true c.l.forget(c.id) }() @@ -214,11 +215,15 @@ func (c *Connection) handle() { } s = s[:len(s)-1] if s == versionCommand { - c.conn.Write([]byte{'1'}) + _, err = c.conn.Write([]byte{'1'}) + if err != nil { + logger.WithError(err).Error("conn write") + c.close() + return + } continue - } else { - dataChan <- s } + dataChan <- s } } } @@ -233,7 +238,6 @@ func (c *Connection) close() { func (l *TCPListener) stop() error { close(l.done) - l.listener.Close() var tcpConnList []*Connection l.cleanup.Lock() for _, conn := range l.connList { @@ -244,7 +248,7 @@ func (l *TCPListener) stop() error { conn.close() } l.wg.Wait() - return nil + return l.listener.Close() } func (p *Plugin) Init(_ gin.IRouter) error { @@ -284,10 +288,17 @@ func (p *Plugin) Stop() error { if p.done != nil { close(p.done) } + var errs []error for _, listener := range p.TCPListeners { - listener.stop() + err := listener.stop() + if err != nil { + errs = append(errs, err) + } } p.wg.Wait() + if len(errs) > 0 { + return joinerror.Join(errs...) + } return nil } diff --git a/plugin/opentsdbtelnet/plugin_test.go b/plugin/opentsdbtelnet/plugin_test.go index 86a8df24..6a8ffd10 100644 --- a/plugin/opentsdbtelnet/plugin_test.go +++ b/plugin/opentsdbtelnet/plugin_test.go @@ -22,6 +22,7 @@ import ( // @date: 2021/12/14 15:08 // @description: test opentsdb_telnet plugin func TestPlugin(t *testing.T) { + //nolint:staticcheck rand.Seed(time.Now().UnixNano()) p := &opentsdbtelnet.Plugin{} config.Init() @@ -36,7 +37,10 @@ func TestPlugin(t *testing.T) { } afC, err := af.NewConnector(conn) assert.NoError(t, err) - defer afC.Close() + defer func() { + err = afC.Close() + assert.NoError(t, err) + }() _, err = afC.Exec("create database if not exists opentsdb_telnet") assert.NoError(t, err) err = p.Init(nil) @@ -50,7 +54,10 @@ func TestPlugin(t *testing.T) { number := rand.Int31() c, err := net.Dial("tcp", "127.0.0.1:6046") assert.NoError(t, err) - defer c.Close() + defer func() { + err = c.Close() + assert.NoError(t, err) + }() _, err = c.Write([]byte(fmt.Sprintf("put sys.if.bytes.out 1479496100 %d host=web01 interface=eth0\r\n", number))) assert.NoError(t, err) time.Sleep(time.Second) @@ -70,7 +77,10 @@ func TestPlugin(t *testing.T) { t.Error(err) return } - defer r.Close() + defer func() { + err = r.Close() + assert.NoError(t, err) + }() values := make([]driver.Value, 1) err = r.Next(values) assert.NoError(t, err) @@ -84,7 +94,10 @@ func TestPlugin(t *testing.T) { t.Error(err) return } - defer rows.Close() + defer func() { + err = rows.Close() + assert.NoError(t, err) + }() values = make([]driver.Value, 1) err = rows.Next(values) assert.NoError(t, err) diff --git a/plugin/prometheus/plugin.go b/plugin/prometheus/plugin.go index 42c8dead..5b462a1f 100644 --- a/plugin/prometheus/plugin.go +++ b/plugin/prometheus/plugin.go @@ -36,7 +36,7 @@ func (p *Plugin) Init(r gin.IRouter) error { return nil } r.Use(plugin.Auth(func(c *gin.Context, code int, err error) { - c.AbortWithError(code, err) + _ = c.AbortWithError(code, err) })) r.POST("remote_read/:db", func(c *gin.Context) { if monitor.QueryPaused() { diff --git a/plugin/prometheus/plugin_test.go b/plugin/prometheus/plugin_test.go index a6e5b954..7b77e1ff 100644 --- a/plugin/prometheus/plugin_test.go +++ b/plugin/prometheus/plugin_test.go @@ -20,6 +20,7 @@ import ( ) func TestMain(m *testing.M) { + //nolint:staticcheck rand.Seed(time.Now().UnixNano()) config.Init() viper.Set("prometheus.enable", true) @@ -47,7 +48,9 @@ func TestPrometheus(t *testing.T) { err = p.Start() assert.NoError(t, err) number := rand.Float64() - defer p.Stop() + defer func() { + _ = p.Stop() + }() w := httptest.NewRecorder() now := time.Now().UnixNano() / 1e6 var wReq = prompb.WriteRequest{ @@ -133,7 +136,9 @@ func TestPrometheusEscapeString(t *testing.T) { err = p.Start() assert.NoError(t, err) number := rand.Float64() - defer p.Stop() + defer func() { + _ = p.Stop() + }() w := httptest.NewRecorder() now := time.Now().UnixNano() / 1e6 var wReq = prompb.WriteRequest{ @@ -216,7 +221,9 @@ func TestPrometheusWithTTL(t *testing.T) { err = p.Start() assert.NoError(t, err) number := rand.Float64() - defer p.Stop() + defer func() { + _ = p.Stop() + }() w := httptest.NewRecorder() now := time.Now().UnixNano() / 1e6 var wReq = prompb.WriteRequest{ @@ -302,7 +309,9 @@ func TestPrometheusEscape(t *testing.T) { err = p.Start() assert.NoError(t, err) number := rand.Float64() - defer p.Stop() + defer func() { + _ = p.Stop() + }() w := httptest.NewRecorder() now := time.Now().UnixNano() / 1e6 var wReq = prompb.WriteRequest{ @@ -387,7 +396,9 @@ func TestPrometheusWithLimit(t *testing.T) { err = p.Start() assert.NoError(t, err) number := rand.Float64() - defer p.Stop() + defer func() { + _ = p.Stop() + }() w := httptest.NewRecorder() now := time.Now().UnixNano() / 1e6 var wReq = prompb.WriteRequest{ diff --git a/plugin/prometheus/process.go b/plugin/prometheus/process.go index edb896a8..f1880fb7 100644 --- a/plugin/prometheus/process.go +++ b/plugin/prometheus/process.go @@ -72,14 +72,12 @@ func processWrite(taosConn unsafe.Pointer, req *prompbWrite.WriteRequest, db str } logger.Debug("retry processWrite cost", time.Since(start)) return nil - } else { - logger.WithError(err).Error(bp.String()) - return err } - } else { logger.WithError(err).Error(bp.String()) return err } + logger.WithError(err).Error(bp.String()) + return err } return nil } @@ -120,7 +118,7 @@ func generateWriteSql(timeseries []prompbWrite.TimeSeries, sql *bytes.Buffer, tt jsonBuilder.WriteString(bytesutil.ToUnsafeString(timeseries[i].Labels[labelIndex].Value)) } jsonBuilder.WriteObjectEnd() - jsonBuilder.Flush() + _ = jsonBuilder.Flush() sql.Write(escapeBytes(bb.B)) sql.WriteString("') ") if ttl > 0 { @@ -140,7 +138,7 @@ func generateWriteSql(timeseries []prompbWrite.TimeSeries, sql *bytes.Buffer, tt } else if math.IsInf(sample.Value, 0) { sql.WriteString("null") } else { - fmt.Fprintf(sql, "%v", sample.Value) + _, _ = fmt.Fprintf(sql, "%v", sample.Value) } sql.WriteString(") ") } @@ -268,7 +266,7 @@ func generateReadSql(query *prompb.Query) (string, error) { } if config.Conf.RestfulRowLimit > -1 { sql.WriteString(" limit ") - fmt.Fprintf(sql, "%d", config.Conf.RestfulRowLimit) + _, _ = fmt.Fprintf(sql, "%d", config.Conf.RestfulRowLimit) } return sql.String(), nil diff --git a/plugin/statsd/plugin_test.go b/plugin/statsd/plugin_test.go index 15cb7093..06a03807 100644 --- a/plugin/statsd/plugin_test.go +++ b/plugin/statsd/plugin_test.go @@ -21,6 +21,7 @@ import ( // @date: 2021/12/14 15:08 // @description: test statsd plugin func TestStatsd(t *testing.T) { + //nolint:staticcheck rand.Seed(time.Now().UnixNano()) p := &Plugin{} config.Init() @@ -35,7 +36,10 @@ func TestStatsd(t *testing.T) { } afC, err := af.NewConnector(conn) assert.NoError(t, err) - defer afC.Close() + defer func() { + err = afC.Close() + assert.NoError(t, err) + }() _, err = afC.Exec("create database if not exists statsd") assert.NoError(t, err) err = p.Init(nil) @@ -68,7 +72,10 @@ func TestStatsd(t *testing.T) { t.Error(err) return } - defer r.Close() + defer func() { + err = r.Close() + assert.NoError(t, err) + }() values := make([]driver.Value, 1) err = r.Next(values) assert.NoError(t, err) @@ -82,7 +89,10 @@ func TestStatsd(t *testing.T) { t.Error(err) return } - defer rows.Close() + defer func() { + err = rows.Close() + assert.NoError(t, err) + }() values = make([]driver.Value, 1) err = rows.Next(values) assert.NoError(t, err) diff --git a/plugin/statsd/statsd.go b/plugin/statsd/statsd.go index 0766ee34..99ab2056 100644 --- a/plugin/statsd/statsd.go +++ b/plugin/statsd/statsd.go @@ -3,7 +3,6 @@ package statsd import ( "bufio" "bytes" - _ "embed" "fmt" "math/rand" "net" @@ -28,10 +27,7 @@ const ( defaultFieldName = "value" - defaultProtocol = "udp4" - - defaultSeparator = "_" - defaultAllowPendingMessage = 10000 + defaultSeparator = "_" parserGoRoutines = 5 ) @@ -399,7 +395,12 @@ func (s *Statsd) tcpListen(listener *net.TCPListener) error { _ = conn.Close() continue } - taosConn.Put() + err = taosConn.Put() + if err != nil { + s.Log.Errorf("PutConnection error: %v", err) + _ = conn.Close() + continue + } authed = true valid = true } @@ -461,10 +462,11 @@ func (s *Statsd) udpListen(conn *net.UDPConn) error { s.Log.Errorf("Error reading: %s", err.Error()) continue } - return nil + return err } if addr == nil { s.Log.Errorf("RemoteAddr is nil") + continue } authed, valid, poolExists := commonpool.VerifyClientIP(s.User, s.Password, addr.IP) if !poolExists { @@ -474,7 +476,12 @@ func (s *Statsd) udpListen(conn *net.UDPConn) error { _ = conn.Close() continue } - taosConn.Put() + err = taosConn.Put() + if err != nil { + s.Log.Errorf("PutConnection error: %v", err) + _ = conn.Close() + continue + } authed = true valid = true } @@ -927,7 +934,11 @@ func (s *Statsd) handler(conn *net.TCPConn, id string) { s.Log.Errorf("get connection error, err:%s", err) return } - taosConn.Put() + err = taosConn.Put() + if err != nil { + s.Log.Errorf("PutConnection error: %v", err) + return + } select { case s.in <- input{Buffer: b, Time: time.Now(), Addr: remoteIP}: default: @@ -1044,10 +1055,11 @@ func (s *Statsd) expireCachedMetrics() { const alphanum string = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" func RandomString(n int) string { - var bytes = make([]byte, n) - rand.Read(bytes) - for i, b := range bytes { - bytes[i] = alphanum[b%byte(len(alphanum))] + var bs = make([]byte, n) + //nolint:staticcheck + rand.Read(bs) + for i, b := range bs { + bs[i] = alphanum[b%byte(len(alphanum))] } - return string(bytes) + return string(bs) } diff --git a/plugin/statsd/statsd_test.go b/plugin/statsd/statsd_test.go index 2bf18aa3..9e308023 100644 --- a/plugin/statsd/statsd_test.go +++ b/plugin/statsd/statsd_test.go @@ -1616,7 +1616,7 @@ func TestTCP(t *testing.T) { if len(acc.Metrics) > 0 { break } - if time.Now().Sub(start) > time.Second*5 { + if time.Since(start) > time.Second*5 { t.Fatal("timeout waiting for metrics") } } @@ -1665,7 +1665,7 @@ func TestUdp(t *testing.T) { if len(acc.Metrics) > 0 { break } - if time.Now().Sub(start) > time.Second*5 { + if time.Since(start) > time.Second*5 { t.Fatal("timeout waiting for metrics") } } diff --git a/schemaless/capi/influxdb_test.go b/schemaless/capi/influxdb_test.go index 02bb4d2e..73635d0b 100644 --- a/schemaless/capi/influxdb_test.go +++ b/schemaless/capi/influxdb_test.go @@ -8,7 +8,6 @@ import ( "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/schemaless/capi" - "github.com/taosdata/taosadapter/v3/schemaless/proto" ) // @author: xftan @@ -47,7 +46,6 @@ func TestInsertInfluxdb(t *testing.T) { tests := []struct { name string args args - want *proto.InfluxResult wantErr bool }{ { diff --git a/schemaless/proto/influx.go b/schemaless/proto/influx.go deleted file mode 100644 index 88ed39e6..00000000 --- a/schemaless/proto/influx.go +++ /dev/null @@ -1,7 +0,0 @@ -package proto - -type InfluxResult struct { - SuccessCount int - FailCount int - ErrorList []string -} diff --git a/system/controller.go b/system/controller.go index 1548e7c8..15ccd633 100644 --- a/system/controller.go +++ b/system/controller.go @@ -2,13 +2,14 @@ package system import ( // http - _ "github.com/taosdata/taosadapter/v3/controller/metrics" - _ "github.com/taosdata/taosadapter/v3/controller/ping" - _ "github.com/taosdata/taosadapter/v3/controller/rest" + _ "github.com/taosdata/taosadapter/v3/controller/metrics" // metrics + _ "github.com/taosdata/taosadapter/v3/controller/ping" // http ping + _ "github.com/taosdata/taosadapter/v3/controller/rest" // http rest api + // websocket - _ "github.com/taosdata/taosadapter/v3/controller/ws/query" - _ "github.com/taosdata/taosadapter/v3/controller/ws/schemaless" - _ "github.com/taosdata/taosadapter/v3/controller/ws/stmt" - _ "github.com/taosdata/taosadapter/v3/controller/ws/tmq" - _ "github.com/taosdata/taosadapter/v3/controller/ws/ws" + _ "github.com/taosdata/taosadapter/v3/controller/ws/query" // old query + _ "github.com/taosdata/taosadapter/v3/controller/ws/schemaless" // old schemaless + _ "github.com/taosdata/taosadapter/v3/controller/ws/stmt" // old stmt + _ "github.com/taosdata/taosadapter/v3/controller/ws/tmq" // tmq + _ "github.com/taosdata/taosadapter/v3/controller/ws/ws" // ws(query, stmt, schemaless) ) diff --git a/system/main.go b/system/main.go index 640a36bc..38d5b1f1 100644 --- a/system/main.go +++ b/system/main.go @@ -7,7 +7,7 @@ import ( "sort" "strconv" "time" - _ "time/tzdata" + _ "time/tzdata" // load time zone data "github.com/gin-contrib/cors" "github.com/gin-contrib/gzip" diff --git a/system/main_test.go b/system/main_test.go index d1941d40..095cb6f8 100644 --- a/system/main_test.go +++ b/system/main_test.go @@ -32,7 +32,7 @@ func TestStart(t *testing.T) { time.Sleep(time.Second) continue } - resp.Body.Close() + _ = resp.Body.Close() success = true break } diff --git a/system/plugin.go b/system/plugin.go index 05e59201..b15c5749 100644 --- a/system/plugin.go +++ b/system/plugin.go @@ -1,11 +1,11 @@ package system import ( - _ "github.com/taosdata/taosadapter/v3/plugin/collectd" - _ "github.com/taosdata/taosadapter/v3/plugin/influxdb" - _ "github.com/taosdata/taosadapter/v3/plugin/nodeexporter" - _ "github.com/taosdata/taosadapter/v3/plugin/opentsdb" - _ "github.com/taosdata/taosadapter/v3/plugin/opentsdbtelnet" - _ "github.com/taosdata/taosadapter/v3/plugin/prometheus" - _ "github.com/taosdata/taosadapter/v3/plugin/statsd" + _ "github.com/taosdata/taosadapter/v3/plugin/collectd" // import collectd plugin + _ "github.com/taosdata/taosadapter/v3/plugin/influxdb" // import influxdb plugin + _ "github.com/taosdata/taosadapter/v3/plugin/nodeexporter" // import nodeexporter plugin + _ "github.com/taosdata/taosadapter/v3/plugin/opentsdb" // import opentsdb plugin + _ "github.com/taosdata/taosadapter/v3/plugin/opentsdbtelnet" // import opentsdbtelnet plugin + _ "github.com/taosdata/taosadapter/v3/plugin/prometheus" // import prometheus plugin + _ "github.com/taosdata/taosadapter/v3/plugin/statsd" // import statsd plugin ) diff --git a/thread/locker_test.go b/thread/locker_test.go index 3d7c74b4..f7a1b692 100644 --- a/thread/locker_test.go +++ b/thread/locker_test.go @@ -23,9 +23,11 @@ func TestNewLocker(t *testing.T) { }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name, func(_ *testing.T) { locker := NewLocker(tt.args.count) locker.Lock() + a := 1 + _ = a locker.Unlock() }) } diff --git a/tools/bytesutil/bytesutil.go b/tools/bytesutil/bytesutil.go index 8d9d1c56..aeaf0d43 100644 --- a/tools/bytesutil/bytesutil.go +++ b/tools/bytesutil/bytesutil.go @@ -72,21 +72,12 @@ func ToUnsafeString(b []byte) string { // // The returned byte slice is valid only until s is reachable and unmodified. func ToUnsafeBytes(s string) (b []byte) { + //nolint:staticcheck sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) + //nolint:staticcheck slh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) slh.Data = sh.Data slh.Len = sh.Len slh.Cap = sh.Len return b } - -// LimitStringLen limits the length of s to maxLen. -// -// If len(s) > maxLen, then the function concatenates s prefix with s suffix. -func LimitStringLen(s string, maxLen int) string { - if maxLen <= 4 || len(s) <= maxLen { - return s - } - n := maxLen/2 - 1 - return s[:n] + ".." + s[len(s)-n:] -} diff --git a/tools/connectpool/pool.go b/tools/connectpool/pool.go index 9b9a2dfb..78748647 100644 --- a/tools/connectpool/pool.go +++ b/tools/connectpool/pool.go @@ -20,7 +20,7 @@ type Config struct { MaxWait int WaitTimeout time.Duration Factory func() (unsafe.Pointer, error) - Close func(pointer unsafe.Pointer) error + Close func(pointer unsafe.Pointer) } type connReq struct { @@ -31,7 +31,7 @@ type ConnectPool struct { mu sync.RWMutex conns chan unsafe.Pointer factory func() (unsafe.Pointer, error) - close func(pointer unsafe.Pointer) error + close func(pointer unsafe.Pointer) maxActive int openingConns int maxWait int @@ -182,7 +182,7 @@ func (c *ConnectPool) Put(conn unsafe.Pointer) error { if c.released { c.mu.Unlock() - c.Close(conn) + _ = c.Close(conn) if c.openingConns == 0 { c.close = nil } @@ -198,16 +198,16 @@ func (c *ConnectPool) Put(conn unsafe.Pointer) error { } c.mu.Unlock() return nil - } else { - select { - case c.conns <- conn: - c.mu.Unlock() - return nil - default: - c.mu.Unlock() - return c.Close(conn) - } } + select { + case c.conns <- conn: + c.mu.Unlock() + return nil + default: + c.mu.Unlock() + return c.Close(conn) + } + } func (c *ConnectPool) Close(conn unsafe.Pointer) error { @@ -220,7 +220,8 @@ func (c *ConnectPool) Close(conn unsafe.Pointer) error { return nil } c.openingConns-- - return c.close(conn) + c.close(conn) + return nil } func (c *ConnectPool) Release() { diff --git a/tools/connectpool/pool_test.go b/tools/connectpool/pool_test.go index e93247b9..6e81c499 100644 --- a/tools/connectpool/pool_test.go +++ b/tools/connectpool/pool_test.go @@ -20,9 +20,8 @@ func TestConnectPool(t *testing.T) { factoryCalled++ return unsafe.Pointer(&struct{}{}), nil }, - Close: func(pointer unsafe.Pointer) error { + Close: func(_ unsafe.Pointer) { closeCalled++ - return nil }, }) @@ -41,9 +40,12 @@ func TestConnectPool(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, conn3) - pool.Put(conn1) - pool.Put(conn2) - pool.Put(conn3) + err = pool.Put(conn1) + assert.NoError(t, err) + err = pool.Put(conn2) + assert.NoError(t, err) + err = pool.Put(conn3) + assert.NoError(t, err) arr := make([]unsafe.Pointer, 5) for i := 0; i < 5; i++ { @@ -56,7 +58,6 @@ func TestConnectPool(t *testing.T) { err = pool.Put(arr[i]) assert.NoError(t, err) } - arr = nil assert.Equal(t, 5, factoryCalled) pool.Release() @@ -73,9 +74,8 @@ func TestTimeout(t *testing.T) { Factory: func() (unsafe.Pointer, error) { return nil, nil }, - Close: func(pointer unsafe.Pointer) error { + Close: func(_ unsafe.Pointer) { t.Log("close") - return nil }, }) assert.NoError(t, err) @@ -95,9 +95,8 @@ func TestGetAfterRelease(t *testing.T) { Factory: func() (unsafe.Pointer, error) { return unsafe.Pointer(&a), nil }, - Close: func(pointer unsafe.Pointer) error { + Close: func(_ unsafe.Pointer) { t.Log("close") - return nil }, }) assert.NoError(t, err) @@ -136,9 +135,8 @@ func TestMaxWait(t *testing.T) { Factory: func() (unsafe.Pointer, error) { return unsafe.Pointer(&a), nil }, - Close: func(pointer unsafe.Pointer) error { + Close: func(_ unsafe.Pointer) { t.Log("close") - return nil }, }) assert.NoError(t, err) @@ -150,7 +148,8 @@ func TestMaxWait(t *testing.T) { _, err = pool.Get() assert.Equal(t, ErrMaxWait, err) // put back connection - pool.Put(c) + err = pool.Put(c) + assert.NoError(t, err) }() // wait for connection put back _, err = pool.Get() @@ -216,9 +215,7 @@ func TestNewConnectPool(t *testing.T) { Factory: func() (unsafe.Pointer, error) { return nil, errors.New("connect error") }, - Close: func(pointer unsafe.Pointer) error { - return nil - }, + Close: func(pointer unsafe.Pointer) {}, InitialCap: 1, }, }, diff --git a/tools/ctools/block.go b/tools/ctools/block.go index 3432e756..0f73900e 100644 --- a/tools/ctools/block.go +++ b/tools/ctools/block.go @@ -102,11 +102,11 @@ func WriteRawJsonBinary(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Poin clen := *((*uint16)(currentRow)) currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - builder.WriteByte('"') + builder.AddByte('"') for index := uint16(0); index < clen; index++ { builder.WriteStringByte(*((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index))))) } - builder.WriteByte('"') + builder.AddByte('"') } func WriteRawJsonVarBinary(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Pointer, row int) { @@ -119,17 +119,17 @@ func WriteRawJsonVarBinary(builder *jsonbuilder.Stream, pHeader, pStart unsafe.P clen := *((*uint16)(currentRow)) currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - builder.WriteByte('"') + builder.AddByte('"') var b byte for index := uint16(0); index < clen; index++ { b = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) s := strconv.FormatInt(int64(b), 16) if len(s) == 1 { - builder.WriteByte('0') + builder.AddByte('0') } builder.WriteRaw(s) } - builder.WriteByte('"') + builder.AddByte('"') } func WriteRawJsonGeometry(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Pointer, row int) { @@ -145,11 +145,11 @@ func WriteRawJsonNchar(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Point currentRow := tools.AddPointer(pStart, uintptr(offset)) clen := *((*uint16)(currentRow)) / 4 currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - builder.WriteByte('"') + builder.AddByte('"') for index := uint16(0); index < clen; index++ { builder.WriteRuneString(*((*rune)(unsafe.Pointer(uintptr(currentRow) + uintptr(index*4))))) } - builder.WriteByte('"') + builder.AddByte('"') } func WriteRawJsonJson(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Pointer, row int) { @@ -163,7 +163,7 @@ func WriteRawJsonJson(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Pointe currentRow = unsafe.Pointer(uintptr(currentRow) + 2) for index := uint16(0); index < clen; index++ { - builder.WriteByte(*((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index))))) + builder.AddByte(*((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index))))) } } diff --git a/tools/ctools/block_test.go b/tools/ctools/block_test.go index aa686dce..79e7e1ca 100644 --- a/tools/ctools/block_test.go +++ b/tools/ctools/block_test.go @@ -147,19 +147,16 @@ func TestJsonWriteRawBlock(t *testing.T) { block := unsafe.Pointer(&raw[0]) lengthOffset := parser.RawBlockGetColumnLengthOffset(fieldsCount) tmpPHeader := tools.AddPointer(block, parser.RawBlockGetColDataOffset(fieldsCount)) - tmpPStart := tmpPHeader for column := 0; column < fieldsCount; column++ { colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*parser.Int32Size))) if IsVarDataType(fieldTypes[column]) { pHeaderList[column] = tmpPHeader - tmpPStart = tools.AddPointer(tmpPHeader, uintptr(4*blockSize)) - pStartList[column] = tmpPStart + pStartList[column] = tools.AddPointer(tmpPHeader, uintptr(4*blockSize)) } else { pHeaderList[column] = tmpPHeader - tmpPStart = tools.AddPointer(tmpPHeader, nullBitMapOffset) - pStartList[column] = tmpPStart + pStartList[column] = tools.AddPointer(tmpPHeader, nullBitMapOffset) } - tmpPHeader = tools.AddPointer(tmpPStart, uintptr(colLength)) + tmpPHeader = tools.AddPointer(pStartList[column], uintptr(colLength)) } timeBuffer := make([]byte, 0, 30) builder.WriteObjectStart() diff --git a/tools/generator/reqid_test.go b/tools/generator/reqid_test.go index b74130df..9e2cc98c 100644 --- a/tools/generator/reqid_test.go +++ b/tools/generator/reqid_test.go @@ -33,11 +33,11 @@ func TestGetReqID(t *testing.T) { // Call GetReqID multiple times for i := 0; i < 100; i++ { // Check if the reqIncrement resets after it exceeds 0x00ffffffffffffff - id := GetReqID() + GetReqID() // Check if the ID is unique atomic.StoreInt64(&reqIncrement, 0x00ffffffffffffff+1) - id = GetReqID() + id := GetReqID() if _, exists := ids[id]; exists { if id != int64(config.Conf.InstanceID)<<56|1 { t.Errorf("GetReqID() returned a duplicate ID: %v", id) diff --git a/tools/joinerror/join.go b/tools/joinerror/join.go new file mode 100644 index 00000000..5accfb8f --- /dev/null +++ b/tools/joinerror/join.go @@ -0,0 +1,41 @@ +package joinerror + +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &JoinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +type JoinError struct { + errs []error +} + +func (e *JoinError) Error() string { + var b []byte + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + return string(b) +} + +func (e *JoinError) Unwrap() []error { + return e.errs +} diff --git a/tools/joinerror/join_test.go b/tools/joinerror/join_test.go new file mode 100644 index 00000000..d14a0ba5 --- /dev/null +++ b/tools/joinerror/join_test.go @@ -0,0 +1,68 @@ +package joinerror + +import ( + "errors" + "reflect" + "testing" +) + +func TestJoinReturnsNil(t *testing.T) { + if err := Join(); err != nil { + t.Errorf("Join() = %v, want nil", err) + } + if err := Join(nil); err != nil { + t.Errorf("Join(nil) = %v, want nil", err) + } + if err := Join(nil, nil); err != nil { + t.Errorf("Join(nil, nil) = %v, want nil", err) + } +} + +func TestJoin(t *testing.T) { + err1 := errors.New("err1") + err2 := errors.New("err2") + for _, test := range []struct { + errs []error + want []error + }{{ + errs: []error{err1}, + want: []error{err1}, + }, { + errs: []error{err1, err2}, + want: []error{err1, err2}, + }, { + errs: []error{err1, nil, err2}, + want: []error{err1, err2}, + }} { + got := Join(test.errs...).(interface{ Unwrap() []error }).Unwrap() + if !reflect.DeepEqual(got, test.want) { + t.Errorf("Join(%v) = %v; want %v", test.errs, got, test.want) + } + if len(got) != cap(got) { + t.Errorf("Join(%v) returns errors with len=%v, cap=%v; want len==cap", test.errs, len(got), cap(got)) + } + } +} + +func TestJoinErrorMethod(t *testing.T) { + err1 := errors.New("err1") + err2 := errors.New("err2") + for _, test := range []struct { + errs []error + want string + }{{ + errs: []error{err1}, + want: "err1", + }, { + errs: []error{err1, err2}, + want: "err1\nerr2", + }, { + errs: []error{err1, nil, err2}, + want: "err1\nerr2", + }} { + got := Join(test.errs...).Error() + if got != test.want { + t.Errorf("Join(%v).Error() = %q; want %q", test.errs, got, test.want) + } + } +} diff --git a/tools/jsonbuilder/builder_test.go b/tools/jsonbuilder/builder_test.go index c5ea6da9..1275db98 100644 --- a/tools/jsonbuilder/builder_test.go +++ b/tools/jsonbuilder/builder_test.go @@ -11,7 +11,8 @@ func TestBorrowStream(t *testing.T) { b := &strings.Builder{} s := BorrowStream(b) s.WriteString(`"a"`) - s.Flush() + err := s.Flush() + assert.NoError(t, err) assert.Equal(t, `"\"a\""`, b.String()) ReturnStream(s) } diff --git a/tools/jsonbuilder/stream.go b/tools/jsonbuilder/stream.go index 169b2297..48f4bbba 100644 --- a/tools/jsonbuilder/stream.go +++ b/tools/jsonbuilder/stream.go @@ -80,11 +80,11 @@ func (stream *Stream) WritePure(p []byte) { stream.buf = append(stream.buf, p...) } -func (stream *Stream) WriteByte(c byte) { +// AddByte writes a single byte. +func (stream *Stream) AddByte(c byte) { stream.writeByte(c) } -// WriteByte writes a single byte. func (stream *Stream) writeByte(c byte) { stream.buf = append(stream.buf, c) } diff --git a/tools/jsonbuilder/stream_rune.go b/tools/jsonbuilder/stream_rune.go index 088253cd..f7a92680 100644 --- a/tools/jsonbuilder/stream_rune.go +++ b/tools/jsonbuilder/stream_rune.go @@ -3,9 +3,7 @@ package jsonbuilder // Numbers fundamental to the encoding. const ( RuneError = '\uFFFD' // the "error" Rune or "Unicode replacement character" - RuneSelf = 0x80 // characters below RuneSelf are represented as themselves in a single byte. MaxRune = '\U0010FFFF' // Maximum valid Unicode code point. - UTFMax = 4 // maximum number of bytes of a UTF-8 encoded Unicode character. ) // Code points in the surrogate range are not valid for UTF-8. @@ -15,64 +13,41 @@ const ( ) const ( - t1 = 0b00000000 tx = 0b10000000 t2 = 0b11000000 t3 = 0b11100000 t4 = 0b11110000 - t5 = 0b11111000 maskx = 0b00111111 - mask2 = 0b00011111 - mask3 = 0b00001111 - mask4 = 0b00000111 rune1Max = 1<<7 - 1 rune2Max = 1<<11 - 1 rune3Max = 1<<16 - 1 - - // The default lowest and highest continuation byte. - locb = 0b10000000 - hicb = 0b10111111 - - // These names of these constants are chosen to give nice alignment in the - // table below. The first nibble is an index into acceptRanges or F for - // special one-byte cases. The second nibble is the Rune length or the - // Status for the special one-byte case. - xx = 0xF1 // invalid: size 1 - as = 0xF0 // ASCII: size 1 - s1 = 0x02 // accept 0, size 2 - s2 = 0x13 // accept 1, size 3 - s3 = 0x03 // accept 0, size 3 - s4 = 0x23 // accept 2, size 3 - s5 = 0x34 // accept 3, size 4 - s6 = 0x04 // accept 0, size 4 - s7 = 0x44 // accept 4, size 4 ) func (stream *Stream) WriteRune(r rune) { if uint32(r) <= rune1Max { - stream.WriteByte(byte(r)) + stream.writeByte(byte(r)) return } switch i := uint32(r); { case i <= rune2Max: - stream.WriteByte(t2 | byte(r>>6)) - stream.WriteByte(tx | byte(r)&maskx) + stream.writeByte(t2 | byte(r>>6)) + stream.writeByte(tx | byte(r)&maskx) return case i > MaxRune, surrogateMin <= i && i <= surrogateMax: r = RuneError fallthrough case i <= rune3Max: - stream.WriteByte(t3 | byte(r>>12)) - stream.WriteByte(tx | byte(r>>6)&maskx) - stream.WriteByte(tx | byte(r)&maskx) + stream.writeByte(t3 | byte(r>>12)) + stream.writeByte(tx | byte(r>>6)&maskx) + stream.writeByte(tx | byte(r)&maskx) return default: - stream.WriteByte(t4 | byte(r>>18)) - stream.WriteByte(tx | byte(r>>12)&maskx) - stream.WriteByte(tx | byte(r>>6)&maskx) - stream.WriteByte(tx | byte(r)&maskx) + stream.writeByte(t4 | byte(r>>18)) + stream.writeByte(tx | byte(r>>12)&maskx) + stream.writeByte(tx | byte(r>>6)&maskx) + stream.writeByte(tx | byte(r)&maskx) return } } diff --git a/tools/jsonbuilder/stream_test.go b/tools/jsonbuilder/stream_test.go index 85e90301..488a68c1 100644 --- a/tools/jsonbuilder/stream_test.go +++ b/tools/jsonbuilder/stream_test.go @@ -27,10 +27,12 @@ func Test_writeByte_should_grow_buffer(t *testing.T) { func Test_writeBytes_should_grow_buffer(t *testing.T) { should := require.New(t) stream := NewStream(ConfigDefault, nil, 1) - stream.Write([]byte{'1', '2'}) + _, err := stream.Write([]byte{'1', '2'}) + assert.NoError(t, err) should.Equal("12", string(stream.Buffer())) should.Equal(2, len(stream.buf)) - stream.Write([]byte{'3', '4', '5', '6', '7'}) + _, err = stream.Write([]byte{'3', '4', '5', '6', '7'}) + assert.NoError(t, err) should.Equal("1234567", string(stream.Buffer())) should.Equal(7, len(stream.buf)) } @@ -68,7 +70,8 @@ func Test_flush_buffer_should_stop_grow_buffer(t *testing.T) { for i := 0; i < 10000000; i++ { stream.WriteInt(0) stream.WriteMore() - stream.Flush() + err := stream.Flush() + assert.NoError(t, err) } stream.WriteInt(0) stream.WriteArrayEnd() @@ -91,10 +94,12 @@ func TestStream_Common(t *testing.T) { assert.Equal(t, 0, stream.Buffered()) stream.SetBuffer(make([]byte, 0, 512)) stream.Reset(writer2) - stream.Write([]byte{1}) + _, err := stream.Write([]byte{1}) + assert.NoError(t, err) stream.WriteArrayStart() - stream.WriteByte(1) - stream.Flush() + stream.AddByte(1) + err = stream.Flush() + assert.NoError(t, err) stream.WriteNil() stream.WriteTrue() stream.WriteFalse() @@ -175,7 +180,8 @@ func TestStr(t *testing.T) { b := &strings.Builder{} stream := BorrowStream(b) stream.WriteString("a\nb") - stream.Flush() + err := stream.Flush() + assert.NoError(t, err) assert.Equal(t, "\"a\\nb\"", b.String()) } @@ -185,7 +191,8 @@ func TestStrByte(t *testing.T) { stream.WriteStringByte('a') stream.WriteStringByte('\n') stream.WriteStringByte('b') - stream.Flush() + err := stream.Flush() + assert.NoError(t, err) assert.Equal(t, "a\\nb", b.String()) } diff --git a/tools/jsontype/uint8.go b/tools/jsontype/uint8.go index 098283a3..d169cdfd 100644 --- a/tools/jsontype/uint8.go +++ b/tools/jsontype/uint8.go @@ -1,5 +1,6 @@ package jsontype +// JsonUint8 is a wrapper for []uint8 to implement json.Marshaler interface. type JsonUint8 []uint8 var digits []uint32 @@ -26,6 +27,7 @@ func writeFirstBuf(space []byte, v uint32) []byte { return space } +// MarshalJSON implements the json.Marshaler interface. func (m JsonUint8) MarshalJSON() ([]byte, error) { if m == nil { return []byte("null"), nil diff --git a/tools/layout/time.go b/tools/layout/time.go index 57ecebec..51d8d52f 100644 --- a/tools/layout/time.go +++ b/tools/layout/time.go @@ -1,7 +1,11 @@ +// Package layout provides time layout package layout const ( - LayoutMillSecond = "2006-01-02T15:04:05.000Z07:00" + // LayoutMillSecond is the time layout for millisecond + LayoutMillSecond = "2006-01-02T15:04:05.000Z07:00" + // LayoutMicroSecond is the time layout for microsecond LayoutMicroSecond = "2006-01-02T15:04:05.000000Z07:00" - LayoutNanoSecond = "2006-01-02T15:04:05.000000000Z07:00" + // LayoutNanoSecond is the time layout for nanosecond + LayoutNanoSecond = "2006-01-02T15:04:05.000000000Z07:00" ) diff --git a/tools/monitor/collect_test.go b/tools/monitor/collect_test.go index 60eec2b9..9eaca916 100644 --- a/tools/monitor/collect_test.go +++ b/tools/monitor/collect_test.go @@ -20,7 +20,7 @@ type MockProcess struct { errMem error } -func (m *MockProcess) Percent(interval time.Duration) (float64, error) { +func (m *MockProcess) Percent(_ time.Duration) (float64, error) { return m.cpuPercent, m.errCpu } diff --git a/tools/monitor/util_test.go b/tools/monitor/util_test.go index 79665f30..9c05b439 100644 --- a/tools/monitor/util_test.go +++ b/tools/monitor/util_test.go @@ -5,13 +5,19 @@ import ( "strconv" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestReadUint(t *testing.T) { // Valid uint64 number in a file validPath := "valid_uint.txt" - os.WriteFile(validPath, []byte("123456"), 0644) - defer os.Remove(validPath) + err := os.WriteFile(validPath, []byte("123456"), 0644) + assert.NoError(t, err) + defer func() { + err := os.Remove(validPath) + assert.NoError(t, err) + }() result, err := readUint(validPath) if err != nil || result != 123456 { diff --git a/tools/web/middleware_test.go b/tools/web/middleware_test.go index 646ebc23..5d552e9f 100644 --- a/tools/web/middleware_test.go +++ b/tools/web/middleware_test.go @@ -24,7 +24,7 @@ func TestSetTaosErrorCode(t *testing.T) { }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name, func(_ *testing.T) { SetTaosErrorCode(tt.args.c, tt.args.code) }) } diff --git a/version/version.go b/version/version.go index fb4a7512..235bc0d8 100644 --- a/version/version.go +++ b/version/version.go @@ -10,6 +10,8 @@ var BuildInfo = "unknown" var TaosClientVersion = wrapper.TaosGetClientInfo() +//revive:disable-next-line var CUS_NAME = "TDengine" +//revive:disable-next-line var CUS_PROMPT = "taos" From ed6b139bacdd955124bec91a6435066902e035fc Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 13 Nov 2024 14:58:35 +0800 Subject: [PATCH 06/48] test: fix stmt bind block test --- controller/ws/stmt/stmt_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/controller/ws/stmt/stmt_test.go b/controller/ws/stmt/stmt_test.go index 4101f333..8209b6c0 100644 --- a/controller/ws/stmt/stmt_test.go +++ b/controller/ws/stmt/stmt_test.go @@ -726,10 +726,6 @@ func TestBlock(t *testing.T) { t.Error(err) return } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() const ( AfterConnect = iota + 1 AfterInit From a10c944a6ec5d499eca127496e0c8b329a40d523 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 13 Nov 2024 15:04:28 +0800 Subject: [PATCH 07/48] test: fix stmt bind test --- controller/ws/stmt/stmt_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/controller/ws/stmt/stmt_test.go b/controller/ws/stmt/stmt_test.go index 8209b6c0..124e8369 100644 --- a/controller/ws/stmt/stmt_test.go +++ b/controller/ws/stmt/stmt_test.go @@ -91,10 +91,6 @@ func TestSTMT(t *testing.T) { t.Error(err) return } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() const ( AfterConnect = iota + 1 AfterInit From 1318d7a3b0ed938b87560f86fb2f84dbf5f05600 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 13 Nov 2024 19:27:40 +0800 Subject: [PATCH 08/48] test: add unit test --- .golangci.yaml | 3 +- controller/ws/query/ws.go | 41 ++- controller/ws/query/ws_test.go | 266 +++++++++----------- controller/ws/schemaless/schemaless.go | 35 ++- controller/ws/schemaless/schemaless_test.go | 94 ++++++- controller/ws/stmt/stmt.go | 35 ++- controller/ws/stmt/stmt_test.go | 72 ++++++ controller/ws/tmq/tmq.go | 35 ++- controller/ws/tmq/tmq_test.go | 51 ++++ tools/pointer_test.go | 17 ++ 10 files changed, 397 insertions(+), 252 deletions(-) create mode 100644 tools/pointer_test.go diff --git a/.golangci.yaml b/.golangci.yaml index 3fabc1c1..e16f0f80 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -35,4 +35,5 @@ linters-settings: - [ ] - name: range - name: receiver-naming - - name: indent-error-flow \ No newline at end of file + - name: indent-error-flow + - name: unreachable-code \ No newline at end of file diff --git a/controller/ws/query/ws.go b/controller/ws/query/ws.go index dd3b1f63..fbe56fca 100644 --- a/controller/ws/query/ws.go +++ b/controller/ws/query/ws.go @@ -266,15 +266,7 @@ func (t *Taos) waitSignal(logger *logrus.Entry) { return } logger.WithField("clientIP", t.ipStr).Info("user dropped! close connection!") - logger.Trace("close session") - s := log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - logger.Trace("close handler") - s = log.GetLogNow(isDebug) - t.Close() - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return case <-t.whitelistChangeChan: logger.Info("get whitelist change signal") @@ -291,29 +283,14 @@ func (t *Taos) waitSignal(logger *logrus.Entry) { logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { logger.WithField("clientIP", t.ipStr).WithError(err).Errorln("get whitelist error! close connection!") - s = log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - logger.Trace("close handler") - s = log.GetLogNow(isDebug) - t.Close() - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return } logger.Tracef("check whitelist, ip: %s, whitelist: %s", t.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, t.ip) if !valid { logger.WithField("clientIP", t.ipStr).Errorln("ip not in whitelist! close connection!") - logger.Trace("close session") - s = log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - logger.Trace("close handler") - s = log.GetLogNow(isDebug) - t.Close() - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return } t.Unlock() @@ -323,6 +300,18 @@ func (t *Taos) waitSignal(logger *logrus.Entry) { } } +func (t *Taos) signalExit(logger *logrus.Entry, isDebug bool) { + logger.Trace("close session") + s := log.GetLogNow(isDebug) + _ = t.session.Close() + logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) + t.Unlock() + logger.Trace("close handler") + s = log.GetLogNow(isDebug) + t.Close() + logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) +} + type Result struct { index uint64 TaosResult unsafe.Pointer diff --git a/controller/ws/query/ws_test.go b/controller/ws/query/ws_test.go index 96319303..9bd0de8d 100644 --- a/controller/ws/query/ws_test.go +++ b/controller/ws/query/ws_test.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -48,35 +49,14 @@ func TestMain(m *testing.M) { // @description: test websocket bulk pulling func TestWebsocket(t *testing.T) { now := time.Now().Local().UnixNano() / 1e6 - w := httptest.NewRecorder() - body := strings.NewReader("create database if not exists test_ws WAL_RETENTION_PERIOD 86400") - req, _ := http.NewRequest(http.MethodPost, "/rest/sql", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader("drop table if exists test_ws") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader("create table if not exists test_ws(ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20)) tags (info json)") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader(fmt.Sprintf(`insert into t1 using test_ws tags('{"table":"t1"}') values (%d,true,2,3,4,5,6,7,8,9,10,11,'中文"binary','中文nchar')(%d,false,12,13,14,15,16,17,18,19,110,111,'中文"binary','中文nchar')(%d,null,null,null,null,null,null,null,null,null,null,null,null,null)`, now, now+1, now+3)) - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - + code, message := doRestful("create database if not exists test_ws WAL_RETENTION_PERIOD 86400", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("drop table if exists test_ws", "test_ws") + assert.Equal(t, 0, code, message) + code, message = doRestful("create table if not exists test_ws(ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20)) tags (info json)", "test_ws") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf(`insert into t1 using test_ws tags('{"table":"t1"}') values (%d,true,2,3,4,5,6,7,8,9,10,11,'中文"binary','中文nchar')(%d,false,12,13,14,15,16,17,18,19,110,111,'中文"binary','中文nchar')(%d,null,null,null,null,null,null,null,null,null,null,null,null,null)`, now, now+1, now+3), "test_ws") + assert.Equal(t, 0, code, message) s := httptest.NewServer(router) defer s.Close() ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/ws", nil) @@ -328,52 +308,22 @@ func TestWebsocket(t *testing.T) { assert.Equal(t, nil, blockResult[2][12]) assert.Equal(t, nil, blockResult[2][13]) assert.Equal(t, []byte(`{"table":"t1"}`), blockResult[2][14]) - w = httptest.NewRecorder() - body = strings.NewReader("drop database if exists test_ws") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + code, message = doRestful("drop database if exists test_ws", "") + assert.Equal(t, 0, code, message) } func TestWriteBlock(t *testing.T) { now := time.Now().Local().UnixNano() / 1e6 - w := httptest.NewRecorder() - body := strings.NewReader("create database if not exists test_ws_write_block WAL_RETENTION_PERIOD 86400") - req, _ := http.NewRequest(http.MethodPost, "/rest/sql", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader("drop table if exists test_ws_write_block") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_write_block", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader("create table if not exists test_ws_write_block(ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20)) tags (info json)") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_write_block", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader(fmt.Sprintf(`insert into t1 using test_ws_write_block tags('{"table":"t1"}') values (%d,true,2,3,4,5,6,7,8,9,10,11,'中文"binary','中文nchar')(%d,false,12,13,14,15,16,17,18,19,110,111,'中文"binary','中文nchar')(%d,null,null,null,null,null,null,null,null,null,null,null,null,null)`, now, now+1, now+3)) - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_write_block", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader(`create table t2 using test_ws_write_block tags('{"table":"t2"}')`) - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_write_block", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + code, message := doRestful("create database if not exists test_ws_write_block WAL_RETENTION_PERIOD 86400", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("drop table if exists test_ws_write_block", "test_ws_write_block") + assert.Equal(t, 0, code, message) + code, message = doRestful("create table if not exists test_ws_write_block(ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20)) tags (info json)", "test_ws_write_block") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf(`insert into t1 using test_ws_write_block tags('{"table":"t1"}') values (%d,true,2,3,4,5,6,7,8,9,10,11,'中文"binary','中文nchar')(%d,false,12,13,14,15,16,17,18,19,110,111,'中文"binary','中文nchar')(%d,null,null,null,null,null,null,null,null,null,null,null,null,null)`, now, now+1, now+3), "test_ws_write_block") + assert.Equal(t, 0, code, message) + code, message = doRestful(`create table t2 using test_ws_write_block tags('{"table":"t2"}')`, "test_ws_write_block") + assert.Equal(t, 0, code, message) s := httptest.NewServer(router) defer s.Close() @@ -822,52 +772,22 @@ func TestWriteBlock(t *testing.T) { assert.Equal(t, nil, blockResult[2][11]) assert.Equal(t, nil, blockResult[2][12]) assert.Equal(t, nil, blockResult[2][13]) - w = httptest.NewRecorder() - body = strings.NewReader("drop database if exists test_ws_write_block") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + code, message = doRestful("drop database if exists test_ws_write_block", "") + assert.Equal(t, 0, code, message) } func TestWriteBlockWithFields(t *testing.T) { now := time.Now().Local().UnixNano() / 1e6 - w := httptest.NewRecorder() - body := strings.NewReader("create database if not exists test_ws_write_block_with_fields WAL_RETENTION_PERIOD 86400") - req, _ := http.NewRequest(http.MethodPost, "/rest/sql", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader("drop table if exists test_ws_write_block_with_fields") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_write_block_with_fields", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader("create table if not exists test_ws_write_block_with_fields(ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20)) tags (info json)") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_write_block_with_fields", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader(fmt.Sprintf(`insert into t1 using test_ws_write_block_with_fields tags('{"table":"t1"}') values (%d,true,2,3,4,5,6,7,8,9,10,11,'中文"binary','中文nchar')(%d,false,12,13,14,15,16,17,18,19,110,111,'中文"binary','中文nchar')(%d,null,null,null,null,null,null,null,null,null,null,null,null,null)`, now, now+1, now+3)) - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_write_block_with_fields", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader(`create table t2 using test_ws_write_block_with_fields tags('{"table":"t2"}')`) - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_write_block_with_fields", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + code, message := doRestful("create database if not exists test_ws_write_block_with_fields WAL_RETENTION_PERIOD 86400", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("drop table if exists test_ws_write_block_with_fields", "test_ws_write_block_with_fields") + assert.Equal(t, 0, code, message) + code, message = doRestful("create table if not exists test_ws_write_block_with_fields(ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20)) tags (info json)", "test_ws_write_block_with_fields") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf(`insert into t1 using test_ws_write_block_with_fields tags('{"table":"t1"}') values (%d,true,2,3,4,5,6,7,8,9,10,11,'中文"binary','中文nchar')(%d,false,12,13,14,15,16,17,18,19,110,111,'中文"binary','中文nchar')(%d,null,null,null,null,null,null,null,null,null,null,null,null,null)`, now, now+1, now+3), "test_ws_write_block_with_fields") + assert.Equal(t, 0, code, message) + code, message = doRestful(`create table t2 using test_ws_write_block_with_fields tags('{"table":"t2"}')`, "test_ws_write_block_with_fields") + assert.Equal(t, 0, code, message) s := httptest.NewServer(router) defer s.Close() @@ -1319,46 +1239,20 @@ func TestWriteBlockWithFields(t *testing.T) { for i := 1; i < 14; i++ { assert.Equal(t, nil, blockResult[2][i]) } - w = httptest.NewRecorder() - body = strings.NewReader("drop database if exists test_ws_write_block_with_fields") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + code, message = doRestful("drop database if exists test_ws_write_block_with_fields", "") + assert.Equal(t, 0, code, message) } func TestQueryAllType(t *testing.T) { now := time.Now().Local().UnixNano() / 1e6 - w := httptest.NewRecorder() - body := strings.NewReader("create database if not exists test_ws_all_query WAL_RETENTION_PERIOD 86400") - req, _ := http.NewRequest(http.MethodPost, "/rest/sql", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader("drop table if exists test_ws_all_query") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_all_query", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader("create table if not exists test_ws_all_query(ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100))") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_all_query", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - w = httptest.NewRecorder() - body = strings.NewReader(fmt.Sprintf(`insert into test_ws_all_query values (%d,true,2,3,4,5,6,7,8,9,10,11,'中文"binary','中文nchar','\xaabbcc','POINT(100 100)')(%d,false,12,13,14,15,16,17,18,19,110,111,'中文"binary','中文nchar','\xaabbcc','POINT(100 100)')(%d,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)`, now, now+1, now+3)) - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_all_query", body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) - + code, message := doRestful("create database if not exists test_ws_all_query WAL_RETENTION_PERIOD 86400", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("drop table if exists test_ws_all_query", "test_ws_all_query") + assert.Equal(t, 0, code, message) + code, message = doRestful("create table if not exists test_ws_all_query(ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100))", "test_ws_all_query") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf(`insert into test_ws_all_query values (%d,true,2,3,4,5,6,7,8,9,10,11,'中文"binary','中文nchar','\xaabbcc','POINT(100 100)')(%d,false,12,13,14,15,16,17,18,19,110,111,'中文"binary','中文nchar','\xaabbcc','POINT(100 100)')(%d,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)`, now, now+1, now+3), "test_ws_all_query") + assert.Equal(t, 0, code, message) s := httptest.NewServer(router) defer s.Close() ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/ws", nil) @@ -1658,11 +1552,77 @@ func TestQueryAllType(t *testing.T) { assert.Equal(t, nil, blockResult[2][13]) assert.Equal(t, nil, blockResult[2][14]) assert.Equal(t, nil, blockResult[2][15]) - w = httptest.NewRecorder() - body = strings.NewReader("drop database if exists test_ws_all_query") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql", body) + code, message = doRestful("drop database if exists test_ws_all_query", "") + assert.Equal(t, 0, code, message) +} + +type restResp struct { + Code int `json:"code"` + Desc string `json:"desc"` +} + +func doRestful(sql string, db string) (code int, message string) { + w := httptest.NewRecorder() + body := strings.NewReader(sql) + url := "/rest/sql" + if db != "" { + url = fmt.Sprintf("/rest/sql/%s", db) + } + req, _ := http.NewRequest(http.MethodPost, url, body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + if w.Code != http.StatusOK { + return w.Code, w.Body.String() + } + b, _ := io.ReadAll(w.Body) + var res restResp + _ = json.Unmarshal(b, &res) + return res.Code, res.Desc +} + +func doWebSocket(ws *websocket.Conn, action string, arg interface{}) (resp []byte, err error) { + var b []byte + if arg != nil { + b, _ = json.Marshal(arg) + } + a, _ := json.Marshal(WSAction{Action: action, Args: b}) + err = ws.WriteMessage(websocket.TextMessage, a) + if err != nil { + return nil, err + } + _, message, err := ws.ReadMessage() + return message, err +} + +func TestDropUser(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + defer doRestful("drop user test_ws_query_drop_user", "") + code, message := doRestful("create user test_ws_query_drop_user pass 'pass'", "") + assert.Equal(t, 0, code, message) + // connect + connReq := &WSConnectReq{ReqID: 1, User: "test_ws_query_drop_user", Password: "pass"} + resp, err := doWebSocket(ws, WSConnect, &connReq) + assert.NoError(t, err) + var connResp WSConnectResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + // drop user + code, message = doRestful("drop user test_ws_query_drop_user", "") + assert.Equal(t, 0, code, message) + time.Sleep(time.Second * 3) + resp, err = doWebSocket(ws, wstool.ClientVersion, nil) + assert.Error(t, err, resp) } diff --git a/controller/ws/schemaless/schemaless.go b/controller/ws/schemaless/schemaless.go index 54953f0d..b0554659 100644 --- a/controller/ws/schemaless/schemaless.go +++ b/controller/ws/schemaless/schemaless.go @@ -177,13 +177,7 @@ func (t *TaosSchemaless) waitSignal(logger *logrus.Entry) { return } logger.Info("user dropped! close connection!") - s := log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return case <-t.whitelistChangeChan: logger.Info("get whitelist change signal") @@ -201,26 +195,14 @@ func (t *TaosSchemaless) waitSignal(logger *logrus.Entry) { if err != nil { logger.Errorf("get whitelist error, close connection, err:%s", err) wstool.GetLogger(t.session).WithField("ip", t.ipStr).WithError(err).Errorln("get whitelist error! close connection!") - s = log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(t.logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, t.ip) if !valid { logger.Errorf("ip not in whitelist, close connection, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) - s = log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return } t.Unlock() @@ -230,6 +212,17 @@ func (t *TaosSchemaless) waitSignal(logger *logrus.Entry) { } } +func (t *TaosSchemaless) signalExit(logger *logrus.Entry, isDebug bool) { + logger.Trace("close session") + s := log.GetLogNow(isDebug) + _ = t.session.Close() + logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) + t.Unlock() + s = log.GetLogNow(isDebug) + t.Close(logger) + logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) +} + func (t *TaosSchemaless) lock(logger *logrus.Entry, isDebug bool) { s := log.GetLogNow(isDebug) logger.Trace("get handler lock") diff --git a/controller/ws/schemaless/schemaless_test.go b/controller/ws/schemaless/schemaless_test.go index a2c87d1a..a6f4fe9e 100644 --- a/controller/ws/schemaless/schemaless_test.go +++ b/controller/ws/schemaless/schemaless_test.go @@ -2,20 +2,25 @@ package schemaless import ( "encoding/json" + "fmt" + "io" + "net/http" "net/http/httptest" "strings" "testing" + "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/driver-go/v3/ws/schemaless" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" + _ "github.com/taosdata/taosadapter/v3/controller/rest" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" ) var router *gin.Engine @@ -24,8 +29,10 @@ func TestMain(m *testing.M) { viper.Set("pool.maxConnect", 10000) viper.Set("pool.maxIdle", 10000) viper.Set("logLevel", "trace") + viper.Set("uploadKeeper.enable", false) config.Init() db.PrepareConnection() + log.ConfigLog() gin.SetMode(gin.ReleaseMode) router = gin.New() controllers := controller.GetControllers() @@ -36,15 +43,13 @@ func TestMain(m *testing.M) { } func TestRestful_InitSchemaless(t *testing.T) { - conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) - if err != nil { - t.Error(err) - return - } - wrapper.TaosFreeResult(wrapper.TaosQuery(conn, "drop database if exists test_schemaless_ws")) - wrapper.TaosFreeResult(wrapper.TaosQuery(conn, "create database if not exists test_schemaless_ws")) + code, message := doRestful("drop database if exists test_schemaless_ws", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_schemaless_ws", "") + assert.Equal(t, 0, code, message) defer func() { - wrapper.TaosFreeResult(wrapper.TaosQuery(conn, "drop database if exists test_schemaless_ws")) + code, message = doRestful("drop database if exists test_schemaless_ws", "") + assert.Equal(t, 0, code, message) }() s := httptest.NewServer(router) @@ -204,3 +209,74 @@ func TestRestful_InitSchemaless(t *testing.T) { }) } } + +type restResp struct { + Code int `json:"code"` + Desc string `json:"desc"` +} + +func doRestful(sql string, db string) (code int, message string) { + w := httptest.NewRecorder() + body := strings.NewReader(sql) + url := "/rest/sql" + if db != "" { + url = fmt.Sprintf("/rest/sql/%s", db) + } + req, _ := http.NewRequest(http.MethodPost, url, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + return w.Code, w.Body.String() + } + b, _ := io.ReadAll(w.Body) + var res restResp + _ = json.Unmarshal(b, &res) + return res.Code, res.Desc +} + +func doWebSocket(ws *websocket.Conn, action string, arg interface{}) (resp []byte, err error) { + var b []byte + if arg != nil { + b, _ = json.Marshal(arg) + } + a, _ := json.Marshal(wstool.WSAction{Action: action, Args: b}) + err = ws.WriteMessage(websocket.TextMessage, a) + if err != nil { + return nil, err + } + _, message, err := ws.ReadMessage() + return message, err +} + +func TestDropUser(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/schemaless", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + defer doRestful("drop user test_ws_sml_drop_user", "") + code, message := doRestful("create user test_ws_sml_drop_user pass 'pass'", "") + assert.Equal(t, 0, code, message) + // connect + connReq := &schemalessConnReq{ReqID: 1, User: "test_ws_sml_drop_user", Password: "pass"} + resp, err := doWebSocket(ws, SchemalessConn, &connReq) + assert.NoError(t, err) + var connResp schemalessConnResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + // drop user + code, message = doRestful("drop user test_ws_sml_drop_user", "") + assert.Equal(t, 0, code, message) + time.Sleep(time.Second * 3) + resp, err = doWebSocket(ws, wstool.ClientVersion, nil) + assert.Error(t, err, resp) +} diff --git a/controller/ws/stmt/stmt.go b/controller/ws/stmt/stmt.go index 9eb5b918..e7315e98 100644 --- a/controller/ws/stmt/stmt.go +++ b/controller/ws/stmt/stmt.go @@ -303,13 +303,7 @@ func (t *TaosStmt) waitSignal(logger *logrus.Entry) { return } logger.Info("user dropped! close connection!") - s := log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return case <-t.whitelistChangeChan: logger.Info("get whitelist change signal") @@ -327,26 +321,14 @@ func (t *TaosStmt) waitSignal(logger *logrus.Entry) { if err != nil { logger.Errorf("get whitelist error, close connection, err:%s", err) wstool.GetLogger(t.session).WithField("ip", t.ipStr).WithError(err).Errorln("get whitelist error! close connection!") - s = log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(t.logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, t.ip) if !valid { logger.Errorf("ip not in whitelist, close connection, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) - s = log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return } t.Unlock() @@ -356,6 +338,17 @@ func (t *TaosStmt) waitSignal(logger *logrus.Entry) { } } +func (t *TaosStmt) signalExit(logger *logrus.Entry, isDebug bool) { + logger.Trace("close session") + s := log.GetLogNow(isDebug) + _ = t.session.Close() + logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) + t.Unlock() + s = log.GetLogNow(isDebug) + t.Close(logger) + logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) +} + func (t *TaosStmt) lock(logger *logrus.Entry, isDebug bool) { s := log.GetLogNow(isDebug) logger.Trace("get handler lock") diff --git a/controller/ws/stmt/stmt_test.go b/controller/ws/stmt/stmt_test.go index 124e8369..c73d4968 100644 --- a/controller/ws/stmt/stmt_test.go +++ b/controller/ws/stmt/stmt_test.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -1047,3 +1048,74 @@ func TestBlock(t *testing.T) { assert.Equal(t, 200, w.Code) } + +type restResp struct { + Code int `json:"code"` + Desc string `json:"desc"` +} + +func doRestful(sql string, db string) (code int, message string) { + w := httptest.NewRecorder() + body := strings.NewReader(sql) + url := "/rest/sql" + if db != "" { + url = fmt.Sprintf("/rest/sql/%s", db) + } + req, _ := http.NewRequest(http.MethodPost, url, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + return w.Code, w.Body.String() + } + b, _ := io.ReadAll(w.Body) + var res restResp + _ = json.Unmarshal(b, &res) + return res.Code, res.Desc +} + +func doWebSocket(ws *websocket.Conn, action string, arg interface{}) (resp []byte, err error) { + var b []byte + if arg != nil { + b, _ = json.Marshal(arg) + } + a, _ := json.Marshal(wstool.WSAction{Action: action, Args: b}) + err = ws.WriteMessage(websocket.TextMessage, a) + if err != nil { + return nil, err + } + _, message, err := ws.ReadMessage() + return message, err +} + +func TestDropUser(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/stmt", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + defer doRestful("drop user test_ws_stmt_drop_user", "") + code, message := doRestful("create user test_ws_stmt_drop_user pass 'pass'", "") + assert.Equal(t, 0, code, message) + // connect + connReq := &StmtConnectReq{ReqID: 1, User: "test_ws_stmt_drop_user", Password: "pass"} + resp, err := doWebSocket(ws, STMTConnect, &connReq) + assert.NoError(t, err) + var connResp StmtConnectResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + // drop user + code, message = doRestful("drop user test_ws_stmt_drop_user", "") + assert.Equal(t, 0, code, message) + time.Sleep(time.Second * 3) + resp, err = doWebSocket(ws, wstool.ClientVersion, nil) + assert.Error(t, err, resp) +} diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index 657a0aaa..ea502ba7 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -325,13 +325,7 @@ func (t *TMQ) waitSignal(logger *logrus.Entry) { return } logger.Info("user dropped! close connection!") - s := log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return case <-t.whitelistChangeChan: logger.Info("get whitelist change signal") @@ -348,26 +342,14 @@ func (t *TMQ) waitSignal(logger *logrus.Entry) { logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { logger.Errorf("get whitelist error, close connection, err:%s", err) - s = log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(t.logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, t.ip) if !valid { logger.Errorf("ip not in whitelist, close connection, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) - s = log.GetLogNow(isDebug) - _ = t.session.Close() - logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) - t.Unlock() - s = log.GetLogNow(isDebug) - t.Close(t.logger) - logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) + t.signalExit(logger, isDebug) return } t.Unlock() @@ -377,6 +359,17 @@ func (t *TMQ) waitSignal(logger *logrus.Entry) { } } +func (t *TMQ) signalExit(logger *logrus.Entry, isDebug bool) { + logger.Trace("close session") + s := log.GetLogNow(isDebug) + _ = t.session.Close() + logger.Debugf("close session cost:%s", log.GetLogDuration(isDebug, s)) + t.Unlock() + s = log.GetLogNow(isDebug) + t.Close(logger) + logger.Debugf("close handler cost:%s", log.GetLogDuration(isDebug, s)) +} + func (t *TMQ) lock(logger *logrus.Entry, isDebug bool) { s := log.GetLogNow(isDebug) logger.Trace("get handler lock") diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index 733b66a3..1a7c208c 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -3155,3 +3155,54 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) } + +//func TestDropUser(t *testing.T) { +// defer doHttpSql("drop user test_tmq_drop_user") +// code, message := doHttpSql("create user test_tmq_drop_user pass 'pass'") +// assert.Equal(t, 0, code, message) +// +// dbName := "test_ws_tmq_drop_user" +// topic := "test_ws_tmq_drop_user_topic" +// +// before(t, dbName, topic) +// +// s := httptest.NewServer(router) +// defer s.Close() +// ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) +// if err != nil { +// t.Error(err) +// return +// } +// defer func() { +// err = ws.Close() +// assert.NoError(t, err) +// }() +// +// defer func() { +// err = after(ws, dbName, topic) +// assert.NoError(t, err) +// }() +// +// // subscribe +// b, _ := json.Marshal(TMQSubscribeReq{ +// User: "test_tmq_drop_user", +// Password: "pass", +// DB: dbName, +// GroupID: "test", +// Topics: []string{topic}, +// AutoCommit: "false", +// OffsetReset: "earliest", +// }) +// msg, err := doWebSocket(ws, TMQSubscribe, b) +// assert.NoError(t, err) +// var subscribeResp TMQSubscribeResp +// err = json.Unmarshal(msg, &subscribeResp) +// assert.NoError(t, err) +// assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) +// // drop user +// code, message = doHttpSql("drop user test_tmq_drop_user") +// assert.Equal(t, 0, code, message) +// time.Sleep(time.Second * 3) +// resp, err := doWebSocket(ws, wstool.ClientVersion, nil) +// assert.Error(t, err, string(resp)) +//} diff --git a/tools/pointer_test.go b/tools/pointer_test.go new file mode 100644 index 00000000..20164f60 --- /dev/null +++ b/tools/pointer_test.go @@ -0,0 +1,17 @@ +package tools + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestAddPointer(t *testing.T) { + s := []int32{1, 2, 3} + p0 := unsafe.Pointer(&s[0]) + p1 := AddPointer(p0, 4) + assert.Equal(t, *(*int32)(p1), s[1]) + p2 := AddPointer(p1, 4) + assert.Equal(t, *(*int32)(p2), s[2]) +} From c37974cf4ba84bc694875a51eff31d7ac0156110 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 21 Nov 2024 19:43:46 +0800 Subject: [PATCH 09/48] enh: refactor websocket --- controller/ws/query/ws.go | 79 +- controller/ws/schemaless/schemaless.go | 12 +- controller/ws/stmt/stmt.go | 28 +- controller/ws/tmq/tmq.go | 19 +- controller/ws/ws/const.go | 30 +- controller/ws/ws/fetch.go | 245 ++ controller/ws/ws/fetch_test.go | 70 + controller/ws/ws/free.go | 15 + controller/ws/ws/handler.go | 2581 ++++--------------- controller/ws/ws/handler_test.go | 424 +++- controller/ws/ws/misc.go | 67 + controller/ws/ws/misc_test.go | 90 + controller/ws/ws/query.go | 289 +++ controller/ws/ws/query_test.go | 1258 +++++++++ controller/ws/ws/raw.go | 77 + controller/ws/ws/raw_test.go | 124 + controller/ws/ws/resp.go | 112 + controller/ws/ws/schemaless.go | 57 + controller/ws/ws/schemaless_test.go | 241 ++ controller/ws/ws/stmt.go | 893 +++++++ controller/ws/ws/stmt2.go | 448 ++++ controller/ws/ws/stmt2_test.go | 742 ++++++ controller/ws/ws/stmt_test.go | 900 +++++++ controller/ws/ws/ws.go | 32 +- controller/ws/ws/ws_test.go | 3235 +----------------------- controller/ws/wstool/error.go | 2 +- controller/ws/wstool/error_test.go | 5 +- controller/ws/wstool/log.go | 12 +- controller/ws/wstool/log_test.go | 2 +- controller/ws/wstool/resp.go | 10 +- controller/ws/wstool/resp_test.go | 5 +- go.mod | 3 - go.sum | 2 - tools/bytesutil/bytesutil.go | 28 +- tools/bytesutil/bytesutil_test.go | 4 + tools/melody/config.go | 22 + tools/melody/melody.go | 147 ++ tools/melody/melody_test.go | 623 +++++ tools/melody/session.go | 214 ++ 39 files changed, 7601 insertions(+), 5546 deletions(-) create mode 100644 controller/ws/ws/fetch.go create mode 100644 controller/ws/ws/fetch_test.go create mode 100644 controller/ws/ws/free.go create mode 100644 controller/ws/ws/misc.go create mode 100644 controller/ws/ws/misc_test.go create mode 100644 controller/ws/ws/query.go create mode 100644 controller/ws/ws/query_test.go create mode 100644 controller/ws/ws/raw.go create mode 100644 controller/ws/ws/raw_test.go create mode 100644 controller/ws/ws/resp.go create mode 100644 controller/ws/ws/schemaless.go create mode 100644 controller/ws/ws/schemaless_test.go create mode 100644 controller/ws/ws/stmt.go create mode 100644 controller/ws/ws/stmt2.go create mode 100644 controller/ws/ws/stmt2_test.go create mode 100644 controller/ws/ws/stmt_test.go create mode 100644 tools/melody/config.go create mode 100644 tools/melody/melody.go create mode 100644 tools/melody/melody_test.go create mode 100644 tools/melody/session.go diff --git a/controller/ws/query/ws.go b/controller/ws/query/ws.go index fbe56fca..f1e2ea04 100644 --- a/controller/ws/query/ws.go +++ b/controller/ws/query/ws.go @@ -15,7 +15,6 @@ import ( "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/gin-gonic/gin" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/wrapper" @@ -32,6 +31,7 @@ import ( "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/iptool" "github.com/taosdata/taosadapter/v3/tools/jsontype" + "github.com/taosdata/taosadapter/v3/tools/melody" ) type QueryController struct { @@ -40,7 +40,7 @@ type QueryController struct { func NewQueryController() *QueryController { queryM := melody.New() - queryM.UpGrader.EnableCompression = true + queryM.Upgrader.EnableCompression = true queryM.Config.MaxMessageSize = 0 queryM.HandleConnect(func(session *melody.Session) { @@ -51,9 +51,6 @@ func NewQueryController() *QueryController { }) queryM.HandleMessage(func(session *melody.Session, data []byte) { - if queryM.IsClosed() { - return - } t := session.MustGet(TaosSessionKey).(*Taos) if t.closed { return @@ -61,6 +58,9 @@ func NewQueryController() *QueryController { t.wg.Add(1) go func() { defer t.wg.Done() + if t.closed { + return + } ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) logger := wstool.GetLogger(session) logger.Debugf("get ws message data: %s", data) @@ -72,7 +72,7 @@ func NewQueryController() *QueryController { } switch action.Action { case wstool.ClientVersion: - _ = session.Write(wstool.VersionResp) + wstool.WSWriteVersion(session, logger) case WSConnect: var wsConnect WSConnectReq err = json.Unmarshal(action.Args, &wsConnect) @@ -122,9 +122,6 @@ func NewQueryController() *QueryController { }) queryM.HandleMessageBinary(func(session *melody.Session, data []byte) { - if queryM.IsClosed() { - return - } t := session.MustGet(TaosSessionKey).(*Taos) if t.closed { return @@ -132,6 +129,9 @@ func NewQueryController() *QueryController { t.wg.Add(1) go func() { defer t.wg.Done() + if t.closed { + return + } ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) logger := wstool.GetLogger(session) logger.Tracef("get ws block message data:%+v", data) @@ -415,7 +415,7 @@ func (t *Taos) connect(ctx context.Context, session *melody.Session, req *WSConn } if t.conn != nil { logger.Trace("duplicate connections") - wsErrorMsg(ctx, session, 0xffff, "duplicate connections", WSConnect, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "duplicate connections", WSConnect, req.ReqID) return } conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug) @@ -499,7 +499,7 @@ func (t *Taos) query(ctx context.Context, session *melody.Session, req *WSQueryR ) if t.conn == nil { logger.Trace("server not connected") - wsErrorMsg(ctx, session, 0xffff, "server not connected", WSQuery, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSQuery, req.ReqID) return } logger.Tracef("req_id: 0x%x,query sql: %s", req.ReqID, req.SQL) @@ -521,7 +521,7 @@ func (t *Taos) query(ctx context.Context, session *melody.Session, req *WSQueryR logger.Errorf("query error, code: %d, message: %s", code, errStr) logger.Trace("get thread lock for free result") syncinterface.FreeResult(result.Res, logger, isDebug) - wsErrorMsg(ctx, session, code, errStr, WSQuery, req.ReqID) + wsErrorMsg(ctx, session, logger, code, errStr, WSQuery, req.ReqID) return } monitor.WSRecordResult(sqlType, true) @@ -593,7 +593,7 @@ func (t *Taos) writeRaw(ctx context.Context, session *melody.Session, reqID, mes } if t.conn == nil { logger.Error("server not connected") - wsTMQErrorMsg(ctx, session, 0xffff, "server not connected", WSWriteRaw, reqID, &messageID) + wsTMQErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSWriteRaw, reqID, &messageID) return } meta := wrapper.BuildRawMeta(length, metaType, data) @@ -609,7 +609,7 @@ func (t *Taos) writeRaw(ctx context.Context, session *melody.Session, reqID, mes if errCode != 0 { errStr := wrapper.TMQErr2Str(errCode) logger.Errorf("write raw meta error, code: %d, message: %s", errCode, errStr) - wsErrorMsg(ctx, session, int(errCode)&0xffff, errStr, WSWriteRaw, reqID) + wsErrorMsg(ctx, session, logger, int(errCode)&0xffff, errStr, WSWriteRaw, reqID) return } resp := &WSWriteMetaResp{Action: WSWriteRaw, ReqID: reqID, MessageID: messageID, Timing: wstool.GetDuration(ctx)} @@ -636,7 +636,7 @@ func (t *Taos) writeRawBlock(ctx context.Context, session *melody.Session, reqID return } if t.conn == nil { - wsErrorMsg(ctx, session, 0xffff, "server not connected", WSWriteRawBlock, reqID) + wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSWriteRawBlock, reqID) return } logger.Trace("get thread lock for write raw block") @@ -651,7 +651,7 @@ func (t *Taos) writeRawBlock(ctx context.Context, session *melody.Session, reqID if errCode != 0 { errStr := wrapper.TMQErr2Str(int32(errCode)) logger.Errorf("write raw block error, code: %d, message: %s", errCode, errStr) - wsErrorMsg(ctx, session, errCode&0xffff, errStr, WSWriteRawBlock, reqID) + wsErrorMsg(ctx, session, logger, errCode&0xffff, errStr, WSWriteRawBlock, reqID) return } resp := &WSWriteRawBlockResp{Action: WSWriteRawBlock, ReqID: reqID, Timing: wstool.GetDuration(ctx)} @@ -679,7 +679,7 @@ func (t *Taos) writeRawBlockWithFields(ctx context.Context, session *melody.Sess } if t.conn == nil { logger.Errorf("server not connected") - wsErrorMsg(ctx, session, 0xffff, "server not connected", WSWriteRawBlockWithFields, reqID) + wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSWriteRawBlockWithFields, reqID) return } logger.Trace("get thread lock for write raw block with fields") @@ -694,7 +694,7 @@ func (t *Taos) writeRawBlockWithFields(ctx context.Context, session *melody.Sess if errCode != 0 { errStr := wrapper.TMQErr2Str(int32(errCode)) logger.Errorf("write raw block with fields error, code: %d, message: %s", errCode, errStr) - wsErrorMsg(ctx, session, errCode&0xffff, errStr, WSWriteRawBlockWithFields, reqID) + wsErrorMsg(ctx, session, logger, errCode&0xffff, errStr, WSWriteRawBlockWithFields, reqID) return } resp := &WSWriteRawBlockWithFieldsResp{Action: WSWriteRawBlockWithFields, ReqID: reqID, Timing: wstool.GetDuration(ctx)} @@ -724,14 +724,14 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR ) if t.conn == nil { logger.Errorf("server not connected") - wsErrorMsg(ctx, session, 0xffff, "server not connected", WSFetch, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSFetch, req.ReqID) return } isDebug := log.IsDebug() resultItem := t.getResult(req.ID) if resultItem == nil { logger.Errorf("result is nil") - wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetch, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "result is nil", WSFetch, req.ReqID) return } resultS := resultItem.Value.(*Result) @@ -739,7 +739,7 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR if resultS.TaosResult == nil { resultS.Unlock() logger.Errorf("result is nil") - wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetch, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "result is nil", WSFetch, req.ReqID) return } s := log.GetLogNow(isDebug) @@ -768,7 +768,7 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR logger.Errorf("fetch raw block error, code: %d, message: %s", result.N, errStr) resultS.Unlock() t.FreeResult(resultItem, logger) - wsErrorMsg(ctx, session, result.N&0xffff, errStr, WSFetch, req.ReqID) + wsErrorMsg(ctx, session, logger, result.N&0xffff, errStr, WSFetch, req.ReqID) return } s = log.GetLogNow(isDebug) @@ -802,26 +802,26 @@ func (t *Taos) fetchBlock(ctx context.Context, session *melody.Session, req *WSF ) if t.conn == nil { logger.Error("server not connected") - wsErrorMsg(ctx, session, 0xffff, "server not connected", WSFetchBlock, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSFetchBlock, req.ReqID) return } isDebug := log.IsDebug() s := log.GetLogNow(isDebug) resultItem := t.getResult(req.ID) if resultItem == nil { - wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetchBlock, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "result is nil", WSFetchBlock, req.ReqID) return } resultS := resultItem.Value.(*Result) resultS.Lock() if resultS.TaosResult == nil { resultS.Unlock() - wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetchBlock, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "result is nil", WSFetchBlock, req.ReqID) return } if resultS.Block == nil { resultS.Unlock() - wsErrorMsg(ctx, session, 0xffff, "block is nil", WSFetchBlock, req.ReqID) + wsErrorMsg(ctx, session, logger, 0xffff, "block is nil", WSFetchBlock, req.ReqID) return } blockLength := int(parser.RawBlockGetLength(resultS.Block)) @@ -867,15 +867,6 @@ func (t *Taos) freeResult(req *WSFreeResultReq) { } } -type Writer struct { - session *melody.Session -} - -func (w *Writer) Write(p []byte) (int, error) { - err := w.session.Write(p) - return 0, err -} - func (t *Taos) FreeResult(element *list.Element, logger *logrus.Entry) { if element == nil { return @@ -968,16 +959,15 @@ type WSErrorResp struct { Timing int64 `json:"timing"` } -func wsErrorMsg(ctx context.Context, session *melody.Session, code int, message string, action string, reqID uint64) { - b, _ := json.Marshal(&WSErrorResp{ +func wsErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, action string, reqID uint64) { + data := &WSErrorResp{ Code: code & 0xffff, Message: message, Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), - }) - wstool.GetLogger(session).Tracef("write error message: %s", b) - _ = session.Write(b) + } + wstool.WSWriteJson(session, logger, data) } type WSTMQErrorResp struct { @@ -989,17 +979,16 @@ type WSTMQErrorResp struct { MessageID *uint64 `json:"message_id,omitempty"` } -func wsTMQErrorMsg(ctx context.Context, session *melody.Session, code int, message string, action string, reqID uint64, messageID *uint64) { - b, _ := json.Marshal(&WSTMQErrorResp{ +func wsTMQErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, action string, reqID uint64, messageID *uint64) { + data := &WSTMQErrorResp{ Code: code & 0xffff, Message: message, Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), MessageID: messageID, - }) - wstool.GetLogger(session).Tracef("write error message: %s", b) - _ = session.Write(b) + } + wstool.WSWriteJson(session, logger, data) } func init() { diff --git a/controller/ws/schemaless/schemaless.go b/controller/ws/schemaless/schemaless.go index b0554659..26cdcb99 100644 --- a/controller/ws/schemaless/schemaless.go +++ b/controller/ws/schemaless/schemaless.go @@ -9,7 +9,6 @@ import ( "unsafe" "github.com/gin-gonic/gin" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" tErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" @@ -22,6 +21,7 @@ import ( "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/taosdata/taosadapter/v3/tools/iptool" + "github.com/taosdata/taosadapter/v3/tools/melody" ) type SchemalessController struct { @@ -30,7 +30,7 @@ type SchemalessController struct { func NewSchemalessController() *SchemalessController { schemaless := melody.New() - schemaless.UpGrader.EnableCompression = true + schemaless.Upgrader.EnableCompression = true schemaless.Config.MaxMessageSize = 0 schemaless.HandleConnect(func(session *melody.Session) { @@ -40,9 +40,6 @@ func NewSchemalessController() *SchemalessController { }) schemaless.HandleMessage(func(session *melody.Session, bytes []byte) { - if schemaless.IsClosed() { - return - } t := session.MustGet(taosSchemalessKey).(*TaosSchemaless) if t.closed { return @@ -50,6 +47,9 @@ func NewSchemalessController() *SchemalessController { t.wg.Add(1) go func() { defer t.wg.Done() + if t.closed { + return + } ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) logger := wstool.GetLogger(session) logger.Debugf("get ws message data:%s", bytes) @@ -62,7 +62,7 @@ func NewSchemalessController() *SchemalessController { } switch action.Action { case wstool.ClientVersion: - _ = session.Write(wstool.VersionResp) + wstool.WSWriteVersion(session, logger) case SchemalessConn: var req schemalessConnReq if err = json.Unmarshal(action.Args, &req); err != nil { diff --git a/controller/ws/stmt/stmt.go b/controller/ws/stmt/stmt.go index e7315e98..2b14bd6c 100644 --- a/controller/ws/stmt/stmt.go +++ b/controller/ws/stmt/stmt.go @@ -13,7 +13,6 @@ import ( "unsafe" "github.com/gin-gonic/gin" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" "github.com/taosdata/driver-go/v3/common/parser" stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" @@ -31,6 +30,7 @@ import ( "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/taosdata/taosadapter/v3/tools/iptool" + "github.com/taosdata/taosadapter/v3/tools/melody" ) type STMTController struct { @@ -39,7 +39,7 @@ type STMTController struct { func NewSTMTController() *STMTController { stmtM := melody.New() - stmtM.UpGrader.EnableCompression = true + stmtM.Upgrader.EnableCompression = true stmtM.Config.MaxMessageSize = 0 stmtM.HandleConnect(func(session *melody.Session) { @@ -49,9 +49,6 @@ func NewSTMTController() *STMTController { }) stmtM.HandleMessage(func(session *melody.Session, data []byte) { - if stmtM.IsClosed() { - return - } t := session.MustGet(TaosStmtKey).(*TaosStmt) if t.closed { return @@ -59,6 +56,9 @@ func NewSTMTController() *STMTController { t.wg.Add(1) go func() { defer t.wg.Done() + if t.closed { + return + } ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) logger := wstool.GetLogger(session) logger.Debugf("get ws message data:%s", data) @@ -71,7 +71,7 @@ func NewSTMTController() *STMTController { } switch action.Action { case wstool.ClientVersion: - _ = session.Write(wstool.VersionResp) + wstool.WSWriteVersion(session, logger) case STMTConnect: var req StmtConnectReq err = json.Unmarshal(action.Args, &req) @@ -172,7 +172,7 @@ func NewSTMTController() *STMTController { }) stmtM.HandleMessageBinary(func(session *melody.Session, data []byte) { - if stmtM.IsClosed() { + if session.IsClosed() { return } t := session.MustGet(TaosStmtKey).(*TaosStmt) @@ -182,6 +182,9 @@ func NewSTMTController() *STMTController { t.wg.Add(1) go func() { defer t.wg.Done() + if t.closed { + return + } logger := wstool.GetLogger(session) logger.Tracef("get ws block message data:%+v", data) ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) @@ -197,9 +200,6 @@ func NewSTMTController() *STMTController { block := tools.AddPointer(p0, uintptr(24)) columns := parser.RawBlockGetNumOfCols(block) rows := parser.RawBlockGetNumOfRows(block) - if stmtM.IsClosed() { - return - } switch action { case BindMessage: t.bindBlock(ctx, session, reqID, stmtID, int(rows), int(columns), block) @@ -677,10 +677,10 @@ func (t *TaosStmt) setTags(ctx context.Context, session *melody.Session, req *St s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(tagNums, tagFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - tags := make([][]driver.Value, tagNums) - for i := 0; i < tagNums; i++ { - tags[i] = []driver.Value{req.Tags[i]} - } + //tags := make([][]driver.Value, tagNums) + //for i := 0; i < tagNums; i++ { + // tags[i] = []driver.Value{req.Tags[i]} + //} s = log.GetLogNow(isDebug) data, err := StmtParseTag(req.Tags, fields) logger.Debugf("stmt parse tag json cost:%s", log.GetLogDuration(isDebug, s)) diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index ea502ba7..ec6b3c94 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -12,7 +12,6 @@ import ( "unsafe" "github.com/gin-gonic/gin" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/driver-go/v3/common/parser" @@ -33,6 +32,7 @@ import ( "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/taosdata/taosadapter/v3/tools/iptool" "github.com/taosdata/taosadapter/v3/tools/jsontype" + "github.com/taosdata/taosadapter/v3/tools/melody" ) type TMQController struct { @@ -41,7 +41,7 @@ type TMQController struct { func NewTMQController() *TMQController { tmqM := melody.New() - tmqM.UpGrader.EnableCompression = true + tmqM.Upgrader.EnableCompression = true tmqM.Config.MaxMessageSize = 0 tmqM.HandleConnect(func(session *melody.Session) { @@ -51,9 +51,6 @@ func NewTMQController() *TMQController { }) tmqM.HandleMessage(func(session *melody.Session, data []byte) { - if tmqM.IsClosed() { - return - } t := session.MustGet(TaosTMQKey).(*TMQ) if t.isClosed() { return @@ -61,6 +58,9 @@ func NewTMQController() *TMQController { t.wg.Add(1) go func() { defer t.wg.Done() + if t.isClosed() { + return + } ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) logger := wstool.GetLogger(session) logger.Debugf("get ws message data:%s", data) @@ -72,7 +72,7 @@ func NewTMQController() *TMQController { } switch action.Action { case wstool.ClientVersion: - _ = session.Write(wstool.VersionResp) + wstool.WSWriteVersion(session, logger) case TMQSubscribe: var req TMQSubscribeReq err = json.Unmarshal(action.Args, &req) @@ -1300,16 +1300,15 @@ type WSTMQErrorResp struct { } func wsTMQErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, action string, reqID uint64, messageID *uint64) { - b, _ := json.Marshal(&WSTMQErrorResp{ + data := &WSTMQErrorResp{ Code: code & 0xffff, Message: message, Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), MessageID: messageID, - }) - logger.Tracef("write json:%s", b) - _ = session.Write(b) + } + wstool.WSWriteJson(session, logger, data) } func canGetData(messageType int32) bool { diff --git a/controller/ws/ws/const.go b/controller/ws/ws/const.go index e33bd2e0..59e95839 100644 --- a/controller/ws/ws/const.go +++ b/controller/ws/ws/const.go @@ -3,18 +3,20 @@ package ws const actionKey = "action" const TaosKey = "taos" const ( + //Deprecated + //WSWriteRaw = "write_raw" + //WSWriteRawBlock = "write_raw_block" + //WSWriteRawBlockWithFields = "write_raw_block_with_fields" + Connect = "conn" // websocket - WSQuery = "query" - WSFetch = "fetch" - WSFetchBlock = "fetch_block" - WSFreeResult = "free_result" - WSWriteRaw = "write_raw" - WSWriteRawBlock = "write_raw_block" - WSWriteRawBlockWithFields = "write_raw_block_with_fields" - WSGetCurrentDB = "get_current_db" - WSGetServerInfo = "get_server_info" - WSNumFields = "num_fields" + WSQuery = "query" + WSFetch = "fetch" + WSFetchBlock = "fetch_block" + WSFreeResult = "free_result" + WSGetCurrentDB = "get_current_db" + WSGetServerInfo = "get_server_info" + WSNumFields = "num_fields" // schemaless SchemalessWrite = "insert" @@ -43,10 +45,8 @@ const ( STMT2Close = "stmt2_close" ) -type messageType uint64 - const ( - _ messageType = iota + _ = iota SetTagsMessage BindMessage TMQRawMessage @@ -57,8 +57,8 @@ const ( Stmt2BindMessage = 9 ) -func (m messageType) String() string { - switch m { +func getActionString(binaryAction uint64) string { + switch binaryAction { case SetTagsMessage: return "set_tags" case BindMessage: diff --git a/controller/ws/ws/fetch.go b/controller/ws/ws/fetch.go new file mode 100644 index 00000000..7bb80c31 --- /dev/null +++ b/controller/ws/ws/fetch.go @@ -0,0 +1,245 @@ +package ws + +import ( + "context" + "encoding/binary" + + "github.com/sirupsen/logrus" + "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/driver-go/v3/wrapper" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/db/async" + "github.com/taosdata/taosadapter/v3/log" + "github.com/taosdata/taosadapter/v3/tools/bytesutil" + "github.com/taosdata/taosadapter/v3/tools/melody" +) + +func (h *messageHandler) resultValidateAndLock(ctx context.Context, session *melody.Session, action string, reqID uint64, resultID uint64, logger *logrus.Entry) (item *QueryResult, locked bool) { + item = h.queryResults.Get(resultID) + if item == nil { + logger.Errorf("result is nil, result_id:%d", resultID) + commonErrorResponse(ctx, session, logger, action, reqID, 0xffff, "result is nil") + return nil, false + } + item.Lock() + if item.TaosResult == nil { + item.Unlock() + logger.Errorf("result has been freed, result_id:%d", resultID) + commonErrorResponse(ctx, session, logger, action, reqID, 0xffff, "result has been freed") + return nil, false + } + return item, true +} + +type fetchRequest struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +type fetchResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + ID uint64 `json:"id"` + Completed bool `json:"completed"` + Lengths []int `json:"lengths"` + Rows int `json:"rows"` +} + +func (h *messageHandler) fetch(ctx context.Context, session *melody.Session, action string, req *fetchRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("get result by id, id:%d", req.ID) + item := h.queryResults.Get(req.ID) + if item == nil { + logger.Errorf("result is nil") + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "result is nil") + return + } + item.Lock() + if item.TaosResult == nil { + item.Unlock() + logger.Errorf("result has been freed") + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "result has been freed") + return + } + s := log.GetLogNow(isDebug) + handler := async.GlobalAsync.HandlerPool.Get() + defer async.GlobalAsync.HandlerPool.Put(handler) + logger.Debugf("get handler, cost:%s", log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + result := async.GlobalAsync.TaosFetchRawBlockA(item.TaosResult, logger, isDebug, handler) + logger.Debugf("fetch_raw_block_a, cost:%s", log.GetLogDuration(isDebug, s)) + if result.N == 0 { + item.Unlock() + logger.Trace("fetch raw block completed") + h.queryResults.FreeResultByID(req.ID, logger) + resp := fetchResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + ID: req.ID, + Completed: true, + } + wstool.WSWriteJson(session, logger, resp) + return + } + if result.N < 0 { + item.Unlock() + errStr := wrapper.TaosErrorStr(result.Res) + logger.Errorf("fetch raw block error, code:%d, message:%s", result.N, errStr) + h.queryResults.FreeResultByID(req.ID, logger) + commonErrorResponse(ctx, session, logger, action, req.ReqID, result.N, errStr) + return + } + s = log.GetLogNow(isDebug) + length := wrapper.FetchLengths(item.TaosResult, item.FieldsCount) + logger.Debugf("fetch_lengths result:%d, cost:%s", length, log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + logger.Trace("get raw block") + item.Block = wrapper.TaosGetRawBlock(item.TaosResult) + logger.Debugf("get_raw_block result:%p, cost:%s", item.Block, log.GetLogDuration(isDebug, s)) + item.Size = result.N + item.Unlock() + resp := fetchResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + ID: req.ID, + Lengths: length, + Rows: result.N, + } + wstool.WSWriteJson(session, logger, resp) +} + +type fetchBlockRequest struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +func (h *messageHandler) fetchBlock(ctx context.Context, session *melody.Session, action string, req *fetchBlockRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("fetch block, id:%d", req.ID) + item, locked := h.resultValidateAndLock(ctx, session, action, req.ReqID, req.ID, logger) + if !locked { + return + } + defer item.Unlock() + if item.Block == nil { + logger.Trace("block is nil") + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "block is nil") + return + } + + blockLength := int(parser.RawBlockGetLength(item.Block)) + if blockLength <= 0 { + logger.Errorf("block length illegal:%d", blockLength) + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "block length illegal") + return + } + s := log.GetLogNow(isDebug) + if cap(item.buf) < blockLength+16 { + item.buf = make([]byte, 0, blockLength+16) + } + item.buf = item.buf[:blockLength+16] + binary.LittleEndian.PutUint64(item.buf, uint64(wstool.GetDuration(ctx))) + binary.LittleEndian.PutUint64(item.buf[8:], req.ID) + bytesutil.Copy(item.Block, item.buf, 16, blockLength) + logger.Debugf("handle binary content cost:%s", log.GetLogDuration(isDebug, s)) + wstool.WSWriteBinary(session, item.buf, logger) +} + +func (h *messageHandler) fetchRawBlock(ctx context.Context, session *melody.Session, reqID uint64, resultID uint64, message []byte, logger *logrus.Entry, isDebug bool) { + if len(message) < 26 { + logger.Errorf("message length is too short") + fetchRawBlockErrorResponse(session, logger, 0xffff, "message length is too short", reqID, resultID, uint64(wstool.GetDuration(ctx))) + return + } + v := binary.LittleEndian.Uint16(message[24:]) + if v != BinaryProtocolVersion1 { + logger.Errorf("unknown fetch raw block version:%d", v) + fetchRawBlockErrorResponse(session, logger, 0xffff, "unknown fetch raw block version", reqID, resultID, uint64(wstool.GetDuration(ctx))) + return + } + item := h.queryResults.Get(resultID) + logger.Tracef("fetch raw block, result_id:%d", resultID) + if item == nil { + logger.Errorf("result is nil, result_id:%d", resultID) + fetchRawBlockErrorResponse(session, logger, 0xffff, "result is nil", reqID, resultID, uint64(wstool.GetDuration(ctx))) + return + } + item.Lock() + if item.TaosResult == nil { + item.Unlock() + logger.Errorf("result has been freed, result_id:%d", resultID) + fetchRawBlockErrorResponse(session, logger, 0xffff, "result has been freed", reqID, resultID, uint64(wstool.GetDuration(ctx))) + return + } + s := log.GetLogNow(isDebug) + handler := async.GlobalAsync.HandlerPool.Get() + defer async.GlobalAsync.HandlerPool.Put(handler) + logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) + result := async.GlobalAsync.TaosFetchRawBlockA(item.TaosResult, logger, isDebug, handler) + if result.N == 0 { + item.Unlock() + logger.Trace("fetch raw block success") + h.queryResults.FreeResultByID(resultID, logger) + fetchRawBlockFinishResponse(session, logger, reqID, resultID, uint64(wstool.GetDuration(ctx))) + return + } + if result.N < 0 { + item.Unlock() + errStr := wrapper.TaosErrorStr(result.Res) + logger.Errorf("fetch raw block error:%d %s", result.N, errStr) + h.queryResults.FreeResultByID(resultID, logger) + fetchRawBlockErrorResponse(session, logger, result.N, errStr, reqID, resultID, uint64(wstool.GetDuration(ctx))) + return + } + logger.Trace("call taos_get_raw_block") + s = log.GetLogNow(isDebug) + item.Block = wrapper.TaosGetRawBlock(item.TaosResult) + logger.Debugf("get_raw_block cost:%s", log.GetLogDuration(isDebug, s)) + item.Size = result.N + s = log.GetLogNow(isDebug) + blockLength := int(parser.RawBlockGetLength(item.Block)) + if blockLength <= 0 { + item.Unlock() + logger.Errorf("block length illegal:%d", blockLength) + fetchRawBlockErrorResponse(session, logger, 0xffff, "block length illegal", reqID, resultID, uint64(wstool.GetDuration(ctx))) + return + } + item.buf = fetchRawBlockMessage(item.buf, reqID, resultID, uint64(wstool.GetDuration(ctx)), int32(blockLength), item.Block) + logger.Debugf("handle binary content cost:%s", log.GetLogDuration(isDebug, s)) + item.Unlock() + wstool.WSWriteBinary(session, item.buf, logger) +} + +type numFieldsRequest struct { + ReqID uint64 `json:"req_id"` + ResultID uint64 `json:"result_id"` +} + +type numFieldsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + NumFields int `json:"num_fields"` +} + +func (h *messageHandler) numFields(ctx context.Context, session *melody.Session, action string, req *numFieldsRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("num fields, result_id:%d", req.ResultID) + item, locked := h.resultValidateAndLock(ctx, session, action, req.ReqID, req.ResultID, logger) + if !locked { + return + } + defer item.Unlock() + num := wrapper.TaosNumFields(item.TaosResult) + resp := &numFieldsResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + NumFields: num, + } + wstool.WSWriteJson(session, logger, resp) +} diff --git a/controller/ws/ws/fetch_test.go b/controller/ws/ws/fetch_test.go new file mode 100644 index 00000000..bb052399 --- /dev/null +++ b/controller/ws/ws/fetch_test.go @@ -0,0 +1,70 @@ +package ws + +import ( + "encoding/json" + "fmt" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +func TestNumFields(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + db := "test_ws_num_fields" + code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), db) + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), db) + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create stable if not exists %s.meters (ts timestamp,current float,voltage int,phase float) tags (groupid int,location varchar(24))", db), db) + assert.Equal(t, 0, code, message) + code, message = doRestful("INSERT INTO d1 USING meters TAGS (1, 'location1') VALUES (now, 10.2, 219, 0.31) "+ + "d2 USING meters TAGS (2, 'location2') VALUES (now, 10.3, 220, 0.32)", db) + assert.Equal(t, 0, code, message) + + defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // query + queryReq := queryRequest{ReqID: 2, Sql: "select * from meters"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), queryResp.ReqID) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // num fields + numFieldsReq := numFieldsRequest{ReqID: 3, ResultID: queryResp.ID} + resp, err = doWebSocket(ws, WSNumFields, &numFieldsReq) + assert.NoError(t, err) + var numFieldsResp numFieldsResponse + err = json.Unmarshal(resp, &numFieldsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), numFieldsResp.ReqID) + assert.Equal(t, 0, numFieldsResp.Code, numFieldsResp.Message) + assert.Equal(t, 6, numFieldsResp.NumFields) +} diff --git a/controller/ws/ws/free.go b/controller/ws/ws/free.go new file mode 100644 index 00000000..baa13f48 --- /dev/null +++ b/controller/ws/ws/free.go @@ -0,0 +1,15 @@ +package ws + +import ( + "github.com/sirupsen/logrus" +) + +type freeResultRequest struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +func (h *messageHandler) freeResult(req *freeResultRequest, logger *logrus.Entry) { + logger.Tracef("free result by id, id:%d", req.ID) + h.queryResults.FreeResultByID(req.ID, logger) +} diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 2970d39f..cb7011bd 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -2,40 +2,25 @@ package ws import ( "context" - "database/sql/driver" - "encoding/binary" "encoding/json" - "errors" "fmt" "net" - "strconv" "sync" "time" "unsafe" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/parser" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - errors2 "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/types" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" - "github.com/taosdata/taosadapter/v3/controller/ws/stmt" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" - "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" - "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" - "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/bytesutil" "github.com/taosdata/taosadapter/v3/tools/iptool" - "github.com/taosdata/taosadapter/v3/tools/jsontype" - "github.com/taosdata/taosadapter/v3/version" + "github.com/taosdata/taosadapter/v3/tools/melody" + "github.com/tidwall/gjson" ) type messageHandler struct { @@ -167,2168 +152,510 @@ type Request struct { Args json.RawMessage `json:"args"` } +func (h *messageHandler) stop() { + h.once.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + waitCh := make(chan struct{}, 1) + go func() { + h.wait.Wait() + close(waitCh) + }() + + select { + case <-ctx.Done(): + case <-waitCh: + } + // clean query result and stmt + h.queryResults.FreeAll(h.logger) + h.stmts.FreeAll(h.logger) + // clean connection + if h.conn != nil { + syncinterface.TaosClose(h.conn, h.logger, log.IsDebug()) + } + }) +} + func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) h.logger.Debugf("get ws message data:%s", data) - - var request Request - if err := json.Unmarshal(data, &request); err != nil { - h.logger.WithError(err).Errorln("unmarshal ws request") + jsonStr := bytesutil.ToUnsafeString(data) + action := gjson.Get(jsonStr, "action").String() + args := gjson.Get(jsonStr, "args") + if action == "" { + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, "", reqID, 0xffff, "request no action") return } + argsBytes := bytesutil.ToUnsafeBytes(args.Raw) - var f dealFunc - switch request.Action { + // no need connection actions + switch action { case wstool.ClientVersion: - f = h.handleVersion + wstool.WSWriteVersion(session, h.logger) + return case Connect: - f = h.handleConnect + action = Connect + var req connRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal connect request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, Connect, reqID, 0xffff, "unmarshal connect request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.connect(ctx, session, action, &req, logger, log.IsDebug()) + return + } + + // check connection + if h.conn == nil { + h.logger.Errorf("server not connected") + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "server not connected") + return + } + + // need connection actions + switch action { + // query case WSQuery: - f = h.handleQuery + action = WSQuery + var req queryRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal query request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal query request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.query(ctx, session, action, &req, logger, log.IsDebug()) case WSFetch: - f = h.handleFetch + action = WSFetch + var req fetchRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal fetch request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal fetch request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.fetch(ctx, session, action, &req, logger, log.IsDebug()) case WSFetchBlock: - f = h.handleFetchBlock + action = WSFetchBlock + var req fetchBlockRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal fetch block request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal fetch block request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.fetchBlock(ctx, session, action, &req, logger, log.IsDebug()) case WSFreeResult: - f = h.handleFreeResult + action = WSFreeResult + var req freeResultRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal free result request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal free result request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.freeResult(&req, logger) + case WSNumFields: + action = WSNumFields + var req numFieldsRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal num fields request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal num fields request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.numFields(ctx, session, action, &req, logger, log.IsDebug()) + // schemaless case SchemalessWrite: - f = h.handleSchemalessWrite + action = SchemalessWrite + var req schemalessWriteRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal schemaless insert request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal schemaless insert request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.schemalessWrite(ctx, session, action, &req, logger, log.IsDebug()) + // stmt case STMTInit: - f = h.handleStmtInit + action = STMTInit + var req stmtInitRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt init request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt init request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtInit(ctx, session, action, &req, logger, log.IsDebug()) case STMTPrepare: - f = h.handleStmtPrepare + action = STMTPrepare + var req stmtPrepareRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt prepare request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt prepare request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtPrepare(ctx, session, action, &req, logger, log.IsDebug()) case STMTSetTableName: - f = h.handleStmtSetTableName + action = STMTSetTableName + var req stmtSetTableNameRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt set table name request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt set table name request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtSetTableName(ctx, session, action, &req, logger, log.IsDebug()) case STMTSetTags: - f = h.handleStmtSetTags + action = STMTSetTags + var req stmtSetTagsRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt set tags request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt set tags request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtSetTags(ctx, session, action, &req, logger, log.IsDebug()) case STMTBind: - f = h.handleStmtBind + action = STMTBind + var req stmtBindRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt bind request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt bind request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtBind(ctx, session, action, &req, logger, log.IsDebug()) case STMTAddBatch: - f = h.handleStmtAddBatch + action = STMTAddBatch + var req stmtAddBatchRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt add batch request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt add batch request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtAddBatch(ctx, session, action, &req, logger, log.IsDebug()) case STMTExec: - f = h.handleStmtExec + action = STMTExec + var req stmtExecRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt exec request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt exec request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtExec(ctx, session, action, &req, logger, log.IsDebug()) case STMTClose: - f = h.handleStmtClose - case STMTGetColFields: - f = h.handleStmtGetColFields + action = STMTClose + var req stmtCloseRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt close request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt close request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtClose(ctx, session, action, &req, logger) case STMTGetTagFields: - f = h.handleStmtGetTagFields + action = STMTGetTagFields + var req stmtGetTagFieldsRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt get tag fields request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get tag fields request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtGetTagFields(ctx, session, action, &req, logger, log.IsDebug()) + case STMTGetColFields: + action = STMTGetColFields + var req stmtGetColFieldsRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt get col fields request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get col fields request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtGetColFields(ctx, session, action, &req, logger, log.IsDebug()) case STMTUseResult: - f = h.handleStmtUseResult + action = STMTUseResult + var req stmtUseResultRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt use result request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt use result request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtUseResult(ctx, session, action, &req, logger, log.IsDebug()) case STMTNumParams: - f = h.handleStmtNumParams + action = STMTNumParams + var req stmtNumParamsRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt num params request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt num params request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtNumParams(ctx, session, action, &req, logger, log.IsDebug()) case STMTGetParam: - f = h.handleStmtGetParam - case WSNumFields: - f = h.handleNumFields - case WSGetCurrentDB: - f = h.handleGetCurrentDB - case WSGetServerInfo: - f = h.handleGetServerInfo + action = STMTGetParam + var req stmtGetParamRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt get param request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get param request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmtGetParam(ctx, session, action, &req, logger, log.IsDebug()) + // stmt2 case STMT2Init: - f = h.handleStmt2Init + action = STMT2Init + var req stmt2InitRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 init request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 init request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmt2Init(ctx, session, action, &req, logger, log.IsDebug()) case STMT2Prepare: - f = h.handleStmt2Prepare + action = STMT2Prepare + var req stmt2PrepareRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 prepare request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 prepare request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmt2Prepare(ctx, session, action, &req, logger, log.IsDebug()) case STMT2GetFields: - f = h.handleStmt2GetFields + action = STMT2GetFields + var req stmt2GetFieldsRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 get fields request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 get fields request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmt2GetFields(ctx, session, action, &req, logger, log.IsDebug()) case STMT2Exec: - f = h.handleStmt2Exec + action = STMT2Exec + var req stmt2ExecRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 exec request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 exec request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmt2Exec(ctx, session, action, &req, logger, log.IsDebug()) case STMT2Result: - f = h.handleStmt2UseResult + action = STMT2Result + var req stmt2UseResultRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 result request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 result request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmt2UseResult(ctx, session, action, &req, logger, log.IsDebug()) case STMT2Close: - f = h.handleStmt2Close + action = STMT2Close + var req stmt2CloseRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 close request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 close request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.stmt2Close(ctx, session, action, &req, logger) + // misc + case WSGetCurrentDB: + action = WSGetCurrentDB + var req getCurrentDBRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal get current db request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal get current db request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.getCurrentDB(ctx, session, action, &req, logger, log.IsDebug()) + case WSGetServerInfo: + action = WSGetServerInfo + var req getServerInfoRequest + if err := json.Unmarshal(argsBytes, &req); err != nil { + h.logger.Errorf("unmarshal get server info request error, request:%s, err:%s", argsBytes, err) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal get server info request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.getServerInfo(ctx, session, action, &req, logger, log.IsDebug()) default: - f = h.handleDefault + h.logger.Errorf("unknown action %s", action) + reqID := getReqID(args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, fmt.Sprintf("unknown action %s", action)) } - h.deal(ctx, session, request, f) } -func (h *messageHandler) handleMessageBinary(session *melody.Session, bytes []byte) { +func (h *messageHandler) handleMessageBinary(session *melody.Session, message []byte) { //p0 uin64 req_id - //p0+8 uint64 message_id + //p0+8 uint64 resource_id(result_id or stmt_id) //p0+16 uint64 (1 (set tag) 2 (bind)) - h.logger.Tracef("get ws block message data:%+v", bytes) - p0 := unsafe.Pointer(&bytes[0]) + h.logger.Tracef("get ws block message data:%+v", message) + p0 := unsafe.Pointer(&message[0]) reqID := *(*uint64)(p0) - messageID := *(*uint64)(tools.AddPointer(p0, uintptr(8))) + resourceID := *(*uint64)(tools.AddPointer(p0, uintptr(8))) action := *(*uint64)(tools.AddPointer(p0, uintptr(16))) - h.logger.Debugf("get ws message binary QID:0x%x, messageID:%d, action:%d", reqID, messageID, action) + h.logger.Debugf("get ws message binary QID:0x%x, resourceID:%d, action:%d", reqID, resourceID, action) ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) - mt := messageType(action) + actionStr := getActionString(action) + logger := h.logger.WithField(actionKey, actionStr).WithField(config.ReqIDKey, reqID) - var f dealBinaryFunc - switch mt { + // check error connection + if h.conn == nil { + logger.Errorf("server not connected") + commonErrorResponse(ctx, session, h.logger, actionStr, reqID, 0xffff, "server not connected") + return + } + switch action { case SetTagsMessage: - f = h.handleSetTagsMessage + h.stmtBinarySetTags(ctx, session, actionStr, reqID, resourceID, message, logger, log.IsDebug()) case BindMessage: - f = h.handleBindMessage + h.stmtBinaryBind(ctx, session, actionStr, reqID, resourceID, message, logger, log.IsDebug()) case TMQRawMessage: - f = h.handleTMQRawMessage + h.binaryTMQRawMessage(ctx, session, actionStr, reqID, message, logger, log.IsDebug()) case RawBlockMessage: - f = h.handleRawBlockMessage + h.binaryRawBlockMessage(ctx, session, actionStr, reqID, message, logger, log.IsDebug()) case RawBlockMessageWithFields: - f = h.handleRawBlockMessageWithFields + h.binaryRawBlockMessageWithFields(ctx, session, actionStr, reqID, message, logger, log.IsDebug()) case BinaryQueryMessage: - f = h.handleBinaryQuery + h.binaryQuery(ctx, session, actionStr, reqID, message, logger, log.IsDebug()) case FetchRawBlockMessage: - f = h.handleFetchRawBlock + h.fetchRawBlock(ctx, session, reqID, resourceID, message, logger, log.IsDebug()) case Stmt2BindMessage: - f = h.handleStmt2Bind + h.stmt2BinaryBind(ctx, session, actionStr, reqID, resourceID, message, logger, log.IsDebug()) default: - f = h.handleDefaultBinary - } - h.dealBinary(ctx, session, mt, reqID, messageID, p0, bytes, f) -} - -type RequestID struct { - ReqID uint64 `json:"req_id"` -} - -type dealFunc func(context.Context, Request, *logrus.Entry, bool) Response - -type dealBinaryRequest struct { - action messageType - reqID uint64 - id uint64 // messageID or stmtID - p0 unsafe.Pointer - message []byte -} -type dealBinaryFunc func(context.Context, dealBinaryRequest, *logrus.Entry, bool) Response - -func (h *messageHandler) deal(ctx context.Context, session *melody.Session, request Request, f dealFunc) { - h.wait.Add(1) - go func() { - defer h.wait.Done() - isDebug := log.IsDebug() - reqID := request.ReqID - if reqID == 0 { - var req RequestID - _ = json.Unmarshal(request.Args, &req) - reqID = req.ReqID - } - request.ReqID = reqID - - logger := h.logger.WithFields(logrus.Fields{ - actionKey: request.Action, - config.ReqIDKey: reqID, - }) - - if h.conn == nil && request.Action != Connect && request.Action != wstool.ClientVersion { - logger.Errorf("server not connected") - resp := wsCommonErrorMsg(0xffff, "server not connected") - h.writeResponse(ctx, session, resp, request.Action, request.ReqID, logger) - return - } - - resp := f(ctx, request, logger, isDebug) - h.writeResponse(ctx, session, resp, request.Action, reqID, logger) - }() -} - -func (h *messageHandler) dealBinary(ctx context.Context, session *melody.Session, action messageType, reqID uint64, messageID uint64, p0 unsafe.Pointer, message []byte, f dealBinaryFunc) { - h.wait.Add(1) - go func() { - defer h.wait.Done() - - logger := h.logger.WithField(actionKey, action.String()).WithField(config.ReqIDKey, reqID) - isDebug := log.IsDebug() - if h.conn == nil { - resp := wsCommonErrorMsg(0xffff, "server not connected") - h.writeResponse(ctx, session, resp, action.String(), reqID, logger) - return - } - - req := dealBinaryRequest{ - action: action, - reqID: reqID, - id: messageID, - p0: p0, - message: message, - } - resp := f(ctx, req, logger, isDebug) - h.writeResponse(ctx, session, resp, action.String(), reqID, logger) - }() -} - -type BaseResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Action string `json:"action"` - ReqID uint64 `json:"req_id"` - Timing int64 `json:"timing"` - binary bool - null bool -} - -func (h *messageHandler) writeResponse(ctx context.Context, session *melody.Session, response Response, action string, reqID uint64, logger *logrus.Entry) { - if response == nil { - logger.Trace("response is nil") - // session closed handle return nil - return - } - if response.IsNull() { - logger.Trace("no need to response") - return - } - if response.IsBinary() { - logger.Tracef("write binary response:%v", response) - _ = session.WriteBinary(response.(*BinaryResponse).Data) - return - } - response.SetAction(action) - response.SetReqID(reqID) - response.SetTiming(wstool.GetDuration(ctx)) - - respByte, _ := json.Marshal(response) - logger.Tracef("write json response:%s", respByte) - _ = session.Write(respByte) -} - -func (h *messageHandler) stop() { - h.once.Do(func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - waitCh := make(chan struct{}, 1) - go func() { - h.wait.Wait() - close(waitCh) - }() - - select { - case <-ctx.Done(): - case <-waitCh: - } - // clean query result and stmt - h.queryResults.FreeAll(h.logger) - h.stmts.FreeAll(h.logger) - // clean connection - if h.conn != nil { - syncinterface.TaosClose(h.conn, h.logger, log.IsDebug()) - } - }) -} - -func (h *messageHandler) handleDefault(_ context.Context, request Request, _ *logrus.Entry, _ bool) (resp Response) { - return wsCommonErrorMsg(0xffff, fmt.Sprintf("unknown action %s", request.Action)) -} - -func (h *messageHandler) handleDefaultBinary(_ context.Context, req dealBinaryRequest, _ *logrus.Entry, _ bool) (resp Response) { - return wsCommonErrorMsg(0xffff, fmt.Sprintf("unknown action %v", req.action)) -} - -func (h *messageHandler) handleVersion(_ context.Context, _ Request, _ *logrus.Entry, _ bool) (resp Response) { - return &VersionResponse{Version: version.TaosClientVersion} -} - -type ConnRequest struct { - ReqID uint64 `json:"req_id"` - User string `json:"user"` - Password string `json:"password"` - DB string `json:"db"` - Mode *int `json:"mode"` -} - -func (h *messageHandler) handleConnect(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req ConnRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal connect request:%s, error, err:%s", string(request.Args), err) - return wsCommonErrorMsg(0xffff, "unmarshal connect request error") - } - - h.lock(logger, isDebug) - defer h.Unlock() - if h.closed { - logger.Trace("server closed") - return - } - if h.conn != nil { - logger.Trace("duplicate connections") - return wsCommonErrorMsg(0xffff, "duplicate connections") - } - - conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug) - - if err != nil { - logger.WithError(err).Errorln("connect to TDengine error") - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - return wsCommonErrorMsg(int(taosErr.Code), taosErr.ErrStr) - } - logger.Trace("get whitelist") - s := log.GetLogNow(isDebug) - whitelist, err := tool.GetWhitelist(conn) - logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) - if err != nil { - logger.WithError(err).Errorln("get whitelist error") - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - return wsCommonErrorMsg(int(taosErr.Code), taosErr.ErrStr) - } - logger.Tracef("check whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) - valid := tool.CheckWhitelist(whitelist, h.ip) - if !valid { - logger.Errorf("ip not in whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) - syncinterface.TaosClose(conn, logger, isDebug) - return wsCommonErrorMsg(0xffff, "whitelist prohibits current IP access") - } - s = log.GetLogNow(isDebug) - logger.Trace("register whitelist change") - err = tool.RegisterChangeWhitelist(conn, h.whitelistChangeHandle) - logger.Debugf("register whitelist change cost:%s", log.GetLogDuration(isDebug, s)) - if err != nil { - logger.WithError(err).Errorln("register whitelist change error") - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - return wsCommonErrorMsg(int(taosErr.Code), taosErr.ErrStr) + h.logger.Errorf("unknown binary action %d", action) + commonErrorResponse(ctx, session, h.logger, actionStr, reqID, 0xffff, fmt.Sprintf("unknown binary action %d", action)) } - s = log.GetLogNow(isDebug) - logger.Trace("register drop user") - err = tool.RegisterDropUser(conn, h.dropUserHandle) - logger.Debugf("register drop user cost:%s", log.GetLogDuration(isDebug, s)) - if err != nil { - logger.WithError(err).Errorln("register drop user error") - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - return wsCommonErrorMsg(int(taosErr.Code), taosErr.ErrStr) - } - if req.Mode != nil { - switch *req.Mode { - case common.TAOS_CONN_MODE_BI: - // BI mode - logger.Trace("set connection mode to BI") - code := wrapper.TaosSetConnMode(conn, common.TAOS_CONN_MODE_BI, 1) - logger.Trace("set connection mode to BI done") - if code != 0 { - logger.Errorf("set connection mode to BI error, err:%s", wrapper.TaosErrorStr(nil)) - syncinterface.TaosClose(conn, logger, isDebug) - return wsCommonErrorMsg(code, wrapper.TaosErrorStr(nil)) - } - default: - syncinterface.TaosClose(conn, logger, isDebug) - logger.Tracef("unexpected mode:%d", *req.Mode) - return wsCommonErrorMsg(0xffff, fmt.Sprintf("unexpected mode:%d", *req.Mode)) - } - } - h.conn = conn - logger.Trace("start wait signal goroutine") - go h.waitSignal(h.logger) - return &BaseResponse{} -} - -type QueryRequest struct { - ReqID uint64 `json:"req_id"` - Sql string `json:"sql"` -} - -type QueryResponse struct { - BaseResponse - ID uint64 `json:"id"` - IsUpdate bool `json:"is_update"` - AffectedRows int `json:"affected_rows"` - FieldsCount int `json:"fields_count"` - FieldsNames []string `json:"fields_names"` - FieldsTypes jsontype.JsonUint8 `json:"fields_types"` - FieldsLengths []int64 `json:"fields_lengths"` - Precision int `json:"precision"` } -func (h *messageHandler) handleQuery(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req QueryRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal ws query request %s error, err:%s", request.Args, err) - return wsCommonErrorMsg(0xffff, "unmarshal ws query request error") - } - sqlType := monitor.WSRecordRequest(req.Sql) - logger.Debugf("get query request, sql:%s", req.Sql) - s := log.GetLogNow(isDebug) - handler := async.GlobalAsync.HandlerPool.Get() - defer async.GlobalAsync.HandlerPool.Put(handler) - logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) - result := async.GlobalAsync.TaosQuery(h.conn, logger, isDebug, req.Sql, handler, int64(request.ReqID)) - code := wrapper.TaosError(result.Res) - if code != httperror.SUCCESS { - monitor.WSRecordResult(sqlType, false) - errStr := wrapper.TaosErrorStr(result.Res) - logger.Errorf("query error, code:%d, message:%s", code, errStr) - syncinterface.FreeResult(result.Res, logger, isDebug) - return wsCommonErrorMsg(code, errStr) - } - - monitor.WSRecordResult(sqlType, true) - logger.Trace("check is_update_query") - s = log.GetLogNow(isDebug) - isUpdate := wrapper.TaosIsUpdateQuery(result.Res) - logger.Debugf("get is_update_query %t, cost:%s", isUpdate, log.GetLogDuration(isDebug, s)) - if isUpdate { - s = log.GetLogNow(isDebug) - affectRows := wrapper.TaosAffectedRows(result.Res) - logger.Debugf("affected_rows %d cost:%s", affectRows, log.GetLogDuration(isDebug, s)) - syncinterface.FreeResult(result.Res, logger, isDebug) - return &QueryResponse{IsUpdate: true, AffectedRows: affectRows} - } - s = log.GetLogNow(isDebug) - fieldsCount := wrapper.TaosNumFields(result.Res) - logger.Debugf("get num_fields:%d, cost:%s", fieldsCount, log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - rowsHeader, _ := wrapper.ReadColumn(result.Res, fieldsCount) - logger.Debugf("read column cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - precision := wrapper.TaosResultPrecision(result.Res) - logger.Debugf("get result_precision:%d, cost:%s", precision, log.GetLogDuration(isDebug, s)) - queryResult := QueryResult{TaosResult: result.Res, FieldsCount: fieldsCount, Header: rowsHeader, precision: precision} - idx := h.queryResults.Add(&queryResult) - logger.Trace("add result to list finished") - - return &QueryResponse{ - ID: idx, - FieldsCount: fieldsCount, - FieldsNames: rowsHeader.ColNames, - FieldsLengths: rowsHeader.ColLength, - FieldsTypes: rowsHeader.ColTypes, - Precision: precision, - } -} - -type FetchRequest struct { - ReqID uint64 `json:"req_id"` - ID uint64 `json:"id"` -} - -type FetchResponse struct { - BaseResponse - ID uint64 `json:"id"` - Completed bool `json:"completed"` - Lengths []int `json:"lengths"` - Rows int `json:"rows"` -} - -func (h *messageHandler) handleFetch(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req FetchRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal ws fetch request %s error, err:%s", request.Args, err) - return wsCommonErrorMsg(0xffff, "unmarshal ws fetch request error") - } - - logger.Tracef("get result by id, id:%d", req.ID) - item := h.queryResults.Get(req.ID) - if item == nil { - logger.Errorf("result is nil") - return wsCommonErrorMsg(0xffff, "result is nil") - } - item.Lock() - if item.TaosResult == nil { - item.Unlock() - logger.Errorf("result has been freed") - return wsCommonErrorMsg(0xffff, "result has been freed") - } - s := log.GetLogNow(isDebug) - handler := async.GlobalAsync.HandlerPool.Get() - defer async.GlobalAsync.HandlerPool.Put(handler) - logger.Debugf("get handler, cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - result := async.GlobalAsync.TaosFetchRawBlockA(item.TaosResult, logger, isDebug, handler) - logger.Debugf("fetch_raw_block_a, cost:%s", log.GetLogDuration(isDebug, s)) - if result.N == 0 { - logger.Trace("fetch raw block completed") - item.Unlock() - h.queryResults.FreeResultByID(req.ID, logger) - return &FetchResponse{ID: req.ID, Completed: true} - } - if result.N < 0 { - item.Unlock() - errStr := wrapper.TaosErrorStr(result.Res) - logger.Errorf("fetch raw block error, code:%d, message:%s", result.N, errStr) - h.queryResults.FreeResultByID(req.ID, logger) - return wsCommonErrorMsg(0xffff, errStr) - } - s = log.GetLogNow(isDebug) - length := wrapper.FetchLengths(item.TaosResult, item.FieldsCount) - logger.Debugf("fetch_lengths result:%d, cost:%s", length, log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - logger.Trace("get raw block") - item.Block = wrapper.TaosGetRawBlock(item.TaosResult) - logger.Debugf("get_raw_block result:%p, cost:%s", item.Block, log.GetLogDuration(isDebug, s)) - item.Size = result.N - item.Unlock() - return &FetchResponse{ID: req.ID, Lengths: length, Rows: result.N} -} - -type FetchBlockRequest struct { - ReqID uint64 `json:"req_id"` - ID uint64 `json:"id"` -} - -func (h *messageHandler) handleFetchBlock(ctx context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req FetchBlockRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal ws fetch block request, req:%s, error, err:%s", request.Args, err) - return wsCommonErrorMsg(0xffff, "unmarshal ws fetch block request error") - } - - item := h.queryResults.Get(req.ID) - if item == nil { - logger.Errorf("result is nil") - return wsCommonErrorMsg(0xffff, "result is nil") - } - item.Lock() - defer item.Unlock() - if item.TaosResult == nil { - logger.Trace("result has been freed") - return wsCommonErrorMsg(0xffff, "result has been freed") - } - if item.Block == nil { - logger.Trace("block is nil") - return wsCommonErrorMsg(0xffff, "block is nil") - } - - blockLength := int(parser.RawBlockGetLength(item.Block)) - if blockLength <= 0 { - return wsCommonErrorMsg(0xffff, "block length illegal") - } - s := log.GetLogNow(isDebug) - if cap(item.buf) < blockLength+16 { - item.buf = make([]byte, 0, blockLength+16) - } - item.buf = item.buf[:blockLength+16] - binary.LittleEndian.PutUint64(item.buf, uint64(wstool.GetDuration(ctx))) - binary.LittleEndian.PutUint64(item.buf[8:], req.ID) - bytesutil.Copy(item.Block, item.buf, 16, blockLength) - logger.Debugf("handle binary content cost:%s", log.GetLogDuration(isDebug, s)) - resp = &BinaryResponse{Data: item.buf} - resp.SetBinary(true) - return resp -} - -type FreeResultRequest struct { - ReqID uint64 `json:"req_id"` - ID uint64 `json:"id"` -} - -func (h *messageHandler) handleFreeResult(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { - var req FreeResultRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal ws fetch request %s error, err:%s", request.Args, err) - return wsCommonErrorMsg(0xffff, "unmarshal connect request error") - } - logger.Tracef("free result by id, id:%d", req.ID) - h.queryResults.FreeResultByID(req.ID, logger) - resp = &BaseResponse{} - resp.SetNull(true) - return resp -} - -type SchemalessWriteRequest struct { - ReqID uint64 `json:"req_id"` - Protocol int `json:"protocol"` - Precision string `json:"precision"` - TTL int `json:"ttl"` - Data string `json:"data"` - TableNameKey string `json:"table_name_key"` -} - -type SchemalessWriteResponse struct { - BaseResponse - AffectedRows int `json:"affected_rows"` - TotalRows int32 `json:"total_rows"` -} - -func (h *messageHandler) handleSchemalessWrite(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req SchemalessWriteRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal schemaless write request %s error, err:%s", request.Args, err) - return wsCommonErrorMsg(0xffff, "unmarshal schemaless write request error") - } - - if req.Protocol == 0 { - logger.Errorf("schemaless write request %s args error. protocol is null", request.Args) - return wsCommonErrorMsg(0xffff, "args error") - } - var totalRows int32 - var affectedRows int - totalRows, result := syncinterface.TaosSchemalessInsertRawTTLWithReqIDTBNameKey(h.conn, req.Data, req.Protocol, req.Precision, req.TTL, int64(request.ReqID), req.TableNameKey, logger, isDebug) - logger.Tracef("total_rows:%d, result:%p", totalRows, result) - defer syncinterface.FreeResult(result, logger, isDebug) - affectedRows = wrapper.TaosAffectedRows(result) - if code := wrapper.TaosError(result); code != 0 { - logger.Errorf("schemaless write error, err:%s", wrapper.TaosErrorStr(result)) - return wsCommonErrorMsg(code, wrapper.TaosErrorStr(result)) - } - logger.Tracef("schemaless write total rows:%d, affected rows:%d", totalRows, affectedRows) - return &SchemalessWriteResponse{ - TotalRows: totalRows, - AffectedRows: affectedRows, - } -} - -type StmtInitResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func (h *messageHandler) handleStmtInit(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - stmtInit := syncinterface.TaosStmtInitWithReqID(h.conn, int64(request.ReqID), logger, isDebug) - if stmtInit == nil { - errStr := wrapper.TaosStmtErrStr(stmtInit) - logger.Errorf("stmt init error, err:%s", errStr) - return wsCommonErrorMsg(0xffff, errStr) - } - stmtItem := &StmtItem{stmt: stmtInit} - h.stmts.Add(stmtItem) - logger.Tracef("stmt init sucess, stmt_id:%d, stmt pointer:%p", stmtItem.index, stmtInit) - return &StmtInitResponse{StmtID: stmtItem.index} -} - -type StmtPrepareRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` - SQL string `json:"sql"` -} - -type StmtPrepareResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - IsInsert bool `json:"is_insert"` -} - -func (h *messageHandler) handleStmtPrepare(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtPrepareRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt prepare request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal connect request error", req.StmtID) - } - logger.Debugf("stmt prepare, stmt_id:%d, sql:%s", req.StmtID, req.SQL) - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - s := log.GetLogNow(isDebug) - logger.Trace("get stmt lock") - stmtItem.Lock() - logger.Debugf("get stmt lock cost:%s", log.GetLogDuration(isDebug, s)) - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code := syncinterface.TaosStmtPrepare(stmtItem.stmt, req.SQL, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt prepare error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - logger.Tracef("stmt prepare success, stmt_id:%d", req.StmtID) - isInsert, code := syncinterface.TaosStmtIsInsert(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("check stmt is insert error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - logger.Tracef("stmt is insert:%t", isInsert) - stmtItem.isInsert = isInsert - return &StmtPrepareResponse{StmtID: req.StmtID, IsInsert: isInsert} -} - -type StmtSetTableNameRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` - Name string `json:"name"` -} - -type StmtSetTableNameResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func (h *messageHandler) handleStmtSetTableName(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtSetTableNameRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt set table name request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt set table name request error", req.StmtID) - } - - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code := syncinterface.TaosStmtSetTBName(stmtItem.stmt, req.Name, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt set table name error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - logger.Tracef("stmt set table name success, stmt_id:%d", req.StmtID) - return &StmtSetTableNameResponse{StmtID: req.StmtID} -} - -type StmtSetTagsRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` - Tags json.RawMessage `json:"tags"` -} - -type StmtSetTagsResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func (h *messageHandler) handleStmtSetTags(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtSetTagsRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt set tags request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt set tags request error", req.StmtID) - } - logger.Tracef("stmt set tags, stmt_id:%d, tags:%s", req.StmtID, req.Tags) - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code, tagNums, tagFields := syncinterface.TaosStmtGetTagFields(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt get tag fields error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - defer func() { - wrapper.TaosStmtReclaimFields(stmtItem.stmt, tagFields) - }() - logger.Tracef("stmt tag nums:%d", tagNums) - if tagNums == 0 { - logger.Trace("no tags") - return &StmtSetTagsResponse{StmtID: req.StmtID} - } - s := log.GetLogNow(isDebug) - fields := wrapper.StmtParseFields(tagNums, tagFields) - logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - tags := make([][]driver.Value, tagNums) - for i := 0; i < tagNums; i++ { - tags[i] = []driver.Value{req.Tags[i]} - } - data, err := stmt.StmtParseTag(req.Tags, fields) - logger.Debugf("stmt parse tag json cost:%s", log.GetLogDuration(isDebug, s)) - if err != nil { - logger.Errorf("stmt parse tag json error, err:%s", err.Error()) - return wsStmtErrorMsg(0xffff, fmt.Sprintf("stmt parse tag json:%s", err.Error()), req.StmtID) - } - code = syncinterface.TaosStmtSetTags(stmtItem.stmt, data, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt set tags error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - logger.Trace("stmt set tags success") - return &StmtSetTagsResponse{StmtID: req.StmtID} -} - -type StmtBindRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` - Columns json.RawMessage `json:"columns"` -} - -type StmtBindResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func (h *messageHandler) handleStmtBind(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtBindRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt bind tag request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt bind request error", req.StmtID) - } - - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code, colNums, colFields := syncinterface.TaosStmtGetColFields(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt get col fields error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - defer func() { - wrapper.TaosStmtReclaimFields(stmtItem.stmt, colFields) - }() - if colNums == 0 { - logger.Trace("no columns") - return &StmtBindResponse{StmtID: req.StmtID} - } - s := log.GetLogNow(isDebug) - fields := wrapper.StmtParseFields(colNums, colFields) - logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - fieldTypes := make([]*types.ColumnType, colNums) - - var err error - for i := 0; i < colNums; i++ { - if fieldTypes[i], err = fields[i].GetType(); err != nil { - logger.Errorf("stmt get column type error, err:%s", err.Error()) - return wsStmtErrorMsg(0xffff, fmt.Sprintf("stmt get column type error, err:%s", err.Error()), req.StmtID) - } - } - s = log.GetLogNow(isDebug) - data, err := stmt.StmtParseColumn(req.Columns, fields, fieldTypes) - logger.Debugf("stmt parse column json cost:%s", log.GetLogDuration(isDebug, s)) - if err != nil { - logger.Errorf("stmt parse column json error, err:%s", err.Error()) - return wsStmtErrorMsg(0xffff, fmt.Sprintf("stmt parse column json:%s", err.Error()), req.StmtID) - } - code = syncinterface.TaosStmtBindParamBatch(stmtItem.stmt, data, fieldTypes, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt bind param error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - logger.Trace("stmt bind success") - return &StmtBindResponse{StmtID: req.StmtID} -} - -func (h *messageHandler) handleBindMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { - block := tools.AddPointer(req.p0, uintptr(24)) - columns := parser.RawBlockGetNumOfCols(block) - rows := parser.RawBlockGetNumOfRows(block) - logger.Tracef("bind message, stmt_id:%d columns:%d, rows:%d", req.id, columns, rows) - stmtItem := h.stmts.Get(req.id) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.id) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.id) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.id) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.id) - } - var data [][]driver.Value - var fieldTypes []*types.ColumnType - if stmtItem.isInsert { - code, colNums, colFields := syncinterface.TaosStmtGetColFields(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt get col fields error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.id) - } - defer func() { - wrapper.TaosStmtReclaimFields(stmtItem.stmt, colFields) - }() - if colNums == 0 { - logger.Trace("no columns") - return &StmtBindResponse{StmtID: req.id} - } - s := log.GetLogNow(isDebug) - fields := wrapper.StmtParseFields(colNums, colFields) - logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - fieldTypes = make([]*types.ColumnType, colNums) - var err error - for i := 0; i < colNums; i++ { - fieldTypes[i], err = fields[i].GetType() - if err != nil { - logger.Errorf("stmt get column type error, err:%s", err.Error()) - return wsStmtErrorMsg(0xffff, fmt.Sprintf("stmt get column type error, err:%s", err.Error()), req.id) - } - } - if int(columns) != colNums { - logger.Errorf("stmt column count not match %d != %d", columns, colNums) - return wsStmtErrorMsg(0xffff, "stmt column count not match", req.id) - } - s = log.GetLogNow(isDebug) - data = stmt.BlockConvert(block, int(rows), fields, fieldTypes) - logger.Debugf("block convert cost:%s", log.GetLogDuration(isDebug, s)) - } else { - var fields []*stmtCommon.StmtField - var err error - logger.Trace("parse row block info") - fields, fieldTypes, err = parseRowBlockInfo(block, int(columns)) - if err != nil { - logger.Errorf("parse row block info error, err:%s", err.Error()) - return wsStmtErrorMsg(0xffff, fmt.Sprintf("parse row block info error, err:%s", err.Error()), req.id) - } - logger.Trace("convert block to data") - data = stmt.BlockConvert(block, int(rows), fields, fieldTypes) - logger.Trace("convert block to data finish") - } - - code := syncinterface.TaosStmtBindParamBatch(stmtItem.stmt, data, fieldTypes, logger, isDebug) - if code != 0 { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt bind param error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.id) - } - logger.Trace("stmt bind param success") - return &StmtBindResponse{StmtID: req.id} -} - -func parseRowBlockInfo(block unsafe.Pointer, columns int) (fields []*stmtCommon.StmtField, fieldTypes []*types.ColumnType, err error) { - infos := make([]parser.RawBlockColInfo, columns) - parser.RawBlockGetColInfo(block, infos) - - fields = make([]*stmtCommon.StmtField, len(infos)) - fieldTypes = make([]*types.ColumnType, len(infos)) - - for i, info := range infos { - switch info.ColType { - case common.TSDB_DATA_TYPE_BOOL: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_BOOL} - fieldTypes[i] = &types.ColumnType{Type: types.TaosBoolType} - case common.TSDB_DATA_TYPE_TINYINT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_TINYINT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosTinyintType} - case common.TSDB_DATA_TYPE_SMALLINT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_SMALLINT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosSmallintType} - case common.TSDB_DATA_TYPE_INT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_INT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosIntType} - case common.TSDB_DATA_TYPE_BIGINT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_BIGINT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosBigintType} - case common.TSDB_DATA_TYPE_FLOAT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_FLOAT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosFloatType} - case common.TSDB_DATA_TYPE_DOUBLE: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_DOUBLE} - fieldTypes[i] = &types.ColumnType{Type: types.TaosDoubleType} - case common.TSDB_DATA_TYPE_BINARY: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_BINARY} - fieldTypes[i] = &types.ColumnType{Type: types.TaosBinaryType} - //case common.TSDB_DATA_TYPE_TIMESTAMP:// todo precision - // fields[i] = &stmtCommon.StmtField{FieldType:common.TSDB_DATA_TYPE_TIMESTAMP} - // fieldTypes[i] = &types.ColumnType{Type:types.TaosTimestampType} - case common.TSDB_DATA_TYPE_NCHAR: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_NCHAR} - fieldTypes[i] = &types.ColumnType{Type: types.TaosNcharType} - case common.TSDB_DATA_TYPE_UTINYINT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_UTINYINT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosUTinyintType} - case common.TSDB_DATA_TYPE_USMALLINT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_USMALLINT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosUSmallintType} - case common.TSDB_DATA_TYPE_UINT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_UINT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosUIntType} - case common.TSDB_DATA_TYPE_UBIGINT: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_UBIGINT} - fieldTypes[i] = &types.ColumnType{Type: types.TaosUBigintType} - case common.TSDB_DATA_TYPE_JSON: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_JSON} - fieldTypes[i] = &types.ColumnType{Type: types.TaosJsonType} - case common.TSDB_DATA_TYPE_VARBINARY: - fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_VARBINARY} - fieldTypes[i] = &types.ColumnType{Type: types.TaosBinaryType} - default: - err = fmt.Errorf("unsupported data type %d", info.ColType) - } - } - - return -} - -type StmtAddBatchRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type StmtAddBatchResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func (h *messageHandler) handleStmtAddBatch(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtAddBatchRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt add batch request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt add batch request error", req.StmtID) - } - - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code := syncinterface.TaosStmtAddBatch(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt add batch error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - logger.Trace("stmt add batch success") - return &StmtAddBatchResponse{StmtID: req.StmtID} -} - -type StmtExecRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type StmtExecResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - Affected int `json:"affected"` -} - -func (h *messageHandler) handleStmtExec(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtExecRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt exec request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt exec request error", req.StmtID) - } - logger.Tracef("stmt execute, stmt_id:%d", req.StmtID) - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code := syncinterface.TaosStmtExecute(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt execute error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - s := log.GetLogNow(isDebug) - affected := wrapper.TaosStmtAffectedRowsOnce(stmtItem.stmt) - logger.Debugf("stmt_affected_rows_once, affected:%d, cost:%s", affected, log.GetLogDuration(isDebug, s)) - return &StmtExecResponse{StmtID: req.StmtID, Affected: affected} -} - -type StmtCloseRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type StmtCloseResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id,omitempty"` -} - -func (h *messageHandler) handleStmtClose(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { - var req StmtCloseRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt close request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt close request error", req.StmtID) - } - logger.Tracef("stmt close, stmt_id:%d", req.StmtID) - err := h.stmts.FreeStmtByID(req.StmtID, false, logger) - if err != nil { - logger.Errorf("stmt close error, err:%s", err.Error()) - return wsStmtErrorMsg(0xffff, "unmarshal stmt close request error", req.StmtID) - } - resp = &BaseResponse{} - resp.SetNull(true) - logger.Tracef("stmt close success, stmt_id:%d", req.StmtID) - return resp -} - -type StmtGetColFieldsRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type StmtGetColFieldsResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - Fields []*stmtCommon.StmtField `json:"fields"` -} - -func (h *messageHandler) handleStmtGetColFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtGetColFieldsRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt get col request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt get col request error", req.StmtID) - } - logger.Tracef("stmt get col fields, stmt_id:%d", req.StmtID) - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code, colNums, colFields := syncinterface.TaosStmtGetColFields(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt get col fields error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - defer func() { - wrapper.TaosStmtReclaimFields(stmtItem.stmt, colFields) - }() - if colNums == 0 { - return &StmtGetColFieldsResponse{StmtID: req.StmtID} - } - s := log.GetLogNow(isDebug) - fields := wrapper.StmtParseFields(colNums, colFields) - logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - return &StmtGetColFieldsResponse{StmtID: req.StmtID, Fields: fields} -} - -type StmtGetTagFieldsRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type StmtGetTagFieldsResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - Fields []*stmtCommon.StmtField `json:"fields,omitempty"` -} - -func (h *messageHandler) handleStmtGetTagFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtGetTagFieldsRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt get tags request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt get tags request error", req.StmtID) - } - logger.Tracef("stmt get tag fields, stmt_id:%d", req.StmtID) - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code, tagNums, tagFields := syncinterface.TaosStmtGetTagFields(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt get tag fields error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - defer func() { - wrapper.TaosStmtReclaimFields(stmtItem.stmt, tagFields) - }() - if tagNums == 0 { - return &StmtGetTagFieldsResponse{StmtID: req.StmtID} - } - s := log.GetLogNow(isDebug) - fields := wrapper.StmtParseFields(tagNums, tagFields) - logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - return &StmtGetTagFieldsResponse{StmtID: req.StmtID, Fields: fields} -} - -type StmtUseResultRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type StmtUseResultResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - ResultID uint64 `json:"result_id"` - FieldsCount int `json:"fields_count"` - FieldsNames []string `json:"fields_names"` - FieldsTypes jsontype.JsonUint8 `json:"fields_types"` - FieldsLengths []int64 `json:"fields_lengths"` - Precision int `json:"precision"` -} - -func (h *messageHandler) handleStmtUseResult(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { - var req StmtUseResultRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt use result request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt use result request error", req.StmtID) - } - logger.Tracef("stmt use result, stmt_id:%d", req.StmtID) - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - logger.Trace("call stmt use result") - result := wrapper.TaosStmtUseResult(stmtItem.stmt) - if result == nil { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt use result error, err:%s", errStr) - return wsStmtErrorMsg(0xffff, errStr, req.StmtID) - } - - fieldsCount := wrapper.TaosNumFields(result) - rowsHeader, _ := wrapper.ReadColumn(result, fieldsCount) - precision := wrapper.TaosResultPrecision(result) - logger.Tracef("stmt use result success, stmt_id:%d, fields_count:%d, precision:%d", req.StmtID, fieldsCount, precision) - queryResult := QueryResult{TaosResult: result, FieldsCount: fieldsCount, Header: rowsHeader, precision: precision, inStmt: true} - idx := h.queryResults.Add(&queryResult) - - return &StmtUseResultResponse{ - StmtID: req.StmtID, - ResultID: idx, - FieldsCount: fieldsCount, - FieldsNames: rowsHeader.ColNames, - FieldsTypes: rowsHeader.ColTypes, - FieldsLengths: rowsHeader.ColLength, - Precision: precision, - } -} - -func (h *messageHandler) handleSetTagsMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { - block := tools.AddPointer(req.p0, uintptr(24)) - columns := parser.RawBlockGetNumOfCols(block) - rows := parser.RawBlockGetNumOfRows(block) - logger.Tracef("set tags message, stmt_id:%d, columns:%d, rows:%d", req.id, columns, rows) - if rows != 1 { - return wsStmtErrorMsg(0xffff, "rows not equal 1", req.id) - } - - stmtItem := h.stmts.Get(req.id) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.id) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.id) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.id) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.id) - } - code, tagNums, tagFields := syncinterface.TaosStmtGetTagFields(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt get tag fields error:%d %s", code, errStr) - return wsStmtErrorMsg(code, errStr, req.id) - } - defer func() { - wrapper.TaosStmtReclaimFields(stmtItem.stmt, tagFields) - }() - if tagNums == 0 { - logger.Trace("no tags") - return &StmtSetTagsResponse{StmtID: req.id} - } - if int(columns) != tagNums { - logger.Tracef("stmt tags count not match %d != %d", columns, tagNums) - return wsStmtErrorMsg(0xffff, "stmt tags count not match", req.id) - } - s := log.GetLogNow(isDebug) - fields := wrapper.StmtParseFields(tagNums, tagFields) - logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - tags := stmt.BlockConvert(block, int(rows), fields, nil) - logger.Debugf("block concert cost:%s", log.GetLogDuration(isDebug, s)) - reTags := make([]driver.Value, tagNums) - for i := 0; i < tagNums; i++ { - reTags[i] = tags[i][0] - } - code = syncinterface.TaosStmtSetTags(stmtItem.stmt, reTags, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt set tags error, code:%d, msg:%s", code, errStr) - return wsStmtErrorMsg(code, errStr, req.id) - } - - return &StmtSetTagsResponse{StmtID: req.id} -} - -func (h *messageHandler) handleTMQRawMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { - length := *(*uint32)(tools.AddPointer(req.p0, uintptr(24))) - metaType := *(*uint16)(tools.AddPointer(req.p0, uintptr(28))) - data := tools.AddPointer(req.p0, uintptr(30)) - logger.Tracef("get write raw message, length:%d, metaType:%d", length, metaType) - logger.Trace("get global lock for raw message") - s := log.GetLogNow(isDebug) - h.Lock() - logger.Debugf("get global lock cost:%s", log.GetLogDuration(isDebug, s)) - defer h.Unlock() - if h.closed { - logger.Trace("server closed") - return - } - meta := wrapper.BuildRawMeta(length, metaType, data) - code := syncinterface.TMQWriteRaw(h.conn, meta, logger, isDebug) - if code != 0 { - errStr := wrapper.TMQErr2Str(code) - logger.Errorf("write raw meta error, code:%d, msg:%s", code, errStr) - return wsCommonErrorMsg(int(code)&0xffff, errStr) - } - logger.Trace("write raw meta success") - - return &BaseResponse{} -} - -func (h *messageHandler) handleRawBlockMessage(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { - numOfRows := *(*int32)(tools.AddPointer(req.p0, uintptr(24))) - tableNameLength := *(*uint16)(tools.AddPointer(req.p0, uintptr(28))) - tableName := make([]byte, tableNameLength) - for i := 0; i < int(tableNameLength); i++ { - tableName[i] = *(*byte)(tools.AddPointer(req.p0, uintptr(30+i))) - } - rawBlock := tools.AddPointer(req.p0, uintptr(30+tableNameLength)) - logger.Tracef("raw block message, table:%s, rows:%d", tableName, numOfRows) - s := log.GetLogNow(isDebug) - h.Lock() - logger.Debugf("get global lock cost:%s", log.GetLogDuration(isDebug, s)) - defer h.Unlock() - if h.closed { - logger.Trace("server closed") - return - } - code := syncinterface.TaosWriteRawBlockWithReqID(h.conn, int(numOfRows), rawBlock, string(tableName), int64(req.reqID), logger, isDebug) - if code != 0 { - errStr := wrapper.TMQErr2Str(int32(code)) - logger.Errorf("write raw meta error, code:%d, msg:%s", code, errStr) - return wsCommonErrorMsg(int(code)&0xffff, errStr) - } - logger.Trace("write raw meta success") - return &BaseResponse{} -} - -func (h *messageHandler) handleRawBlockMessageWithFields(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) (resp Response) { - numOfRows := *(*int32)(tools.AddPointer(req.p0, uintptr(24))) - tableNameLength := int(*(*uint16)(tools.AddPointer(req.p0, uintptr(28)))) - tableName := make([]byte, tableNameLength) - for i := 0; i < tableNameLength; i++ { - tableName[i] = *(*byte)(tools.AddPointer(req.p0, uintptr(30+i))) - } - rawBlock := tools.AddPointer(req.p0, uintptr(30+tableNameLength)) - blockLength := int(parser.RawBlockGetLength(rawBlock)) - numOfColumn := int(parser.RawBlockGetNumOfCols(rawBlock)) - fieldsBlock := tools.AddPointer(req.p0, uintptr(30+tableNameLength+blockLength)) - logger.Tracef("raw block message with fields, table:%s, rows:%d", tableName, numOfRows) - h.Lock() - defer h.Unlock() - if h.closed { - logger.Trace("server closed") - return - } - code := syncinterface.TaosWriteRawBlockWithFieldsWithReqID(h.conn, int(numOfRows), rawBlock, string(tableName), fieldsBlock, numOfColumn, int64(req.reqID), logger, isDebug) - if code != 0 { - errStr := wrapper.TMQErr2Str(int32(code)) - logger.Errorf("write raw meta error, err:%s", errStr) - return wsCommonErrorMsg(int(code)&0xffff, errStr) - } - logger.Trace("write raw meta success") - return &BaseResponse{} -} - -func (h *messageHandler) handleBinaryQuery(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) Response { - message := req.message - if len(message) < 31 { - return wsCommonErrorMsg(0xffff, "message length is too short") - } - v := binary.LittleEndian.Uint16(message[24:]) - var sql []byte - if v == BinaryProtocolVersion1 { - sqlLen := binary.LittleEndian.Uint32(message[26:]) - remainMessageLength := len(message) - 30 - if remainMessageLength < int(sqlLen) { - return wsCommonErrorMsg(0xffff, fmt.Sprintf("uncompleted message, sql length:%d, remainMessageLength:%d", sqlLen, remainMessageLength)) - } - sql = message[30 : 30+sqlLen] - } else { - logger.Errorf("unknown binary query version:%d", v) - return wsCommonErrorMsg(0xffff, "unknown binary query version:"+strconv.Itoa(int(v))) - } - logger.Debugf("binary query, sql:%s", log.GetLogSql(bytesutil.ToUnsafeString(sql))) - sqlType := monitor.WSRecordRequest(bytesutil.ToUnsafeString(sql)) - s := log.GetLogNow(isDebug) - handler := async.GlobalAsync.HandlerPool.Get() - defer async.GlobalAsync.HandlerPool.Put(handler) - logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - result := async.GlobalAsync.TaosQuery(h.conn, logger, isDebug, bytesutil.ToUnsafeString(sql), handler, int64(req.reqID)) - logger.Debugf("query cost:%s", log.GetLogDuration(isDebug, s)) - code := wrapper.TaosError(result.Res) - if code != httperror.SUCCESS { - monitor.WSRecordResult(sqlType, false) - errStr := wrapper.TaosErrorStr(result.Res) - logger.Errorf("taos query error, code:%d, msg:%s, sql:%s", code, errStr, log.GetLogSql(bytesutil.ToUnsafeString(sql))) - syncinterface.FreeResult(result.Res, logger, isDebug) - return wsCommonErrorMsg(code, errStr) - } - monitor.WSRecordResult(sqlType, true) - s = log.GetLogNow(isDebug) - isUpdate := wrapper.TaosIsUpdateQuery(result.Res) - logger.Debugf("get is_update_query %t, cost:%s", isUpdate, log.GetLogDuration(isDebug, s)) - if isUpdate { - affectRows := wrapper.TaosAffectedRows(result.Res) - logger.Debugf("affected_rows %d cost:%s", affectRows, log.GetLogDuration(isDebug, s)) - syncinterface.FreeResult(result.Res, logger, isDebug) - return &QueryResponse{IsUpdate: true, AffectedRows: affectRows} - } - s = log.GetLogNow(isDebug) - fieldsCount := wrapper.TaosNumFields(result.Res) - logger.Debugf("num_fields cost:%s", log.GetLogDuration(isDebug, s)) - rowsHeader, _ := wrapper.ReadColumn(result.Res, fieldsCount) - s = log.GetLogNow(isDebug) - logger.Debugf("read column cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - precision := wrapper.TaosResultPrecision(result.Res) - logger.Debugf("result_precision cost:%s", log.GetLogDuration(isDebug, s)) - queryResult := QueryResult{TaosResult: result.Res, FieldsCount: fieldsCount, Header: rowsHeader, precision: precision} - idx := h.queryResults.Add(&queryResult) - logger.Trace("query success") - return &QueryResponse{ - ID: idx, - FieldsCount: fieldsCount, - FieldsNames: rowsHeader.ColNames, - FieldsLengths: rowsHeader.ColLength, - FieldsTypes: rowsHeader.ColTypes, - Precision: precision, - } -} - -func (h *messageHandler) handleFetchRawBlock(ctx context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) Response { - message := req.message - if len(message) < 26 { - return wsFetchRawBlockErrorMsg(0xffff, "message length is too short", req.reqID, req.id, uint64(wstool.GetDuration(ctx))) - } - v := binary.LittleEndian.Uint16(message[24:]) - if v != BinaryProtocolVersion1 { - return wsFetchRawBlockErrorMsg(0xffff, "unknown fetch raw block version", req.reqID, req.id, uint64(wstool.GetDuration(ctx))) - } - item := h.queryResults.Get(req.id) - logger.Tracef("fetch raw block, result_id:%d", req.id) - if item == nil { - logger.Errorf("result is nil, result_id:%d", req.id) - return wsFetchRawBlockErrorMsg(0xffff, "result is nil", req.reqID, req.id, uint64(wstool.GetDuration(ctx))) - } - item.Lock() - if item.TaosResult == nil { - item.Unlock() - logger.Errorf("result has been freed, result_id:%d", req.id) - return wsFetchRawBlockErrorMsg(0xffff, "result has been freed", req.reqID, req.id, uint64(wstool.GetDuration(ctx))) - } - s := log.GetLogNow(isDebug) - handler := async.GlobalAsync.HandlerPool.Get() - defer async.GlobalAsync.HandlerPool.Put(handler) - logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) - result := async.GlobalAsync.TaosFetchRawBlockA(item.TaosResult, logger, isDebug, handler) - if result.N == 0 { - logger.Trace("fetch raw block success") - item.Unlock() - h.queryResults.FreeResultByID(req.id, logger) - return wsFetchRawBlockFinish(req.reqID, req.id, uint64(wstool.GetDuration(ctx))) - } - if result.N < 0 { - item.Unlock() - errStr := wrapper.TaosErrorStr(result.Res) - logger.Errorf("fetch raw block error:%d %s", result.N, errStr) - h.queryResults.FreeResultByID(req.id, logger) - return wsFetchRawBlockErrorMsg(result.N, errStr, req.reqID, req.id, uint64(wstool.GetDuration(ctx))) - } - logger.Trace("call taos_get_raw_block") - s = log.GetLogNow(isDebug) - item.Block = wrapper.TaosGetRawBlock(item.TaosResult) - logger.Debugf("get_raw_block cost:%s", log.GetLogDuration(isDebug, s)) - item.Size = result.N - s = log.GetLogNow(isDebug) - blockLength := int(parser.RawBlockGetLength(item.Block)) - if blockLength <= 0 { - item.Unlock() - return wsFetchRawBlockErrorMsg(0xffff, "block length illegal", req.reqID, req.id, uint64(wstool.GetDuration(ctx))) - } - item.buf = wsFetchRawBlockMessage(item.buf, req.reqID, req.id, uint64(wstool.GetDuration(ctx)), int32(blockLength), item.Block) - logger.Debugf("handle binary content cost:%s", log.GetLogDuration(isDebug, s)) - resp := &BinaryResponse{Data: item.buf} - resp.SetBinary(true) - item.Unlock() - logger.Trace("fetch raw block success") - return resp -} - -type Stmt2BindResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func (h *messageHandler) handleStmt2Bind(_ context.Context, req dealBinaryRequest, logger *logrus.Entry, isDebug bool) Response { - message := req.message - if len(message) < 30 { - return wsStmtErrorMsg(0xffff, "message length is too short", req.id) - } - v := binary.LittleEndian.Uint16(message[24:]) - if v != Stmt2BindProtocolVersion1 { - return wsStmtErrorMsg(0xffff, "unknown stmt2 bind version", req.id) - } - colIndex := int32(binary.LittleEndian.Uint32(message[26:])) - stmtItem := h.stmts.GetStmt2(req.id) - if stmtItem == nil { - logger.Errorf("stmt2 is nil, stmt_id:%d", req.id) - return wsStmtErrorMsg(0xffff, "stmt2 is nil", req.id) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt2 has been freed, stmt_id:%d", req.id) - return wsStmtErrorMsg(0xffff, "stmt2 has been freed", req.id) - } - bindData := message[30:] - err := syncinterface.TaosStmt2BindBinary(stmtItem.stmt, bindData, colIndex, logger, isDebug) - if err != nil { - logger.Errorf("stmt2 bind error, err:%s", err.Error()) - var tError *errors2.TaosError - if errors.As(err, &tError) { - return wsStmtErrorMsg(int(tError.Code), tError.ErrStr, req.id) - } - return wsStmtErrorMsg(0xffff, err.Error(), req.id) - } - logger.Trace("stmt2 bind success") - return &Stmt2BindResponse{StmtID: req.id} -} - -type GetCurrentDBResponse struct { - BaseResponse - DB string `json:"db"` -} - -func (h *messageHandler) handleGetCurrentDB(_ context.Context, _ Request, logger *logrus.Entry, isDebug bool) (resp Response) { - db, err := syncinterface.TaosGetCurrentDB(h.conn, logger, isDebug) - if err != nil { - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - logger.Errorf("get current db error, err:%s", taosErr.Error()) - return wsCommonErrorMsg(int(taosErr.Code), taosErr.Error()) - } - return &GetCurrentDBResponse{DB: db} -} - -type GetServerInfoResponse struct { - BaseResponse - Info string `json:"info"` -} - -func (h *messageHandler) handleGetServerInfo(_ context.Context, _ Request, logger *logrus.Entry, isDebug bool) (resp Response) { - serverInfo := syncinterface.TaosGetServerInfo(h.conn, logger, isDebug) - return &GetServerInfoResponse{Info: serverInfo} -} - -type NumFieldsRequest struct { - ReqID uint64 `json:"req_id"` - ResultID uint64 `json:"result_id"` -} - -type NumFieldsResponse struct { - BaseResponse - NumFields int `json:"num_fields"` -} - -func (h *messageHandler) handleNumFields(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { - var req NumFieldsRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt num params request %s error, err:%s", request.Args, err) - return wsCommonErrorMsg(0xffff, "unmarshal stmt num params request error") - } - logger.Tracef("num fields, result_id:%d", req.ResultID) - item := h.queryResults.Get(req.ResultID) - if item == nil { - logger.Errorf("result is nil, result_id:%d", req.ResultID) - return wsCommonErrorMsg(0xffff, "result is nil") - } - item.Lock() - defer item.Unlock() - if item.TaosResult == nil { - logger.Errorf("result has been freed, result_id:%d", req.ResultID) - return wsCommonErrorMsg(0xffff, "result has been freed") - } - num := wrapper.TaosNumFields(item.TaosResult) - return &NumFieldsResponse{NumFields: num} -} - -type StmtNumParamsRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type StmtNumParamsResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - NumParams int `json:"num_params"` -} - -func (h *messageHandler) handleStmtNumParams(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtNumParamsRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt num params request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt num params request error", req.StmtID) - } - logger.Tracef("stmt num params, stmt_id:%d", req.StmtID) - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - count, code := syncinterface.TaosStmtNumParams(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt get col fields error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - return &StmtNumParamsResponse{StmtID: req.StmtID, NumParams: count} -} - -type StmtGetParamRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` - Index int `json:"index"` -} - -type StmtGetParamResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - Index int `json:"index"` - DataType int `json:"data_type"` - Length int `json:"length"` -} - -func (h *messageHandler) handleStmtGetParam(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req StmtGetParamRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt get param request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt get param request error", req.StmtID) - } - logger.Tracef("stmt get param, stmt_id:%d, index:%d", req.StmtID, req.Index) - - stmtItem := h.stmts.Get(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - dataType, length, err := syncinterface.TaosStmtGetParam(stmtItem.stmt, req.Index, logger, isDebug) - if err != nil { - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - logger.Errorf("stmt get param error, err:%s", taosErr.Error()) - return wsStmtErrorMsg(int(taosErr.Code), taosErr.Error(), req.StmtID) - } - logger.Tracef("stmt get param success, data_type:%d, length:%d", dataType, length) - return &StmtGetParamResponse{StmtID: req.StmtID, Index: req.Index, DataType: dataType, Length: length} -} - -type Stmt2InitRequest struct { - ReqID uint64 `json:"req_id"` - SingleStbInsert bool `json:"single_stb_insert"` - SingleTableBindOnce bool `json:"single_table_bind_once"` -} - -type Stmt2InitResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func (h *messageHandler) handleStmt2Init(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req Stmt2InitRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt2 init request %s error, err:%s", request.Args, err) - return wsCommonErrorMsg(0xffff, "unmarshal stmt2 init request error") - } - handle, caller := async.GlobalStmt2CallBackCallerPool.Get() - stmtInit := syncinterface.TaosStmt2Init(h.conn, int64(req.ReqID), req.SingleStbInsert, req.SingleTableBindOnce, handle, logger, isDebug) - if stmtInit == nil { - async.GlobalStmt2CallBackCallerPool.Put(handle) - errStr := wrapper.TaosStmtErrStr(stmtInit) - logger.Errorf("stmt2 init error, err:%s", errStr) - return wsCommonErrorMsg(0xffff, errStr) - } - stmtItem := &StmtItem{stmt: stmtInit, handler: handle, caller: caller, isStmt2: true} - h.stmts.Add(stmtItem) - logger.Tracef("stmt2 init sucess, stmt_id:%d, stmt pointer:%p", stmtItem.index, stmtInit) - return &StmtInitResponse{StmtID: stmtItem.index} -} - -type Stmt2PrepareRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` - SQL string `json:"sql"` - GetFields bool `json:"get_fields"` -} - -type PrepareFields struct { - stmtCommon.StmtField - BindType int8 `json:"bind_type"` -} - -type Stmt2PrepareResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - IsInsert bool `json:"is_insert"` - Fields []*PrepareFields `json:"fields"` - FieldsCount int `json:"fields_count"` -} - -func (h *messageHandler) handleStmt2Prepare(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) Response { - var req Stmt2PrepareRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt2 prepare request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal connect request error", req.StmtID) - } - logger.Debugf("stmt2 prepare, stmt_id:%d, sql:%s", req.StmtID, req.SQL) - stmtItem := h.stmts.GetStmt2(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt2 is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt2 is nil", req.StmtID) - } - s := log.GetLogNow(isDebug) - logger.Trace("get stmt2 lock") - stmtItem.Lock() - logger.Debugf("get stmt2 lock cost:%s", log.GetLogDuration(isDebug, s)) - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt2 has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - stmt2 := stmtItem.stmt - code := syncinterface.TaosStmt2Prepare(stmt2, req.SQL, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmt2Error(stmt2) - logger.Errorf("stmt2 prepare error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - logger.Tracef("stmt2 prepare success, stmt_id:%d", req.StmtID) - isInsert, code := syncinterface.TaosStmt2IsInsert(stmt2, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmt2Error(stmt2) - logger.Errorf("check stmt2 is insert error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - logger.Tracef("stmt2 is insert:%t", isInsert) - stmtItem.isInsert = isInsert - prepareResp := &Stmt2PrepareResponse{StmtID: req.StmtID, IsInsert: isInsert} - if req.GetFields { - if isInsert { - var prepareFields []*PrepareFields - // get table field - _, count, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_TBNAME, logger, isDebug) - if code != 0 { - return wsStmtErrorMsg(code, fmt.Sprintf("get table names fields error, %s", errStr), req.StmtID) - } - if count == 1 { - tableNameFields := &PrepareFields{ - StmtField: stmtCommon.StmtField{}, - BindType: stmtCommon.TAOS_FIELD_TBNAME, - } - prepareFields = append(prepareFields, tableNameFields) - } - // get tags field - tagFields, _, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_TAG, logger, isDebug) - if code != 0 { - return wsStmtErrorMsg(code, fmt.Sprintf("get tag fields error, %s", errStr), req.StmtID) - } - for i := 0; i < len(tagFields); i++ { - prepareFields = append(prepareFields, &PrepareFields{StmtField: *tagFields[i], BindType: stmtCommon.TAOS_FIELD_TAG}) - } - // get cols field - colFields, _, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_COL, logger, isDebug) - if code != 0 { - return wsStmtErrorMsg(code, fmt.Sprintf("get col fields error, %s", errStr), req.StmtID) - } - for i := 0; i < len(colFields); i++ { - prepareFields = append(prepareFields, &PrepareFields{StmtField: *colFields[i], BindType: stmtCommon.TAOS_FIELD_COL}) - } - prepareResp.Fields = prepareFields - } else { - _, count, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_QUERY, logger, isDebug) - if code != 0 { - return wsStmtErrorMsg(code, fmt.Sprintf("get query fields error, %s", errStr), req.StmtID) - } - prepareResp.FieldsCount = count - } - } - return prepareResp -} - -func getFields(stmt2 unsafe.Pointer, fieldType int8, logger *logrus.Entry, isDebug bool) (fields []*stmtCommon.StmtField, count int, code int, errSt string) { - var cFields unsafe.Pointer - code, count, cFields = syncinterface.TaosStmt2GetFields(stmt2, int(fieldType), logger, isDebug) - if code != 0 { - errStr := wrapper.TaosStmt2Error(stmt2) - logger.Errorf("stmt2 get fields error, field_type:%d, err:%s", fieldType, errStr) - return nil, count, code, errStr - } - defer wrapper.TaosStmt2FreeFields(stmt2, cFields) - if count > 0 && cFields != nil { - s := log.GetLogNow(isDebug) - fields = wrapper.StmtParseFields(count, cFields) - logger.Debugf("stmt2 parse fields cost:%s", log.GetLogDuration(isDebug, s)) - return fields, count, 0, "" - } - return nil, count, 0, "" -} - -type Stmt2GetFieldsRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` - FieldTypes []int8 `json:"field_types"` -} - -type Stmt2GetFieldsResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - TableCount int32 `json:"table_count"` - QueryCount int32 `json:"query_count"` - ColFields []*stmtCommon.StmtField `json:"col_fields"` - TagFields []*stmtCommon.StmtField `json:"tag_fields"` -} - -func (h *messageHandler) handleStmt2GetFields(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req Stmt2GetFieldsRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt2 get fields request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt get fields request error", req.StmtID) - } - logger.Tracef("stmt2 get col fields, stmt_id:%d", req.StmtID) - stmtItem := h.stmts.GetStmt2(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt2 is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt2 has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - stmt2GetFieldsResp := &Stmt2GetFieldsResponse{StmtID: req.StmtID} - for i := 0; i < len(req.FieldTypes); i++ { - switch req.FieldTypes[i] { - case stmtCommon.TAOS_FIELD_COL: - colFields, _, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_COL, logger, isDebug) - if code != 0 { - return wsStmtErrorMsg(code, fmt.Sprintf("get col fields error, %s", errStr), req.StmtID) - } - stmt2GetFieldsResp.ColFields = colFields - case stmtCommon.TAOS_FIELD_TAG: - tagFields, _, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_TAG, logger, isDebug) - if code != 0 { - return wsStmtErrorMsg(code, fmt.Sprintf("get tag fields error, %s", errStr), req.StmtID) - } - stmt2GetFieldsResp.TagFields = tagFields - case stmtCommon.TAOS_FIELD_TBNAME: - _, count, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_TBNAME, logger, isDebug) - if code != 0 { - return wsStmtErrorMsg(code, fmt.Sprintf("get table names fields error, %s", errStr), req.StmtID) - } - stmt2GetFieldsResp.TableCount = int32(count) - case stmtCommon.TAOS_FIELD_QUERY: - _, count, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_QUERY, logger, isDebug) - if code != 0 { - return wsStmtErrorMsg(code, fmt.Sprintf("get query fields error, %s", errStr), req.StmtID) - } - stmt2GetFieldsResp.QueryCount = int32(count) - } - } - return stmt2GetFieldsResp -} - -type Stmt2ExecRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type Stmt2ExecResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - Affected int `json:"affected"` -} - -func (h *messageHandler) handleStmt2Exec(_ context.Context, request Request, logger *logrus.Entry, isDebug bool) (resp Response) { - var req Stmt2ExecRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt2 exec request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt2 exec request error", req.StmtID) - } - logger.Tracef("stmt2 execute, stmt_id:%d", req.StmtID) - stmtItem := h.stmts.GetStmt2(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt2 is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt2 is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt2 has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt has been freed", req.StmtID) - } - code := syncinterface.TaosStmt2Exec(stmtItem.stmt, logger, isDebug) - if code != httperror.SUCCESS { - errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt2 execute error, err:%s", errStr) - return wsStmtErrorMsg(code, errStr, req.StmtID) - } - s := log.GetLogNow(isDebug) - logger.Tracef("stmt2 execute wait callback, stmt_id:%d", req.StmtID) - result := <-stmtItem.caller.ExecResult - logger.Debugf("stmt2 execute wait callback finish, affected:%d, res:%p, n:%d, cost:%s", result.Affected, result.Res, result.N, log.GetLogDuration(isDebug, s)) - stmtItem.result = result.Res - return &Stmt2ExecResponse{StmtID: req.StmtID, Affected: result.Affected} -} - -type Stmt2CloseRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type Stmt2CloseResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func (h *messageHandler) handleStmt2Close(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { - var req StmtCloseRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt close request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt close request error", req.StmtID) - } - logger.Tracef("stmt2 close, stmt_id:%d", req.StmtID) - err := h.stmts.FreeStmtByID(req.StmtID, true, logger) - if err != nil { - logger.Errorf("stmt2 close error, err:%s", err.Error()) - return wsStmtErrorMsg(0xffff, "unmarshal stmt close request error", req.StmtID) - } - resp = &Stmt2CloseResponse{StmtID: req.StmtID} - logger.Tracef("stmt2 close success, stmt_id:%d", req.StmtID) - return resp -} - -type Stmt2UseResultRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` -} - -type Stmt2UseResultResponse struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` - ResultID uint64 `json:"result_id"` - FieldsCount int `json:"fields_count"` - FieldsNames []string `json:"fields_names"` - FieldsTypes jsontype.JsonUint8 `json:"fields_types"` - FieldsLengths []int64 `json:"fields_lengths"` - Precision int `json:"precision"` -} - -func (h *messageHandler) handleStmt2UseResult(_ context.Context, request Request, logger *logrus.Entry, _ bool) (resp Response) { - var req Stmt2UseResultRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - logger.Errorf("unmarshal stmt2 use result request %s error, err:%s", request.Args, err) - return wsStmtErrorMsg(0xffff, "unmarshal stmt2 use result request error", req.StmtID) - } - logger.Tracef("stmt2 use result, stmt_id:%d", req.StmtID) - stmtItem := h.stmts.GetStmt2(req.StmtID) - if stmtItem == nil { - logger.Errorf("stmt2 is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt2 is nil", req.StmtID) - } - stmtItem.Lock() - defer stmtItem.Unlock() - if stmtItem.stmt == nil { - logger.Errorf("stmt2 has been freed, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt2 has been freed", req.StmtID) - } - - if stmtItem.result == nil { - logger.Errorf("stmt2 result is nil, stmt_id:%d", req.StmtID) - return wsStmtErrorMsg(0xffff, "stmt result is nil", req.StmtID) - } - result := stmtItem.result - fieldsCount := wrapper.TaosNumFields(result) - rowsHeader, _ := wrapper.ReadColumn(result, fieldsCount) - precision := wrapper.TaosResultPrecision(result) - logger.Tracef("stmt use result success, stmt_id:%d, fields_count:%d, precision:%d", req.StmtID, fieldsCount, precision) - queryResult := QueryResult{TaosResult: result, FieldsCount: fieldsCount, Header: rowsHeader, precision: precision, inStmt: true} - idx := h.queryResults.Add(&queryResult) - - return &Stmt2UseResultResponse{ - StmtID: req.StmtID, - ResultID: idx, - FieldsCount: fieldsCount, - FieldsNames: rowsHeader.ColNames, - FieldsTypes: rowsHeader.ColTypes, - FieldsLengths: rowsHeader.ColLength, - Precision: precision, - } -} - -type Response interface { - SetCode(code int) - SetMessage(message string) - SetAction(action string) - SetReqID(reqID uint64) - SetTiming(timing int64) - SetBinary(b bool) - IsBinary() bool - SetNull(b bool) - IsNull() bool -} - -func (b *BaseResponse) SetCode(code int) { - b.Code = code -} - -func (b *BaseResponse) SetMessage(message string) { - b.Message = message -} - -func (b *BaseResponse) SetAction(action string) { - b.Action = action -} - -func (b *BaseResponse) SetReqID(reqID uint64) { - b.ReqID = reqID -} - -func (b *BaseResponse) SetTiming(timing int64) { - b.Timing = timing -} - -func (b *BaseResponse) SetBinary(binary bool) { - b.binary = binary -} - -func (b *BaseResponse) IsBinary() bool { - return b.binary -} - -func (b *BaseResponse) SetNull(null bool) { - b.null = null -} - -func (b *BaseResponse) IsNull() bool { - return b.null -} - -type VersionResponse struct { - BaseResponse - Version string `json:"version"` -} - -type BinaryResponse struct { - BaseResponse - Data []byte -} - -type WSStmtErrorResp struct { - BaseResponse - StmtID uint64 `json:"stmt_id"` -} - -func wsStmtErrorMsg(code int, message string, stmtID uint64) *WSStmtErrorResp { - return &WSStmtErrorResp{ - BaseResponse: BaseResponse{ - Code: code & 0xffff, - Message: message, - }, - StmtID: stmtID, - } -} - -func wsCommonErrorMsg(code int, message string) *BaseResponse { - return &BaseResponse{ - Code: code & 0xffff, - Message: message, - } -} - -func wsFetchRawBlockErrorMsg(code int, message string, reqID uint64, resultID uint64, t uint64) *BinaryResponse { - bufLength := 8 + 8 + 2 + 8 + 8 + 4 + 4 + len(message) + 8 + 1 - buf := make([]byte, bufLength) - binary.LittleEndian.PutUint64(buf, 0xffffffffffffffff) - binary.LittleEndian.PutUint64(buf[8:], uint64(FetchRawBlockMessage)) - binary.LittleEndian.PutUint16(buf[16:], 1) - binary.LittleEndian.PutUint64(buf[18:], t) - binary.LittleEndian.PutUint64(buf[26:], reqID) - binary.LittleEndian.PutUint32(buf[34:], uint32(code&0xffff)) - binary.LittleEndian.PutUint32(buf[38:], uint32(len(message))) - copy(buf[42:], message) - binary.LittleEndian.PutUint64(buf[42+len(message):], resultID) - buf[42+len(message)+8] = 1 - resp := &BinaryResponse{Data: buf} - resp.SetBinary(true) - return resp -} - -func wsFetchRawBlockFinish(reqID uint64, resultID uint64, t uint64) *BinaryResponse { - bufLength := 8 + 8 + 2 + 8 + 8 + 4 + 4 + 8 + 1 - buf := make([]byte, bufLength) - binary.LittleEndian.PutUint64(buf, 0xffffffffffffffff) - binary.LittleEndian.PutUint64(buf[8:], uint64(FetchRawBlockMessage)) - binary.LittleEndian.PutUint16(buf[16:], 1) - binary.LittleEndian.PutUint64(buf[18:], t) - binary.LittleEndian.PutUint64(buf[26:], reqID) - binary.LittleEndian.PutUint32(buf[34:], 0) - binary.LittleEndian.PutUint32(buf[38:], 0) - binary.LittleEndian.PutUint64(buf[42:], resultID) - buf[50] = 1 - resp := &BinaryResponse{Data: buf} - resp.SetBinary(true) - return resp -} - -func wsFetchRawBlockMessage(buf []byte, reqID uint64, resultID uint64, t uint64, blockLength int32, rawBlock unsafe.Pointer) []byte { - bufLength := 8 + 8 + 2 + 8 + 8 + 4 + 4 + 8 + 1 + 4 + int(blockLength) - if cap(buf) < bufLength { - buf = make([]byte, 0, bufLength) - } - buf = buf[:bufLength] - binary.LittleEndian.PutUint64(buf, 0xffffffffffffffff) - binary.LittleEndian.PutUint64(buf[8:], uint64(FetchRawBlockMessage)) - binary.LittleEndian.PutUint16(buf[16:], 1) - binary.LittleEndian.PutUint64(buf[18:], t) - binary.LittleEndian.PutUint64(buf[26:], reqID) - binary.LittleEndian.PutUint32(buf[34:], 0) - binary.LittleEndian.PutUint32(buf[38:], 0) - binary.LittleEndian.PutUint64(buf[42:], resultID) - buf[50] = 0 - binary.LittleEndian.PutUint32(buf[51:], uint32(blockLength)) - bytesutil.Copy(rawBlock, buf, 55, int(blockLength)) - return buf +func getReqID(value gjson.Result) uint64 { + return value.Get("req_id").Uint() } diff --git a/controller/ws/ws/handler_test.go b/controller/ws/ws/handler_test.go index a1fe491a..f26c7979 100644 --- a/controller/ws/ws/handler_test.go +++ b/controller/ws/ws/handler_test.go @@ -1,90 +1,360 @@ package ws import ( + "bytes" + "encoding/json" + "net/http/httptest" + "strings" "testing" - "unsafe" + "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/param" - "github.com/taosdata/driver-go/v3/common/serializer" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - "github.com/taosdata/driver-go/v3/types" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" ) -func Test_parseRowBlockInfo(t *testing.T) { - b, err := serializer.SerializeRawBlock( - []*param.Param{ - param.NewParam(1).AddBool(true), - param.NewParam(1).AddTinyint(1), - param.NewParam(1).AddSmallint(1), - param.NewParam(1).AddInt(1), - param.NewParam(1).AddBigint(1), - param.NewParam(1).AddFloat(1.1), - param.NewParam(1).AddDouble(1.1), - param.NewParam(1).AddBinary([]byte("California.SanFrancisco")), - param.NewParam(1).AddNchar("California.SanFrancisco"), - param.NewParam(1).AddUTinyint(1), - param.NewParam(1).AddUSmallint(1), - param.NewParam(1).AddUInt(1), - param.NewParam(1).AddUBigint(1), - param.NewParam(1).AddJson([]byte(`{"name":"taos"}`)), - param.NewParam(1).AddVarBinary([]byte("California.SanFrancisco")), - }, - param.NewColumnType(15). - AddBool(). - AddTinyint(). - AddSmallint(). - AddInt(). - AddBigint(). - AddFloat(). - AddDouble(). - AddBinary(100). - AddNchar(100). - AddUTinyint(). - AddUSmallint(). - AddUInt(). - AddUBigint(). - AddJson(100). - AddVarBinary(100), - ) +func TestDropUser(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + defer doRestful("drop user test_ws_drop_user", "") + code, message := doRestful("create user test_ws_drop_user pass 'pass'", "") + assert.Equal(t, 0, code, message) + // connect + connReq := connRequest{ReqID: 1, User: "test_ws_drop_user", Password: "pass"} + resp, err := doWebSocket(ws, Connect, &connReq) assert.NoError(t, err) - fields, fieldsType, err := parseRowBlockInfo(unsafe.Pointer(&b[0]), 15) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) assert.NoError(t, err) - expectFields := []*stmtCommon.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_BOOL}, - {FieldType: common.TSDB_DATA_TYPE_TINYINT}, - {FieldType: common.TSDB_DATA_TYPE_SMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_INT}, - {FieldType: common.TSDB_DATA_TYPE_BIGINT}, - {FieldType: common.TSDB_DATA_TYPE_FLOAT}, - {FieldType: common.TSDB_DATA_TYPE_DOUBLE}, - {FieldType: common.TSDB_DATA_TYPE_BINARY}, - {FieldType: common.TSDB_DATA_TYPE_NCHAR}, - {FieldType: common.TSDB_DATA_TYPE_UTINYINT}, - {FieldType: common.TSDB_DATA_TYPE_USMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_UINT}, - {FieldType: common.TSDB_DATA_TYPE_UBIGINT}, - {FieldType: common.TSDB_DATA_TYPE_JSON}, - {FieldType: common.TSDB_DATA_TYPE_VARBINARY}, + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + // drop user + code, message = doRestful("drop user test_ws_drop_user", "") + assert.Equal(t, 0, code, message) + time.Sleep(time.Second * 3) + resp, err = doWebSocket(ws, wstool.ClientVersion, nil) + assert.Error(t, err, resp) +} + +func Test_WrongJsonProtocol(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return } - assert.Equal(t, expectFields, fields) - expectFieldsType := []*types.ColumnType{ - {Type: types.TaosBoolType}, - {Type: types.TaosTinyintType}, - {Type: types.TaosSmallintType}, - {Type: types.TaosIntType}, - {Type: types.TaosBigintType}, - {Type: types.TaosFloatType}, - {Type: types.TaosDoubleType}, - {Type: types.TaosBinaryType}, - {Type: types.TaosNcharType}, - {Type: types.TaosUTinyintType}, - {Type: types.TaosUSmallintType}, - {Type: types.TaosUIntType}, - {Type: types.TaosUBigintType}, - {Type: types.TaosJsonType}, - {Type: types.TaosBinaryType}, + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + connReq := connRequest{ + ReqID: 1, + User: "root", + Password: "taosdata", } - assert.Equal(t, expectFieldsType, fieldsType) + message, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + resp := commonResp{} + err = json.Unmarshal(message, &resp) + assert.NoError(t, err) + assert.Equal(t, 0, resp.Code, resp.Message) + tests := []struct { + name string + action string + args interface{} + errorPrefix string + }{ + { + name: "empty action", + action: "", + args: nil, + errorPrefix: "request no action", + }, + { + name: "connect with wrong args", + action: Connect, + args: "wrong", + errorPrefix: "unmarshal connect request error", + }, + { + name: "query with wrong args", + action: WSQuery, + args: "wrong", + errorPrefix: "unmarshal query request error", + }, + { + name: "fetch with wrong args", + action: WSFetch, + args: "wrong", + errorPrefix: "unmarshal fetch request error", + }, + { + name: "fetch_block with wrong args", + action: WSFetchBlock, + args: "wrong", + errorPrefix: "unmarshal fetch block request error", + }, + { + name: "free_result with wrong args", + action: WSFreeResult, + args: "wrong", + errorPrefix: "unmarshal free result request error", + }, + { + name: "num_fields with wrong args", + action: WSNumFields, + args: "wrong", + errorPrefix: "unmarshal num fields request error", + }, + { + name: "insert schemaless with wrong args", + action: SchemalessWrite, + args: "wrong", + errorPrefix: "unmarshal schemaless insert request error", + }, + { + name: "stmt init with wrong args", + action: STMTInit, + args: "wrong", + errorPrefix: "unmarshal stmt init request error", + }, + { + name: "stmt prepare with wrong args", + action: STMTPrepare, + args: "wrong", + errorPrefix: "unmarshal stmt prepare request error", + }, + { + name: "stmt set table name with wrong args", + action: STMTSetTableName, + args: "wrong", + errorPrefix: "unmarshal stmt set table name request error", + }, + { + name: "stmt set tags with wrong args", + action: STMTSetTags, + args: "wrong", + errorPrefix: "unmarshal stmt set tags request error", + }, + { + name: "stmt bind with wrong args", + action: STMTBind, + args: "wrong", + errorPrefix: "unmarshal stmt bind request error", + }, + { + name: "stmt add batch with wrong args", + action: STMTAddBatch, + args: "wrong", + errorPrefix: "unmarshal stmt add batch request error", + }, + { + name: "stmt exec with wrong args", + action: STMTExec, + args: "wrong", + errorPrefix: "unmarshal stmt exec request error", + }, + { + name: "stmt close with wrong args", + action: STMTClose, + args: "wrong", + errorPrefix: "unmarshal stmt close request error", + }, + { + name: "stmt get tag fields with wrong args", + action: STMTGetTagFields, + args: "wrong", + errorPrefix: "unmarshal stmt get tag fields request error", + }, + { + name: "stmt get col fields with wrong args", + action: STMTGetColFields, + args: "wrong", + errorPrefix: "unmarshal stmt get col fields request error", + }, + { + name: "stmt use result with wrong args", + action: STMTUseResult, + args: "wrong", + errorPrefix: "unmarshal stmt use result request error", + }, + { + name: "stmt num params with wrong args", + action: STMTNumParams, + args: "wrong", + errorPrefix: "unmarshal stmt num params request error", + }, + { + name: "stmt get param with wrong args", + action: STMTGetParam, + args: "wrong", + errorPrefix: "unmarshal stmt get param request error", + }, + { + name: "stmt2 init with wrong args", + action: STMT2Init, + args: "wrong", + errorPrefix: "unmarshal stmt2 init request error", + }, + { + name: "stmt2 prepare with wrong args", + action: STMT2Prepare, + args: "wrong", + errorPrefix: "unmarshal stmt2 prepare request error", + }, + { + name: "stmt2 get fields with wrong args", + action: STMT2GetFields, + args: "wrong", + errorPrefix: "unmarshal stmt2 get fields request error", + }, + { + name: "stmt2 exec with wrong args", + action: STMT2Exec, + args: "wrong", + errorPrefix: "unmarshal stmt2 exec request error", + }, + { + name: "stmt2 result with wrong args", + action: STMT2Result, + args: "wrong", + errorPrefix: "unmarshal stmt2 result request error", + }, + { + name: "stmt2 close with wrong args", + action: STMT2Close, + args: "wrong", + errorPrefix: "unmarshal stmt2 close request error", + }, + { + name: "get current db with wrong args", + action: WSGetCurrentDB, + args: "wrong", + errorPrefix: "unmarshal get current db request error", + }, + { + name: "get server info with wrong args", + action: WSGetServerInfo, + args: "wrong", + errorPrefix: "unmarshal get server info request error", + }, + { + name: "unknown action", + action: "unknown", + args: nil, + errorPrefix: "unknown action", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + message, err = doWebSocket(ws, tt.action, tt.args) + assert.NoError(t, err) + resp = commonResp{} + err = json.Unmarshal(message, &resp) + assert.NoError(t, err) + assert.NotEqual(t, 0, resp.Code) + if !strings.HasPrefix(resp.Message, tt.errorPrefix) { + t.Errorf("expected error message to start with %s, got %s", tt.errorPrefix, resp.Message) + } + }) + } +} + +func TestNotConnection(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + // json + query := queryRequest{ReqID: 1, Sql: "select * from test"} + message, err := doWebSocket(ws, WSQuery, &query) + assert.NoError(t, err) + resp := commonResp{} + err = json.Unmarshal(message, &resp) + assert.NoError(t, err) + assert.NotEqual(t, 0, resp.Code) + assert.Equal(t, "server not connected", resp.Message) + // binary + + sql := "select * from test" + var buffer bytes.Buffer + wstool.WriteUint64(&buffer, 2) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, message, err = ws.ReadMessage() + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(message, &queryResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), queryResp.ReqID) + assert.NotEqual(t, 0, queryResp.Code) + assert.Equal(t, "server not connected", queryResp.Message) +} + +func TestUnknownBinaryProtocol(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + connReq := connRequest{ + ReqID: 1, + User: "root", + Password: "taosdata", + } + message, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + resp := commonResp{} + err = json.Unmarshal(message, &resp) + assert.NoError(t, err) + assert.Equal(t, 0, resp.Code, resp.Message) + + sql := "select * from test" + var buffer bytes.Buffer + wstool.WriteUint64(&buffer, 2) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(9999)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, message, err = ws.ReadMessage() + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(message, &queryResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), queryResp.ReqID) + assert.NotEqual(t, 0, queryResp.Code) + assert.Equal(t, "unknown", queryResp.Action) + assert.Equal(t, "unknown binary action 9999", queryResp.Message) } diff --git a/controller/ws/ws/misc.go b/controller/ws/ws/misc.go new file mode 100644 index 00000000..f6e00296 --- /dev/null +++ b/controller/ws/ws/misc.go @@ -0,0 +1,67 @@ +package ws + +import ( + "context" + + "github.com/sirupsen/logrus" + errors2 "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/tools/melody" +) + +type getCurrentDBRequest struct { + ReqID uint64 `json:"req_id"` +} + +type getCurrentDBResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + DB string `json:"db"` +} + +func (h *messageHandler) getCurrentDB(ctx context.Context, session *melody.Session, action string, req *getCurrentDBRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("get current db") + db, err := syncinterface.TaosGetCurrentDB(h.conn, logger, isDebug) + if err != nil { + logger.Errorf("get current db error, err:%s", err) + taosErr := err.(*errors2.TaosError) + commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.Error()) + return + } + resp := &getCurrentDBResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + DB: db, + } + wstool.WSWriteJson(session, logger, resp) +} + +type getServerInfoRequest struct { + ReqID uint64 `json:"req_id"` +} + +type getServerInfoResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + Info string `json:"info"` +} + +func (h *messageHandler) getServerInfo(ctx context.Context, session *melody.Session, action string, req *getServerInfoRequest, logger *logrus.Entry, isDebug bool) { + logger.Trace("get server info") + serverInfo := syncinterface.TaosGetServerInfo(h.conn, logger, isDebug) + resp := &getServerInfoResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + Info: serverInfo, + } + wstool.WSWriteJson(session, logger, resp) +} diff --git a/controller/ws/ws/misc_test.go b/controller/ws/ws/misc_test.go new file mode 100644 index 00000000..a2bafdf4 --- /dev/null +++ b/controller/ws/ws/misc_test.go @@ -0,0 +1,90 @@ +package ws + +import ( + "encoding/json" + "fmt" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +func TestGetCurrentDB(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + db := "test_current_db" + code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") + assert.Equal(t, 0, code, message) + + defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // current db + currentDBReq := map[string]uint64{"req_id": 1} + resp, err = doWebSocket(ws, WSGetCurrentDB, ¤tDBReq) + assert.NoError(t, err) + var currentDBResp getCurrentDBResponse + err = json.Unmarshal(resp, ¤tDBResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), currentDBResp.ReqID) + assert.Equal(t, 0, currentDBResp.Code, currentDBResp.Message) + assert.Equal(t, db, currentDBResp.DB) +} + +func TestGetServerInfo(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // server info + serverInfoReq := map[string]uint64{"req_id": 1} + resp, err = doWebSocket(ws, WSGetServerInfo, &serverInfoReq) + assert.NoError(t, err) + var serverInfoResp getServerInfoResponse + err = json.Unmarshal(resp, &serverInfoResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), serverInfoResp.ReqID) + assert.Equal(t, 0, serverInfoResp.Code, serverInfoResp.Message) + t.Log(serverInfoResp.Info) +} diff --git a/controller/ws/ws/query.go b/controller/ws/ws/query.go new file mode 100644 index 00000000..fb905d95 --- /dev/null +++ b/controller/ws/ws/query.go @@ -0,0 +1,289 @@ +package ws + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + + "github.com/sirupsen/logrus" + "github.com/taosdata/driver-go/v3/common" + errors2 "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/wrapper" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/db/async" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/log" + "github.com/taosdata/taosadapter/v3/monitor" + "github.com/taosdata/taosadapter/v3/tools/bytesutil" + "github.com/taosdata/taosadapter/v3/tools/jsontype" + "github.com/taosdata/taosadapter/v3/tools/melody" +) + +type connRequest struct { + ReqID uint64 `json:"req_id"` + User string `json:"user"` + Password string `json:"password"` + DB string `json:"db"` + Mode *int `json:"mode"` +} + +func (h *messageHandler) connect(ctx context.Context, session *melody.Session, action string, req *connRequest, logger *logrus.Entry, isDebug bool) { + h.lock(logger, isDebug) + defer h.Unlock() + if h.closed { + logger.Trace("server closed") + return + } + if h.conn != nil { + logger.Trace("duplicate connections") + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "duplicate connections") + return + } + + conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug) + + if err != nil { + logger.Errorf("connect to TDengine error, err:%s", err) + var taosErr *errors2.TaosError + errors.As(err, &taosErr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + return + } + logger.Trace("get whitelist") + s := log.GetLogNow(isDebug) + whitelist, err := tool.GetWhitelist(conn) + logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) + if err != nil { + logger.Errorf("get whitelist error, err:%s", err) + syncinterface.TaosClose(conn, logger, isDebug) + var taosErr *errors2.TaosError + errors.As(err, &taosErr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + return + } + logger.Tracef("check whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) + valid := tool.CheckWhitelist(whitelist, h.ip) + if !valid { + logger.Errorf("ip not in whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) + syncinterface.TaosClose(conn, logger, isDebug) + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "whitelist prohibits current IP access") + return + } + s = log.GetLogNow(isDebug) + logger.Trace("register whitelist change") + err = tool.RegisterChangeWhitelist(conn, h.whitelistChangeHandle) + logger.Debugf("register whitelist change cost:%s", log.GetLogDuration(isDebug, s)) + if err != nil { + logger.Errorf("register whitelist change error, err:%s", err) + syncinterface.TaosClose(conn, logger, isDebug) + var taosErr *errors2.TaosError + errors.As(err, &taosErr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + return + } + s = log.GetLogNow(isDebug) + logger.Trace("register drop user") + err = tool.RegisterDropUser(conn, h.dropUserHandle) + logger.Debugf("register drop user cost:%s", log.GetLogDuration(isDebug, s)) + if err != nil { + logger.Errorf("register drop user error, err:%s", err) + syncinterface.TaosClose(conn, logger, isDebug) + var taosErr *errors2.TaosError + errors.As(err, &taosErr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + return + } + if req.Mode != nil { + switch *req.Mode { + case common.TAOS_CONN_MODE_BI: + // BI mode + logger.Trace("set connection mode to BI") + code := wrapper.TaosSetConnMode(conn, common.TAOS_CONN_MODE_BI, 1) + logger.Trace("set connection mode to BI done") + if code != 0 { + logger.Errorf("set connection mode to BI error, err:%s", wrapper.TaosErrorStr(nil)) + syncinterface.TaosClose(conn, logger, isDebug) + commonErrorResponse(ctx, session, logger, action, req.ReqID, code, wrapper.TaosErrorStr(nil)) + return + } + default: + syncinterface.TaosClose(conn, logger, isDebug) + logger.Tracef("unexpected mode:%d", *req.Mode) + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, fmt.Sprintf("unexpected mode:%d", *req.Mode)) + return + } + } + h.conn = conn + logger.Trace("start wait signal goroutine") + go h.waitSignal(h.logger) + commonSuccessResponse(ctx, session, logger, action, req.ReqID) +} + +type queryRequest struct { + ReqID uint64 `json:"req_id"` + Sql string `json:"sql"` +} + +type queryResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + ID uint64 `json:"id"` + IsUpdate bool `json:"is_update"` + AffectedRows int `json:"affected_rows"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes jsontype.JsonUint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} + +func (h *messageHandler) query(ctx context.Context, session *melody.Session, action string, req *queryRequest, logger *logrus.Entry, isDebug bool) { + sqlType := monitor.WSRecordRequest(req.Sql) + logger.Debugf("get query request, sql:%s", req.Sql) + s := log.GetLogNow(isDebug) + handler := async.GlobalAsync.HandlerPool.Get() + defer async.GlobalAsync.HandlerPool.Put(handler) + logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) + result := async.GlobalAsync.TaosQuery(h.conn, logger, isDebug, req.Sql, handler, int64(req.ReqID)) + code := wrapper.TaosError(result.Res) + if code != 0 { + monitor.WSRecordResult(sqlType, false) + errStr := wrapper.TaosErrorStr(result.Res) + logger.Errorf("query error, code:%d, message:%s", code, errStr) + syncinterface.FreeResult(result.Res, logger, isDebug) + commonErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr) + return + } + + monitor.WSRecordResult(sqlType, true) + logger.Trace("check is_update_query") + s = log.GetLogNow(isDebug) + isUpdate := wrapper.TaosIsUpdateQuery(result.Res) + logger.Debugf("get is_update_query %t, cost:%s", isUpdate, log.GetLogDuration(isDebug, s)) + if isUpdate { + s = log.GetLogNow(isDebug) + affectRows := wrapper.TaosAffectedRows(result.Res) + logger.Debugf("affected_rows %d cost:%s", affectRows, log.GetLogDuration(isDebug, s)) + syncinterface.FreeResult(result.Res, logger, isDebug) + resp := queryResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + IsUpdate: true, + AffectedRows: affectRows, + } + wstool.WSWriteJson(session, logger, resp) + return + } + s = log.GetLogNow(isDebug) + fieldsCount := wrapper.TaosNumFields(result.Res) + logger.Debugf("get num_fields:%d, cost:%s", fieldsCount, log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + rowsHeader, _ := wrapper.ReadColumn(result.Res, fieldsCount) + logger.Debugf("read column cost:%s", log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + precision := wrapper.TaosResultPrecision(result.Res) + logger.Debugf("get result_precision:%d, cost:%s", precision, log.GetLogDuration(isDebug, s)) + queryResult := QueryResult{TaosResult: result.Res, FieldsCount: fieldsCount, Header: rowsHeader, precision: precision} + idx := h.queryResults.Add(&queryResult) + logger.Trace("add result to list finished") + resp := queryResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + ID: idx, + FieldsCount: fieldsCount, + FieldsNames: rowsHeader.ColNames, + FieldsLengths: rowsHeader.ColLength, + FieldsTypes: rowsHeader.ColTypes, + Precision: precision, + } + wstool.WSWriteJson(session, logger, resp) +} + +func (h *messageHandler) binaryQuery(ctx context.Context, session *melody.Session, action string, reqID uint64, message []byte, logger *logrus.Entry, isDebug bool) { + if len(message) < 31 { + commonErrorResponse(ctx, session, logger, action, reqID, 0xffff, "message length is too short") + return + } + v := binary.LittleEndian.Uint16(message[24:]) + var sql []byte + if v == BinaryProtocolVersion1 { + sqlLen := binary.LittleEndian.Uint32(message[26:]) + remainMessageLength := len(message) - 30 + if remainMessageLength < int(sqlLen) { + commonErrorResponse(ctx, session, logger, action, reqID, 0xffff, fmt.Sprintf("uncompleted message, sql length:%d, remainMessageLength:%d", sqlLen, remainMessageLength)) + return + } + sql = message[30 : 30+sqlLen] + } else { + logger.Errorf("unknown binary query version:%d", v) + commonErrorResponse(ctx, session, logger, action, reqID, 0xffff, fmt.Sprintf("unknown binary query version:%d", v)) + return + } + logger.Debugf("binary query, sql:%s", log.GetLogSql(bytesutil.ToUnsafeString(sql))) + sqlType := monitor.WSRecordRequest(bytesutil.ToUnsafeString(sql)) + s := log.GetLogNow(isDebug) + handler := async.GlobalAsync.HandlerPool.Get() + defer async.GlobalAsync.HandlerPool.Put(handler) + logger.Debugf("get handler cost:%s", log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + result := async.GlobalAsync.TaosQuery(h.conn, logger, isDebug, bytesutil.ToUnsafeString(sql), handler, int64(reqID)) + logger.Debugf("query cost:%s", log.GetLogDuration(isDebug, s)) + code := wrapper.TaosError(result.Res) + if code != 0 { + monitor.WSRecordResult(sqlType, false) + errStr := wrapper.TaosErrorStr(result.Res) + logger.Errorf("taos query error, code:%d, msg:%s, sql:%s", code, errStr, log.GetLogSql(bytesutil.ToUnsafeString(sql))) + syncinterface.FreeResult(result.Res, logger, isDebug) + commonErrorResponse(ctx, session, logger, action, reqID, code, errStr) + return + } + monitor.WSRecordResult(sqlType, true) + s = log.GetLogNow(isDebug) + isUpdate := wrapper.TaosIsUpdateQuery(result.Res) + logger.Debugf("get is_update_query %t, cost:%s", isUpdate, log.GetLogDuration(isDebug, s)) + if isUpdate { + affectRows := wrapper.TaosAffectedRows(result.Res) + logger.Debugf("affected_rows %d cost:%s", affectRows, log.GetLogDuration(isDebug, s)) + syncinterface.FreeResult(result.Res, logger, isDebug) + resp := &queryResponse{ + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + IsUpdate: true, + AffectedRows: affectRows, + } + wstool.WSWriteJson(session, logger, resp) + return + } + s = log.GetLogNow(isDebug) + fieldsCount := wrapper.TaosNumFields(result.Res) + logger.Debugf("num_fields cost:%s", log.GetLogDuration(isDebug, s)) + rowsHeader, _ := wrapper.ReadColumn(result.Res, fieldsCount) + s = log.GetLogNow(isDebug) + logger.Debugf("read column cost:%s", log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + precision := wrapper.TaosResultPrecision(result.Res) + logger.Debugf("result_precision cost:%s", log.GetLogDuration(isDebug, s)) + queryResult := QueryResult{TaosResult: result.Res, FieldsCount: fieldsCount, Header: rowsHeader, precision: precision} + idx := h.queryResults.Add(&queryResult) + logger.Trace("query success") + resp := &queryResponse{ + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + ID: idx, + FieldsCount: fieldsCount, + FieldsNames: rowsHeader.ColNames, + FieldsLengths: rowsHeader.ColLength, + FieldsTypes: rowsHeader.ColTypes, + Precision: precision, + } + wstool.WSWriteJson(session, logger, resp) +} diff --git a/controller/ws/ws/query_test.go b/controller/ws/ws/query_test.go new file mode 100644 index 00000000..aadab537 --- /dev/null +++ b/controller/ws/ws/query_test.go @@ -0,0 +1,1258 @@ +package ws + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "net/http/httptest" + "strings" + "testing" + "time" + "unsafe" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/tools/parseblock" +) + +func TestWSConnect(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // wrong password + connReq := connRequest{ReqID: 1, User: "root", Password: "wrong"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, "Authentication failure", connResp.Message) + assert.Equal(t, 0x357, connResp.Code, connResp.Message) + + // connect + connReq = connRequest{ReqID: 1, User: "root", Password: "taosdata"} + resp, err = doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + //duplicate connections + connReq = connRequest{ReqID: 1, User: "root", Password: "taosdata"} + resp, err = doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0xffff, connResp.Code) + assert.Equal(t, "duplicate connections", connResp.Message) +} + +type TestConnRequest struct { + ReqID uint64 `json:"req_id"` + User string `json:"user"` + Password string `json:"password"` + DB string `json:"db"` + Mode int `json:"mode"` +} + +func TestMode(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + wrongMode := 999 + connReq := TestConnRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: wrongMode} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0xffff, connResp.Code) + assert.Equal(t, fmt.Sprintf("unexpected mode:%d", wrongMode), connResp.Message) + + //bi + biMode := 0 + connReq = TestConnRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: biMode} + resp, err = doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + +} + +func TestWsQuery(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful("drop database if exists test_ws_query", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_ws_query", "") + assert.Equal(t, 0, code, message) + code, message = doRestful( + "create table if not exists stb1 (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", + "test_ws_query") + assert.Equal(t, 0, code, message) + code, message = doRestful( + `insert into t1 using stb1 tags ('{\"table\":\"t1\"}') values (now-2s,true,2,3,4,5,6,7,8,9,10,11,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now-1s,false,12,13,14,15,16,17,18,19,110,111,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)`, + "test_ws_query") + assert.Equal(t, 0, code, message) + + code, message = doRestful("create table t2 using stb1 tags('{\"table\":\"t2\"}')", "test_ws_query") + assert.Equal(t, 0, code, message) + code, message = doRestful("create table t3 using stb1 tags('{\"table\":\"t3\"}')", "test_ws_query") + assert.Equal(t, 0, code, message) + + defer doRestful("drop database if exists test_ws_query", "") + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_query"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // query + queryReq := queryRequest{ReqID: 2, Sql: "select * from stb1"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), queryResp.ReqID) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch + fetchReq := fetchRequest{ReqID: 3, ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + var fetchResp fetchResponse + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), fetchResp.ReqID) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + assert.Equal(t, 3, fetchResp.Rows) + + // fetch block + fetchBlockReq := fetchBlockRequest{ReqID: 4, ID: queryResp.ID} + fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + resultID, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) + assert.Equal(t, uint64(1), resultID) + checkBlockResult(t, blockResult) + + fetchReq = fetchRequest{ReqID: 5, ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), fetchResp.ReqID) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + assert.Equal(t, true, fetchResp.Completed) + + // write block + var buffer bytes.Buffer + wstool.WriteUint64(&buffer, 300) // req id + wstool.WriteUint64(&buffer, 400) // message id + wstool.WriteUint64(&buffer, uint64(RawBlockMessage)) // action + wstool.WriteUint32(&buffer, uint32(fetchResp.Rows)) // rows + wstool.WriteUint16(&buffer, uint16(2)) // table name length + buffer.WriteString("t2") // table name + buffer.Write(fetchBlockResp[16:]) // raw block + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var writeResp commonResp + err = json.Unmarshal(resp, &writeResp) + assert.NoError(t, err) + assert.Equal(t, 0, writeResp.Code, writeResp.Message) + + // query + queryReq = queryRequest{ReqID: 6, Sql: "select * from t2"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch + fetchReq = fetchRequest{ReqID: 7, ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + // fetch block + fetchBlockReq = fetchBlockRequest{ReqID: 8, ID: queryResp.ID} + fetchBlockResp, err = doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + resultID, blockResult = parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) + checkBlockResult(t, blockResult) + assert.Equal(t, queryResp.ID, resultID) + // fetch + fetchReq = fetchRequest{ReqID: 9, ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + assert.Equal(t, true, fetchResp.Completed) + + // write block with filed + buffer.Reset() + wstool.WriteUint64(&buffer, 300) // req id + wstool.WriteUint64(&buffer, 400) // message id + wstool.WriteUint64(&buffer, uint64(RawBlockMessageWithFields)) // action + wstool.WriteUint32(&buffer, uint32(fetchResp.Rows)) // rows + wstool.WriteUint16(&buffer, uint16(2)) // table name length + buffer.WriteString("t3") // table name + buffer.Write(fetchBlockResp[16:]) // raw block + fields := []byte{ + // ts + 0x74, 0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x09, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v1 + 0x76, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x01, + // padding + 0x00, 0x00, + // bytes + 0x01, 0x00, 0x00, 0x00, + + // v2 + 0x76, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x02, + // padding + 0x00, 0x00, + // bytes + 0x01, 0x00, 0x00, 0x00, + + // v3 + 0x76, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x03, + // padding + 0x00, 0x00, + // bytes + 0x02, 0x00, 0x00, 0x00, + + // v4 + 0x76, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x04, + // padding + 0x00, 0x00, + // bytes + 0x04, 0x00, 0x00, 0x00, + + // v5 + 0x76, 0x35, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x05, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v6 + 0x76, 0x36, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0b, + // padding + 0x00, 0x00, + // bytes + 0x01, 0x00, 0x00, 0x00, + + // v7 + 0x76, 0x37, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0c, + // padding + 0x00, 0x00, + // bytes + 0x02, 0x00, 0x00, 0x00, + + // v8 + 0x76, 0x38, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0d, + // padding + 0x00, 0x00, + // bytes + 0x04, 0x00, 0x00, 0x00, + + // v9 + 0x76, 0x39, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0e, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v10 + 0x76, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x06, + // padding + 0x00, 0x00, + // bytes + 0x04, 0x00, 0x00, 0x00, + + // v11 + 0x76, 0x31, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x07, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v12 + 0x76, 0x31, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x08, + // padding + 0x00, 0x00, + // bytes + 0x14, 0x00, 0x00, 0x00, + + // v13 + 0x76, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0a, + // padding + 0x00, 0x00, + // bytes + 0x14, 0x00, 0x00, 0x00, + + // v14 + 0x76, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x10, + // padding + 0x00, 0x00, + // bytes + 0x14, 0x00, 0x00, 0x00, + + // v15 + 0x76, 0x31, 0x35, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x14, + // padding + 0x00, 0x00, + // bytes + 0x64, 0x00, 0x00, 0x00, + + // info + 0x69, 0x6e, 0x66, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0f, + // padding + 0x00, 0x00, + // bytes + 0x00, 0x10, 0x00, 0x00, + } + buffer.Write(fields) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &writeResp) + assert.NoError(t, err) + assert.Equal(t, 0, writeResp.Code, writeResp.Message) + + // query + queryReq = queryRequest{ReqID: 10, Sql: "select * from t3"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch + fetchReq = fetchRequest{ReqID: 11, ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + // fetch block + fetchBlockReq = fetchBlockRequest{ReqID: 12, ID: queryResp.ID} + fetchBlockResp, err = doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + resultID, blockResult = parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) + assert.Equal(t, queryResp.ID, resultID) + checkBlockResult(t, blockResult) + // fetch + fetchReq = fetchRequest{ReqID: 13, ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + assert.Equal(t, true, fetchResp.Completed) + time.Sleep(time.Second) +} + +type FetchRawBlockResponse struct { + Flag uint64 + Version uint16 + Time uint64 + ReqID uint64 + Code uint32 + Message string + ResultID uint64 + Finished bool + RawBlock []byte +} + +func parseFetchRawBlock(message []byte) *FetchRawBlockResponse { + var resp = &FetchRawBlockResponse{} + resp.Flag = binary.LittleEndian.Uint64(message) + resp.Version = binary.LittleEndian.Uint16(message[16:]) + resp.Time = binary.LittleEndian.Uint64(message[18:]) + resp.ReqID = binary.LittleEndian.Uint64(message[26:]) + resp.Code = binary.LittleEndian.Uint32(message[34:]) + msgLen := int(binary.LittleEndian.Uint32(message[38:])) + resp.Message = string(message[42 : 42+msgLen]) + if resp.Code != 0 { + return resp + } + resp.ResultID = binary.LittleEndian.Uint64(message[42+msgLen:]) + resp.Finished = message[50+msgLen] == 1 + if resp.Finished { + return resp + } + blockLength := binary.LittleEndian.Uint32(message[51+msgLen:]) + resp.RawBlock = message[55+msgLen : 55+msgLen+int(blockLength)] + return resp +} + +func ReadBlockSimple(block unsafe.Pointer, precision int) [][]driver.Value { + blockSize := parser.RawBlockGetNumOfRows(block) + colCount := parser.RawBlockGetNumOfCols(block) + colInfo := make([]parser.RawBlockColInfo, colCount) + parser.RawBlockGetColInfo(block, colInfo) + colTypes := make([]uint8, colCount) + for i := int32(0); i < colCount; i++ { + colTypes[i] = uint8(colInfo[i].ColType) + } + return parser.ReadBlock(block, int(blockSize), colTypes, precision) +} + +func TestWsBinaryQuery(t *testing.T) { + dbName := "test_ws_binary_query" + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful(fmt.Sprintf("drop database if exists %s", dbName), "") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create database if not exists %s", dbName), "") + assert.Equal(t, 0, code, message) + code, message = doRestful( + "create table if not exists stb1 (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", + dbName) + assert.Equal(t, 0, code, message) + code, message = doRestful( + `insert into t1 using stb1 tags ('{\"table\":\"t1\"}') values (now-2s,true,2,3,4,5,6,7,8,9,10,11,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now-1s,false,12,13,14,15,16,17,18,19,110,111,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)`, + dbName) + assert.Equal(t, 0, code, message) + + code, message = doRestful("create table t2 using stb1 tags('{\"table\":\"t2\"}')", dbName) + assert.Equal(t, 0, code, message) + code, message = doRestful("create table t3 using stb1 tags('{\"table\":\"t3\"}')", dbName) + assert.Equal(t, 0, code, message) + + defer doRestful(fmt.Sprintf("drop database if exists %s", dbName), "") + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: dbName} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // query + sql := "select * from stb1" + var buffer bytes.Buffer + wstool.WriteUint64(&buffer, 2) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), queryResp.ReqID) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch raw block + buffer.Reset() + wstool.WriteUint64(&buffer, 3) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) // version + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp := parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(3), fetchRawBlockResp.ReqID) + assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) + assert.Equal(t, uint64(1), fetchRawBlockResp.ResultID) + assert.Equal(t, false, fetchRawBlockResp.Finished) + rows := parser.RawBlockGetNumOfRows(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0])) + assert.Equal(t, int32(3), rows) + blockResult := ReadBlockSimple(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]), queryResp.Precision) + checkBlockResult(t, blockResult) + rawBlock := fetchRawBlockResp.RawBlock + + buffer.Reset() + wstool.WriteUint64(&buffer, 5) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) // version + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(5), fetchRawBlockResp.ReqID) + assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) + assert.Equal(t, uint64(1), fetchRawBlockResp.ResultID) + assert.Equal(t, true, fetchRawBlockResp.Finished) + + // write block + + buffer.Reset() + wstool.WriteUint64(&buffer, 300) // req id + wstool.WriteUint64(&buffer, 400) // message id + wstool.WriteUint64(&buffer, uint64(RawBlockMessage)) // action + wstool.WriteUint32(&buffer, uint32(3)) // rows + wstool.WriteUint16(&buffer, uint16(2)) // table name length + buffer.WriteString("t2") // table name + buffer.Write(rawBlock) // raw block + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var writeResp commonResp + err = json.Unmarshal(resp, &writeResp) + assert.NoError(t, err) + assert.Equal(t, 0, writeResp.Code, writeResp.Message) + + // query + sql = "select * from t2" + buffer.Reset() + wstool.WriteUint64(&buffer, 6) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch raw block + buffer.Reset() + wstool.WriteUint64(&buffer, 7) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) // version + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(7), fetchRawBlockResp.ReqID) + assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) + assert.Equal(t, false, fetchRawBlockResp.Finished) + blockResult = ReadBlockSimple(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]), queryResp.Precision) + checkBlockResult(t, blockResult) + rawBlock = fetchRawBlockResp.RawBlock + + buffer.Reset() + wstool.WriteUint64(&buffer, 9) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) // version + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(9), fetchRawBlockResp.ReqID) + assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) + assert.Equal(t, true, fetchRawBlockResp.Finished) + + // write block with filed + buffer.Reset() + wstool.WriteUint64(&buffer, 300) // req id + wstool.WriteUint64(&buffer, 400) // message id + wstool.WriteUint64(&buffer, uint64(RawBlockMessageWithFields)) // action + wstool.WriteUint32(&buffer, uint32(3)) // rows + wstool.WriteUint16(&buffer, uint16(2)) // table name length + buffer.WriteString("t3") // table name + buffer.Write(rawBlock) // raw block + fields := []byte{ + // ts + 0x74, 0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x09, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v1 + 0x76, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x01, + // padding + 0x00, 0x00, + // bytes + 0x01, 0x00, 0x00, 0x00, + + // v2 + 0x76, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x02, + // padding + 0x00, 0x00, + // bytes + 0x01, 0x00, 0x00, 0x00, + + // v3 + 0x76, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x03, + // padding + 0x00, 0x00, + // bytes + 0x02, 0x00, 0x00, 0x00, + + // v4 + 0x76, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x04, + // padding + 0x00, 0x00, + // bytes + 0x04, 0x00, 0x00, 0x00, + + // v5 + 0x76, 0x35, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x05, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v6 + 0x76, 0x36, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0b, + // padding + 0x00, 0x00, + // bytes + 0x01, 0x00, 0x00, 0x00, + + // v7 + 0x76, 0x37, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0c, + // padding + 0x00, 0x00, + // bytes + 0x02, 0x00, 0x00, 0x00, + + // v8 + 0x76, 0x38, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0d, + // padding + 0x00, 0x00, + // bytes + 0x04, 0x00, 0x00, 0x00, + + // v9 + 0x76, 0x39, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0e, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v10 + 0x76, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x06, + // padding + 0x00, 0x00, + // bytes + 0x04, 0x00, 0x00, 0x00, + + // v11 + 0x76, 0x31, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x07, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v12 + 0x76, 0x31, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x08, + // padding + 0x00, 0x00, + // bytes + 0x14, 0x00, 0x00, 0x00, + + // v13 + 0x76, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0a, + // padding + 0x00, 0x00, + // bytes + 0x14, 0x00, 0x00, 0x00, + + // v14 + 0x76, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x10, + // padding + 0x00, 0x00, + // bytes + 0x14, 0x00, 0x00, 0x00, + + // v15 + 0x76, 0x31, 0x35, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x14, + // padding + 0x00, 0x00, + // bytes + 0x64, 0x00, 0x00, 0x00, + + // info + 0x69, 0x6e, 0x66, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x0f, + // padding + 0x00, 0x00, + // bytes + 0x00, 0x10, 0x00, 0x00, + } + buffer.Write(fields) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &writeResp) + assert.NoError(t, err) + assert.Equal(t, 0, writeResp.Code, writeResp.Message) + + // query + sql = "select * from t3" + buffer.Reset() + wstool.WriteUint64(&buffer, 6) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch raw block + buffer.Reset() + wstool.WriteUint64(&buffer, 11) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) // version + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(11), fetchRawBlockResp.ReqID) + assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) + assert.Equal(t, false, fetchRawBlockResp.Finished) + blockResult = ReadBlockSimple(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]), queryResp.Precision) + checkBlockResult(t, blockResult) + + buffer.Reset() + wstool.WriteUint64(&buffer, 13) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) // version + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(13), fetchRawBlockResp.ReqID) + assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) + assert.Equal(t, true, fetchRawBlockResp.Finished) + + // wrong message length + buffer.Reset() + wstool.WriteUint64(&buffer, 6) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, 0) // sql length + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 65535, queryResp.Code, queryResp.Message) + + // wrong sql length + buffer.Reset() + wstool.WriteUint64(&buffer, 6) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, 100) // sql length + buffer.WriteString("wrong sql length") + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 65535, queryResp.Code, queryResp.Message) + + // wrong version + buffer.Reset() + sql = "select 1" + wstool.WriteUint64(&buffer, 6) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 100) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 65535, queryResp.Code, queryResp.Message) + + // wrong sql + buffer.Reset() + sql = "wrong sql" + wstool.WriteUint64(&buffer, 6) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.NotEqual(t, 0, queryResp.Code, queryResp.Message) + + // insert sql + buffer.Reset() + sql = "create table t4 using stb1 tags('{\"table\":\"t4\"}')" + wstool.WriteUint64(&buffer, 6) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + assert.Equal(t, true, queryResp.IsUpdate) + + // wrong fetch + buffer.Reset() + sql = "select 1" + wstool.WriteUint64(&buffer, 6) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) + wstool.WriteUint16(&buffer, 1) // version + wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length + buffer.WriteString(sql) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // wrong fetch raw block length + buffer.Reset() + wstool.WriteUint64(&buffer, 700) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(700), fetchRawBlockResp.ReqID) + assert.NotEqual(t, uint32(0), fetchRawBlockResp.Code) + + // wrong fetch raw block version + buffer.Reset() + wstool.WriteUint64(&buffer, 800) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 100) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(800), fetchRawBlockResp.ReqID) + assert.NotEqual(t, uint32(0), fetchRawBlockResp.Code) + time.Sleep(time.Second) + + // wrong fetch raw block result + buffer.Reset() + wstool.WriteUint64(&buffer, 900) // req id + wstool.WriteUint64(&buffer, queryResp.ID+100) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(900), fetchRawBlockResp.ReqID) + assert.NotEqual(t, uint32(0), fetchRawBlockResp.Code) + + // fetch freed raw block + buffer.Reset() + wstool.WriteUint64(&buffer, 600) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(600), fetchRawBlockResp.ReqID) + assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) + assert.Equal(t, int32(1), parser.RawBlockGetNumOfRows(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]))) + + buffer.Reset() + wstool.WriteUint64(&buffer, 700) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(700), fetchRawBlockResp.ReqID) + assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) + assert.Equal(t, true, fetchRawBlockResp.Finished) + + buffer.Reset() + wstool.WriteUint64(&buffer, 400) // req id + wstool.WriteUint64(&buffer, queryResp.ID) // message id + wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) + wstool.WriteUint16(&buffer, 1) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + fetchRawBlockResp = parseFetchRawBlock(resp) + assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) + assert.Equal(t, uint64(400), fetchRawBlockResp.ReqID) + assert.NotEqual(t, uint32(0), fetchRawBlockResp.Code) + time.Sleep(time.Second) +} + +func checkBlockResult(t *testing.T, blockResult [][]driver.Value) { + assert.Equal(t, 3, len(blockResult)) + assert.Equal(t, true, blockResult[0][1]) + assert.Equal(t, int8(2), blockResult[0][2]) + assert.Equal(t, int16(3), blockResult[0][3]) + assert.Equal(t, int32(4), blockResult[0][4]) + assert.Equal(t, int64(5), blockResult[0][5]) + assert.Equal(t, uint8(6), blockResult[0][6]) + assert.Equal(t, uint16(7), blockResult[0][7]) + assert.Equal(t, uint32(8), blockResult[0][8]) + assert.Equal(t, uint64(9), blockResult[0][9]) + assert.Equal(t, float32(10), blockResult[0][10]) + assert.Equal(t, float64(11), blockResult[0][11]) + assert.Equal(t, "中文\"binary", blockResult[0][12]) + assert.Equal(t, "中文nchar", blockResult[0][13]) + assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[0][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) + assert.Equal(t, false, blockResult[1][1]) + assert.Equal(t, int8(12), blockResult[1][2]) + assert.Equal(t, int16(13), blockResult[1][3]) + assert.Equal(t, int32(14), blockResult[1][4]) + assert.Equal(t, int64(15), blockResult[1][5]) + assert.Equal(t, uint8(16), blockResult[1][6]) + assert.Equal(t, uint16(17), blockResult[1][7]) + assert.Equal(t, uint32(18), blockResult[1][8]) + assert.Equal(t, uint64(19), blockResult[1][9]) + assert.Equal(t, float32(110), blockResult[1][10]) + assert.Equal(t, float64(111), blockResult[1][11]) + assert.Equal(t, "中文\"binary", blockResult[1][12]) + assert.Equal(t, "中文nchar", blockResult[1][13]) + assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) + assert.Equal(t, nil, blockResult[2][1]) + assert.Equal(t, nil, blockResult[2][2]) + assert.Equal(t, nil, blockResult[2][3]) + assert.Equal(t, nil, blockResult[2][4]) + assert.Equal(t, nil, blockResult[2][5]) + assert.Equal(t, nil, blockResult[2][6]) + assert.Equal(t, nil, blockResult[2][7]) + assert.Equal(t, nil, blockResult[2][8]) + assert.Equal(t, nil, blockResult[2][9]) + assert.Equal(t, nil, blockResult[2][10]) + assert.Equal(t, nil, blockResult[2][11]) + assert.Equal(t, nil, blockResult[2][12]) + assert.Equal(t, nil, blockResult[2][13]) + assert.Equal(t, nil, blockResult[2][14]) + assert.Equal(t, nil, blockResult[2][15]) +} diff --git a/controller/ws/ws/raw.go b/controller/ws/ws/raw.go new file mode 100644 index 00000000..89111c02 --- /dev/null +++ b/controller/ws/ws/raw.go @@ -0,0 +1,77 @@ +package ws + +import ( + "context" + "unsafe" + + "github.com/sirupsen/logrus" + "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/driver-go/v3/wrapper" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/tools" + "github.com/taosdata/taosadapter/v3/tools/melody" +) + +func (h *messageHandler) binaryTMQRawMessage(ctx context.Context, session *melody.Session, action string, reqID uint64, message []byte, logger *logrus.Entry, isDebug bool) { + p0 := unsafe.Pointer(&message[0]) + length := *(*uint32)(tools.AddPointer(p0, uintptr(24))) + metaType := *(*uint16)(tools.AddPointer(p0, uintptr(28))) + data := tools.AddPointer(p0, uintptr(30)) + logger.Tracef("get write raw message, length:%d, metaType:%d", length, metaType) + logger.Trace("get global lock for raw message") + meta := wrapper.BuildRawMeta(length, metaType, data) + code := syncinterface.TMQWriteRaw(h.conn, meta, logger, isDebug) + if code != 0 { + errStr := wrapper.TMQErr2Str(code) + logger.Errorf("write raw meta error, code:%d, msg:%s", code, errStr) + commonErrorResponse(ctx, session, logger, action, reqID, int(code), errStr) + return + } + logger.Trace("write raw meta success") + commonSuccessResponse(ctx, session, logger, action, reqID) +} + +func (h *messageHandler) binaryRawBlockMessage(ctx context.Context, session *melody.Session, action string, reqID uint64, message []byte, logger *logrus.Entry, isDebug bool) { + p0 := unsafe.Pointer(&message[0]) + numOfRows := *(*int32)(tools.AddPointer(p0, uintptr(24))) + tableNameLength := *(*uint16)(tools.AddPointer(p0, uintptr(28))) + tableName := make([]byte, tableNameLength) + for i := 0; i < int(tableNameLength); i++ { + tableName[i] = *(*byte)(tools.AddPointer(p0, uintptr(30+i))) + } + rawBlock := tools.AddPointer(p0, uintptr(30+tableNameLength)) + logger.Tracef("raw block message, table:%s, rows:%d", tableName, numOfRows) + code := syncinterface.TaosWriteRawBlockWithReqID(h.conn, int(numOfRows), rawBlock, string(tableName), int64(reqID), logger, isDebug) + if code != 0 { + errStr := wrapper.TMQErr2Str(int32(code)) + logger.Errorf("write raw meta error, code:%d, msg:%s", code, errStr) + commonErrorResponse(ctx, session, logger, action, reqID, code, errStr) + return + } + logger.Trace("write raw meta success") + commonSuccessResponse(ctx, session, logger, action, reqID) +} + +func (h *messageHandler) binaryRawBlockMessageWithFields(ctx context.Context, session *melody.Session, action string, reqID uint64, message []byte, logger *logrus.Entry, isDebug bool) { + p0 := unsafe.Pointer(&message[0]) + numOfRows := *(*int32)(tools.AddPointer(p0, uintptr(24))) + tableNameLength := int(*(*uint16)(tools.AddPointer(p0, uintptr(28)))) + tableName := make([]byte, tableNameLength) + for i := 0; i < tableNameLength; i++ { + tableName[i] = *(*byte)(tools.AddPointer(p0, uintptr(30+i))) + } + rawBlock := tools.AddPointer(p0, uintptr(30+tableNameLength)) + blockLength := int(parser.RawBlockGetLength(rawBlock)) + numOfColumn := int(parser.RawBlockGetNumOfCols(rawBlock)) + fieldsBlock := tools.AddPointer(p0, uintptr(30+tableNameLength+blockLength)) + logger.Tracef("raw block message with fields, table:%s, rows:%d", tableName, numOfRows) + code := syncinterface.TaosWriteRawBlockWithFieldsWithReqID(h.conn, int(numOfRows), rawBlock, string(tableName), fieldsBlock, numOfColumn, int64(reqID), logger, isDebug) + if code != 0 { + errStr := wrapper.TMQErr2Str(int32(code)) + logger.Errorf("write raw meta error, code:%d, err:%s", code, errStr) + commonErrorResponse(ctx, session, logger, action, reqID, code, errStr) + return + } + logger.Trace("write raw meta success") + commonSuccessResponse(ctx, session, logger, action, reqID) +} diff --git a/controller/ws/ws/raw_test.go b/controller/ws/ws/raw_test.go new file mode 100644 index 00000000..5b1c40ad --- /dev/null +++ b/controller/ws/ws/raw_test.go @@ -0,0 +1,124 @@ +package ws + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" +) + +func TestWSTMQWriteRaw(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + data := []byte{ + 0x64, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0x01, 0x00, 0x00, 0x04, 0x73, 0x74, 0x62, + 0x00, 0xd5, 0xf0, 0xed, 0x8a, 0xe0, 0x23, 0xf3, 0x45, 0x00, 0x1c, 0x02, 0x09, 0x01, 0x10, 0x02, + 0x03, 0x74, 0x73, 0x00, 0x01, 0x01, 0x02, 0x04, 0x03, 0x63, 0x31, 0x00, 0x02, 0x01, 0x02, 0x06, + 0x03, 0x63, 0x32, 0x00, 0x03, 0x01, 0x04, 0x08, 0x03, 0x63, 0x33, 0x00, 0x04, 0x01, 0x08, 0x0a, + 0x03, 0x63, 0x34, 0x00, 0x05, 0x01, 0x10, 0x0c, 0x03, 0x63, 0x35, 0x00, 0x0b, 0x01, 0x02, 0x0e, + 0x03, 0x63, 0x36, 0x00, 0x0c, 0x01, 0x04, 0x10, 0x03, 0x63, 0x37, 0x00, 0x0d, 0x01, 0x08, 0x12, + 0x03, 0x63, 0x38, 0x00, 0x0e, 0x01, 0x10, 0x14, 0x03, 0x63, 0x39, 0x00, 0x06, 0x01, 0x08, 0x16, + 0x04, 0x63, 0x31, 0x30, 0x00, 0x07, 0x01, 0x10, 0x18, 0x04, 0x63, 0x31, 0x31, 0x00, 0x08, 0x01, + 0x2c, 0x1a, 0x04, 0x63, 0x31, 0x32, 0x00, 0x0a, 0x01, 0xa4, 0x01, 0x1c, 0x04, 0x63, 0x31, 0x33, + 0x00, 0x1c, 0x02, 0x09, 0x02, 0x10, 0x1e, 0x04, 0x74, 0x74, 0x73, 0x00, 0x01, 0x00, 0x02, 0x20, + 0x04, 0x74, 0x63, 0x31, 0x00, 0x02, 0x00, 0x02, 0x22, 0x04, 0x74, 0x63, 0x32, 0x00, 0x03, 0x00, + 0x04, 0x24, 0x04, 0x74, 0x63, 0x33, 0x00, 0x04, 0x00, 0x08, 0x26, 0x04, 0x74, 0x63, 0x34, 0x00, + 0x05, 0x00, 0x10, 0x28, 0x04, 0x74, 0x63, 0x35, 0x00, 0x0b, 0x00, 0x02, 0x2a, 0x04, 0x74, 0x63, + 0x36, 0x00, 0x0c, 0x00, 0x04, 0x2c, 0x04, 0x74, 0x63, 0x37, 0x00, 0x0d, 0x00, 0x08, 0x2e, 0x04, + 0x74, 0x63, 0x38, 0x00, 0x0e, 0x00, 0x10, 0x30, 0x04, 0x74, 0x63, 0x39, 0x00, 0x06, 0x00, 0x08, + 0x32, 0x05, 0x74, 0x63, 0x31, 0x30, 0x00, 0x07, 0x00, 0x10, 0x34, 0x05, 0x74, 0x63, 0x31, 0x31, + 0x00, 0x08, 0x00, 0x2c, 0x36, 0x05, 0x74, 0x63, 0x31, 0x32, 0x00, 0x0a, 0x00, 0xa4, 0x01, 0x38, + 0x05, 0x74, 0x63, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x1c, 0x02, 0x02, 0x02, + 0x01, 0x00, 0x02, 0x04, 0x02, 0x01, 0x00, 0x03, 0x06, 0x02, 0x01, 0x00, 0x01, 0x08, 0x02, 0x01, + 0x00, 0x01, 0x0a, 0x02, 0x01, 0x00, 0x01, 0x0c, 0x02, 0x01, 0x00, 0x01, 0x0e, 0x02, 0x01, 0x00, + 0x01, 0x10, 0x02, 0x01, 0x00, 0x01, 0x12, 0x02, 0x01, 0x00, 0x01, 0x14, 0x02, 0x01, 0x00, 0x01, + 0x16, 0x02, 0x01, 0x00, 0x04, 0x18, 0x02, 0x01, 0x00, 0x04, 0x1a, 0x02, 0x01, 0x00, 0xff, 0x1c, + 0x02, 0x01, 0x00, 0xff, + } + length := uint32(356) + metaType := uint16(531) + code, message := doRestful("create database if not exists test_ws_tmq_write_raw", "") + assert.Equal(t, 0, code, message) + defer func() { + code, message := doRestful("drop database if exists test_ws_tmq_write_raw", "") + assert.Equal(t, 0, code, message) + }() + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_tmq_write_raw"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + buffer := bytes.Buffer{} + wstool.WriteUint64(&buffer, 2) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(TMQRawMessage)) + wstool.WriteUint32(&buffer, length) + wstool.WriteUint16(&buffer, metaType) + buffer.Write(data) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var tmqResp commonResp + err = json.Unmarshal(resp, &tmqResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), tmqResp.ReqID) + assert.Equal(t, 0, tmqResp.Code, tmqResp.Message) + + d := restQuery("describe stb", "test_ws_tmq_write_raw") + expect := [][]driver.Value{ + {"ts", "TIMESTAMP", float64(8), ""}, + {"c1", "BOOL", float64(1), ""}, + {"c2", "TINYINT", float64(1), ""}, + {"c3", "SMALLINT", float64(2), ""}, + {"c4", "INT", float64(4), ""}, + {"c5", "BIGINT", float64(8), ""}, + {"c6", "TINYINT UNSIGNED", float64(1), ""}, + {"c7", "SMALLINT UNSIGNED", float64(2), ""}, + {"c8", "INT UNSIGNED", float64(4), ""}, + {"c9", "BIGINT UNSIGNED", float64(8), ""}, + {"c10", "FLOAT", float64(4), ""}, + {"c11", "DOUBLE", float64(8), ""}, + {"c12", "VARCHAR", float64(20), ""}, + {"c13", "NCHAR", float64(20), ""}, + {"tts", "TIMESTAMP", float64(8), "TAG"}, + {"tc1", "BOOL", float64(1), "TAG"}, + {"tc2", "TINYINT", float64(1), "TAG"}, + {"tc3", "SMALLINT", float64(2), "TAG"}, + {"tc4", "INT", float64(4), "TAG"}, + {"tc5", "BIGINT", float64(8), "TAG"}, + {"tc6", "TINYINT UNSIGNED", float64(1), "TAG"}, + {"tc7", "SMALLINT UNSIGNED", float64(2), "TAG"}, + {"tc8", "INT UNSIGNED", float64(4), "TAG"}, + {"tc9", "BIGINT UNSIGNED", float64(8), "TAG"}, + {"tc10", "FLOAT", float64(4), "TAG"}, + {"tc11", "DOUBLE", float64(8), "TAG"}, + {"tc12", "VARCHAR", float64(20), "TAG"}, + {"tc13", "NCHAR", float64(20), "TAG"}, + } + for rowIndex, values := range d.Data { + for i := 0; i < 4; i++ { + assert.Equal(t, expect[rowIndex][i], values[i]) + } + } +} diff --git a/controller/ws/ws/resp.go b/controller/ws/ws/resp.go new file mode 100644 index 00000000..3129f55f --- /dev/null +++ b/controller/ws/ws/resp.go @@ -0,0 +1,112 @@ +package ws + +import ( + "context" + "encoding/binary" + "unsafe" + + "github.com/sirupsen/logrus" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/tools/bytesutil" + "github.com/taosdata/taosadapter/v3/tools/melody" +) + +type commonResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` +} + +func commonErrorResponse(ctx context.Context, session *melody.Session, logger *logrus.Entry, action string, reqID uint64, code int, message string) { + data := &commonResp{ + Code: code & 0xffff, + Message: message, + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + } + wstool.WSWriteJson(session, logger, data) +} + +func commonSuccessResponse(ctx context.Context, session *melody.Session, logger *logrus.Entry, action string, reqID uint64) { + data := &commonResp{ + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + } + wstool.WSWriteJson(session, logger, data) +} + +func fetchRawBlockErrorResponse(session *melody.Session, logger *logrus.Entry, code int, message string, reqID uint64, resultID uint64, t uint64) { + bufLength := 8 + 8 + 2 + 8 + 8 + 4 + 4 + len(message) + 8 + 1 + buf := make([]byte, bufLength) + binary.LittleEndian.PutUint64(buf, 0xffffffffffffffff) + binary.LittleEndian.PutUint64(buf[8:], uint64(FetchRawBlockMessage)) + binary.LittleEndian.PutUint16(buf[16:], 1) + binary.LittleEndian.PutUint64(buf[18:], t) + binary.LittleEndian.PutUint64(buf[26:], reqID) + binary.LittleEndian.PutUint32(buf[34:], uint32(code&0xffff)) + binary.LittleEndian.PutUint32(buf[38:], uint32(len(message))) + copy(buf[42:], message) + binary.LittleEndian.PutUint64(buf[42+len(message):], resultID) + buf[42+len(message)+8] = 1 + wstool.WSWriteBinary(session, buf, logger) +} + +func fetchRawBlockFinishResponse(session *melody.Session, logger *logrus.Entry, reqID uint64, resultID uint64, t uint64) { + bufLength := 8 + 8 + 2 + 8 + 8 + 4 + 4 + 8 + 1 + buf := make([]byte, bufLength) + binary.LittleEndian.PutUint64(buf, 0xffffffffffffffff) + binary.LittleEndian.PutUint64(buf[8:], uint64(FetchRawBlockMessage)) + binary.LittleEndian.PutUint16(buf[16:], 1) + binary.LittleEndian.PutUint64(buf[18:], t) + binary.LittleEndian.PutUint64(buf[26:], reqID) + binary.LittleEndian.PutUint32(buf[34:], 0) + binary.LittleEndian.PutUint32(buf[38:], 0) + binary.LittleEndian.PutUint64(buf[42:], resultID) + buf[50] = 1 + wstool.WSWriteBinary(session, buf, logger) +} + +func fetchRawBlockMessage(buf []byte, reqID uint64, resultID uint64, t uint64, blockLength int32, rawBlock unsafe.Pointer) []byte { + bufLength := 8 + 8 + 2 + 8 + 8 + 4 + 4 + 8 + 1 + 4 + int(blockLength) + if cap(buf) < bufLength { + buf = make([]byte, 0, bufLength) + } + buf = buf[:bufLength] + binary.LittleEndian.PutUint64(buf, 0xffffffffffffffff) + binary.LittleEndian.PutUint64(buf[8:], uint64(FetchRawBlockMessage)) + binary.LittleEndian.PutUint16(buf[16:], 1) + binary.LittleEndian.PutUint64(buf[18:], t) + binary.LittleEndian.PutUint64(buf[26:], reqID) + binary.LittleEndian.PutUint32(buf[34:], 0) + binary.LittleEndian.PutUint32(buf[38:], 0) + binary.LittleEndian.PutUint64(buf[42:], resultID) + buf[50] = 0 + binary.LittleEndian.PutUint32(buf[51:], uint32(blockLength)) + bytesutil.Copy(rawBlock, buf, 55, int(blockLength)) + return buf +} + +type stmtErrorResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func stmtErrorResponse(ctx context.Context, session *melody.Session, logger *logrus.Entry, action string, reqID uint64, code int, message string, stmtID uint64) { + resp := &stmtErrorResp{ + Code: code & 0xffff, + Message: message, + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtID, + } + wstool.WSWriteJson(session, logger, resp) +} diff --git a/controller/ws/ws/schemaless.go b/controller/ws/ws/schemaless.go new file mode 100644 index 00000000..532cea1c --- /dev/null +++ b/controller/ws/ws/schemaless.go @@ -0,0 +1,57 @@ +package ws + +import ( + "context" + + "github.com/sirupsen/logrus" + "github.com/taosdata/driver-go/v3/wrapper" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/tools/melody" +) + +type schemalessWriteRequest struct { + ReqID uint64 `json:"req_id"` + Protocol int `json:"protocol"` + Precision string `json:"precision"` + TTL int `json:"ttl"` + Data string `json:"data"` + TableNameKey string `json:"table_name_key"` +} + +type schemalessWriteResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + AffectedRows int `json:"affected_rows"` + TotalRows int32 `json:"total_rows"` +} + +func (h *messageHandler) schemalessWrite(ctx context.Context, session *melody.Session, action string, req *schemalessWriteRequest, logger *logrus.Entry, isDebug bool) { + if req.Protocol == 0 { + logger.Error("schemaless write request error. protocol is null") + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "schemaless write protocol is null") + return + } + var affectedRows int + totalRows, result := syncinterface.TaosSchemalessInsertRawTTLWithReqIDTBNameKey(h.conn, req.Data, req.Protocol, req.Precision, req.TTL, int64(req.ReqID), req.TableNameKey, logger, isDebug) + defer syncinterface.FreeResult(result, logger, isDebug) + if code := wrapper.TaosError(result); code != 0 { + errStr := wrapper.TaosErrorStr(result) + logger.Errorf("schemaless write error, code:%d, err:%s", code, errStr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr) + return + } + affectedRows = wrapper.TaosAffectedRows(result) + logger.Tracef("schemaless write total rows:%d, affected rows:%d", totalRows, affectedRows) + resp := schemalessWriteResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + AffectedRows: affectedRows, + TotalRows: totalRows, + } + wstool.WSWriteJson(session, logger, resp) +} diff --git a/controller/ws/ws/schemaless_test.go b/controller/ws/ws/schemaless_test.go new file mode 100644 index 00000000..e925a2d6 --- /dev/null +++ b/controller/ws/ws/schemaless_test.go @@ -0,0 +1,241 @@ +package ws + +import ( + "encoding/json" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/ws/schemaless" +) + +func TestWsSchemaless(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful("drop database if exists test_ws_schemaless", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_ws_schemaless", "") + assert.Equal(t, 0, code, message) + + defer doRestful("drop database if exists test_ws_schemaless", "") + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + cases := []struct { + name string + protocol int + precision string + data string + ttl int + totalRows int32 + affectedRows int + tableNameKey string + }{ + { + name: "influxdb", + protocol: schemaless.InfluxDBLineProtocol, + precision: "ms", + data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837600000", + ttl: 1000, + totalRows: 4, + affectedRows: 4, + }, + { + name: "opentsdb_telnet", + protocol: schemaless.OpenTSDBTelnetLineProtocol, + precision: "ms", + data: "meters.current 1648432611249 10.3 location=California.SanFrancisco group=2\n" + + "meters.current 1648432611250 12.6 location=California.SanFrancisco group=2\n" + + "meters.current 1648432611249 10.8 location=California.LosAngeles group=3\n" + + "meters.current 1648432611250 11.3 location=California.LosAngeles group=3\n" + + "meters.voltage 1648432611249 219 location=California.SanFrancisco group=2\n" + + "meters.voltage 1648432611250 218 location=California.SanFrancisco group=2\n" + + "meters.voltage 1648432611249 221 location=California.LosAngeles group=3\n" + + "meters.voltage 1648432611250 217 location=California.LosAngeles group=3", + ttl: 1000, + totalRows: 8, + affectedRows: 8, + }, + { + name: "opentsdb_json", + protocol: schemaless.OpenTSDBJsonFormatProtocol, + precision: "ms", + data: `[ + { + "metric": "meters2.current", + "timestamp": 1648432611249, + "value": 10.3, + "tags": { + "location": "California.SanFrancisco", + "groupid": 2 + } + }, + { + "metric": "meters2.voltage", + "timestamp": 1648432611249, + "value": 219, + "tags": { + "location": "California.LosAngeles", + "groupid": 1 + } + }, + { + "metric": "meters2.current", + "timestamp": 1648432611250, + "value": 12.6, + "tags": { + "location": "California.SanFrancisco", + "groupid": 2 + } + }, + { + "metric": "meters2.voltage", + "timestamp": 1648432611250, + "value": 221, + "tags": { + "location": "California.LosAngeles", + "groupid": 1 + } + } +]`, + ttl: 100, + affectedRows: 4, + }, + { + name: "influxdb_tbnamekey", + protocol: schemaless.InfluxDBLineProtocol, + precision: "ms", + data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837600000", + ttl: 1000, + totalRows: 4, + affectedRows: 4, + tableNameKey: "host", + }, + } + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_schemaless"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + for _, c := range cases { + reqID := uint64(1) + t.Run(c.name, func(t *testing.T) { + reqID += 1 + req := schemalessWriteRequest{ + ReqID: reqID, + Protocol: c.protocol, + Precision: c.precision, + TTL: c.ttl, + Data: c.data, + TableNameKey: c.tableNameKey, + } + resp, err = doWebSocket(ws, SchemalessWrite, &req) + assert.NoError(t, err) + var schemalessResp schemalessWriteResponse + err = json.Unmarshal(resp, &schemalessResp) + assert.NoError(t, err, string(resp)) + assert.Equal(t, reqID, schemalessResp.ReqID) + assert.Equal(t, 0, schemalessResp.Code, schemalessResp.Message) + if c.protocol != schemaless.OpenTSDBJsonFormatProtocol { + assert.Equal(t, c.totalRows, schemalessResp.TotalRows) + } + assert.Equal(t, c.affectedRows, schemalessResp.AffectedRows) + }) + } +} + +func TestWsSchemalessError(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful("drop database if exists test_ws_schemaless_error", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_ws_schemaless_error", "") + assert.Equal(t, 0, code, message) + + defer doRestful("drop database if exists test_ws_schemaless_error", "") + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + cases := []struct { + name string + protocol int + precision string + data string + }{ + { + name: "wrong protocol", + protocol: 0, + precision: "ms", + data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837600000", + }, + { + name: "wrong timestamp", + protocol: schemaless.InfluxDBLineProtocol, + precision: "ms", + data: "measurement,host=host1 field1=2i,field2=2.0 10", + }, + } + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_schemaless_error"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + for _, c := range cases { + reqID := uint64(1) + t.Run(c.name, func(t *testing.T) { + reqID += 1 + req := schemalessWriteRequest{ + ReqID: reqID, + Protocol: c.protocol, + Precision: c.precision, + Data: c.data, + } + resp, err = doWebSocket(ws, SchemalessWrite, &req) + assert.NoError(t, err) + var schemalessResp schemalessWriteResponse + err = json.Unmarshal(resp, &schemalessResp) + assert.NoError(t, err, string(resp)) + assert.Equal(t, reqID, schemalessResp.ReqID) + assert.NotEqual(t, 0, schemalessResp.Code) + }) + } +} diff --git a/controller/ws/ws/stmt.go b/controller/ws/ws/stmt.go new file mode 100644 index 00000000..ff7fd552 --- /dev/null +++ b/controller/ws/ws/stmt.go @@ -0,0 +1,893 @@ +package ws + +import ( + "context" + "database/sql/driver" + "encoding/json" + "fmt" + "unsafe" + + "github.com/sirupsen/logrus" + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/parser" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" + errors2 "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/types" + "github.com/taosdata/driver-go/v3/wrapper" + "github.com/taosdata/taosadapter/v3/controller/ws/stmt" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/log" + "github.com/taosdata/taosadapter/v3/tools" + "github.com/taosdata/taosadapter/v3/tools/jsontype" + "github.com/taosdata/taosadapter/v3/tools/melody" +) + +type stmtInitRequest struct { + ReqID uint64 `json:"req_id"` +} + +type stmtInitResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmtInit(ctx context.Context, session *melody.Session, action string, req *stmtInitRequest, logger *logrus.Entry, isDebug bool) { + stmtInit := syncinterface.TaosStmtInitWithReqID(h.conn, int64(req.ReqID), logger, isDebug) + if stmtInit == nil { + errStr := wrapper.TaosStmtErrStr(stmtInit) + logger.Errorf("stmt init error, err:%s", errStr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, errStr) + return + } + stmtItem := &StmtItem{stmt: stmtInit} + h.stmts.Add(stmtItem) + logger.Tracef("stmt init sucess, stmt_id:%d, stmt pointer:%p", stmtItem.index, stmtInit) + resp := stmtInitResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtItem.index, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtPrepareRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + SQL string `json:"sql"` +} + +type stmtPrepareResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + IsInsert bool `json:"is_insert"` +} + +func (h *messageHandler) stmtValidateAndLock(ctx context.Context, session *melody.Session, action string, reqID uint64, stmtID uint64, logger *logrus.Entry, isDebug bool) (stmtItem *StmtItem, locked bool) { + stmtItem = h.stmts.Get(stmtID) + if stmtItem == nil { + logger.Errorf("stmt is nil, stmt_id:%d", stmtID) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "stmt is nil", stmtID) + return nil, false + } + s := log.GetLogNow(isDebug) + logger.Trace("get stmt lock") + stmtItem.Lock() + logger.Debugf("get stmt lock cost:%s", log.GetLogDuration(isDebug, s)) + if stmtItem.stmt == nil { + stmtItem.Unlock() + logger.Errorf("stmt has been freed, stmt_id:%d", stmtID) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "stmt has been freed", stmtID) + return nil, false + } + return stmtItem, true +} + +func (h *messageHandler) stmtPrepare(ctx context.Context, session *melody.Session, action string, req *stmtPrepareRequest, logger *logrus.Entry, isDebug bool) { + logger.Debugf("stmt prepare, stmt_id:%d, sql:%s", req.StmtID, req.SQL) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code := syncinterface.TaosStmtPrepare(stmtItem.stmt, req.SQL, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt prepare error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Tracef("stmt prepare success, stmt_id:%d", req.StmtID) + isInsert, code := syncinterface.TaosStmtIsInsert(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("check stmt is insert error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Tracef("stmt is insert:%t", isInsert) + stmtItem.isInsert = isInsert + resp := stmtPrepareResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtItem.index, + IsInsert: isInsert, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtSetTableNameRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + Name string `json:"name"` +} + +type stmtSetTableNameResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmtSetTableName(ctx context.Context, session *melody.Session, action string, req *stmtSetTableNameRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt set table name, stmt_id:%d, name:%s", req.StmtID, req.Name) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code := syncinterface.TaosStmtSetTBName(stmtItem.stmt, req.Name, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt set table name error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Tracef("stmt set table name success, stmt_id:%d", req.StmtID) + resp := stmtSetTableNameResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtSetTagsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + Tags json.RawMessage `json:"tags"` +} + +type stmtSetTagsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmtSetTags(ctx context.Context, session *melody.Session, action string, req *stmtSetTagsRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt set tags, stmt_id:%d, tags:%s", req.StmtID, req.Tags) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code, tagNums, tagFields := syncinterface.TaosStmtGetTagFields(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt get tag fields error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + defer func() { + wrapper.TaosStmtReclaimFields(stmtItem.stmt, tagFields) + }() + logger.Tracef("stmt tag nums:%d", tagNums) + if tagNums == 0 { + logger.Trace("no tags") + resp := stmtSetTagsResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) + return + } + s := log.GetLogNow(isDebug) + fields := wrapper.StmtParseFields(tagNums, tagFields) + logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) + data, err := stmt.StmtParseTag(req.Tags, fields) + logger.Debugf("stmt parse tag json cost:%s", log.GetLogDuration(isDebug, s)) + if err != nil { + logger.Errorf("stmt parse tag json error, err:%s", err.Error()) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, fmt.Sprintf("stmt parse tag json:%s", err.Error()), req.StmtID) + return + } + code = syncinterface.TaosStmtSetTags(stmtItem.stmt, data, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt set tags error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Trace("stmt set tags success") + resp := stmtSetTagsResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtBindRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + Columns json.RawMessage `json:"columns"` +} + +type stmtBindResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmtBind(ctx context.Context, session *melody.Session, action string, req *stmtBindRequest, logger *logrus.Entry, isDebug bool) { + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code, colNums, colFields := syncinterface.TaosStmtGetColFields(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt get col fields error,code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + defer func() { + wrapper.TaosStmtReclaimFields(stmtItem.stmt, colFields) + }() + if colNums == 0 { + logger.Trace("no columns") + resp := stmtBindResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) + return + } + s := log.GetLogNow(isDebug) + fields := wrapper.StmtParseFields(colNums, colFields) + logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) + fieldTypes := make([]*types.ColumnType, colNums) + + var err error + for i := 0; i < colNums; i++ { + if fieldTypes[i], err = fields[i].GetType(); err != nil { + logger.Errorf("stmt get column type error, err:%s", err.Error()) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, fmt.Sprintf("stmt get column type error, err:%s", err.Error()), req.StmtID) + return + } + } + s = log.GetLogNow(isDebug) + data, err := stmt.StmtParseColumn(req.Columns, fields, fieldTypes) + logger.Debugf("stmt parse column json cost:%s", log.GetLogDuration(isDebug, s)) + if err != nil { + logger.Errorf("stmt parse column json error, err:%s", err.Error()) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, fmt.Sprintf("stmt parse column json:%s", err.Error()), req.StmtID) + return + } + code = syncinterface.TaosStmtBindParamBatch(stmtItem.stmt, data, fieldTypes, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt bind param error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Trace("stmt bind success") + resp := stmtBindResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtAddBatchRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmtAddBatchResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmtAddBatch(ctx context.Context, session *melody.Session, action string, req *stmtAddBatchRequest, logger *logrus.Entry, isDebug bool) { + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code := syncinterface.TaosStmtAddBatch(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt add batch error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Trace("stmt add batch success") + resp := stmtAddBatchResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtExecRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmtExecResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Affected int `json:"affected"` +} + +func (h *messageHandler) stmtExec(ctx context.Context, session *melody.Session, action string, req *stmtExecRequest, logger *logrus.Entry, isDebug bool) { + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code := syncinterface.TaosStmtExecute(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt execute error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + s := log.GetLogNow(isDebug) + affected := wrapper.TaosStmtAffectedRowsOnce(stmtItem.stmt) + logger.Debugf("stmt_affected_rows_once, affected:%d, cost:%s", affected, log.GetLogDuration(isDebug, s)) + resp := stmtExecResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + Affected: affected, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtCloseRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmtClose(ctx context.Context, session *melody.Session, action string, req *stmtCloseRequest, logger *logrus.Entry) { + logger.Tracef("stmt close, stmt_id:%d", req.StmtID) + err := h.stmts.FreeStmtByID(req.StmtID, false, logger) + if err != nil { + logger.Errorf("stmt close error, err:%s", err.Error()) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, err.Error(), req.StmtID) + return + } + logger.Tracef("stmt close success, stmt_id:%d", req.StmtID) +} + +type stmtGetTagFieldsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmtGetTagFieldsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Fields []*stmtCommon.StmtField `json:"fields,omitempty"` +} + +func (h *messageHandler) stmtGetTagFields(ctx context.Context, session *melody.Session, action string, req *stmtGetTagFieldsRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt get tag fields, stmt_id:%d", req.StmtID) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code, tagNums, tagFields := syncinterface.TaosStmtGetTagFields(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt get tag fields error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + defer func() { + wrapper.TaosStmtReclaimFields(stmtItem.stmt, tagFields) + }() + if tagNums == 0 { + logger.Trace("no tags") + resp := stmtGetTagFieldsResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) + return + } + s := log.GetLogNow(isDebug) + fields := wrapper.StmtParseFields(tagNums, tagFields) + logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) + resp := stmtGetTagFieldsResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + Fields: fields, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtGetColFieldsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmtGetColFieldsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Fields []*stmtCommon.StmtField `json:"fields"` +} + +func (h *messageHandler) stmtGetColFields(ctx context.Context, session *melody.Session, action string, req *stmtGetColFieldsRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt get col fields, stmt_id:%d", req.StmtID) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code, colNums, colFields := syncinterface.TaosStmtGetColFields(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt get col fields error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + defer func() { + wrapper.TaosStmtReclaimFields(stmtItem.stmt, colFields) + }() + if colNums == 0 { + logger.Trace("no columns") + resp := stmtGetColFieldsResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) + return + } + s := log.GetLogNow(isDebug) + fields := wrapper.StmtParseFields(colNums, colFields) + logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) + resp := stmtGetColFieldsResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + Fields: fields, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtUseResultRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmtUseResultResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + ResultID uint64 `json:"result_id"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes jsontype.JsonUint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} + +func (h *messageHandler) stmtUseResult(ctx context.Context, session *melody.Session, action string, req *stmtUseResultRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt use result, stmt_id:%d", req.StmtID) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + logger.Trace("call stmt use result") + result := wrapper.TaosStmtUseResult(stmtItem.stmt) + if result == nil { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt use result error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, errStr, req.StmtID) + return + } + + fieldsCount := wrapper.TaosNumFields(result) + rowsHeader, _ := wrapper.ReadColumn(result, fieldsCount) + precision := wrapper.TaosResultPrecision(result) + logger.Tracef("stmt use result success, stmt_id:%d, fields_count:%d, precision:%d", req.StmtID, fieldsCount, precision) + queryResult := QueryResult{ + TaosResult: result, + FieldsCount: fieldsCount, + Header: rowsHeader, + precision: precision, + inStmt: true, + } + idx := h.queryResults.Add(&queryResult) + logger.Tracef("add query result, result_id:%d", idx) + resp := &stmtUseResultResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + ResultID: idx, + FieldsCount: fieldsCount, + FieldsNames: rowsHeader.ColNames, + FieldsTypes: rowsHeader.ColTypes, + FieldsLengths: rowsHeader.ColLength, + Precision: precision, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtNumParamsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmtNumParamsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + NumParams int `json:"num_params"` +} + +func (h *messageHandler) stmtNumParams(ctx context.Context, session *melody.Session, action string, req *stmtNumParamsRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt num params, stmt_id:%d", req.StmtID) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + count, code := syncinterface.TaosStmtNumParams(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt get col fields error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Tracef("stmt num params success, stmt_id:%d, num_params:%d", req.StmtID, count) + resp := stmtNumParamsResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + NumParams: count, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmtGetParamRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + Index int `json:"index"` +} + +type stmtGetParamResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Index int `json:"index"` + DataType int `json:"data_type"` + Length int `json:"length"` +} + +func (h *messageHandler) stmtGetParam(ctx context.Context, session *melody.Session, action string, req *stmtGetParamRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt get param, stmt_id:%d, index:%d", req.StmtID, req.Index) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + dataType, length, err := syncinterface.TaosStmtGetParam(stmtItem.stmt, req.Index, logger, isDebug) + if err != nil { + taosErr := err.(*errors2.TaosError) + logger.Errorf("stmt get param error, err:%s", taosErr.Error()) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr, req.StmtID) + return + } + logger.Tracef("stmt get param success, data_type:%d, length:%d", dataType, length) + resp := &stmtGetParamResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + Index: req.Index, + DataType: dataType, + Length: length, + } + wstool.WSWriteJson(session, logger, resp) +} + +func (h *messageHandler) stmtBinarySetTags(ctx context.Context, session *melody.Session, action string, reqID uint64, stmtID uint64, message []byte, logger *logrus.Entry, isDebug bool) { + p0 := unsafe.Pointer(&message[0]) + block := tools.AddPointer(p0, uintptr(24)) + columns := parser.RawBlockGetNumOfCols(block) + rows := parser.RawBlockGetNumOfRows(block) + logger.Tracef("set tags message, stmt_id:%d, columns:%d, rows:%d", stmtID, columns, rows) + if rows != 1 { + logger.Errorf("rows not equal 1, rows:%d", rows) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "rows not equal 1", stmtID) + return + } + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, reqID, stmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code, tagNums, tagFields := syncinterface.TaosStmtGetTagFields(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt get tag fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, reqID, code, errStr, stmtID) + return + } + defer func() { + wrapper.TaosStmtReclaimFields(stmtItem.stmt, tagFields) + }() + if tagNums == 0 { + logger.Trace("no tags") + resp := stmtSetTagsResponse{ + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtID, + } + wstool.WSWriteJson(session, logger, resp) + return + } + if int(columns) != tagNums { + logger.Errorf("stmt tags count not match %d != %d", columns, tagNums) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "stmt tags count not match", stmtID) + return + } + s := log.GetLogNow(isDebug) + fields := wrapper.StmtParseFields(tagNums, tagFields) + logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + tags := stmt.BlockConvert(block, int(rows), fields, nil) + logger.Debugf("block concert cost:%s", log.GetLogDuration(isDebug, s)) + reTags := make([]driver.Value, tagNums) + for i := 0; i < tagNums; i++ { + reTags[i] = tags[i][0] + } + code = syncinterface.TaosStmtSetTags(stmtItem.stmt, reTags, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt set tags error, code:%d, msg:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, reqID, code, errStr, stmtID) + return + } + resp := stmtSetTagsResponse{ + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtID, + } + wstool.WSWriteJson(session, logger, resp) +} + +func (h *messageHandler) stmtBinaryBind(ctx context.Context, session *melody.Session, action string, reqID uint64, stmtID uint64, message []byte, logger *logrus.Entry, isDebug bool) { + p0 := unsafe.Pointer(&message[0]) + block := tools.AddPointer(p0, uintptr(24)) + columns := parser.RawBlockGetNumOfCols(block) + rows := parser.RawBlockGetNumOfRows(block) + logger.Tracef("bind message, stmt_id:%d columns:%d, rows:%d", stmtID, columns, rows) + stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, reqID, stmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + var data [][]driver.Value + var fieldTypes []*types.ColumnType + if stmtItem.isInsert { + code, colNums, colFields := syncinterface.TaosStmtGetColFields(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt get col fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, reqID, code, errStr, stmtID) + return + } + defer func() { + wrapper.TaosStmtReclaimFields(stmtItem.stmt, colFields) + }() + if colNums == 0 { + logger.Trace("no columns") + resp := stmtBindResponse{ + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtID, + } + wstool.WSWriteJson(session, logger, resp) + return + } + s := log.GetLogNow(isDebug) + fields := wrapper.StmtParseFields(colNums, colFields) + logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) + fieldTypes = make([]*types.ColumnType, colNums) + var err error + for i := 0; i < colNums; i++ { + fieldTypes[i], err = fields[i].GetType() + if err != nil { + logger.Errorf("stmt get column type error, err:%s", err.Error()) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, fmt.Sprintf("stmt get column type error, err:%s", err.Error()), stmtID) + return + } + } + if int(columns) != colNums { + logger.Errorf("stmt column count not match %d != %d", columns, colNums) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "stmt column count not match", stmtID) + return + } + s = log.GetLogNow(isDebug) + data = stmt.BlockConvert(block, int(rows), fields, fieldTypes) + logger.Debugf("block convert cost:%s", log.GetLogDuration(isDebug, s)) + } else { + var fields []*stmtCommon.StmtField + var err error + logger.Trace("parse row block info") + fields, fieldTypes, err = parseRowBlockInfo(block, int(columns)) + if err != nil { + logger.Errorf("parse row block info error, err:%s", err.Error()) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, fmt.Sprintf("parse row block info error, err:%s", err.Error()), stmtID) + return + } + logger.Trace("convert block to data") + data = stmt.BlockConvert(block, int(rows), fields, fieldTypes) + logger.Trace("convert block to data finish") + } + + code := syncinterface.TaosStmtBindParamBatch(stmtItem.stmt, data, fieldTypes, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt bind param error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, reqID, code, errStr, stmtID) + return + } + logger.Trace("stmt bind param success") + resp := stmtBindResponse{ + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtID, + } + wstool.WSWriteJson(session, logger, resp) +} + +func parseRowBlockInfo(block unsafe.Pointer, columns int) (fields []*stmtCommon.StmtField, fieldTypes []*types.ColumnType, err error) { + infos := make([]parser.RawBlockColInfo, columns) + parser.RawBlockGetColInfo(block, infos) + + fields = make([]*stmtCommon.StmtField, len(infos)) + fieldTypes = make([]*types.ColumnType, len(infos)) + + for i, info := range infos { + switch info.ColType { + case common.TSDB_DATA_TYPE_BOOL: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_BOOL} + fieldTypes[i] = &types.ColumnType{Type: types.TaosBoolType} + case common.TSDB_DATA_TYPE_TINYINT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_TINYINT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosTinyintType} + case common.TSDB_DATA_TYPE_SMALLINT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_SMALLINT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosSmallintType} + case common.TSDB_DATA_TYPE_INT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_INT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosIntType} + case common.TSDB_DATA_TYPE_BIGINT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_BIGINT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosBigintType} + case common.TSDB_DATA_TYPE_FLOAT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_FLOAT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosFloatType} + case common.TSDB_DATA_TYPE_DOUBLE: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_DOUBLE} + fieldTypes[i] = &types.ColumnType{Type: types.TaosDoubleType} + case common.TSDB_DATA_TYPE_BINARY: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_BINARY} + fieldTypes[i] = &types.ColumnType{Type: types.TaosBinaryType} + //case common.TSDB_DATA_TYPE_TIMESTAMP:// todo precision + // fields[i] = &stmtCommon.StmtField{FieldType:common.TSDB_DATA_TYPE_TIMESTAMP} + // fieldTypes[i] = &types.ColumnType{Type:types.TaosTimestampType} + case common.TSDB_DATA_TYPE_NCHAR: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_NCHAR} + fieldTypes[i] = &types.ColumnType{Type: types.TaosNcharType} + case common.TSDB_DATA_TYPE_UTINYINT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_UTINYINT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosUTinyintType} + case common.TSDB_DATA_TYPE_USMALLINT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_USMALLINT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosUSmallintType} + case common.TSDB_DATA_TYPE_UINT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_UINT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosUIntType} + case common.TSDB_DATA_TYPE_UBIGINT: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_UBIGINT} + fieldTypes[i] = &types.ColumnType{Type: types.TaosUBigintType} + case common.TSDB_DATA_TYPE_JSON: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_JSON} + fieldTypes[i] = &types.ColumnType{Type: types.TaosJsonType} + case common.TSDB_DATA_TYPE_VARBINARY: + fields[i] = &stmtCommon.StmtField{FieldType: common.TSDB_DATA_TYPE_VARBINARY} + fieldTypes[i] = &types.ColumnType{Type: types.TaosBinaryType} + default: + err = fmt.Errorf("unsupported data type %d", info.ColType) + } + } + + return +} diff --git a/controller/ws/ws/stmt2.go b/controller/ws/ws/stmt2.go new file mode 100644 index 00000000..4f6bc689 --- /dev/null +++ b/controller/ws/ws/stmt2.go @@ -0,0 +1,448 @@ +package ws + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "unsafe" + + "github.com/sirupsen/logrus" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" + errors2 "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/wrapper" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/db/async" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/log" + "github.com/taosdata/taosadapter/v3/tools/jsontype" + "github.com/taosdata/taosadapter/v3/tools/melody" +) + +type stmt2InitRequest struct { + ReqID uint64 `json:"req_id"` + SingleStbInsert bool `json:"single_stb_insert"` + SingleTableBindOnce bool `json:"single_table_bind_once"` +} + +type stmt2InitResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmt2Init(ctx context.Context, session *melody.Session, action string, req *stmt2InitRequest, logger *logrus.Entry, isDebug bool) { + handle, caller := async.GlobalStmt2CallBackCallerPool.Get() + stmtInit := syncinterface.TaosStmt2Init(h.conn, int64(req.ReqID), req.SingleStbInsert, req.SingleTableBindOnce, handle, logger, isDebug) + if stmtInit == nil { + async.GlobalStmt2CallBackCallerPool.Put(handle) + errStr := wrapper.TaosStmtErrStr(stmtInit) + logger.Errorf("stmt2 init error, err:%s", errStr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, errStr) + return + } + stmtItem := &StmtItem{stmt: stmtInit, handler: handle, caller: caller, isStmt2: true} + h.stmts.Add(stmtItem) + logger.Tracef("stmt2 init sucess, stmt_id:%d, stmt pointer:%p", stmtItem.index, stmtInit) + resp := stmt2InitResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtItem.index, + } + wstool.WSWriteJson(session, logger, resp) +} + +func (h *messageHandler) stmt2ValidateAndLock(ctx context.Context, session *melody.Session, action string, reqID uint64, stmtID uint64, logger *logrus.Entry, isDebug bool) (stmtItem *StmtItem, locked bool) { + stmtItem = h.stmts.GetStmt2(stmtID) + if stmtItem == nil { + logger.Errorf("stmt2 is nil, stmt_id:%d", stmtID) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "stmt2 is nil", stmtID) + return nil, false + } + s := log.GetLogNow(isDebug) + logger.Trace("get stmt2 lock") + stmtItem.Lock() + logger.Debugf("get stmt2 lock cost:%s", log.GetLogDuration(isDebug, s)) + if stmtItem.stmt == nil { + stmtItem.Unlock() + logger.Errorf("stmt2 has been freed, stmt_id:%d", stmtID) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "stmt has been freed", stmtID) + return nil, false + } + return stmtItem, true +} + +type stmt2PrepareRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + SQL string `json:"sql"` + GetFields bool `json:"get_fields"` +} + +type prepareFields struct { + stmtCommon.StmtField + BindType int8 `json:"bind_type"` +} + +type stmt2PrepareResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + IsInsert bool `json:"is_insert"` + Fields []*prepareFields `json:"fields"` + FieldsCount int `json:"fields_count"` +} + +func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Session, action string, req *stmt2PrepareRequest, logger *logrus.Entry, isDebug bool) { + logger.Debugf("stmt2 prepare, stmt_id:%d, sql:%s", req.StmtID, req.SQL) + stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + stmt2 := stmtItem.stmt + code := syncinterface.TaosStmt2Prepare(stmt2, req.SQL, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmt2Error(stmt2) + logger.Errorf("stmt2 prepare error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Tracef("stmt2 prepare success, stmt_id:%d", req.StmtID) + isInsert, code := syncinterface.TaosStmt2IsInsert(stmt2, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmt2Error(stmt2) + logger.Errorf("check stmt2 is insert error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + logger.Tracef("stmt2 is insert:%t", isInsert) + stmtItem.isInsert = isInsert + prepareResp := &stmt2PrepareResponse{StmtID: req.StmtID, IsInsert: isInsert} + if req.GetFields { + if isInsert { + var fields []*prepareFields + // get table field + _, count, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_TBNAME, logger, isDebug) + if code != 0 { + logger.Errorf("get table names fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get table names fields error, %s", errStr), req.StmtID) + return + } + if count == 1 { + tableNameFields := &prepareFields{ + StmtField: stmtCommon.StmtField{}, + BindType: stmtCommon.TAOS_FIELD_TBNAME, + } + fields = append(fields, tableNameFields) + } + // get tags field + tagFields, _, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_TAG, logger, isDebug) + if code != 0 { + logger.Errorf("get tag fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get tag fields error, %s", errStr), req.StmtID) + return + } + for i := 0; i < len(tagFields); i++ { + fields = append(fields, &prepareFields{ + StmtField: *tagFields[i], + BindType: stmtCommon.TAOS_FIELD_TAG, + }) + } + // get cols field + colFields, _, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_COL, logger, isDebug) + if code != 0 { + logger.Errorf("get col fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get col fields error, %s", errStr), req.StmtID) + return + } + for i := 0; i < len(colFields); i++ { + fields = append(fields, &prepareFields{ + StmtField: *colFields[i], + BindType: stmtCommon.TAOS_FIELD_COL, + }) + } + prepareResp.Fields = fields + } else { + _, count, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_QUERY, logger, isDebug) + if code != 0 { + logger.Errorf("get query fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get query fields error, %s", errStr), req.StmtID) + return + } + prepareResp.FieldsCount = count + } + } + prepareResp.ReqID = req.ReqID + prepareResp.Action = action + prepareResp.Timing = wstool.GetDuration(ctx) + wstool.WSWriteJson(session, logger, prepareResp) +} + +func getFields(stmt2 unsafe.Pointer, fieldType int8, logger *logrus.Entry, isDebug bool) (fields []*stmtCommon.StmtField, count int, code int, errSt string) { + var cFields unsafe.Pointer + code, count, cFields = syncinterface.TaosStmt2GetFields(stmt2, int(fieldType), logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmt2Error(stmt2) + logger.Errorf("stmt2 get fields error, field_type:%d, err:%s", fieldType, errStr) + return nil, count, code, errStr + } + defer wrapper.TaosStmt2FreeFields(stmt2, cFields) + if count > 0 && cFields != nil { + s := log.GetLogNow(isDebug) + fields = wrapper.StmtParseFields(count, cFields) + logger.Debugf("stmt2 parse fields cost:%s", log.GetLogDuration(isDebug, s)) + return fields, count, 0, "" + } + return nil, count, 0, "" +} + +type stmt2GetFieldsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + FieldTypes []int8 `json:"field_types"` +} + +type stmt2GetFieldsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + TableCount int32 `json:"table_count"` + QueryCount int32 `json:"query_count"` + ColFields []*stmtCommon.StmtField `json:"col_fields"` + TagFields []*stmtCommon.StmtField `json:"tag_fields"` +} + +func (h *messageHandler) stmt2GetFields(ctx context.Context, session *melody.Session, action string, req *stmt2GetFieldsRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt2 get col fields, stmt_id:%d", req.StmtID) + stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + stmt2GetFieldsResp := &stmt2GetFieldsResponse{StmtID: req.StmtID} + for i := 0; i < len(req.FieldTypes); i++ { + switch req.FieldTypes[i] { + case stmtCommon.TAOS_FIELD_COL: + colFields, _, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_COL, logger, isDebug) + if code != 0 { + logger.Errorf("get col fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get col fields error, %s", errStr), req.StmtID) + return + } + stmt2GetFieldsResp.ColFields = colFields + case stmtCommon.TAOS_FIELD_TAG: + tagFields, _, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_TAG, logger, isDebug) + if code != 0 { + logger.Errorf("get tag fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get tag fields error, %s", errStr), req.StmtID) + return + } + stmt2GetFieldsResp.TagFields = tagFields + case stmtCommon.TAOS_FIELD_TBNAME: + _, count, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_TBNAME, logger, isDebug) + if code != 0 { + logger.Errorf("get table names fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get table names fields error, %s", errStr), req.StmtID) + return + } + stmt2GetFieldsResp.TableCount = int32(count) + case stmtCommon.TAOS_FIELD_QUERY: + _, count, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_QUERY, logger, isDebug) + if code != 0 { + logger.Errorf("get query fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get query fields error, %s", errStr), req.StmtID) + return + } + stmt2GetFieldsResp.QueryCount = int32(count) + } + } + stmt2GetFieldsResp.ReqID = req.ReqID + stmt2GetFieldsResp.Action = action + stmt2GetFieldsResp.Timing = wstool.GetDuration(ctx) + wstool.WSWriteJson(session, logger, stmt2GetFieldsResp) +} + +type stmt2ExecRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmt2ExecResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Affected int `json:"affected"` +} + +func (h *messageHandler) stmt2Exec(ctx context.Context, session *melody.Session, action string, req *stmt2ExecRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt2 execute, stmt_id:%d", req.StmtID) + stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + code := syncinterface.TaosStmt2Exec(stmtItem.stmt, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt2 execute error, err:%s", errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return + } + s := log.GetLogNow(isDebug) + logger.Tracef("stmt2 execute wait callback, stmt_id:%d", req.StmtID) + result := <-stmtItem.caller.ExecResult + logger.Debugf("stmt2 execute wait callback finish, affected:%d, res:%p, n:%d, cost:%s", result.Affected, result.Res, result.N, log.GetLogDuration(isDebug, s)) + stmtItem.result = result.Res + resp := stmt2ExecResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + Affected: result.Affected, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmt2UseResultRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmt2UseResultResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + ResultID uint64 `json:"result_id"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes jsontype.JsonUint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} + +func (h *messageHandler) stmt2UseResult(ctx context.Context, session *melody.Session, action string, req *stmt2UseResultRequest, logger *logrus.Entry, isDebug bool) { + logger.Tracef("stmt2 use result, stmt_id:%d", req.StmtID) + stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + result := stmtItem.result + fieldsCount := wrapper.TaosNumFields(result) + rowsHeader, _ := wrapper.ReadColumn(result, fieldsCount) + precision := wrapper.TaosResultPrecision(result) + logger.Tracef("stmt use result success, stmt_id:%d, fields_count:%d, precision:%d", req.StmtID, fieldsCount, precision) + queryResult := QueryResult{TaosResult: result, FieldsCount: fieldsCount, Header: rowsHeader, precision: precision, inStmt: true} + idx := h.queryResults.Add(&queryResult) + resp := &stmt2UseResultResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + ResultID: idx, + FieldsCount: fieldsCount, + FieldsNames: rowsHeader.ColNames, + FieldsTypes: rowsHeader.ColTypes, + FieldsLengths: rowsHeader.ColLength, + Precision: precision, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmt2CloseRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type stmt2CloseResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmt2Close(ctx context.Context, session *melody.Session, action string, req *stmt2CloseRequest, logger *logrus.Entry) { + logger.Tracef("stmt2 close, stmt_id:%d", req.StmtID) + err := h.stmts.FreeStmtByID(req.StmtID, true, logger) + if err != nil { + logger.Errorf("stmt2 close error, err:%s", err.Error()) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, err.Error(), req.StmtID) + return + } + logger.Tracef("stmt2 close success, stmt_id:%d", req.StmtID) + resp := stmt2CloseResponse{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + StmtID: req.StmtID, + } + wstool.WSWriteJson(session, logger, resp) +} + +type stmt2BindResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +func (h *messageHandler) stmt2BinaryBind(ctx context.Context, session *melody.Session, action string, reqID uint64, stmtID uint64, message []byte, logger *logrus.Entry, isDebug bool) { + if len(message) < 30 { + logger.Errorf("message length is too short, len:%d, stmt_id:%d", len(message), stmtID) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "message length is too short", stmtID) + return + } + v := binary.LittleEndian.Uint16(message[24:]) + if v != Stmt2BindProtocolVersion1 { + logger.Errorf("unknown stmt2 bind version, version:%d, stmt_id:%d", v, stmtID) + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, "unknown stmt2 bind version", stmtID) + return + } + colIndex := int32(binary.LittleEndian.Uint32(message[26:])) + stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, reqID, stmtID, logger, isDebug) + if !locked { + return + } + defer stmtItem.Unlock() + bindData := message[30:] + err := syncinterface.TaosStmt2BindBinary(stmtItem.stmt, bindData, colIndex, logger, isDebug) + if err != nil { + logger.Errorf("stmt2 bind error, err:%s", err.Error()) + var tError *errors2.TaosError + if errors.As(err, &tError) { + stmtErrorResponse(ctx, session, logger, action, reqID, int(tError.Code), tError.ErrStr, stmtID) + return + } + stmtErrorResponse(ctx, session, logger, action, reqID, 0xffff, err.Error(), stmtID) + return + } + logger.Trace("stmt2 bind success") + resp := &stmt2BindResponse{ + Action: action, + ReqID: reqID, + Timing: wstool.GetDuration(ctx), + StmtID: stmtID, + } + wstool.WSWriteJson(session, logger, resp) +} diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go new file mode 100644 index 00000000..e018ae3c --- /dev/null +++ b/controller/ws/ws/stmt2_test.go @@ -0,0 +1,742 @@ +package ws + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/tools/parseblock" +) + +func TestWsStmt2(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful("drop database if exists test_ws_stmt2_ws", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_ws_stmt2_ws precision 'ns'", "") + assert.Equal(t, 0, code, message) + + defer doRestful("drop database if exists test_ws_stmt2_ws", "") + + code, message = doRestful( + "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", + "test_ws_stmt2_ws") + assert.Equal(t, 0, code, message) + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt2_ws"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := stmt2InitRequest{ + ReqID: 0x123, + SingleStbInsert: false, + SingleTableBindOnce: false, + } + resp, err = doWebSocket(ws, STMT2Init, &initReq) + assert.NoError(t, err) + var initResp stmt2InitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(0x123), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmt2PrepareRequest{ReqID: 3, StmtID: initResp.StmtID, SQL: "insert into ct1 using test_ws_stmt2_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} + resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmt2PrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.True(t, prepareResp.IsInsert) + + // get tag fields + getTagFieldsReq := stmt2GetFieldsRequest{ReqID: 5, StmtID: prepareResp.StmtID, FieldTypes: []int8{stmtCommon.TAOS_FIELD_TAG}} + resp, err = doWebSocket(ws, STMT2GetFields, &getTagFieldsReq) + assert.NoError(t, err) + var getTagFieldsResp stmt2GetFieldsResponse + err = json.Unmarshal(resp, &getTagFieldsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), getTagFieldsResp.ReqID) + assert.Equal(t, 0, getTagFieldsResp.Code, getTagFieldsResp.Message) + + // get col fields + getColFieldsReq := stmt2GetFieldsRequest{ReqID: 6, StmtID: prepareResp.StmtID, FieldTypes: []int8{stmtCommon.TAOS_FIELD_COL}} + resp, err = doWebSocket(ws, STMT2GetFields, &getColFieldsReq) + assert.NoError(t, err) + var getColFieldsResp stmt2GetFieldsResponse + err = json.Unmarshal(resp, &getColFieldsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(6), getColFieldsResp.ReqID) + assert.Equal(t, 0, getColFieldsResp.Code, getColFieldsResp.Message) + + // bind + now := time.Now() + cols := [][]driver.Value{ + // ts + {now, now.Add(time.Second), now.Add(time.Second * 2)}, + // bool + {true, false, nil}, + // tinyint + {int8(2), int8(22), nil}, + // smallint + {int16(3), int16(33), nil}, + // int + {int32(4), int32(44), nil}, + // bigint + {int64(5), int64(55), nil}, + // tinyint unsigned + {uint8(6), uint8(66), nil}, + // smallint unsigned + {uint16(7), uint16(77), nil}, + // int unsigned + {uint32(8), uint32(88), nil}, + // bigint unsigned + {uint64(9), uint64(99), nil}, + // float + {float32(10), float32(1010), nil}, + // double + {float64(11), float64(1111), nil}, + // binary + {"binary", "binary2", nil}, + // nchar + {"nchar", "nchar2", nil}, + // varbinary + {[]byte{0xaa, 0xbb, 0xcc}, []byte{0xaa, 0xbb, 0xcc}, nil}, + // geometry + {[]byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, nil}, + } + tbName := "test_ws_stmt2_ws.ct1" + tag := []driver.Value{"{\"a\":\"b\"}"} + binds := &stmtCommon.TaosStmt2BindData{ + TableName: tbName, + Tags: tag, + Cols: cols, + } + bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, getColFieldsResp.ColFields, getTagFieldsResp.TagFields) + assert.NoError(t, err) + bindReq := make([]byte, len(bs)+30) + // req_id + binary.LittleEndian.PutUint64(bindReq, 0x12345) + // stmt_id + binary.LittleEndian.PutUint64(bindReq[8:], prepareResp.StmtID) + // action + binary.LittleEndian.PutUint64(bindReq[16:], Stmt2BindMessage) + // version + binary.LittleEndian.PutUint16(bindReq[24:], Stmt2BindProtocolVersion1) + // col_idx + idx := int32(-1) + binary.LittleEndian.PutUint32(bindReq[26:], uint32(idx)) + // data + copy(bindReq[30:], bs) + err = ws.WriteMessage(websocket.BinaryMessage, bindReq) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var bindResp stmt2BindResponse + err = json.Unmarshal(resp, &bindResp) + assert.NoError(t, err) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + //exec + execReq := stmt2ExecRequest{ReqID: 10, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMT2Exec, &execReq) + assert.NoError(t, err) + var execResp stmt2ExecResponse + err = json.Unmarshal(resp, &execResp) + assert.NoError(t, err) + assert.Equal(t, uint64(10), execResp.ReqID) + assert.Equal(t, 0, execResp.Code, execResp.Message) + assert.Equal(t, 3, execResp.Affected) + + // close + closeReq := stmt2CloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMT2Close, &closeReq) + assert.NoError(t, err) + var closeResp stmt2CloseResponse + err = json.Unmarshal(resp, &closeResp) + assert.NoError(t, err) + assert.Equal(t, uint64(11), closeResp.ReqID) + assert.Equal(t, 0, closeResp.Code, closeResp.Message) + + // query + queryReq := queryRequest{Sql: "select * from test_ws_stmt2_ws.stb"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch + fetchReq := fetchRequest{ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + var fetchResp fetchResponse + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + // fetch block + fetchBlockReq := fetchBlockRequest{ID: queryResp.ID} + fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) + assert.Equal(t, 3, len(blockResult)) + assert.Equal(t, now.UnixNano(), blockResult[0][0].(time.Time).UnixNano()) + + assert.Equal(t, true, blockResult[0][1]) + assert.Equal(t, int8(2), blockResult[0][2]) + assert.Equal(t, int16(3), blockResult[0][3]) + assert.Equal(t, int32(4), blockResult[0][4]) + assert.Equal(t, int64(5), blockResult[0][5]) + assert.Equal(t, uint8(6), blockResult[0][6]) + assert.Equal(t, uint16(7), blockResult[0][7]) + assert.Equal(t, uint32(8), blockResult[0][8]) + assert.Equal(t, uint64(9), blockResult[0][9]) + assert.Equal(t, float32(10), blockResult[0][10]) + assert.Equal(t, float64(11), blockResult[0][11]) + assert.Equal(t, "binary", blockResult[0][12]) + assert.Equal(t, "nchar", blockResult[0][13]) + assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) + + assert.Equal(t, now.Add(time.Second).UnixNano(), blockResult[1][0].(time.Time).UnixNano()) + assert.Equal(t, false, blockResult[1][1]) + assert.Equal(t, int8(22), blockResult[1][2]) + assert.Equal(t, int16(33), blockResult[1][3]) + assert.Equal(t, int32(44), blockResult[1][4]) + assert.Equal(t, int64(55), blockResult[1][5]) + assert.Equal(t, uint8(66), blockResult[1][6]) + assert.Equal(t, uint16(77), blockResult[1][7]) + assert.Equal(t, uint32(88), blockResult[1][8]) + assert.Equal(t, uint64(99), blockResult[1][9]) + assert.Equal(t, float32(1010), blockResult[1][10]) + assert.Equal(t, float64(1111), blockResult[1][11]) + assert.Equal(t, "binary2", blockResult[1][12]) + assert.Equal(t, "nchar2", blockResult[1][13]) + assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) + + assert.Equal(t, now.Add(time.Second*2).UnixNano(), blockResult[2][0].(time.Time).UnixNano()) + for i := 1; i < 16; i++ { + assert.Nil(t, blockResult[2][i]) + } + +} + +func TestStmt2Prepare(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful("drop database if exists test_ws_stmt2_prepare_ws", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_ws_stmt2_prepare_ws precision 'ns'", "") + assert.Equal(t, 0, code, message) + + defer doRestful("drop database if exists test_ws_stmt2_prepare_ws", "") + + code, message = doRestful( + "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", + "test_ws_stmt2_prepare_ws") + assert.Equal(t, 0, code, message) + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt2_prepare_ws"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := stmt2InitRequest{ + ReqID: 0x123, + SingleStbInsert: false, + SingleTableBindOnce: false, + } + resp, err = doWebSocket(ws, STMT2Init, &initReq) + assert.NoError(t, err) + var initResp stmt2InitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(0x123), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmt2PrepareRequest{ + ReqID: 3, + StmtID: initResp.StmtID, + SQL: "insert into ctb using test_ws_stmt2_prepare_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + GetFields: true, + } + resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmt2PrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.Equal(t, true, prepareResp.IsInsert) + names := [17]string{ + "info", + "ts", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + } + fieldTypes := [17]int8{ + common.TSDB_DATA_TYPE_JSON, + common.TSDB_DATA_TYPE_TIMESTAMP, + common.TSDB_DATA_TYPE_BOOL, + common.TSDB_DATA_TYPE_TINYINT, + common.TSDB_DATA_TYPE_SMALLINT, + common.TSDB_DATA_TYPE_INT, + common.TSDB_DATA_TYPE_BIGINT, + common.TSDB_DATA_TYPE_UTINYINT, + common.TSDB_DATA_TYPE_USMALLINT, + common.TSDB_DATA_TYPE_UINT, + common.TSDB_DATA_TYPE_UBIGINT, + common.TSDB_DATA_TYPE_FLOAT, + common.TSDB_DATA_TYPE_DOUBLE, + common.TSDB_DATA_TYPE_BINARY, + common.TSDB_DATA_TYPE_NCHAR, + common.TSDB_DATA_TYPE_VARBINARY, + common.TSDB_DATA_TYPE_GEOMETRY, + } + assert.True(t, prepareResp.IsInsert) + assert.Equal(t, 17, len(prepareResp.Fields)) + for i := 0; i < 17; i++ { + assert.Equal(t, names[i], prepareResp.Fields[i].Name) + assert.Equal(t, fieldTypes[i], prepareResp.Fields[i].FieldType) + } + // prepare query + prepareReq = stmt2PrepareRequest{ + ReqID: 4, + StmtID: initResp.StmtID, + SQL: "select * from test_ws_stmt2_prepare_ws.stb where ts = ? and v1 = ?", + GetFields: true, + } + resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(4), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.Equal(t, false, prepareResp.IsInsert) + assert.Nil(t, prepareResp.Fields) + assert.Equal(t, 2, prepareResp.FieldsCount) +} + +func TestStmt2GetFields(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful("drop database if exists test_ws_stmt2_getfields_ws", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_ws_stmt2_getfields_ws precision 'ns'", "") + assert.Equal(t, 0, code, message) + + defer doRestful("drop database if exists test_ws_stmt2_getfields_ws", "") + + code, message = doRestful( + "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", + "test_ws_stmt2_getfields_ws") + assert.Equal(t, 0, code, message) + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt2_getfields_ws"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := stmt2InitRequest{ + ReqID: 0x123, + SingleStbInsert: false, + SingleTableBindOnce: false, + } + resp, err = doWebSocket(ws, STMT2Init, &initReq) + assert.NoError(t, err) + var initResp stmt2InitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(0x123), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmt2PrepareRequest{ + ReqID: 3, + StmtID: initResp.StmtID, + SQL: "insert into ctb using test_ws_stmt2_getfields_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + GetFields: false, + } + resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmt2PrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.Equal(t, true, prepareResp.IsInsert) + + // get fields + getFieldsReq := stmt2GetFieldsRequest{ + ReqID: 4, + StmtID: prepareResp.StmtID, + FieldTypes: []int8{ + stmtCommon.TAOS_FIELD_TAG, + stmtCommon.TAOS_FIELD_COL, + }, + } + resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) + assert.NoError(t, err) + var getFieldsResp stmt2GetFieldsResponse + err = json.Unmarshal(resp, &getFieldsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(4), getFieldsResp.ReqID) + assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) + names := [16]string{ + "ts", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + } + fieldTypes := [16]int8{ + common.TSDB_DATA_TYPE_TIMESTAMP, + common.TSDB_DATA_TYPE_BOOL, + common.TSDB_DATA_TYPE_TINYINT, + common.TSDB_DATA_TYPE_SMALLINT, + common.TSDB_DATA_TYPE_INT, + common.TSDB_DATA_TYPE_BIGINT, + common.TSDB_DATA_TYPE_UTINYINT, + common.TSDB_DATA_TYPE_USMALLINT, + common.TSDB_DATA_TYPE_UINT, + common.TSDB_DATA_TYPE_UBIGINT, + common.TSDB_DATA_TYPE_FLOAT, + common.TSDB_DATA_TYPE_DOUBLE, + common.TSDB_DATA_TYPE_BINARY, + common.TSDB_DATA_TYPE_NCHAR, + common.TSDB_DATA_TYPE_VARBINARY, + common.TSDB_DATA_TYPE_GEOMETRY, + } + assert.Equal(t, 16, len(getFieldsResp.ColFields)) + assert.Equal(t, 1, len(getFieldsResp.TagFields)) + for i := 0; i < 16; i++ { + assert.Equal(t, names[i], getFieldsResp.ColFields[i].Name) + assert.Equal(t, fieldTypes[i], getFieldsResp.ColFields[i].FieldType) + } + assert.Equal(t, "info", getFieldsResp.TagFields[0].Name) + assert.Equal(t, int8(common.TSDB_DATA_TYPE_JSON), getFieldsResp.TagFields[0].FieldType) + + // prepare get tablename + prepareReq = stmt2PrepareRequest{ + ReqID: 5, + StmtID: initResp.StmtID, + SQL: "insert into ? using test_ws_stmt2_getfields_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + GetFields: false, + } + resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) + assert.NoError(t, err) + + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.Equal(t, true, prepareResp.IsInsert) + // get fields + getFieldsReq = stmt2GetFieldsRequest{ + ReqID: 6, + StmtID: prepareResp.StmtID, + FieldTypes: []int8{ + stmtCommon.TAOS_FIELD_TBNAME, + }, + } + resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &getFieldsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(6), getFieldsResp.ReqID) + assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) + + assert.Nil(t, getFieldsResp.ColFields) + assert.Nil(t, getFieldsResp.TagFields) + assert.Equal(t, int32(1), getFieldsResp.TableCount) + + // prepare query + prepareReq = stmt2PrepareRequest{ + ReqID: 7, + StmtID: initResp.StmtID, + SQL: "select * from test_ws_stmt2_getfields_ws.stb where ts = ? and v1 = ?", + GetFields: false, + } + resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(7), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.Equal(t, false, prepareResp.IsInsert) + // get fields + getFieldsReq = stmt2GetFieldsRequest{ + ReqID: 8, + StmtID: prepareResp.StmtID, + FieldTypes: []int8{ + stmtCommon.TAOS_FIELD_QUERY, + }, + } + resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &getFieldsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(8), getFieldsResp.ReqID) + assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) + + assert.Nil(t, getFieldsResp.ColFields) + assert.Nil(t, getFieldsResp.TagFields) + assert.Equal(t, int32(2), getFieldsResp.QueryCount) + +} + +func TestStmt2Query(t *testing.T) { + //for stable + prepareDataSql := []string{ + "create stable meters (ts timestamp,current float,voltage int,phase float) tags (group_id int, location varchar(24))", + "insert into d0 using meters tags (2, 'California.SanFrancisco') values ('2023-09-13 17:53:52.123', 10.2, 219, 0.32) ", + "insert into d1 using meters tags (1, 'California.SanFrancisco') values ('2023-09-13 17:54:43.321', 10.3, 218, 0.31) ", + } + Stmt2Query(t, "test_ws_stmt2_query_for_stable", prepareDataSql) + + // for table + prepareDataSql = []string{ + "create table meters (ts timestamp,current float,voltage int,phase float, group_id int, location varchar(24))", + "insert into meters values ('2023-09-13 17:53:52.123', 10.2, 219, 0.32, 2, 'California.SanFrancisco') ", + "insert into meters values ('2023-09-13 17:54:43.321', 10.3, 218, 0.31, 1, 'California.SanFrancisco') ", + } + Stmt2Query(t, "test_ws_stmt2_query_for_table", prepareDataSql) +} + +func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") + assert.Equal(t, 0, code, message) + + defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") + + for _, sql := range prepareDataSql { + code, message = doRestful(sql, db) + assert.Equal(t, 0, code, message) + } + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := map[string]uint64{"req_id": 2} + resp, err = doWebSocket(ws, STMT2Init, &initReq) + assert.NoError(t, err) + var initResp stmt2InitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmt2PrepareRequest{ + ReqID: 3, + StmtID: initResp.StmtID, + SQL: fmt.Sprintf("select * from %s.meters where group_id=? and location=?", db), + GetFields: false, + } + resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmt2PrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.False(t, prepareResp.IsInsert) + + // bind + var block bytes.Buffer + wstool.WriteUint64(&block, 5) + wstool.WriteUint64(&block, prepareResp.StmtID) + wstool.WriteUint64(&block, uint64(Stmt2BindMessage)) + wstool.WriteUint16(&block, Stmt2BindProtocolVersion1) + idx := int32(-1) + wstool.WriteUint32(&block, uint32(idx)) + params := []*stmtCommon.TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + {int32(1)}, + {"California.SanFrancisco"}, + }, + }, + } + b, err := stmtCommon.MarshalStmt2Binary(params, false, nil, nil) + assert.NoError(t, err) + block.Write(b) + + err = ws.WriteMessage(websocket.BinaryMessage, block.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var bindResp stmt2BindResponse + err = json.Unmarshal(resp, &bindResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), bindResp.ReqID) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + // exec + execReq := stmtExecRequest{ReqID: 6, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMT2Exec, &execReq) + assert.NoError(t, err) + var execResp stmtExecResponse + err = json.Unmarshal(resp, &execResp) + assert.NoError(t, err) + assert.Equal(t, uint64(6), execResp.ReqID) + assert.Equal(t, 0, execResp.Code, execResp.Message) + + // use result + useResultReq := stmt2UseResultRequest{ReqID: 7, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMT2Result, &useResultReq) + assert.NoError(t, err) + var useResultResp stmt2UseResultResponse + err = json.Unmarshal(resp, &useResultResp) + assert.NoError(t, err) + assert.Equal(t, uint64(7), useResultResp.ReqID) + assert.Equal(t, 0, useResultResp.Code, useResultResp.Message) + + // fetch + fetchReq := fetchRequest{ReqID: 8, ID: useResultResp.ResultID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + var fetchResp fetchResponse + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, uint64(8), fetchResp.ReqID) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + assert.Equal(t, 1, fetchResp.Rows) + + // fetch block + fetchBlockReq := fetchBlockRequest{ReqID: 9, ID: useResultResp.ResultID} + fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], useResultResp.FieldsTypes, fetchResp.Rows, useResultResp.Precision) + assert.Equal(t, 1, len(blockResult)) + assert.Equal(t, float32(10.3), blockResult[0][1]) + assert.Equal(t, int32(218), blockResult[0][2]) + assert.Equal(t, float32(0.31), blockResult[0][3]) + + // free result + freeResultReq, _ := json.Marshal(freeResultRequest{ReqID: 10, ID: useResultResp.ResultID}) + a, _ := json.Marshal(Request{Action: WSFreeResult, Args: freeResultReq}) + err = ws.WriteMessage(websocket.TextMessage, a) + assert.NoError(t, err) + + // close + closeReq := stmtCloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMT2Close, &closeReq) + assert.NoError(t, err) + var closeResp stmt2CloseResponse + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, closeResp.Code, closeResp.Message) +} diff --git a/controller/ws/ws/stmt_test.go b/controller/ws/ws/stmt_test.go new file mode 100644 index 00000000..9115f416 --- /dev/null +++ b/controller/ws/ws/stmt_test.go @@ -0,0 +1,900 @@ +package ws + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "net/http/httptest" + "strings" + "testing" + "time" + "unsafe" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" + "github.com/taosdata/driver-go/v3/common/serializer" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" + "github.com/taosdata/driver-go/v3/types" + "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/tools/parseblock" +) + +func Test_parseRowBlockInfo(t *testing.T) { + b, err := serializer.SerializeRawBlock( + []*param.Param{ + param.NewParam(1).AddBool(true), + param.NewParam(1).AddTinyint(1), + param.NewParam(1).AddSmallint(1), + param.NewParam(1).AddInt(1), + param.NewParam(1).AddBigint(1), + param.NewParam(1).AddFloat(1.1), + param.NewParam(1).AddDouble(1.1), + param.NewParam(1).AddBinary([]byte("California.SanFrancisco")), + param.NewParam(1).AddNchar("California.SanFrancisco"), + param.NewParam(1).AddUTinyint(1), + param.NewParam(1).AddUSmallint(1), + param.NewParam(1).AddUInt(1), + param.NewParam(1).AddUBigint(1), + param.NewParam(1).AddJson([]byte(`{"name":"taos"}`)), + param.NewParam(1).AddVarBinary([]byte("California.SanFrancisco")), + }, + param.NewColumnType(15). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddFloat(). + AddDouble(). + AddBinary(100). + AddNchar(100). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddJson(100). + AddVarBinary(100), + ) + assert.NoError(t, err) + fields, fieldsType, err := parseRowBlockInfo(unsafe.Pointer(&b[0]), 15) + assert.NoError(t, err) + expectFields := []*stmtCommon.StmtField{ + {FieldType: common.TSDB_DATA_TYPE_BOOL}, + {FieldType: common.TSDB_DATA_TYPE_TINYINT}, + {FieldType: common.TSDB_DATA_TYPE_SMALLINT}, + {FieldType: common.TSDB_DATA_TYPE_INT}, + {FieldType: common.TSDB_DATA_TYPE_BIGINT}, + {FieldType: common.TSDB_DATA_TYPE_FLOAT}, + {FieldType: common.TSDB_DATA_TYPE_DOUBLE}, + {FieldType: common.TSDB_DATA_TYPE_BINARY}, + {FieldType: common.TSDB_DATA_TYPE_NCHAR}, + {FieldType: common.TSDB_DATA_TYPE_UTINYINT}, + {FieldType: common.TSDB_DATA_TYPE_USMALLINT}, + {FieldType: common.TSDB_DATA_TYPE_UINT}, + {FieldType: common.TSDB_DATA_TYPE_UBIGINT}, + {FieldType: common.TSDB_DATA_TYPE_JSON}, + {FieldType: common.TSDB_DATA_TYPE_VARBINARY}, + } + assert.Equal(t, expectFields, fields) + expectFieldsType := []*types.ColumnType{ + {Type: types.TaosBoolType}, + {Type: types.TaosTinyintType}, + {Type: types.TaosSmallintType}, + {Type: types.TaosIntType}, + {Type: types.TaosBigintType}, + {Type: types.TaosFloatType}, + {Type: types.TaosDoubleType}, + {Type: types.TaosBinaryType}, + {Type: types.TaosNcharType}, + {Type: types.TaosUTinyintType}, + {Type: types.TaosUSmallintType}, + {Type: types.TaosUIntType}, + {Type: types.TaosUBigintType}, + {Type: types.TaosJsonType}, + {Type: types.TaosBinaryType}, + } + assert.Equal(t, expectFieldsType, fieldsType) +} + +func TestWsStmt(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful("drop database if exists test_ws_stmt_ws", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_ws_stmt_ws precision 'ns'", "") + assert.Equal(t, 0, code, message) + + defer doRestful("drop database if exists test_ws_stmt_ws", "") + + code, message = doRestful( + "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", + "test_ws_stmt_ws") + assert.Equal(t, 0, code, message) + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt_ws"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := map[string]uint64{"req_id": 2} + resp, err = doWebSocket(ws, STMTInit, &initReq) + assert.NoError(t, err) + var initResp stmtInitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmtPrepareRequest{ReqID: 3, StmtID: initResp.StmtID, SQL: "insert into ? using test_ws_stmt_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} + resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmtPrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.True(t, prepareResp.IsInsert) + + // set table name + setTableNameReq := stmtSetTableNameRequest{ReqID: 4, StmtID: prepareResp.StmtID, Name: "test_ws_stmt_ws.ct1"} + resp, err = doWebSocket(ws, STMTSetTableName, &setTableNameReq) + assert.NoError(t, err) + var setTableNameResp commonResp + err = json.Unmarshal(resp, &setTableNameResp) + assert.NoError(t, err) + assert.Equal(t, uint64(4), setTableNameResp.ReqID) + assert.Equal(t, 0, setTableNameResp.Code, setTableNameResp.Message) + + // get tag fields + getTagFieldsReq := stmtGetTagFieldsRequest{ReqID: 5, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTGetTagFields, &getTagFieldsReq) + assert.NoError(t, err) + var getTagFieldsResp stmtGetTagFieldsResponse + err = json.Unmarshal(resp, &getTagFieldsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), getTagFieldsResp.ReqID) + assert.Equal(t, 0, getTagFieldsResp.Code, getTagFieldsResp.Message) + + // get col fields + getColFieldsReq := stmtGetColFieldsRequest{ReqID: 6, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTGetColFields, &getColFieldsReq) + assert.NoError(t, err) + var getColFieldsResp stmtGetColFieldsResponse + err = json.Unmarshal(resp, &getColFieldsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(6), getColFieldsResp.ReqID) + assert.Equal(t, 0, getColFieldsResp.Code, getColFieldsResp.Message) + + // set tags + setTagsReq := stmtSetTagsRequest{ReqID: 7, StmtID: prepareResp.StmtID, Tags: json.RawMessage(`["{\"a\":\"b\"}"]`)} + resp, err = doWebSocket(ws, STMTSetTags, &setTagsReq) + assert.NoError(t, err) + var setTagsResp stmtSetTagsResponse + err = json.Unmarshal(resp, &setTagsResp) + assert.NoError(t, err) + assert.Equal(t, uint64(7), setTagsResp.ReqID) + assert.Equal(t, 0, setTagsResp.Code, setTagsResp.Message) + + // bind + now := time.Now() + columns, _ := json.Marshal([][]driver.Value{ + {now, now.Add(time.Second), now.Add(time.Second * 2)}, + {true, false, nil}, + {2, 22, nil}, + {3, 33, nil}, + {4, 44, nil}, + {5, 55, nil}, + {6, 66, nil}, + {7, 77, nil}, + {8, 88, nil}, + {9, 99, nil}, + {10, 1010, nil}, + {11, 1111, nil}, + {"binary", "binary2", nil}, + {"nchar", "nchar2", nil}, + {"aabbcc", "aabbcc", nil}, + {"010100000000000000000059400000000000005940", "010100000000000000000059400000000000005940", nil}, + }) + bindReq := stmtBindRequest{ReqID: 8, StmtID: prepareResp.StmtID, Columns: columns} + resp, err = doWebSocket(ws, STMTBind, &bindReq) + assert.NoError(t, err) + var bindResp stmtBindResponse + err = json.Unmarshal(resp, &bindResp) + assert.NoError(t, err) + assert.Equal(t, uint64(8), bindResp.ReqID) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + // add batch + addBatchReq := stmtAddBatchRequest{ReqID: 9, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTAddBatch, &addBatchReq) + assert.NoError(t, err) + var addBatchResp stmtAddBatchResponse + err = json.Unmarshal(resp, &addBatchResp) + assert.NoError(t, err) + assert.Equal(t, uint64(9), addBatchResp.ReqID) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + // exec + execReq := stmtExecRequest{ReqID: 10, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTExec, &execReq) + assert.NoError(t, err) + var execResp stmtExecResponse + err = json.Unmarshal(resp, &execResp) + assert.NoError(t, err) + assert.Equal(t, uint64(10), execResp.ReqID) + assert.Equal(t, 0, execResp.Code, execResp.Message) + + // close + closeReq := stmtCloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} + err = doWebSocketWithoutResp(ws, STMTClose, &closeReq) + assert.NoError(t, err) + + // query + queryReq := queryRequest{Sql: "select * from test_ws_stmt_ws.stb"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch + fetchReq := fetchRequest{ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + var fetchResp fetchResponse + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + // fetch block + fetchBlockReq := fetchBlockRequest{ID: queryResp.ID} + fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) + assert.Equal(t, 3, len(blockResult)) + assert.Equal(t, now.UnixNano(), blockResult[0][0].(time.Time).UnixNano()) + + assert.Equal(t, true, blockResult[0][1]) + assert.Equal(t, int8(2), blockResult[0][2]) + assert.Equal(t, int16(3), blockResult[0][3]) + assert.Equal(t, int32(4), blockResult[0][4]) + assert.Equal(t, int64(5), blockResult[0][5]) + assert.Equal(t, uint8(6), blockResult[0][6]) + assert.Equal(t, uint16(7), blockResult[0][7]) + assert.Equal(t, uint32(8), blockResult[0][8]) + assert.Equal(t, uint64(9), blockResult[0][9]) + assert.Equal(t, float32(10), blockResult[0][10]) + assert.Equal(t, float64(11), blockResult[0][11]) + assert.Equal(t, "binary", blockResult[0][12]) + assert.Equal(t, "nchar", blockResult[0][13]) + assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) + + assert.Equal(t, now.Add(time.Second).UnixNano(), blockResult[1][0].(time.Time).UnixNano()) + assert.Equal(t, false, blockResult[1][1]) + assert.Equal(t, int8(22), blockResult[1][2]) + assert.Equal(t, int16(33), blockResult[1][3]) + assert.Equal(t, int32(44), blockResult[1][4]) + assert.Equal(t, int64(55), blockResult[1][5]) + assert.Equal(t, uint8(66), blockResult[1][6]) + assert.Equal(t, uint16(77), blockResult[1][7]) + assert.Equal(t, uint32(88), blockResult[1][8]) + assert.Equal(t, uint64(99), blockResult[1][9]) + assert.Equal(t, float32(1010), blockResult[1][10]) + assert.Equal(t, float64(1111), blockResult[1][11]) + assert.Equal(t, "binary2", blockResult[1][12]) + assert.Equal(t, "nchar2", blockResult[1][13]) + assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) + + assert.Equal(t, now.Add(time.Second*2).UnixNano(), blockResult[2][0].(time.Time).UnixNano()) + for i := 1; i < 16; i++ { + assert.Nil(t, blockResult[2][i]) + } + + // block message + // init + resp, err = doWebSocket(ws, STMTInit, nil) + assert.NoError(t, err) + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq = stmtPrepareRequest{StmtID: initResp.StmtID, SQL: "insert into ? using test_ws_stmt_ws.stb tags(?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} + resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + + // set table name + setTableNameReq = stmtSetTableNameRequest{StmtID: prepareResp.StmtID, Name: "test_ws_stmt_ws.ct2"} + resp, err = doWebSocket(ws, STMTSetTableName, &setTableNameReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &setTableNameResp) + assert.NoError(t, err) + assert.Equal(t, 0, setTableNameResp.Code, setTableNameResp.Message) + + // set tags + var tagBuffer bytes.Buffer + wstool.WriteUint64(&tagBuffer, 100) + wstool.WriteUint64(&tagBuffer, prepareResp.StmtID) + wstool.WriteUint64(&tagBuffer, uint64(SetTagsMessage)) + tags, err := json.Marshal(map[string]string{"a": "b"}) + assert.NoError(t, err) + b, err := serializer.SerializeRawBlock( + []*param.Param{ + param.NewParam(1).AddJson(tags), + }, + param.NewColumnType(1).AddJson(50)) + assert.NoError(t, err) + assert.NoError(t, err) + tagBuffer.Write(b) + + err = ws.WriteMessage(websocket.BinaryMessage, tagBuffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &setTagsResp) + assert.NoError(t, err) + assert.Equal(t, 0, setTagsResp.Code, setTagsResp.Message) + + // bind binary + var block bytes.Buffer + wstool.WriteUint64(&block, 10) + wstool.WriteUint64(&block, prepareResp.StmtID) + wstool.WriteUint64(&block, uint64(BindMessage)) + rawBlock := []byte{ + 0x01, 0x00, 0x00, 0x00, + 0x11, 0x02, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x05, 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x02, 0x00, 0x00, 0x00, + 0x0d, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x08, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x08, 0x00, 0x00, 0x00, + 0x08, 0x16, 0x00, 0x00, 0x00, + 0x0a, 0x52, 0x00, 0x00, 0x00, + 0x10, 0x20, 0x00, 0x00, 0x00, + 0x14, 0x20, 0x00, 0x00, 0x00, + + 0x18, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, + 0x30, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, + 0x2e, 0x00, 0x00, 0x00, + + 0x00, + 0x2c, 0x5b, 0x70, 0x86, 0x82, 0x01, 0x00, 0x00, + 0x14, 0x5f, 0x70, 0x86, 0x82, 0x01, 0x00, 0x00, + 0xfc, 0x62, 0x70, 0x86, 0x82, 0x01, 0x00, 0x00, + + 0x20, + 0x01, + 0x00, + 0x00, + + 0x20, + 0x02, + 0x16, + 0x00, + + 0x20, + 0x03, 0x00, + 0x21, 0x00, + 0x00, 0x00, + + 0x20, + 0x04, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + + 0x20, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x37, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x20, + 0x06, + 0x42, + 0x00, + + 0x20, + 0x07, 0x00, + 0x4d, 0x00, + 0x00, 0x00, + + 0x20, + 0x08, 0x00, 0x00, 0x00, + 0x58, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + + 0x20, + 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x20, + 0x00, 0x00, 0x20, 0x41, + 0x00, 0x80, 0x7c, 0x44, + 0x00, 0x00, 0x00, 0x00, + + 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x40, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x5c, 0x91, 0x40, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x06, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + 0x07, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x32, + + 0x00, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x14, 0x00, + 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, + 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, + 0x18, 0x00, + 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, + 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x0e, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + 0x0f, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x32, + + 0x00, 0x00, 0x00, 0x00, + 0x17, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x15, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x15, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + } + binary.LittleEndian.PutUint64(rawBlock[173:], uint64(now.UnixNano())) + binary.LittleEndian.PutUint64(rawBlock[181:], uint64(now.Add(time.Second).UnixNano())) + binary.LittleEndian.PutUint64(rawBlock[189:], uint64(now.Add(time.Second*2).UnixNano())) + block.Write(rawBlock) + err = ws.WriteMessage( + websocket.BinaryMessage, + block.Bytes(), + ) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &bindResp) + assert.NoError(t, err) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + // add batch + addBatchReq = stmtAddBatchRequest{StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTAddBatch, &addBatchReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &addBatchResp) + assert.NoError(t, err) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + // exec + execReq = stmtExecRequest{StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTExec, &execReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &execResp) + assert.NoError(t, err) + assert.Equal(t, 0, execResp.Code, execResp.Message) + + // query + queryReq = queryRequest{Sql: "select * from test_ws_stmt_ws.ct2"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch + fetchReq = fetchRequest{ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + // fetch block + fetchBlockReq = fetchBlockRequest{ID: queryResp.ID} + fetchBlockResp, err = doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + _, blockResult = parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) + assert.Equal(t, now.UnixNano(), blockResult[0][0].(time.Time).UnixNano()) + assert.Equal(t, true, blockResult[0][1]) + assert.Equal(t, int8(2), blockResult[0][2]) + assert.Equal(t, int16(3), blockResult[0][3]) + assert.Equal(t, int32(4), blockResult[0][4]) + assert.Equal(t, int64(5), blockResult[0][5]) + assert.Equal(t, uint8(6), blockResult[0][6]) + assert.Equal(t, uint16(7), blockResult[0][7]) + assert.Equal(t, uint32(8), blockResult[0][8]) + assert.Equal(t, uint64(9), blockResult[0][9]) + assert.Equal(t, float32(10), blockResult[0][10]) + assert.Equal(t, float64(11), blockResult[0][11]) + assert.Equal(t, "binary", blockResult[0][12]) + assert.Equal(t, "nchar", blockResult[0][13]) + assert.Equal(t, []byte("test_varbinary2"), blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) + + assert.Equal(t, now.Add(time.Second).UnixNano(), blockResult[1][0].(time.Time).UnixNano()) + assert.Equal(t, false, blockResult[1][1]) + assert.Equal(t, int8(22), blockResult[1][2]) + assert.Equal(t, int16(33), blockResult[1][3]) + assert.Equal(t, int32(44), blockResult[1][4]) + assert.Equal(t, int64(55), blockResult[1][5]) + assert.Equal(t, uint8(66), blockResult[1][6]) + assert.Equal(t, uint16(77), blockResult[1][7]) + assert.Equal(t, uint32(88), blockResult[1][8]) + assert.Equal(t, uint64(99), blockResult[1][9]) + assert.Equal(t, float32(1010), blockResult[1][10]) + assert.Equal(t, float64(1111), blockResult[1][11]) + assert.Equal(t, "binary2", blockResult[1][12]) + assert.Equal(t, "nchar2", blockResult[1][13]) + assert.Equal(t, []byte("test_varbinary2"), blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) + + assert.Equal(t, now.Add(time.Second*2).UnixNano(), blockResult[2][0].(time.Time).UnixNano()) + for i := 1; i < 16; i++ { + assert.Nil(t, blockResult[2][i]) + } +} + +func TestStmtQuery(t *testing.T) { + //for stable + prepareDataSql := []string{ + "create stable meters (ts timestamp,current float,voltage int,phase float) tags (group_id int, location varchar(24))", + "insert into d0 using meters tags (2, 'California.SanFrancisco') values ('2023-09-13 17:53:52.123', 10.2, 219, 0.32) ", + "insert into d1 using meters tags (1, 'California.SanFrancisco') values ('2023-09-13 17:54:43.321', 10.3, 218, 0.31) ", + } + StmtQuery(t, "test_ws_stmt_query_for_stable", prepareDataSql) + + // for table + prepareDataSql = []string{ + "create table meters (ts timestamp,current float,voltage int,phase float, group_id int, location varchar(24))", + "insert into meters values ('2023-09-13 17:53:52.123', 10.2, 219, 0.32, 2, 'California.SanFrancisco') ", + "insert into meters values ('2023-09-13 17:54:43.321', 10.3, 218, 0.31, 1, 'California.SanFrancisco') ", + } + StmtQuery(t, "test_ws_stmt_query_for_table", prepareDataSql) +} + +func StmtQuery(t *testing.T, db string, prepareDataSql []string) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") + assert.Equal(t, 0, code, message) + + defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") + + for _, sql := range prepareDataSql { + code, message = doRestful(sql, db) + assert.Equal(t, 0, code, message) + } + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := map[string]uint64{"req_id": 2} + resp, err = doWebSocket(ws, STMTInit, &initReq) + assert.NoError(t, err) + var initResp stmtInitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmtPrepareRequest{ + ReqID: 3, + StmtID: initResp.StmtID, + SQL: fmt.Sprintf("select * from %s.meters where group_id=? and location=?", db), + } + resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmtPrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.False(t, prepareResp.IsInsert) + + // bind + var block bytes.Buffer + wstool.WriteUint64(&block, 5) + wstool.WriteUint64(&block, prepareResp.StmtID) + wstool.WriteUint64(&block, uint64(BindMessage)) + b, err := serializer.SerializeRawBlock( + []*param.Param{ + param.NewParam(1).AddInt(1), + param.NewParam(1).AddBinary([]byte("California.SanFrancisco")), + }, + param.NewColumnType(2).AddInt().AddBinary(24)) + assert.NoError(t, err) + block.Write(b) + + err = ws.WriteMessage(websocket.BinaryMessage, block.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var bindResp stmtBindResponse + err = json.Unmarshal(resp, &bindResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), bindResp.ReqID) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + // add batch + addBatchReq := stmtAddBatchRequest{StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTAddBatch, &addBatchReq) + assert.NoError(t, err) + var addBatchResp stmtAddBatchResponse + err = json.Unmarshal(resp, &addBatchResp) + assert.NoError(t, err) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + // exec + execReq := stmtExecRequest{ReqID: 6, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTExec, &execReq) + assert.NoError(t, err) + var execResp stmtExecResponse + err = json.Unmarshal(resp, &execResp) + assert.NoError(t, err) + assert.Equal(t, uint64(6), execResp.ReqID) + assert.Equal(t, 0, execResp.Code, execResp.Message) + + // use result + useResultReq := stmtUseResultRequest{ReqID: 7, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTUseResult, &useResultReq) + assert.NoError(t, err) + var useResultResp stmtUseResultResponse + err = json.Unmarshal(resp, &useResultResp) + assert.NoError(t, err) + assert.Equal(t, uint64(7), useResultResp.ReqID) + assert.Equal(t, 0, useResultResp.Code, useResultResp.Message) + + // fetch + fetchReq := fetchRequest{ReqID: 8, ID: useResultResp.ResultID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + var fetchResp fetchResponse + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, uint64(8), fetchResp.ReqID) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + assert.Equal(t, 1, fetchResp.Rows) + + // fetch block + fetchBlockReq := fetchBlockRequest{ReqID: 9, ID: useResultResp.ResultID} + fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], useResultResp.FieldsTypes, fetchResp.Rows, useResultResp.Precision) + assert.Equal(t, 1, len(blockResult)) + assert.Equal(t, float32(10.3), blockResult[0][1]) + assert.Equal(t, int32(218), blockResult[0][2]) + assert.Equal(t, float32(0.31), blockResult[0][3]) + + // free result + freeResultReq, _ := json.Marshal(freeResultRequest{ReqID: 10, ID: useResultResp.ResultID}) + a, _ := json.Marshal(Request{Action: WSFreeResult, Args: freeResultReq}) + err = ws.WriteMessage(websocket.TextMessage, a) + assert.NoError(t, err) + + // close + closeReq := stmtCloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} + err = doWebSocketWithoutResp(ws, STMTClose, &closeReq) + assert.NoError(t, err) +} + +func TestStmtNumParams(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + db := "test_ws_stmt_num_params" + code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create stable if not exists %s.meters (ts timestamp,current float,voltage int,phase float) tags (groupid int,location varchar(24))", db), "") + assert.Equal(t, 0, code, message) + + defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := map[string]uint64{"req_id": 2} + resp, err = doWebSocket(ws, STMTInit, &initReq) + assert.NoError(t, err) + var initResp stmtInitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmtPrepareRequest{ + ReqID: 3, + StmtID: initResp.StmtID, + SQL: fmt.Sprintf("insert into d1 using %s.meters tags(?, ?) values (?, ?, ?, ?)", db), + } + resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmtPrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + + // num params + numParamsReq := stmtNumParamsRequest{ReqID: 4, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTNumParams, &numParamsReq) + assert.NoError(t, err) + var numParamsResp stmtNumParamsResponse + err = json.Unmarshal(resp, &numParamsResp) + assert.NoError(t, err) + assert.Equal(t, 0, numParamsResp.Code, numParamsResp.Message) + assert.Equal(t, uint64(4), numParamsResp.ReqID) + assert.Equal(t, 4, numParamsResp.NumParams) +} + +func TestStmtGetParams(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + db := "test_ws_stmt_get_params" + code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") + assert.Equal(t, 0, code, message) + code, message = doRestful(fmt.Sprintf("create stable if not exists %s.meters (ts timestamp,current float,voltage int,phase float) tags (groupid int,location varchar(24))", db), "") + assert.Equal(t, 0, code, message) + + defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := map[string]uint64{"req_id": 2} + resp, err = doWebSocket(ws, STMTInit, &initReq) + assert.NoError(t, err) + var initResp stmtInitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmtPrepareRequest{ + ReqID: 3, + StmtID: initResp.StmtID, + SQL: fmt.Sprintf("insert into d1 using %s.meters tags(?, ?) values (?, ?, ?, ?)", db), + } + resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmtPrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + + // get param + getParamsReq := stmtGetParamRequest{ReqID: 4, StmtID: prepareResp.StmtID, Index: 0} + resp, err = doWebSocket(ws, STMTGetParam, &getParamsReq) + assert.NoError(t, err) + var getParamsResp stmtGetParamResponse + err = json.Unmarshal(resp, &getParamsResp) + assert.NoError(t, err) + assert.Equal(t, 0, getParamsResp.Code, getParamsResp.Message) + assert.Equal(t, uint64(4), getParamsResp.ReqID) + assert.Equal(t, 0, getParamsResp.Index) + assert.Equal(t, 9, getParamsResp.DataType) + assert.Equal(t, 8, getParamsResp.Length) +} diff --git a/controller/ws/ws/ws.go b/controller/ws/ws/ws.go index 9e2f0e69..66898439 100644 --- a/controller/ws/ws/ws.go +++ b/controller/ws/ws/ws.go @@ -2,13 +2,13 @@ package ws import ( "github.com/gin-gonic/gin" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/generator" + "github.com/taosdata/taosadapter/v3/tools/melody" ) func init() { @@ -25,7 +25,7 @@ func (ws *webSocketCtl) Init(ctl gin.IRouter) { logger := log.GetLogger("WSC").WithFields(logrus.Fields{ config.SessionIDKey: sessionID}) if err := ws.m.HandleRequestWithKeys(c.Writer, c.Request, map[string]interface{}{"logger": logger}); err != nil { - panic(err) + logger.Errorf("handle request error: %v", err) } }) } @@ -33,7 +33,7 @@ func (ws *webSocketCtl) Init(ctl gin.IRouter) { func initController() *webSocketCtl { m := melody.New() m.Config.MaxMessageSize = 0 - m.UpGrader.EnableCompression = true + m.Upgrader.EnableCompression = true m.HandleConnect(func(session *melody.Session) { logger := wstool.GetLogger(session) @@ -41,16 +41,32 @@ func initController() *webSocketCtl { session.Set(TaosKey, newHandler(session)) }) m.HandleMessage(func(session *melody.Session, data []byte) { - if m.IsClosed() { + h := session.MustGet(TaosKey).(*messageHandler) + if h.closed { return } - session.MustGet(TaosKey).(*messageHandler).handleMessage(session, data) + h.wait.Add(1) + go func() { + defer h.wait.Done() + if h.closed { + return + } + h.handleMessage(session, data) + }() }) - m.HandleMessageBinary(func(session *melody.Session, bytes []byte) { - if m.IsClosed() { + m.HandleMessageBinary(func(session *melody.Session, data []byte) { + h := session.MustGet(TaosKey).(*messageHandler) + if h.closed { return } - session.MustGet(TaosKey).(*messageHandler).handleMessageBinary(session, bytes) + h.wait.Add(1) + go func() { + defer h.wait.Done() + if h.closed { + return + } + h.handleMessageBinary(session, data) + }() }) m.HandleClose(func(session *melody.Session, i int, s string) error { logger := wstool.GetLogger(session) diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index 66251775..54023fb7 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -1,9 +1,7 @@ package ws import ( - "bytes" "database/sql/driver" - "encoding/binary" "encoding/json" "fmt" "io" @@ -11,26 +9,17 @@ import ( "net/http/httptest" "strings" "testing" - "time" - "unsafe" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/param" - "github.com/taosdata/driver-go/v3/common/parser" - "github.com/taosdata/driver-go/v3/common/serializer" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - "github.com/taosdata/driver-go/v3/ws/schemaless" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" _ "github.com/taosdata/taosadapter/v3/controller/rest" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db" "github.com/taosdata/taosadapter/v3/log" - "github.com/taosdata/taosadapter/v3/tools/parseblock" "github.com/taosdata/taosadapter/v3/version" ) @@ -78,7 +67,7 @@ func doRestful(sql string, db string) (code int, message string) { return res.Code, res.Desc } -type queryResp struct { +type httpQueryResp struct { Code int `json:"code,omitempty"` Desc string `json:"desc,omitempty"` ColumnMeta [][]driver.Value `json:"column_meta,omitempty"` @@ -86,7 +75,7 @@ type queryResp struct { Rows int `json:"rows,omitempty"` } -func restQuery(sql string, db string) *queryResp { +func restQuery(sql string, db string) *httpQueryResp { w := httptest.NewRecorder() body := strings.NewReader(sql) url := "/rest/sql" @@ -98,13 +87,13 @@ func restQuery(sql string, db string) *queryResp { req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") router.ServeHTTP(w, req) if w.Code != http.StatusOK { - return &queryResp{ + return &httpQueryResp{ Code: w.Code, Desc: w.Body.String(), } } b, _ := io.ReadAll(w.Body) - var res queryResp + var res httpQueryResp _ = json.Unmarshal(b, &res) return &res } @@ -112,9 +101,15 @@ func restQuery(sql string, db string) *queryResp { func doWebSocket(ws *websocket.Conn, action string, arg interface{}) (resp []byte, err error) { var b []byte if arg != nil { - b, _ = json.Marshal(arg) + b, err = json.Marshal(arg) + if err != nil { + return nil, err + } + } + a, err := json.Marshal(Request{Action: action, Args: b}) + if err != nil { + return nil, err } - a, _ := json.Marshal(Request{Action: action, Args: b}) err = ws.WriteMessage(websocket.TextMessage, a) if err != nil { return nil, err @@ -136,6 +131,11 @@ func doWebSocketWithoutResp(ws *websocket.Conn, action string, arg interface{}) return nil } +type versionResponse struct { + commonResp + Version string +} + func TestVersion(t *testing.T) { s := httptest.NewServer(router) defer s.Close() @@ -150,3207 +150,10 @@ func TestVersion(t *testing.T) { }() resp, err := doWebSocket(ws, wstool.ClientVersion, nil) assert.NoError(t, err) - var versionResp VersionResponse + var versionResp versionResponse err = json.Unmarshal(resp, &versionResp) assert.NoError(t, err) assert.Equal(t, 0, versionResp.Code, versionResp.Message) assert.Equal(t, version.TaosClientVersion, versionResp.Version) -} - -func TestWsQuery(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful("drop database if exists test_ws_query", "") - assert.Equal(t, 0, code, message) - code, message = doRestful("create database if not exists test_ws_query", "") - assert.Equal(t, 0, code, message) - code, message = doRestful( - "create table if not exists stb1 (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", - "test_ws_query") - assert.Equal(t, 0, code, message) - code, message = doRestful( - `insert into t1 using stb1 tags ('{\"table\":\"t1\"}') values (now-2s,true,2,3,4,5,6,7,8,9,10,11,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now-1s,false,12,13,14,15,16,17,18,19,110,111,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)`, - "test_ws_query") - assert.Equal(t, 0, code, message) - - code, message = doRestful("create table t2 using stb1 tags('{\"table\":\"t2\"}')", "test_ws_query") - assert.Equal(t, 0, code, message) - code, message = doRestful("create table t3 using stb1 tags('{\"table\":\"t3\"}')", "test_ws_query") - assert.Equal(t, 0, code, message) - - defer doRestful("drop database if exists test_ws_query", "") - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_query"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // query - queryReq := QueryRequest{ReqID: 2, Sql: "select * from stb1"} - resp, err = doWebSocket(ws, WSQuery, &queryReq) - assert.NoError(t, err) - var queryResp QueryResponse - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), queryResp.ReqID) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch - fetchReq := FetchRequest{ReqID: 3, ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - var fetchResp FetchResponse - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), fetchResp.ReqID) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - assert.Equal(t, 3, fetchResp.Rows) - - // fetch block - fetchBlockReq := FetchBlockRequest{ReqID: 4, ID: queryResp.ID} - fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) - assert.NoError(t, err) - resultID, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) - assert.Equal(t, uint64(1), resultID) - checkBlockResult(t, blockResult) - - fetchReq = FetchRequest{ReqID: 5, ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, uint64(5), fetchResp.ReqID) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - - assert.Equal(t, true, fetchResp.Completed) - - // write block - var buffer bytes.Buffer - wstool.WriteUint64(&buffer, 300) // req id - wstool.WriteUint64(&buffer, 400) // message id - wstool.WriteUint64(&buffer, uint64(RawBlockMessage)) // action - wstool.WriteUint32(&buffer, uint32(fetchResp.Rows)) // rows - wstool.WriteUint16(&buffer, uint16(2)) // table name length - buffer.WriteString("t2") // table name - buffer.Write(fetchBlockResp[16:]) // raw block - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - var writeResp BaseResponse - err = json.Unmarshal(resp, &writeResp) - assert.NoError(t, err) - assert.Equal(t, 0, writeResp.Code, writeResp.Message) - - // query - queryReq = QueryRequest{ReqID: 6, Sql: "select * from t2"} - resp, err = doWebSocket(ws, WSQuery, &queryReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch - fetchReq = FetchRequest{ReqID: 7, ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - - // fetch block - fetchBlockReq = FetchBlockRequest{ReqID: 8, ID: queryResp.ID} - fetchBlockResp, err = doWebSocket(ws, WSFetchBlock, &fetchBlockReq) - assert.NoError(t, err) - resultID, blockResult = parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) - checkBlockResult(t, blockResult) - assert.Equal(t, queryResp.ID, resultID) - // fetch - fetchReq = FetchRequest{ReqID: 9, ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - - assert.Equal(t, true, fetchResp.Completed) - - // write block with filed - buffer.Reset() - wstool.WriteUint64(&buffer, 300) // req id - wstool.WriteUint64(&buffer, 400) // message id - wstool.WriteUint64(&buffer, uint64(RawBlockMessageWithFields)) // action - wstool.WriteUint32(&buffer, uint32(fetchResp.Rows)) // rows - wstool.WriteUint16(&buffer, uint16(2)) // table name length - buffer.WriteString("t3") // table name - buffer.Write(fetchBlockResp[16:]) // raw block - fields := []byte{ - // ts - 0x74, 0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x09, - // padding - 0x00, 0x00, - // bytes - 0x08, 0x00, 0x00, 0x00, - - // v1 - 0x76, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x01, - // padding - 0x00, 0x00, - // bytes - 0x01, 0x00, 0x00, 0x00, - - // v2 - 0x76, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x02, - // padding - 0x00, 0x00, - // bytes - 0x01, 0x00, 0x00, 0x00, - - // v3 - 0x76, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x03, - // padding - 0x00, 0x00, - // bytes - 0x02, 0x00, 0x00, 0x00, - - // v4 - 0x76, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x04, - // padding - 0x00, 0x00, - // bytes - 0x04, 0x00, 0x00, 0x00, - - // v5 - 0x76, 0x35, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x05, - // padding - 0x00, 0x00, - // bytes - 0x08, 0x00, 0x00, 0x00, - - // v6 - 0x76, 0x36, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0b, - // padding - 0x00, 0x00, - // bytes - 0x01, 0x00, 0x00, 0x00, - - // v7 - 0x76, 0x37, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0c, - // padding - 0x00, 0x00, - // bytes - 0x02, 0x00, 0x00, 0x00, - - // v8 - 0x76, 0x38, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0d, - // padding - 0x00, 0x00, - // bytes - 0x04, 0x00, 0x00, 0x00, - - // v9 - 0x76, 0x39, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0e, - // padding - 0x00, 0x00, - // bytes - 0x08, 0x00, 0x00, 0x00, - - // v10 - 0x76, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x06, - // padding - 0x00, 0x00, - // bytes - 0x04, 0x00, 0x00, 0x00, - - // v11 - 0x76, 0x31, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x07, - // padding - 0x00, 0x00, - // bytes - 0x08, 0x00, 0x00, 0x00, - - // v12 - 0x76, 0x31, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x08, - // padding - 0x00, 0x00, - // bytes - 0x14, 0x00, 0x00, 0x00, - - // v13 - 0x76, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0a, - // padding - 0x00, 0x00, - // bytes - 0x14, 0x00, 0x00, 0x00, - - // v14 - 0x76, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x10, - // padding - 0x00, 0x00, - // bytes - 0x14, 0x00, 0x00, 0x00, - - // v15 - 0x76, 0x31, 0x35, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x14, - // padding - 0x00, 0x00, - // bytes - 0x64, 0x00, 0x00, 0x00, - - // info - 0x69, 0x6e, 0x66, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0f, - // padding - 0x00, 0x00, - // bytes - 0x00, 0x10, 0x00, 0x00, - } - buffer.Write(fields) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &writeResp) - assert.NoError(t, err) - assert.Equal(t, 0, writeResp.Code, writeResp.Message) - - // query - queryReq = QueryRequest{ReqID: 10, Sql: "select * from t3"} - resp, err = doWebSocket(ws, WSQuery, &queryReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch - fetchReq = FetchRequest{ReqID: 11, ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - - // fetch block - fetchBlockReq = FetchBlockRequest{ReqID: 12, ID: queryResp.ID} - fetchBlockResp, err = doWebSocket(ws, WSFetchBlock, &fetchBlockReq) - assert.NoError(t, err) - resultID, blockResult = parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) - assert.Equal(t, queryResp.ID, resultID) - checkBlockResult(t, blockResult) - // fetch - fetchReq = FetchRequest{ReqID: 13, ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - - assert.Equal(t, true, fetchResp.Completed) - time.Sleep(time.Second) -} - -type FetchRawBlockResponse struct { - Flag uint64 - Version uint16 - Time uint64 - ReqID uint64 - Code uint32 - Message string - ResultID uint64 - Finished bool - RawBlock []byte -} - -func parseFetchRawBlock(message []byte) *FetchRawBlockResponse { - var resp = &FetchRawBlockResponse{} - resp.Flag = binary.LittleEndian.Uint64(message) - resp.Version = binary.LittleEndian.Uint16(message[16:]) - resp.Time = binary.LittleEndian.Uint64(message[18:]) - resp.ReqID = binary.LittleEndian.Uint64(message[26:]) - resp.Code = binary.LittleEndian.Uint32(message[34:]) - msgLen := int(binary.LittleEndian.Uint32(message[38:])) - resp.Message = string(message[42 : 42+msgLen]) - if resp.Code != 0 { - return resp - } - resp.ResultID = binary.LittleEndian.Uint64(message[42+msgLen:]) - resp.Finished = message[50+msgLen] == 1 - if resp.Finished { - return resp - } - blockLength := binary.LittleEndian.Uint32(message[51+msgLen:]) - resp.RawBlock = message[55+msgLen : 55+msgLen+int(blockLength)] - return resp -} - -func ReadBlockSimple(block unsafe.Pointer, precision int) [][]driver.Value { - blockSize := parser.RawBlockGetNumOfRows(block) - colCount := parser.RawBlockGetNumOfCols(block) - colInfo := make([]parser.RawBlockColInfo, colCount) - parser.RawBlockGetColInfo(block, colInfo) - colTypes := make([]uint8, colCount) - for i := int32(0); i < colCount; i++ { - colTypes[i] = uint8(colInfo[i].ColType) - } - return parser.ReadBlock(block, int(blockSize), colTypes, precision) -} - -func TestWsBinaryQuery(t *testing.T) { - dbName := "test_ws_binary_query" - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful(fmt.Sprintf("drop database if exists %s", dbName), "") - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create database if not exists %s", dbName), "") - assert.Equal(t, 0, code, message) - code, message = doRestful( - "create table if not exists stb1 (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", - dbName) - assert.Equal(t, 0, code, message) - code, message = doRestful( - `insert into t1 using stb1 tags ('{\"table\":\"t1\"}') values (now-2s,true,2,3,4,5,6,7,8,9,10,11,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now-1s,false,12,13,14,15,16,17,18,19,110,111,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)`, - dbName) - assert.Equal(t, 0, code, message) - - code, message = doRestful("create table t2 using stb1 tags('{\"table\":\"t2\"}')", dbName) - assert.Equal(t, 0, code, message) - code, message = doRestful("create table t3 using stb1 tags('{\"table\":\"t3\"}')", dbName) - assert.Equal(t, 0, code, message) - - defer doRestful(fmt.Sprintf("drop database if exists %s", dbName), "") - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: dbName} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // query - sql := "select * from stb1" - var buffer bytes.Buffer - wstool.WriteUint64(&buffer, 2) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 1) // version - wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length - buffer.WriteString(sql) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - var queryResp QueryResponse - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), queryResp.ReqID) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch raw block - buffer.Reset() - wstool.WriteUint64(&buffer, 3) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) // version - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp := parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(3), fetchRawBlockResp.ReqID) - assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) - assert.Equal(t, uint64(1), fetchRawBlockResp.ResultID) - assert.Equal(t, false, fetchRawBlockResp.Finished) - rows := parser.RawBlockGetNumOfRows(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0])) - assert.Equal(t, int32(3), rows) - blockResult := ReadBlockSimple(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]), queryResp.Precision) - checkBlockResult(t, blockResult) - rawBlock := fetchRawBlockResp.RawBlock - - buffer.Reset() - wstool.WriteUint64(&buffer, 5) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) // version - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(5), fetchRawBlockResp.ReqID) - assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) - assert.Equal(t, uint64(1), fetchRawBlockResp.ResultID) - assert.Equal(t, true, fetchRawBlockResp.Finished) - - // write block - - buffer.Reset() - wstool.WriteUint64(&buffer, 300) // req id - wstool.WriteUint64(&buffer, 400) // message id - wstool.WriteUint64(&buffer, uint64(RawBlockMessage)) // action - wstool.WriteUint32(&buffer, uint32(3)) // rows - wstool.WriteUint16(&buffer, uint16(2)) // table name length - buffer.WriteString("t2") // table name - buffer.Write(rawBlock) // raw block - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - var writeResp BaseResponse - err = json.Unmarshal(resp, &writeResp) - assert.NoError(t, err) - assert.Equal(t, 0, writeResp.Code, writeResp.Message) - - // query - sql = "select * from t2" - buffer.Reset() - wstool.WriteUint64(&buffer, 6) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 1) // version - wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length - buffer.WriteString(sql) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch raw block - buffer.Reset() - wstool.WriteUint64(&buffer, 7) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) // version - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(7), fetchRawBlockResp.ReqID) - assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) - assert.Equal(t, false, fetchRawBlockResp.Finished) - blockResult = ReadBlockSimple(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]), queryResp.Precision) - checkBlockResult(t, blockResult) - rawBlock = fetchRawBlockResp.RawBlock - - buffer.Reset() - wstool.WriteUint64(&buffer, 9) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) // version - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(9), fetchRawBlockResp.ReqID) - assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) - assert.Equal(t, true, fetchRawBlockResp.Finished) - - // write block with filed - buffer.Reset() - wstool.WriteUint64(&buffer, 300) // req id - wstool.WriteUint64(&buffer, 400) // message id - wstool.WriteUint64(&buffer, uint64(RawBlockMessageWithFields)) // action - wstool.WriteUint32(&buffer, uint32(3)) // rows - wstool.WriteUint16(&buffer, uint16(2)) // table name length - buffer.WriteString("t3") // table name - buffer.Write(rawBlock) // raw block - fields := []byte{ - // ts - 0x74, 0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x09, - // padding - 0x00, 0x00, - // bytes - 0x08, 0x00, 0x00, 0x00, - - // v1 - 0x76, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x01, - // padding - 0x00, 0x00, - // bytes - 0x01, 0x00, 0x00, 0x00, - - // v2 - 0x76, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x02, - // padding - 0x00, 0x00, - // bytes - 0x01, 0x00, 0x00, 0x00, - - // v3 - 0x76, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x03, - // padding - 0x00, 0x00, - // bytes - 0x02, 0x00, 0x00, 0x00, - - // v4 - 0x76, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x04, - // padding - 0x00, 0x00, - // bytes - 0x04, 0x00, 0x00, 0x00, - - // v5 - 0x76, 0x35, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x05, - // padding - 0x00, 0x00, - // bytes - 0x08, 0x00, 0x00, 0x00, - - // v6 - 0x76, 0x36, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0b, - // padding - 0x00, 0x00, - // bytes - 0x01, 0x00, 0x00, 0x00, - - // v7 - 0x76, 0x37, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0c, - // padding - 0x00, 0x00, - // bytes - 0x02, 0x00, 0x00, 0x00, - - // v8 - 0x76, 0x38, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0d, - // padding - 0x00, 0x00, - // bytes - 0x04, 0x00, 0x00, 0x00, - - // v9 - 0x76, 0x39, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0e, - // padding - 0x00, 0x00, - // bytes - 0x08, 0x00, 0x00, 0x00, - - // v10 - 0x76, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x06, - // padding - 0x00, 0x00, - // bytes - 0x04, 0x00, 0x00, 0x00, - - // v11 - 0x76, 0x31, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x07, - // padding - 0x00, 0x00, - // bytes - 0x08, 0x00, 0x00, 0x00, - - // v12 - 0x76, 0x31, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x08, - // padding - 0x00, 0x00, - // bytes - 0x14, 0x00, 0x00, 0x00, - - // v13 - 0x76, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0a, - // padding - 0x00, 0x00, - // bytes - 0x14, 0x00, 0x00, 0x00, - - // v14 - 0x76, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x10, - // padding - 0x00, 0x00, - // bytes - 0x14, 0x00, 0x00, 0x00, - - // v15 - 0x76, 0x31, 0x35, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x14, - // padding - 0x00, 0x00, - // bytes - 0x64, 0x00, 0x00, 0x00, - - // info - 0x69, 0x6e, 0x66, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, - // type - 0x0f, - // padding - 0x00, 0x00, - // bytes - 0x00, 0x10, 0x00, 0x00, - } - buffer.Write(fields) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &writeResp) - assert.NoError(t, err) - assert.Equal(t, 0, writeResp.Code, writeResp.Message) - - // query - sql = "select * from t3" - buffer.Reset() - wstool.WriteUint64(&buffer, 6) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 1) // version - wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length - buffer.WriteString(sql) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch raw block - buffer.Reset() - wstool.WriteUint64(&buffer, 11) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) // version - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(11), fetchRawBlockResp.ReqID) - assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) - assert.Equal(t, false, fetchRawBlockResp.Finished) - blockResult = ReadBlockSimple(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]), queryResp.Precision) - checkBlockResult(t, blockResult) - - buffer.Reset() - wstool.WriteUint64(&buffer, 13) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) // version - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(13), fetchRawBlockResp.ReqID) - assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) - assert.Equal(t, true, fetchRawBlockResp.Finished) - - // wrong message length - buffer.Reset() - wstool.WriteUint64(&buffer, 6) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 1) // version - wstool.WriteUint32(&buffer, 0) // sql length - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 65535, queryResp.Code, queryResp.Message) - - // wrong sql length - buffer.Reset() - wstool.WriteUint64(&buffer, 6) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 1) // version - wstool.WriteUint32(&buffer, 100) // sql length - buffer.WriteString("wrong sql length") - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 65535, queryResp.Code, queryResp.Message) - - // wrong version - buffer.Reset() - sql = "select 1" - wstool.WriteUint64(&buffer, 6) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 100) // version - wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length - buffer.WriteString(sql) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 65535, queryResp.Code, queryResp.Message) - - // wrong sql - buffer.Reset() - sql = "wrong sql" - wstool.WriteUint64(&buffer, 6) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 1) // version - wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length - buffer.WriteString(sql) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.NotEqual(t, 0, queryResp.Code, queryResp.Message) - - // insert sql - buffer.Reset() - sql = "create table t4 using stb1 tags('{\"table\":\"t4\"}')" - wstool.WriteUint64(&buffer, 6) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 1) // version - wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length - buffer.WriteString(sql) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - assert.Equal(t, true, queryResp.IsUpdate) - - // wrong fetch - buffer.Reset() - sql = "select 1" - wstool.WriteUint64(&buffer, 6) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(BinaryQueryMessage)) - wstool.WriteUint16(&buffer, 1) // version - wstool.WriteUint32(&buffer, uint32(len(sql))) // sql length - buffer.WriteString(sql) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // wrong fetch raw block length - buffer.Reset() - wstool.WriteUint64(&buffer, 700) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(700), fetchRawBlockResp.ReqID) - assert.NotEqual(t, uint32(0), fetchRawBlockResp.Code) - - // wrong fetch raw block version - buffer.Reset() - wstool.WriteUint64(&buffer, 800) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 100) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(800), fetchRawBlockResp.ReqID) - assert.NotEqual(t, uint32(0), fetchRawBlockResp.Code) - time.Sleep(time.Second) - - // wrong fetch raw block result - buffer.Reset() - wstool.WriteUint64(&buffer, 900) // req id - wstool.WriteUint64(&buffer, queryResp.ID+100) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(900), fetchRawBlockResp.ReqID) - assert.NotEqual(t, uint32(0), fetchRawBlockResp.Code) - - // fetch freed raw block - buffer.Reset() - wstool.WriteUint64(&buffer, 600) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(600), fetchRawBlockResp.ReqID) - assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) - assert.Equal(t, int32(1), parser.RawBlockGetNumOfRows(unsafe.Pointer(&fetchRawBlockResp.RawBlock[0]))) - - buffer.Reset() - wstool.WriteUint64(&buffer, 700) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(700), fetchRawBlockResp.ReqID) - assert.Equal(t, uint32(0), fetchRawBlockResp.Code, fetchRawBlockResp.Message) - assert.Equal(t, true, fetchRawBlockResp.Finished) - - buffer.Reset() - wstool.WriteUint64(&buffer, 400) // req id - wstool.WriteUint64(&buffer, queryResp.ID) // message id - wstool.WriteUint64(&buffer, uint64(FetchRawBlockMessage)) - wstool.WriteUint16(&buffer, 1) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - fetchRawBlockResp = parseFetchRawBlock(resp) - assert.Equal(t, uint64(0xffffffffffffffff), fetchRawBlockResp.Flag) - assert.Equal(t, uint64(400), fetchRawBlockResp.ReqID) - assert.NotEqual(t, uint32(0), fetchRawBlockResp.Code) - time.Sleep(time.Second) -} - -func checkBlockResult(t *testing.T, blockResult [][]driver.Value) { - assert.Equal(t, 3, len(blockResult)) - assert.Equal(t, true, blockResult[0][1]) - assert.Equal(t, int8(2), blockResult[0][2]) - assert.Equal(t, int16(3), blockResult[0][3]) - assert.Equal(t, int32(4), blockResult[0][4]) - assert.Equal(t, int64(5), blockResult[0][5]) - assert.Equal(t, uint8(6), blockResult[0][6]) - assert.Equal(t, uint16(7), blockResult[0][7]) - assert.Equal(t, uint32(8), blockResult[0][8]) - assert.Equal(t, uint64(9), blockResult[0][9]) - assert.Equal(t, float32(10), blockResult[0][10]) - assert.Equal(t, float64(11), blockResult[0][11]) - assert.Equal(t, "中文\"binary", blockResult[0][12]) - assert.Equal(t, "中文nchar", blockResult[0][13]) - assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[0][14]) - assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) - assert.Equal(t, false, blockResult[1][1]) - assert.Equal(t, int8(12), blockResult[1][2]) - assert.Equal(t, int16(13), blockResult[1][3]) - assert.Equal(t, int32(14), blockResult[1][4]) - assert.Equal(t, int64(15), blockResult[1][5]) - assert.Equal(t, uint8(16), blockResult[1][6]) - assert.Equal(t, uint16(17), blockResult[1][7]) - assert.Equal(t, uint32(18), blockResult[1][8]) - assert.Equal(t, uint64(19), blockResult[1][9]) - assert.Equal(t, float32(110), blockResult[1][10]) - assert.Equal(t, float64(111), blockResult[1][11]) - assert.Equal(t, "中文\"binary", blockResult[1][12]) - assert.Equal(t, "中文nchar", blockResult[1][13]) - assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) - assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) - assert.Equal(t, nil, blockResult[2][1]) - assert.Equal(t, nil, blockResult[2][2]) - assert.Equal(t, nil, blockResult[2][3]) - assert.Equal(t, nil, blockResult[2][4]) - assert.Equal(t, nil, blockResult[2][5]) - assert.Equal(t, nil, blockResult[2][6]) - assert.Equal(t, nil, blockResult[2][7]) - assert.Equal(t, nil, blockResult[2][8]) - assert.Equal(t, nil, blockResult[2][9]) - assert.Equal(t, nil, blockResult[2][10]) - assert.Equal(t, nil, blockResult[2][11]) - assert.Equal(t, nil, blockResult[2][12]) - assert.Equal(t, nil, blockResult[2][13]) - assert.Equal(t, nil, blockResult[2][14]) - assert.Equal(t, nil, blockResult[2][15]) -} - -func TestWsSchemaless(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful("drop database if exists test_ws_schemaless", "") - assert.Equal(t, 0, code, message) - code, message = doRestful("create database if not exists test_ws_schemaless", "") - assert.Equal(t, 0, code, message) - - //defer doRestful("drop database if exists test_ws_schemaless", "") - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - cases := []struct { - name string - protocol int - precision string - data string - ttl int - code int - totalRows int32 - affectedRows int - tableNameKey string - }{ - { - name: "influxdb", - protocol: schemaless.InfluxDBLineProtocol, - precision: "ms", - data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + - "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + - "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + - "measurement,host=host1 field1=2i,field2=2.0 1577837600000", - ttl: 1000, - code: 0, - totalRows: 4, - affectedRows: 4, - }, - { - name: "opentsdb_telnet", - protocol: schemaless.OpenTSDBTelnetLineProtocol, - precision: "ms", - data: "meters.current 1648432611249 10.3 location=California.SanFrancisco group=2\n" + - "meters.current 1648432611250 12.6 location=California.SanFrancisco group=2\n" + - "meters.current 1648432611249 10.8 location=California.LosAngeles group=3\n" + - "meters.current 1648432611250 11.3 location=California.LosAngeles group=3\n" + - "meters.voltage 1648432611249 219 location=California.SanFrancisco group=2\n" + - "meters.voltage 1648432611250 218 location=California.SanFrancisco group=2\n" + - "meters.voltage 1648432611249 221 location=California.LosAngeles group=3\n" + - "meters.voltage 1648432611250 217 location=California.LosAngeles group=3", - ttl: 1000, - code: 0, - totalRows: 8, - affectedRows: 8, - }, - { - name: "opentsdb_json", - protocol: schemaless.OpenTSDBJsonFormatProtocol, - precision: "ms", - data: `[ - { - "metric": "meters2.current", - "timestamp": 1648432611249, - "value": 10.3, - "tags": { - "location": "California.SanFrancisco", - "groupid": 2 - } - }, - { - "metric": "meters2.voltage", - "timestamp": 1648432611249, - "value": 219, - "tags": { - "location": "California.LosAngeles", - "groupid": 1 - } - }, - { - "metric": "meters2.current", - "timestamp": 1648432611250, - "value": 12.6, - "tags": { - "location": "California.SanFrancisco", - "groupid": 2 - } - }, - { - "metric": "meters2.voltage", - "timestamp": 1648432611250, - "value": 221, - "tags": { - "location": "California.LosAngeles", - "groupid": 1 - } - } -]`, - ttl: 100, - code: 0, - affectedRows: 4, - }, - { - name: "influxdb_tbnamekey", - protocol: schemaless.InfluxDBLineProtocol, - precision: "ms", - data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + - "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + - "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + - "measurement,host=host1 field1=2i,field2=2.0 1577837600000", - ttl: 1000, - code: 0, - totalRows: 4, - affectedRows: 4, - tableNameKey: "host", - }, - } - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_schemaless"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - for _, c := range cases { - reqID := uint64(1) - t.Run(c.name, func(t *testing.T) { - reqID += 1 - req := SchemalessWriteRequest{ - ReqID: reqID, - Protocol: c.protocol, - Precision: c.precision, - TTL: c.ttl, - Data: c.data, - TableNameKey: c.tableNameKey, - } - resp, err = doWebSocket(ws, SchemalessWrite, &req) - assert.NoError(t, err) - var schemalessResp SchemalessWriteResponse - err = json.Unmarshal(resp, &schemalessResp) - assert.NoError(t, err, string(resp)) - assert.Equal(t, reqID, schemalessResp.ReqID) - assert.Equal(t, 0, schemalessResp.Code, schemalessResp.Message) - if c.protocol != schemaless.OpenTSDBJsonFormatProtocol { - assert.Equal(t, c.totalRows, schemalessResp.TotalRows) - } - assert.Equal(t, c.affectedRows, schemalessResp.AffectedRows) - }) - } -} - -func TestWsStmt(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful("drop database if exists test_ws_stmt_ws", "") - assert.Equal(t, 0, code, message) - code, message = doRestful("create database if not exists test_ws_stmt_ws precision 'ns'", "") - assert.Equal(t, 0, code, message) - - defer doRestful("drop database if exists test_ws_stmt_ws", "") - - code, message = doRestful( - "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", - "test_ws_stmt_ws") - assert.Equal(t, 0, code, message) - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt_ws"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := map[string]uint64{"req_id": 2} - resp, err = doWebSocket(ws, STMTInit, &initReq) - assert.NoError(t, err) - var initResp StmtInitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := StmtPrepareRequest{ReqID: 3, StmtID: initResp.StmtID, SQL: "insert into ? using test_ws_stmt_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} - resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) - assert.NoError(t, err) - var prepareResp StmtPrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.True(t, prepareResp.IsInsert) - - // set table name - setTableNameReq := StmtSetTableNameRequest{ReqID: 4, StmtID: prepareResp.StmtID, Name: "test_ws_stmt_ws.ct1"} - resp, err = doWebSocket(ws, STMTSetTableName, &setTableNameReq) - assert.NoError(t, err) - var setTableNameResp BaseResponse - err = json.Unmarshal(resp, &setTableNameResp) - assert.NoError(t, err) - assert.Equal(t, uint64(4), setTableNameResp.ReqID) - assert.Equal(t, 0, setTableNameResp.Code, setTableNameResp.Message) - - // get tag fields - getTagFieldsReq := StmtGetTagFieldsRequest{ReqID: 5, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTGetTagFields, &getTagFieldsReq) - assert.NoError(t, err) - var getTagFieldsResp StmtGetTagFieldsResponse - err = json.Unmarshal(resp, &getTagFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(5), getTagFieldsResp.ReqID) - assert.Equal(t, 0, getTagFieldsResp.Code, getTagFieldsResp.Message) - - // get col fields - getColFieldsReq := StmtGetColFieldsRequest{ReqID: 6, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTGetColFields, &getColFieldsReq) - assert.NoError(t, err) - var getColFieldsResp StmtGetColFieldsResponse - err = json.Unmarshal(resp, &getColFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(6), getColFieldsResp.ReqID) - assert.Equal(t, 0, getColFieldsResp.Code, getColFieldsResp.Message) - - // set tags - setTagsReq := StmtSetTagsRequest{ReqID: 7, StmtID: prepareResp.StmtID, Tags: json.RawMessage(`["{\"a\":\"b\"}"]`)} - resp, err = doWebSocket(ws, STMTSetTags, &setTagsReq) - assert.NoError(t, err) - var setTagsResp BaseResponse - err = json.Unmarshal(resp, &setTagsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(7), setTagsResp.ReqID) - assert.Equal(t, 0, setTagsResp.Code, setTagsResp.Message) - - // bind - now := time.Now() - columns, _ := json.Marshal([][]driver.Value{ - {now, now.Add(time.Second), now.Add(time.Second * 2)}, - {true, false, nil}, - {2, 22, nil}, - {3, 33, nil}, - {4, 44, nil}, - {5, 55, nil}, - {6, 66, nil}, - {7, 77, nil}, - {8, 88, nil}, - {9, 99, nil}, - {10, 1010, nil}, - {11, 1111, nil}, - {"binary", "binary2", nil}, - {"nchar", "nchar2", nil}, - {"aabbcc", "aabbcc", nil}, - {"010100000000000000000059400000000000005940", "010100000000000000000059400000000000005940", nil}, - }) - bindReq := StmtBindRequest{ReqID: 8, StmtID: prepareResp.StmtID, Columns: columns} - resp, err = doWebSocket(ws, STMTBind, &bindReq) - assert.NoError(t, err) - var bindResp StmtBindResponse - err = json.Unmarshal(resp, &bindResp) - assert.NoError(t, err) - assert.Equal(t, uint64(8), bindResp.ReqID) - assert.Equal(t, 0, bindResp.Code, bindResp.Message) - - // add batch - addBatchReq := StmtAddBatchRequest{ReqID: 9, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTAddBatch, &addBatchReq) - assert.NoError(t, err) - var addBatchResp StmtAddBatchResponse - err = json.Unmarshal(resp, &addBatchResp) - assert.NoError(t, err) - assert.Equal(t, uint64(9), addBatchResp.ReqID) - assert.Equal(t, 0, bindResp.Code, bindResp.Message) - - // exec - execReq := StmtExecRequest{ReqID: 10, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTExec, &execReq) - assert.NoError(t, err) - var execResp StmtExecResponse - err = json.Unmarshal(resp, &execResp) - assert.NoError(t, err) - assert.Equal(t, uint64(10), execResp.ReqID) - assert.Equal(t, 0, execResp.Code, execResp.Message) - - // close - closeReq := StmtCloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} - err = doWebSocketWithoutResp(ws, STMTClose, &closeReq) - assert.NoError(t, err) - - // query - queryReq := QueryRequest{Sql: "select * from test_ws_stmt_ws.stb"} - resp, err = doWebSocket(ws, WSQuery, &queryReq) - assert.NoError(t, err) - var queryResp QueryResponse - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch - fetchReq := FetchRequest{ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - var fetchResp FetchResponse - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - - // fetch block - fetchBlockReq := FetchBlockRequest{ID: queryResp.ID} - fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) - assert.NoError(t, err) - _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) - assert.Equal(t, 3, len(blockResult)) - assert.Equal(t, now.UnixNano(), blockResult[0][0].(time.Time).UnixNano()) - - assert.Equal(t, true, blockResult[0][1]) - assert.Equal(t, int8(2), blockResult[0][2]) - assert.Equal(t, int16(3), blockResult[0][3]) - assert.Equal(t, int32(4), blockResult[0][4]) - assert.Equal(t, int64(5), blockResult[0][5]) - assert.Equal(t, uint8(6), blockResult[0][6]) - assert.Equal(t, uint16(7), blockResult[0][7]) - assert.Equal(t, uint32(8), blockResult[0][8]) - assert.Equal(t, uint64(9), blockResult[0][9]) - assert.Equal(t, float32(10), blockResult[0][10]) - assert.Equal(t, float64(11), blockResult[0][11]) - assert.Equal(t, "binary", blockResult[0][12]) - assert.Equal(t, "nchar", blockResult[0][13]) - assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) - assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) - - assert.Equal(t, now.Add(time.Second).UnixNano(), blockResult[1][0].(time.Time).UnixNano()) - assert.Equal(t, false, blockResult[1][1]) - assert.Equal(t, int8(22), blockResult[1][2]) - assert.Equal(t, int16(33), blockResult[1][3]) - assert.Equal(t, int32(44), blockResult[1][4]) - assert.Equal(t, int64(55), blockResult[1][5]) - assert.Equal(t, uint8(66), blockResult[1][6]) - assert.Equal(t, uint16(77), blockResult[1][7]) - assert.Equal(t, uint32(88), blockResult[1][8]) - assert.Equal(t, uint64(99), blockResult[1][9]) - assert.Equal(t, float32(1010), blockResult[1][10]) - assert.Equal(t, float64(1111), blockResult[1][11]) - assert.Equal(t, "binary2", blockResult[1][12]) - assert.Equal(t, "nchar2", blockResult[1][13]) - assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) - assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) - - assert.Equal(t, now.Add(time.Second*2).UnixNano(), blockResult[2][0].(time.Time).UnixNano()) - for i := 1; i < 16; i++ { - assert.Nil(t, blockResult[2][i]) - } - - // block message - // init - resp, err = doWebSocket(ws, STMTInit, nil) - assert.NoError(t, err) - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq = StmtPrepareRequest{StmtID: initResp.StmtID, SQL: "insert into ? using test_ws_stmt_ws.stb tags(?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} - resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - - // set table name - setTableNameReq = StmtSetTableNameRequest{StmtID: prepareResp.StmtID, Name: "test_ws_stmt_ws.ct2"} - resp, err = doWebSocket(ws, STMTSetTableName, &setTableNameReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &setTableNameResp) - assert.NoError(t, err) - assert.Equal(t, 0, setTableNameResp.Code, setTableNameResp.Message) - - // set tags - var tagBuffer bytes.Buffer - wstool.WriteUint64(&tagBuffer, 100) - wstool.WriteUint64(&tagBuffer, prepareResp.StmtID) - wstool.WriteUint64(&tagBuffer, uint64(SetTagsMessage)) - tags, err := json.Marshal(map[string]string{"a": "b"}) - assert.NoError(t, err) - b, err := serializer.SerializeRawBlock( - []*param.Param{ - param.NewParam(1).AddJson(tags), - }, - param.NewColumnType(1).AddJson(50)) - assert.NoError(t, err) - assert.NoError(t, err) - tagBuffer.Write(b) - - err = ws.WriteMessage(websocket.BinaryMessage, tagBuffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &setTagsResp) - assert.NoError(t, err) - assert.Equal(t, 0, setTagsResp.Code, setTagsResp.Message) - - // bind binary - var block bytes.Buffer - wstool.WriteUint64(&block, 10) - wstool.WriteUint64(&block, prepareResp.StmtID) - wstool.WriteUint64(&block, uint64(BindMessage)) - rawBlock := []byte{ - 0x01, 0x00, 0x00, 0x00, - 0x11, 0x02, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x80, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - - 0x09, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x01, 0x00, 0x00, 0x00, - 0x02, 0x01, 0x00, 0x00, 0x00, - 0x03, 0x02, 0x00, 0x00, 0x00, - 0x04, 0x04, 0x00, 0x00, 0x00, - 0x05, 0x08, 0x00, 0x00, 0x00, - 0x0b, 0x01, 0x00, 0x00, 0x00, - 0x0c, 0x02, 0x00, 0x00, 0x00, - 0x0d, 0x04, 0x00, 0x00, 0x00, - 0x0e, 0x08, 0x00, 0x00, 0x00, - 0x06, 0x04, 0x00, 0x00, 0x00, - 0x07, 0x08, 0x00, 0x00, 0x00, - 0x08, 0x16, 0x00, 0x00, 0x00, - 0x0a, 0x52, 0x00, 0x00, 0x00, - 0x10, 0x20, 0x00, 0x00, 0x00, - 0x14, 0x20, 0x00, 0x00, 0x00, - - 0x18, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, - 0x06, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, - 0x06, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, - 0x11, 0x00, 0x00, 0x00, - 0x30, 0x00, 0x00, 0x00, - 0x21, 0x00, 0x00, 0x00, - 0x2e, 0x00, 0x00, 0x00, - - 0x00, - 0x2c, 0x5b, 0x70, 0x86, 0x82, 0x01, 0x00, 0x00, - 0x14, 0x5f, 0x70, 0x86, 0x82, 0x01, 0x00, 0x00, - 0xfc, 0x62, 0x70, 0x86, 0x82, 0x01, 0x00, 0x00, - - 0x20, - 0x01, - 0x00, - 0x00, - - 0x20, - 0x02, - 0x16, - 0x00, - - 0x20, - 0x03, 0x00, - 0x21, 0x00, - 0x00, 0x00, - - 0x20, - 0x04, 0x00, 0x00, 0x00, - 0x2c, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - - 0x20, - 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x37, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - - 0x20, - 0x06, - 0x42, - 0x00, - - 0x20, - 0x07, 0x00, - 0x4d, 0x00, - 0x00, 0x00, - - 0x20, - 0x08, 0x00, 0x00, 0x00, - 0x58, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - - 0x20, - 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - - 0x20, - 0x00, 0x00, 0x20, 0x41, - 0x00, 0x80, 0x7c, 0x44, - 0x00, 0x00, 0x00, 0x00, - - 0x20, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x40, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x5c, 0x91, 0x40, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - - 0x00, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x00, 0x00, - 0xff, 0xff, 0xff, 0xff, - 0x06, 0x00, - 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, - 0x07, 0x00, - 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x32, - - 0x00, 0x00, 0x00, 0x00, - 0x16, 0x00, 0x00, 0x00, - 0xff, 0xff, 0xff, 0xff, - 0x14, 0x00, - 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, - 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, - 0x18, 0x00, - 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, - 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, - - 0x00, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, - 0xff, 0xff, 0xff, 0xff, - 0x0e, 0x00, - 0x74, 0x65, 0x73, 0x74, 0x5f, 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, - 0x0f, 0x00, - 0x74, 0x65, 0x73, 0x74, 0x5f, 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x32, - - 0x00, 0x00, 0x00, 0x00, - 0x17, 0x00, 0x00, 0x00, - 0xff, 0xff, 0xff, 0xff, - 0x15, 0x00, - 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, - 0x15, 0x00, - 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, - } - binary.LittleEndian.PutUint64(rawBlock[173:], uint64(now.UnixNano())) - binary.LittleEndian.PutUint64(rawBlock[181:], uint64(now.Add(time.Second).UnixNano())) - binary.LittleEndian.PutUint64(rawBlock[189:], uint64(now.Add(time.Second*2).UnixNano())) - block.Write(rawBlock) - err = ws.WriteMessage( - websocket.BinaryMessage, - block.Bytes(), - ) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - err = json.Unmarshal(resp, &bindResp) - assert.NoError(t, err) - assert.Equal(t, 0, bindResp.Code, bindResp.Message) - - // add batch - addBatchReq = StmtAddBatchRequest{StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTAddBatch, &addBatchReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &addBatchResp) - assert.NoError(t, err) - assert.Equal(t, 0, bindResp.Code, bindResp.Message) - - // exec - execReq = StmtExecRequest{StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTExec, &execReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &execResp) - assert.NoError(t, err) - assert.Equal(t, 0, execResp.Code, execResp.Message) - - // query - queryReq = QueryRequest{Sql: "select * from test_ws_stmt_ws.ct2"} - resp, err = doWebSocket(ws, WSQuery, &queryReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch - fetchReq = FetchRequest{ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - - // fetch block - fetchBlockReq = FetchBlockRequest{ID: queryResp.ID} - fetchBlockResp, err = doWebSocket(ws, WSFetchBlock, &fetchBlockReq) - assert.NoError(t, err) - _, blockResult = parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) - assert.Equal(t, now.UnixNano(), blockResult[0][0].(time.Time).UnixNano()) - assert.Equal(t, true, blockResult[0][1]) - assert.Equal(t, int8(2), blockResult[0][2]) - assert.Equal(t, int16(3), blockResult[0][3]) - assert.Equal(t, int32(4), blockResult[0][4]) - assert.Equal(t, int64(5), blockResult[0][5]) - assert.Equal(t, uint8(6), blockResult[0][6]) - assert.Equal(t, uint16(7), blockResult[0][7]) - assert.Equal(t, uint32(8), blockResult[0][8]) - assert.Equal(t, uint64(9), blockResult[0][9]) - assert.Equal(t, float32(10), blockResult[0][10]) - assert.Equal(t, float64(11), blockResult[0][11]) - assert.Equal(t, "binary", blockResult[0][12]) - assert.Equal(t, "nchar", blockResult[0][13]) - assert.Equal(t, []byte("test_varbinary2"), blockResult[1][14]) - assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) - - assert.Equal(t, now.Add(time.Second).UnixNano(), blockResult[1][0].(time.Time).UnixNano()) - assert.Equal(t, false, blockResult[1][1]) - assert.Equal(t, int8(22), blockResult[1][2]) - assert.Equal(t, int16(33), blockResult[1][3]) - assert.Equal(t, int32(44), blockResult[1][4]) - assert.Equal(t, int64(55), blockResult[1][5]) - assert.Equal(t, uint8(66), blockResult[1][6]) - assert.Equal(t, uint16(77), blockResult[1][7]) - assert.Equal(t, uint32(88), blockResult[1][8]) - assert.Equal(t, uint64(99), blockResult[1][9]) - assert.Equal(t, float32(1010), blockResult[1][10]) - assert.Equal(t, float64(1111), blockResult[1][11]) - assert.Equal(t, "binary2", blockResult[1][12]) - assert.Equal(t, "nchar2", blockResult[1][13]) - assert.Equal(t, []byte("test_varbinary2"), blockResult[1][14]) - assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) - - assert.Equal(t, now.Add(time.Second*2).UnixNano(), blockResult[2][0].(time.Time).UnixNano()) - for i := 1; i < 16; i++ { - assert.Nil(t, blockResult[2][i]) - } -} - -func TestStmtQuery(t *testing.T) { - //for stable - prepareDataSql := []string{ - "create stable meters (ts timestamp,current float,voltage int,phase float) tags (group_id int, location varchar(24))", - "insert into d0 using meters tags (2, 'California.SanFrancisco') values ('2023-09-13 17:53:52.123', 10.2, 219, 0.32) ", - "insert into d1 using meters tags (1, 'California.SanFrancisco') values ('2023-09-13 17:54:43.321', 10.3, 218, 0.31) ", - } - StmtQuery(t, "test_ws_stmt_query_for_stable", prepareDataSql) - - // for table - prepareDataSql = []string{ - "create table meters (ts timestamp,current float,voltage int,phase float, group_id int, location varchar(24))", - "insert into meters values ('2023-09-13 17:53:52.123', 10.2, 219, 0.32, 2, 'California.SanFrancisco') ", - "insert into meters values ('2023-09-13 17:54:43.321', 10.3, 218, 0.31, 1, 'California.SanFrancisco') ", - } - StmtQuery(t, "test_ws_stmt_query_for_table", prepareDataSql) -} - -func StmtQuery(t *testing.T, db string, prepareDataSql []string) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") - assert.Equal(t, 0, code, message) - - defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") - - for _, sql := range prepareDataSql { - code, message = doRestful(sql, db) - assert.Equal(t, 0, code, message) - } - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := map[string]uint64{"req_id": 2} - resp, err = doWebSocket(ws, STMTInit, &initReq) - assert.NoError(t, err) - var initResp StmtInitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := StmtPrepareRequest{ - ReqID: 3, - StmtID: initResp.StmtID, - SQL: fmt.Sprintf("select * from %s.meters where group_id=? and location=?", db), - } - resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) - assert.NoError(t, err) - var prepareResp StmtPrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.False(t, prepareResp.IsInsert) - - // bind - var block bytes.Buffer - wstool.WriteUint64(&block, 5) - wstool.WriteUint64(&block, prepareResp.StmtID) - wstool.WriteUint64(&block, uint64(BindMessage)) - b, err := serializer.SerializeRawBlock( - []*param.Param{ - param.NewParam(1).AddInt(1), - param.NewParam(1).AddBinary([]byte("California.SanFrancisco")), - }, - param.NewColumnType(2).AddInt().AddBinary(24)) - assert.NoError(t, err) - block.Write(b) - - err = ws.WriteMessage(websocket.BinaryMessage, block.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - var bindResp StmtBindResponse - err = json.Unmarshal(resp, &bindResp) - assert.NoError(t, err) - assert.Equal(t, uint64(5), bindResp.ReqID) - assert.Equal(t, 0, bindResp.Code, bindResp.Message) - - // add batch - addBatchReq := StmtAddBatchRequest{StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTAddBatch, &addBatchReq) - assert.NoError(t, err) - var addBatchResp StmtAddBatchResponse - err = json.Unmarshal(resp, &addBatchResp) - assert.NoError(t, err) - assert.Equal(t, 0, bindResp.Code, bindResp.Message) - - // exec - execReq := StmtExecRequest{ReqID: 6, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTExec, &execReq) - assert.NoError(t, err) - var execResp StmtExecResponse - err = json.Unmarshal(resp, &execResp) - assert.NoError(t, err) - assert.Equal(t, uint64(6), execResp.ReqID) - assert.Equal(t, 0, execResp.Code, execResp.Message) - - // use result - useResultReq := StmtUseResultRequest{ReqID: 7, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTUseResult, &useResultReq) - assert.NoError(t, err) - var useResultResp StmtUseResultResponse - err = json.Unmarshal(resp, &useResultResp) - assert.NoError(t, err) - assert.Equal(t, uint64(7), useResultResp.ReqID) - assert.Equal(t, 0, useResultResp.Code, useResultResp.Message) - - // fetch - fetchReq := FetchRequest{ReqID: 8, ID: useResultResp.ResultID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - var fetchResp FetchResponse - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, uint64(8), fetchResp.ReqID) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - assert.Equal(t, 1, fetchResp.Rows) - - // fetch block - fetchBlockReq := FetchBlockRequest{ReqID: 9, ID: useResultResp.ResultID} - fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) - assert.NoError(t, err) - _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], useResultResp.FieldsTypes, fetchResp.Rows, useResultResp.Precision) - assert.Equal(t, 1, len(blockResult)) - assert.Equal(t, float32(10.3), blockResult[0][1]) - assert.Equal(t, int32(218), blockResult[0][2]) - assert.Equal(t, float32(0.31), blockResult[0][3]) - - // free result - freeResultReq, _ := json.Marshal(FreeResultRequest{ReqID: 10, ID: useResultResp.ResultID}) - a, _ := json.Marshal(Request{Action: WSFreeResult, Args: freeResultReq}) - err = ws.WriteMessage(websocket.TextMessage, a) - assert.NoError(t, err) - - // close - closeReq := StmtCloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} - err = doWebSocketWithoutResp(ws, STMTClose, &closeReq) - assert.NoError(t, err) -} - -func TestStmtNumParams(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - db := "test_ws_stmt_num_params" - code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create stable if not exists %s.meters (ts timestamp,current float,voltage int,phase float) tags (groupid int,location varchar(24))", db), "") - assert.Equal(t, 0, code, message) - - defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := map[string]uint64{"req_id": 2} - resp, err = doWebSocket(ws, STMTInit, &initReq) - assert.NoError(t, err) - var initResp StmtInitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := StmtPrepareRequest{ - ReqID: 3, - StmtID: initResp.StmtID, - SQL: fmt.Sprintf("insert into d1 using %s.meters tags(?, ?) values (?, ?, ?, ?)", db), - } - resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) - assert.NoError(t, err) - var prepareResp StmtPrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - - // num params - numParamsReq := StmtNumParamsRequest{ReqID: 4, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMTNumParams, &numParamsReq) - assert.NoError(t, err) - var numParamsResp StmtNumParamsResponse - err = json.Unmarshal(resp, &numParamsResp) - assert.NoError(t, err) - assert.Equal(t, 0, numParamsResp.Code, numParamsResp.Message) - assert.Equal(t, uint64(4), numParamsResp.ReqID) - assert.Equal(t, 4, numParamsResp.NumParams) -} - -func TestStmtGetParams(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - db := "test_ws_stmt_get_params" - code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create stable if not exists %s.meters (ts timestamp,current float,voltage int,phase float) tags (groupid int,location varchar(24))", db), "") - assert.Equal(t, 0, code, message) - - defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := map[string]uint64{"req_id": 2} - resp, err = doWebSocket(ws, STMTInit, &initReq) - assert.NoError(t, err) - var initResp StmtInitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := StmtPrepareRequest{ - ReqID: 3, - StmtID: initResp.StmtID, - SQL: fmt.Sprintf("insert into d1 using %s.meters tags(?, ?) values (?, ?, ?, ?)", db), - } - resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) - assert.NoError(t, err) - var prepareResp StmtPrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - - // get param - getParamsReq := StmtGetParamRequest{ReqID: 4, StmtID: prepareResp.StmtID, Index: 0} - resp, err = doWebSocket(ws, STMTGetParam, &getParamsReq) - assert.NoError(t, err) - var getParamsResp StmtGetParamResponse - err = json.Unmarshal(resp, &getParamsResp) - assert.NoError(t, err) - assert.Equal(t, 0, getParamsResp.Code, getParamsResp.Message) - assert.Equal(t, uint64(4), getParamsResp.ReqID) - assert.Equal(t, 0, getParamsResp.Index) - assert.Equal(t, 9, getParamsResp.DataType) - assert.Equal(t, 8, getParamsResp.Length) -} - -func TestGetCurrentDB(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - db := "test_current_db" - code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") - assert.Equal(t, 0, code, message) - - defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // current db - currentDBReq := map[string]uint64{"req_id": 1} - resp, err = doWebSocket(ws, WSGetCurrentDB, ¤tDBReq) - assert.NoError(t, err) - var currentDBResp GetCurrentDBResponse - err = json.Unmarshal(resp, ¤tDBResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), currentDBResp.ReqID) - assert.Equal(t, 0, currentDBResp.Code, currentDBResp.Message) - assert.Equal(t, db, currentDBResp.DB) -} - -func TestGetServerInfo(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // server info - serverInfoReq := map[string]uint64{"req_id": 1} - resp, err = doWebSocket(ws, WSGetServerInfo, &serverInfoReq) - assert.NoError(t, err) - var serverInfoResp GetServerInfoResponse - err = json.Unmarshal(resp, &serverInfoResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), serverInfoResp.ReqID) - assert.Equal(t, 0, serverInfoResp.Code, serverInfoResp.Message) - t.Log(serverInfoResp.Info) -} - -func TestNumFields(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - db := "test_ws_num_fields" - code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), db) - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), db) - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create stable if not exists %s.meters (ts timestamp,current float,voltage int,phase float) tags (groupid int,location varchar(24))", db), db) - assert.Equal(t, 0, code, message) - code, message = doRestful("INSERT INTO d1 USING meters TAGS (1, 'location1') VALUES (now, 10.2, 219, 0.31) "+ - "d2 USING meters TAGS (2, 'location2') VALUES (now, 10.3, 220, 0.32)", db) - assert.Equal(t, 0, code, message) - - defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // query - queryReq := QueryRequest{ReqID: 2, Sql: "select * from meters"} - resp, err = doWebSocket(ws, WSQuery, &queryReq) - assert.NoError(t, err) - var queryResp QueryResponse - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), queryResp.ReqID) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // num fields - numFieldsReq := NumFieldsRequest{ReqID: 3, ResultID: queryResp.ID} - resp, err = doWebSocket(ws, WSNumFields, &numFieldsReq) - assert.NoError(t, err) - var numFieldsResp NumFieldsResponse - err = json.Unmarshal(resp, &numFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), numFieldsResp.ReqID) - assert.Equal(t, 0, numFieldsResp.Code, numFieldsResp.Message) - assert.Equal(t, 6, numFieldsResp.NumFields) -} - -func TestWsStmt2(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful("drop database if exists test_ws_stmt2_ws", "") - assert.Equal(t, 0, code, message) - code, message = doRestful("create database if not exists test_ws_stmt2_ws precision 'ns'", "") - assert.Equal(t, 0, code, message) - - defer doRestful("drop database if exists test_ws_stmt2_ws", "") - - code, message = doRestful( - "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", - "test_ws_stmt2_ws") - assert.Equal(t, 0, code, message) - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt2_ws"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := Stmt2InitRequest{ - ReqID: 0x123, - SingleStbInsert: false, - SingleTableBindOnce: false, - } - resp, err = doWebSocket(ws, STMT2Init, &initReq) - assert.NoError(t, err) - var initResp Stmt2InitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(0x123), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := Stmt2PrepareRequest{ReqID: 3, StmtID: initResp.StmtID, SQL: "insert into ct1 using test_ws_stmt2_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - var prepareResp Stmt2PrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.True(t, prepareResp.IsInsert) - - // get tag fields - getTagFieldsReq := Stmt2GetFieldsRequest{ReqID: 5, StmtID: prepareResp.StmtID, FieldTypes: []int8{stmtCommon.TAOS_FIELD_TAG}} - resp, err = doWebSocket(ws, STMT2GetFields, &getTagFieldsReq) - assert.NoError(t, err) - var getTagFieldsResp Stmt2GetFieldsResponse - err = json.Unmarshal(resp, &getTagFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(5), getTagFieldsResp.ReqID) - assert.Equal(t, 0, getTagFieldsResp.Code, getTagFieldsResp.Message) - - // get col fields - getColFieldsReq := Stmt2GetFieldsRequest{ReqID: 6, StmtID: prepareResp.StmtID, FieldTypes: []int8{stmtCommon.TAOS_FIELD_COL}} - resp, err = doWebSocket(ws, STMT2GetFields, &getColFieldsReq) - assert.NoError(t, err) - var getColFieldsResp Stmt2GetFieldsResponse - err = json.Unmarshal(resp, &getColFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(6), getColFieldsResp.ReqID) - assert.Equal(t, 0, getColFieldsResp.Code, getColFieldsResp.Message) - - // bind - now := time.Now() - cols := [][]driver.Value{ - // ts - {now, now.Add(time.Second), now.Add(time.Second * 2)}, - // bool - {true, false, nil}, - // tinyint - {int8(2), int8(22), nil}, - // smallint - {int16(3), int16(33), nil}, - // int - {int32(4), int32(44), nil}, - // bigint - {int64(5), int64(55), nil}, - // tinyint unsigned - {uint8(6), uint8(66), nil}, - // smallint unsigned - {uint16(7), uint16(77), nil}, - // int unsigned - {uint32(8), uint32(88), nil}, - // bigint unsigned - {uint64(9), uint64(99), nil}, - // float - {float32(10), float32(1010), nil}, - // double - {float64(11), float64(1111), nil}, - // binary - {"binary", "binary2", nil}, - // nchar - {"nchar", "nchar2", nil}, - // varbinary - {[]byte{0xaa, 0xbb, 0xcc}, []byte{0xaa, 0xbb, 0xcc}, nil}, - // geometry - {[]byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, nil}, - } - tbName := "test_ws_stmt2_ws.ct1" - tag := []driver.Value{"{\"a\":\"b\"}"} - binds := &stmtCommon.TaosStmt2BindData{ - TableName: tbName, - Tags: tag, - Cols: cols, - } - bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, getColFieldsResp.ColFields, getTagFieldsResp.TagFields) - assert.NoError(t, err) - bindReq := make([]byte, len(bs)+30) - // req_id - binary.LittleEndian.PutUint64(bindReq, 0x12345) - // stmt_id - binary.LittleEndian.PutUint64(bindReq[8:], prepareResp.StmtID) - // action - binary.LittleEndian.PutUint64(bindReq[16:], Stmt2BindMessage) - // version - binary.LittleEndian.PutUint16(bindReq[24:], Stmt2BindProtocolVersion1) - // col_idx - idx := int32(-1) - binary.LittleEndian.PutUint32(bindReq[26:], uint32(idx)) - // data - copy(bindReq[30:], bs) - err = ws.WriteMessage(websocket.BinaryMessage, bindReq) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - var bindResp Stmt2BindResponse - err = json.Unmarshal(resp, &bindResp) - assert.NoError(t, err) - assert.Equal(t, 0, bindResp.Code, bindResp.Message) - - //exec - execReq := Stmt2ExecRequest{ReqID: 10, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMT2Exec, &execReq) - assert.NoError(t, err) - var execResp Stmt2ExecResponse - err = json.Unmarshal(resp, &execResp) - assert.NoError(t, err) - assert.Equal(t, uint64(10), execResp.ReqID) - assert.Equal(t, 0, execResp.Code, execResp.Message) - assert.Equal(t, 3, execResp.Affected) - - // close - closeReq := Stmt2CloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMT2Close, &closeReq) - assert.NoError(t, err) - var closeResp Stmt2CloseResponse - err = json.Unmarshal(resp, &closeResp) - assert.NoError(t, err) - assert.Equal(t, uint64(11), closeResp.ReqID) - assert.Equal(t, 0, closeResp.Code, closeResp.Message) - - // query - queryReq := QueryRequest{Sql: "select * from test_ws_stmt2_ws.stb"} - resp, err = doWebSocket(ws, WSQuery, &queryReq) - assert.NoError(t, err) - var queryResp QueryResponse - err = json.Unmarshal(resp, &queryResp) - assert.NoError(t, err) - assert.Equal(t, 0, queryResp.Code, queryResp.Message) - - // fetch - fetchReq := FetchRequest{ID: queryResp.ID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - var fetchResp FetchResponse - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - - // fetch block - fetchBlockReq := FetchBlockRequest{ID: queryResp.ID} - fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) - assert.NoError(t, err) - _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) - assert.Equal(t, 3, len(blockResult)) - assert.Equal(t, now.UnixNano(), blockResult[0][0].(time.Time).UnixNano()) - - assert.Equal(t, true, blockResult[0][1]) - assert.Equal(t, int8(2), blockResult[0][2]) - assert.Equal(t, int16(3), blockResult[0][3]) - assert.Equal(t, int32(4), blockResult[0][4]) - assert.Equal(t, int64(5), blockResult[0][5]) - assert.Equal(t, uint8(6), blockResult[0][6]) - assert.Equal(t, uint16(7), blockResult[0][7]) - assert.Equal(t, uint32(8), blockResult[0][8]) - assert.Equal(t, uint64(9), blockResult[0][9]) - assert.Equal(t, float32(10), blockResult[0][10]) - assert.Equal(t, float64(11), blockResult[0][11]) - assert.Equal(t, "binary", blockResult[0][12]) - assert.Equal(t, "nchar", blockResult[0][13]) - assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) - assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) - - assert.Equal(t, now.Add(time.Second).UnixNano(), blockResult[1][0].(time.Time).UnixNano()) - assert.Equal(t, false, blockResult[1][1]) - assert.Equal(t, int8(22), blockResult[1][2]) - assert.Equal(t, int16(33), blockResult[1][3]) - assert.Equal(t, int32(44), blockResult[1][4]) - assert.Equal(t, int64(55), blockResult[1][5]) - assert.Equal(t, uint8(66), blockResult[1][6]) - assert.Equal(t, uint16(77), blockResult[1][7]) - assert.Equal(t, uint32(88), blockResult[1][8]) - assert.Equal(t, uint64(99), blockResult[1][9]) - assert.Equal(t, float32(1010), blockResult[1][10]) - assert.Equal(t, float64(1111), blockResult[1][11]) - assert.Equal(t, "binary2", blockResult[1][12]) - assert.Equal(t, "nchar2", blockResult[1][13]) - assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) - assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) - - assert.Equal(t, now.Add(time.Second*2).UnixNano(), blockResult[2][0].(time.Time).UnixNano()) - for i := 1; i < 16; i++ { - assert.Nil(t, blockResult[2][i]) - } - -} - -func TestStmt2Prepare(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful("drop database if exists test_ws_stmt2_prepare_ws", "") - assert.Equal(t, 0, code, message) - code, message = doRestful("create database if not exists test_ws_stmt2_prepare_ws precision 'ns'", "") - assert.Equal(t, 0, code, message) - - defer doRestful("drop database if exists test_ws_stmt2_prepare_ws", "") - - code, message = doRestful( - "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", - "test_ws_stmt2_prepare_ws") - assert.Equal(t, 0, code, message) - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt2_prepare_ws"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := Stmt2InitRequest{ - ReqID: 0x123, - SingleStbInsert: false, - SingleTableBindOnce: false, - } - resp, err = doWebSocket(ws, STMT2Init, &initReq) - assert.NoError(t, err) - var initResp Stmt2InitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(0x123), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := Stmt2PrepareRequest{ - ReqID: 3, - StmtID: initResp.StmtID, - SQL: "insert into ctb using test_ws_stmt2_prepare_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - GetFields: true, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - var prepareResp Stmt2PrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.Equal(t, true, prepareResp.IsInsert) - names := [17]string{ - "info", - "ts", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - } - fieldTypes := [17]int8{ - common.TSDB_DATA_TYPE_JSON, - common.TSDB_DATA_TYPE_TIMESTAMP, - common.TSDB_DATA_TYPE_BOOL, - common.TSDB_DATA_TYPE_TINYINT, - common.TSDB_DATA_TYPE_SMALLINT, - common.TSDB_DATA_TYPE_INT, - common.TSDB_DATA_TYPE_BIGINT, - common.TSDB_DATA_TYPE_UTINYINT, - common.TSDB_DATA_TYPE_USMALLINT, - common.TSDB_DATA_TYPE_UINT, - common.TSDB_DATA_TYPE_UBIGINT, - common.TSDB_DATA_TYPE_FLOAT, - common.TSDB_DATA_TYPE_DOUBLE, - common.TSDB_DATA_TYPE_BINARY, - common.TSDB_DATA_TYPE_NCHAR, - common.TSDB_DATA_TYPE_VARBINARY, - common.TSDB_DATA_TYPE_GEOMETRY, - } - assert.True(t, prepareResp.IsInsert) - assert.Equal(t, 17, len(prepareResp.Fields)) - for i := 0; i < 17; i++ { - assert.Equal(t, names[i], prepareResp.Fields[i].Name) - assert.Equal(t, fieldTypes[i], prepareResp.Fields[i].FieldType) - } - // prepare query - prepareReq = Stmt2PrepareRequest{ - ReqID: 4, - StmtID: initResp.StmtID, - SQL: "select * from test_ws_stmt2_prepare_ws.stb where ts = ? and v1 = ?", - GetFields: true, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(4), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.Equal(t, false, prepareResp.IsInsert) - assert.Nil(t, prepareResp.Fields) - assert.Equal(t, 2, prepareResp.FieldsCount) -} - -func TestStmt2GetFields(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful("drop database if exists test_ws_stmt2_getfields_ws", "") - assert.Equal(t, 0, code, message) - code, message = doRestful("create database if not exists test_ws_stmt2_getfields_ws precision 'ns'", "") - assert.Equal(t, 0, code, message) - - defer doRestful("drop database if exists test_ws_stmt2_getfields_ws", "") - - code, message = doRestful( - "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", - "test_ws_stmt2_getfields_ws") - assert.Equal(t, 0, code, message) - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt2_getfields_ws"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := Stmt2InitRequest{ - ReqID: 0x123, - SingleStbInsert: false, - SingleTableBindOnce: false, - } - resp, err = doWebSocket(ws, STMT2Init, &initReq) - assert.NoError(t, err) - var initResp Stmt2InitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(0x123), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := Stmt2PrepareRequest{ - ReqID: 3, - StmtID: initResp.StmtID, - SQL: "insert into ctb using test_ws_stmt2_getfields_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - GetFields: false, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - var prepareResp Stmt2PrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.Equal(t, true, prepareResp.IsInsert) - - // get fields - getFieldsReq := Stmt2GetFieldsRequest{ - ReqID: 4, - StmtID: prepareResp.StmtID, - FieldTypes: []int8{ - stmtCommon.TAOS_FIELD_TAG, - stmtCommon.TAOS_FIELD_COL, - }, - } - resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) - assert.NoError(t, err) - var getFieldsResp Stmt2GetFieldsResponse - err = json.Unmarshal(resp, &getFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(4), getFieldsResp.ReqID) - assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) - names := [16]string{ - "ts", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - } - fieldTypes := [16]int8{ - common.TSDB_DATA_TYPE_TIMESTAMP, - common.TSDB_DATA_TYPE_BOOL, - common.TSDB_DATA_TYPE_TINYINT, - common.TSDB_DATA_TYPE_SMALLINT, - common.TSDB_DATA_TYPE_INT, - common.TSDB_DATA_TYPE_BIGINT, - common.TSDB_DATA_TYPE_UTINYINT, - common.TSDB_DATA_TYPE_USMALLINT, - common.TSDB_DATA_TYPE_UINT, - common.TSDB_DATA_TYPE_UBIGINT, - common.TSDB_DATA_TYPE_FLOAT, - common.TSDB_DATA_TYPE_DOUBLE, - common.TSDB_DATA_TYPE_BINARY, - common.TSDB_DATA_TYPE_NCHAR, - common.TSDB_DATA_TYPE_VARBINARY, - common.TSDB_DATA_TYPE_GEOMETRY, - } - assert.Equal(t, 16, len(getFieldsResp.ColFields)) - assert.Equal(t, 1, len(getFieldsResp.TagFields)) - for i := 0; i < 16; i++ { - assert.Equal(t, names[i], getFieldsResp.ColFields[i].Name) - assert.Equal(t, fieldTypes[i], getFieldsResp.ColFields[i].FieldType) - } - assert.Equal(t, "info", getFieldsResp.TagFields[0].Name) - assert.Equal(t, int8(common.TSDB_DATA_TYPE_JSON), getFieldsResp.TagFields[0].FieldType) - - // prepare get tablename - prepareReq = Stmt2PrepareRequest{ - ReqID: 5, - StmtID: initResp.StmtID, - SQL: "insert into ? using test_ws_stmt2_getfields_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - GetFields: false, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(5), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.Equal(t, true, prepareResp.IsInsert) - // get fields - getFieldsReq = Stmt2GetFieldsRequest{ - ReqID: 6, - StmtID: prepareResp.StmtID, - FieldTypes: []int8{ - stmtCommon.TAOS_FIELD_TBNAME, - }, - } - resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &getFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(6), getFieldsResp.ReqID) - assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) - - assert.Nil(t, getFieldsResp.ColFields) - assert.Nil(t, getFieldsResp.TagFields) - assert.Equal(t, int32(1), getFieldsResp.TableCount) - - // prepare query - prepareReq = Stmt2PrepareRequest{ - ReqID: 7, - StmtID: initResp.StmtID, - SQL: "select * from test_ws_stmt2_getfields_ws.stb where ts = ? and v1 = ?", - GetFields: false, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(7), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.Equal(t, false, prepareResp.IsInsert) - // get fields - getFieldsReq = Stmt2GetFieldsRequest{ - ReqID: 8, - StmtID: prepareResp.StmtID, - FieldTypes: []int8{ - stmtCommon.TAOS_FIELD_QUERY, - }, - } - resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &getFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(8), getFieldsResp.ReqID) - assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) - - assert.Nil(t, getFieldsResp.ColFields) - assert.Nil(t, getFieldsResp.TagFields) - assert.Equal(t, int32(2), getFieldsResp.QueryCount) - -} - -func TestStmt2Query(t *testing.T) { - //for stable - prepareDataSql := []string{ - "create stable meters (ts timestamp,current float,voltage int,phase float) tags (group_id int, location varchar(24))", - "insert into d0 using meters tags (2, 'California.SanFrancisco') values ('2023-09-13 17:53:52.123', 10.2, 219, 0.32) ", - "insert into d1 using meters tags (1, 'California.SanFrancisco') values ('2023-09-13 17:54:43.321', 10.3, 218, 0.31) ", - } - Stmt2Query(t, "test_ws_stmt2_query_for_stable", prepareDataSql) - - // for table - prepareDataSql = []string{ - "create table meters (ts timestamp,current float,voltage int,phase float, group_id int, location varchar(24))", - "insert into meters values ('2023-09-13 17:53:52.123', 10.2, 219, 0.32, 2, 'California.SanFrancisco') ", - "insert into meters values ('2023-09-13 17:54:43.321', 10.3, 218, 0.31, 1, 'California.SanFrancisco') ", - } - Stmt2Query(t, "test_ws_stmt2_query_for_table", prepareDataSql) -} - -func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful(fmt.Sprintf("drop database if exists %s", db), "") - assert.Equal(t, 0, code, message) - code, message = doRestful(fmt.Sprintf("create database if not exists %s", db), "") - assert.Equal(t, 0, code, message) - - defer doRestful(fmt.Sprintf("drop database if exists %s", db), "") - - for _, sql := range prepareDataSql { - code, message = doRestful(sql, db) - assert.Equal(t, 0, code, message) - } - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: db} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := map[string]uint64{"req_id": 2} - resp, err = doWebSocket(ws, STMT2Init, &initReq) - assert.NoError(t, err) - var initResp Stmt2InitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := Stmt2PrepareRequest{ - ReqID: 3, - StmtID: initResp.StmtID, - SQL: fmt.Sprintf("select * from %s.meters where group_id=? and location=?", db), - GetFields: false, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - var prepareResp Stmt2PrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.False(t, prepareResp.IsInsert) - - // bind - var block bytes.Buffer - wstool.WriteUint64(&block, 5) - wstool.WriteUint64(&block, prepareResp.StmtID) - wstool.WriteUint64(&block, uint64(Stmt2BindMessage)) - wstool.WriteUint16(&block, Stmt2BindProtocolVersion1) - idx := int32(-1) - wstool.WriteUint32(&block, uint32(idx)) - params := []*stmtCommon.TaosStmt2BindData{ - { - Cols: [][]driver.Value{ - {int32(1)}, - {"California.SanFrancisco"}, - }, - }, - } - b, err := stmtCommon.MarshalStmt2Binary(params, false, nil, nil) - assert.NoError(t, err) - block.Write(b) - - err = ws.WriteMessage(websocket.BinaryMessage, block.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - var bindResp Stmt2BindResponse - err = json.Unmarshal(resp, &bindResp) - assert.NoError(t, err) - assert.Equal(t, uint64(5), bindResp.ReqID) - assert.Equal(t, 0, bindResp.Code, bindResp.Message) - - // exec - execReq := StmtExecRequest{ReqID: 6, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMT2Exec, &execReq) - assert.NoError(t, err) - var execResp StmtExecResponse - err = json.Unmarshal(resp, &execResp) - assert.NoError(t, err) - assert.Equal(t, uint64(6), execResp.ReqID) - assert.Equal(t, 0, execResp.Code, execResp.Message) - - // use result - useResultReq := Stmt2UseResultRequest{ReqID: 7, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMT2Result, &useResultReq) - assert.NoError(t, err) - var useResultResp Stmt2UseResultResponse - err = json.Unmarshal(resp, &useResultResp) - assert.NoError(t, err) - assert.Equal(t, uint64(7), useResultResp.ReqID) - assert.Equal(t, 0, useResultResp.Code, useResultResp.Message) - - // fetch - fetchReq := FetchRequest{ReqID: 8, ID: useResultResp.ResultID} - resp, err = doWebSocket(ws, WSFetch, &fetchReq) - assert.NoError(t, err) - var fetchResp FetchResponse - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, uint64(8), fetchResp.ReqID) - assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) - assert.Equal(t, 1, fetchResp.Rows) - - // fetch block - fetchBlockReq := FetchBlockRequest{ReqID: 9, ID: useResultResp.ResultID} - fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) - assert.NoError(t, err) - _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], useResultResp.FieldsTypes, fetchResp.Rows, useResultResp.Precision) - assert.Equal(t, 1, len(blockResult)) - assert.Equal(t, float32(10.3), blockResult[0][1]) - assert.Equal(t, int32(218), blockResult[0][2]) - assert.Equal(t, float32(0.31), blockResult[0][3]) - - // free result - freeResultReq, _ := json.Marshal(FreeResultRequest{ReqID: 10, ID: useResultResp.ResultID}) - a, _ := json.Marshal(Request{Action: WSFreeResult, Args: freeResultReq}) - err = ws.WriteMessage(websocket.TextMessage, a) - assert.NoError(t, err) - - // close - closeReq := StmtCloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} - resp, err = doWebSocket(ws, STMT2Close, &closeReq) - assert.NoError(t, err) - var closeResp Stmt2CloseResponse - err = json.Unmarshal(resp, &fetchResp) - assert.NoError(t, err) - assert.Equal(t, 0, closeResp.Code, closeResp.Message) -} - -func TestWSConnect(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // wrong password - connReq := ConnRequest{ReqID: 1, User: "root", Password: "wrong"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, "Authentication failure", connResp.Message) - assert.Equal(t, 0x357, connResp.Code, connResp.Message) - - // connect - connReq = ConnRequest{ReqID: 1, User: "root", Password: "taosdata"} - resp, err = doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - //duplicate connections - connReq = ConnRequest{ReqID: 1, User: "root", Password: "taosdata"} - resp, err = doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0xffff, connResp.Code) - assert.Equal(t, "duplicate connections", connResp.Message) - -} - -type TestConnRequest struct { - ReqID uint64 `json:"req_id"` - User string `json:"user"` - Password string `json:"password"` - DB string `json:"db"` - Mode int `json:"mode"` -} - -func TestMode(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - wrongMode := 999 - connReq := TestConnRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: wrongMode} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0xffff, connResp.Code) - assert.Equal(t, fmt.Sprintf("unexpected mode:%d", wrongMode), connResp.Message) - - //bi - biMode := 0 - connReq = TestConnRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: biMode} - resp, err = doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - -} - -func TestWSTMQWriteRaw(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - data := []byte{ - 0x64, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0x01, 0x00, 0x00, 0x04, 0x73, 0x74, 0x62, - 0x00, 0xd5, 0xf0, 0xed, 0x8a, 0xe0, 0x23, 0xf3, 0x45, 0x00, 0x1c, 0x02, 0x09, 0x01, 0x10, 0x02, - 0x03, 0x74, 0x73, 0x00, 0x01, 0x01, 0x02, 0x04, 0x03, 0x63, 0x31, 0x00, 0x02, 0x01, 0x02, 0x06, - 0x03, 0x63, 0x32, 0x00, 0x03, 0x01, 0x04, 0x08, 0x03, 0x63, 0x33, 0x00, 0x04, 0x01, 0x08, 0x0a, - 0x03, 0x63, 0x34, 0x00, 0x05, 0x01, 0x10, 0x0c, 0x03, 0x63, 0x35, 0x00, 0x0b, 0x01, 0x02, 0x0e, - 0x03, 0x63, 0x36, 0x00, 0x0c, 0x01, 0x04, 0x10, 0x03, 0x63, 0x37, 0x00, 0x0d, 0x01, 0x08, 0x12, - 0x03, 0x63, 0x38, 0x00, 0x0e, 0x01, 0x10, 0x14, 0x03, 0x63, 0x39, 0x00, 0x06, 0x01, 0x08, 0x16, - 0x04, 0x63, 0x31, 0x30, 0x00, 0x07, 0x01, 0x10, 0x18, 0x04, 0x63, 0x31, 0x31, 0x00, 0x08, 0x01, - 0x2c, 0x1a, 0x04, 0x63, 0x31, 0x32, 0x00, 0x0a, 0x01, 0xa4, 0x01, 0x1c, 0x04, 0x63, 0x31, 0x33, - 0x00, 0x1c, 0x02, 0x09, 0x02, 0x10, 0x1e, 0x04, 0x74, 0x74, 0x73, 0x00, 0x01, 0x00, 0x02, 0x20, - 0x04, 0x74, 0x63, 0x31, 0x00, 0x02, 0x00, 0x02, 0x22, 0x04, 0x74, 0x63, 0x32, 0x00, 0x03, 0x00, - 0x04, 0x24, 0x04, 0x74, 0x63, 0x33, 0x00, 0x04, 0x00, 0x08, 0x26, 0x04, 0x74, 0x63, 0x34, 0x00, - 0x05, 0x00, 0x10, 0x28, 0x04, 0x74, 0x63, 0x35, 0x00, 0x0b, 0x00, 0x02, 0x2a, 0x04, 0x74, 0x63, - 0x36, 0x00, 0x0c, 0x00, 0x04, 0x2c, 0x04, 0x74, 0x63, 0x37, 0x00, 0x0d, 0x00, 0x08, 0x2e, 0x04, - 0x74, 0x63, 0x38, 0x00, 0x0e, 0x00, 0x10, 0x30, 0x04, 0x74, 0x63, 0x39, 0x00, 0x06, 0x00, 0x08, - 0x32, 0x05, 0x74, 0x63, 0x31, 0x30, 0x00, 0x07, 0x00, 0x10, 0x34, 0x05, 0x74, 0x63, 0x31, 0x31, - 0x00, 0x08, 0x00, 0x2c, 0x36, 0x05, 0x74, 0x63, 0x31, 0x32, 0x00, 0x0a, 0x00, 0xa4, 0x01, 0x38, - 0x05, 0x74, 0x63, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x1c, 0x02, 0x02, 0x02, - 0x01, 0x00, 0x02, 0x04, 0x02, 0x01, 0x00, 0x03, 0x06, 0x02, 0x01, 0x00, 0x01, 0x08, 0x02, 0x01, - 0x00, 0x01, 0x0a, 0x02, 0x01, 0x00, 0x01, 0x0c, 0x02, 0x01, 0x00, 0x01, 0x0e, 0x02, 0x01, 0x00, - 0x01, 0x10, 0x02, 0x01, 0x00, 0x01, 0x12, 0x02, 0x01, 0x00, 0x01, 0x14, 0x02, 0x01, 0x00, 0x01, - 0x16, 0x02, 0x01, 0x00, 0x04, 0x18, 0x02, 0x01, 0x00, 0x04, 0x1a, 0x02, 0x01, 0x00, 0xff, 0x1c, - 0x02, 0x01, 0x00, 0xff, - } - length := uint32(356) - metaType := uint16(531) - code, message := doRestful("create database if not exists test_ws_tmq_write_raw", "") - assert.Equal(t, 0, code, message) - defer func() { - code, message := doRestful("drop database if exists test_ws_tmq_write_raw", "") - assert.Equal(t, 0, code, message) - }() - // connect - connReq := ConnRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_tmq_write_raw"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - buffer := bytes.Buffer{} - wstool.WriteUint64(&buffer, 2) // req id - wstool.WriteUint64(&buffer, 0) // message id - wstool.WriteUint64(&buffer, uint64(TMQRawMessage)) - wstool.WriteUint32(&buffer, length) - wstool.WriteUint16(&buffer, metaType) - buffer.Write(data) - err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) - assert.NoError(t, err) - _, resp, err = ws.ReadMessage() - assert.NoError(t, err) - var tmqResp BaseResponse - err = json.Unmarshal(resp, &tmqResp) - assert.NoError(t, err) - assert.Equal(t, uint64(2), tmqResp.ReqID) - assert.Equal(t, 0, tmqResp.Code, tmqResp.Message) - - d := restQuery("describe stb", "test_ws_tmq_write_raw") - expect := [][]driver.Value{ - {"ts", "TIMESTAMP", float64(8), ""}, - {"c1", "BOOL", float64(1), ""}, - {"c2", "TINYINT", float64(1), ""}, - {"c3", "SMALLINT", float64(2), ""}, - {"c4", "INT", float64(4), ""}, - {"c5", "BIGINT", float64(8), ""}, - {"c6", "TINYINT UNSIGNED", float64(1), ""}, - {"c7", "SMALLINT UNSIGNED", float64(2), ""}, - {"c8", "INT UNSIGNED", float64(4), ""}, - {"c9", "BIGINT UNSIGNED", float64(8), ""}, - {"c10", "FLOAT", float64(4), ""}, - {"c11", "DOUBLE", float64(8), ""}, - {"c12", "VARCHAR", float64(20), ""}, - {"c13", "NCHAR", float64(20), ""}, - {"tts", "TIMESTAMP", float64(8), "TAG"}, - {"tc1", "BOOL", float64(1), "TAG"}, - {"tc2", "TINYINT", float64(1), "TAG"}, - {"tc3", "SMALLINT", float64(2), "TAG"}, - {"tc4", "INT", float64(4), "TAG"}, - {"tc5", "BIGINT", float64(8), "TAG"}, - {"tc6", "TINYINT UNSIGNED", float64(1), "TAG"}, - {"tc7", "SMALLINT UNSIGNED", float64(2), "TAG"}, - {"tc8", "INT UNSIGNED", float64(4), "TAG"}, - {"tc9", "BIGINT UNSIGNED", float64(8), "TAG"}, - {"tc10", "FLOAT", float64(4), "TAG"}, - {"tc11", "DOUBLE", float64(8), "TAG"}, - {"tc12", "VARCHAR", float64(20), "TAG"}, - {"tc13", "NCHAR", float64(20), "TAG"}, - } - for rowIndex, values := range d.Data { - for i := 0; i < 4; i++ { - assert.Equal(t, expect[rowIndex][i], values[i]) - } - } -} - -func TestDropUser(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - defer doRestful("drop user test_ws_drop_user", "") - code, message := doRestful("create user test_ws_drop_user pass 'pass'", "") - assert.Equal(t, 0, code, message) - // connect - connReq := ConnRequest{ReqID: 1, User: "test_ws_drop_user", Password: "pass"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp BaseResponse - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - // drop user - code, message = doRestful("drop user test_ws_drop_user", "") - assert.Equal(t, 0, code, message) - time.Sleep(time.Second * 3) - resp, err = doWebSocket(ws, wstool.ClientVersion, nil) - assert.Error(t, err, resp) + assert.Equal(t, wstool.ClientVersion, versionResp.Action) } diff --git a/controller/ws/wstool/error.go b/controller/ws/wstool/error.go index 009c085c..5551d126 100644 --- a/controller/ws/wstool/error.go +++ b/controller/ws/wstool/error.go @@ -3,9 +3,9 @@ package wstool import ( "context" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" tErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/taosadapter/v3/tools/melody" ) type WSErrorResp struct { diff --git a/controller/ws/wstool/error_test.go b/controller/ws/wstool/error_test.go index 75d5a7d3..9cd75648 100644 --- a/controller/ws/wstool/error_test.go +++ b/controller/ws/wstool/error_test.go @@ -11,10 +11,10 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" tErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/taosadapter/v3/tools/melody" ) func TestWSError(t *testing.T) { @@ -29,9 +29,6 @@ func TestWSError(t *testing.T) { commonErr := errors.New("test common error") logger := logrus.New().WithField("test", "TestWSError") m.HandleMessage(func(session *melody.Session, data []byte) { - if m.IsClosed() { - return - } switch data[0] { case '1': WSError(ctx, session, logger, taosErr, "test action", reqID) diff --git a/controller/ws/wstool/log.go b/controller/ws/wstool/log.go index a6f7fb5e..d2d35025 100644 --- a/controller/ws/wstool/log.go +++ b/controller/ws/wstool/log.go @@ -2,12 +2,12 @@ package wstool import ( "context" - "errors" + "os" "time" "github.com/gorilla/websocket" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" + "github.com/taosdata/taosadapter/v3/tools/melody" ) func GetDuration(ctx context.Context) int64 { @@ -21,7 +21,7 @@ func GetLogger(session *melody.Session) *logrus.Entry { func LogWSError(session *melody.Session, err error) { logger := session.MustGet("logger").(*logrus.Entry) var wsCloseErr *websocket.CloseError - is := errors.As(err, &wsCloseErr) + wsCloseErr, is := err.(*websocket.CloseError) if is { if wsCloseErr.Code == websocket.CloseNormalClosure { logger.Debug("ws close normal") @@ -29,6 +29,10 @@ func LogWSError(session *melody.Session, err error) { logger.Debugf("ws close in error, err:%s", wsCloseErr) } } else { - logger.Errorf("ws error, err:%s", err) + if os.IsTimeout(err) { + logger.Debugf("ws close due to timeout, err:%s", err) + } else { + logger.Debugf("ws error, err:%s", err) + } } } diff --git a/controller/ws/wstool/log_test.go b/controller/ws/wstool/log_test.go index 4b442724..6b0ccdb3 100644 --- a/controller/ws/wstool/log_test.go +++ b/controller/ws/wstool/log_test.go @@ -7,9 +7,9 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/tools/melody" ) func TestGetDuration(t *testing.T) { diff --git a/controller/ws/wstool/resp.go b/controller/ws/wstool/resp.go index 0b544a85..9f4772cd 100644 --- a/controller/ws/wstool/resp.go +++ b/controller/ws/wstool/resp.go @@ -4,8 +4,8 @@ import ( "database/sql/driver" "encoding/json" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" + "github.com/taosdata/taosadapter/v3/tools/melody" "github.com/taosdata/taosadapter/v3/version" ) @@ -25,11 +25,13 @@ func WSWriteJson(session *melody.Session, logger *logrus.Entry, data interface{} } logger.Tracef("write json:%s", b) _ = session.Write(b) + logger.Trace("write json done") } func WSWriteBinary(session *melody.Session, data []byte, logger *logrus.Entry) { logger.Tracef("write binary:%+v", data) _ = session.WriteBinary(data) + logger.Trace("write binary done") } type WSVersionResp struct { @@ -41,6 +43,12 @@ type WSVersionResp struct { var VersionResp []byte +func WSWriteVersion(session *melody.Session, logger *logrus.Entry) { + logger.Tracef("write version") + _ = session.WriteBinary(VersionResp) + logger.Trace("write version done") +} + type WSAction struct { Action string `json:"action"` Args json.RawMessage `json:"args"` diff --git a/controller/ws/wstool/resp_test.go b/controller/ws/wstool/resp_test.go index e47d737d..1680b1e1 100644 --- a/controller/ws/wstool/resp_test.go +++ b/controller/ws/wstool/resp_test.go @@ -8,9 +8,9 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "github.com/huskar-t/melody" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/tools/melody" ) func TestWSWriteJson(t *testing.T) { @@ -23,9 +23,6 @@ func TestWSWriteJson(t *testing.T) { Version: "1.0.0", } m.HandleMessage(func(session *melody.Session, _ []byte) { - if m.IsClosed() { - return - } logger := logrus.New().WithField("test", "TestWSWriteJson") session.Set("logger", logger) WSWriteJson(session, logger, data) diff --git a/go.mod b/go.mod index 2859e2f1..1cecbbc9 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module github.com/taosdata/taosadapter/v3 go 1.17 -replace github.com/huskar-t/melody => github.com/taosdata/melody v0.0.0-20240407104517-11dcf4a47591 - require ( collectd.org v0.5.0 github.com/gin-contrib/cors v1.4.0 @@ -13,7 +11,6 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/golang/snappy v0.0.4 github.com/gorilla/websocket v1.5.0 - github.com/huskar-t/melody v0.0.0-20240407104517-11dcf4a47591 github.com/influxdata/telegraf v1.23.4 github.com/json-iterator/go v1.1.12 github.com/kardianos/service v1.2.2 diff --git a/go.sum b/go.sum index 3a563b50..ba7dcb3b 100644 --- a/go.sum +++ b/go.sum @@ -2613,8 +2613,6 @@ github.com/taosdata/driver-go/v3 v3.5.1-0.20241101015534-8fb37f82db51 h1:diWG8X6 github.com/taosdata/driver-go/v3 v3.5.1-0.20241101015534-8fb37f82db51/go.mod h1:H2vo/At+rOPY1aMzUV9P49SVX7NlXb3LAbKw+MCLrmU= github.com/taosdata/file-rotatelogs/v2 v2.5.2 h1:6ryjwDdKqQtWrkVq9OKj4gvMING/f+fDluMAAe2DIXQ= github.com/taosdata/file-rotatelogs/v2 v2.5.2/go.mod h1:Qm99Lh0iMZouGgyy++JgTqKvP5FQw1ruR5jkWF7e1n0= -github.com/taosdata/melody v0.0.0-20240407104517-11dcf4a47591 h1:JT7pgLJpQvmSGPAFVWJZG/bPyuAio0uY3cb7AKyfhGY= -github.com/taosdata/melody v0.0.0-20240407104517-11dcf4a47591/go.mod h1:pfxtyQ9i9meRbS4BJBZ9YkE7upPetr0KKqcOCfi4nqY= github.com/tbrandon/mbserver v0.0.0-20170611213546-993e1772cc62/go.mod h1:qUzPVlSj2UgxJkVbH0ZwuuiR46U8RBMDT5KLY78Ifpw= github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= github.com/tdakkota/asciicheck v0.0.0-20200416200610-e657995f937b/go.mod h1:yHp0ai0Z9gUljN3o0xMhYJnH/IcvkdTBOX2fmJ93JEM= diff --git a/tools/bytesutil/bytesutil.go b/tools/bytesutil/bytesutil.go index aeaf0d43..d186114d 100644 --- a/tools/bytesutil/bytesutil.go +++ b/tools/bytesutil/bytesutil.go @@ -2,7 +2,6 @@ package bytesutil import ( "math/bits" - "reflect" "unsafe" ) @@ -68,16 +67,29 @@ func ToUnsafeString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } +// stringHeader instead of reflect.StringHeader +type stringHeader struct { + data unsafe.Pointer + len int +} + +// sliceHeader instead of reflect.SliceHeader +type sliceHeader struct { + data unsafe.Pointer + len int + cap int +} + // ToUnsafeBytes converts s to a byte slice without memory allocations. // // The returned byte slice is valid only until s is reachable and unmodified. func ToUnsafeBytes(s string) (b []byte) { - //nolint:staticcheck - sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) - //nolint:staticcheck - slh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - slh.Data = sh.Data - slh.Len = sh.Len - slh.Cap = sh.Len + if len(s) == 0 { + return []byte{} + } + hdr := (*sliceHeader)(unsafe.Pointer(&b)) + hdr.data = (*stringHeader)(unsafe.Pointer(&s)).data + hdr.cap = len(s) + hdr.len = len(s) return b } diff --git a/tools/bytesutil/bytesutil_test.go b/tools/bytesutil/bytesutil_test.go index a3ab29a9..2e446f11 100644 --- a/tools/bytesutil/bytesutil_test.go +++ b/tools/bytesutil/bytesutil_test.go @@ -197,4 +197,8 @@ func TestToUnsafeString(t *testing.T) { if !bytes.Equal([]byte("str"), ToUnsafeBytes(s)) { t.Fatalf(`[]bytes(%s) doesnt equal to %s `, s, s) } + s = "" + if !bytes.Equal([]byte(""), ToUnsafeBytes(s)) { + t.Fatalf(`[]bytes(%s) doesnt equal to %s `, s, s) + } } diff --git a/tools/melody/config.go b/tools/melody/config.go new file mode 100644 index 00000000..abfc7494 --- /dev/null +++ b/tools/melody/config.go @@ -0,0 +1,22 @@ +package melody + +import "time" + +// Config melody configuration struct. +type Config struct { + WriteWait time.Duration // Milliseconds until write times out. + PongWait time.Duration // Timeout for waiting on pong. + PingPeriod time.Duration // Milliseconds between pings. + MaxMessageSize int64 // Maximum size in bytes of a message. + MessageBufferSize int // The max amount of messages that can be in a sessions buffer before it starts dropping them. +} + +func newConfig() *Config { + return &Config{ + WriteWait: 60 * time.Second, + PongWait: 60 * time.Second, + PingPeriod: (60 * time.Second * 9) / 10, + MaxMessageSize: 0, + MessageBufferSize: 1, + } +} diff --git a/tools/melody/melody.go b/tools/melody/melody.go new file mode 100644 index 00000000..35a4be55 --- /dev/null +++ b/tools/melody/melody.go @@ -0,0 +1,147 @@ +package melody + +import ( + "net/http" + "sync" + "sync/atomic" + + "github.com/gorilla/websocket" +) + +type envelope struct { + t int + msg []byte +} + +type handleMessageFunc func(*Session, []byte) +type handleErrorFunc func(*Session, error) +type handleCloseFunc func(*Session, int, string) error +type handleSessionFunc func(*Session) + +// Melody implements a websocket manager. +type Melody struct { + sessionCount uint32 + Config *Config + Upgrader *websocket.Upgrader + messageHandler handleMessageFunc + messageHandlerBinary handleMessageFunc + errorHandler handleErrorFunc + closeHandler handleCloseFunc + connectHandler handleSessionFunc + disconnectHandler handleSessionFunc + pongHandler handleSessionFunc +} + +// New creates a new melody instance with default Upgrader and Config. +func New() *Melody { + upGrader := &websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + WriteBufferPool: &sync.Pool{}, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: true, + } + + return &Melody{ + sessionCount: 0, + Config: newConfig(), + Upgrader: upGrader, + messageHandler: func(*Session, []byte) {}, + messageHandlerBinary: func(*Session, []byte) {}, + errorHandler: func(*Session, error) {}, + closeHandler: nil, + connectHandler: func(*Session) {}, + disconnectHandler: func(*Session) {}, + pongHandler: func(*Session) {}, + } +} + +// HandleConnect fires fn when a session connects. +func (m *Melody) HandleConnect(fn func(*Session)) { + m.connectHandler = fn +} + +// HandleDisconnect fires fn when a session disconnects. +func (m *Melody) HandleDisconnect(fn func(*Session)) { + m.disconnectHandler = fn +} + +// HandlePong fires fn when a pong is received from a session. +func (m *Melody) HandlePong(fn func(*Session)) { + m.pongHandler = fn +} + +// HandleMessage fires fn when a text message comes in. +func (m *Melody) HandleMessage(fn func(*Session, []byte)) { + m.messageHandler = fn +} + +// HandleMessageBinary fires fn when a binary message comes in. +func (m *Melody) HandleMessageBinary(fn func(*Session, []byte)) { + m.messageHandlerBinary = fn +} + +// HandleError fires fn when a session has an error. +func (m *Melody) HandleError(fn func(*Session, error)) { + m.errorHandler = fn +} + +// HandleClose sets the handler for close messages received from the session. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close frame +// back to the session. +// +// The application must read the connection to process close messages as +// described in the section on Control Frames above. +// +// The connection read methods return a CloseError when a close frame is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close frame back to +// the session. +func (m *Melody) HandleClose(fn func(*Session, int, string) error) { + if fn != nil { + m.closeHandler = fn + } +} + +// HandleRequestWithKeys does the same as HandleRequest but populates session.Keys with keys. +func (m *Melody) HandleRequestWithKeys(w http.ResponseWriter, r *http.Request, keys map[string]interface{}) error { + conn, err := m.Upgrader.Upgrade(w, r, w.Header()) + + if err != nil { + return err + } + + session := &Session{ + Request: r, + conn: conn, + output: make(chan *envelope, m.Config.MessageBufferSize), + melody: m, + status: StatusNormal, + closeOnce: sync.Once{}, + } + for k, v := range keys { + session.Keys.Store(k, v) + } + atomic.AddUint32(&m.sessionCount, 1) + + m.connectHandler(session) + + go session.writePump() + + session.readPump() + + atomic.AddUint32(&m.sessionCount, ^uint32(0)) + + session.close() + + m.disconnectHandler(session) + + return nil +} + +// Len return the number of connected sessions. +func (m *Melody) Len() uint32 { + return atomic.LoadUint32(&m.sessionCount) +} diff --git a/tools/melody/melody_test.go b/tools/melody/melody_test.go new file mode 100644 index 00000000..b6d9f050 --- /dev/null +++ b/tools/melody/melody_test.go @@ -0,0 +1,623 @@ +package melody + +import ( + "bytes" + "math/rand" + "net/http" + "net/http/httptest" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "testing/quick" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestNewMelody(t *testing.T) { + handleConnect := handleSessionFunc(func(*Session) {}) + handleDisconnect := handleSessionFunc(func(*Session) {}) + handlePong := handleSessionFunc(func(*Session) {}) + handleMessage := handleMessageFunc(func(*Session, []byte) {}) + handleMessageBinary := handleMessageFunc(func(*Session, []byte) {}) + handleError := handleErrorFunc(func(*Session, error) {}) + handleClose := handleCloseFunc(func(*Session, int, string) error { + return nil + }) + melody := New() + melody.HandleConnect(handleConnect) + melody.HandleDisconnect(handleDisconnect) + melody.HandlePong(handlePong) + melody.HandleMessage(handleMessage) + melody.HandleMessageBinary(handleMessageBinary) + melody.HandleError(handleError) + melody.HandleClose(handleClose) + + defaultConf := &Config{ + WriteWait: 60 * time.Second, + PongWait: 60 * time.Second, + PingPeriod: (60 * time.Second * 9) / 10, + MaxMessageSize: 0, + MessageBufferSize: 1, + } + assert.Equal(t, defaultConf, melody.Config) + assert.Equal(t, uint32(0), melody.sessionCount) +} + +var TestMsg = []byte("test") + +type TestServer struct { + m *Melody +} + +func NewTestServerHandler(handler handleMessageFunc) *TestServer { + m := New() + m.HandleMessage(handler) + return &TestServer{ + m: m, + } +} + +func NewTestServer() *TestServer { + m := New() + return &TestServer{ + m: m, + } +} + +func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + _ = s.m.HandleRequestWithKeys(w, r, map[string]interface{}{"logger": logrus.New().WithField("test", "melody")}) +} + +func NewDialer(url string) (*websocket.Conn, error) { + dialer := &websocket.Dialer{} + conn, _, err := dialer.Dial(strings.Replace(url, "http", "ws", 1), nil) + return conn, err +} + +func MustNewDialer(url string) *websocket.Conn { + conn, err := NewDialer(url) + + if err != nil { + panic("could not dail websocket") + } + + return conn +} + +func TestEcho(t *testing.T) { + ws := NewTestServerHandler(func(session *Session, msg []byte) { + err := session.Write(msg) + assert.NoError(t, err) + }) + server := httptest.NewServer(ws) + defer server.Close() + + fn := func(msg string) bool { + conn := MustNewDialer(server.URL) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + err := conn.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.NoError(t, err) + retType, ret, err := conn.ReadMessage() + assert.NoError(t, err) + assert.Equal(t, websocket.TextMessage, retType) + + assert.Equal(t, msg, string(ret)) + + return true + } + + err := quick.Check(fn, nil) + + assert.Nil(t, err) +} + +func TestEchoBinary(t *testing.T) { + ws := NewTestServerHandler(func(session *Session, msg []byte) { + err := session.WriteBinary(msg) + assert.NoError(t, err) + }) + server := httptest.NewServer(ws) + defer server.Close() + + fn := func(msg string) bool { + conn := MustNewDialer(server.URL) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + err := conn.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.NoError(t, err) + retType, ret, err := conn.ReadMessage() + + assert.NoError(t, err) + assert.Equal(t, websocket.BinaryMessage, retType) + assert.True(t, bytes.Equal([]byte(msg), ret)) + + return true + } + + err := quick.Check(fn, nil) + + assert.Nil(t, err) +} + +func TestWriteClosedServer(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + + server := httptest.NewServer(ws) + defer server.Close() + + ws.m.HandleConnect(func(s *Session) { + err := s.Close() + assert.NoError(t, err) + }) + + ws.m.HandleDisconnect(func(s *Session) { + err := s.Write([]byte("msg")) + assert.NotNil(t, err) + close(done) + }) + + conn := MustNewDialer(server.URL) + _, _, err := conn.ReadMessage() + assert.Error(t, err) + defer func() { + err = conn.Close() + assert.Nil(t, err) + }() + + <-done +} + +func TestWriteClosedClient(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + + server := httptest.NewServer(ws) + defer server.Close() + + ws.m.HandleDisconnect(func(s *Session) { + err := s.Write([]byte("msg")) + assert.Error(t, err) + close(done) + }) + + conn := MustNewDialer(server.URL) + err := conn.Close() + assert.NoError(t, err) + <-done +} + +func TestUpgrader(t *testing.T) { + ws := NewTestServer() + ws.m.HandleMessage(func(session *Session, msg []byte) { + err := session.Write(msg) + assert.NoError(t, err) + }) + + server := httptest.NewServer(ws) + defer server.Close() + + ws.m.Upgrader = &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return false }, + } + + _, err := NewDialer(server.URL) + + assert.ErrorIs(t, err, websocket.ErrBadHandshake) +} + +func TestLen(t *testing.T) { + //nolint:staticcheck + rand.Seed(time.Now().UnixNano()) + + connect := int(rand.Int31n(100)) + disconnect := rand.Float32() + conns := make([]*websocket.Conn, connect) + + defer func() { + for _, conn := range conns { + if conn != nil { + err := conn.Close() + assert.NoError(t, err) + } + } + }() + + ws := NewTestServer() + + server := httptest.NewServer(ws) + defer server.Close() + + disconnected := 0 + for i := 0; i < connect; i++ { + conn := MustNewDialer(server.URL) + + if rand.Float32() < disconnect { + conns[i] = nil + disconnected++ + err := conn.Close() + assert.NoError(t, err) + continue + } + + conns[i] = conn + } + + time.Sleep(time.Millisecond) + + connected := connect - disconnected + + assert.Equal(t, uint32(connected), ws.m.Len()) +} + +func TestPingPong(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + ws.m.Config.PingPeriod = time.Millisecond + + ws.m.HandlePong(func(s *Session) { + close(done) + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + go func() { + _, _, err := conn.NextReader() + if err != nil { + return + } + }() + + <-done +} + +func TestHandleClose(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + ws.m.Config.PingPeriod = time.Millisecond + + ws.m.HandleClose(func(s *Session, code int, text string) error { + close(done) + return nil + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + + err := conn.WriteMessage(websocket.CloseMessage, nil) + assert.NoError(t, err) + <-done +} + +func TestHandleError(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + + ws.m.HandleError(func(s *Session, err error) { + var closeError *websocket.CloseError + assert.ErrorAs(t, err, &closeError) + close(done) + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + + err := conn.Close() + assert.NoError(t, err) + <-done +} + +func TestHandleErrorWrite(t *testing.T) { + writeError := make(chan struct{}) + disconnect := make(chan struct{}) + + ws := NewTestServer() + ws.m.Config.WriteWait = 0 + ws.m.Config.PingPeriod = time.Millisecond * 100 + ws.m.Config.PongWait = time.Millisecond * 100 + ws.m.HandleConnect(func(s *Session) { + err := s.Write(TestMsg) + assert.Nil(t, err) + }) + once := sync.Once{} + + ws.m.HandleError(func(s *Session, err error) { + assert.NotNil(t, err) + + if os.IsTimeout(err) { + once.Do(func() { + close(writeError) + }) + } + }) + + ws.m.HandleDisconnect(func(s *Session) { + close(disconnect) + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + go func() { + _, _, err := conn.NextReader() + if err != nil { + return + } + }() + + <-writeError + <-disconnect +} + +func TestErrSessionClosed(t *testing.T) { + res := make(chan *Session) + + ws := NewTestServer() + + ws.m.HandleConnect(func(s *Session) { + err := s.Close() + assert.NoError(t, err) + }) + + ws.m.HandleDisconnect(func(s *Session) { + res <- s + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + go func() { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + }() + + s := <-res + + assert.True(t, s.IsClosed()) + assert.ErrorIs(t, s.Write(TestMsg), ErrSessionClosed) + assert.ErrorIs(t, s.WriteBinary(TestMsg), ErrSessionClosed) + assert.ErrorIs(t, s.Close(), ErrSessionClosed) + + assert.ErrorIs(t, s.writeRaw(&envelope{}), ErrWriteClosed) + s.writeMessage(&envelope{}) +} + +func TestSessionKeys(t *testing.T) { + ws := NewTestServer() + + ws.m.HandleConnect(func(session *Session) { + session.Set("stamp", time.Now().UnixNano()) + }) + ws.m.HandleMessage(func(session *Session, msg []byte) { + stamp := session.MustGet("stamp").(int64) + err := session.Write([]byte(strconv.Itoa(int(stamp)))) + assert.NoError(t, err) + }) + server := httptest.NewServer(ws) + defer server.Close() + + fn := func(msg string) bool { + conn := MustNewDialer(server.URL) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + err := conn.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.NoError(t, err) + _, ret, err := conn.ReadMessage() + + assert.NoError(t, err) + + stamp, err := strconv.Atoi(string(ret)) + + assert.Nil(t, err) + + diff := int(time.Now().UnixNano()) - stamp + + assert.Greater(t, diff, 0) + + return true + } + + assert.Nil(t, quick.Check(fn, nil)) +} + +func TestSessionKeysConcurrent(t *testing.T) { + ss := make(chan *Session) + + ws := NewTestServer() + + ws.m.HandleConnect(func(s *Session) { + ss <- s + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + s := <-ss + + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(1) + + go func() { + s.Set("test", TestMsg) + + v1, exists := s.Get("test") + + assert.True(t, exists) + assert.Equal(t, v1, TestMsg) + + v2 := s.MustGet("test") + + assert.Equal(t, v1, v2) + + wg.Done() + }() + } + + wg.Wait() + + for i := 0; i < 100; i++ { + wg.Add(1) + + go func() { + s.UnSet("test") + + _, exists := s.Get("test") + + assert.False(t, exists) + + wg.Done() + }() + } + + wg.Wait() +} + +func TestConcurrentMessageHandling(t *testing.T) { + testTimeout := func(cmh bool, msgType int) bool { + base := time.Millisecond * 100 + done := make(chan struct{}) + + handler := func(s *Session, msg []byte) { + if len(msg) == 0 { + done <- struct{}{} + return + } + + time.Sleep(base * 2) + } + + messageHandler := func(s *Session, msg []byte) { + if cmh { + go handler(s, msg) + } else { + handler(s, msg) + } + } + + ws := NewTestServerHandler(func(session *Session, msg []byte) {}) + if msgType == websocket.TextMessage { + ws.m.HandleMessage(messageHandler) + } else { + ws.m.HandleMessageBinary(messageHandler) + } + ws.m.Config.PingPeriod = base / 2 + ws.m.Config.PongWait = base + + var errorSet atomic.Bool + ws.m.HandleError(func(s *Session, err error) { + errorSet.Store(true) + done <- struct{}{} + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + err := conn.WriteMessage(msgType, TestMsg) + assert.NoError(t, err) + err = conn.WriteMessage(msgType, TestMsg) + assert.NoError(t, err) + + time.Sleep(base / 4) + + err = conn.WriteMessage(msgType, nil) + assert.NoError(t, err) + + <-done + + return errorSet.Load() + } + + t.Run("text should error", func(t *testing.T) { + errorSet := testTimeout(false, websocket.TextMessage) + + if !errorSet { + t.FailNow() + } + }) + + t.Run("text should not error", func(t *testing.T) { + errorSet := testTimeout(true, websocket.TextMessage) + + if errorSet { + t.FailNow() + } + }) + + t.Run("binary should error", func(t *testing.T) { + errorSet := testTimeout(false, websocket.BinaryMessage) + + if !errorSet { + t.FailNow() + } + }) + + t.Run("binary should not error", func(t *testing.T) { + errorSet := testTimeout(true, websocket.BinaryMessage) + + if errorSet { + t.FailNow() + } + }) +} diff --git a/tools/melody/session.go b/tools/melody/session.go new file mode 100644 index 00000000..9d4cd316 --- /dev/null +++ b/tools/melody/session.go @@ -0,0 +1,214 @@ +package melody + +import ( + "errors" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" +) + +const ( + StatusNormal = uint32(1) + StatusStop = uint32(2) +) + +var ( + ErrSessionClosed = errors.New("session is closed") + ErrWriteClosed = errors.New("tried to write to a closed session") +) + +// Session wrapper around websocket connections. +type Session struct { + Request *http.Request + Keys sync.Map + conn *websocket.Conn + output chan *envelope + melody *Melody + status uint32 + closeOnce sync.Once + lastReadTime time.Time +} + +func (s *Session) writeMessage(message *envelope) { + if s.closed() { + s.melody.errorHandler(s, ErrWriteClosed) + return + } + defer func() { + if recover() != nil { + s.melody.errorHandler(s, ErrWriteClosed) + } + }() + s.output <- message +} + +func (s *Session) writeRaw(message *envelope) error { + if s.closed() { + return ErrWriteClosed + } + + // no error returned from SetWriteDeadline + _ = s.conn.SetWriteDeadline(time.Now().Add(s.melody.Config.WriteWait)) + + err := s.conn.WriteMessage(message.t, message.msg) + + if err != nil { + return err + } + + return nil +} + +func (s *Session) closed() bool { + return atomic.LoadUint32(&s.status) == StatusStop +} + +func (s *Session) close() { + s.closeOnce.Do(func() { + atomic.StoreUint32(&s.status, StatusStop) + _ = s.conn.Close() + close(s.output) + }) +} + +func (s *Session) writePump() { + ticker := time.NewTicker(s.melody.Config.PingPeriod) + defer ticker.Stop() + for { + select { + case msg, ok := <-s.output: + if !ok { + // The channel has been closed, this means the session is closed, return to stop the writePump + return + } + + err := s.writeRaw(msg) + + if err != nil { + s.melody.errorHandler(s, err) + return + } + + if msg.t == websocket.CloseMessage { + return + } + case <-ticker.C: + _ = s.writeRaw(&envelope{t: websocket.PingMessage, msg: []byte{}}) + } + } +} + +func (s *Session) readPump() { + s.conn.SetReadLimit(s.melody.Config.MaxMessageSize) + s.setReadDeadline() + + s.conn.SetPongHandler(func(string) error { + s.setReadDeadline() + s.melody.pongHandler(s) + return nil + }) + + if s.melody.closeHandler != nil { + s.conn.SetCloseHandler(func(code int, text string) error { + return s.melody.closeHandler(s, code, text) + }) + } + + for { + t, message, err := s.conn.ReadMessage() + + if err != nil { + s.melody.errorHandler(s, err) + break + } + s.setReadDeadline() + if t == websocket.TextMessage { + s.melody.messageHandler(s, message) + } + + if t == websocket.BinaryMessage { + s.melody.messageHandlerBinary(s, message) + } + } +} + +func (s *Session) setReadDeadline() { + now := time.Now() + if now.Sub(s.lastReadTime) >= time.Second { + s.lastReadTime = now + err := s.conn.SetReadDeadline(s.lastReadTime.Add(s.melody.Config.PongWait + s.melody.Config.PingPeriod)) + if err != nil { + if logger, exists := s.Get("logger"); exists { + logger.(*logrus.Entry).Errorf("setReadDeadline error: %v", err) + } + } + } +} + +// Write writes message to session. +func (s *Session) Write(msg []byte) error { + if s.closed() { + return ErrSessionClosed + } + + s.writeMessage(&envelope{t: websocket.TextMessage, msg: msg}) + + return nil +} + +// WriteBinary writes a binary message to session. +func (s *Session) WriteBinary(msg []byte) error { + if s.closed() { + return ErrSessionClosed + } + + s.writeMessage(&envelope{t: websocket.BinaryMessage, msg: msg}) + + return nil +} + +// Close closes session. +func (s *Session) Close() error { + if s.closed() { + return ErrSessionClosed + } + + s.writeMessage(&envelope{t: websocket.CloseMessage, msg: []byte{}}) + + return nil +} + +// Set is used to store a new key/value pair exclusivelly for this session. +// It also lazy initializes s.Keys if it was not used previously. +func (s *Session) Set(key string, value interface{}) { + s.Keys.Store(key, value) +} + +// Get returns the value for the given key, ie: (value, true). +// If the value does not exists it returns (nil, false) +func (s *Session) Get(key string) (value interface{}, exists bool) { + return s.Keys.Load(key) +} + +// MustGet returns the value for the given key if it exists, otherwise it panics. +func (s *Session) MustGet(key string) interface{} { + if value, exists := s.Get(key); exists { + return value + } + + panic("Key \"" + key + "\" does not exist") +} + +// UnSet will delete the key and has no return value +func (s *Session) UnSet(key string) { + s.Keys.Delete(key) +} + +// IsClosed returns the status of the connection. +func (s *Session) IsClosed() bool { + return s.closed() +} From cf5e8e7159820b38e2b50a802aa4e06ed673f1f9 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 22 Nov 2024 11:41:55 +0800 Subject: [PATCH 10/48] test: add unit test --- controller/ws/ws/query_test.go | 32 ++++++- controller/ws/ws/raw_test.go | 155 ++++++++++++++++++++++++++++++++ controller/ws/ws/stmt_test.go | 160 +++++++++++++++++++++++++++++++++ 3 files changed, 344 insertions(+), 3 deletions(-) diff --git a/controller/ws/ws/query_test.go b/controller/ws/ws/query_test.go index aadab537..825ca8d4 100644 --- a/controller/ws/ws/query_test.go +++ b/controller/ws/ws/query_test.go @@ -106,6 +106,11 @@ func TestMode(t *testing.T) { } +func TestWrongConnect(t *testing.T) { + // mock tool.GetWhitelist return error + +} + func TestWsQuery(t *testing.T) { s := httptest.NewServer(router) defer s.Close() @@ -147,15 +152,26 @@ func TestWsQuery(t *testing.T) { assert.NoError(t, err) assert.Equal(t, uint64(1), connResp.ReqID) assert.Equal(t, 0, connResp.Code, connResp.Message) + assert.Equal(t, Connect, connResp.Action) - // query - queryReq := queryRequest{ReqID: 2, Sql: "select * from stb1"} + // wrong sql + queryReq := queryRequest{ReqID: 2, Sql: "wrong sql"} resp, err = doWebSocket(ws, WSQuery, &queryReq) assert.NoError(t, err) var queryResp queryResponse err = json.Unmarshal(resp, &queryResp) assert.NoError(t, err) assert.Equal(t, uint64(2), queryResp.ReqID) + assert.NotEqual(t, 0, queryResp.Code) + assert.Equal(t, WSQuery, queryResp.Action) + + // query + queryReq = queryRequest{ReqID: 2, Sql: "select * from stb1"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), queryResp.ReqID) assert.Equal(t, 0, queryResp.Code, queryResp.Message) // fetch @@ -510,7 +526,17 @@ func TestWsQuery(t *testing.T) { assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) assert.Equal(t, true, fetchResp.Completed) - time.Sleep(time.Second) + + // insert + queryReq = queryRequest{ReqID: 14, Sql: `insert into t4 using stb1 tags ('{\"table\":\"t4\"}') values (now-2s,true,2,3,4,5,6,7,8,9,10,11,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now-1s,false,12,13,14,15,16,17,18,19,110,111,'中文\"binary','中文nchar','\xaabbcc','point(100 100)')(now,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)`} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, uint64(14), queryResp.ReqID) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + assert.Equal(t, WSQuery, queryResp.Action) + assert.Equal(t, 3, queryResp.AffectedRows) } type FetchRawBlockResponse struct { diff --git a/controller/ws/ws/raw_test.go b/controller/ws/ws/raw_test.go index 5b1c40ad..769eac2b 100644 --- a/controller/ws/ws/raw_test.go +++ b/controller/ws/ws/raw_test.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/tools/generator" ) func TestWSTMQWriteRaw(t *testing.T) { @@ -121,4 +122,158 @@ func TestWSTMQWriteRaw(t *testing.T) { assert.Equal(t, expect[rowIndex][i], values[i]) } } + // wrong meta type + buffer.Reset() + metaType = 0 + reqID := uint64(generator.GetReqID()) + wstool.WriteUint64(&buffer, reqID) // req id + wstool.WriteUint64(&buffer, 0) // message id + wstool.WriteUint64(&buffer, uint64(TMQRawMessage)) + wstool.WriteUint32(&buffer, length) + wstool.WriteUint16(&buffer, metaType) + buffer.Write(data) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &tmqResp) + assert.NoError(t, err) + assert.Equal(t, reqID, tmqResp.ReqID) + assert.NotEqual(t, 0, tmqResp.Code) + assert.Equal(t, getActionString(TMQRawMessage), tmqResp.Action) +} + +func TestWSWriteRawBlockError(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + code, message := doRestful("create database if not exists test_ws_write_raw_block_error", "") + assert.Equal(t, 0, code, message) + defer func() { + code, message := doRestful("drop database if exists test_ws_write_raw_block_error", "") + assert.Equal(t, 0, code, message) + }() + code, message = doRestful("create table test_ws_write_raw_block_error.tb1 (ts timestamp,v int)", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("insert into test_ws_write_raw_block_error.tb1 values(now , 1)", "") + assert.Equal(t, 0, code, message) + // connect without db + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: ""} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // query + queryReq := queryRequest{ReqID: 2, Sql: "select * from test_ws_write_raw_block_error.tb1"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), queryResp.ReqID) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch + fetchReq := fetchRequest{ReqID: 3, ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + var fetchResp fetchResponse + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), fetchResp.ReqID) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + assert.Equal(t, 1, fetchResp.Rows) + + // fetch block + fetchBlockReq := fetchBlockRequest{ReqID: 4, ID: queryResp.ID} + fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + + fetchReq = fetchRequest{ReqID: 5, ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), fetchResp.ReqID) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + assert.Equal(t, true, fetchResp.Completed) + + // write raw block + var buffer bytes.Buffer + wstool.WriteUint64(&buffer, 300) // req id + wstool.WriteUint64(&buffer, 400) // message id + wstool.WriteUint64(&buffer, uint64(RawBlockMessage)) // action + wstool.WriteUint32(&buffer, uint32(fetchResp.Rows)) // rows + wstool.WriteUint16(&buffer, uint16(2)) // table name length + buffer.WriteString("t2") // table name + buffer.Write(fetchBlockResp[16:]) // raw block + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var writeResp commonResp + err = json.Unmarshal(resp, &writeResp) + assert.NoError(t, err) + assert.Equal(t, getActionString(RawBlockMessage), writeResp.Action) + assert.NotEqual(t, 0, writeResp.Code) + + // write raw block with fields + buffer.Reset() + wstool.WriteUint64(&buffer, 300) // req id + wstool.WriteUint64(&buffer, 400) // message id + wstool.WriteUint64(&buffer, uint64(RawBlockMessageWithFields)) // action + wstool.WriteUint32(&buffer, uint32(fetchResp.Rows)) // rows + wstool.WriteUint16(&buffer, uint16(2)) // table name length + buffer.WriteString("t2") // table name + buffer.Write(fetchBlockResp[16:]) // raw block + fields := []byte{ + // ts + 0x74, 0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x09, + // padding + 0x00, 0x00, + // bytes + 0x08, 0x00, 0x00, 0x00, + + // v1 + 0x76, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + // type + 0x04, + // padding + 0x00, 0x00, + // bytes + 0x04, 0x00, 0x00, 0x00, + } + buffer.Write(fields) + err = ws.WriteMessage(websocket.BinaryMessage, buffer.Bytes()) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + err = json.Unmarshal(resp, &writeResp) + assert.NoError(t, err) + assert.Equal(t, getActionString(RawBlockMessageWithFields), writeResp.Action) + assert.NotEqual(t, 0, writeResp.Code) + } diff --git a/controller/ws/ws/stmt_test.go b/controller/ws/ws/stmt_test.go index 9115f416..2687cba5 100644 --- a/controller/ws/ws/stmt_test.go +++ b/controller/ws/ws/stmt_test.go @@ -20,6 +20,7 @@ import ( stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" "github.com/taosdata/driver-go/v3/types" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/taosdata/taosadapter/v3/tools/parseblock" ) @@ -898,3 +899,162 @@ func TestStmtGetParams(t *testing.T) { assert.Equal(t, 9, getParamsResp.DataType) assert.Equal(t, 8, getParamsResp.Length) } + +func TestStmtInvalidStmtID(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + reqID := uint64(generator.GetReqID()) + connReq := connRequest{ReqID: reqID, User: "root", Password: "taosdata"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, Connect, connResp.Action) + assert.Equal(t, reqID, connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // prepare + reqID = uint64(generator.GetReqID()) + prepareReq := stmtPrepareRequest{ReqID: reqID, StmtID: 0, SQL: "insert into ? using test_ws_stmt_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} + resp, err = doWebSocket(ws, STMTPrepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmtPrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, STMTPrepare, prepareResp.Action) + assert.Equal(t, reqID, prepareResp.ReqID) + assert.NotEqual(t, 0, prepareResp.Code) + + // set table name + reqID = uint64(generator.GetReqID()) + setTableNameReq := stmtSetTableNameRequest{ReqID: reqID, StmtID: prepareResp.StmtID, Name: "d1"} + resp, err = doWebSocket(ws, STMTSetTableName, &setTableNameReq) + assert.NoError(t, err) + var setTableNameResp stmtSetTableNameResponse + err = json.Unmarshal(resp, &setTableNameResp) + assert.NoError(t, err) + assert.Equal(t, STMTSetTableName, setTableNameResp.Action) + assert.Equal(t, reqID, setTableNameResp.ReqID) + assert.NotEqual(t, 0, setTableNameResp.Code) + + // set tags + reqID = uint64(generator.GetReqID()) + setTagsReq := stmtSetTagsRequest{ReqID: reqID, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTSetTags, &setTagsReq) + assert.NoError(t, err) + var setTagsResp stmtSetTagsResponse + err = json.Unmarshal(resp, &setTagsResp) + assert.NoError(t, err) + assert.Equal(t, STMTSetTags, setTagsResp.Action) + assert.Equal(t, reqID, setTagsResp.ReqID) + assert.NotEqual(t, 0, setTagsResp.Code) + + // bind + reqID = uint64(generator.GetReqID()) + bindReq := stmtBindRequest{ReqID: reqID, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTBind, &bindReq) + assert.NoError(t, err) + var bindResp stmtBindResponse + err = json.Unmarshal(resp, &bindResp) + assert.NoError(t, err) + assert.Equal(t, STMTBind, bindResp.Action) + assert.Equal(t, reqID, bindResp.ReqID) + assert.NotEqual(t, 0, bindResp.Code) + + // add batch + reqID = uint64(generator.GetReqID()) + addBatchReq := stmtAddBatchRequest{ReqID: reqID, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTAddBatch, &addBatchReq) + assert.NoError(t, err) + var addBatchResp stmtAddBatchResponse + err = json.Unmarshal(resp, &addBatchResp) + assert.NoError(t, err) + assert.Equal(t, STMTAddBatch, addBatchResp.Action) + assert.Equal(t, reqID, addBatchResp.ReqID) + assert.NotEqual(t, 0, addBatchResp.Code) + + // exec + reqID = uint64(generator.GetReqID()) + execReq := stmtExecRequest{ReqID: reqID, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTExec, &execReq) + assert.NoError(t, err) + var execResp stmtExecResponse + err = json.Unmarshal(resp, &execResp) + assert.NoError(t, err) + assert.Equal(t, STMTExec, execResp.Action) + assert.Equal(t, reqID, execResp.ReqID) + assert.NotEqual(t, 0, execResp.Code) + + // use result + reqID = uint64(generator.GetReqID()) + useResultReq := stmtUseResultRequest{ReqID: reqID, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTUseResult, &useResultReq) + assert.NoError(t, err) + var useResultResp stmtUseResultResponse + err = json.Unmarshal(resp, &useResultResp) + assert.NoError(t, err) + assert.Equal(t, STMTUseResult, useResultResp.Action) + assert.Equal(t, reqID, useResultResp.ReqID) + assert.NotEqual(t, 0, useResultResp.Code) + + // get tag fields + reqID = uint64(generator.GetReqID()) + getTagFieldsReq := stmtGetTagFieldsRequest{ReqID: reqID, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTGetTagFields, &getTagFieldsReq) + assert.NoError(t, err) + var getTagFieldsResp stmtGetTagFieldsResponse + err = json.Unmarshal(resp, &getTagFieldsResp) + assert.NoError(t, err) + assert.Equal(t, STMTGetTagFields, getTagFieldsResp.Action) + assert.Equal(t, reqID, getTagFieldsResp.ReqID) + assert.NotEqual(t, 0, getTagFieldsResp.Code) + + // get col fields + reqID = uint64(generator.GetReqID()) + getColFieldsReq := stmtGetColFieldsRequest{ReqID: reqID, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTGetColFields, &getColFieldsReq) + assert.NoError(t, err) + var getColFieldsResp stmtGetColFieldsResponse + err = json.Unmarshal(resp, &getColFieldsResp) + assert.NoError(t, err) + assert.Equal(t, STMTGetColFields, getColFieldsResp.Action) + assert.Equal(t, reqID, getColFieldsResp.ReqID) + assert.NotEqual(t, 0, getColFieldsResp.Code) + + // num params + reqID = uint64(generator.GetReqID()) + numParamsReq := stmtNumParamsRequest{ReqID: reqID, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMTNumParams, &numParamsReq) + assert.NoError(t, err) + var numParamsResp stmtNumParamsResponse + err = json.Unmarshal(resp, &numParamsResp) + assert.NoError(t, err) + assert.Equal(t, STMTNumParams, numParamsResp.Action) + assert.Equal(t, reqID, numParamsResp.ReqID) + assert.NotEqual(t, 0, numParamsResp.Code) + + // get param + reqID = uint64(generator.GetReqID()) + getParamsReq := stmtGetParamRequest{ReqID: reqID, StmtID: prepareResp.StmtID, Index: 0} + resp, err = doWebSocket(ws, STMTGetParam, &getParamsReq) + assert.NoError(t, err) + var getParamsResp stmtGetParamResponse + err = json.Unmarshal(resp, &getParamsResp) + assert.NoError(t, err) + assert.Equal(t, STMTGetParam, getParamsResp.Action) + assert.Equal(t, reqID, getParamsResp.ReqID) + assert.NotEqual(t, 0, getParamsResp.Code) + +} From fc6687f921699cd114efbd2ca60a7a7eed7732c5 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 22 Nov 2024 12:42:11 +0800 Subject: [PATCH 11/48] test: compatible with go version 1.17 --- tools/melody/melody_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/melody/melody_test.go b/tools/melody/melody_test.go index b6d9f050..53c8cce5 100644 --- a/tools/melody/melody_test.go +++ b/tools/melody/melody_test.go @@ -559,9 +559,9 @@ func TestConcurrentMessageHandling(t *testing.T) { ws.m.Config.PingPeriod = base / 2 ws.m.Config.PongWait = base - var errorSet atomic.Bool + var errorSet uint32 ws.m.HandleError(func(s *Session, err error) { - errorSet.Store(true) + atomic.StoreUint32(&errorSet, 1) done <- struct{}{} }) @@ -586,7 +586,7 @@ func TestConcurrentMessageHandling(t *testing.T) { <-done - return errorSet.Load() + return atomic.LoadUint32(&errorSet) != 0 } t.Run("text should error", func(t *testing.T) { From 51d0f53d0bde04efb24aa5d60a524658d5313f6c Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 22 Nov 2024 12:49:55 +0800 Subject: [PATCH 12/48] enh: wait for all action finish before close --- controller/ws/tmq/tmq.go | 5 +++-- controller/ws/ws/handler.go | 6 +++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index ec6b3c94..b9656ac1 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -1235,10 +1235,11 @@ func (t *TMQ) Close(logger *logrus.Entry) { }() select { case <-ctx.Done(): - logger.Error("wait for all goroutines to exit timeout") + logger.Warn("wait stop over 1 minute") + <-done case <-done: - logger.Debug("all goroutines exit") } + logger.Debug("wait stop done") isDebug := log.IsDebug() defer func() { diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index cb7011bd..6e324337 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -157,7 +157,7 @@ func (h *messageHandler) stop() { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - waitCh := make(chan struct{}, 1) + waitCh := make(chan struct{}) go func() { h.wait.Wait() close(waitCh) @@ -165,8 +165,12 @@ func (h *messageHandler) stop() { select { case <-ctx.Done(): + h.logger.Warn("wait stop over 1 minute") + <-waitCh + break case <-waitCh: } + h.logger.Debugf("wait stop done") // clean query result and stmt h.queryResults.FreeAll(h.logger) h.stmts.FreeAll(h.logger) From bfa7c47906a942eebd4814a2d9fe05e4867c2b86 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 25 Nov 2024 14:58:33 +0800 Subject: [PATCH 13/48] enh: use jsoniter instead of std json --- controller/ws/ws/fetch.go | 10 +- controller/ws/ws/free.go | 2 +- controller/ws/ws/handler.go | 265 ++++++++++++++++--------------- controller/ws/ws/handler_test.go | 22 +++ controller/ws/ws/misc.go | 4 +- controller/ws/ws/query.go | 8 +- controller/ws/ws/schemaless.go | 4 +- controller/ws/ws/stmt.go | 62 ++++---- controller/ws/ws/stmt2.go | 18 +-- controller/ws/wstool/resp.go | 11 +- 10 files changed, 222 insertions(+), 184 deletions(-) diff --git a/controller/ws/ws/fetch.go b/controller/ws/ws/fetch.go index 7bb80c31..9f52a0cf 100644 --- a/controller/ws/ws/fetch.go +++ b/controller/ws/ws/fetch.go @@ -48,7 +48,7 @@ type fetchResponse struct { Rows int `json:"rows"` } -func (h *messageHandler) fetch(ctx context.Context, session *melody.Session, action string, req *fetchRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) fetch(ctx context.Context, session *melody.Session, action string, req fetchRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("get result by id, id:%d", req.ID) item := h.queryResults.Get(req.ID) if item == nil { @@ -74,7 +74,7 @@ func (h *messageHandler) fetch(ctx context.Context, session *melody.Session, act item.Unlock() logger.Trace("fetch raw block completed") h.queryResults.FreeResultByID(req.ID, logger) - resp := fetchResponse{ + resp := &fetchResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -101,7 +101,7 @@ func (h *messageHandler) fetch(ctx context.Context, session *melody.Session, act logger.Debugf("get_raw_block result:%p, cost:%s", item.Block, log.GetLogDuration(isDebug, s)) item.Size = result.N item.Unlock() - resp := fetchResponse{ + resp := &fetchResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -117,7 +117,7 @@ type fetchBlockRequest struct { ID uint64 `json:"id"` } -func (h *messageHandler) fetchBlock(ctx context.Context, session *melody.Session, action string, req *fetchBlockRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) fetchBlock(ctx context.Context, session *melody.Session, action string, req fetchBlockRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("fetch block, id:%d", req.ID) item, locked := h.resultValidateAndLock(ctx, session, action, req.ReqID, req.ID, logger) if !locked { @@ -227,7 +227,7 @@ type numFieldsResponse struct { NumFields int `json:"num_fields"` } -func (h *messageHandler) numFields(ctx context.Context, session *melody.Session, action string, req *numFieldsRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) numFields(ctx context.Context, session *melody.Session, action string, req numFieldsRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("num fields, result_id:%d", req.ResultID) item, locked := h.resultValidateAndLock(ctx, session, action, req.ReqID, req.ResultID, logger) if !locked { diff --git a/controller/ws/ws/free.go b/controller/ws/ws/free.go index baa13f48..06479d9b 100644 --- a/controller/ws/ws/free.go +++ b/controller/ws/ws/free.go @@ -9,7 +9,7 @@ type freeResultRequest struct { ID uint64 `json:"id"` } -func (h *messageHandler) freeResult(req *freeResultRequest, logger *logrus.Entry) { +func (h *messageHandler) freeResult(req freeResultRequest, logger *logrus.Entry) { logger.Tracef("free result by id, id:%d", req.ID) h.queryResults.FreeResultByID(req.ID, logger) } diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 6e324337..e43abed7 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -9,6 +9,7 @@ import ( "time" "unsafe" + jsoniter "github.com/json-iterator/go" "github.com/sirupsen/logrus" "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" @@ -17,10 +18,8 @@ import ( "github.com/taosdata/taosadapter/v3/db/tool" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools" - "github.com/taosdata/taosadapter/v3/tools/bytesutil" "github.com/taosdata/taosadapter/v3/tools/iptool" "github.com/taosdata/taosadapter/v3/tools/melody" - "github.com/tidwall/gjson" ) type messageHandler struct { @@ -147,7 +146,6 @@ func (h *messageHandler) Close() { } type Request struct { - ReqID uint64 `json:"req_id"` Action string `json:"action"` Args json.RawMessage `json:"args"` } @@ -181,30 +179,36 @@ func (h *messageHandler) stop() { }) } +var jsonIter = jsoniter.ConfigCompatibleWithStandardLibrary + func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) h.logger.Debugf("get ws message data:%s", data) - jsonStr := bytesutil.ToUnsafeString(data) - action := gjson.Get(jsonStr, "action").String() - args := gjson.Get(jsonStr, "args") - if action == "" { - reqID := getReqID(args) + var request Request + err := jsonIter.Unmarshal(data, &request) + if err != nil { + h.logger.Errorf("unmarshal request error, request:%s, err:%s", data, err) + commonErrorResponse(ctx, session, h.logger, "", 0, 0xffff, "unmarshal request error") + return + } + action := request.Action + if request.Action == "" { + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, "", reqID, 0xffff, "request no action") return } - argsBytes := bytesutil.ToUnsafeBytes(args.Raw) // no need connection actions - switch action { + switch request.Action { case wstool.ClientVersion: wstool.WSWriteVersion(session, h.logger) return case Connect: action = Connect var req connRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal connect request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal connect request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, Connect, reqID, 0xffff, "unmarshal connect request error") return } @@ -212,14 +216,14 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.connect(ctx, session, action, &req, logger, log.IsDebug()) + h.connect(ctx, session, action, req, logger, log.IsDebug()) return } // check connection if h.conn == nil { h.logger.Errorf("server not connected") - reqID := getReqID(args) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "server not connected") return } @@ -230,9 +234,9 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case WSQuery: action = WSQuery var req queryRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal query request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal query request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal query request error") return } @@ -240,13 +244,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.query(ctx, session, action, &req, logger, log.IsDebug()) + h.query(ctx, session, action, req, logger, log.IsDebug()) case WSFetch: action = WSFetch var req fetchRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal fetch request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal fetch request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal fetch request error") return } @@ -254,13 +258,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.fetch(ctx, session, action, &req, logger, log.IsDebug()) + h.fetch(ctx, session, action, req, logger, log.IsDebug()) case WSFetchBlock: action = WSFetchBlock var req fetchBlockRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal fetch block request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal fetch block request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal fetch block request error") return } @@ -268,13 +272,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.fetchBlock(ctx, session, action, &req, logger, log.IsDebug()) + h.fetchBlock(ctx, session, action, req, logger, log.IsDebug()) case WSFreeResult: action = WSFreeResult var req freeResultRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal free result request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal free result request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal free result request error") return } @@ -282,13 +286,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.freeResult(&req, logger) + h.freeResult(req, logger) case WSNumFields: action = WSNumFields var req numFieldsRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal num fields request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal num fields request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal num fields request error") return } @@ -296,14 +300,14 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.numFields(ctx, session, action, &req, logger, log.IsDebug()) + h.numFields(ctx, session, action, req, logger, log.IsDebug()) // schemaless case SchemalessWrite: action = SchemalessWrite var req schemalessWriteRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal schemaless insert request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal schemaless insert request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal schemaless insert request error") return } @@ -311,14 +315,14 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.schemalessWrite(ctx, session, action, &req, logger, log.IsDebug()) + h.schemalessWrite(ctx, session, action, req, logger, log.IsDebug()) // stmt case STMTInit: action = STMTInit var req stmtInitRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt init request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt init request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt init request error") return } @@ -326,13 +330,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtInit(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtInit(ctx, session, action, req, logger, log.IsDebug()) case STMTPrepare: action = STMTPrepare var req stmtPrepareRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt prepare request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt prepare request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt prepare request error") return } @@ -340,13 +344,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtPrepare(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtPrepare(ctx, session, action, req, logger, log.IsDebug()) case STMTSetTableName: action = STMTSetTableName var req stmtSetTableNameRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt set table name request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt set table name request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt set table name request error") return } @@ -354,13 +358,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtSetTableName(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtSetTableName(ctx, session, action, req, logger, log.IsDebug()) case STMTSetTags: action = STMTSetTags var req stmtSetTagsRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt set tags request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt set tags request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt set tags request error") return } @@ -368,13 +372,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtSetTags(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtSetTags(ctx, session, action, req, logger, log.IsDebug()) case STMTBind: action = STMTBind var req stmtBindRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt bind request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt bind request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt bind request error") return } @@ -382,13 +386,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtBind(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtBind(ctx, session, action, req, logger, log.IsDebug()) case STMTAddBatch: action = STMTAddBatch var req stmtAddBatchRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt add batch request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt add batch request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt add batch request error") return } @@ -396,13 +400,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtAddBatch(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtAddBatch(ctx, session, action, req, logger, log.IsDebug()) case STMTExec: action = STMTExec var req stmtExecRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt exec request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt exec request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt exec request error") return } @@ -410,13 +414,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtExec(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtExec(ctx, session, action, req, logger, log.IsDebug()) case STMTClose: action = STMTClose var req stmtCloseRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt close request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt close request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt close request error") return } @@ -424,13 +428,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtClose(ctx, session, action, &req, logger) + h.stmtClose(ctx, session, action, req, logger) case STMTGetTagFields: action = STMTGetTagFields var req stmtGetTagFieldsRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt get tag fields request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt get tag fields request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get tag fields request error") return } @@ -438,13 +442,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtGetTagFields(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtGetTagFields(ctx, session, action, req, logger, log.IsDebug()) case STMTGetColFields: action = STMTGetColFields var req stmtGetColFieldsRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt get col fields request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt get col fields request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get col fields request error") return } @@ -452,13 +456,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtGetColFields(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtGetColFields(ctx, session, action, req, logger, log.IsDebug()) case STMTUseResult: action = STMTUseResult var req stmtUseResultRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt use result request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt use result request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt use result request error") return } @@ -466,13 +470,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtUseResult(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtUseResult(ctx, session, action, req, logger, log.IsDebug()) case STMTNumParams: action = STMTNumParams var req stmtNumParamsRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt num params request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt num params request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt num params request error") return } @@ -480,13 +484,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtNumParams(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtNumParams(ctx, session, action, req, logger, log.IsDebug()) case STMTGetParam: action = STMTGetParam var req stmtGetParamRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt get param request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt get param request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get param request error") return } @@ -494,14 +498,14 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmtGetParam(ctx, session, action, &req, logger, log.IsDebug()) + h.stmtGetParam(ctx, session, action, req, logger, log.IsDebug()) // stmt2 case STMT2Init: action = STMT2Init var req stmt2InitRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt2 init request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 init request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 init request error") return } @@ -509,13 +513,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmt2Init(ctx, session, action, &req, logger, log.IsDebug()) + h.stmt2Init(ctx, session, action, req, logger, log.IsDebug()) case STMT2Prepare: action = STMT2Prepare var req stmt2PrepareRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt2 prepare request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 prepare request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 prepare request error") return } @@ -523,13 +527,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmt2Prepare(ctx, session, action, &req, logger, log.IsDebug()) + h.stmt2Prepare(ctx, session, action, req, logger, log.IsDebug()) case STMT2GetFields: action = STMT2GetFields var req stmt2GetFieldsRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt2 get fields request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 get fields request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 get fields request error") return } @@ -537,13 +541,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmt2GetFields(ctx, session, action, &req, logger, log.IsDebug()) + h.stmt2GetFields(ctx, session, action, req, logger, log.IsDebug()) case STMT2Exec: action = STMT2Exec var req stmt2ExecRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt2 exec request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 exec request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 exec request error") return } @@ -551,13 +555,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmt2Exec(ctx, session, action, &req, logger, log.IsDebug()) + h.stmt2Exec(ctx, session, action, req, logger, log.IsDebug()) case STMT2Result: action = STMT2Result var req stmt2UseResultRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt2 result request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 result request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 result request error") return } @@ -565,13 +569,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmt2UseResult(ctx, session, action, &req, logger, log.IsDebug()) + h.stmt2UseResult(ctx, session, action, req, logger, log.IsDebug()) case STMT2Close: action = STMT2Close var req stmt2CloseRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal stmt2 close request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal stmt2 close request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 close request error") return } @@ -579,14 +583,14 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.stmt2Close(ctx, session, action, &req, logger) + h.stmt2Close(ctx, session, action, req, logger) // misc case WSGetCurrentDB: action = WSGetCurrentDB var req getCurrentDBRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal get current db request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal get current db request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal get current db request error") return } @@ -594,13 +598,13 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.getCurrentDB(ctx, session, action, &req, logger, log.IsDebug()) + h.getCurrentDB(ctx, session, action, req, logger, log.IsDebug()) case WSGetServerInfo: action = WSGetServerInfo var req getServerInfoRequest - if err := json.Unmarshal(argsBytes, &req); err != nil { - h.logger.Errorf("unmarshal get server info request error, request:%s, err:%s", argsBytes, err) - reqID := getReqID(args) + if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal get server info request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal get server info request error") return } @@ -608,10 +612,10 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { actionKey: action, config.ReqIDKey: req.ReqID, }) - h.getServerInfo(ctx, session, action, &req, logger, log.IsDebug()) + h.getServerInfo(ctx, session, action, req, logger, log.IsDebug()) default: h.logger.Errorf("unknown action %s", action) - reqID := getReqID(args) + reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, fmt.Sprintf("unknown action %s", action)) } } @@ -660,6 +664,15 @@ func (h *messageHandler) handleMessageBinary(session *melody.Session, message [] } } -func getReqID(value gjson.Result) uint64 { - return value.Get("req_id").Uint() +func getReqID(value json.RawMessage) uint64 { + return jsonIter.Get(value, "req_id").ToUint64() +} + +type VersionResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + Version string `json:"version"` } diff --git a/controller/ws/ws/handler_test.go b/controller/ws/ws/handler_test.go index f26c7979..01d75474 100644 --- a/controller/ws/ws/handler_test.go +++ b/controller/ws/ws/handler_test.go @@ -45,6 +45,28 @@ func TestDropUser(t *testing.T) { assert.Error(t, err, resp) } +func Test_WrongJsonRequest(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + err = ws.WriteMessage(websocket.TextMessage, []byte("{wrong json}")) + assert.NoError(t, err) + _, message, err := ws.ReadMessage() + assert.NoError(t, err) + var resp commonResp + err = json.Unmarshal(message, &resp) + assert.NoError(t, err) + assert.NotEqual(t, 0, resp.Code) +} + func Test_WrongJsonProtocol(t *testing.T) { s := httptest.NewServer(router) defer s.Close() diff --git a/controller/ws/ws/misc.go b/controller/ws/ws/misc.go index f6e00296..98f78322 100644 --- a/controller/ws/ws/misc.go +++ b/controller/ws/ws/misc.go @@ -23,7 +23,7 @@ type getCurrentDBResponse struct { DB string `json:"db"` } -func (h *messageHandler) getCurrentDB(ctx context.Context, session *melody.Session, action string, req *getCurrentDBRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) getCurrentDB(ctx context.Context, session *melody.Session, action string, req getCurrentDBRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("get current db") db, err := syncinterface.TaosGetCurrentDB(h.conn, logger, isDebug) if err != nil { @@ -54,7 +54,7 @@ type getServerInfoResponse struct { Info string `json:"info"` } -func (h *messageHandler) getServerInfo(ctx context.Context, session *melody.Session, action string, req *getServerInfoRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) getServerInfo(ctx context.Context, session *melody.Session, action string, req getServerInfoRequest, logger *logrus.Entry, isDebug bool) { logger.Trace("get server info") serverInfo := syncinterface.TaosGetServerInfo(h.conn, logger, isDebug) resp := &getServerInfoResponse{ diff --git a/controller/ws/ws/query.go b/controller/ws/ws/query.go index fb905d95..08256725 100644 --- a/controller/ws/ws/query.go +++ b/controller/ws/ws/query.go @@ -29,7 +29,7 @@ type connRequest struct { Mode *int `json:"mode"` } -func (h *messageHandler) connect(ctx context.Context, session *melody.Session, action string, req *connRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) connect(ctx context.Context, session *melody.Session, action string, req connRequest, logger *logrus.Entry, isDebug bool) { h.lock(logger, isDebug) defer h.Unlock() if h.closed { @@ -142,7 +142,7 @@ type queryResponse struct { Precision int `json:"precision"` } -func (h *messageHandler) query(ctx context.Context, session *melody.Session, action string, req *queryRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) query(ctx context.Context, session *melody.Session, action string, req queryRequest, logger *logrus.Entry, isDebug bool) { sqlType := monitor.WSRecordRequest(req.Sql) logger.Debugf("get query request, sql:%s", req.Sql) s := log.GetLogNow(isDebug) @@ -170,7 +170,7 @@ func (h *messageHandler) query(ctx context.Context, session *melody.Session, act affectRows := wrapper.TaosAffectedRows(result.Res) logger.Debugf("affected_rows %d cost:%s", affectRows, log.GetLogDuration(isDebug, s)) syncinterface.FreeResult(result.Res, logger, isDebug) - resp := queryResponse{ + resp := &queryResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -192,7 +192,7 @@ func (h *messageHandler) query(ctx context.Context, session *melody.Session, act queryResult := QueryResult{TaosResult: result.Res, FieldsCount: fieldsCount, Header: rowsHeader, precision: precision} idx := h.queryResults.Add(&queryResult) logger.Trace("add result to list finished") - resp := queryResponse{ + resp := &queryResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), diff --git a/controller/ws/ws/schemaless.go b/controller/ws/ws/schemaless.go index 532cea1c..093464b4 100644 --- a/controller/ws/ws/schemaless.go +++ b/controller/ws/ws/schemaless.go @@ -29,7 +29,7 @@ type schemalessWriteResponse struct { TotalRows int32 `json:"total_rows"` } -func (h *messageHandler) schemalessWrite(ctx context.Context, session *melody.Session, action string, req *schemalessWriteRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) schemalessWrite(ctx context.Context, session *melody.Session, action string, req schemalessWriteRequest, logger *logrus.Entry, isDebug bool) { if req.Protocol == 0 { logger.Error("schemaless write request error. protocol is null") commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "schemaless write protocol is null") @@ -46,7 +46,7 @@ func (h *messageHandler) schemalessWrite(ctx context.Context, session *melody.Se } affectedRows = wrapper.TaosAffectedRows(result) logger.Tracef("schemaless write total rows:%d, affected rows:%d", totalRows, affectedRows) - resp := schemalessWriteResponse{ + resp := &schemalessWriteResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), diff --git a/controller/ws/ws/stmt.go b/controller/ws/ws/stmt.go index ff7fd552..7387bb71 100644 --- a/controller/ws/ws/stmt.go +++ b/controller/ws/ws/stmt.go @@ -36,7 +36,7 @@ type stmtInitResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) stmtInit(ctx context.Context, session *melody.Session, action string, req *stmtInitRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtInit(ctx context.Context, session *melody.Session, action string, req stmtInitRequest, logger *logrus.Entry, isDebug bool) { stmtInit := syncinterface.TaosStmtInitWithReqID(h.conn, int64(req.ReqID), logger, isDebug) if stmtInit == nil { errStr := wrapper.TaosStmtErrStr(stmtInit) @@ -47,7 +47,7 @@ func (h *messageHandler) stmtInit(ctx context.Context, session *melody.Session, stmtItem := &StmtItem{stmt: stmtInit} h.stmts.Add(stmtItem) logger.Tracef("stmt init sucess, stmt_id:%d, stmt pointer:%p", stmtItem.index, stmtInit) - resp := stmtInitResponse{ + resp := &stmtInitResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -92,7 +92,7 @@ func (h *messageHandler) stmtValidateAndLock(ctx context.Context, session *melod return stmtItem, true } -func (h *messageHandler) stmtPrepare(ctx context.Context, session *melody.Session, action string, req *stmtPrepareRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtPrepare(ctx context.Context, session *melody.Session, action string, req stmtPrepareRequest, logger *logrus.Entry, isDebug bool) { logger.Debugf("stmt prepare, stmt_id:%d, sql:%s", req.StmtID, req.SQL) stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -116,7 +116,7 @@ func (h *messageHandler) stmtPrepare(ctx context.Context, session *melody.Sessio } logger.Tracef("stmt is insert:%t", isInsert) stmtItem.isInsert = isInsert - resp := stmtPrepareResponse{ + resp := &stmtPrepareResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -141,7 +141,7 @@ type stmtSetTableNameResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) stmtSetTableName(ctx context.Context, session *melody.Session, action string, req *stmtSetTableNameRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtSetTableName(ctx context.Context, session *melody.Session, action string, req stmtSetTableNameRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt set table name, stmt_id:%d, name:%s", req.StmtID, req.Name) stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -156,7 +156,7 @@ func (h *messageHandler) stmtSetTableName(ctx context.Context, session *melody.S return } logger.Tracef("stmt set table name success, stmt_id:%d", req.StmtID) - resp := stmtSetTableNameResponse{ + resp := &stmtSetTableNameResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -180,7 +180,7 @@ type stmtSetTagsResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) stmtSetTags(ctx context.Context, session *melody.Session, action string, req *stmtSetTagsRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtSetTags(ctx context.Context, session *melody.Session, action string, req stmtSetTagsRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt set tags, stmt_id:%d, tags:%s", req.StmtID, req.Tags) stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -200,7 +200,7 @@ func (h *messageHandler) stmtSetTags(ctx context.Context, session *melody.Sessio logger.Tracef("stmt tag nums:%d", tagNums) if tagNums == 0 { logger.Trace("no tags") - resp := stmtSetTagsResponse{ + resp := &stmtSetTagsResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -227,7 +227,7 @@ func (h *messageHandler) stmtSetTags(ctx context.Context, session *melody.Sessio return } logger.Trace("stmt set tags success") - resp := stmtSetTagsResponse{ + resp := &stmtSetTagsResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -251,7 +251,7 @@ type stmtBindResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) stmtBind(ctx context.Context, session *melody.Session, action string, req *stmtBindRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtBind(ctx context.Context, session *melody.Session, action string, req stmtBindRequest, logger *logrus.Entry, isDebug bool) { stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { return @@ -269,7 +269,7 @@ func (h *messageHandler) stmtBind(ctx context.Context, session *melody.Session, }() if colNums == 0 { logger.Trace("no columns") - resp := stmtBindResponse{ + resp := &stmtBindResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -307,7 +307,7 @@ func (h *messageHandler) stmtBind(ctx context.Context, session *melody.Session, return } logger.Trace("stmt bind success") - resp := stmtBindResponse{ + resp := &stmtBindResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -330,7 +330,7 @@ type stmtAddBatchResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) stmtAddBatch(ctx context.Context, session *melody.Session, action string, req *stmtAddBatchRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtAddBatch(ctx context.Context, session *melody.Session, action string, req stmtAddBatchRequest, logger *logrus.Entry, isDebug bool) { stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { return @@ -344,7 +344,7 @@ func (h *messageHandler) stmtAddBatch(ctx context.Context, session *melody.Sessi return } logger.Trace("stmt add batch success") - resp := stmtAddBatchResponse{ + resp := &stmtAddBatchResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -368,7 +368,7 @@ type stmtExecResponse struct { Affected int `json:"affected"` } -func (h *messageHandler) stmtExec(ctx context.Context, session *melody.Session, action string, req *stmtExecRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtExec(ctx context.Context, session *melody.Session, action string, req stmtExecRequest, logger *logrus.Entry, isDebug bool) { stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { return @@ -384,7 +384,7 @@ func (h *messageHandler) stmtExec(ctx context.Context, session *melody.Session, s := log.GetLogNow(isDebug) affected := wrapper.TaosStmtAffectedRowsOnce(stmtItem.stmt) logger.Debugf("stmt_affected_rows_once, affected:%d, cost:%s", affected, log.GetLogDuration(isDebug, s)) - resp := stmtExecResponse{ + resp := &stmtExecResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -399,7 +399,7 @@ type stmtCloseRequest struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) stmtClose(ctx context.Context, session *melody.Session, action string, req *stmtCloseRequest, logger *logrus.Entry) { +func (h *messageHandler) stmtClose(ctx context.Context, session *melody.Session, action string, req stmtCloseRequest, logger *logrus.Entry) { logger.Tracef("stmt close, stmt_id:%d", req.StmtID) err := h.stmts.FreeStmtByID(req.StmtID, false, logger) if err != nil { @@ -425,7 +425,7 @@ type stmtGetTagFieldsResponse struct { Fields []*stmtCommon.StmtField `json:"fields,omitempty"` } -func (h *messageHandler) stmtGetTagFields(ctx context.Context, session *melody.Session, action string, req *stmtGetTagFieldsRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtGetTagFields(ctx context.Context, session *melody.Session, action string, req stmtGetTagFieldsRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt get tag fields, stmt_id:%d", req.StmtID) stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -444,7 +444,7 @@ func (h *messageHandler) stmtGetTagFields(ctx context.Context, session *melody.S }() if tagNums == 0 { logger.Trace("no tags") - resp := stmtGetTagFieldsResponse{ + resp := &stmtGetTagFieldsResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -456,7 +456,7 @@ func (h *messageHandler) stmtGetTagFields(ctx context.Context, session *melody.S s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(tagNums, tagFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - resp := stmtGetTagFieldsResponse{ + resp := &stmtGetTagFieldsResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -481,7 +481,7 @@ type stmtGetColFieldsResponse struct { Fields []*stmtCommon.StmtField `json:"fields"` } -func (h *messageHandler) stmtGetColFields(ctx context.Context, session *melody.Session, action string, req *stmtGetColFieldsRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtGetColFields(ctx context.Context, session *melody.Session, action string, req stmtGetColFieldsRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt get col fields, stmt_id:%d", req.StmtID) stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -500,7 +500,7 @@ func (h *messageHandler) stmtGetColFields(ctx context.Context, session *melody.S }() if colNums == 0 { logger.Trace("no columns") - resp := stmtGetColFieldsResponse{ + resp := &stmtGetColFieldsResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -512,7 +512,7 @@ func (h *messageHandler) stmtGetColFields(ctx context.Context, session *melody.S s := log.GetLogNow(isDebug) fields := wrapper.StmtParseFields(colNums, colFields) logger.Debugf("stmt parse fields cost:%s", log.GetLogDuration(isDebug, s)) - resp := stmtGetColFieldsResponse{ + resp := &stmtGetColFieldsResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -542,7 +542,7 @@ type stmtUseResultResponse struct { Precision int `json:"precision"` } -func (h *messageHandler) stmtUseResult(ctx context.Context, session *melody.Session, action string, req *stmtUseResultRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtUseResult(ctx context.Context, session *melody.Session, action string, req stmtUseResultRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt use result, stmt_id:%d", req.StmtID) stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -601,7 +601,7 @@ type stmtNumParamsResponse struct { NumParams int `json:"num_params"` } -func (h *messageHandler) stmtNumParams(ctx context.Context, session *melody.Session, action string, req *stmtNumParamsRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtNumParams(ctx context.Context, session *melody.Session, action string, req stmtNumParamsRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt num params, stmt_id:%d", req.StmtID) stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -616,7 +616,7 @@ func (h *messageHandler) stmtNumParams(ctx context.Context, session *melody.Sess return } logger.Tracef("stmt num params success, stmt_id:%d, num_params:%d", req.StmtID, count) - resp := stmtNumParamsResponse{ + resp := &stmtNumParamsResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -644,7 +644,7 @@ type stmtGetParamResponse struct { Length int `json:"length"` } -func (h *messageHandler) stmtGetParam(ctx context.Context, session *melody.Session, action string, req *stmtGetParamRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmtGetParam(ctx context.Context, session *melody.Session, action string, req stmtGetParamRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt get param, stmt_id:%d, index:%d", req.StmtID, req.Index) stmtItem, locked := h.stmtValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -699,7 +699,7 @@ func (h *messageHandler) stmtBinarySetTags(ctx context.Context, session *melody. }() if tagNums == 0 { logger.Trace("no tags") - resp := stmtSetTagsResponse{ + resp := &stmtSetTagsResponse{ Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), @@ -730,7 +730,7 @@ func (h *messageHandler) stmtBinarySetTags(ctx context.Context, session *melody. stmtErrorResponse(ctx, session, logger, action, reqID, code, errStr, stmtID) return } - resp := stmtSetTagsResponse{ + resp := &stmtSetTagsResponse{ Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), @@ -765,7 +765,7 @@ func (h *messageHandler) stmtBinaryBind(ctx context.Context, session *melody.Ses }() if colNums == 0 { logger.Trace("no columns") - resp := stmtBindResponse{ + resp := &stmtBindResponse{ Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), @@ -818,7 +818,7 @@ func (h *messageHandler) stmtBinaryBind(ctx context.Context, session *melody.Ses return } logger.Trace("stmt bind param success") - resp := stmtBindResponse{ + resp := &stmtBindResponse{ Action: action, ReqID: reqID, Timing: wstool.GetDuration(ctx), diff --git a/controller/ws/ws/stmt2.go b/controller/ws/ws/stmt2.go index 4f6bc689..b06ff7b3 100644 --- a/controller/ws/ws/stmt2.go +++ b/controller/ws/ws/stmt2.go @@ -34,7 +34,7 @@ type stmt2InitResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) stmt2Init(ctx context.Context, session *melody.Session, action string, req *stmt2InitRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmt2Init(ctx context.Context, session *melody.Session, action string, req stmt2InitRequest, logger *logrus.Entry, isDebug bool) { handle, caller := async.GlobalStmt2CallBackCallerPool.Get() stmtInit := syncinterface.TaosStmt2Init(h.conn, int64(req.ReqID), req.SingleStbInsert, req.SingleTableBindOnce, handle, logger, isDebug) if stmtInit == nil { @@ -47,7 +47,7 @@ func (h *messageHandler) stmt2Init(ctx context.Context, session *melody.Session, stmtItem := &StmtItem{stmt: stmtInit, handler: handle, caller: caller, isStmt2: true} h.stmts.Add(stmtItem) logger.Tracef("stmt2 init sucess, stmt_id:%d, stmt pointer:%p", stmtItem.index, stmtInit) - resp := stmt2InitResponse{ + resp := &stmt2InitResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -100,7 +100,7 @@ type stmt2PrepareResponse struct { FieldsCount int `json:"fields_count"` } -func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Session, action string, req *stmt2PrepareRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Session, action string, req stmt2PrepareRequest, logger *logrus.Entry, isDebug bool) { logger.Debugf("stmt2 prepare, stmt_id:%d, sql:%s", req.StmtID, req.SQL) stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -223,7 +223,7 @@ type stmt2GetFieldsResponse struct { TagFields []*stmtCommon.StmtField `json:"tag_fields"` } -func (h *messageHandler) stmt2GetFields(ctx context.Context, session *melody.Session, action string, req *stmt2GetFieldsRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmt2GetFields(ctx context.Context, session *melody.Session, action string, req stmt2GetFieldsRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt2 get col fields, stmt_id:%d", req.StmtID) stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -288,7 +288,7 @@ type stmt2ExecResponse struct { Affected int `json:"affected"` } -func (h *messageHandler) stmt2Exec(ctx context.Context, session *melody.Session, action string, req *stmt2ExecRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmt2Exec(ctx context.Context, session *melody.Session, action string, req stmt2ExecRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt2 execute, stmt_id:%d", req.StmtID) stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -307,7 +307,7 @@ func (h *messageHandler) stmt2Exec(ctx context.Context, session *melody.Session, result := <-stmtItem.caller.ExecResult logger.Debugf("stmt2 execute wait callback finish, affected:%d, res:%p, n:%d, cost:%s", result.Affected, result.Res, result.N, log.GetLogDuration(isDebug, s)) stmtItem.result = result.Res - resp := stmt2ExecResponse{ + resp := &stmt2ExecResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), @@ -337,7 +337,7 @@ type stmt2UseResultResponse struct { Precision int `json:"precision"` } -func (h *messageHandler) stmt2UseResult(ctx context.Context, session *melody.Session, action string, req *stmt2UseResultRequest, logger *logrus.Entry, isDebug bool) { +func (h *messageHandler) stmt2UseResult(ctx context.Context, session *melody.Session, action string, req stmt2UseResultRequest, logger *logrus.Entry, isDebug bool) { logger.Tracef("stmt2 use result, stmt_id:%d", req.StmtID) stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) if !locked { @@ -380,7 +380,7 @@ type stmt2CloseResponse struct { StmtID uint64 `json:"stmt_id"` } -func (h *messageHandler) stmt2Close(ctx context.Context, session *melody.Session, action string, req *stmt2CloseRequest, logger *logrus.Entry) { +func (h *messageHandler) stmt2Close(ctx context.Context, session *melody.Session, action string, req stmt2CloseRequest, logger *logrus.Entry) { logger.Tracef("stmt2 close, stmt_id:%d", req.StmtID) err := h.stmts.FreeStmtByID(req.StmtID, true, logger) if err != nil { @@ -389,7 +389,7 @@ func (h *messageHandler) stmt2Close(ctx context.Context, session *melody.Session return } logger.Tracef("stmt2 close success, stmt_id:%d", req.StmtID) - resp := stmt2CloseResponse{ + resp := &stmt2CloseResponse{ Action: action, ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), diff --git a/controller/ws/wstool/resp.go b/controller/ws/wstool/resp.go index 9f4772cd..e6f4a4a7 100644 --- a/controller/ws/wstool/resp.go +++ b/controller/ws/wstool/resp.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/json" + jsoniter "github.com/json-iterator/go" "github.com/sirupsen/logrus" "github.com/taosdata/taosadapter/v3/tools/melody" "github.com/taosdata/taosadapter/v3/version" @@ -17,8 +18,10 @@ type TDEngineRestfulResp struct { Rows int `json:"rows,omitempty"` } +var jsonIter = jsoniter.ConfigCompatibleWithStandardLibrary + func WSWriteJson(session *melody.Session, logger *logrus.Entry, data interface{}) { - b, err := json.Marshal(data) + b, err := jsonIter.Marshal(data) if err != nil { logger.Errorf("marshal json failed:%s, data:%#v", err, data) return @@ -44,8 +47,8 @@ type WSVersionResp struct { var VersionResp []byte func WSWriteVersion(session *melody.Session, logger *logrus.Entry) { - logger.Tracef("write version") - _ = session.WriteBinary(VersionResp) + logger.Tracef("write version,%s", VersionResp) + _ = session.Write(VersionResp) logger.Trace("write version done") } @@ -59,5 +62,5 @@ func init() { Action: ClientVersion, Version: version.TaosClientVersion, } - VersionResp, _ = json.Marshal(resp) + VersionResp, _ = jsonIter.Marshal(resp) } From 8f35374fef8075f02e52063666fe9bd0959fe7a7 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 25 Nov 2024 15:24:48 +0800 Subject: [PATCH 14/48] enh: use std json --- controller/ws/ws/handler.go | 62 +++++++++++++++++------------------ controller/ws/ws/stmt_test.go | 4 ++- controller/ws/wstool/resp.go | 7 ++-- 3 files changed, 36 insertions(+), 37 deletions(-) diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index e43abed7..ad319761 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -179,13 +179,11 @@ func (h *messageHandler) stop() { }) } -var jsonIter = jsoniter.ConfigCompatibleWithStandardLibrary - func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano()) h.logger.Debugf("get ws message data:%s", data) var request Request - err := jsonIter.Unmarshal(data, &request) + err := json.Unmarshal(data, &request) if err != nil { h.logger.Errorf("unmarshal request error, request:%s, err:%s", data, err) commonErrorResponse(ctx, session, h.logger, "", 0, 0xffff, "unmarshal request error") @@ -206,7 +204,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case Connect: action = Connect var req connRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal connect request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, Connect, reqID, 0xffff, "unmarshal connect request error") @@ -234,7 +232,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case WSQuery: action = WSQuery var req queryRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal query request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal query request error") @@ -248,7 +246,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case WSFetch: action = WSFetch var req fetchRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal fetch request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal fetch request error") @@ -262,7 +260,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case WSFetchBlock: action = WSFetchBlock var req fetchBlockRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal fetch block request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal fetch block request error") @@ -276,7 +274,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case WSFreeResult: action = WSFreeResult var req freeResultRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal free result request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal free result request error") @@ -290,7 +288,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case WSNumFields: action = WSNumFields var req numFieldsRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal num fields request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal num fields request error") @@ -305,7 +303,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case SchemalessWrite: action = SchemalessWrite var req schemalessWriteRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal schemaless insert request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal schemaless insert request error") @@ -320,7 +318,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTInit: action = STMTInit var req stmtInitRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt init request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt init request error") @@ -334,7 +332,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTPrepare: action = STMTPrepare var req stmtPrepareRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt prepare request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt prepare request error") @@ -348,7 +346,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTSetTableName: action = STMTSetTableName var req stmtSetTableNameRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt set table name request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt set table name request error") @@ -362,7 +360,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTSetTags: action = STMTSetTags var req stmtSetTagsRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt set tags request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt set tags request error") @@ -376,7 +374,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTBind: action = STMTBind var req stmtBindRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt bind request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt bind request error") @@ -390,7 +388,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTAddBatch: action = STMTAddBatch var req stmtAddBatchRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt add batch request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt add batch request error") @@ -404,7 +402,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTExec: action = STMTExec var req stmtExecRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt exec request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt exec request error") @@ -418,7 +416,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTClose: action = STMTClose var req stmtCloseRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt close request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt close request error") @@ -432,7 +430,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTGetTagFields: action = STMTGetTagFields var req stmtGetTagFieldsRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt get tag fields request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get tag fields request error") @@ -446,7 +444,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTGetColFields: action = STMTGetColFields var req stmtGetColFieldsRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt get col fields request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get col fields request error") @@ -460,7 +458,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTUseResult: action = STMTUseResult var req stmtUseResultRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt use result request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt use result request error") @@ -474,7 +472,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTNumParams: action = STMTNumParams var req stmtNumParamsRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt num params request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt num params request error") @@ -488,7 +486,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMTGetParam: action = STMTGetParam var req stmtGetParamRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt get param request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt get param request error") @@ -503,7 +501,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMT2Init: action = STMT2Init var req stmt2InitRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt2 init request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 init request error") @@ -517,7 +515,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMT2Prepare: action = STMT2Prepare var req stmt2PrepareRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt2 prepare request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 prepare request error") @@ -531,7 +529,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMT2GetFields: action = STMT2GetFields var req stmt2GetFieldsRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt2 get fields request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 get fields request error") @@ -545,7 +543,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMT2Exec: action = STMT2Exec var req stmt2ExecRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt2 exec request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 exec request error") @@ -559,7 +557,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMT2Result: action = STMT2Result var req stmt2UseResultRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt2 result request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 result request error") @@ -573,7 +571,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case STMT2Close: action = STMT2Close var req stmt2CloseRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal stmt2 close request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 close request error") @@ -588,7 +586,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case WSGetCurrentDB: action = WSGetCurrentDB var req getCurrentDBRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal get current db request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal get current db request error") @@ -602,7 +600,7 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case WSGetServerInfo: action = WSGetServerInfo var req getServerInfoRequest - if err := jsonIter.Unmarshal(request.Args, &req); err != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { h.logger.Errorf("unmarshal get server info request error, request:%s, err:%s", request.Args, err) reqID := getReqID(request.Args) commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal get server info request error") @@ -664,6 +662,8 @@ func (h *messageHandler) handleMessageBinary(session *melody.Session, message [] } } +var jsonIter = jsoniter.ConfigCompatibleWithStandardLibrary + func getReqID(value json.RawMessage) uint64 { return jsonIter.Get(value, "req_id").ToUint64() } diff --git a/controller/ws/ws/stmt_test.go b/controller/ws/ws/stmt_test.go index 2687cba5..8c752814 100644 --- a/controller/ws/ws/stmt_test.go +++ b/controller/ws/ws/stmt_test.go @@ -317,11 +317,13 @@ func TestWsStmt(t *testing.T) { // block message // init - resp, err = doWebSocket(ws, STMTInit, nil) + initReq = map[string]uint64{"req_id": 0x11} + resp, err = doWebSocket(ws, STMTInit, &initReq) assert.NoError(t, err) err = json.Unmarshal(resp, &initResp) assert.NoError(t, err) assert.Equal(t, 0, initResp.Code, initResp.Message) + assert.Equal(t, uint64(0x11), initResp.ReqID) // prepare prepareReq = stmtPrepareRequest{StmtID: initResp.StmtID, SQL: "insert into ? using test_ws_stmt_ws.stb tags(?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} diff --git a/controller/ws/wstool/resp.go b/controller/ws/wstool/resp.go index e6f4a4a7..245bdd5e 100644 --- a/controller/ws/wstool/resp.go +++ b/controller/ws/wstool/resp.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/json" - jsoniter "github.com/json-iterator/go" "github.com/sirupsen/logrus" "github.com/taosdata/taosadapter/v3/tools/melody" "github.com/taosdata/taosadapter/v3/version" @@ -18,10 +17,8 @@ type TDEngineRestfulResp struct { Rows int `json:"rows,omitempty"` } -var jsonIter = jsoniter.ConfigCompatibleWithStandardLibrary - func WSWriteJson(session *melody.Session, logger *logrus.Entry, data interface{}) { - b, err := jsonIter.Marshal(data) + b, err := json.Marshal(data) if err != nil { logger.Errorf("marshal json failed:%s, data:%#v", err, data) return @@ -62,5 +59,5 @@ func init() { Action: ClientVersion, Version: version.TaosClientVersion, } - VersionResp, _ = jsonIter.Marshal(resp) + VersionResp, _ = json.Marshal(resp) } From 7bd82f57aed298bb5627ac861989499a397a4b90 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 28 Nov 2024 18:06:01 +0800 Subject: [PATCH 15/48] enh: avoid goroutine leaks when subscription fails --- controller/ws/tmq/tmq.go | 66 ++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index b9656ac1..beba3711 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -531,12 +531,36 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu wsTMQErrorMsg(ctx, session, logger, 0xffff, err.Error(), action, req.ReqID, nil) return } + topicList := wrapper.TMQListNew() + defer func() { + wrapper.TMQListDestroy(topicList) + }() + for _, topic := range req.Topics { + logger.Tracef("tmq append topic:%s", topic) + errCode = wrapper.TMQListAppend(topicList, topic) + if errCode != 0 { + errStr := wrapper.TMQErr2Str(errCode) + logger.Errorf("tmq list append error, tpic:%s, code:%d, msg:%s", topic, errCode, errStr) + t.closeConsumerWithErrLog(logger, isDebug, cPointer) + wsTMQErrorMsg(ctx, session, logger, int(errCode), errStr, action, req.ReqID, nil) + return + } + } + + errCode = t.wrapperSubscribe(logger, isDebug, cPointer, topicList) + if errCode != 0 { + errStr := wrapper.TMQErr2Str(errCode) + logger.Errorf("tmq subscribe error:%d %s", errCode, errStr) + t.closeConsumerWithErrLog(logger, isDebug, cPointer) + wsTMQErrorMsg(ctx, session, logger, int(errCode), errStr, action, req.ReqID, nil) + return + } conn := wrapper.TMQGetConnect(cPointer) logger.Trace("get whitelist") whitelist, err := tool.GetWhitelist(conn) if err != nil { logger.Errorf("get whitelist error:%s", err.Error()) - t.wrapperCloseConsumer(logger, isDebug, cPointer) + t.closeConsumerWithErrLog(logger, isDebug, cPointer) wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } @@ -544,7 +568,7 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu valid := tool.CheckWhitelist(whitelist, t.ip) if !valid { logger.Errorf("whitelist prohibits current IP access, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) - t.wrapperCloseConsumer(logger, isDebug, cPointer) + t.closeConsumerWithErrLog(logger, isDebug, cPointer) wstool.WSErrorMsg(ctx, session, logger, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) return } @@ -552,7 +576,7 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu err = tool.RegisterChangeWhitelist(conn, t.whitelistChangeHandle) if err != nil { logger.Errorf("register change whitelist error:%s", err) - t.wrapperCloseConsumer(logger, isDebug, cPointer) + t.closeConsumerWithErrLog(logger, isDebug, cPointer) wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } @@ -560,38 +584,14 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu err = tool.RegisterDropUser(conn, t.dropUserHandle) if err != nil { logger.Errorf("register drop user error:%s", err) - t.wrapperCloseConsumer(logger, isDebug, cPointer) + t.closeConsumerWithErrLog(logger, isDebug, cPointer) wstool.WSError(ctx, session, logger, err, action, req.ReqID) return } t.conn = conn + t.consumer = cPointer logger.Trace("start to wait signal") go t.waitSignal(t.logger) - topicList := wrapper.TMQListNew() - defer func() { - wrapper.TMQListDestroy(topicList) - }() - for _, topic := range req.Topics { - logger.Tracef("tmq append topic:%s", topic) - errCode = wrapper.TMQListAppend(topicList, topic) - if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - logger.Errorf("tmq list append error, tpic:%s, code:%d, msg:%s", topic, errCode, errStr) - t.wrapperCloseConsumer(logger, isDebug, cPointer) - wsTMQErrorMsg(ctx, session, logger, int(errCode), errStr, action, req.ReqID, nil) - return - } - } - - errCode = t.wrapperSubscribe(logger, isDebug, cPointer, topicList) - if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - logger.Errorf("tmq subscribe error:%d %s", errCode, errStr) - t.wrapperCloseConsumer(logger, isDebug, cPointer) - wsTMQErrorMsg(ctx, session, logger, int(errCode), errStr, action, req.ReqID, nil) - return - } - t.consumer = cPointer wstool.WSWriteJson(session, logger, &TMQSubscribeResp{ Action: action, ReqID: req.ReqID, @@ -599,6 +599,14 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu }) } +func (t *TMQ) closeConsumerWithErrLog(logger *logrus.Entry, isDebug bool, consumer unsafe.Pointer) { + errCode := t.wrapperCloseConsumer(logger, isDebug, consumer) + if errCode != 0 { + errMsg := wrapper.TMQErr2Str(errCode) + logger.Errorf("tmq close consumer error, consumer:%p, code:%d, msg:%s", t.consumer, errCode, errMsg) + } +} + type TMQCommitReq struct { ReqID uint64 `json:"req_id"` MessageID uint64 `json:"message_id"` // unused From f96f248ddafdb40ce70c646833bb76df3b67f01d Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 28 Nov 2024 18:14:12 +0800 Subject: [PATCH 16/48] enh: remove unnecessary session check --- controller/ws/stmt/stmt.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/controller/ws/stmt/stmt.go b/controller/ws/stmt/stmt.go index 2b14bd6c..201dbc45 100644 --- a/controller/ws/stmt/stmt.go +++ b/controller/ws/stmt/stmt.go @@ -172,9 +172,6 @@ func NewSTMTController() *STMTController { }) stmtM.HandleMessageBinary(func(session *melody.Session, data []byte) { - if session.IsClosed() { - return - } t := session.MustGet(TaosStmtKey).(*TaosStmt) if t.closed { return From 2a49c4c7eed8808ae82bc16d5b0ad91395b1eaea Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 28 Nov 2024 19:14:53 +0800 Subject: [PATCH 17/48] fix: influxdb test use same db --- schemaless/capi/influxdb_test.go | 10 +++++----- schemaless/capi/opentsdb_test.go | 30 +++++++++++++++--------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/schemaless/capi/influxdb_test.go b/schemaless/capi/influxdb_test.go index 73635d0b..85e20634 100644 --- a/schemaless/capi/influxdb_test.go +++ b/schemaless/capi/influxdb_test.go @@ -21,7 +21,7 @@ func TestInsertInfluxdb(t *testing.T) { } defer wrapper.TaosClose(conn) defer func() { - r := wrapper.TaosQuery(conn, "drop database if exists test_capi") + r := wrapper.TaosQuery(conn, "drop database if exists test_capi_influxdb") code := wrapper.TaosError(r) if code != 0 { errStr := wrapper.TaosErrorStr(r) @@ -29,7 +29,7 @@ func TestInsertInfluxdb(t *testing.T) { } wrapper.TaosFreeResult(r) }() - r := wrapper.TaosQuery(conn, "create database if not exists test_capi") + r := wrapper.TaosQuery(conn, "create database if not exists test_capi_influxdb") code := wrapper.TaosError(r) if code != 0 { errStr := wrapper.TaosErrorStr(r) @@ -53,7 +53,7 @@ func TestInsertInfluxdb(t *testing.T) { args: args{ taosConnect: conn, data: []byte("measurement,host=host1 field1=2i,field2=2.0,fieldKey=\"Launch 🚀\" 1577836800000000001"), - db: "test_capi", + db: "test_capi_influxdb", ttl: 0, }, wantErr: false, @@ -62,7 +62,7 @@ func TestInsertInfluxdb(t *testing.T) { args: args{ taosConnect: conn, data: []byte("wrong,host=host1 field1=wrong 1577836800000000001"), - db: "test_capi", + db: "test_capi_influxdb", ttl: 100, }, wantErr: true, @@ -71,7 +71,7 @@ func TestInsertInfluxdb(t *testing.T) { args: args{ taosConnect: conn, data: []byte("wrong,host=host1 field1=wrong 1577836800000000001"), - db: "1'test_capi", + db: "1'test_capi_influxdb", ttl: 1000, }, wantErr: true, diff --git a/schemaless/capi/opentsdb_test.go b/schemaless/capi/opentsdb_test.go index b19e0024..18a83325 100644 --- a/schemaless/capi/opentsdb_test.go +++ b/schemaless/capi/opentsdb_test.go @@ -31,7 +31,7 @@ func TestInsertOpentsdbTelnet(t *testing.T) { } defer wrapper.TaosClose(conn) defer func() { - r := wrapper.TaosQuery(conn, "drop database if exists test_capi") + r := wrapper.TaosQuery(conn, "drop database if exists test_capi_opentsdb") code := wrapper.TaosError(r) if code != 0 { errStr := wrapper.TaosErrorStr(r) @@ -39,7 +39,7 @@ func TestInsertOpentsdbTelnet(t *testing.T) { } wrapper.TaosFreeResult(r) }() - r := wrapper.TaosQuery(conn, "create database if not exists test_capi") + r := wrapper.TaosQuery(conn, "create database if not exists test_capi_opentsdb") code := wrapper.TaosError(r) if code != 0 { errStr := wrapper.TaosErrorStr(r) @@ -62,7 +62,7 @@ func TestInsertOpentsdbTelnet(t *testing.T) { args: args{ taosConnect: conn, data: "df.data.df_complex.used 1636539620 21393473536 fqdn=vm130 status=production", - db: "test_capi", + db: "test_capi_opentsdb", ttl: 100, }, wantErr: false, @@ -71,7 +71,7 @@ func TestInsertOpentsdbTelnet(t *testing.T) { args: args{ taosConnect: conn, data: "df.data.df_complex.used 163653962000 21393473536 fqdn=vm130 status=production", - db: "test_capi", + db: "test_capi_opentsdb", }, wantErr: true, }, { @@ -79,7 +79,7 @@ func TestInsertOpentsdbTelnet(t *testing.T) { args: args{ taosConnect: conn, data: "df.data.df_complex.used 1636539620 21393473536 fqdn=vm130 status=production", - db: "1'test_capi", + db: "1'test_capi_opentsdb", ttl: 1000, }, wantErr: true, @@ -88,7 +88,7 @@ func TestInsertOpentsdbTelnet(t *testing.T) { args: args{ taosConnect: conn, data: "", - db: "test_capi", + db: "test_capi_opentsdb", ttl: 1000, }, wantErr: false, @@ -144,7 +144,7 @@ func TestInsertOpentsdbJson(t *testing.T) { now := time.Now().Unix() defer wrapper.TaosClose(conn) defer func() { - r := wrapper.TaosQuery(conn, "drop database if exists test_capi") + r := wrapper.TaosQuery(conn, "drop database if exists test_capi_opentsdb_json") code := wrapper.TaosError(r) if code != 0 { errStr := wrapper.TaosErrorStr(r) @@ -152,7 +152,7 @@ func TestInsertOpentsdbJson(t *testing.T) { } wrapper.TaosFreeResult(r) }() - r := wrapper.TaosQuery(conn, "create database if not exists test_capi") + r := wrapper.TaosQuery(conn, "create database if not exists test_capi_opentsdb_json") code := wrapper.TaosError(r) if code != 0 { errStr := wrapper.TaosErrorStr(r) @@ -184,7 +184,7 @@ func TestInsertOpentsdbJson(t *testing.T) { "dc": "lga" } }`, now)), - db: "test_capi", + db: "test_capi_opentsdb_json", ttl: 100, }, wantErr: false, @@ -202,7 +202,7 @@ func TestInsertOpentsdbJson(t *testing.T) { "dc": "lga" } }`, now)), - db: "test_capi", + db: "test_capi_opentsdb_json", ttl: 0, }, wantErr: false, @@ -220,7 +220,7 @@ func TestInsertOpentsdbJson(t *testing.T) { "dc": "lga" } }`), - db: "1'test_capi", + db: "1'test_capi_opentsdb_json", }, wantErr: true, }, { @@ -228,7 +228,7 @@ func TestInsertOpentsdbJson(t *testing.T) { args: args{ taosConnect: conn, data: nil, - db: "test_capi", + db: "test_capi_opentsdb_json", ttl: 1000, }, wantErr: false, @@ -252,7 +252,7 @@ func TestInsertOpentsdbTelnetBatch(t *testing.T) { } defer wrapper.TaosClose(conn) defer func() { - r := wrapper.TaosQuery(conn, "drop database if exists test_capi") + r := wrapper.TaosQuery(conn, "drop database if exists test_capi_opentsdb_batch") code := wrapper.TaosError(r) if code != 0 { errStr := wrapper.TaosErrorStr(r) @@ -260,7 +260,7 @@ func TestInsertOpentsdbTelnetBatch(t *testing.T) { } wrapper.TaosFreeResult(r) }() - r := wrapper.TaosQuery(conn, "create database if not exists test_capi") + r := wrapper.TaosQuery(conn, "create database if not exists test_capi_opentsdb_batch") code := wrapper.TaosError(r) if code != 0 { errStr := wrapper.TaosErrorStr(r) @@ -286,7 +286,7 @@ func TestInsertOpentsdbTelnetBatch(t *testing.T) { "df.data.df_complex.used 1636539620 21393473536 fqdn=vm130 status=production", "df.data.df_complex.used 1636539621 21393473536 fqdn=vm129 status=production", }, - db: "test_capi", + db: "test_capi_opentsdb_batch", ttl: 100, }, wantErr: false, From 7b6b0628ad05044be0e4f3e64912231f4721019c Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 28 Nov 2024 19:18:12 +0800 Subject: [PATCH 18/48] fix: stmt test use same db --- controller/ws/stmt/stmt_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/controller/ws/stmt/stmt_test.go b/controller/ws/stmt/stmt_test.go index c73d4968..e2b53b28 100644 --- a/controller/ws/stmt/stmt_test.go +++ b/controller/ws/stmt/stmt_test.go @@ -536,14 +536,14 @@ func TestSTMT(t *testing.T) { func TestBlock(t *testing.T) { w := httptest.NewRecorder() - body := strings.NewReader("drop database if exists test_ws_stmt") + body := strings.NewReader("drop database if exists test_ws_stmt_block") req, _ := http.NewRequest(http.MethodPost, "/rest/sql", body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) w = httptest.NewRecorder() - body = strings.NewReader("create database if not exists test_ws_stmt precision 'ns'") + body = strings.NewReader("create database if not exists test_ws_stmt_block precision 'ns'") req, _ = http.NewRequest(http.MethodPost, "/rest/sql", body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") @@ -567,7 +567,7 @@ func TestBlock(t *testing.T) { "c14 varchar(20)," + "c15 geometry(100)" + ") tags(info json)") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_stmt", body) + req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_stmt_block", body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") router.ServeHTTP(w, req) @@ -781,7 +781,7 @@ func TestBlock(t *testing.T) { b, _ := json.Marshal(&StmtPrepareReq{ ReqID: 3, StmtID: stmtID, - SQL: "insert into ? using test_ws_stmt.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + SQL: "insert into ? using test_ws_stmt_block.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", }) action, _ := json.Marshal(&wstool.WSAction{ Action: STMTPrepare, @@ -808,7 +808,7 @@ func TestBlock(t *testing.T) { b, _ := json.Marshal(&StmtSetTableNameReq{ ReqID: 4, StmtID: stmtID, - Name: "test_ws_stmt.ctb", + Name: "test_ws_stmt_block.ctb", }) action, _ := json.Marshal(&wstool.WSAction{ Action: STMTSetTableName, @@ -1032,7 +1032,7 @@ func TestBlock(t *testing.T) { assert.NoError(t, err) w = httptest.NewRecorder() body = strings.NewReader("select * from stb") - req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_stmt", body) + req, _ = http.NewRequest(http.MethodPost, "/rest/sql/test_ws_stmt_block", body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") router.ServeHTTP(w, req) @@ -1040,7 +1040,7 @@ func TestBlock(t *testing.T) { resultBody := fmt.Sprintf(`{"code":0,"column_meta":[["ts","TIMESTAMP",8],["c1","BOOL",1],["c2","TINYINT",1],["c3","SMALLINT",2],["c4","INT",4],["c5","BIGINT",8],["c6","TINYINT UNSIGNED",1],["c7","SMALLINT UNSIGNED",2],["c8","INT UNSIGNED",4],["c9","BIGINT UNSIGNED",8],["c10","FLOAT",4],["c11","DOUBLE",8],["c12","VARCHAR",20],["c13","NCHAR",20],["c14","VARCHAR",20],["c15","GEOMETRY",100],["info","JSON",4095]],"data":[["%s",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","test_varbinary","010100000000000000000059400000000000005940",{"a":"b"}],["%s",false,22,33,44,55,66,77,88,99,1010,1111,"binary2","nchar2","test_varbinary2","010100000000000000000059400000000000005940",{"a":"b"}],["%s",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":"b"}]],"rows":3}`, now.UTC().Format(layout.LayoutNanoSecond), now.Add(time.Second).UTC().Format(layout.LayoutNanoSecond), now.Add(time.Second*2).UTC().Format(layout.LayoutNanoSecond)) assert.Equal(t, resultBody, w.Body.String()) w = httptest.NewRecorder() - body = strings.NewReader("drop database if exists test_ws_stmt") + body = strings.NewReader("drop database if exists test_ws_stmt_block") req, _ = http.NewRequest(http.MethodPost, "/rest/sql", body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") From 1dd516e1ac0ddd161237ed60af2ce5075d0ba150 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 28 Nov 2024 19:31:07 +0800 Subject: [PATCH 19/48] ci: upgrade go version and artifact version --- .github/workflows/linux.yml | 6 +++--- .github/workflows/macos.yml | 6 +++--- .github/workflows/windows.yml | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 462d3944..7b7a05ed 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -126,7 +126,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.17', '1.20' ] + go: [ '1.17', 'stable' ] name: Build taosAdapter ${{ matrix.go }} steps: - name: get cache server by pr @@ -183,7 +183,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.17', '1.20' ] + go: [ '1.17', 'stable' ] name: test taosAdapter ${{ matrix.go }} steps: - name: get cache server by pr @@ -255,7 +255,7 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_ORG_TOKEN }} - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 if: always() && (steps.test.outcome == 'failure' || steps.test.outcome == 'cancelled') with: name: ${{ runner.os }}-${{ matrix.go }}-log diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 8575f28c..8d3bf220 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -122,7 +122,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.17', '1.20' ] + go: [ '1.17', 'stable' ] name: Build taosAdapter ${{ matrix.go }} steps: - name: get cache server by pr @@ -183,7 +183,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.17', '1.20' ] + go: [ '1.17', 'stable' ] name: test taosAdapter ${{ matrix.go }} steps: - name: get cache server by pr @@ -256,7 +256,7 @@ jobs: DYLD_LIBRARY_PATH: /usr/local/lib:$DYLD_LIBRARY_PATH run: cd ./taosadapter && sudo DYLD_LIBRARY_PATH=/usr/local/lib:$DYLD_LIBRARY_PATH go test -v --count=1 ./... - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 if: always() && (steps.test.outcome == 'failure' || steps.test.outcome == 'cancelled') with: name: ${{ runner.os }}-${{ matrix.go }}-log diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 932442e0..f91ffe88 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -106,7 +106,7 @@ jobs: runs-on: windows-2022 strategy: matrix: - go: [ '1.20' ] + go: [ 'stable' ] name: Go ${{ matrix.go }} steps: - name: get cache server by pr From 020b8acdc317656123997c914ed55fa2a46b0e98 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 29 Nov 2024 14:25:10 +0800 Subject: [PATCH 20/48] test: use file log in unit test --- controller/ws/query/ws_test.go | 2 +- controller/ws/schemaless/schemaless_test.go | 2 +- controller/ws/stmt/stmt_test.go | 2 +- controller/ws/tmq/tmq_test.go | 2 +- controller/ws/ws/ws_test.go | 2 +- controller/ws/wstool/error_test.go | 4 ++-- controller/ws/wstool/log_test.go | 6 +++--- controller/ws/wstool/main_test.go | 17 +++++++++++++++++ controller/ws/wstool/resp_test.go | 4 ++-- db/async/row_test.go | 2 +- db/syncinterface/wrapper_test.go | 1 + db/tool/createdb_test.go | 9 +++++---- plugin/influxdb/plugin_test.go | 4 +++- plugin/nodeexporter/plugin_test.go | 2 ++ plugin/opentsdb/plugin_test.go | 2 ++ plugin/opentsdbtelnet/plugin_test.go | 2 ++ plugin/prometheus/plugin_test.go | 2 ++ plugin/statsd/plugin_test.go | 2 ++ schemaless/capi/influxdb_test.go | 4 ++-- schemaless/capi/opentsdb_test.go | 16 ++++++++++------ 20 files changed, 61 insertions(+), 26 deletions(-) create mode 100644 controller/ws/wstool/main_test.go diff --git a/controller/ws/query/ws_test.go b/controller/ws/query/ws_test.go index 9bd0de8d..baa99ef6 100644 --- a/controller/ws/query/ws_test.go +++ b/controller/ws/query/ws_test.go @@ -33,8 +33,8 @@ func TestMain(m *testing.M) { viper.Set("logLevel", "trace") viper.Set("uploadKeeper.enable", false) config.Init() - db.PrepareConnection() log.ConfigLog() + db.PrepareConnection() gin.SetMode(gin.ReleaseMode) router = gin.New() controllers := controller.GetControllers() diff --git a/controller/ws/schemaless/schemaless_test.go b/controller/ws/schemaless/schemaless_test.go index a6f4fe9e..c2f4e25b 100644 --- a/controller/ws/schemaless/schemaless_test.go +++ b/controller/ws/schemaless/schemaless_test.go @@ -31,8 +31,8 @@ func TestMain(m *testing.M) { viper.Set("logLevel", "trace") viper.Set("uploadKeeper.enable", false) config.Init() - db.PrepareConnection() log.ConfigLog() + db.PrepareConnection() gin.SetMode(gin.ReleaseMode) router = gin.New() controllers := controller.GetControllers() diff --git a/controller/ws/stmt/stmt_test.go b/controller/ws/stmt/stmt_test.go index e2b53b28..445a2fb8 100644 --- a/controller/ws/stmt/stmt_test.go +++ b/controller/ws/stmt/stmt_test.go @@ -34,8 +34,8 @@ func TestMain(m *testing.M) { viper.Set("logLevel", "trace") viper.Set("uploadKeeper.enable", false) config.Init() - db.PrepareConnection() log.ConfigLog() + db.PrepareConnection() gin.SetMode(gin.ReleaseMode) router = gin.New() controllers := controller.GetControllers() diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index 1a7c208c..4e50128b 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -42,8 +42,8 @@ func TestMain(m *testing.M) { viper.Set("logLevel", "trace") viper.Set("uploadKeeper.enable", false) config.Init() - db.PrepareConnection() log.ConfigLog() + db.PrepareConnection() gin.SetMode(gin.ReleaseMode) router = gin.New() controllers := controller.GetControllers() diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index 54023fb7..cf67435e 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -31,8 +31,8 @@ func TestMain(m *testing.M) { viper.Set("logLevel", "trace") viper.Set("uploadKeeper.enable", false) config.Init() - db.PrepareConnection() log.ConfigLog() + db.PrepareConnection() gin.SetMode(gin.ReleaseMode) router = gin.New() controllers := controller.GetControllers() diff --git a/controller/ws/wstool/error_test.go b/controller/ws/wstool/error_test.go index 9cd75648..fd85f964 100644 --- a/controller/ws/wstool/error_test.go +++ b/controller/ws/wstool/error_test.go @@ -11,9 +11,9 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" tErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/melody" ) @@ -27,7 +27,7 @@ func TestWSError(t *testing.T) { ErrStr: "test error", } commonErr := errors.New("test common error") - logger := logrus.New().WithField("test", "TestWSError") + logger := log.GetLogger("test").WithField("test", "TestWSError") m.HandleMessage(func(session *melody.Session, data []byte) { switch data[0] { case '1': diff --git a/controller/ws/wstool/log_test.go b/controller/ws/wstool/log_test.go index 6b0ccdb3..09d527d9 100644 --- a/controller/ws/wstool/log_test.go +++ b/controller/ws/wstool/log_test.go @@ -7,8 +7,8 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/melody" ) @@ -21,7 +21,7 @@ func TestGetDuration(t *testing.T) { } func TestGetLogger(t *testing.T) { - logger := logrus.New() + logger := log.GetLogger("test") session := &melody.Session{} session.Set("logger", logger.WithField("test_field", "test_value")) entry := GetLogger(session) @@ -29,7 +29,7 @@ func TestGetLogger(t *testing.T) { } func TestLogWSError(t *testing.T) { - logger := logrus.New() + logger := log.GetLogger("test") session := &melody.Session{} session.Set("logger", logger.WithField("test_field", "test_value")) LogWSError(session, nil) diff --git a/controller/ws/wstool/main_test.go b/controller/ws/wstool/main_test.go new file mode 100644 index 00000000..5fc3352c --- /dev/null +++ b/controller/ws/wstool/main_test.go @@ -0,0 +1,17 @@ +package wstool + +import ( + "testing" + + "github.com/spf13/viper" + "github.com/taosdata/taosadapter/v3/config" + "github.com/taosdata/taosadapter/v3/log" +) + +func TestMain(m *testing.M) { + viper.Set("logLevel", "trace") + viper.Set("uploadKeeper.enable", false) + config.Init() + log.ConfigLog() + m.Run() +} diff --git a/controller/ws/wstool/resp_test.go b/controller/ws/wstool/resp_test.go index 1680b1e1..60e1eca4 100644 --- a/controller/ws/wstool/resp_test.go +++ b/controller/ws/wstool/resp_test.go @@ -8,8 +8,8 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/melody" ) @@ -23,7 +23,7 @@ func TestWSWriteJson(t *testing.T) { Version: "1.0.0", } m.HandleMessage(func(session *melody.Session, _ []byte) { - logger := logrus.New().WithField("test", "TestWSWriteJson") + logger := log.GetLogger("test").WithField("test", "TestWSWriteJson") session.Set("logger", logger) WSWriteJson(session, logger, data) }) diff --git a/db/async/row_test.go b/db/async/row_test.go index 65789004..548a2864 100644 --- a/db/async/row_test.go +++ b/db/async/row_test.go @@ -98,7 +98,7 @@ func TestAsync_TaosExec(t *testing.T) { wantErr: true, }, } - var logger = logrus.New().WithField("test", "async_test") + var logger = log.GetLogger("test").WithField("test", "async_test") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Async{ diff --git a/db/syncinterface/wrapper_test.go b/db/syncinterface/wrapper_test.go index c37a3b5c..a22ca7e0 100644 --- a/db/syncinterface/wrapper_test.go +++ b/db/syncinterface/wrapper_test.go @@ -26,6 +26,7 @@ const isDebug = true func TestMain(m *testing.M) { config.Init() + log.ConfigLog() _ = log.SetLevel("trace") db.PrepareConnection() m.Run() diff --git a/db/tool/createdb_test.go b/db/tool/createdb_test.go index 6ff2002d..f2465768 100644 --- a/db/tool/createdb_test.go +++ b/db/tool/createdb_test.go @@ -5,18 +5,20 @@ import ( "testing" "unsafe" - "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/assert" tErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" ) func TestMain(m *testing.M) { viper.Set("smlAutoCreateDB", true) + viper.Set("logLevel", "trace") config.Init() + log.ConfigLog() db.PrepareConnection() m.Run() viper.Set("smlAutoCreateDB", false) @@ -26,7 +28,6 @@ func TestMain(m *testing.M) { // @date: 2021/12/14 15:05 // @description: test creat database with connection func TestCreateDBWithConnection(t *testing.T) { - db.PrepareConnection() conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { t.Error(err) @@ -52,7 +53,7 @@ func TestCreateDBWithConnection(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := CreateDBWithConnection(tt.args.connection, logrus.New().WithField("test", "TestCreateDBWithConnection"), false, tt.args.db, 0); (err != nil) != tt.wantErr { + if err := CreateDBWithConnection(tt.args.connection, log.GetLogger("test").WithField("test", "TestCreateDBWithConnection"), false, tt.args.db, 0); (err != nil) != tt.wantErr { t.Errorf("CreateDBWithConnection() error = %v, wantErr %v", err, tt.wantErr) } code := wrapper.TaosSelectDB(tt.args.connection, tt.args.db) @@ -107,7 +108,7 @@ func TestSchemalessSelectDB(t *testing.T) { return } wrapper.TaosFreeResult(result) - if err := SchemalessSelectDB(tt.args.taosConnect, logrus.New().WithField("test", "TestSchemalessSelectDB"), false, tt.args.db, 0); (err != nil) != tt.wantErr { + if err := SchemalessSelectDB(tt.args.taosConnect, log.GetLogger("test").WithField("test", "TestSchemalessSelectDB"), false, tt.args.db, 0); (err != nil) != tt.wantErr { t.Errorf("selectDB() error = %v, wantErr %v", err, tt.wantErr) } r := wrapper.TaosQuery(tt.args.taosConnect, fmt.Sprintf("drop database if exists %s", tt.args.db)) diff --git a/plugin/influxdb/plugin_test.go b/plugin/influxdb/plugin_test.go index 6b0ce01a..2c19cddc 100644 --- a/plugin/influxdb/plugin_test.go +++ b/plugin/influxdb/plugin_test.go @@ -18,6 +18,7 @@ import ( "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" ) // @author: xftan @@ -28,8 +29,9 @@ func TestInfluxdb(t *testing.T) { rand.Seed(time.Now().UnixNano()) viper.Set("smlAutoCreateDB", true) defer viper.Set("smlAutoCreateDB", false) - config.Init() viper.Set("influxdb.enable", true) + config.Init() + log.ConfigLog() db.PrepareConnection() p := Influxdb{} router := gin.Default() diff --git a/plugin/nodeexporter/plugin_test.go b/plugin/nodeexporter/plugin_test.go index 71bb0a60..c6340d87 100644 --- a/plugin/nodeexporter/plugin_test.go +++ b/plugin/nodeexporter/plugin_test.go @@ -12,6 +12,7 @@ import ( "github.com/taosdata/driver-go/v3/af" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" ) var s = ` @@ -37,6 +38,7 @@ test_metric{label="value"} 1.0 1490802350000 // @description: test node-exporter plugin func TestNodeExporter_Gather(t *testing.T) { config.Init() + log.ConfigLog() db.PrepareConnection() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte(s)) diff --git a/plugin/opentsdb/plugin_test.go b/plugin/opentsdb/plugin_test.go index 10a4d4c5..22451901 100644 --- a/plugin/opentsdb/plugin_test.go +++ b/plugin/opentsdb/plugin_test.go @@ -18,6 +18,7 @@ import ( "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" ) // @author: xftan @@ -30,6 +31,7 @@ func TestOpentsdb(t *testing.T) { defer viper.Set("smlAutoCreateDB", false) config.Init() viper.Set("opentsdb.enable", true) + log.ConfigLog() db.PrepareConnection() p := Plugin{} diff --git a/plugin/opentsdbtelnet/plugin_test.go b/plugin/opentsdbtelnet/plugin_test.go index 6a8ffd10..d2c8b3f4 100644 --- a/plugin/opentsdbtelnet/plugin_test.go +++ b/plugin/opentsdbtelnet/plugin_test.go @@ -15,6 +15,7 @@ import ( "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/plugin/opentsdbtelnet" ) @@ -26,6 +27,7 @@ func TestPlugin(t *testing.T) { rand.Seed(time.Now().UnixNano()) p := &opentsdbtelnet.Plugin{} config.Init() + log.ConfigLog() db.PrepareConnection() viper.Set("opentsdb_telnet.enable", true) viper.Set("opentsdb_telnet.batchSize", 1) diff --git a/plugin/prometheus/plugin_test.go b/plugin/prometheus/plugin_test.go index 7b77e1ff..76409f49 100644 --- a/plugin/prometheus/plugin_test.go +++ b/plugin/prometheus/plugin_test.go @@ -17,6 +17,7 @@ import ( "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" ) func TestMain(m *testing.M) { @@ -24,6 +25,7 @@ func TestMain(m *testing.M) { rand.Seed(time.Now().UnixNano()) config.Init() viper.Set("prometheus.enable", true) + log.ConfigLog() db.PrepareConnection() conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { diff --git a/plugin/statsd/plugin_test.go b/plugin/statsd/plugin_test.go index 06a03807..14b3aee5 100644 --- a/plugin/statsd/plugin_test.go +++ b/plugin/statsd/plugin_test.go @@ -15,6 +15,7 @@ import ( "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" ) // @author: xftan @@ -25,6 +26,7 @@ func TestStatsd(t *testing.T) { rand.Seed(time.Now().UnixNano()) p := &Plugin{} config.Init() + log.ConfigLog() db.PrepareConnection() viper.Set("statsd.gatherInterval", time.Millisecond) viper.Set("statsd.enable", true) diff --git a/schemaless/capi/influxdb_test.go b/schemaless/capi/influxdb_test.go index 85e20634..45527e6f 100644 --- a/schemaless/capi/influxdb_test.go +++ b/schemaless/capi/influxdb_test.go @@ -4,9 +4,9 @@ import ( "testing" "unsafe" - "github.com/sirupsen/logrus" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" + "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/schemaless/capi" ) @@ -79,7 +79,7 @@ func TestInsertInfluxdb(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - logger := logrus.New().WithField("test", "TestInsertInfluxdb") + logger := log.GetLogger("test").WithField("test", "TestInsertInfluxdb").WithField("name", tt.name) err := capi.InsertInfluxdb(tt.args.taosConnect, tt.args.data, tt.args.db, tt.args.precision, tt.args.ttl, 0, "", logger) if (err != nil) != tt.wantErr { t.Errorf("InsertInfluxdb() error = %v, wantErr %v", err, tt.wantErr) diff --git a/schemaless/capi/opentsdb_test.go b/schemaless/capi/opentsdb_test.go index 18a83325..c6e52e92 100644 --- a/schemaless/capi/opentsdb_test.go +++ b/schemaless/capi/opentsdb_test.go @@ -6,16 +6,20 @@ import ( "time" "unsafe" - "github.com/sirupsen/logrus" + "github.com/spf13/viper" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/schemaless/capi" ) func TestMain(m *testing.M) { + viper.Set("logLevel", "trace") + viper.Set("uploadKeeper.enable", false) config.Init() + log.ConfigLog() db.PrepareConnection() m.Run() } @@ -94,10 +98,10 @@ func TestInsertOpentsdbTelnet(t *testing.T) { wantErr: false, }, } - logger := logrus.New().WithField("test", "TestInsertOpentsdbTelnet") + logger := log.GetLogger("test").WithField("test", "TestInsertOpentsdbTelnet") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := capi.InsertOpentsdbTelnet(tt.args.taosConnect, []string{tt.args.data}, tt.args.db, tt.args.ttl, 0, "", logger); (err != nil) != tt.wantErr { + if err := capi.InsertOpentsdbTelnet(tt.args.taosConnect, []string{tt.args.data}, tt.args.db, tt.args.ttl, 0, "", logger.WithField("name", tt.name)); (err != nil) != tt.wantErr { t.Errorf("InsertOpentsdbTelnet() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -120,7 +124,7 @@ func BenchmarkTelnet(b *testing.B) { } wrapper.TaosFreeResult(r) }() - logger := logrus.New().WithField("test", "BenchmarkTelnet") + logger := log.GetLogger("test").WithField("test", "BenchmarkTelnet") for i := 0; i < b.N; i++ { //`sys.if.bytes.out`,`host`=web01,`interface`=eth0 //t_98df8453856519710bfc2f1b5f8202cf @@ -234,7 +238,7 @@ func TestInsertOpentsdbJson(t *testing.T) { wantErr: false, }, } - logger := logrus.New().WithField("test", "TestInsertOpentsdbJson") + logger := log.GetLogger("test").WithField("test", "TestInsertOpentsdbJson") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := capi.InsertOpentsdbJson(tt.args.taosConnect, tt.args.data, tt.args.db, tt.args.ttl, 0, "", logger); (err != nil) != tt.wantErr { @@ -292,7 +296,7 @@ func TestInsertOpentsdbTelnetBatch(t *testing.T) { wantErr: false, }, } - logger := logrus.New().WithField("test", "TestInsertOpentsdbTelnetBatch") + logger := log.GetLogger("test").WithField("test", "TestInsertOpentsdbTelnetBatch") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := capi.InsertOpentsdbTelnet(tt.args.taosConnect, tt.args.data, tt.args.db, tt.args.ttl, 0, "", logger); (err != nil) != tt.wantErr { From 7b550cacde61d103466269e2c9167a3826e313fa Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 2 Dec 2024 14:13:10 +0800 Subject: [PATCH 21/48] refactor: remove driver-go dependency --- controller/rest/configcontroller.go | 4 +- controller/rest/restful.go | 8 +- controller/rest/table_vgid.go | 4 +- controller/ws/query/ws.go | 6 +- controller/ws/schemaless/schemaless.go | 6 +- controller/ws/schemaless/schemaless_test.go | 10 +- controller/ws/stmt/convert.go | 8 +- controller/ws/stmt/convert_test.go | 6 +- controller/ws/stmt/stmt.go | 12 +- controller/ws/tmq/tmq.go | 10 +- controller/ws/tmq/tmq_test.go | 4 +- controller/ws/ws/fetch.go | 4 +- controller/ws/ws/handler.go | 2 +- controller/ws/ws/misc.go | 2 +- controller/ws/ws/query.go | 6 +- controller/ws/ws/query_result.go | 4 +- controller/ws/ws/query_test.go | 2 +- controller/ws/ws/raw.go | 4 +- controller/ws/ws/schemaless.go | 2 +- controller/ws/ws/schemaless_test.go | 14 +- controller/ws/ws/stmt.go | 12 +- controller/ws/ws/stmt2.go | 6 +- controller/ws/ws/stmt2_test.go | 4 +- controller/ws/ws/stmt_test.go | 10 +- controller/ws/wstool/error.go | 2 +- controller/ws/wstool/error_test.go | 2 +- db/async/handlerpool.go | 2 +- db/async/row.go | 4 +- db/async/row_test.go | 2 +- db/async/stmt2pool.go | 2 +- db/asynctmq/tmq.go | 2 +- db/asynctmq/tmq_windows.go | 2 +- db/asynctmq/tmqcb.go | 2 +- db/asynctmq/tmqhandle/handler.go | 2 +- db/commonpool/pool.go | 6 +- db/commonpool/pool_test.go | 2 +- db/init.go | 6 +- db/syncinterface/wrapper.go | 6 +- db/syncinterface/wrapper_test.go | 14 +- db/tool/createdb.go | 4 +- db/tool/createdb_test.go | 4 +- db/tool/notify.go | 8 +- db/tool/notify_test.go | 6 +- driver/common/change.go | 34 + driver/common/change_test.go | 101 + driver/common/column.go | 46 + driver/common/const.go | 73 + driver/common/datatype.go | 89 + driver/common/param/column.go | 220 + driver/common/param/column_test.go | 435 ++ driver/common/param/param.go | 337 ++ driver/common/param/param_test.go | 654 +++ driver/common/parser/block.go | 374 ++ driver/common/parser/block_test.go | 797 +++ driver/common/parser/mem.go | 12 + driver/common/parser/mem.s | 0 driver/common/parser/mem_test.go | 20 + driver/common/parser/raw.go | 184 + driver/common/parser/raw_test.go | 1049 ++++ driver/common/serializer/block.go | 552 ++ driver/common/serializer/block_test.go | 397 ++ driver/common/stmt/field.go | 73 + driver/common/stmt/field_test.go | 143 + driver/common/stmt/stmt2.go | 580 +++ driver/common/stmt/stmt2_test.go | 2437 +++++++++ driver/common/tmq/config.go | 34 + driver/common/tmq/config_test.go | 52 + driver/common/tmq/event.go | 204 + driver/common/tmq/event_test.go | 352 ++ driver/common/tmq/tmq.go | 87 + driver/common/tmq/tmq_test.go | 197 + driver/errors/errors.go | 30 + driver/errors/errors_test.go | 52 + driver/types/taostype.go | 54 + driver/types/types.go | 492 ++ driver/types/types_test.go | 2122 ++++++++ driver/wrapper/asynccb.go | 38 + driver/wrapper/block.go | 49 + driver/wrapper/block_test.go | 924 ++++ driver/wrapper/cgo/README.md | 1 + driver/wrapper/cgo/handle.go | 81 + driver/wrapper/cgo/handle_test.go | 107 + driver/wrapper/field.go | 67 + driver/wrapper/field_test.go | 483 ++ driver/wrapper/notify.go | 22 + driver/wrapper/notify_test.go | 100 + driver/wrapper/notifycb.go | 36 + driver/wrapper/row.go | 77 + driver/wrapper/row_test.go | 628 +++ driver/wrapper/schemaless.go | 233 + driver/wrapper/schemaless_test.go | 744 +++ driver/wrapper/setconfig.go | 42 + driver/wrapper/setconfig_test.go | 41 + driver/wrapper/stmt.go | 756 +++ driver/wrapper/stmt2.go | 857 ++++ driver/wrapper/stmt2_test.go | 5076 +++++++++++++++++++ driver/wrapper/stmt2async.go | 26 + driver/wrapper/stmt_test.go | 1367 +++++ driver/wrapper/taosc.go | 289 ++ driver/wrapper/taosc_test.go | 607 +++ driver/wrapper/tmq.go | 334 ++ driver/wrapper/tmq_test.go | 2012 ++++++++ driver/wrapper/tmqcb.go | 49 + driver/wrapper/whitelist.go | 29 + driver/wrapper/whitelist_test.go | 21 + driver/wrapper/whitelistcb.go | 35 + driver/wrapper/whitelistcb_test.go | 57 + go.mod | 1 - go.sum | 2 - plugin/collectd/config.go | 2 +- plugin/collectd/plugin_test.go | 82 +- plugin/influxdb/plugin.go | 2 +- plugin/influxdb/plugin_test.go | 104 +- plugin/nodeexporter/config.go | 2 +- plugin/nodeexporter/plugin.go | 2 +- plugin/nodeexporter/plugin_test.go | 91 +- plugin/opentsdb/plugin.go | 2 +- plugin/opentsdb/plugin_test.go | 129 +- plugin/opentsdbtelnet/config.go | 2 +- plugin/opentsdbtelnet/plugin_test.go | 85 +- plugin/prometheus/plugin.go | 2 +- plugin/prometheus/plugin_test.go | 2 +- plugin/prometheus/process.go | 6 +- plugin/statsd/config.go | 2 +- plugin/statsd/plugin_test.go | 85 +- schemaless/capi/influxdb.go | 4 +- schemaless/capi/influxdb_test.go | 4 +- schemaless/capi/opentsdb.go | 4 +- schemaless/capi/opentsdb_test.go | 4 +- tools/ctools/block.go | 4 +- tools/ctools/block_test.go | 4 +- tools/parseblock/parse.go | 2 +- tools/parseblock/parse_test.go | 2 +- version/version.go | 2 +- 134 files changed, 27933 insertions(+), 399 deletions(-) create mode 100644 driver/common/change.go create mode 100644 driver/common/change_test.go create mode 100644 driver/common/column.go create mode 100644 driver/common/const.go create mode 100644 driver/common/datatype.go create mode 100644 driver/common/param/column.go create mode 100644 driver/common/param/column_test.go create mode 100644 driver/common/param/param.go create mode 100644 driver/common/param/param_test.go create mode 100644 driver/common/parser/block.go create mode 100644 driver/common/parser/block_test.go create mode 100644 driver/common/parser/mem.go create mode 100644 driver/common/parser/mem.s create mode 100644 driver/common/parser/mem_test.go create mode 100644 driver/common/parser/raw.go create mode 100644 driver/common/parser/raw_test.go create mode 100644 driver/common/serializer/block.go create mode 100644 driver/common/serializer/block_test.go create mode 100644 driver/common/stmt/field.go create mode 100644 driver/common/stmt/field_test.go create mode 100644 driver/common/stmt/stmt2.go create mode 100644 driver/common/stmt/stmt2_test.go create mode 100644 driver/common/tmq/config.go create mode 100644 driver/common/tmq/config_test.go create mode 100644 driver/common/tmq/event.go create mode 100644 driver/common/tmq/event_test.go create mode 100644 driver/common/tmq/tmq.go create mode 100644 driver/common/tmq/tmq_test.go create mode 100644 driver/errors/errors.go create mode 100644 driver/errors/errors_test.go create mode 100644 driver/types/taostype.go create mode 100644 driver/types/types.go create mode 100644 driver/types/types_test.go create mode 100644 driver/wrapper/asynccb.go create mode 100644 driver/wrapper/block.go create mode 100644 driver/wrapper/block_test.go create mode 100644 driver/wrapper/cgo/README.md create mode 100644 driver/wrapper/cgo/handle.go create mode 100644 driver/wrapper/cgo/handle_test.go create mode 100644 driver/wrapper/field.go create mode 100644 driver/wrapper/field_test.go create mode 100644 driver/wrapper/notify.go create mode 100644 driver/wrapper/notify_test.go create mode 100644 driver/wrapper/notifycb.go create mode 100644 driver/wrapper/row.go create mode 100644 driver/wrapper/row_test.go create mode 100644 driver/wrapper/schemaless.go create mode 100644 driver/wrapper/schemaless_test.go create mode 100644 driver/wrapper/setconfig.go create mode 100644 driver/wrapper/setconfig_test.go create mode 100644 driver/wrapper/stmt.go create mode 100644 driver/wrapper/stmt2.go create mode 100644 driver/wrapper/stmt2_test.go create mode 100644 driver/wrapper/stmt2async.go create mode 100644 driver/wrapper/stmt_test.go create mode 100644 driver/wrapper/taosc.go create mode 100644 driver/wrapper/taosc_test.go create mode 100644 driver/wrapper/tmq.go create mode 100644 driver/wrapper/tmq_test.go create mode 100644 driver/wrapper/tmqcb.go create mode 100644 driver/wrapper/whitelist.go create mode 100644 driver/wrapper/whitelist_test.go create mode 100644 driver/wrapper/whitelistcb.go create mode 100644 driver/wrapper/whitelistcb_test.go diff --git a/controller/rest/configcontroller.go b/controller/rest/configcontroller.go index e5490213..a87e5d7e 100644 --- a/controller/rest/configcontroller.go +++ b/controller/rest/configcontroller.go @@ -7,11 +7,11 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - taoserrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/db/commonpool" "github.com/taosdata/taosadapter/v3/db/tool" + taoserrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/iptool" ) diff --git a/controller/rest/restful.go b/controller/rest/restful.go index d7667e52..8b56bd32 100644 --- a/controller/rest/restful.go +++ b/controller/rest/restful.go @@ -13,14 +13,14 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/parser" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/commonpool" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" diff --git a/controller/rest/table_vgid.go b/controller/rest/table_vgid.go index 42f135fb..6ea6dc91 100644 --- a/controller/rest/table_vgid.go +++ b/controller/rest/table_vgid.go @@ -6,10 +6,10 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/db/commonpool" "github.com/taosdata/taosadapter/v3/db/syncinterface" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/connectpool" "github.com/taosdata/taosadapter/v3/tools/iptool" diff --git a/controller/ws/query/ws.go b/controller/ws/query/ws.go index f1e2ea04..ab1d4f9f 100644 --- a/controller/ws/query/ws.go +++ b/controller/ws/query/ws.go @@ -16,14 +16,14 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common/parser" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" diff --git a/controller/ws/schemaless/schemaless.go b/controller/ws/schemaless/schemaless.go index 26cdcb99..6129a75c 100644 --- a/controller/ws/schemaless/schemaless.go +++ b/controller/ws/schemaless/schemaless.go @@ -10,14 +10,14 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/taosdata/taosadapter/v3/tools/iptool" diff --git a/controller/ws/schemaless/schemaless_test.go b/controller/ws/schemaless/schemaless_test.go index c2f4e25b..1e54c2ae 100644 --- a/controller/ws/schemaless/schemaless_test.go +++ b/controller/ws/schemaless/schemaless_test.go @@ -14,12 +14,12 @@ import ( "github.com/gorilla/websocket" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/ws/schemaless" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" _ "github.com/taosdata/taosadapter/v3/controller/rest" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) @@ -68,7 +68,7 @@ func TestRestful_InitSchemaless(t *testing.T) { }{ { name: "influxdb", - protocol: schemaless.InfluxDBLineProtocol, + protocol: wrapper.InfluxDBLineProtocol, precision: "ms", data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + @@ -81,7 +81,7 @@ func TestRestful_InitSchemaless(t *testing.T) { }, { name: "opentsdb_telnet", - protocol: schemaless.OpenTSDBTelnetLineProtocol, + protocol: wrapper.OpenTSDBTelnetLineProtocol, precision: "ms", data: "meters.current 1648432611249 10.3 location=California.SanFrancisco group=2\n" + "meters.current 1648432611250 12.6 location=California.SanFrancisco group=2\n" + @@ -98,7 +98,7 @@ func TestRestful_InitSchemaless(t *testing.T) { }, { name: "opentsdb_json", - protocol: schemaless.OpenTSDBJsonFormatProtocol, + protocol: wrapper.OpenTSDBJsonFormatProtocol, precision: "ms", data: `[ { @@ -202,7 +202,7 @@ func TestRestful_InitSchemaless(t *testing.T) { assert.NoError(t, err, string(msg)) assert.Equal(t, reqID, schemalessResp.ReqID) assert.Equal(t, 0, schemalessResp.Code, schemalessResp.Message) - if c.protocol != schemaless.OpenTSDBJsonFormatProtocol { + if c.protocol != wrapper.OpenTSDBJsonFormatProtocol { assert.Equal(t, c.totalRows, schemalessResp.TotalRows) } assert.Equal(t, c.affectedRows, schemalessResp.AffectedRows) diff --git a/controller/ws/stmt/convert.go b/controller/ws/stmt/convert.go index 17537dab..7791c186 100644 --- a/controller/ws/stmt/convert.go +++ b/controller/ws/stmt/convert.go @@ -10,10 +10,10 @@ import ( "unsafe" jsoniter "github.com/json-iterator/go" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/parser" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - "github.com/taosdata/driver-go/v3/types" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" + "github.com/taosdata/taosadapter/v3/driver/types" "github.com/taosdata/taosadapter/v3/tools" ) diff --git a/controller/ws/stmt/convert_test.go b/controller/ws/stmt/convert_test.go index 9d07e170..a6fd9721 100644 --- a/controller/ws/stmt/convert_test.go +++ b/controller/ws/stmt/convert_test.go @@ -9,9 +9,9 @@ import ( "unsafe" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - "github.com/taosdata/driver-go/v3/types" + "github.com/taosdata/taosadapter/v3/driver/common" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" + "github.com/taosdata/taosadapter/v3/driver/types" ) func Test_stmtParseColumn(t *testing.T) { diff --git a/controller/ws/stmt/stmt.go b/controller/ws/stmt/stmt.go index 201dbc45..2df4857e 100644 --- a/controller/ws/stmt/stmt.go +++ b/controller/ws/stmt/stmt.go @@ -14,17 +14,17 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common/parser" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/types" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/types" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools" diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index beba3711..cdbcb18e 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -13,11 +13,6 @@ import ( "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/parser" - taoserrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" @@ -25,6 +20,11 @@ import ( "github.com/taosdata/taosadapter/v3/db/asynctmq/tmqhandle" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + taoserrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/thread" diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index 4e50128b..69fabf70 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -22,14 +22,14 @@ import ( jsoniter "github.com/json-iterator/go" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common/parser" - "github.com/taosdata/driver-go/v3/common/tmq" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller" _ "github.com/taosdata/taosadapter/v3/controller/rest" "github.com/taosdata/taosadapter/v3/controller/ws/query" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/common/tmq" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/parseblock" ) diff --git a/controller/ws/ws/fetch.go b/controller/ws/ws/fetch.go index 9f52a0cf..dd63301a 100644 --- a/controller/ws/ws/fetch.go +++ b/controller/ws/ws/fetch.go @@ -5,10 +5,10 @@ import ( "encoding/binary" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common/parser" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/async" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/bytesutil" "github.com/taosdata/taosadapter/v3/tools/melody" diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index ad319761..4e758b08 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -11,11 +11,11 @@ import ( jsoniter "github.com/json-iterator/go" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/iptool" diff --git a/controller/ws/ws/misc.go b/controller/ws/ws/misc.go index 98f78322..da4e6342 100644 --- a/controller/ws/ws/misc.go +++ b/controller/ws/ws/misc.go @@ -4,9 +4,9 @@ import ( "context" "github.com/sirupsen/logrus" - errors2 "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/syncinterface" + errors2 "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/tools/melody" ) diff --git a/controller/ws/ws/query.go b/controller/ws/ws/query.go index 08256725..e0679d74 100644 --- a/controller/ws/ws/query.go +++ b/controller/ws/ws/query.go @@ -7,13 +7,13 @@ import ( "fmt" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common" - errors2 "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/driver/common" + errors2 "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/tools/bytesutil" diff --git a/controller/ws/ws/query_result.go b/controller/ws/ws/query_result.go index 27028489..9a1a8d24 100644 --- a/controller/ws/ws/query_result.go +++ b/controller/ws/ws/query_result.go @@ -8,10 +8,10 @@ import ( "unsafe" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/log" ) diff --git a/controller/ws/ws/query_test.go b/controller/ws/ws/query_test.go index 825ca8d4..93c58bb4 100644 --- a/controller/ws/ws/query_test.go +++ b/controller/ws/ws/query_test.go @@ -14,8 +14,8 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/driver/common/parser" "github.com/taosdata/taosadapter/v3/tools/parseblock" ) diff --git a/controller/ws/ws/raw.go b/controller/ws/ws/raw.go index 89111c02..a2045196 100644 --- a/controller/ws/ws/raw.go +++ b/controller/ws/ws/raw.go @@ -5,9 +5,9 @@ import ( "unsafe" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common/parser" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/melody" ) diff --git a/controller/ws/ws/schemaless.go b/controller/ws/ws/schemaless.go index 093464b4..4c3f9090 100644 --- a/controller/ws/ws/schemaless.go +++ b/controller/ws/ws/schemaless.go @@ -4,9 +4,9 @@ import ( "context" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/tools/melody" ) diff --git a/controller/ws/ws/schemaless_test.go b/controller/ws/ws/schemaless_test.go index e925a2d6..e6c04a2c 100644 --- a/controller/ws/ws/schemaless_test.go +++ b/controller/ws/ws/schemaless_test.go @@ -8,7 +8,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/ws/schemaless" + "github.com/taosdata/taosadapter/v3/driver/wrapper" ) func TestWsSchemaless(t *testing.T) { @@ -43,7 +43,7 @@ func TestWsSchemaless(t *testing.T) { }{ { name: "influxdb", - protocol: schemaless.InfluxDBLineProtocol, + protocol: wrapper.InfluxDBLineProtocol, precision: "ms", data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + @@ -55,7 +55,7 @@ func TestWsSchemaless(t *testing.T) { }, { name: "opentsdb_telnet", - protocol: schemaless.OpenTSDBTelnetLineProtocol, + protocol: wrapper.OpenTSDBTelnetLineProtocol, precision: "ms", data: "meters.current 1648432611249 10.3 location=California.SanFrancisco group=2\n" + "meters.current 1648432611250 12.6 location=California.SanFrancisco group=2\n" + @@ -71,7 +71,7 @@ func TestWsSchemaless(t *testing.T) { }, { name: "opentsdb_json", - protocol: schemaless.OpenTSDBJsonFormatProtocol, + protocol: wrapper.OpenTSDBJsonFormatProtocol, precision: "ms", data: `[ { @@ -116,7 +116,7 @@ func TestWsSchemaless(t *testing.T) { }, { name: "influxdb_tbnamekey", - protocol: schemaless.InfluxDBLineProtocol, + protocol: wrapper.InfluxDBLineProtocol, precision: "ms", data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + @@ -158,7 +158,7 @@ func TestWsSchemaless(t *testing.T) { assert.NoError(t, err, string(resp)) assert.Equal(t, reqID, schemalessResp.ReqID) assert.Equal(t, 0, schemalessResp.Code, schemalessResp.Message) - if c.protocol != schemaless.OpenTSDBJsonFormatProtocol { + if c.protocol != wrapper.OpenTSDBJsonFormatProtocol { assert.Equal(t, c.totalRows, schemalessResp.TotalRows) } assert.Equal(t, c.affectedRows, schemalessResp.AffectedRows) @@ -203,7 +203,7 @@ func TestWsSchemalessError(t *testing.T) { }, { name: "wrong timestamp", - protocol: schemaless.InfluxDBLineProtocol, + protocol: wrapper.InfluxDBLineProtocol, precision: "ms", data: "measurement,host=host1 field1=2i,field2=2.0 10", }, diff --git a/controller/ws/ws/stmt.go b/controller/ws/ws/stmt.go index 7387bb71..58488c74 100644 --- a/controller/ws/ws/stmt.go +++ b/controller/ws/ws/stmt.go @@ -8,15 +8,15 @@ import ( "unsafe" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/parser" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - errors2 "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/types" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/controller/ws/stmt" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" + errors2 "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/types" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/jsontype" diff --git a/controller/ws/ws/stmt2.go b/controller/ws/ws/stmt2.go index b06ff7b3..f5648547 100644 --- a/controller/ws/ws/stmt2.go +++ b/controller/ws/ws/stmt2.go @@ -8,12 +8,12 @@ import ( "unsafe" "github.com/sirupsen/logrus" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - errors2 "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/syncinterface" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" + errors2 "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/jsontype" "github.com/taosdata/taosadapter/v3/tools/melody" diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go index e018ae3c..cd83fde6 100644 --- a/controller/ws/ws/stmt2_test.go +++ b/controller/ws/ws/stmt2_test.go @@ -13,9 +13,9 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/driver/common" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" "github.com/taosdata/taosadapter/v3/tools/parseblock" ) diff --git a/controller/ws/ws/stmt_test.go b/controller/ws/ws/stmt_test.go index 8c752814..fcef5483 100644 --- a/controller/ws/ws/stmt_test.go +++ b/controller/ws/ws/stmt_test.go @@ -14,12 +14,12 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/param" - "github.com/taosdata/driver-go/v3/common/serializer" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - "github.com/taosdata/driver-go/v3/types" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/param" + "github.com/taosdata/taosadapter/v3/driver/common/serializer" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" + "github.com/taosdata/taosadapter/v3/driver/types" "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/taosdata/taosadapter/v3/tools/parseblock" ) diff --git a/controller/ws/wstool/error.go b/controller/ws/wstool/error.go index 5551d126..7768bfd7 100644 --- a/controller/ws/wstool/error.go +++ b/controller/ws/wstool/error.go @@ -4,7 +4,7 @@ import ( "context" "github.com/sirupsen/logrus" - tErrors "github.com/taosdata/driver-go/v3/errors" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/tools/melody" ) diff --git a/controller/ws/wstool/error_test.go b/controller/ws/wstool/error_test.go index fd85f964..88b6f63d 100644 --- a/controller/ws/wstool/error_test.go +++ b/controller/ws/wstool/error_test.go @@ -12,7 +12,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" - tErrors "github.com/taosdata/driver-go/v3/errors" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/melody" ) diff --git a/db/async/handlerpool.go b/db/async/handlerpool.go index 79b37877..3ba5e540 100644 --- a/db/async/handlerpool.go +++ b/db/async/handlerpool.go @@ -5,7 +5,7 @@ import ( "sync" "unsafe" - "github.com/taosdata/driver-go/v3/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" ) type Result struct { diff --git a/db/async/row.go b/db/async/row.go index c012c2fa..dc3c77ce 100644 --- a/db/async/row.go +++ b/db/async/row.go @@ -8,9 +8,9 @@ import ( "unsafe" "github.com/sirupsen/logrus" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/thread" diff --git a/db/async/row_test.go b/db/async/row_test.go index 548a2864..2b7572d9 100644 --- a/db/async/row_test.go +++ b/db/async/row_test.go @@ -6,8 +6,8 @@ import ( "unsafe" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) diff --git a/db/async/stmt2pool.go b/db/async/stmt2pool.go index 5e863f07..91630ca9 100644 --- a/db/async/stmt2pool.go +++ b/db/async/stmt2pool.go @@ -3,7 +3,7 @@ package async import ( "unsafe" - "github.com/taosdata/driver-go/v3/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" ) type Stmt2Result struct { diff --git a/db/asynctmq/tmq.go b/db/asynctmq/tmq.go index 4582c849..5687f8b9 100644 --- a/db/asynctmq/tmq.go +++ b/db/asynctmq/tmq.go @@ -643,7 +643,7 @@ import "C" import ( "unsafe" - "github.com/taosdata/driver-go/v3/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" ) // InitTMQThread tmq_thread *init_tmq_thread() diff --git a/db/asynctmq/tmq_windows.go b/db/asynctmq/tmq_windows.go index 71c0b65c..eb4331f4 100644 --- a/db/asynctmq/tmq_windows.go +++ b/db/asynctmq/tmq_windows.go @@ -656,7 +656,7 @@ import "C" import ( "unsafe" - "github.com/taosdata/driver-go/v3/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" ) // InitTMQThread tmq_thread *init_tmq_thread() diff --git a/db/asynctmq/tmqcb.go b/db/asynctmq/tmqcb.go index efbc071f..67e134ae 100644 --- a/db/asynctmq/tmqcb.go +++ b/db/asynctmq/tmqcb.go @@ -10,8 +10,8 @@ import "C" import ( "unsafe" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/db/asynctmq/tmqhandle" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" ) //export AdapterTMQPollCallback diff --git a/db/asynctmq/tmqhandle/handler.go b/db/asynctmq/tmqhandle/handler.go index b915a723..4927e0b1 100644 --- a/db/asynctmq/tmqhandle/handler.go +++ b/db/asynctmq/tmqhandle/handler.go @@ -6,7 +6,7 @@ import ( "sync" "unsafe" - "github.com/taosdata/driver-go/v3/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" ) type FetchRawBlockResult struct { diff --git a/db/commonpool/pool.go b/db/commonpool/pool.go index 499f24af..9af5d545 100644 --- a/db/commonpool/pool.go +++ b/db/commonpool/pool.go @@ -11,12 +11,12 @@ import ( "unsafe" "github.com/sirupsen/logrus" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/connectpool" diff --git a/db/commonpool/pool_test.go b/db/commonpool/pool_test.go index 4d7b4e03..838415c4 100644 --- a/db/commonpool/pool_test.go +++ b/db/commonpool/pool_test.go @@ -8,8 +8,8 @@ import ( "unsafe" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" + "github.com/taosdata/taosadapter/v3/driver/wrapper" ) func TestMain(m *testing.M) { diff --git a/db/init.go b/db/init.go index b8600039..8d8638ff 100644 --- a/db/init.go +++ b/db/init.go @@ -3,11 +3,11 @@ package db import ( "sync" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/async" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) diff --git a/db/syncinterface/wrapper.go b/db/syncinterface/wrapper.go index 96e57c65..00b508e1 100644 --- a/db/syncinterface/wrapper.go +++ b/db/syncinterface/wrapper.go @@ -6,9 +6,9 @@ import ( "unsafe" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/types" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/driver/types" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/thread" ) diff --git a/db/syncinterface/wrapper_test.go b/db/syncinterface/wrapper_test.go index a22ca7e0..25852e27 100644 --- a/db/syncinterface/wrapper_test.go +++ b/db/syncinterface/wrapper_test.go @@ -7,15 +7,15 @@ import ( "unsafe" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/parser" - stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" - taoserrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/types" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" + taoserrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/types" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/generator" ) diff --git a/db/tool/createdb.go b/db/tool/createdb.go index 4391f9ac..4ddf277c 100644 --- a/db/tool/createdb.go +++ b/db/tool/createdb.go @@ -4,11 +4,11 @@ import ( "unsafe" "github.com/sirupsen/logrus" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/tools/pool" ) diff --git a/db/tool/createdb_test.go b/db/tool/createdb_test.go index f2465768..5214652a 100644 --- a/db/tool/createdb_test.go +++ b/db/tool/createdb_test.go @@ -7,10 +7,10 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/assert" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) diff --git a/db/tool/notify.go b/db/tool/notify.go index cf5dd8f5..65dc224f 100644 --- a/db/tool/notify.go +++ b/db/tool/notify.go @@ -5,10 +5,10 @@ import ( "strings" "unsafe" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/thread" ) diff --git a/db/tool/notify_test.go b/db/tool/notify_test.go index 6c9a8a17..c1df3fbc 100644 --- a/db/tool/notify_test.go +++ b/db/tool/notify_test.go @@ -8,9 +8,9 @@ import ( "unsafe" "github.com/stretchr/testify/assert" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" ) func TestWhiteListHandle(t *testing.T) { diff --git a/driver/common/change.go b/driver/common/change.go new file mode 100644 index 00000000..d287162e --- /dev/null +++ b/driver/common/change.go @@ -0,0 +1,34 @@ +package common + +import ( + "fmt" + "time" +) + +func TimestampConvertToTime(timestamp int64, precision int) time.Time { + switch precision { + case PrecisionMilliSecond: // milli-second + return time.Unix(0, timestamp*1e6) + case PrecisionMicroSecond: // micro-second + return time.Unix(0, timestamp*1e3) + case PrecisionNanoSecond: // nano-second + return time.Unix(0, timestamp) + default: + s := fmt.Sprintln("unknown precision", precision, "timestamp", timestamp) + panic(s) + } +} + +func TimeToTimestamp(t time.Time, precision int) (timestamp int64) { + switch precision { + case PrecisionMilliSecond: + return t.UnixNano() / 1e6 + case PrecisionMicroSecond: + return t.UnixNano() / 1e3 + case PrecisionNanoSecond: + return t.UnixNano() + default: + s := fmt.Sprintln("unknown precision", precision, "time", t) + panic(s) + } +} diff --git a/driver/common/change_test.go b/driver/common/change_test.go new file mode 100644 index 00000000..6c5c727d --- /dev/null +++ b/driver/common/change_test.go @@ -0,0 +1,101 @@ +package common + +import ( + "reflect" + "testing" + "time" +) + +// @author: xftan +// @date: 2022/1/25 16:55 +// @description: test timestamp with precision convert to time.Time +func TestTimestampConvertToTime(t *testing.T) { + type args struct { + timestamp int64 + precision int + } + tests := []struct { + name string + args args + want time.Time + }{ + { + name: "ms", + args: args{ + timestamp: 1643068800000, + precision: PrecisionMilliSecond, + }, + want: time.Date(2022, 01, 25, 0, 0, 0, 0, time.UTC), + }, + { + name: "us", + args: args{ + timestamp: 1643068800000000, + precision: PrecisionMicroSecond, + }, + want: time.Date(2022, 01, 25, 0, 0, 0, 0, time.UTC), + }, + { + name: "ns", + args: args{ + timestamp: 1643068800000000000, + precision: PrecisionNanoSecond, + }, + want: time.Date(2022, 01, 25, 0, 0, 0, 0, time.UTC), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TimestampConvertToTime(tt.args.timestamp, tt.args.precision); !reflect.DeepEqual(got.UTC(), tt.want.UTC()) { + t.Errorf("TimestampConvertToTime() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/25 16:56 +// @description: test time.Time with precision convert to timestamp +func TestTimeToTimestamp(t *testing.T) { + type args struct { + t time.Time + precision int + } + tests := []struct { + name string + args args + wantTimestamp int64 + }{ + { + name: "ms", + args: args{ + t: time.Date(2022, 01, 25, 0, 0, 0, 0, time.UTC), + precision: PrecisionMilliSecond, + }, + wantTimestamp: 1643068800000, + }, + { + name: "us", + args: args{ + t: time.Date(2022, 01, 25, 0, 0, 0, 0, time.UTC), + precision: PrecisionMicroSecond, + }, + wantTimestamp: 1643068800000000, + }, + { + name: "ns", + args: args{ + t: time.Date(2022, 01, 25, 0, 0, 0, 0, time.UTC), + precision: PrecisionNanoSecond, + }, + wantTimestamp: 1643068800000000000, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotTimestamp := TimeToTimestamp(tt.args.t, tt.args.precision); gotTimestamp != tt.wantTimestamp { + t.Errorf("TimeToTimestamp() = %v, want %v", gotTimestamp, tt.wantTimestamp) + } + }) + } +} diff --git a/driver/common/column.go b/driver/common/column.go new file mode 100644 index 00000000..a092fb6d --- /dev/null +++ b/driver/common/column.go @@ -0,0 +1,46 @@ +package common + +import ( + "reflect" + + "github.com/taosdata/taosadapter/v3/driver/types" +) + +var ( + NullInt8 = reflect.TypeOf(types.NullInt8{}) + NullInt16 = reflect.TypeOf(types.NullInt16{}) + NullInt32 = reflect.TypeOf(types.NullInt32{}) + NullInt64 = reflect.TypeOf(types.NullInt64{}) + NullUInt8 = reflect.TypeOf(types.NullUInt8{}) + NullUInt16 = reflect.TypeOf(types.NullUInt16{}) + NullUInt32 = reflect.TypeOf(types.NullUInt32{}) + NullUInt64 = reflect.TypeOf(types.NullUInt64{}) + NullFloat32 = reflect.TypeOf(types.NullFloat32{}) + NullFloat64 = reflect.TypeOf(types.NullFloat64{}) + NullTime = reflect.TypeOf(types.NullTime{}) + NullBool = reflect.TypeOf(types.NullBool{}) + NullString = reflect.TypeOf(types.NullString{}) + Bytes = reflect.TypeOf([]byte{}) + NullJson = reflect.TypeOf(types.NullJson{}) + UnknownType = reflect.TypeOf(new(interface{})).Elem() +) + +var ColumnTypeMap = map[int]reflect.Type{ + TSDB_DATA_TYPE_BOOL: NullBool, + TSDB_DATA_TYPE_TINYINT: NullInt8, + TSDB_DATA_TYPE_SMALLINT: NullInt16, + TSDB_DATA_TYPE_INT: NullInt32, + TSDB_DATA_TYPE_BIGINT: NullInt64, + TSDB_DATA_TYPE_UTINYINT: NullUInt8, + TSDB_DATA_TYPE_USMALLINT: NullUInt16, + TSDB_DATA_TYPE_UINT: NullUInt32, + TSDB_DATA_TYPE_UBIGINT: NullUInt64, + TSDB_DATA_TYPE_FLOAT: NullFloat32, + TSDB_DATA_TYPE_DOUBLE: NullFloat64, + TSDB_DATA_TYPE_BINARY: NullString, + TSDB_DATA_TYPE_NCHAR: NullString, + TSDB_DATA_TYPE_TIMESTAMP: NullTime, + TSDB_DATA_TYPE_JSON: NullJson, + TSDB_DATA_TYPE_VARBINARY: Bytes, + TSDB_DATA_TYPE_GEOMETRY: Bytes, +} diff --git a/driver/common/const.go b/driver/common/const.go new file mode 100644 index 00000000..49efee48 --- /dev/null +++ b/driver/common/const.go @@ -0,0 +1,73 @@ +package common + +import "unsafe" + +//revive:disable +const ( + MaxTaosSqlLen = 1048576 + DefaultUser = "root" + DefaultPassword = "taosdata" +) + +const ( + PrecisionMilliSecond = 0 + PrecisionMicroSecond = 1 + PrecisionNanoSecond = 2 +) + +const ( + TSDB_OPTION_LOCALE = iota + TSDB_OPTION_CHARSET + TSDB_OPTION_TIMEZONE + TSDB_OPTION_CONFIGDIR + TSDB_OPTION_SHELL_ACTIVITY_TIMER + TSDB_OPTION_USE_ADAPTER +) + +const ( + TMQ_RES_INVALID = -1 + TMQ_RES_DATA = 1 + TMQ_RES_TABLE_META = 2 + TMQ_RES_METADATA = 3 +) + +var TypeLengthMap = map[int]int{ + TSDB_DATA_TYPE_NULL: 1, + TSDB_DATA_TYPE_BOOL: 1, + TSDB_DATA_TYPE_TINYINT: 1, + TSDB_DATA_TYPE_SMALLINT: 2, + TSDB_DATA_TYPE_INT: 4, + TSDB_DATA_TYPE_BIGINT: 8, + TSDB_DATA_TYPE_FLOAT: 4, + TSDB_DATA_TYPE_DOUBLE: 8, + TSDB_DATA_TYPE_TIMESTAMP: 8, + TSDB_DATA_TYPE_UTINYINT: 1, + TSDB_DATA_TYPE_USMALLINT: 2, + TSDB_DATA_TYPE_UINT: 4, + TSDB_DATA_TYPE_UBIGINT: 8, +} + +const ( + Int8Size = unsafe.Sizeof(int8(0)) + Int16Size = unsafe.Sizeof(int16(0)) + Int32Size = unsafe.Sizeof(int32(0)) + Int64Size = unsafe.Sizeof(int64(0)) + UInt8Size = unsafe.Sizeof(uint8(0)) + UInt16Size = unsafe.Sizeof(uint16(0)) + UInt32Size = unsafe.Sizeof(uint32(0)) + UInt64Size = unsafe.Sizeof(uint64(0)) + Float32Size = unsafe.Sizeof(float32(0)) + Float64Size = unsafe.Sizeof(float64(0)) +) + +const ReqIDKey = "taos_req_id" + +const ( + TAOS_NOTIFY_PASSVER = 0 + TAOS_NOTIFY_WHITELIST_VER = 1 + TAOS_NOTIFY_USER_DROPPED = 2 +) + +const ( + TAOS_CONN_MODE_BI = 0 +) diff --git a/driver/common/datatype.go b/driver/common/datatype.go new file mode 100644 index 00000000..cf5a3379 --- /dev/null +++ b/driver/common/datatype.go @@ -0,0 +1,89 @@ +package common + +//revive:disable +const ( + TSDB_DATA_TYPE_NULL = 0 // 1 bytes + TSDB_DATA_TYPE_BOOL = 1 // 1 bytes + TSDB_DATA_TYPE_TINYINT = 2 // 1 byte + TSDB_DATA_TYPE_SMALLINT = 3 // 2 bytes + TSDB_DATA_TYPE_INT = 4 // 4 bytes + TSDB_DATA_TYPE_BIGINT = 5 // 8 bytes + TSDB_DATA_TYPE_FLOAT = 6 // 4 bytes + TSDB_DATA_TYPE_DOUBLE = 7 // 8 bytes + TSDB_DATA_TYPE_BINARY = 8 // string + TSDB_DATA_TYPE_TIMESTAMP = 9 // 8 bytes + TSDB_DATA_TYPE_NCHAR = 10 // unicode string + TSDB_DATA_TYPE_UTINYINT = 11 // 1 byte + TSDB_DATA_TYPE_USMALLINT = 12 // 2 bytes + TSDB_DATA_TYPE_UINT = 13 // 4 bytes + TSDB_DATA_TYPE_UBIGINT = 14 // 8 bytes + TSDB_DATA_TYPE_JSON = 15 + TSDB_DATA_TYPE_VARBINARY = 16 + TSDB_DATA_TYPE_DECIMAL = 17 + TSDB_DATA_TYPE_BLOB = 18 + TSDB_DATA_TYPE_MEDIUMBLOB = 19 + TSDB_DATA_TYPE_GEOMETRY = 20 +) + +const ( + TSDB_DATA_TYPE_NULL_Str = "NULL" + TSDB_DATA_TYPE_BOOL_Str = "BOOL" + TSDB_DATA_TYPE_TINYINT_Str = "TINYINT" + TSDB_DATA_TYPE_SMALLINT_Str = "SMALLINT" + TSDB_DATA_TYPE_INT_Str = "INT" + TSDB_DATA_TYPE_BIGINT_Str = "BIGINT" + TSDB_DATA_TYPE_FLOAT_Str = "FLOAT" + TSDB_DATA_TYPE_DOUBLE_Str = "DOUBLE" + TSDB_DATA_TYPE_BINARY_Str = "VARCHAR" + TSDB_DATA_TYPE_TIMESTAMP_Str = "TIMESTAMP" + TSDB_DATA_TYPE_NCHAR_Str = "NCHAR" + TSDB_DATA_TYPE_UTINYINT_Str = "TINYINT UNSIGNED" + TSDB_DATA_TYPE_USMALLINT_Str = "SMALLINT UNSIGNED" + TSDB_DATA_TYPE_UINT_Str = "INT UNSIGNED" + TSDB_DATA_TYPE_UBIGINT_Str = "BIGINT UNSIGNED" + TSDB_DATA_TYPE_JSON_Str = "JSON" + TSDB_DATA_TYPE_VARBINARY_Str = "VARBINARY" + TSDB_DATA_TYPE_GEOMETRY_Str = "GEOMETRY" +) + +var TypeNameMap = map[int]string{ + TSDB_DATA_TYPE_NULL: TSDB_DATA_TYPE_NULL_Str, + TSDB_DATA_TYPE_BOOL: TSDB_DATA_TYPE_BOOL_Str, + TSDB_DATA_TYPE_TINYINT: TSDB_DATA_TYPE_TINYINT_Str, + TSDB_DATA_TYPE_SMALLINT: TSDB_DATA_TYPE_SMALLINT_Str, + TSDB_DATA_TYPE_INT: TSDB_DATA_TYPE_INT_Str, + TSDB_DATA_TYPE_BIGINT: TSDB_DATA_TYPE_BIGINT_Str, + TSDB_DATA_TYPE_FLOAT: TSDB_DATA_TYPE_FLOAT_Str, + TSDB_DATA_TYPE_DOUBLE: TSDB_DATA_TYPE_DOUBLE_Str, + TSDB_DATA_TYPE_BINARY: TSDB_DATA_TYPE_BINARY_Str, + TSDB_DATA_TYPE_TIMESTAMP: TSDB_DATA_TYPE_TIMESTAMP_Str, + TSDB_DATA_TYPE_NCHAR: TSDB_DATA_TYPE_NCHAR_Str, + TSDB_DATA_TYPE_UTINYINT: TSDB_DATA_TYPE_UTINYINT_Str, + TSDB_DATA_TYPE_USMALLINT: TSDB_DATA_TYPE_USMALLINT_Str, + TSDB_DATA_TYPE_UINT: TSDB_DATA_TYPE_UINT_Str, + TSDB_DATA_TYPE_UBIGINT: TSDB_DATA_TYPE_UBIGINT_Str, + TSDB_DATA_TYPE_JSON: TSDB_DATA_TYPE_JSON_Str, + TSDB_DATA_TYPE_VARBINARY: TSDB_DATA_TYPE_VARBINARY_Str, + TSDB_DATA_TYPE_GEOMETRY: TSDB_DATA_TYPE_GEOMETRY_Str, +} + +var NameTypeMap = map[string]int{ + TSDB_DATA_TYPE_NULL_Str: TSDB_DATA_TYPE_NULL, + TSDB_DATA_TYPE_BOOL_Str: TSDB_DATA_TYPE_BOOL, + TSDB_DATA_TYPE_TINYINT_Str: TSDB_DATA_TYPE_TINYINT, + TSDB_DATA_TYPE_SMALLINT_Str: TSDB_DATA_TYPE_SMALLINT, + TSDB_DATA_TYPE_INT_Str: TSDB_DATA_TYPE_INT, + TSDB_DATA_TYPE_BIGINT_Str: TSDB_DATA_TYPE_BIGINT, + TSDB_DATA_TYPE_FLOAT_Str: TSDB_DATA_TYPE_FLOAT, + TSDB_DATA_TYPE_DOUBLE_Str: TSDB_DATA_TYPE_DOUBLE, + TSDB_DATA_TYPE_BINARY_Str: TSDB_DATA_TYPE_BINARY, + TSDB_DATA_TYPE_TIMESTAMP_Str: TSDB_DATA_TYPE_TIMESTAMP, + TSDB_DATA_TYPE_NCHAR_Str: TSDB_DATA_TYPE_NCHAR, + TSDB_DATA_TYPE_UTINYINT_Str: TSDB_DATA_TYPE_UTINYINT, + TSDB_DATA_TYPE_USMALLINT_Str: TSDB_DATA_TYPE_USMALLINT, + TSDB_DATA_TYPE_UINT_Str: TSDB_DATA_TYPE_UINT, + TSDB_DATA_TYPE_UBIGINT_Str: TSDB_DATA_TYPE_UBIGINT, + TSDB_DATA_TYPE_JSON_Str: TSDB_DATA_TYPE_JSON, + TSDB_DATA_TYPE_VARBINARY_Str: TSDB_DATA_TYPE_VARBINARY, + TSDB_DATA_TYPE_GEOMETRY_Str: TSDB_DATA_TYPE_GEOMETRY, +} diff --git a/driver/common/param/column.go b/driver/common/param/column.go new file mode 100644 index 00000000..06e47e61 --- /dev/null +++ b/driver/common/param/column.go @@ -0,0 +1,220 @@ +package param + +import ( + "fmt" + + "github.com/taosdata/taosadapter/v3/driver/types" +) + +type ColumnType struct { + size int + value []*types.ColumnType + column int +} + +func NewColumnType(size int) *ColumnType { + return &ColumnType{size: size, value: make([]*types.ColumnType, size)} +} + +func NewColumnTypeWithValue(value []*types.ColumnType) *ColumnType { + return &ColumnType{size: len(value), value: value, column: len(value)} +} + +func (c *ColumnType) AddBool() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosBoolType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddTinyint() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosTinyintType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddSmallint() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosSmallintType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddInt() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosIntType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddBigint() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosBigintType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddUTinyint() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosUTinyintType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddUSmallint() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosUSmallintType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddUInt() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosUIntType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddUBigint() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosUBigintType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddFloat() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosFloatType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddDouble() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosDoubleType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddBinary(strMaxLen int) *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: strMaxLen, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddVarBinary(strMaxLen int) *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosVarBinaryType, + MaxLen: strMaxLen, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddNchar(strMaxLen int) *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosNcharType, + MaxLen: strMaxLen, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddTimestamp() *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosTimestampType, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddJson(strMaxLen int) *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosJsonType, + MaxLen: strMaxLen, + } + c.column += 1 + return c +} + +func (c *ColumnType) AddGeometry(strMaxLen int) *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosGeometryType, + MaxLen: strMaxLen, + } + c.column += 1 + return c +} + +func (c *ColumnType) GetValue() ([]*types.ColumnType, error) { + if c.size != c.column { + return nil, fmt.Errorf("incomplete column expect %d columns set %d columns", c.size, c.column) + } + return c.value, nil +} diff --git a/driver/common/param/column_test.go b/driver/common/param/column_test.go new file mode 100644 index 00000000..dc179035 --- /dev/null +++ b/driver/common/param/column_test.go @@ -0,0 +1,435 @@ +package param + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/types" +) + +func TestColumnType_AddBool(t *testing.T) { + colType := NewColumnType(1) + colType.AddBool() + + expected := []*types.ColumnType{ + { + Type: types.TaosBoolType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddBool() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddTinyint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddTinyint() + + expected := []*types.ColumnType{ + { + Type: types.TaosTinyintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddTinyint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddSmallint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddSmallint() + + expected := []*types.ColumnType{ + { + Type: types.TaosSmallintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddSmallint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddInt(t *testing.T) { + colType := NewColumnType(1) + + colType.AddInt() + + expected := []*types.ColumnType{ + { + Type: types.TaosIntType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddInt() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddBigint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddBigint() + + expected := []*types.ColumnType{ + { + Type: types.TaosBigintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddBigint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddUTinyint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddUTinyint() + + expected := []*types.ColumnType{ + { + Type: types.TaosUTinyintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddUTinyint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddUSmallint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddUSmallint() + + expected := []*types.ColumnType{ + { + Type: types.TaosUSmallintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddUSmallint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddUInt(t *testing.T) { + colType := NewColumnType(1) + + colType.AddUInt() + + expected := []*types.ColumnType{ + { + Type: types.TaosUIntType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddUInt() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddUBigint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddUBigint() + + expected := []*types.ColumnType{ + { + Type: types.TaosUBigintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddUBigint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddFloat(t *testing.T) { + colType := NewColumnType(1) + + colType.AddFloat() + + expected := []*types.ColumnType{ + { + Type: types.TaosFloatType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddFloat() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddDouble(t *testing.T) { + colType := NewColumnType(1) + + colType.AddDouble() + + expected := []*types.ColumnType{ + { + Type: types.TaosDoubleType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddDouble() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddBinary(t *testing.T) { + colType := NewColumnType(1) + + colType.AddBinary(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosBinaryType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddBinary(50) + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddVarBinary(t *testing.T) { + colType := NewColumnType(1) + + colType.AddVarBinary(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosVarBinaryType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddVarBinary(50) + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddNchar(t *testing.T) { + colType := NewColumnType(1) + + colType.AddNchar(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosNcharType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddNchar(50) + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddTimestamp(t *testing.T) { + colType := NewColumnType(1) + + colType.AddTimestamp() + + expected := []*types.ColumnType{ + { + Type: types.TaosTimestampType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddTimestamp() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddJson(t *testing.T) { + colType := NewColumnType(1) + + colType.AddJson(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosJsonType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddJson(50) + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddGeometry(t *testing.T) { + colType := NewColumnType(1) + + colType.AddGeometry(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosGeometryType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddGeometry(50) + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_GetValue(t *testing.T) { + // Initialize ColumnType with size 3 + colType := NewColumnType(3) + + // Add column types + colType.AddBool() + colType.AddTinyint() + colType.AddFloat() + + // Try to get values + values, err := colType.GetValue() + assert.NoError(t, err) + + // Check if the length of values matches the expected size + expectedSize := 3 + assert.Equal(t, expectedSize, len(values)) + + // Initialize ColumnType with size 3 + colType = NewColumnType(3) + + // Add only 2 column types + colType.AddBool() + colType.AddTinyint() + + // Try to get values + _, err = colType.GetValue() + + // Check if an error is returned due to incomplete column + assert.Error(t, err) + assert.Equal(t, "incomplete column expect 3 columns set 2 columns", err.Error()) +} + +func TestNewColumnTypeWithValue(t *testing.T) { + value := []*types.ColumnType{ + {Type: types.TaosBoolType}, + {Type: types.TaosTinyintType}, + } + + colType := NewColumnTypeWithValue(value) + + expectedSize := len(value) + assert.Equal(t, expectedSize, colType.size) + + expectedValue := value + assert.Equal(t, expectedValue, colType.value) + + expectedColumn := len(value) + assert.Equal(t, expectedColumn, colType.column) +} diff --git a/driver/common/param/param.go b/driver/common/param/param.go new file mode 100644 index 00000000..b2aa7027 --- /dev/null +++ b/driver/common/param/param.go @@ -0,0 +1,337 @@ +package param + +import ( + "database/sql/driver" + "time" + + taosTypes "github.com/taosdata/taosadapter/v3/driver/types" +) + +type Param struct { + size int + value []driver.Value + offset int +} + +func NewParam(size int) *Param { + return &Param{ + size: size, + value: make([]driver.Value, size), + } +} + +func NewParamsWithRowValue(value []driver.Value) []*Param { + params := make([]*Param, len(value)) + for i, d := range value { + params[i] = NewParam(1) + params[i].AddValue(d) + } + return params +} + +func (p *Param) SetBool(offset int, value bool) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosBool(value) +} + +func (p *Param) SetNull(offset int) { + if offset >= p.size { + return + } + p.value[offset] = nil +} + +func (p *Param) SetTinyint(offset int, value int) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosTinyint(value) +} + +func (p *Param) SetSmallint(offset int, value int) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosSmallint(value) +} + +func (p *Param) SetInt(offset int, value int) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosInt(value) +} + +func (p *Param) SetBigint(offset int, value int) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosBigint(value) +} + +func (p *Param) SetUTinyint(offset int, value uint) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosUTinyint(value) +} + +func (p *Param) SetUSmallint(offset int, value uint) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosUSmallint(value) +} + +func (p *Param) SetUInt(offset int, value uint) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosUInt(value) +} + +func (p *Param) SetUBigint(offset int, value uint) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosUBigint(value) +} + +func (p *Param) SetFloat(offset int, value float32) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosFloat(value) +} + +func (p *Param) SetDouble(offset int, value float64) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosDouble(value) +} + +func (p *Param) SetBinary(offset int, value []byte) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosBinary(value) +} + +func (p *Param) SetVarBinary(offset int, value []byte) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosVarBinary(value) +} + +func (p *Param) SetNchar(offset int, value string) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosNchar(value) +} + +func (p *Param) SetTimestamp(offset int, value time.Time, precision int) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosTimestamp{ + T: value, + Precision: precision, + } +} + +func (p *Param) SetJson(offset int, value []byte) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosJson(value) +} + +func (p *Param) SetGeometry(offset int, value []byte) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosGeometry(value) +} + +func (p *Param) AddBool(value bool) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosBool(value) + p.offset += 1 + return p +} + +func (p *Param) AddNull() *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = nil + p.offset += 1 + return p +} + +func (p *Param) AddTinyint(value int) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosTinyint(value) + p.offset += 1 + return p +} + +func (p *Param) AddSmallint(value int) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosSmallint(value) + p.offset += 1 + return p +} + +func (p *Param) AddInt(value int) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosInt(value) + p.offset += 1 + return p +} + +func (p *Param) AddBigint(value int) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosBigint(value) + p.offset += 1 + return p +} + +func (p *Param) AddUTinyint(value uint) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosUTinyint(value) + p.offset += 1 + return p +} + +func (p *Param) AddUSmallint(value uint) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosUSmallint(value) + p.offset += 1 + return p +} + +func (p *Param) AddUInt(value uint) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosUInt(value) + p.offset += 1 + return p +} + +func (p *Param) AddUBigint(value uint) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosUBigint(value) + p.offset += 1 + return p +} + +func (p *Param) AddFloat(value float32) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosFloat(value) + p.offset += 1 + return p +} + +func (p *Param) AddDouble(value float64) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosDouble(value) + p.offset += 1 + return p +} + +func (p *Param) AddBinary(value []byte) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosBinary(value) + p.offset += 1 + return p +} + +func (p *Param) AddVarBinary(value []byte) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosVarBinary(value) + p.offset += 1 + return p +} + +func (p *Param) AddNchar(value string) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosNchar(value) + p.offset += 1 + return p +} + +func (p *Param) AddTimestamp(value time.Time, precision int) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosTimestamp{ + T: value, + Precision: precision, + } + p.offset += 1 + return p +} + +func (p *Param) AddJson(value []byte) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosJson(value) + p.offset += 1 + return p +} + +func (p *Param) AddGeometry(value []byte) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosGeometry(value) + p.offset += 1 + return p +} + +func (p *Param) GetValues() []driver.Value { + return p.value +} + +func (p *Param) AddValue(value interface{}) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = value + p.offset += 1 + return p +} diff --git a/driver/common/param/param_test.go b/driver/common/param/param_test.go new file mode 100644 index 00000000..35834629 --- /dev/null +++ b/driver/common/param/param_test.go @@ -0,0 +1,654 @@ +package param + +import ( + "database/sql/driver" + "testing" + "time" + + "github.com/stretchr/testify/assert" + taosTypes "github.com/taosdata/taosadapter/v3/driver/types" +) + +func TestParam_SetBool(t *testing.T) { + param := NewParam(1) + param.SetBool(0, true) + + expected := []driver.Value{taosTypes.TaosBool(true)} + assert.Equal(t, expected, param.GetValues()) + + param = NewParam(0) + param.SetBool(0, true) + assert.Equal(t, 0, len(param.GetValues())) +} + +func TestParam_SetNull(t *testing.T) { + param := NewParam(1) + param.SetNull(0) + + if param.GetValues()[0] != nil { + t.Errorf("SetNull failed, expected nil, got %v", param.GetValues()[0]) + } + param = NewParam(0) + param.SetNull(0) + assert.Equal(t, 0, len(param.GetValues())) +} + +func TestParam_SetTinyint(t *testing.T) { + param := NewParam(1) + param.SetTinyint(0, 42) + + expected := []driver.Value{taosTypes.TaosTinyint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetTinyint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetSmallint(t *testing.T) { + param := NewParam(1) + param.SetSmallint(0, 42) + + expected := []driver.Value{taosTypes.TaosSmallint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetSmallint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetInt(t *testing.T) { + param := NewParam(1) + param.SetInt(0, 42) + + expected := []driver.Value{taosTypes.TaosInt(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetInt(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetBigint(t *testing.T) { + param := NewParam(1) + param.SetBigint(0, 42) + + expected := []driver.Value{taosTypes.TaosBigint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetBigint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetUTinyint(t *testing.T) { + param := NewParam(1) + param.SetUTinyint(0, 42) + + expected := []driver.Value{taosTypes.TaosUTinyint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetUTinyint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetUSmallint(t *testing.T) { + param := NewParam(1) + param.SetUSmallint(0, 42) + + expected := []driver.Value{taosTypes.TaosUSmallint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetUSmallint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetUInt(t *testing.T) { + param := NewParam(1) + param.SetUInt(0, 42) + + expected := []driver.Value{taosTypes.TaosUInt(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetUInt(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetUBigint(t *testing.T) { + param := NewParam(1) + param.SetUBigint(0, 42) + + expected := []driver.Value{taosTypes.TaosUBigint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetUBigint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetFloat(t *testing.T) { + param := NewParam(1) + param.SetFloat(0, 3.14) + + expected := []driver.Value{taosTypes.TaosFloat(3.14)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetFloat(1, 3.14) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetDouble(t *testing.T) { + param := NewParam(1) + param.SetDouble(0, 3.14) + + expected := []driver.Value{taosTypes.TaosDouble(3.14)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetDouble(1, 3.14) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetBinary(t *testing.T) { + param := NewParam(1) + param.SetBinary(0, []byte{0x01, 0x02}) + + expected := []driver.Value{taosTypes.TaosBinary([]byte{0x01, 0x02})} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetBinary(1, []byte{0x01, 0x02}) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetVarBinary(t *testing.T) { + param := NewParam(1) + param.SetVarBinary(0, []byte{0x01, 0x02}) + + expected := []driver.Value{taosTypes.TaosVarBinary([]byte{0x01, 0x02})} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetVarBinary(1, []byte{0x01, 0x02}) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetNchar(t *testing.T) { + param := NewParam(1) + param.SetNchar(0, "hello") + + expected := []driver.Value{taosTypes.TaosNchar("hello")} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetNchar(1, "hello") // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetTimestamp(t *testing.T) { + timestamp := time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC) + param := NewParam(1) + param.SetTimestamp(0, timestamp, 6) + + expected := []driver.Value{taosTypes.TaosTimestamp{T: timestamp, Precision: 6}} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetTimestamp(1, timestamp, 6) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetJson(t *testing.T) { + jsonData := []byte(`{"key": "value"}`) + param := NewParam(1) + param.SetJson(0, jsonData) + + expected := []driver.Value{taosTypes.TaosJson(jsonData)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetJson(1, jsonData) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetGeometry(t *testing.T) { + geometryData := []byte{0x01, 0x02, 0x03, 0x04} + param := NewParam(1) + param.SetGeometry(0, geometryData) + + expected := []driver.Value{taosTypes.TaosGeometry(geometryData)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetGeometry(1, geometryData) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddBool(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a bool value + param.AddBool(true) + + expected := []driver.Value{taosTypes.TaosBool(true), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another bool value + param.AddBool(false) + + expected = []driver.Value{taosTypes.TaosBool(true), taosTypes.TaosBool(false)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddBool(true) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddNull(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a null value + param.AddNull() + + expected := []driver.Value{nil, nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another null value + param.AddNull() + + expected = []driver.Value{nil, nil} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddNull() // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddTinyint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a tinyint value + param.AddTinyint(42) + + expected := []driver.Value{taosTypes.TaosTinyint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another tinyint value + param.AddTinyint(84) + + expected = []driver.Value{taosTypes.TaosTinyint(42), taosTypes.TaosTinyint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddTinyint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddSmallint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a smallint value + param.AddSmallint(42) + + expected := []driver.Value{taosTypes.TaosSmallint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another smallint value + param.AddSmallint(84) + + expected = []driver.Value{taosTypes.TaosSmallint(42), taosTypes.TaosSmallint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddSmallint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddInt(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add an int value + param.AddInt(42) + + expected := []driver.Value{taosTypes.TaosInt(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another int value + param.AddInt(84) + + expected = []driver.Value{taosTypes.TaosInt(42), taosTypes.TaosInt(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddInt(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not mod +} + +func TestParam_AddBigint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a bigint value + param.AddBigint(42) + + expected := []driver.Value{taosTypes.TaosBigint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another bigint value + param.AddBigint(84) + + expected = []driver.Value{taosTypes.TaosBigint(42), taosTypes.TaosBigint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddBigint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddUTinyint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a utinyint value + param.AddUTinyint(42) + + expected := []driver.Value{taosTypes.TaosUTinyint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another utinyint value + param.AddUTinyint(84) + + expected = []driver.Value{taosTypes.TaosUTinyint(42), taosTypes.TaosUTinyint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddUTinyint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddUSmallint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a usmallint value + param.AddUSmallint(42) + + expected := []driver.Value{taosTypes.TaosUSmallint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another usmallint value + param.AddUSmallint(84) + + expected = []driver.Value{taosTypes.TaosUSmallint(42), taosTypes.TaosUSmallint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddUSmallint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddUInt(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a uint value + param.AddUInt(42) + + expected := []driver.Value{taosTypes.TaosUInt(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another uint value + param.AddUInt(84) + + expected = []driver.Value{taosTypes.TaosUInt(42), taosTypes.TaosUInt(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddUInt(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddUBigint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a ubigint value + param.AddUBigint(42) + + expected := []driver.Value{taosTypes.TaosUBigint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another ubigint value + param.AddUBigint(84) + + expected = []driver.Value{taosTypes.TaosUBigint(42), taosTypes.TaosUBigint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddUBigint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddFloat(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a float value + param.AddFloat(3.14) + + expected := []driver.Value{taosTypes.TaosFloat(3.14), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another float value + param.AddFloat(6.28) + + expected = []driver.Value{taosTypes.TaosFloat(3.14), taosTypes.TaosFloat(6.28)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddFloat(9.42) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddDouble(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a double value + param.AddDouble(3.14) + + expected := []driver.Value{taosTypes.TaosDouble(3.14), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another double value + param.AddDouble(6.28) + + expected = []driver.Value{taosTypes.TaosDouble(3.14), taosTypes.TaosDouble(6.28)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddDouble(9.42) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddBinary(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + binaryData := []byte{0x01, 0x02, 0x03} + + // Add a binary value + param.AddBinary(binaryData) + + expected := []driver.Value{taosTypes.TaosBinary(binaryData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another binary value + param.AddBinary([]byte{0x04, 0x05, 0x06}) + + expected = []driver.Value{taosTypes.TaosBinary(binaryData), taosTypes.TaosBinary([]byte{0x04, 0x05, 0x06})} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddBinary([]byte{0x07, 0x08, 0x09}) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddVarBinary(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + binaryData := []byte{0x01, 0x02, 0x03} + + // Add a varbinary value + param.AddVarBinary(binaryData) + + expected := []driver.Value{taosTypes.TaosVarBinary(binaryData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another varbinary value + param.AddVarBinary([]byte{0x04, 0x05, 0x06}) + + expected = []driver.Value{taosTypes.TaosVarBinary(binaryData), taosTypes.TaosVarBinary([]byte{0x04, 0x05, 0x06})} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddVarBinary([]byte{0x07, 0x08, 0x09}) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddNchar(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add an nchar value + param.AddNchar("hello") + + expected := []driver.Value{taosTypes.TaosNchar("hello"), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another nchar value + param.AddNchar("world") + + expected = []driver.Value{taosTypes.TaosNchar("hello"), taosTypes.TaosNchar("world")} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddNchar("test") // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddTimestamp(t *testing.T) { + timestamp := time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC) + param := NewParam(2) // Initialize with size 2 + + // Add a timestamp value + param.AddTimestamp(timestamp, 6) + + expected := []driver.Value{taosTypes.TaosTimestamp{T: timestamp, Precision: 6}, nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another timestamp value + param.AddTimestamp(timestamp.Add(time.Hour), 9) + + expected = []driver.Value{ + taosTypes.TaosTimestamp{T: timestamp, Precision: 6}, + taosTypes.TaosTimestamp{T: timestamp.Add(time.Hour), Precision: 9}, + } + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddTimestamp(timestamp.Add(2*time.Hour), 6) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddJson(t *testing.T) { + jsonData := []byte(`{"key": "value"}`) + param := NewParam(2) // Initialize with size 2 + + // Add a JSON value + param.AddJson(jsonData) + + expected := []driver.Value{taosTypes.TaosJson(jsonData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another JSON value + param.AddJson([]byte(`{"key2": "value2"}`)) + + expected = []driver.Value{ + taosTypes.TaosJson(jsonData), + taosTypes.TaosJson([]byte(`{"key2": "value2"}`)), + } + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddJson([]byte(`{"key3": "value3"}`)) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddGeometry(t *testing.T) { + geometryData := []byte{0x01, 0x02, 0x03} + param := NewParam(2) // Initialize with size 2 + + // Add a geometry value + param.AddGeometry(geometryData) + + expected := []driver.Value{taosTypes.TaosGeometry(geometryData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another geometry value + param.AddGeometry([]byte{0x04, 0x05, 0x06}) + + expected = []driver.Value{ + taosTypes.TaosGeometry(geometryData), + taosTypes.TaosGeometry([]byte{0x04, 0x05, 0x06}), + } + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddGeometry([]byte{0x07, 0x08, 0x09}) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddValue(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a binary value + binaryData := []byte{0x01, 0x02, 0x03} + param.AddValue(taosTypes.TaosBinary(binaryData)) + + expected := []driver.Value{taosTypes.TaosBinary(binaryData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add a varchar value + param.AddValue(taosTypes.TaosVarBinary("hello")) + + expected = []driver.Value{taosTypes.TaosBinary(binaryData), taosTypes.TaosVarBinary("hello")} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddValue(taosTypes.TaosVarBinary("world")) + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestNewParamsWithRowValue(t *testing.T) { + rowValues := []driver.Value{taosTypes.TaosBool(true), taosTypes.TaosInt(42), taosTypes.TaosNchar("hello")} + + params := NewParamsWithRowValue(rowValues) + + expected := []*Param{ + { + size: 1, + value: []driver.Value{taosTypes.TaosBool(true)}, + offset: 1, + }, + { + size: 1, + value: []driver.Value{taosTypes.TaosInt(42)}, + offset: 1, + }, + { + size: 1, + value: []driver.Value{taosTypes.TaosNchar("hello")}, + offset: 1, + }, + } + + for i, param := range params { + assert.Equal(t, expected[i].size, param.size) + assert.Equal(t, expected[i].value, param.value) + assert.Equal(t, expected[i].offset, param.offset) + } +} diff --git a/driver/common/parser/block.go b/driver/common/parser/block.go new file mode 100644 index 00000000..b527242d --- /dev/null +++ b/driver/common/parser/block.go @@ -0,0 +1,374 @@ +package parser + +import ( + "database/sql/driver" + "math" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/tools" +) + +const ( + Int8Size = common.Int8Size + Int16Size = common.Int16Size + Int32Size = common.Int32Size + Int64Size = common.Int64Size + UInt8Size = common.UInt8Size + UInt16Size = common.UInt16Size + UInt32Size = common.UInt32Size + UInt64Size = common.UInt64Size + Float32Size = common.Float32Size + Float64Size = common.Float64Size +) + +const ( + ColInfoSize = Int8Size + Int32Size + RawBlockVersionOffset = 0 + RawBlockLengthOffset = RawBlockVersionOffset + Int32Size + NumOfRowsOffset = RawBlockLengthOffset + Int32Size + NumOfColsOffset = NumOfRowsOffset + Int32Size + HasColumnSegmentOffset = NumOfColsOffset + Int32Size + GroupIDOffset = HasColumnSegmentOffset + Int32Size + ColInfoOffset = GroupIDOffset + UInt64Size +) + +func RawBlockGetVersion(rawBlock unsafe.Pointer) int32 { + return *((*int32)(tools.AddPointer(rawBlock, RawBlockVersionOffset))) +} + +func RawBlockGetLength(rawBlock unsafe.Pointer) int32 { + return *((*int32)(tools.AddPointer(rawBlock, RawBlockLengthOffset))) +} + +func RawBlockGetNumOfRows(rawBlock unsafe.Pointer) int32 { + return *((*int32)(tools.AddPointer(rawBlock, NumOfRowsOffset))) +} + +func RawBlockGetNumOfCols(rawBlock unsafe.Pointer) int32 { + return *((*int32)(tools.AddPointer(rawBlock, NumOfColsOffset))) +} + +func RawBlockGetHasColumnSegment(rawBlock unsafe.Pointer) int32 { + return *((*int32)(tools.AddPointer(rawBlock, HasColumnSegmentOffset))) +} + +func RawBlockGetGroupID(rawBlock unsafe.Pointer) uint64 { + return *((*uint64)(tools.AddPointer(rawBlock, GroupIDOffset))) +} + +type RawBlockColInfo struct { + ColType int8 + Bytes int32 +} + +func RawBlockGetColInfo(rawBlock unsafe.Pointer, infos []RawBlockColInfo) { + for i := 0; i < len(infos); i++ { + offset := ColInfoOffset + ColInfoSize*uintptr(i) + infos[i].ColType = *((*int8)(tools.AddPointer(rawBlock, offset))) + infos[i].Bytes = *((*int32)(tools.AddPointer(rawBlock, offset+Int8Size))) + } +} + +func RawBlockGetColumnLengthOffset(colCount int) uintptr { + return ColInfoOffset + uintptr(colCount)*ColInfoSize +} + +func RawBlockGetColDataOffset(colCount int) uintptr { + return ColInfoOffset + uintptr(colCount)*ColInfoSize + uintptr(colCount)*Int32Size +} + +type FormatTimeFunc func(ts int64, precision int) driver.Value + +func IsVarDataType(colType uint8) bool { + return colType == common.TSDB_DATA_TYPE_BINARY || + colType == common.TSDB_DATA_TYPE_NCHAR || + colType == common.TSDB_DATA_TYPE_JSON || + colType == common.TSDB_DATA_TYPE_VARBINARY || + colType == common.TSDB_DATA_TYPE_GEOMETRY +} + +func BitmapLen(n int) int { + return ((n) + ((1 << 3) - 1)) >> 3 +} + +func BitPos(n int) int { + return n & ((1 << 3) - 1) +} + +func CharOffset(n int) int { + return n >> 3 +} + +func BMIsNull(c byte, n int) bool { + return c&(1<<(7-BitPos(n))) == (1 << (7 - BitPos(n))) +} + +type rawConvertFunc func(pStart unsafe.Pointer, row int, arg ...interface{}) driver.Value + +type rawConvertVarDataFunc func(pHeader, pStart unsafe.Pointer, row int) driver.Value + +var rawConvertFuncSlice = [15]rawConvertFunc{} + +var rawConvertVarDataSlice = [21]rawConvertVarDataFunc{} + +func ItemIsNull(pHeader unsafe.Pointer, row int) bool { + offset := CharOffset(row) + c := *((*byte)(tools.AddPointer(pHeader, uintptr(offset)))) + return BMIsNull(c, row) +} + +func rawConvertBool(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return (*((*byte)(tools.AddPointer(pStart, uintptr(row)*1)))) != 0 +} + +func rawConvertTinyint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*int8)(tools.AddPointer(pStart, uintptr(row)*Int8Size))) +} + +func rawConvertSmallint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*int16)(tools.AddPointer(pStart, uintptr(row)*Int16Size))) +} + +func rawConvertInt(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*int32)(tools.AddPointer(pStart, uintptr(row)*Int32Size))) +} + +func rawConvertBigint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*int64)(tools.AddPointer(pStart, uintptr(row)*Int64Size))) +} + +func rawConvertUTinyint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*uint8)(tools.AddPointer(pStart, uintptr(row)*UInt8Size))) +} + +func rawConvertUSmallint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*uint16)(tools.AddPointer(pStart, uintptr(row)*UInt16Size))) +} + +func rawConvertUInt(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*uint32)(tools.AddPointer(pStart, uintptr(row)*UInt32Size))) +} + +func rawConvertUBigint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*uint64)(tools.AddPointer(pStart, uintptr(row)*UInt64Size))) +} + +func rawConvertFloat(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return math.Float32frombits(*((*uint32)(tools.AddPointer(pStart, uintptr(row)*Float32Size)))) +} + +func rawConvertDouble(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return math.Float64frombits(*((*uint64)(tools.AddPointer(pStart, uintptr(row)*Float64Size)))) +} + +func rawConvertTime(pStart unsafe.Pointer, row int, arg ...interface{}) driver.Value { + if len(arg) == 1 { + return common.TimestampConvertToTime(*((*int64)(tools.AddPointer(pStart, uintptr(row)*Int64Size))), arg[0].(int)) + } else if len(arg) == 2 { + return arg[1].(FormatTimeFunc)(*((*int64)(tools.AddPointer(pStart, uintptr(row)*Int64Size))), arg[0].(int)) + } + panic("convertTime error") +} + +func rawConvertVarBinary(pHeader, pStart unsafe.Pointer, row int) driver.Value { + result := rawGetBytes(pHeader, pStart, row) + if result == nil { + return nil + } + return result +} + +func rawGetBytes(pHeader, pStart unsafe.Pointer, row int) []byte { + offset := *((*int32)(tools.AddPointer(pHeader, uintptr(row*4)))) + if offset == -1 { + return nil + } + currentRow := tools.AddPointer(pStart, uintptr(offset)) + clen := *((*uint16)(currentRow)) + if clen == 0 { + return make([]byte, 0) + } + currentRow = tools.AddPointer(currentRow, 2) + result := make([]byte, clen) + Copy(currentRow, result, 0, int(clen)) + return result +} + +func rawConvertGeometry(pHeader, pStart unsafe.Pointer, row int) driver.Value { + return rawConvertVarBinary(pHeader, pStart, row) +} + +func rawConvertBinary(pHeader, pStart unsafe.Pointer, row int) driver.Value { + result := rawGetBytes(pHeader, pStart, row) + if result == nil { + return nil + } + return *(*string)(unsafe.Pointer(&result)) +} + +func rawConvertNchar(pHeader, pStart unsafe.Pointer, row int) driver.Value { + offset := *((*int32)(tools.AddPointer(pHeader, uintptr(row*4)))) + if offset == -1 { + return nil + } + currentRow := tools.AddPointer(pStart, uintptr(offset)) + clen := *((*uint16)(currentRow)) / 4 + if clen == 0 { + return "" + } + currentRow = unsafe.Pointer(uintptr(currentRow) + 2) + binaryVal := make([]rune, clen) + + for index := uint16(0); index < clen; index++ { + binaryVal[index] = *((*rune)(unsafe.Pointer(uintptr(currentRow) + uintptr(index*4)))) + } + return string(binaryVal) +} + +func rawConvertJson(pHeader, pStart unsafe.Pointer, row int) driver.Value { + return rawConvertVarBinary(pHeader, pStart, row) +} + +func ReadBlockSimple(block unsafe.Pointer, precision int) [][]driver.Value { + blockSize := RawBlockGetNumOfRows(block) + colCount := RawBlockGetNumOfCols(block) + colInfo := make([]RawBlockColInfo, colCount) + RawBlockGetColInfo(block, colInfo) + colTypes := make([]uint8, colCount) + for i := int32(0); i < colCount; i++ { + colTypes[i] = uint8(colInfo[i].ColType) + } + return ReadBlock(block, int(blockSize), colTypes, precision) +} + +// ReadBlock in-place +func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision int) [][]driver.Value { + r := make([][]driver.Value, blockSize) + colCount := len(colTypes) + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(colCount) + pHeader := tools.AddPointer(block, RawBlockGetColDataOffset(colCount)) + var pStart unsafe.Pointer + for column := 0; column < colCount; column++ { + colLength := *((*int32)(tools.AddPointer(block, lengthOffset+uintptr(column)*Int32Size))) + if IsVarDataType(colTypes[column]) { + convertF := rawConvertVarDataSlice[colTypes[column]] + pStart = tools.AddPointer(pHeader, Int32Size*uintptr(blockSize)) + for row := 0; row < blockSize; row++ { + if column == 0 { + r[row] = make([]driver.Value, colCount) + } + r[row][column] = convertF(pHeader, pStart, row) + } + } else { + convertF := rawConvertFuncSlice[colTypes[column]] + pStart = tools.AddPointer(pHeader, nullBitMapOffset) + for row := 0; row < blockSize; row++ { + if column == 0 { + r[row] = make([]driver.Value, colCount) + } + if ItemIsNull(pHeader, row) { + r[row][column] = nil + } else { + r[row][column] = convertF(pStart, row, precision) + } + } + } + pHeader = tools.AddPointer(pStart, uintptr(colLength)) + } + return r +} + +func ReadRow(dest []driver.Value, block unsafe.Pointer, blockSize int, row int, colTypes []uint8, precision int) { + colCount := len(colTypes) + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(colCount) + pHeader := tools.AddPointer(block, RawBlockGetColDataOffset(colCount)) + var pStart unsafe.Pointer + for column := 0; column < colCount; column++ { + colLength := *((*int32)(tools.AddPointer(block, lengthOffset+uintptr(column)*Int32Size))) + if IsVarDataType(colTypes[column]) { + convertF := rawConvertVarDataSlice[colTypes[column]] + pStart = tools.AddPointer(pHeader, Int32Size*uintptr(blockSize)) + dest[column] = convertF(pHeader, pStart, row) + } else { + convertF := rawConvertFuncSlice[colTypes[column]] + pStart = tools.AddPointer(pHeader, nullBitMapOffset) + if ItemIsNull(pHeader, row) { + dest[column] = nil + } else { + dest[column] = convertF(pStart, row, precision) + } + } + pHeader = tools.AddPointer(pStart, uintptr(colLength)) + } +} + +func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uint8, precision int, formatFunc FormatTimeFunc) [][]driver.Value { + r := make([][]driver.Value, blockSize) + colCount := len(colTypes) + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(colCount) + pHeader := tools.AddPointer(block, RawBlockGetColDataOffset(colCount)) + var pStart unsafe.Pointer + for column := 0; column < colCount; column++ { + colLength := *((*int32)(tools.AddPointer(block, lengthOffset+uintptr(column)*Int32Size))) + if IsVarDataType(colTypes[column]) { + convertF := rawConvertVarDataSlice[colTypes[column]] + pStart = tools.AddPointer(pHeader, uintptr(4*blockSize)) + for row := 0; row < blockSize; row++ { + if column == 0 { + r[row] = make([]driver.Value, colCount) + } + r[row][column] = convertF(pHeader, pStart, row) + } + } else { + convertF := rawConvertFuncSlice[colTypes[column]] + pStart = tools.AddPointer(pHeader, nullBitMapOffset) + for row := 0; row < blockSize; row++ { + if column == 0 { + r[row] = make([]driver.Value, colCount) + } + if ItemIsNull(pHeader, row) { + r[row][column] = nil + } else { + r[row][column] = convertF(pStart, row, precision, formatFunc) + } + } + } + pHeader = tools.AddPointer(pStart, uintptr(colLength)) + } + return r +} + +func ItemRawBlock(colType uint8, pHeader, pStart unsafe.Pointer, row int, precision int, timeFormat FormatTimeFunc) driver.Value { + if IsVarDataType(colType) { + return rawConvertVarDataSlice[colType](pHeader, pStart, row) + } + if ItemIsNull(pHeader, row) { + return nil + } + return rawConvertFuncSlice[colType](pStart, row, precision, timeFormat) +} + +func init() { + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_BOOL)] = rawConvertBool + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_TINYINT)] = rawConvertTinyint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_SMALLINT)] = rawConvertSmallint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_INT)] = rawConvertInt + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_BIGINT)] = rawConvertBigint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UTINYINT)] = rawConvertUTinyint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_USMALLINT)] = rawConvertUSmallint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UINT)] = rawConvertUInt + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UBIGINT)] = rawConvertUBigint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_FLOAT)] = rawConvertFloat + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_DOUBLE)] = rawConvertDouble + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_TIMESTAMP)] = rawConvertTime + + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_BINARY)] = rawConvertBinary + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_NCHAR)] = rawConvertNchar + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_JSON)] = rawConvertJson + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_VARBINARY)] = rawConvertVarBinary + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_GEOMETRY)] = rawConvertGeometry +} diff --git a/driver/common/parser/block_test.go b/driver/common/parser/block_test.go new file mode 100644 index 00000000..11b2877e --- /dev/null +++ b/driver/common/parser/block_test.go @@ -0,0 +1,797 @@ +package parser + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" + "github.com/taosdata/taosadapter/v3/tools" +) + +// @author: xftan +// @date: 2023/10/13 11:13 +// @description: test block +func TestReadBlock(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + defer func() { + res := wrapper.TaosQuery(conn, "drop database if exists test_block_raw_parser") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res := wrapper.TaosQuery(conn, "create database if not exists test_block_raw_parser") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + res = wrapper.TaosQuery(conn, "drop table if exists test_block_raw_parser.all_type2") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + res = wrapper.TaosQuery(conn, "create table if not exists test_block_raw_parser.all_type2 (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ")") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_block_raw_parser.all_type2 values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql = "select * from test_block_raw_parser.all_type2" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + pHeaderList := make([]unsafe.Pointer, fileCount) + pStartList := make([]unsafe.Pointer, fileCount) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(fileCount) + tmpPHeader := tools.AddPointer(block, RawBlockGetColDataOffset(fileCount)) + var tmpPStart unsafe.Pointer + for column := 0; column < fileCount; column++ { + colLength := *((*int32)(tools.AddPointer(block, lengthOffset+uintptr(column)*Int32Size))) + if IsVarDataType(rh.ColTypes[column]) { + pHeaderList[column] = tmpPHeader + tmpPStart = tools.AddPointer(tmpPHeader, Int32Size*uintptr(blockSize)) + pStartList[column] = tmpPStart + } else { + pHeaderList[column] = tmpPHeader + tmpPStart = tools.AddPointer(tmpPHeader, nullBitMapOffset) + pStartList[column] = tmpPStart + } + tmpPHeader = tools.AddPointer(tmpPStart, uintptr(colLength)) + } + for row := 0; row < blockSize; row++ { + rowV := make([]driver.Value, fileCount) + for column := 0; column < fileCount; column++ { + v := ItemRawBlock(rh.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, func(ts int64, precision int) driver.Value { + return common.TimestampConvertToTime(ts, precision) + }) + rowV[column] = v + } + data = append(data, rowV) + } + } + wrapper.TaosFreeResult(res) + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } +} + +// @author: xftan +// @date: 2023/10/13 11:13 +// @description: test block tag +func TestBlockTag(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + defer func() { + res := wrapper.TaosQuery(conn, "drop database if exists test_block_abc1") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res := wrapper.TaosQuery(conn, "create database if not exists test_block_abc1") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "use test_block_abc1") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists meters(ts timestamp, v int) tags(location varchar(16))") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists tb1 using meters tags('abcd')") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql := "select distinct tbname,location from meters;" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + pHeaderList := make([]unsafe.Pointer, fileCount) + pStartList := make([]unsafe.Pointer, fileCount) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(fileCount) + tmpPHeader := tools.AddPointer(block, RawBlockGetColDataOffset(fileCount)) // length i32, group u64 + var tmpPStart unsafe.Pointer + for column := 0; column < fileCount; column++ { + colLength := *((*int32)(tools.AddPointer(block, lengthOffset+uintptr(column)*Int32Size))) + if IsVarDataType(rh.ColTypes[column]) { + pHeaderList[column] = tmpPHeader + tmpPStart = tools.AddPointer(tmpPHeader, Int32Size*uintptr(blockSize)) + pStartList[column] = tmpPStart + } else { + pHeaderList[column] = tmpPHeader + tmpPStart = tools.AddPointer(tmpPHeader, nullBitMapOffset) + pStartList[column] = tmpPStart + } + tmpPHeader = tools.AddPointer(tmpPStart, uintptr(colLength)) + } + for row := 0; row < blockSize; row++ { + rowV := make([]driver.Value, fileCount) + for column := 0; column < fileCount; column++ { + v := ItemRawBlock(rh.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, func(ts int64, precision int) driver.Value { + return common.TimestampConvertToTime(ts, precision) + }) + rowV[column] = v + } + data = append(data, rowV) + } + } + wrapper.TaosFreeResult(res) + t.Log(data) + t.Log(len(data[0][1].(string))) +} + +// @author: xftan +// @date: 2023/10/13 11:18 +// @description: test read row +func TestReadRow(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + res := wrapper.TaosQuery(conn, "drop database if exists test_read_row") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + defer func() { + res = wrapper.TaosQuery(conn, "drop database if exists test_read_row") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res = wrapper.TaosQuery(conn, "create database test_read_row") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists test_read_row.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_read_row.t0 using test_read_row.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql = "select * from test_read_row.all_type" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + for i := 0; i < blockSize; i++ { + tmp := make([]driver.Value, fileCount) + ReadRow(tmp, block, blockSize, i, rh.ColTypes, precision) + data = append(data, tmp) + } + } + wrapper.TaosFreeResult(res) + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) +} + +// @author: xftan +// @date: 2023/10/13 11:18 +// @description: test read block with time format +func TestReadBlockWithTimeFormat(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + res := wrapper.TaosQuery(conn, "drop database if exists test_read_block_tf") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + defer func() { + res = wrapper.TaosQuery(conn, "drop database if exists test_read_block_tf") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res = wrapper.TaosQuery(conn, "create database test_read_block_tf") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists test_read_block_tf.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_read_block_tf.t0 using test_read_block_tf.all_type tags('{\"a\":1}') values('%s',false,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql = "select * from test_read_block_tf.all_type" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + data = ReadBlockWithTimeFormat(block, blockSize, rh.ColTypes, precision, func(ts int64, precision int) driver.Value { + return common.TimestampConvertToTime(ts, precision) + }) + } + wrapper.TaosFreeResult(res) + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, false, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) +} + +// @author: xftan +// @date: 2023/10/13 11:18 +// @description: test parse block +func TestParseBlock(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + res := wrapper.TaosQuery(conn, "drop database if exists parse_block") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + defer func() { + res = wrapper.TaosQuery(conn, "drop database if exists parse_block") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res = wrapper.TaosQuery(conn, "create database parse_block vgroups 1") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists parse_block.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20),"+ + "c14 varbinary(20),"+ + "c15 geometry(100)"+ + ") tags (info json)") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into parse_block.t0 using parse_block.all_type tags('{\"a\":1}') "+ + "values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar','test_varbinary','POINT(100 100)')"+ + "('%s',null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql = "select * from parse_block.all_type" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + precision := wrapper.TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + version := RawBlockGetVersion(block) + t.Log(version) + length := RawBlockGetLength(block) + assert.Equal(t, int32(448), length) + rows := RawBlockGetNumOfRows(block) + assert.Equal(t, int32(2), rows) + columns := RawBlockGetNumOfCols(block) + assert.Equal(t, int32(17), columns) + hasColumnSegment := RawBlockGetHasColumnSegment(block) + assert.Equal(t, int32(-2147483648), hasColumnSegment) + groupId := RawBlockGetGroupID(block) + assert.Equal(t, uint64(0), groupId) + infos := make([]RawBlockColInfo, columns) + RawBlockGetColInfo(block, infos) + assert.Equal( + t, + []RawBlockColInfo{ + { + ColType: 9, + Bytes: 8, + }, + { + ColType: 1, + Bytes: 1, + }, + { + ColType: 2, + Bytes: 1, + }, + { + ColType: 3, + Bytes: 2, + }, + { + ColType: 4, + Bytes: 4, + }, + { + ColType: 5, + Bytes: 8, + }, + { + ColType: 11, + Bytes: 1, + }, + { + ColType: 12, + Bytes: 2, + }, + { + ColType: 13, + Bytes: 4, + }, + { + ColType: 14, + Bytes: 8, + }, + { + ColType: 6, + Bytes: 4, + }, + { + ColType: 7, + Bytes: 8, + }, + { + ColType: 8, + Bytes: 22, + }, + { + ColType: 10, + Bytes: 82, + }, + { + ColType: 16, + Bytes: 22, + }, + { + ColType: 20, + Bytes: 102, + }, + { + ColType: 15, + Bytes: 16384, + }, + }, + infos, + ) + d := ReadBlockSimple(block, precision) + data = append(data, d...) + } + wrapper.TaosFreeResult(res) + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + assert.Equal(t, []byte("test_varbinary"), row1[14].([]byte)) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, row1[15].([]byte)) + assert.Equal(t, []byte(`{"a":1}`), row1[16].([]byte)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 16; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"a":1}`), row2[16].([]byte)) +} diff --git a/driver/common/parser/mem.go b/driver/common/parser/mem.go new file mode 100644 index 00000000..f0d4b000 --- /dev/null +++ b/driver/common/parser/mem.go @@ -0,0 +1,12 @@ +package parser + +import "unsafe" + +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +//go:linkname memmove runtime.memmove + +func Copy(source unsafe.Pointer, data []byte, index int, length int) { + memmove(unsafe.Pointer(&data[index]), source, uintptr(length)) +} diff --git a/driver/common/parser/mem.s b/driver/common/parser/mem.s new file mode 100644 index 00000000..e69de29b diff --git a/driver/common/parser/mem_test.go b/driver/common/parser/mem_test.go new file mode 100644 index 00000000..d3e244be --- /dev/null +++ b/driver/common/parser/mem_test.go @@ -0,0 +1,20 @@ +package parser + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestCopy(t *testing.T) { + data := []byte("World") + data1 := make([]byte, 10) + data1[0] = 'H' + data1[1] = 'e' + data1[2] = 'l' + data1[3] = 'l' + data1[4] = 'o' + Copy(unsafe.Pointer(&data[0]), data1, 5, 5) + assert.Equal(t, "HelloWorld", string(data1)) +} diff --git a/driver/common/parser/raw.go b/driver/common/parser/raw.go new file mode 100644 index 00000000..9a8235a3 --- /dev/null +++ b/driver/common/parser/raw.go @@ -0,0 +1,184 @@ +package parser + +import ( + "fmt" + "unsafe" + + "github.com/taosdata/taosadapter/v3/tools" +) + +type TMQRawDataParser struct { + block unsafe.Pointer + offset uintptr +} + +func NewTMQRawDataParser() *TMQRawDataParser { + return &TMQRawDataParser{} +} + +type TMQBlockInfo struct { + RawBlock unsafe.Pointer + Precision int + Schema []*TMQRawDataSchema + TableName string +} + +type TMQRawDataSchema struct { + ColType uint8 + Flag int8 + Bytes int64 + ColID int + Name string +} + +func (p *TMQRawDataParser) getTypeSkip(t int8) (int, error) { + skip := 8 + switch t { + case 1: + case 2, 3: + skip = 16 + default: + return 0, fmt.Errorf("unknown type %d", t) + } + return skip, nil +} + +func (p *TMQRawDataParser) skipHead() error { + v := p.parseInt8() + if v >= 100 { + skip := p.parseInt32() + p.skip(int(skip)) + return nil + } + skip, err := p.getTypeSkip(v) + if err != nil { + return err + } + p.skip(skip) + v = p.parseInt8() + skip, err = p.getTypeSkip(v) + if err != nil { + return err + } + p.skip(skip) + return nil +} + +func (p *TMQRawDataParser) skip(count int) { + p.offset += uintptr(count) +} + +func (p *TMQRawDataParser) parseBlockInfos() []*TMQBlockInfo { + blockNum := p.parseInt32() + blockInfos := make([]*TMQBlockInfo, blockNum) + withTableName := p.parseBool() + withSchema := p.parseBool() + for i := int32(0); i < blockNum; i++ { + blockInfo := &TMQBlockInfo{} + blockTotalLen := p.parseVariableByteInteger() + p.skip(17) + blockInfo.Precision = int(p.parseUint8()) + blockInfo.RawBlock = tools.AddPointer(p.block, p.offset) + p.skip(blockTotalLen - 18) + if withSchema { + cols := p.parseZigzagVariableByteInteger() + //version + _ = p.parseZigzagVariableByteInteger() + + blockInfo.Schema = make([]*TMQRawDataSchema, cols) + for j := 0; j < cols; j++ { + blockInfo.Schema[j] = p.parseSchema() + } + } + if withTableName { + blockInfo.TableName = p.parseName() + } + blockInfos[i] = blockInfo + } + return blockInfos +} + +func (p *TMQRawDataParser) parseZigzagVariableByteInteger() int { + return zigzagDecode(p.parseVariableByteInteger()) +} + +func (p *TMQRawDataParser) parseBool() bool { + v := *(*int8)(tools.AddPointer(p.block, p.offset)) + p.skip(1) + return v != 0 +} + +func (p *TMQRawDataParser) parseUint8() uint8 { + v := *(*uint8)(tools.AddPointer(p.block, p.offset)) + p.skip(1) + return v +} + +func (p *TMQRawDataParser) parseInt8() int8 { + v := *(*int8)(tools.AddPointer(p.block, p.offset)) + p.skip(1) + return v +} + +func (p *TMQRawDataParser) parseInt32() int32 { + v := *(*int32)(tools.AddPointer(p.block, p.offset)) + p.skip(4) + return v +} + +func (p *TMQRawDataParser) parseSchema() *TMQRawDataSchema { + colType := p.parseUint8() + flag := p.parseInt8() + bytes := int64(p.parseZigzagVariableByteInteger()) + colID := p.parseZigzagVariableByteInteger() + name := p.parseName() + return &TMQRawDataSchema{ + ColType: colType, + Flag: flag, + Bytes: bytes, + ColID: colID, + Name: name, + } +} + +func (p *TMQRawDataParser) parseName() string { + nameLen := p.parseVariableByteInteger() + name := make([]byte, nameLen-1) + for i := 0; i < nameLen-1; i++ { + name[i] = *(*byte)(tools.AddPointer(p.block, p.offset+uintptr(i))) + } + p.skip(nameLen) + return string(name) +} + +func (p *TMQRawDataParser) Parse(block unsafe.Pointer) ([]*TMQBlockInfo, error) { + p.reset(block) + err := p.skipHead() + if err != nil { + return nil, err + } + return p.parseBlockInfos(), nil +} + +func (p *TMQRawDataParser) reset(block unsafe.Pointer) { + p.block = block + p.offset = 0 +} + +func (p *TMQRawDataParser) parseVariableByteInteger() int { + multiplier := 1 + value := 0 + for { + encodedByte := p.parseUint8() + value += int(encodedByte&127) * multiplier + if encodedByte&128 == 0 { + break + } + multiplier *= 128 + } + return value +} + +func zigzagDecode(n int) int { + return (n >> 1) ^ (-(n & 1)) +} diff --git a/driver/common/parser/raw_test.go b/driver/common/parser/raw_test.go new file mode 100644 index 00000000..521a6260 --- /dev/null +++ b/driver/common/parser/raw_test.go @@ -0,0 +1,1049 @@ +package parser + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x01, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x01, 0x00, 0x00, 0x00, + + 0x01, + 0x01, + + 0xc5, 0x01, + + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x02, + + 0x02, 0x00, 0x00, 0x00, + 0xb3, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x82, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x5c, 0x00, 0x00, 0x00, + + 0x00, + 0xc0, 0xed, 0x82, 0x05, 0xc3, 0x1b, 0xab, 0x17, + + 0x80, + 0x00, 0x00, 0x00, 0x00, + + 0x80, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x5a, 0x00, + 0x61, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x61, + + 0x08, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x03, + 0x63, 0x31, 0x00, + + 0x06, + 0x01, + 0x08, + 0x06, + 0x03, + 0x63, 0x32, 0x00, + + 0x08, + 0x01, + 0x84, 0x02, + 0x08, + 0x03, 0x63, 0x33, 0x00, + + 0x05, + 0x63, 0x74, 0x62, 0x30, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 1, len(blockInfos)) + assert.Equal(t, 2, blockInfos[0].Precision) + assert.Equal(t, 4, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 2, + Name: "c1", + }, + { + ColType: 6, + Flag: 1, + Bytes: 4, + ColID: 3, + Name: "c2", + }, + { + ColType: 8, + Flag: 1, + Bytes: 130, + ColID: 4, + Name: "c3", + }, + }, blockInfos[0].Schema) + assert.Equal(t, "ctb0", blockInfos[0].TableName) +} + +func TestParseTwoBlock(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x02, 0x00, 0x00, 0x00, + + 0x00, // withTbName false + 0x01, // withSchema true + + 0x60, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x0c, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x00, + 0xf8, 0x6b, 0x75, 0x35, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, + 0x63, 0x74, 0x30, + + 0x06, + 0x00, + + 0x09, + 0x00, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x00, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x08, + 0x00, + 0x18, + 0x06, + 0x02, + 0x6e, 0x00, + + 0x60, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x0c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x00, + 0xf9, 0x6b, 0x75, 0x35, + 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, + 0x63, 0x74, 0x31, + + 0x06, + 0x00, + + 0x09, + 0x00, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x00, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x08, + 0x00, + 0x18, + 0x06, + 0x02, + 0x6e, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 2, len(blockInfos)) + assert.Equal(t, 0, blockInfos[0].Precision) + assert.Equal(t, 0, blockInfos[1].Precision) + assert.Equal(t, 3, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 0, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 0, + Bytes: 4, + ColID: 2, + Name: "v", + }, + { + ColType: 8, + Flag: 0, + Bytes: 12, + ColID: 3, + Name: "n", + }, + }, blockInfos[0].Schema) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 0, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 0, + Bytes: 4, + ColID: 2, + Name: "v", + }, + { + ColType: 8, + Flag: 0, + Bytes: 12, + ColID: 3, + Name: "n", + }, + }, blockInfos[1].Schema) + assert.Equal(t, "", blockInfos[0].TableName) + assert.Equal(t, "", blockInfos[1].TableName) +} + +func TestParseTenBlock(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + 0x0d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x01, + 0x01, + + // block1 + 0x4e, + + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x31, 0x00, + + //block2 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x32, 0x00, + + //block3 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x03, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x33, 0x00, + + //block4 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x34, 0x00, + + // block5 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x35, 0x00, + + //block6 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x06, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x36, 0x00, + + //block7 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x07, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x37, 0x00, + + //block8 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x08, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x38, 0x00, + + //block9 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x09, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x39, 0x00, + + //block10 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + 0x00, + 0x0a, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + 0x04, + 0x74, 0x31, 0x30, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 10, len(blockInfos)) + for i := 0; i < 10; i++ { + assert.Equal(t, 0, blockInfos[i].Precision) + assert.Equal(t, 2, len(blockInfos[i].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 2, + Name: "v", + }, + }, blockInfos[i].Schema) + assert.Equal(t, fmt.Sprintf("t%d", i+1), blockInfos[i].TableName) + value := ReadBlockSimple(blockInfos[i].RawBlock, blockInfos[i].Precision) + ts := time.Unix(0, 1706081119570000000).Local() + assert.Equal(t, [][]driver.Value{{ts, int32(i + 1)}}, value) + } +} + +func TestVersion100Block(t *testing.T) { + data := []byte{ + 0x64, //version + 0x12, 0x00, 0x00, 0x00, // skip 18 bytes + 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, //block count 1 + + 0x01, // with table name + 0x01, // with schema + + 0x92, 0x02, // block length 274 + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, // 256 + 0x01, 0x00, 0x00, 0x00, // rows + 0x0e, 0x00, 0x00, 0x00, // cols + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x05, 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x02, 0x00, 0x00, 0x00, + 0x0d, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x08, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x08, 0x00, 0x00, 0x00, + 0x08, 0x16, 0x00, 0x00, 0x00, + 0x0a, 0x52, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, + + 0x00, + 0x9e, 0x37, 0x6a, 0x04, 0x8f, 0x01, 0x00, 0x00, + + 0x00, + 0x01, + + 0x00, + 0x02, + + 0x00, + 0x03, 0x00, + + 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x00, + 0x06, + + 0x00, + 0x07, 0x00, + + 0x00, + 0x08, 0x00, 0x00, 0x00, + + 0x00, + 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x00, + 0xcf, 0xf7, 0x21, 0x41, + + 0x00, + 0xe5, 0xd0, 0x22, 0xdb, 0xf9, 0x3e, 0x26, 0x40, + + 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + 0x00, 0x00, 0x00, 0x00, + 0x14, 0x00, + 0x6e, 0x00, 0x00, 0x00, + 0x63, 0x00, 0x00, 0x00, + 0x68, 0x00, 0x00, 0x00, + 0x61, 0x00, 0x00, 0x00, + 0x72, 0x00, 0x00, 0x00, + + 0x00, // + + 0x1c, // cols 14 + 0x00, // version + + // col meta + 0x09, 0x01, 0x10, 0x02, 0x03, 0x74, 0x73, 0x00, + 0x01, 0x01, 0x02, 0x04, 0x03, 0x63, 0x31, 0x00, + 0x02, 0x01, 0x02, 0x06, 0x03, 0x63, 0x32, 0x00, + 0x03, 0x01, 0x04, 0x08, 0x03, 0x63, 0x33, 0x00, + 0x04, 0x01, 0x08, 0x0a, 0x03, 0x63, 0x34, 0x00, + 0x05, 0x01, 0x10, 0x0c, 0x03, 0x63, 0x35, 0x00, + 0x0b, 0x01, 0x02, 0x0e, 0x03, 0x63, 0x36, 0x00, + 0x0c, 0x01, 0x04, 0x10, 0x03, 0x63, 0x37, 0x00, + 0x0d, 0x01, 0x08, 0x12, 0x03, 0x63, 0x38, 0x00, + 0x0e, 0x01, 0x10, 0x14, 0x03, 0x63, 0x39, 0x00, + 0x06, 0x01, 0x08, 0x16, 0x04, 0x63, 0x31, 0x30, 0x00, + 0x07, 0x01, 0x10, 0x18, 0x04, 0x63, 0x31, 0x31, 0x00, + 0x08, 0x01, 0x2c, 0x1a, 0x04, 0x63, 0x31, 0x32, 0x00, + 0x0a, 0x01, 0xa4, 0x01, 0x1c, 0x04, 0x63, 0x31, 0x33, 0x00, + + 0x06, // table name + 0x74, 0x5f, 0x61, 0x6c, 0x6c, 0x00, + // sleep time + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 1, len(blockInfos)) + assert.Equal(t, 0, blockInfos[0].Precision) + assert.Equal(t, 14, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 1, + Flag: 1, + Bytes: 1, + ColID: 2, + Name: "c1", + }, + { + ColType: 2, + Flag: 1, + Bytes: 1, + ColID: 3, + Name: "c2", + }, + { + ColType: 3, + Flag: 1, + Bytes: 2, + ColID: 4, + Name: "c3", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 5, + Name: "c4", + }, + { + ColType: 5, + Flag: 1, + Bytes: 8, + ColID: 6, + Name: "c5", + }, + { + ColType: 11, + Flag: 1, + Bytes: 1, + ColID: 7, + Name: "c6", + }, + { + ColType: 12, + Flag: 1, + Bytes: 2, + ColID: 8, + Name: "c7", + }, + { + ColType: 13, + Flag: 1, + Bytes: 4, + ColID: 9, + Name: "c8", + }, + { + ColType: 14, + Flag: 1, + Bytes: 8, + ColID: 10, + Name: "c9", + }, + { + ColType: 6, + Flag: 1, + Bytes: 4, + ColID: 11, + Name: "c10", + }, + { + ColType: 7, + Flag: 1, + Bytes: 8, + ColID: 12, + Name: "c11", + }, + { + ColType: 8, + Flag: 1, + Bytes: 22, + ColID: 13, + Name: "c12", + }, + { + ColType: 10, + Flag: 1, + Bytes: 82, + ColID: 14, + Name: "c13", + }, + }, blockInfos[0].Schema) + assert.Equal(t, "t_all", blockInfos[0].TableName) + value := ReadBlockSimple(blockInfos[0].RawBlock, blockInfos[0].Precision) + expect := []driver.Value{ + time.Unix(0, 1713766021022000000).Local(), + true, + int8(2), + int16(3), + int32(4), + int64(5), + uint8(6), + uint16(7), + uint32(8), + uint64(9), + float32(10.123), + float64(11.123), + "binary", + "nchar", + } + assert.Equal(t, [][]driver.Value{expect}, value) +} diff --git a/driver/common/serializer/block.go b/driver/common/serializer/block.go new file mode 100644 index 00000000..ffaf7798 --- /dev/null +++ b/driver/common/serializer/block.go @@ -0,0 +1,552 @@ +package serializer + +import ( + "bytes" + "errors" + "math" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/param" + taosTypes "github.com/taosdata/taosadapter/v3/driver/types" +) + +const ( + Int16Size = int(common.Int16Size) + Int32Size = int(common.Int32Size) + Int64Size = int(common.Int64Size) + UInt16Size = int(common.UInt16Size) + UInt32Size = int(common.UInt32Size) + UInt64Size = int(common.UInt64Size) + Float32Size = int(common.Float32Size) + Float64Size = int(common.Float64Size) +) + +func BitmapLen(n int) int { + return ((n) + ((1 << 3) - 1)) >> 3 +} + +func BitPos(n int) int { + return n & ((1 << 3) - 1) +} + +func CharOffset(n int) int { + return n >> 3 +} + +func BMSetNull(c byte, n int) byte { + return c + (1 << (7 - BitPos(n))) +} + +var ErrColumnNumberNotMatch = errors.New("number of columns does not match") +var ErrDataTypeWrong = errors.New("wrong data type") + +func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte, error) { + columns := len(params) + rows := len(params[0].GetValues()) + colTypes, err := colType.GetValue() + if err != nil { + return nil, err + } + if len(colTypes) != columns { + return nil, ErrColumnNumberNotMatch + } + var block []byte + //version int32 + block = appendUint32(block, uint32(1)) + //length int32 + block = appendUint32(block, uint32(0)) + //rows int32 + block = appendUint32(block, uint32(rows)) + //columns int32 + block = appendUint32(block, uint32(columns)) + //flagSegment int32 + block = appendUint32(block, uint32(0)) + //groupID uint64 + block = appendUint64(block, uint64(0)) + colInfoData := make([]byte, 0, 5*columns) + lengthData := make([]byte, 0, 4*columns) + bitMapLen := BitmapLen(rows) + var data []byte + //colInfo(type+bytes) (int8+int32) * columns + buffer := bytes.NewBuffer(block) + for colIndex := 0; colIndex < columns; colIndex++ { + switch colTypes[colIndex].Type { + case taosTypes.TaosBoolType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_BOOL) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_BOOL] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosBool) + if !is { + return nil, ErrDataTypeWrong + } + if v { + dataTmp[rowIndex+bitMapLen] = 1 + } + } + } + data = append(data, dataTmp...) + case taosTypes.TaosTinyintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_TINYINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_TINYINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosTinyint) + if !is { + return nil, ErrDataTypeWrong + } + dataTmp[rowIndex+bitMapLen] = byte(v) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosSmallintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_SMALLINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_SMALLINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Int16Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosSmallint) + if !is { + return nil, ErrDataTypeWrong + } + offset := rowIndex*Int16Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosIntType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_INT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_INT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Int32Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosInt) + if !is { + return nil, ErrDataTypeWrong + } + offset := rowIndex*Int32Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + dataTmp[offset+2] = byte(v >> 16) + dataTmp[offset+3] = byte(v >> 24) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosBigintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_BIGINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_BIGINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Int64Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosBigint) + if !is { + return nil, ErrDataTypeWrong + } + offset := rowIndex*Int64Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + dataTmp[offset+2] = byte(v >> 16) + dataTmp[offset+3] = byte(v >> 24) + dataTmp[offset+4] = byte(v >> 32) + dataTmp[offset+5] = byte(v >> 40) + dataTmp[offset+6] = byte(v >> 48) + dataTmp[offset+7] = byte(v >> 56) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosUTinyintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_UTINYINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_UTINYINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosUTinyint) + if !is { + return nil, ErrDataTypeWrong + } + dataTmp[rowIndex+bitMapLen] = uint8(v) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosUSmallintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_USMALLINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_USMALLINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*UInt16Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosUSmallint) + if !is { + return nil, ErrDataTypeWrong + } + offset := rowIndex*UInt16Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosUIntType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_UINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_UINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*UInt32Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosUInt) + if !is { + return nil, ErrDataTypeWrong + } + offset := rowIndex*UInt32Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + dataTmp[offset+2] = byte(v >> 16) + dataTmp[offset+3] = byte(v >> 24) + } + } + data = append(data, dataTmp...) + + case taosTypes.TaosUBigintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_UBIGINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_UBIGINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*UInt64Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosUBigint) + if !is { + return nil, ErrDataTypeWrong + } + offset := rowIndex*UInt64Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + dataTmp[offset+2] = byte(v >> 16) + dataTmp[offset+3] = byte(v >> 24) + dataTmp[offset+4] = byte(v >> 32) + dataTmp[offset+5] = byte(v >> 40) + dataTmp[offset+6] = byte(v >> 48) + dataTmp[offset+7] = byte(v >> 56) + } + } + data = append(data, dataTmp...) + + case taosTypes.TaosFloatType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_FLOAT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_FLOAT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Float32Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosFloat) + if !is { + return nil, ErrDataTypeWrong + } + offset := rowIndex*Float32Size + bitMapLen + vv := math.Float32bits(float32(v)) + dataTmp[offset] = byte(vv) + dataTmp[offset+1] = byte(vv >> 8) + dataTmp[offset+2] = byte(vv >> 16) + dataTmp[offset+3] = byte(vv >> 24) + } + } + data = append(data, dataTmp...) + + case taosTypes.TaosDoubleType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_DOUBLE) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_DOUBLE] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Float64Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosDouble) + if !is { + return nil, ErrDataTypeWrong + } + offset := rowIndex*Float64Size + bitMapLen + vv := math.Float64bits(float64(v)) + dataTmp[offset] = byte(vv) + dataTmp[offset+1] = byte(vv >> 8) + dataTmp[offset+2] = byte(vv >> 16) + dataTmp[offset+3] = byte(vv >> 24) + dataTmp[offset+4] = byte(vv >> 32) + dataTmp[offset+5] = byte(vv >> 40) + dataTmp[offset+6] = byte(vv >> 48) + dataTmp[offset+7] = byte(vv >> 56) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosBinaryType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_BINARY) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosBinary) + if !is { + return nil, ErrDataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + dataTmp = appendUint16(dataTmp, uint16(len(v))) + dataTmp = append(dataTmp, v...) + length += len(v) + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + case taosTypes.TaosVarBinaryType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_VARBINARY) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosVarBinary) + if !is { + return nil, ErrDataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + dataTmp = appendUint16(dataTmp, uint16(len(v))) + dataTmp = append(dataTmp, v...) + length += len(v) + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + case taosTypes.TaosGeometryType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_GEOMETRY) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosGeometry) + if !is { + return nil, ErrDataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + dataTmp = appendUint16(dataTmp, uint16(len(v))) + dataTmp = append(dataTmp, v...) + length += len(v) + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + case taosTypes.TaosNcharType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_NCHAR) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosNchar) + if !is { + return nil, ErrDataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + rs := []rune(v) + dataTmp = appendUint16(dataTmp, uint16(len(rs)*4)) + for _, r := range rs { + dataTmp = appendUint32(dataTmp, uint32(r)) + } + length += len(rs)*4 + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + case taosTypes.TaosTimestampType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_TIMESTAMP) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_TIMESTAMP] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Int64Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosTimestamp) + if !is { + return nil, ErrDataTypeWrong + } + vv := common.TimeToTimestamp(v.T, v.Precision) + offset := rowIndex*Int64Size + bitMapLen + dataTmp[offset] = byte(vv) + dataTmp[offset+1] = byte(vv >> 8) + dataTmp[offset+2] = byte(vv >> 16) + dataTmp[offset+3] = byte(vv >> 24) + dataTmp[offset+4] = byte(vv >> 32) + dataTmp[offset+5] = byte(vv >> 40) + dataTmp[offset+6] = byte(vv >> 48) + dataTmp[offset+7] = byte(vv >> 56) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosJsonType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_JSON) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosJson) + if !is { + return nil, ErrDataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + dataTmp = appendUint16(dataTmp, uint16(len(v))) + dataTmp = append(dataTmp, v...) + length += len(v) + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + } + } + buffer.Write(colInfoData) + buffer.Write(lengthData) + buffer.Write(data) + block = buffer.Bytes() + for i := 0; i < Int32Size; i++ { + block[4+i] = byte(len(block) >> (8 * i)) + } + return block, nil +} + +func appendUint16(b []byte, v uint16) []byte { + return append(b, + byte(v), + byte(v>>8), + ) +} + +func appendUint32(b []byte, v uint32) []byte { + return append(b, + byte(v), + byte(v>>8), + byte(v>>16), + byte(v>>24), + ) +} + +func appendUint64(b []byte, v uint64) []byte { + return append(b, + byte(v), + byte(v>>8), + byte(v>>16), + byte(v>>24), + byte(v>>32), + byte(v>>40), + byte(v>>48), + byte(v>>56), + ) +} diff --git a/driver/common/serializer/block_test.go b/driver/common/serializer/block_test.go new file mode 100644 index 00000000..db4bb1ff --- /dev/null +++ b/driver/common/serializer/block_test.go @@ -0,0 +1,397 @@ +package serializer + +import ( + "math" + "reflect" + "testing" + "time" + + "github.com/taosdata/taosadapter/v3/driver/common/param" +) + +// @author: xftan +// @date: 2023/10/13 11:19 +// @description: test block +func TestSerializeRawBlock(t *testing.T) { + type args struct { + params []*param.Param + colType *param.ColumnType + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + name: "all type", + args: args{ + params: []*param.Param{ + param.NewParam(1).AddTimestamp(time.Unix(0, 0), 0), + param.NewParam(1).AddBool(true), + param.NewParam(1).AddTinyint(127), + param.NewParam(1).AddSmallint(32767), + param.NewParam(1).AddInt(2147483647), + param.NewParam(1).AddBigint(9223372036854775807), + param.NewParam(1).AddUTinyint(255), + param.NewParam(1).AddUSmallint(65535), + param.NewParam(1).AddUInt(4294967295), + param.NewParam(1).AddUBigint(18446744073709551615), + param.NewParam(1).AddFloat(math.MaxFloat32), + param.NewParam(1).AddDouble(math.MaxFloat64), + param.NewParam(1).AddBinary([]byte("ABC")), + param.NewParam(1).AddNchar("涛思数据"), + }, + colType: param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0), + }, + want: []byte{ + 0x01, 0x00, 0x00, 0x00, //version + 0xf8, 0x00, 0x00, 0x00, //length + 0x01, 0x00, 0x00, 0x00, //rows + 0x0e, 0x00, 0x00, 0x00, //columns + 0x00, 0x00, 0x00, 0x00, //flagSegment + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //groupID + //types + 0x09, 0x08, 0x00, 0x00, 0x00, //1 + 0x01, 0x01, 0x00, 0x00, 0x00, //2 + 0x02, 0x01, 0x00, 0x00, 0x00, //3 + 0x03, 0x02, 0x00, 0x00, 0x00, //4 + 0x04, 0x04, 0x00, 0x00, 0x00, //5 + 0x05, 0x08, 0x00, 0x00, 0x00, //6 + 0x0b, 0x01, 0x00, 0x00, 0x00, //7 + 0x0c, 0x02, 0x00, 0x00, 0x00, //8 + 0x0d, 0x04, 0x00, 0x00, 0x00, //9 + 0x0e, 0x08, 0x00, 0x00, 0x00, //10 + 0x06, 0x04, 0x00, 0x00, 0x00, //11 + 0x07, 0x08, 0x00, 0x00, 0x00, //12 + 0x08, 0x00, 0x00, 0x00, 0x00, //13 + 0x0a, 0x00, 0x00, 0x00, 0x00, //14 + //lengths + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, + 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //ts + 0x00, + 0x01, //bool + 0x00, + 0x7f, //i8 + 0x00, + 0xff, 0x7f, //i16 + 0x00, + 0xff, 0xff, 0xff, 0x7f, //i32 + 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, //i64 + 0x00, + 0xff, //u8 + 0x00, + 0xff, 0xff, //u16 + 0x00, + 0xff, 0xff, 0xff, 0xff, //u32 + 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, //u64 + 0x00, + 0xff, 0xff, 0x7f, 0x7f, //f32 + 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0x7f, //f64 + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, //binary + 0x41, 0x42, 0x43, + 0x00, 0x00, 0x00, 0x00, + 0x10, 0x00, //nchar + 0x9b, 0x6d, 0x00, 0x00, 0x1d, 0x60, 0x00, 0x00, 0x70, 0x65, 0x00, 0x00, 0x6e, 0x63, 0x00, 0x00, + }, + wantErr: false, + }, + { + name: "all with nil", + args: args{ + params: []*param.Param{ + param.NewParam(3).AddTimestamp(time.Unix(1666248065, 0), 0).AddNull().AddTimestamp(time.Unix(1666248067, 0), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + param.NewParam(3).AddVarBinary([]byte("test_varbinary")).AddNull().AddVarBinary([]byte("test_varbinary")), + param.NewParam(3).AddGeometry([]byte{ + 0x01, + 0x01, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x59, + 0x40, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x59, + 0x40, + }).AddNull().AddGeometry([]byte{ + 0x01, + 0x01, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x59, + 0x40, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x59, + 0x40, + }), + param.NewParam(3).AddJson([]byte("{\"a\":1}")).AddNull().AddJson([]byte("{\"a\":1}")), + }, + colType: param.NewColumnType(17). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0). + AddVarBinary(0). + AddGeometry(0). + AddJson(0), + }, + want: []byte{ + 0x01, 0x00, 0x00, 0x00, + 0x64, 0x02, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + //types + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x05, 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x02, 0x00, 0x00, 0x00, + 0x0d, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x08, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x08, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x00, + //lengths + 0x18, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x1a, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, + 0x2e, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, + // ts + 0x40, + 0xe8, 0xbf, 0x1f, 0xf4, 0x83, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xb8, 0xc7, 0x1f, 0xf4, 0x83, 0x01, 0x00, 0x00, + + // bool + 0x40, + 0x01, + 0x00, + 0x01, + + // i8 + 0x40, + 0x01, + 0x00, + 0x01, + + //int16 + 0x40, + 0x01, 0x00, + 0x00, 0x00, + 0x01, 0x00, + + //int32 + 0x40, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + + //int64 + 0x40, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + //uint8 + 0x40, + 0x01, + 0x00, + 0x01, + + //uint16 + 0x40, + 0x01, 0x00, + 0x00, 0x00, + 0x01, 0x00, + + //uint32 + 0x40, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + + //uint64 + 0x40, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + //float + 0x40, + 0x00, 0x00, 0x80, 0x3f, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x80, 0x3f, + + //double + 0x40, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, + + //binary + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x0d, 0x00, 0x00, 0x00, + 0x0b, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + 0x0b, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + //nchar + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x2a, 0x00, 0x00, 0x00, + 0x28, 0x00, + 0x74, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, 0x73, 0x00, + 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, + 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, + 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, + 0x28, 0x00, + 0x74, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, 0x73, 0x00, + 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, + 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, + 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, + + //varbinary + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, + 0x0e, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + 0x0e, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + //geometry + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x17, 0x00, 0x00, 0x00, + 0x15, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x15, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + + //json + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x09, 0x00, 0x00, 0x00, + 0x07, 0x00, + 0x7b, 0x22, 0x61, 0x22, 0x3a, 0x31, 0x7d, + 0x07, 0x00, + 0x7b, 0x22, 0x61, 0x22, 0x3a, 0x31, 0x7d, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := SerializeRawBlock(tt.args.params, tt.args.colType) + if (err != nil) != tt.wantErr { + t.Errorf("SerializeRawBlock() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SerializeRawBlock() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/driver/common/stmt/field.go b/driver/common/stmt/field.go new file mode 100644 index 00000000..c15ab0ed --- /dev/null +++ b/driver/common/stmt/field.go @@ -0,0 +1,73 @@ +package stmt + +import ( + "database/sql/driver" + "fmt" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/types" +) + +type StmtField struct { + Name string `json:"name"` + FieldType int8 `json:"field_type"` + Precision uint8 `json:"precision"` + Scale uint8 `json:"scale"` + Bytes int32 `json:"bytes"` +} + +func (s *StmtField) GetType() (*types.ColumnType, error) { + switch s.FieldType { + case common.TSDB_DATA_TYPE_BOOL: + return &types.ColumnType{Type: types.TaosBoolType}, nil + case common.TSDB_DATA_TYPE_TINYINT: + return &types.ColumnType{Type: types.TaosTinyintType}, nil + case common.TSDB_DATA_TYPE_SMALLINT: + return &types.ColumnType{Type: types.TaosSmallintType}, nil + case common.TSDB_DATA_TYPE_INT: + return &types.ColumnType{Type: types.TaosIntType}, nil + case common.TSDB_DATA_TYPE_BIGINT: + return &types.ColumnType{Type: types.TaosBigintType}, nil + case common.TSDB_DATA_TYPE_UTINYINT: + return &types.ColumnType{Type: types.TaosUTinyintType}, nil + case common.TSDB_DATA_TYPE_USMALLINT: + return &types.ColumnType{Type: types.TaosUSmallintType}, nil + case common.TSDB_DATA_TYPE_UINT: + return &types.ColumnType{Type: types.TaosUIntType}, nil + case common.TSDB_DATA_TYPE_UBIGINT: + return &types.ColumnType{Type: types.TaosUBigintType}, nil + case common.TSDB_DATA_TYPE_FLOAT: + return &types.ColumnType{Type: types.TaosFloatType}, nil + case common.TSDB_DATA_TYPE_DOUBLE: + return &types.ColumnType{Type: types.TaosDoubleType}, nil + case common.TSDB_DATA_TYPE_BINARY: + return &types.ColumnType{Type: types.TaosBinaryType}, nil + case common.TSDB_DATA_TYPE_VARBINARY: + return &types.ColumnType{Type: types.TaosVarBinaryType}, nil + case common.TSDB_DATA_TYPE_NCHAR: + return &types.ColumnType{Type: types.TaosNcharType}, nil + case common.TSDB_DATA_TYPE_TIMESTAMP: + return &types.ColumnType{Type: types.TaosTimestampType}, nil + case common.TSDB_DATA_TYPE_JSON: + return &types.ColumnType{Type: types.TaosJsonType}, nil + case common.TSDB_DATA_TYPE_GEOMETRY: + return &types.ColumnType{Type: types.TaosGeometryType}, nil + } + return nil, fmt.Errorf("unsupported type: %d, name %s", s.FieldType, s.Name) +} + +//revive:disable +const ( + TAOS_FIELD_COL = iota + 1 + TAOS_FIELD_TAG + TAOS_FIELD_QUERY + TAOS_FIELD_TBNAME +) + +//revive:enable + +type TaosStmt2BindData struct { + TableName string + Tags []driver.Value // row format + Cols [][]driver.Value // column format +} diff --git a/driver/common/stmt/field_test.go b/driver/common/stmt/field_test.go new file mode 100644 index 00000000..9d1271bc --- /dev/null +++ b/driver/common/stmt/field_test.go @@ -0,0 +1,143 @@ +package stmt + +import ( + "testing" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/types" +) + +func TestGetType(t *testing.T) { + tests := []struct { + name string + fieldType int8 + want *types.ColumnType + wantErr bool + }{ + { + name: "Test Bool Type", + fieldType: common.TSDB_DATA_TYPE_BOOL, + want: &types.ColumnType{Type: types.TaosBoolType}, + wantErr: false, + }, + { + name: "Test TinyInt Type", + fieldType: common.TSDB_DATA_TYPE_TINYINT, + want: &types.ColumnType{Type: types.TaosTinyintType}, + wantErr: false, + }, + { + name: "Test SmallInt Type", + fieldType: common.TSDB_DATA_TYPE_SMALLINT, + want: &types.ColumnType{Type: types.TaosSmallintType}, + wantErr: false, + }, + { + name: "Test Int Type", + fieldType: common.TSDB_DATA_TYPE_INT, + want: &types.ColumnType{Type: types.TaosIntType}, + wantErr: false, + }, + { + name: "Test BigInt Type", + fieldType: common.TSDB_DATA_TYPE_BIGINT, + want: &types.ColumnType{Type: types.TaosBigintType}, + wantErr: false, + }, + { + name: "Test UTinyInt Type", + fieldType: common.TSDB_DATA_TYPE_UTINYINT, + want: &types.ColumnType{Type: types.TaosUTinyintType}, + wantErr: false, + }, + { + name: "Test USmallInt Type", + fieldType: common.TSDB_DATA_TYPE_USMALLINT, + want: &types.ColumnType{Type: types.TaosUSmallintType}, + wantErr: false, + }, + { + name: "Test UInt Type", + fieldType: common.TSDB_DATA_TYPE_UINT, + want: &types.ColumnType{Type: types.TaosUIntType}, + wantErr: false, + }, + { + name: "Test UBigInt Type", + fieldType: common.TSDB_DATA_TYPE_UBIGINT, + want: &types.ColumnType{Type: types.TaosUBigintType}, + wantErr: false, + }, + { + name: "Test Float Type", + fieldType: common.TSDB_DATA_TYPE_FLOAT, + want: &types.ColumnType{Type: types.TaosFloatType}, + wantErr: false, + }, + { + name: "Test Double Type", + fieldType: common.TSDB_DATA_TYPE_DOUBLE, + want: &types.ColumnType{Type: types.TaosDoubleType}, + wantErr: false, + }, + { + name: "Test Binary Type", + fieldType: common.TSDB_DATA_TYPE_BINARY, + want: &types.ColumnType{Type: types.TaosBinaryType}, + wantErr: false, + }, + { + name: "Test VarBinary Type", + fieldType: common.TSDB_DATA_TYPE_VARBINARY, + want: &types.ColumnType{Type: types.TaosVarBinaryType}, + wantErr: false, + }, + { + name: "Test Nchar Type", + fieldType: common.TSDB_DATA_TYPE_NCHAR, + want: &types.ColumnType{Type: types.TaosNcharType}, + wantErr: false, + }, + { + name: "Test Timestamp Type", + fieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + want: &types.ColumnType{Type: types.TaosTimestampType}, + wantErr: false, + }, + { + name: "Test Json Type", + fieldType: common.TSDB_DATA_TYPE_JSON, + want: &types.ColumnType{Type: types.TaosJsonType}, + wantErr: false, + }, + { + name: "Test Geometry Type", + fieldType: common.TSDB_DATA_TYPE_GEOMETRY, + want: &types.ColumnType{Type: types.TaosGeometryType}, + wantErr: false, + }, + { + name: "Test Unsupported Type", + fieldType: 0, // An undefined type + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &StmtField{ + FieldType: tt.fieldType, + } + + got, err := s.GetType() + if (err != nil) != tt.wantErr { + t.Errorf("GetType() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil && tt.want != nil && got.Type != tt.want.Type { + t.Errorf("GetType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/driver/common/stmt/stmt2.go b/driver/common/stmt/stmt2.go new file mode 100644 index 00000000..d1a97696 --- /dev/null +++ b/driver/common/stmt/stmt2.go @@ -0,0 +1,580 @@ +package stmt + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "time" + + "github.com/taosdata/taosadapter/v3/driver/common" +) + +const ( + TotalLengthPosition = 0 + CountPosition = TotalLengthPosition + 4 + TagCountPosition = CountPosition + 4 + ColCountPosition = TagCountPosition + 4 + TableNamesOffsetPosition = ColCountPosition + 4 + TagsOffsetPosition = TableNamesOffsetPosition + 4 + ColsOffsetPosition = TagsOffsetPosition + 4 + DataPosition = ColsOffsetPosition + 4 +) + +const ( + BindDataTotalLengthOffset = 0 + BindDataTypeOffset = BindDataTotalLengthOffset + 4 + BindDataNumOffset = BindDataTypeOffset + 4 + BindDataIsNullOffset = BindDataNumOffset + 4 +) + +func MarshalStmt2Binary(bindData []*TaosStmt2BindData, isInsert bool, colType, tagType []*StmtField) ([]byte, error) { + // count + count := len(bindData) + if count == 0 { + return nil, fmt.Errorf("empty data") + } + needTableNames := false + needTags := false + needCols := false + tagCount := len(tagType) + colCount := len(colType) + if isInsert { + for i := 0; i < count; i++ { + data := bindData[i] + if data.TableName != "" { + needTableNames = true + } + if len(data.Tags) != tagCount { + return nil, fmt.Errorf("tag count not match, data count:%d, type count:%d", len(data.Tags), tagCount) + } + if len(data.Cols) != colCount { + return nil, fmt.Errorf("col count not match, data count:%d, type count:%d", len(data.Cols), colCount) + } + } + } else { + if tagCount != 0 { + return nil, fmt.Errorf("query not need tag types") + } + if colCount != 0 { + return nil, fmt.Errorf("query not need col types") + } + if count != 1 { + return nil, fmt.Errorf("query only need one data") + } + + data := bindData[0] + if data.TableName != "" { + return nil, fmt.Errorf("query not need table name") + } + if len(data.Tags) != 0 { + return nil, fmt.Errorf("query not need tag") + } + if len(data.Cols) == 0 { + return nil, fmt.Errorf("query need col") + } + colCount = len(data.Cols) + for j := 0; j < colCount; j++ { + if len(data.Cols[j]) != 1 { + return nil, fmt.Errorf("query col data must be one row, col:%d, count:%d", j, len(data.Cols[j])) + } + } + } + + header := make([]byte, DataPosition) + // count + binary.LittleEndian.PutUint32(header[CountPosition:], uint32(count)) + // tag count + if tagCount != 0 { + needTags = true + binary.LittleEndian.PutUint32(header[TagCountPosition:], uint32(tagCount)) + } + // col count + if colCount != 0 { + needCols = true + binary.LittleEndian.PutUint32(header[ColCountPosition:], uint32(colCount)) + } + if !needTableNames && !needTags && !needCols { + return nil, fmt.Errorf("no data") + } + tmpBuf := &bytes.Buffer{} + tableNameBuf := &bytes.Buffer{} + var tableNameLength []uint16 + if needTableNames { + tableNameLength = make([]uint16, count) + } + tagBuf := &bytes.Buffer{} + var tagDataLength []uint32 + if needTags { + tagDataLength = make([]uint32, count) + } + colBuf := &bytes.Buffer{} + var colDataLength []uint32 + if needCols { + colDataLength = make([]uint32, count) + } + for index, data := range bindData { + // table name + if needTableNames { + if data.TableName != "" { + if len(data.TableName) > math.MaxUint16-1 { + return nil, fmt.Errorf("table name too long, index:%d, length:%d", index, len(data.TableName)) + } + tableNameBuf.WriteString(data.TableName) + } + tableNameBuf.WriteByte(0) + tableNameLength[index] = uint16(len(data.TableName) + 1) + } + + // tag + if needTags { + length := 0 + for i := 0; i < len(data.Tags); i++ { + tag := data.Tags[i] + tagDataBuffer, err := generateBindColData([]driver.Value{tag}, tagType[i], tmpBuf) + if err != nil { + return nil, err + } + length += len(tagDataBuffer) + tagBuf.Write(tagDataBuffer) + } + tagDataLength[index] = uint32(length) + } + // col + if needCols { + length := 0 + for i := 0; i < len(data.Cols); i++ { + col := data.Cols[i] + var colDataBuffer []byte + var err error + if isInsert { + colDataBuffer, err = generateBindColData(col, colType[i], tmpBuf) + } else { + colDataBuffer, err = generateBindQueryData(col[0]) + } + if err != nil { + return nil, err + } + length += len(colDataBuffer) + colBuf.Write(colDataBuffer) + } + colDataLength[index] = uint32(length) + } + } + tableTotalLength := tableNameBuf.Len() + tagTotalLength := tagBuf.Len() + colTotalLength := colBuf.Len() + tagOffset := DataPosition + tableTotalLength + len(tableNameLength)*2 + colOffset := tagOffset + tagTotalLength + len(tagDataLength)*4 + totalLength := colOffset + colTotalLength + len(colDataLength)*4 + if needTableNames { + binary.LittleEndian.PutUint32(header[TableNamesOffsetPosition:], uint32(DataPosition)) + } + if needTags { + binary.LittleEndian.PutUint32(header[TagsOffsetPosition:], uint32(tagOffset)) + } + if needCols { + binary.LittleEndian.PutUint32(header[ColsOffsetPosition:], uint32(colOffset)) + } + binary.LittleEndian.PutUint32(header[TotalLengthPosition:], uint32(totalLength)) + buffer := make([]byte, totalLength) + copy(buffer, header) + if needTableNames { + offset := DataPosition + for _, length := range tableNameLength { + binary.LittleEndian.PutUint16(buffer[offset:], length) + offset += 2 + } + copy(buffer[offset:], tableNameBuf.Bytes()) + } + if needTags { + offset := tagOffset + for _, length := range tagDataLength { + binary.LittleEndian.PutUint32(buffer[offset:], length) + offset += 4 + } + copy(buffer[offset:], tagBuf.Bytes()) + } + if needCols { + offset := colOffset + for _, length := range colDataLength { + binary.LittleEndian.PutUint32(buffer[offset:], length) + offset += 4 + } + copy(buffer[offset:], colBuf.Bytes()) + } + return buffer, nil +} + +func getBindDataHeaderLength(num int, needLength bool) int { + length := 17 + num + if needLength { + length += num * 4 + } + return length +} + +func generateBindColData(data []driver.Value, colType *StmtField, tmpBuffer *bytes.Buffer) ([]byte, error) { + num := len(data) + tmpBuffer.Reset() + needLength := needLength(colType.FieldType) + headerLength := getBindDataHeaderLength(num, needLength) + tmpHeader := make([]byte, headerLength) + // type + binary.LittleEndian.PutUint32(tmpHeader[BindDataTypeOffset:], uint32(colType.FieldType)) + // num + binary.LittleEndian.PutUint32(tmpHeader[BindDataNumOffset:], uint32(num)) + // is null + isNull := tmpHeader[BindDataIsNullOffset : BindDataIsNullOffset+num] + // has length + if needLength { + tmpHeader[BindDataIsNullOffset+num] = 1 + } + bufferLengthOffset := BindDataIsNullOffset + num + 1 + isAllNull := checkAllNull(data) + if isAllNull { + for i := 0; i < num; i++ { + isNull[i] = 1 + } + } else { + switch colType.FieldType { + case common.TSDB_DATA_TYPE_BOOL: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + tmpBuffer.WriteByte(0) + } else { + v, ok := data[i].(bool) + if !ok { + return nil, fmt.Errorf("data type not match, expect bool, but get %T, value:%v", data[i], data[i]) + } + if v { + tmpBuffer.WriteByte(1) + } else { + tmpBuffer.WriteByte(0) + } + } + } + case common.TSDB_DATA_TYPE_TINYINT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + tmpBuffer.WriteByte(0) + } else { + v, ok := data[i].(int8) + if !ok { + return nil, fmt.Errorf("data type not match, expect int8, but get %T, value:%v", data[i], data[i]) + } + tmpBuffer.WriteByte(byte(v)) + } + } + + case common.TSDB_DATA_TYPE_SMALLINT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint16(tmpBuffer, uint16(0)) + } else { + v, ok := data[i].(int16) + if !ok { + return nil, fmt.Errorf("data type not match, expect int16, but get %T, value:%v", data[i], data[i]) + } + writeUint16(tmpBuffer, uint16(v)) + } + } + + case common.TSDB_DATA_TYPE_INT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint32(tmpBuffer, uint32(0)) + } else { + v, ok := data[i].(int32) + if !ok { + return nil, fmt.Errorf("data type not match, expect int32, but get %T, value:%v", data[i], data[i]) + } + writeUint32(tmpBuffer, uint32(v)) + } + } + case common.TSDB_DATA_TYPE_BIGINT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint64(tmpBuffer, 0) + } else { + v, ok := data[i].(int64) + if !ok { + return nil, fmt.Errorf("data type not match, expect int64, but get %T, value:%v", data[i], data[i]) + } + writeUint64(tmpBuffer, uint64(v)) + } + } + case common.TSDB_DATA_TYPE_FLOAT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint32(tmpBuffer, 0) + } else { + v, ok := data[i].(float32) + if !ok { + return nil, fmt.Errorf("data type not match, expect float32, but get %T, value:%v", data[i], data[i]) + } + writeUint32(tmpBuffer, math.Float32bits(v)) + } + } + case common.TSDB_DATA_TYPE_DOUBLE: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint64(tmpBuffer, 0) + } else { + v, ok := data[i].(float64) + if !ok { + return nil, fmt.Errorf("data type not match, expect float64, but get %T, value:%v", data[i], data[i]) + } + writeUint64(tmpBuffer, math.Float64bits(v)) + } + } + case common.TSDB_DATA_TYPE_TIMESTAMP: + precision := int(colType.Precision) + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint64(tmpBuffer, 0) + } else { + switch v := data[i].(type) { + case int64: + writeUint64(tmpBuffer, uint64(v)) + case time.Time: + ts := common.TimeToTimestamp(v, precision) + writeUint64(tmpBuffer, uint64(ts)) + default: + return nil, fmt.Errorf("data type not match, expect int64 or time.Time, but get %T, value:%v", data[i], data[i]) + } + } + } + case common.TSDB_DATA_TYPE_BINARY, common.TSDB_DATA_TYPE_NCHAR, common.TSDB_DATA_TYPE_VARBINARY, common.TSDB_DATA_TYPE_GEOMETRY, common.TSDB_DATA_TYPE_JSON: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + } else { + switch v := data[i].(type) { + case string: + tmpBuffer.WriteString(v) + binary.LittleEndian.PutUint32(tmpHeader[bufferLengthOffset+i*4:], uint32(len(v))) + case []byte: + tmpBuffer.Write(v) + binary.LittleEndian.PutUint32(tmpHeader[bufferLengthOffset+i*4:], uint32(len(v))) + default: + return nil, fmt.Errorf("data type not match, expect string or []byte, but get %T, value:%v", data[i], data[i]) + } + } + } + case common.TSDB_DATA_TYPE_UTINYINT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + tmpBuffer.WriteByte(0) + } else { + v, ok := data[i].(uint8) + if !ok { + return nil, fmt.Errorf("data type not match, expect uint8, but get %T, value:%v", data[i], data[i]) + } + tmpBuffer.WriteByte(v) + } + } + case common.TSDB_DATA_TYPE_USMALLINT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint16(tmpBuffer, 0) + } else { + v, ok := data[i].(uint16) + if !ok { + return nil, fmt.Errorf("data type not match, expect uint16, but get %T, value:%v", data[i], data[i]) + } + writeUint16(tmpBuffer, v) + } + } + case common.TSDB_DATA_TYPE_UINT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint32(tmpBuffer, 0) + } else { + v, ok := data[i].(uint32) + if !ok { + return nil, fmt.Errorf("data type not match, expect uint32, but get %T, value:%v", data[i], data[i]) + } + writeUint32(tmpBuffer, v) + } + } + case common.TSDB_DATA_TYPE_UBIGINT: + for i := 0; i < num; i++ { + if data[i] == nil { + isNull[i] = 1 + writeUint64(tmpBuffer, 0) + } else { + v, ok := data[i].(uint64) + if !ok { + return nil, fmt.Errorf("data type not match, expect uint64, but get %T, value:%v", data[i], data[i]) + } + writeUint64(tmpBuffer, v) + } + } + default: + return nil, fmt.Errorf("unsupported type: %d", colType.FieldType) + } + } + buffer := tmpBuffer.Bytes() + // bufferLength + binary.LittleEndian.PutUint32(tmpHeader[headerLength-4:], uint32(len(buffer))) + totalLength := len(buffer) + headerLength + binary.LittleEndian.PutUint32(tmpHeader[BindDataTotalLengthOffset:], uint32(totalLength)) + dataBuffer := make([]byte, totalLength) + copy(dataBuffer, tmpHeader) + copy(dataBuffer[headerLength:], buffer) + return dataBuffer, nil +} + +func checkAllNull(data []driver.Value) bool { + for i := 0; i < len(data); i++ { + if data[i] != nil { + return false + } + } + return true +} + +func generateBindQueryData(data driver.Value) ([]byte, error) { + var colType uint32 + var haveLength = false + var length = 0 + var buf []byte + switch v := data.(type) { + case string: + colType = common.TSDB_DATA_TYPE_BINARY + haveLength = true + length = len(v) + buf = make([]byte, length) + copy(buf, v) + case []byte: + colType = common.TSDB_DATA_TYPE_BINARY + haveLength = true + length = len(v) + buf = make([]byte, length) + copy(buf, v) + case int8: + colType = common.TSDB_DATA_TYPE_TINYINT + buf = make([]byte, 1) + buf[0] = byte(v) + case int16: + colType = common.TSDB_DATA_TYPE_SMALLINT + buf = make([]byte, 2) + binary.LittleEndian.PutUint16(buf, uint16(v)) + case int32: + colType = common.TSDB_DATA_TYPE_INT + buf = make([]byte, 4) + binary.LittleEndian.PutUint32(buf, uint32(v)) + case int64: + colType = common.TSDB_DATA_TYPE_BIGINT + buf = make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(v)) + case uint8: + colType = common.TSDB_DATA_TYPE_UTINYINT + buf = make([]byte, 1) + buf[0] = byte(v) + case uint16: + colType = common.TSDB_DATA_TYPE_USMALLINT + buf = make([]byte, 2) + binary.LittleEndian.PutUint16(buf, v) + case uint32: + colType = common.TSDB_DATA_TYPE_UINT + buf = make([]byte, 4) + binary.LittleEndian.PutUint32(buf, v) + case uint64: + colType = common.TSDB_DATA_TYPE_UBIGINT + buf = make([]byte, 8) + binary.LittleEndian.PutUint64(buf, v) + case float32: + colType = common.TSDB_DATA_TYPE_FLOAT + buf = make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(v)) + case float64: + colType = common.TSDB_DATA_TYPE_DOUBLE + buf = make([]byte, 8) + binary.LittleEndian.PutUint64(buf, math.Float64bits(v)) + case bool: + colType = common.TSDB_DATA_TYPE_BOOL + buf = make([]byte, 1) + if v { + buf[0] = 1 + } else { + buf[0] = 0 + } + case time.Time: + buf = make([]byte, 0, 35) + colType = common.TSDB_DATA_TYPE_BINARY + haveLength = true + buf = v.AppendFormat(buf, time.RFC3339Nano) + length = len(buf) + default: + return nil, fmt.Errorf("unsupported type: %T", data) + } + headerLength := getBindDataHeaderLength(1, haveLength) + totalLength := len(buf) + headerLength + dataBuf := make([]byte, totalLength) + // type + binary.LittleEndian.PutUint32(dataBuf[BindDataTypeOffset:], colType) + // num + binary.LittleEndian.PutUint32(dataBuf[BindDataNumOffset:], 1) + // is null + dataBuf[BindDataIsNullOffset] = 0 + // has length + if haveLength { + dataBuf[BindDataIsNullOffset+1] = 1 + binary.LittleEndian.PutUint32(dataBuf[BindDataIsNullOffset+2:], uint32(length)) + + } + // bufferLength + binary.LittleEndian.PutUint32(dataBuf[headerLength-4:], uint32(len(buf))) + copy(dataBuf[headerLength:], buf) + binary.LittleEndian.PutUint32(dataBuf[BindDataTotalLengthOffset:], uint32(totalLength)) + return dataBuf, nil +} + +func writeUint64(buffer *bytes.Buffer, v uint64) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) + buffer.WriteByte(byte(v >> 16)) + buffer.WriteByte(byte(v >> 24)) + buffer.WriteByte(byte(v >> 32)) + buffer.WriteByte(byte(v >> 40)) + buffer.WriteByte(byte(v >> 48)) + buffer.WriteByte(byte(v >> 56)) +} + +func writeUint32(buffer *bytes.Buffer, v uint32) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) + buffer.WriteByte(byte(v >> 16)) + buffer.WriteByte(byte(v >> 24)) +} + +func writeUint16(buffer *bytes.Buffer, v uint16) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) +} + +func needLength(colType int8) bool { + switch colType { + case common.TSDB_DATA_TYPE_BINARY, + common.TSDB_DATA_TYPE_NCHAR, + common.TSDB_DATA_TYPE_JSON, + common.TSDB_DATA_TYPE_VARBINARY, + common.TSDB_DATA_TYPE_GEOMETRY: + return true + } + return false +} diff --git a/driver/common/stmt/stmt2_test.go b/driver/common/stmt/stmt2_test.go new file mode 100644 index 00000000..957d592f --- /dev/null +++ b/driver/common/stmt/stmt2_test.go @@ -0,0 +1,2437 @@ +package stmt + +import ( + "database/sql/driver" + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" +) + +type customInt int + +func TestMarshalBinary(t *testing.T) { + largeTableName := "" + for i := 0; i < math.MaxUint16; i++ { + largeTableName += "a" + } + type args struct { + t []*TaosStmt2BindData + isInsert bool + tagType []*StmtField + colType []*StmtField + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + name: "TestSetTableName", + args: args{ + t: []*TaosStmt2BindData{ + { + TableName: "test1", + }, + { + TableName: "", + }, + { + TableName: "test2", + }, + }, + isInsert: true, + tagType: nil, + colType: nil, + }, + want: []byte{ + // total Length + 0x2f, 0x00, 0x00, 0x00, + // tableCount + 0x03, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + // TableNameLength + 0x06, 0x00, + 0x01, 0x00, + 0x06, 0x00, + // test1 + 0x74, 0x65, 0x73, 0x74, 0x31, 0x00, + // nil + 0x00, + // test2 + 0x74, 0x65, 0x73, 0x74, 0x32, 0x00, + }, + wantErr: false, + }, + { + name: "wrong TableName length", + args: args{ + t: []*TaosStmt2BindData{ + { + TableName: largeTableName, + }, + }, + isInsert: true, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "TestSetTableNameAndTags", + args: args{ + t: []*TaosStmt2BindData{ + { + TableName: "test1", + Tags: []driver.Value{ + // ts 1726803356466 + time.Unix(1726803356, 466000000), + // bool + true, + // tinyint + int8(1), + // smallint + int16(2), + // int + int32(3), + // bigint + int64(4), + // float + float32(5.5), + // double + float64(6.6), + // utinyint + uint8(7), + // usmallint + uint16(8), + // uint + uint32(9), + // ubigint + uint64(10), + // binary + []byte("binary"), + // nchar + "nchar", + // geometry + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + // varbinary + []byte("varbinary"), + }, + }, + { + TableName: "testnil", + Tags: []driver.Value{ + // ts 1726803356466 + nil, + // bool + nil, + // tinyint + nil, + // smallint + nil, + // int + nil, + // bigint + nil, + // float + nil, + // double + nil, + // utinyint + nil, + // usmallint + nil, + // uint + nil, + // ubigint + nil, + // binary + nil, + // nchar + nil, + // geometry + nil, + // varbinary + nil, + }, + }, + { + TableName: "test2", + Tags: []driver.Value{ + // ts 1726803356466 + time.Unix(1726803356, 466000000), + // bool + true, + // tinyint + int8(1), + // smallint + int16(2), + // int + int32(3), + // bigint + int64(4), + // float + float32(5.5), + // double + float64(6.6), + // utinyint + uint8(7), + // usmallint + uint16(8), + // uint + uint32(9), + // ubigint + uint64(10), + // binary + []byte("binary"), + // nchar + "nchar", + // geometry + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + // varbinary + []byte("varbinary"), + }, + }, + }, + isInsert: true, + tagType: []*StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionMilliSecond, + }, + { + FieldType: common.TSDB_DATA_TYPE_BOOL, + }, + { + FieldType: common.TSDB_DATA_TYPE_TINYINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_INT, + }, + { + FieldType: common.TSDB_DATA_TYPE_BIGINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_FLOAT, + }, + { + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + }, + { + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_UINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_BINARY, + }, + { + FieldType: common.TSDB_DATA_TYPE_NCHAR, + }, + { + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + }, + { + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + }, + }, + colType: nil, + }, + want: []byte{ + // total Length + 0x8a, 0x04, 0x00, 0x00, + // tableCount + 0x03, 0x00, 0x00, 0x00, + // TagCount + 0x10, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x36, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + // TableNameLength + 0x06, 0x00, + 0x08, 0x00, + 0x06, 0x00, + // test1 + 0x74, 0x65, 0x73, 0x74, 0x31, 0x00, + // testnil + 0x74, 0x65, 0x73, 0x74, 0x6e, 0x69, 0x6c, 0x00, + // test2 + 0x74, 0x65, 0x73, 0x74, 0x32, 0x00, + + // tags + + // tagsDataLength + // table1 DataLength + 0x8c, 0x01, 0x00, 0x00, + // table2 DataLength + 0x30, 0x01, 0x00, 0x00, + // table3 DataLength + 0x8c, 0x01, 0x00, 0x00, + + // tagsData + // table1 tags + // tag1 timestamp + // totalLength + 0x1a, 0x00, 0x00, 0x00, + + // type + 0x09, 0x00, 0x00, 0x00, + + // num + 0x01, 0x00, 0x00, 0x00, + + // isnull + 0x00, + + // haveLength + 0x00, + + //buffer length + 0x08, 0x00, 0x00, 0x00, + + // buffer + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + // tag2 bool + 0x13, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + + // tag3 tinyint + 0x13, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + + // tag4 smallint + 0x14, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, + + // tag5 int + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + + // tag6 bigint + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + // tag7 float + 0x16, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0xb0, 0x40, + + // tag8 double + 0x1a, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x1a, 0x40, + + // tag9 utinyint + 0x13, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x07, + + // tag10 usmallint + 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x08, 0x00, + + // tag11 uint + 0x16, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + + // tag12 ubigint + 0x1a, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + // tag13 binary + 0x1c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + // length + 0x06, 0x00, 0x00, 0x00, + // buffer length + 0x06, 0x00, 0x00, 0x00, + //buffer + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + // tag14 nchar + 0x1b, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x05, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x6e, 0x63, 0x68, 0x61, 0x72, + + // tag15 geometry + 0x2b, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x15, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + + // tag16 varbinary + 0x1f, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x09, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + // table 2 tags + // tag1 timestamp nil + // TotalLength + 0x12, 0x00, 0x00, 0x00, + // type + 0x09, 0x00, 0x00, 0x00, + // num + 0x01, 0x00, 0x00, 0x00, + // isnull + 0x01, + // haveLength + 0x00, + // buffer length + 0x00, 0x00, 0x00, 0x00, + + // tag2 bool nil + 0x12, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag3 tinyint nil + 0x12, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag4 smallint nil + 0x12, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag5 int nil + 0x12, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag6 bigint nil + 0x12, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag7 float nil + 0x12, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag8 double nil + 0x12, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag9 utinyint nil + 0x12, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag10 usmallint nil + 0x12, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag11 uint nil + 0x12, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag12 ubigint nil + 0x12, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag13 binary nil + 0x16, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag14 nchar nil + 0x16, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag15 geometry nil + 0x16, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + + // tag16 varbinary nil + 0x16, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + + // table 3 tags + // tag1 timestamp + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + // tag2 bool + 0x13, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + + // tag3 tinyint + 0x13, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + + // tag4 smallint + 0x14, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, + + // tag5 int + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + + // tag6 bigint + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + // tag7 float + 0x00, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x01, + 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0xb0, 0x40, + + // tag8 double + 0x1a, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x1a, 0x40, + + // tag9 utinyint + 0x13, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x07, + + // tag10 usmallint + 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x08, 0x00, + + // tag11 uint + 0x16, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + + // tag12 ubigint + 0x1a, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + // tag13 binary + 0x1c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x06, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + // tag14 nchar + 0x1b, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x05, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x6e, 0x63, 0x68, 0x61, 0x72, + + // tag15 geometry + 0x2b, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x15, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + + // tag16 varbinary + 0x1f, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x09, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + }, + wantErr: false, + }, + { + name: "TestAllData", + args: args{ + t: []*TaosStmt2BindData{ + { + TableName: "test1", + Tags: []driver.Value{ + // ts 1726803356466 + time.Unix(1726803356, 466000000), + // bool + true, + // tinyint + int8(1), + // smallint + int16(2), + // int + int32(3), + // bigint + int64(4), + // float + float32(5.5), + // double + float64(6.6), + // utinyint + uint8(7), + // usmallint + uint16(8), + // uint + uint32(9), + // ubigint + uint64(10), + // binary + []byte("binary"), + // nchar + "nchar", + // geometry + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + // varbinary + []byte("varbinary"), + }, + Cols: [][]driver.Value{ + { + // ts 1726803356466 + time.Unix(1726803356, 466000000), + // ts 1726803357466 + time.Unix(1726803357, 466000000), + // ts 1726803358466 + time.Unix(1726803358, 466000000), + }, + { + // BOOL + true, + nil, + false, + }, + { + // TINYINT + int8(11), + nil, + int8(12), + }, + { + // SMALLINT + int16(11), + nil, + int16(12), + }, + { + // INT + int32(11), + nil, + int32(12), + }, + { + // BIGINT + int64(11), + nil, + int64(12), + }, + { + // FLOAT + float32(11.2), + nil, + float32(12.2), + }, + { + // DOUBLE + float64(11.2), + nil, + float64(12.2), + }, + { + // TINYINT UNSIGNED + uint8(11), + nil, + uint8(12), + }, + { + // SMALLINT UNSIGNED + uint16(11), + nil, + uint16(12), + }, + { + // INT UNSIGNED + uint32(11), + nil, + uint32(12), + }, + { + // BIGINT UNSIGNED + uint64(11), + nil, + uint64(12), + }, + { + // BINARY + []byte("binary1"), + nil, + []byte("binary2"), + }, + { + // NCHAR + "nchar1", + nil, + "nchar2", + }, + { + // GEOMETRY `point(100 100)` + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + { + // VARBINARY + []byte("varbinary1"), + nil, + []byte("varbinary2"), + }, + }, + }, + }, + isInsert: true, + tagType: []*StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionMilliSecond, + }, + { + FieldType: common.TSDB_DATA_TYPE_BOOL, + }, + { + FieldType: common.TSDB_DATA_TYPE_TINYINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_INT, + }, + { + FieldType: common.TSDB_DATA_TYPE_BIGINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_FLOAT, + }, + { + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + }, + { + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_UINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_BINARY, + }, + { + FieldType: common.TSDB_DATA_TYPE_NCHAR, + }, + { + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + }, + { + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + }, + }, + colType: []*StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionMilliSecond, + }, + { + FieldType: common.TSDB_DATA_TYPE_BOOL, + }, + { + FieldType: common.TSDB_DATA_TYPE_TINYINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_INT, + }, + { + FieldType: common.TSDB_DATA_TYPE_BIGINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_FLOAT, + }, + { + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + }, + { + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_UINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + }, + { + FieldType: common.TSDB_DATA_TYPE_BINARY, + }, + { + FieldType: common.TSDB_DATA_TYPE_NCHAR, + }, + { + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + }, + { + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + }, + }, + }, + want: []byte{ + // TotalLength + 0x19, 0x04, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x10, 0x00, 0x00, 0x00, + // ColCount + 0x10, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x24, 0x00, 0x00, 0x00, + // ColsOffset + 0xb4, 0x01, 0x00, 0x00, + + // TableNameLength + 0x06, 0x00, + // TableNameBuffer + 0x74, 0x65, 0x73, 0x74, 0x31, 0x00, + + // TagsDataLength + 0x8c, 0x01, 0x00, 0x00, + + // TagsBuffer + + // tag1 timestamp + // TotalLength + 0x1a, 0x00, 0x00, 0x00, + // type + 0x09, 0x00, 0x00, 0x00, + // num + 0x01, 0x00, 0x00, 0x00, + // isnull + 0x00, + // haveLength + 0x00, + // buffer length + 0x08, 0x00, 0x00, 0x00, + // buffer + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + // tag2 bool + 0x13, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + + // tag3 tinyint + 0x13, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, + + // tag4 smallint + 0x14, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, + + // tag5 int + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + + // tag6 bigint + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + // tag7 float + 0x16, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0xb0, 0x40, + + // tag8 double + 0x1a, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x1a, 0x40, + + // tag9 utinyint + 0x13, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x07, + + // tag10 usmallint + 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x08, 0x00, + + // tag11 uint + 0x16, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + + // tag12 ubigint + 0x1a, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + // tag13 binary + 0x1c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + // haveLength + 0x01, + // length + 0x06, 0x00, 0x00, 0x00, + //buffer length + 0x06, 0x00, 0x00, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + // tag14 nchar + 0x1b, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x05, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x6e, 0x63, 0x68, 0x61, 0x72, + + // tag15 geometry + 0x2b, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x15, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + + // tag16 varbinary + 0x1f, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x01, + 0x09, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + // ColDataLength + 0x61, 0x02, 0x00, 0x00, + + // ColBuffer + // col1 timestamp + // TotalLength + 0x2c, 0x00, 0x00, 0x00, + // Type + 0x09, 0x00, 0x00, 0x00, + // Num + 0x03, 0x00, 0x00, 0x00, + // IsNull + 0x00, 0x00, 0x00, + //haveLength + 0x00, + // BufferLength + 0x18, 0x00, 0x00, 0x00, + // Buffer + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + 0x1a, 0x2f, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + 0x02, 0x33, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + // col2 bool + 0x17, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + // is null, row index 1 is null + 0x00, 0x01, 0x00, + 0x00, + 0x03, 0x00, 0x00, 0x00, + + // row0 + 0x01, + // row1 + 0x00, + // row2 + 0x00, + + // col3 tinyint + 0x17, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x03, 0x00, 0x00, 0x00, + + 0x0b, + 0x00, + 0x0c, + + // col4 smallint + 0x1a, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x06, 0x00, 0x00, 0x00, + + 0x0b, 0x00, + 0x00, 0x00, + 0x0c, 0x00, + + // col5 int + 0x20, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x0c, 0x00, 0x00, 0x00, + + 0x0b, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + + // col6 bigint + 0x2c, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x18, 0x00, 0x00, 0x00, + + 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + // col7 float + 0x20, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x33, 0x33, 0x33, 0x41, + 0x00, 0x00, 0x00, 0x00, + 0x33, 0x33, 0x43, 0x41, + + // col8 double + 0x2c, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x18, 0x00, 0x00, 0x00, + + 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x26, 0x40, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x28, 0x40, + + // col9 utinyint + 0x17, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x03, 0x00, 0x00, 0x00, + + 0x0b, + 0x00, + 0x0c, + + // col10 usmallint + 0x1a, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x06, 0x00, 0x00, 0x00, + + 0x0b, 0x00, + 0x00, 0x00, + 0x0c, 0x00, + + // col11 uint + 0x20, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x0c, 0x00, 0x00, 0x00, + + 0x0b, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + + // col12 ubigint + 0x2C, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x00, + 0x18, 0x00, 0x00, 0x00, + + 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + // col13 binary + 0x2e, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + // have length + 0x01, + // length + 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, + // buffer length + 0x0e, 0x00, 0x00, 0x00, + // buffer + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x31, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x32, + + // col14 nchar + 0x2c, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x01, + // length + 0x06, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + // buffer length + 0x0c, 0x00, 0x00, 0x00, + // buffer + 0x6e, 0x63, 0x68, 0x61, 0x72, 0x31, + 0x6e, 0x63, 0x68, 0x61, 0x72, 0x32, + + // col15 geometry + 0x4a, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x01, + // length + 0x15, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x00, 0x00, + // buffer length + 0x2a, 0x00, 0x00, 0x00, + // buffer + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + + // col16 varbinary + 0x34, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, + 0x01, + // length + 0x0a, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + // buffer length + 0x14, 0x00, 0x00, 0x00, + // buffer + 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x31, + 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x32, + }, + wantErr: false, + }, + { + name: "TestQuery", + args: args{ + t: []*TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + // ts 1726803356466 + time.Unix(1726803356, 466000000).UTC(), + }, + { + // BOOL + true, + }, + { + // TINYINT + int8(11), + }, + { + // SMALLINT + int16(11), + }, + { + // INT + int32(11), + }, + { + // BIGINT + int64(11), + }, + { + // FLOAT + float32(11.2), + }, + { + // DOUBLE + float64(11.2), + }, + { + // TINYINT UNSIGNED + uint8(11), + }, + { + // SMALLINT UNSIGNED + uint16(11), + }, + { + // INT UNSIGNED + uint32(11), + }, + { + // BIGINT UNSIGNED + uint64(11), + }, + { + // Bytes + []byte("binary1"), + }, + { + // String + "nchar1", + }, + }, + }, + }, + isInsert: false, + }, + want: []byte{ + // total Length + 0x78, 0x01, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x0e, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + // col length + 0x58, 0x01, 0x00, 0x00, + //table 0 cols + //col 0 + //total length + 0x2e, 0x00, 0x00, 0x00, + //type + 0x08, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x01, + // length + 0x18, 0x00, 0x00, 0x00, + // buffer length + 0x18, 0x00, 0x00, 0x00, + 0x32, 0x30, 0x32, 0x34, 0x2d, 0x30, 0x39, 0x2d, 0x32, 0x30, 0x54, 0x30, 0x33, 0x3a, 0x33, 0x35, 0x3a, 0x35, 0x36, 0x2e, 0x34, 0x36, 0x36, 0x5a, + + //col 1 + //total length + 0x13, 0x00, 0x00, 0x00, + //type + 0x01, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x01, 0x00, 0x00, 0x00, + 0x01, + + //col 2 + //total length + 0x13, 0x00, 0x00, 0x00, + //type + 0x02, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x01, 0x00, 0x00, 0x00, + 0x0b, + + //col 3 + //total length + 0x14, 0x00, 0x00, 0x00, + //type + 0x03, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x02, 0x00, 0x00, 0x00, + 0x0b, 0x00, + + //col 4 + //total length + 0x16, 0x00, 0x00, 0x00, + //type + 0x04, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x04, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, + + //col 5 + //total length + 0x1a, 0x00, 0x00, 0x00, + //type + 0x05, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + //col 6 + //total length + 0x16, 0x00, 0x00, 0x00, + //type + 0x06, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x04, 0x00, 0x00, 0x00, + 0x33, 0x33, 0x33, 0x41, + + //col 7 + //total length + 0x1a, 0x00, 0x00, 0x00, + //type + 0x07, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x08, 0x00, 0x00, 0x00, + 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x26, 0x40, + + //col 8 + //total length + 0x13, 0x00, 0x00, 0x00, + //type + 0x0b, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x01, 0x00, 0x00, 0x00, + 0x0b, + + //col 9 + //total length + 0x14, 0x00, 0x00, 0x00, + //type + 0x0c, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x02, 0x00, 0x00, 0x00, + 0x0b, 0x00, + + //col 10 + //total length + 0x16, 0x00, 0x00, 0x00, + //type + 0x0d, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x04, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, + + //col 11 + //total length + 0x1a, 0x00, 0x00, 0x00, + //type + 0x0e, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + //col 12 + //total length + 0x1d, 0x00, 0x00, 0x00, + //type + 0x08, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x01, + // length + 0x07, 0x00, 0x00, 0x00, + // buffer length + 0x07, 0x00, 0x00, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x31, + + //col 13 + //total length + 0x1c, 0x00, 0x00, 0x00, + //type + 0x08, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x01, + // length + 0x06, 0x00, 0x00, 0x00, + // buffer length + 0x06, 0x00, 0x00, 0x00, + 0x6e, 0x63, 0x68, 0x61, 0x72, 0x31, + }, + wantErr: false, + }, + { + name: "Three Table", + args: args{ + t: []*TaosStmt2BindData{ + { + TableName: "table1", + Cols: [][]driver.Value{ + { + // ts 1726803356466 + time.Unix(1726803356, 466000000), + }, + { + int64(1), + }, + }, + Tags: []driver.Value{int32(1)}, + }, + { + TableName: "table2", + Cols: [][]driver.Value{ + { + // ts 1726803356466 + time.Unix(1726803356, 466000000), + }, + { + int64(2), + }, + }, + Tags: []driver.Value{int32(2)}, + }, + { + TableName: "table3", + Cols: [][]driver.Value{ + { + // ts 1726803356466 + time.Unix(1726803356, 466000000), + }, + { + int64(3), + }, + }, + Tags: []driver.Value{int32(3)}, + }, + }, + colType: []*StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionMilliSecond, + }, + { + FieldType: common.TSDB_DATA_TYPE_BIGINT, + }, + }, + tagType: []*StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_INT, + }, + }, + isInsert: true, + }, + want: []byte{ + // TotalLength + 0x2d, 0x01, 0x00, 0x00, + // tableCount + 0x03, 0x00, 0x00, 0x00, + // TagCount + 0x01, 0x00, 0x00, 0x00, + // ColCount + 0x02, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x37, 0x00, 0x00, 0x00, + // ColsOffset + 0x85, 0x00, 0x00, 0x00, + // TableNameLength + 0x07, 0x00, + 0x07, 0x00, + 0x07, 0x00, + // TableNameBuffer + 0x74, 0x61, 0x62, 0x6c, 0x65, 0x31, 0x00, + 0x74, 0x61, 0x62, 0x6c, 0x65, 0x32, 0x00, + 0x74, 0x61, 0x62, 0x6c, 0x65, 0x33, 0x00, + // TagsDataLength + 0x16, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, + // TagsBuffer + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + + // ColDataLength + 0x34, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, + + // ColBuffer + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + wantErr: false, + }, + { + name: "empty", + args: args{ + t: nil, + isInsert: false, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong tag count", + args: args{ + t: []*TaosStmt2BindData{ + { + Tags: []driver.Value{int32(1)}, + }, + }, + isInsert: true, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong col count", + args: args{ + t: []*TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + int32(1), + }, + }, + }, + }, + isInsert: true, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "query has tag type", + args: args{ + t: []*TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + int32(1), + }, + }, + }, + }, + isInsert: false, + tagType: []*StmtField{{ + FieldType: common.TSDB_DATA_TYPE_INT, + }}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "query has col type", + args: args{ + t: []*TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + int32(1), + }, + }, + }, + }, + isInsert: false, + tagType: nil, + colType: []*StmtField{{ + FieldType: common.TSDB_DATA_TYPE_INT, + }}, + }, + want: nil, + wantErr: true, + }, + { + name: "query has multi data", + args: args{ + t: []*TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + int32(1), + }, + }, + }, + { + Cols: [][]driver.Value{ + { + int32(1), + }, + }, + }, + }, + isInsert: false, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "query has tablename", + args: args{ + t: []*TaosStmt2BindData{ + { + TableName: "table1", + }, + }, + isInsert: false, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "query has tag", + args: args{ + t: []*TaosStmt2BindData{ + { + Tags: []driver.Value{int32(1)}, + }, + }, + isInsert: false, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "query without data", + args: args{ + t: []*TaosStmt2BindData{ + {}, + }, + isInsert: false, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "query with multi rows", + args: args{ + t: []*TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + int32(1), + int32(1), + }, + }, + }, + }, + isInsert: false, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong bool", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{int32(1)}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_BOOL}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong tinyint", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TINYINT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong smallint", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_SMALLINT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong int", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_INT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong bigint", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_BIGINT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong tinyint unsigned", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_UTINYINT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong smallint unsigned", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_USMALLINT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong int unsigned", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_UINT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong bigint unsigned", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_UBIGINT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong float", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_FLOAT}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong double", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_DOUBLE}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong binary", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_BINARY}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "wrong timestamp", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{true}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TIMESTAMP}}, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "insert nil timestamp", + args: args{ + t: []*TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + time.Unix(1726803356, 466000000), + nil, + }, + }, + }, + }, + isInsert: true, + tagType: nil, + colType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TIMESTAMP}}, + }, + want: []byte{ + // total Length + 0x43, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x01, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + // col length + 0x23, 0x00, 0x00, 0x00, + //table 0 cols + //col 0 + //total length + 0x23, 0x00, 0x00, 0x00, + //type + 0x09, 0x00, 0x00, 0x00, + //num + 0x02, 0x00, 0x00, 0x00, + //is null + 0x00, + 0x01, + // haveLength + 0x00, + // buffer length + 0x10, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + wantErr: false, + }, + { + name: "query bool false", + args: args{ + t: []*TaosStmt2BindData{{ + Cols: [][]driver.Value{ + {false}, + }, + }}, + isInsert: false, + tagType: nil, + colType: nil, + }, + want: []byte{ + // total Length + 0x33, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x01, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + // col length + 0x13, 0x00, 0x00, 0x00, + //table 0 cols + //col 0 + //total length + 0x13, 0x00, 0x00, 0x00, + //type + 0x01, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x01, 0x00, 0x00, 0x00, + 0x00, + }, + wantErr: false, + }, + { + name: "query unsupported type", + args: args{ + t: []*TaosStmt2BindData{{ + Cols: [][]driver.Value{ + {customInt(1)}, + }, + }}, + isInsert: false, + tagType: nil, + colType: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "insert unsupported type", + args: args{ + t: []*TaosStmt2BindData{{ + Cols: [][]driver.Value{ + {int32(1)}, + }, + }}, + isInsert: true, + tagType: nil, + colType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_NULL}}, + }, + want: nil, + wantErr: true, + }, + { + name: "nil", + args: args{ + t: []*TaosStmt2BindData{ + { + Cols: nil, + }, + }, + isInsert: true, + tagType: nil, + colType: []*StmtField{}, + }, + want: nil, + wantErr: true, + }, + { + name: "int64 timestamp", + args: args{ + t: []*TaosStmt2BindData{{ + Tags: []driver.Value{int64(1726803356466)}, + }}, + isInsert: true, + tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TIMESTAMP}}, + colType: nil, + }, + want: []byte{ + // total Length + 0x3a, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x01, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x1c, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // tags + // table length + 0x1a, 0x00, 0x00, 0x00, + //table 0 tags + //tag 0 + //total length + 0x1a, 0x00, 0x00, 0x00, + //type + 0x09, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x08, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MarshalStmt2Binary(tt.args.t, tt.args.isInsert, tt.args.colType, tt.args.tagType) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalStmt2Binary() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestT(t *testing.T) { + +} diff --git a/driver/common/tmq/config.go b/driver/common/tmq/config.go new file mode 100644 index 00000000..b17084cb --- /dev/null +++ b/driver/common/tmq/config.go @@ -0,0 +1,34 @@ +package tmq + +import ( + "fmt" + "reflect" +) + +type ConfigValue interface{} +type ConfigMap map[string]ConfigValue + +func (m ConfigMap) Get(key string, defval ConfigValue) (ConfigValue, error) { + return m.get(key, defval) +} + +func (m ConfigMap) get(key string, defval ConfigValue) (ConfigValue, error) { + v, ok := m[key] + if !ok { + return defval, nil + } + + if defval != nil && reflect.TypeOf(defval) != reflect.TypeOf(v) { + return nil, fmt.Errorf("%s expects type %T, not %T", key, defval, v) + } + + return v, nil +} + +func (m ConfigMap) Clone() ConfigMap { + m2 := make(ConfigMap) + for k, v := range m { + m2[k] = v + } + return m2 +} diff --git a/driver/common/tmq/config_test.go b/driver/common/tmq/config_test.go new file mode 100644 index 00000000..c5a62d5c --- /dev/null +++ b/driver/common/tmq/config_test.go @@ -0,0 +1,52 @@ +package tmq + +import ( + "fmt" + "reflect" + "testing" +) + +func TestConfigMap_Get(t *testing.T) { + t.Parallel() + + config := ConfigMap{ + "key1": "value1", + "key2": 123, + } + + t.Run("Existing Key", func(t *testing.T) { + want := "value1" + if got, err := config.Get("key1", nil); err != nil || got != want { + t.Errorf("Get() = %v, want %v (error: %v)", got, want, err) + } + }) + + t.Run("Type Mismatch", func(t *testing.T) { + wantErr := fmt.Errorf("key2 expects type string, not int") + if got, err := config.Get("key2", "default"); err == nil || got != nil || err.Error() != wantErr.Error() { + t.Errorf("Get() = %v, want error: %v", got, wantErr) + } + }) + + t.Run("Non-Existing Key with Default Value", func(t *testing.T) { + want := "default" + if got, err := config.Get("key3", "default"); err != nil || got != want { + t.Errorf("Get() = %v, want %v (error: %v)", got, want, err) + } + }) +} + +func TestConfigMap_Clone(t *testing.T) { + t.Parallel() + + config := ConfigMap{ + "key1": "value1", + "key2": 123, + } + + clone := config.Clone() + + if !reflect.DeepEqual(config, clone) { + t.Errorf("Clone() = %v, want %v", clone, config) + } +} diff --git a/driver/common/tmq/event.go b/driver/common/tmq/event.go new file mode 100644 index 00000000..e64d4302 --- /dev/null +++ b/driver/common/tmq/event.go @@ -0,0 +1,204 @@ +package tmq + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + + taosError "github.com/taosdata/taosadapter/v3/driver/errors" +) + +type Data struct { + TableName string + Data [][]driver.Value +} +type Event interface { + String() string +} + +type Error struct { + code int + str string +} + +const ErrorOther = 0xffff + +func NewTMQError(code int, str string) Error { + return Error{ + code: code, + str: str, + } +} + +func NewTMQErrorWithErr(err error) Error { + tErr, ok := err.(*taosError.TaosError) + if ok { + return Error{ + code: int(tErr.Code), + str: tErr.ErrStr, + } + } + return Error{ + code: ErrorOther, + str: err.Error(), + } +} + +func (e Error) String() string { + return fmt.Sprintf("[0x%x] %s", e.code, e.str) +} + +func (e Error) Error() string { + return e.String() +} + +func (e Error) Code() int { + return e.code +} + +type Message interface { + Topic() string + DBName() string + Value() interface{} + Offset() int64 +} + +type DataMessage struct { + TopicPartition TopicPartition + dbName string + topic string + data []*Data + offset Offset +} + +func (m *DataMessage) String() string { + data, _ := json.Marshal(m.data) + return fmt.Sprintf("DataMessage: %s[%s]:%s", m.topic, m.dbName, string(data)) +} + +func (m *DataMessage) SetDbName(dbName string) { + m.dbName = dbName +} + +func (m *DataMessage) SetTopic(topic string) { + m.topic = topic +} + +func (m *DataMessage) SetData(data []*Data) { + m.data = data +} + +func (m *DataMessage) SetOffset(offset Offset) { + m.offset = offset +} + +func (m *DataMessage) Topic() string { + return m.topic +} + +func (m *DataMessage) DBName() string { + return m.dbName +} + +func (m *DataMessage) Value() interface{} { + return m.data +} + +func (m *DataMessage) Offset() Offset { + return m.offset +} + +type MetaMessage struct { + TopicPartition TopicPartition + dbName string + topic string + offset Offset + meta *Meta +} + +func (m *MetaMessage) Offset() Offset { + return m.offset +} + +func (m *MetaMessage) String() string { + data, _ := json.Marshal(m.meta) + return fmt.Sprintf("MetaMessage: %s[%s]:%s", m.topic, m.dbName, string(data)) +} + +func (m *MetaMessage) SetDbName(dbName string) { + m.dbName = dbName +} + +func (m *MetaMessage) SetTopic(topic string) { + m.topic = topic +} + +func (m *MetaMessage) SetOffset(offset Offset) { + m.offset = offset +} + +func (m *MetaMessage) SetMeta(meta *Meta) { + m.meta = meta +} + +func (m *MetaMessage) Topic() string { + return m.topic +} + +func (m *MetaMessage) DBName() string { + return m.dbName +} + +func (m *MetaMessage) Value() interface{} { + return m.meta +} + +type MetaDataMessage struct { + TopicPartition TopicPartition + dbName string + topic string + offset Offset + metaData *MetaData +} + +func (m *MetaDataMessage) Offset() Offset { + return m.offset +} + +func (m *MetaDataMessage) String() string { + data, _ := json.Marshal(m.metaData) + return fmt.Sprintf("MetaDataMessage: %s[%s]:%s", m.topic, m.dbName, string(data)) +} + +func (m *MetaDataMessage) SetDbName(dbName string) { + m.dbName = dbName +} + +func (m *MetaDataMessage) SetTopic(topic string) { + m.topic = topic +} + +func (m *MetaDataMessage) SetOffset(offset Offset) { + m.offset = offset +} + +func (m *MetaDataMessage) SetMetaData(metaData *MetaData) { + m.metaData = metaData +} + +type MetaData struct { + Meta *Meta + Data []*Data +} + +func (m *MetaDataMessage) Topic() string { + return m.topic +} + +func (m *MetaDataMessage) DBName() string { + return m.dbName +} + +func (m *MetaDataMessage) Value() interface{} { + return m.metaData +} diff --git a/driver/common/tmq/event_test.go b/driver/common/tmq/event_test.go new file mode 100644 index 00000000..12812fd9 --- /dev/null +++ b/driver/common/tmq/event_test.go @@ -0,0 +1,352 @@ +package tmq + +import ( + "database/sql/driver" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + taosError "github.com/taosdata/taosadapter/v3/driver/errors" +) + +func TestDataMessage_String(t *testing.T) { + t.Parallel() + + data := []*Data{ + {TableName: "table1", Data: [][]driver.Value{{1, "data1"}}}, + {TableName: "table2", Data: [][]driver.Value{{2, "data2"}}}, + } + message := &DataMessage{ + TopicPartition: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + }, + dbName: "test-db", + topic: "test-topic", + data: data, + offset: 100, + } + + want := `DataMessage: test-topic[test-db]:[{"TableName":"table1","Data":[[1,"data1"]]},{"TableName":"table2","Data":[[2,"data2"]]}]` + + if got := message.String(); got != want { + t.Errorf("DataMessage.String() = %v, want %v", got, want) + } +} + +func TestMetaMessage_String(t *testing.T) { + t.Parallel() + + meta := &Meta{ + Type: "type", + TableName: "table", + TableType: "tableType", + } + message := &MetaMessage{ + TopicPartition: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + }, + dbName: "test-db", + topic: "test-topic", + offset: 100, + meta: meta, + } + + want := `MetaMessage: test-topic[test-db]:{"type":"type","tableName":"table","tableType":"tableType","createList":null,"columns":null,"using":"","tagNum":0,"tags":null,"tableNameList":null,"alterType":0,"colName":"","colNewName":"","colType":0,"colLength":0,"colValue":"","colValueNull":false}` + + if got := message.String(); got != want { + t.Errorf("MetaMessage.String() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_String(t *testing.T) { + t.Parallel() + + meta := &Meta{ + Type: "type", + TableName: "table", + TableType: "tableType", + } + data := []*Data{ + {TableName: "table1", Data: [][]driver.Value{{1, "data1"}}}, + {TableName: "table2", Data: [][]driver.Value{{2, "data2"}}}, + } + metaData := &MetaData{ + Meta: meta, + Data: data, + } + message := &MetaDataMessage{ + TopicPartition: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + }, + dbName: "test-db", + topic: "test-topic", + offset: 100, + metaData: metaData, + } + + want := `MetaDataMessage: test-topic[test-db]:{"Meta":{"type":"type","tableName":"table","tableType":"tableType","createList":null,"columns":null,"using":"","tagNum":0,"tags":null,"tableNameList":null,"alterType":0,"colName":"","colNewName":"","colType":0,"colLength":0,"colValue":"","colValueNull":false},"Data":[{"TableName":"table1","Data":[[1,"data1"]]},{"TableName":"table2","Data":[[2,"data2"]]}]}` + if got := message.String(); got != want { + t.Errorf("MetaDataMessage.String() = %v, want %v", got, want) + } +} + +func TestNewTMQError(t *testing.T) { + t.Parallel() + + code := 123 + str := "test error" + err := NewTMQError(code, str) + + if err.code != code { + t.Errorf("NewTMQError() code = %v, want %v", err.code, code) + } + + if err.str != str { + t.Errorf("NewTMQError() str = %v, want %v", err.str, str) + } +} + +func TestNewTMQErrorWithErr(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + err error + code int + str string + }{ + { + name: "TaosError", + err: &taosError.TaosError{ + Code: 456, + ErrStr: "taos error", + }, + code: 456, + str: "taos error", + }, + { + name: "OtherError", + err: fmt.Errorf("other error"), + code: ErrorOther, + str: "other error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := NewTMQErrorWithErr(tc.err) + + if err.code != tc.code { + t.Errorf("NewTMQErrorWithErr() code = %v, want %v", err.code, tc.code) + } + + if err.str != tc.str { + t.Errorf("NewTMQErrorWithErr() str = %v, want %v", err.str, tc.str) + } + }) + } +} + +func TestError_String(t *testing.T) { + t.Parallel() + + code := 789 + str := "test error" + err := Error{code: code, str: str} + want := fmt.Sprintf("[0x%x] %s", code, str) + + if got := err.String(); got != want { + t.Errorf("Error.String() = %v, want %v", got, want) + } +} + +func TestError_Error(t *testing.T) { + t.Parallel() + + code := 789 + str := "test error" + err := Error{code: code, str: str} + want := fmt.Sprintf("[0x%x] %s", code, str) + + if got := err.Error(); got != want { + t.Errorf("Error.Error() = %v, want %v", got, want) + } +} + +func TestError_Code(t *testing.T) { + t.Parallel() + + code := 789 + err := Error{code: code} + + if got := err.Code(); got != code { + t.Errorf("Error.Code() = %v, want %v", got, code) + } +} + +func TestMetaMessage_Offset(t *testing.T) { + t.Parallel() + + message := &MetaMessage{ + offset: 100, + } + + want := Offset(100) + if got := message.Offset(); got != want { + t.Errorf("Offset() = %v, want %v", got, want) + } +} + +func TestMetaMessage_SetDbName(t *testing.T) { + t.Parallel() + + message := &MetaMessage{} + message.SetDbName("test-db") + + want := "test-db" + if got := message.DBName(); got != want { + t.Errorf("DBName() = %v, want %v", got, want) + } +} + +func TestMetaMessage_SetTopic(t *testing.T) { + t.Parallel() + + message := &MetaMessage{} + message.SetTopic("test-topic") + + want := "test-topic" + if got := message.Topic(); got != want { + t.Errorf("Topic() = %v, want %v", got, want) + } +} + +func TestMetaMessage_SetOffset(t *testing.T) { + t.Parallel() + + message := &MetaMessage{} + message.SetOffset(200) + + want := Offset(200) + if got := message.Offset(); got != want { + t.Errorf("Offset() = %v, want %v", got, want) + } +} + +func TestMetaMessage_SetMeta(t *testing.T) { + t.Parallel() + + meta := &Meta{} + message := &MetaMessage{} + message.SetMeta(meta) + + want := meta + if got := message.Value(); got != want { + t.Errorf("Value() = %v, want %v", got, want) + } +} + +func TestDataMessage_SetDbName(t *testing.T) { + t.Parallel() + + message := &DataMessage{} + message.SetDbName("test-db") + + want := "test-db" + if got := message.DBName(); got != want { + t.Errorf("DBName() = %v, want %v", got, want) + } +} + +func TestDataMessage_SetTopic(t *testing.T) { + t.Parallel() + + message := &DataMessage{} + message.SetTopic("test-topic") + + want := "test-topic" + if got := message.Topic(); got != want { + t.Errorf("Topic() = %v, want %v", got, want) + } +} + +func TestDataMessage_SetData(t *testing.T) { + t.Parallel() + + data := []*Data{ + {TableName: "table1", Data: [][]driver.Value{{1, "data1"}}}, + {TableName: "table2", Data: [][]driver.Value{{2, "data2"}}}, + } + message := &DataMessage{} + message.SetData(data) + + want := data + assert.Equal(t, want, message.Value()) +} + +func TestDataMessage_SetOffset(t *testing.T) { + t.Parallel() + + message := &DataMessage{} + message.SetOffset(200) + + want := Offset(200) + if got := message.Offset(); got != want { + t.Errorf("Offset() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_SetDbName(t *testing.T) { + t.Parallel() + + message := &MetaDataMessage{} + message.SetDbName("test-db") + + want := "test-db" + if got := message.DBName(); got != want { + t.Errorf("DBName() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_SetTopic(t *testing.T) { + t.Parallel() + + message := &MetaDataMessage{} + message.SetTopic("test-topic") + + want := "test-topic" + if got := message.Topic(); got != want { + t.Errorf("Topic() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_SetOffset(t *testing.T) { + t.Parallel() + + message := &MetaDataMessage{} + message.SetOffset(200) + + want := Offset(200) + if got := message.Offset(); got != want { + t.Errorf("Offset() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_SetMetaData(t *testing.T) { + t.Parallel() + + metaData := &MetaData{} + message := &MetaDataMessage{} + message.SetMetaData(metaData) + + want := metaData + if got := message.Value(); got != want { + t.Errorf("Value() = %v, want %v", got, want) + } +} diff --git a/driver/common/tmq/tmq.go b/driver/common/tmq/tmq.go new file mode 100644 index 00000000..a6e5617c --- /dev/null +++ b/driver/common/tmq/tmq.go @@ -0,0 +1,87 @@ +package tmq + +import "fmt" + +type Meta struct { + Type string `json:"type"` + TableName string `json:"tableName"` + TableType string `json:"tableType"` + CreateList []*CreateItem `json:"createList"` + Columns []*Column `json:"columns"` + Using string `json:"using"` + TagNum int `json:"tagNum"` + Tags []*Tag `json:"tags"` + TableNameList []string `json:"tableNameList"` + AlterType int `json:"alterType"` + ColName string `json:"colName"` + ColNewName string `json:"colNewName"` + ColType int `json:"colType"` + ColLength int `json:"colLength"` + ColValue string `json:"colValue"` + ColValueNull bool `json:"colValueNull"` +} + +type Tag struct { + Name string `json:"name"` + Type int `json:"type"` + Value interface{} `json:"value"` +} + +type Column struct { + Name string `json:"name"` + Type int `json:"type"` + Length int `json:"length"` +} + +type CreateItem struct { + TableName string `json:"tableName"` + Using string `json:"using"` + TagNum int `json:"tagNum"` + Tags []*Tag `json:"tags"` +} + +type Offset int64 + +const OffsetInvalid = Offset(-2147467247) + +func (o Offset) String() string { + if o == OffsetInvalid { + return "unset" + } + return fmt.Sprintf("%d", int64(o)) +} + +func (o Offset) Valid() bool { + if o < 0 && o != OffsetInvalid { + return false + } + return true +} + +type TopicPartition struct { + Topic *string + Partition int32 + Offset Offset + Metadata *string + Error error +} + +func (p TopicPartition) String() string { + topic := "" + if p.Topic != nil { + topic = *p.Topic + } + if p.Error != nil { + return fmt.Sprintf("%s[%d]@%s(%s)", + topic, p.Partition, p.Offset, p.Error) + } + return fmt.Sprintf("%s[%d]@%s", + topic, p.Partition, p.Offset) +} + +type Assignment struct { + VGroupID int32 `json:"vgroup_id"` + Offset int64 `json:"offset"` + Begin int64 `json:"begin"` + End int64 `json:"end"` +} diff --git a/driver/common/tmq/tmq_test.go b/driver/common/tmq/tmq_test.go new file mode 100644 index 00000000..ae9fefb5 --- /dev/null +++ b/driver/common/tmq/tmq_test.go @@ -0,0 +1,197 @@ +package tmq + +import ( + "encoding/json" + "errors" + "reflect" + "testing" +) + +const createJson = `{ + "type": "create", + "tableName": "t1", + "tableType": "super", + "columns": [ + { + "name": "c1", + "type": 0, + "length": 0 + }, + { + "name": "c2", + "type": 8, + "length": 8 + } + ], + "tags": [ + { + "name": "t1", + "type": 0, + "length": 0 + }, + { + "name": "t2", + "type": 8, + "length": 8 + } + ] +}` +const dropJson = `{ + "type":"drop", + "tableName":"t1", + "tableType":"super", + "tableNameList":["t1", "t2"] +}` + +// @author: xftan +// @date: 2023/10/13 11:19 +// @description: test json +func TestCreateJson(t *testing.T) { + var obj Meta + err := json.Unmarshal([]byte(createJson), &obj) + if err != nil { + t.Log(err) + return + } + t.Log(obj) +} + +// @author: xftan +// @date: 2023/10/13 11:19 +// @description: test drop json +func TestDropJson(t *testing.T) { + var obj Meta + err := json.Unmarshal([]byte(dropJson), &obj) + if err != nil { + t.Log(err) + return + } + t.Log(obj) +} + +func TestOffset_String(t *testing.T) { + tests := []struct { + name string + o Offset + want string + }{ + { + name: "Valid Offset", + o: 100, + want: "100", + }, + { + name: "Invalid Offset", + o: OffsetInvalid, + want: "unset", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.o.String(); got != tt.want { + t.Errorf("Offset.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestOffset_Valid(t *testing.T) { + tests := []struct { + name string + o Offset + want bool + }{ + { + name: "Valid Offset", + o: 100, + want: true, + }, + { + name: "Invalid Offset", + o: OffsetInvalid, + want: true, + }, + { + name: "Negative Offset", + o: -100, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.o.Valid(); got != tt.want { + t.Errorf("Offset.Valid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTopicPartition_String(t *testing.T) { + tests := []struct { + name string + tp TopicPartition + want string + }{ + { + name: "With Error", + tp: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + Error: errors.New("error message"), + }, + want: "test-topic[0]@100(error message)", + }, + { + name: "Without Error", + tp: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + }, + want: "test-topic[0]@100", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.tp.String(); got != tt.want { + t.Errorf("TopicPartition.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAssignment_MarshalJSON(t *testing.T) { + tests := []struct { + name string + a Assignment + want string + }{ + { + name: "Marshal Assignment", + a: Assignment{ + VGroupID: 1, + Offset: 100, + Begin: 50, + End: 150, + }, + want: `{"vgroup_id":1,"offset":100,"begin":50,"end":150}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.a) + if err != nil { + t.Errorf("MarshalJSON error: %v", err) + return + } + if !reflect.DeepEqual(string(got), tt.want) { + t.Errorf("MarshalJSON = %v, want %v", string(got), tt.want) + } + }) + } +} + +func stringPtr(s string) *string { + return &s +} diff --git a/driver/errors/errors.go b/driver/errors/errors.go new file mode 100644 index 00000000..7c932d89 --- /dev/null +++ b/driver/errors/errors.go @@ -0,0 +1,30 @@ +//! TDengine error codes. +//! THIS IS AUTO GENERATED FROM TDENGINE , MAKE SURE YOU KNOW WHAT YOU ARE CHANING. + +package errors + +import "fmt" + +type TaosError struct { + Code int32 + ErrStr string +} + +const ( + SUCCESS int32 = 0 + UNKNOWN int32 = 0xffff +) + +func (e *TaosError) Error() string { + if e.Code != UNKNOWN { + return fmt.Sprintf("[0x%x] %s", e.Code, e.ErrStr) + } + return e.ErrStr +} + +func NewError(code int, errStr string) error { + return &TaosError{ + Code: int32(code) & 0xffff, + ErrStr: errStr, + } +} diff --git a/driver/errors/errors_test.go b/driver/errors/errors_test.go new file mode 100644 index 00000000..2a426c51 --- /dev/null +++ b/driver/errors/errors_test.go @@ -0,0 +1,52 @@ +package errors + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// @author: xftan +// @date: 2023/10/13 11:20 +// @description: test new error +func TestNewError(t *testing.T) { + type args struct { + code int + errStr string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "common", + args: args{ + code: 0, + errStr: "success", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := NewError(tt.args.code, tt.args.errStr); (err != nil) != tt.wantErr { + t.Errorf("NewError() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestError(t *testing.T) { + var ErrTscInvalidConnection = &TaosError{ + Code: 0x020B, + ErrStr: "Invalid connection", + } + invalidError := ErrTscInvalidConnection.Error() + assert.Equal(t, "[0x20b] Invalid connection", invalidError) + unknownError := &TaosError{ + Code: 0xffff, + ErrStr: "unknown error", + } + assert.Equal(t, "unknown error", unknownError.Error()) +} diff --git a/driver/types/taostype.go b/driver/types/taostype.go new file mode 100644 index 00000000..f1bbcc26 --- /dev/null +++ b/driver/types/taostype.go @@ -0,0 +1,54 @@ +package types + +import ( + "reflect" + "time" +) + +type ( + TaosBool bool + TaosTinyint int8 + TaosSmallint int16 + TaosInt int32 + TaosBigint int64 + TaosUTinyint uint8 + TaosUSmallint uint16 + TaosUInt uint32 + TaosUBigint uint64 + TaosFloat float32 + TaosDouble float64 + TaosBinary []byte + TaosVarBinary []byte + TaosNchar string + TaosTimestamp struct { + T time.Time + Precision int + } + TaosJson []byte + TaosGeometry []byte +) + +var ( + TaosBoolType = reflect.TypeOf(TaosBool(false)) + TaosTinyintType = reflect.TypeOf(TaosTinyint(0)) + TaosSmallintType = reflect.TypeOf(TaosSmallint(0)) + TaosIntType = reflect.TypeOf(TaosInt(0)) + TaosBigintType = reflect.TypeOf(TaosBigint(0)) + TaosUTinyintType = reflect.TypeOf(TaosUTinyint(0)) + TaosUSmallintType = reflect.TypeOf(TaosUSmallint(0)) + TaosUIntType = reflect.TypeOf(TaosUInt(0)) + TaosUBigintType = reflect.TypeOf(TaosUBigint(0)) + TaosFloatType = reflect.TypeOf(TaosFloat(0)) + TaosDoubleType = reflect.TypeOf(TaosDouble(0)) + TaosBinaryType = reflect.TypeOf(TaosBinary(nil)) + TaosVarBinaryType = reflect.TypeOf(TaosVarBinary(nil)) + TaosNcharType = reflect.TypeOf(TaosNchar("")) + TaosTimestampType = reflect.TypeOf(TaosTimestamp{}) + TaosJsonType = reflect.TypeOf(TaosJson("")) + TaosGeometryType = reflect.TypeOf(TaosGeometry(nil)) +) + +type ColumnType struct { + Type reflect.Type + MaxLen int +} diff --git a/driver/types/types.go b/driver/types/types.go new file mode 100644 index 00000000..59b94d39 --- /dev/null +++ b/driver/types/types.go @@ -0,0 +1,492 @@ +package types + +import ( + "database/sql/driver" + "fmt" + "time" + + "github.com/taosdata/taosadapter/v3/driver/errors" +) + +type NullInt64 struct { + Inner int64 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt64) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(int64) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse int64 error"} + } + n.Inner = v + return nil +} + +// Value implements the driver Valuer interface. +func (n NullInt64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullInt64) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullInt32 struct { + Inner int32 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt32) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(int32) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse int32 error"} + } + n.Inner = v + return nil +} + +// Value implements the driver Valuer interface. +func (n NullInt32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullInt32) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullInt16 struct { + Inner int16 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt16) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(int16) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse int16 error"} + } + n.Inner = v + return nil +} + +// Value implements the driver Valuer interface. +func (n NullInt16) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullInt16) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullInt8 struct { + Inner int8 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt8) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(int8) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse int8 error"} + } + n.Inner = v + return nil +} + +// Value implements the driver Valuer interface. +func (n NullInt8) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullInt8) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullUInt64 struct { + Inner uint64 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUInt64) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(uint64) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse uint64 error"} + } + n.Inner = v + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUInt64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullUInt64) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullUInt32 struct { + Inner uint32 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUInt32) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(uint32) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse uint32 error"} + } + n.Inner = v + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUInt32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullUInt32) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullUInt16 struct { + Inner uint16 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUInt16) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(uint16) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse uint16 error"} + } + n.Inner = v + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUInt16) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullUInt16) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullUInt8 struct { + Inner uint8 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUInt8) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(uint8) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse uint8 error"} + } + n.Inner = v + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUInt8) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullUInt8) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullFloat32 struct { + Inner float32 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullFloat32) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(float32) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse float32 error"} + } + n.Inner = v + return nil +} + +func (n NullFloat32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullFloat32) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullFloat64 struct { + Inner float64 + Valid bool // Valid is true if Inner is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullFloat64) Scan(value interface{}) error { + if value == nil { + n.Inner, n.Valid = 0, false + return nil + } + n.Valid = true + v, ok := value.(float64) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse float64 error"} + } + n.Inner = v + return nil +} + +func (n NullFloat64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +func (n NullFloat64) String() string { + if n.Valid { + return fmt.Sprintf("%v", n.Inner) + } + return "NULL" +} + +type NullBool struct { + Inner bool + Valid bool // Valid is true if Inner is not NULL +} + +func (n *NullBool) Scan(value interface{}) error { + if value == nil { + n.Valid = false + return nil + } + n.Valid = true + v, ok := value.(bool) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse bool error"} + } + n.Inner = v + return nil +} + +func (n NullBool) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +type NullString struct { + Inner string + Valid bool // Valid is true if Inner is not NULL +} + +func (n *NullString) Scan(value interface{}) error { + if value == nil { + n.Valid = false + return nil + } + n.Valid = true + v, ok := value.(string) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse string error"} + } + n.Inner = v + return nil +} + +func (n NullString) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +// The value type must be time.Time or string / []byte (formatted time-string), +// otherwise Scan fails. +func (nt *NullTime) Scan(value interface{}) (err error) { + if value == nil { + nt.Time, nt.Valid = time.Time{}, false + return + } + + switch v := value.(type) { + case time.Time: + nt.Time, nt.Valid = v, true + return + case []byte: + nt.Time, err = time.Parse(time.RFC3339Nano, string(v)) + nt.Valid = err == nil + return + case string: + nt.Time, err = time.Parse(time.RFC3339Nano, v) + nt.Valid = err == nil + return + } + + nt.Valid = false + return fmt.Errorf("can't convert %T to time.Time", value) +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + +type NullJson struct { + Inner RawMessage + Valid bool +} + +func (n *NullJson) Scan(value interface{}) error { + if value == nil { + n.Valid = false + return nil + } + n.Valid = true + v, ok := value.([]byte) + if !ok { + return &errors.TaosError{Code: 0xffff, ErrStr: "taosSql parse json error"} + } + n.Inner = v + return nil +} + +func (n NullJson) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Inner, nil +} + +type RawMessage []byte + +func (m RawMessage) MarshalJSON() ([]byte, error) { + if m == nil { + return []byte("null"), nil + } + return m, nil +} + +func (m *RawMessage) UnmarshalJSON(data []byte) error { + if m == nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "json.RawMessage: UnmarshalJSON on nil pointer"} + } + *m = append((*m)[0:0], data...) + return nil +} diff --git a/driver/types/types_test.go b/driver/types/types_test.go new file mode 100644 index 00000000..487da516 --- /dev/null +++ b/driver/types/types_test.go @@ -0,0 +1,2122 @@ +package types + +import ( + "database/sql/driver" + "reflect" + "testing" + "time" +) + +// @author: xftan +// @date: 2022/1/27 16:20 +// @description: test null bool type Scan() +func TestNullBool_Scan(t *testing.T) { + type fields struct { + Inner bool + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "true", + fields: fields{ + Inner: true, + Valid: true, + }, + args: args{ + value: true, + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: true, + Valid: false, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: false, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullBool{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:20 +// @description: test null bool type Value() +func TestNullBool_Value(t *testing.T) { + type fields struct { + Inner bool + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "ture", + fields: fields{ + Inner: true, + Valid: true, + }, + want: true, + wantErr: false, + }, + { + name: "false", + fields: fields{ + Inner: false, + Valid: true, + }, + want: false, + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: false, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullBool{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:21 +// @description: test null float32 type Scan() +func TestNullFloat32_Scan(t *testing.T) { + type fields struct { + Inner float32 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: float32(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullFloat32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:21 +// @description: test null float32 type String() +func TestNullFloat32_String(t *testing.T) { + type fields struct { + Inner float32 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullFloat32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:21 +// @description: test null float32 type Value() +func TestNullFloat32_Value(t *testing.T) { + type fields struct { + Inner float32 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: float32(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullFloat32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:21 +// @description: test null float64 type Scan() +func TestNullFloat64_Scan(t *testing.T) { + type fields struct { + Inner float64 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: float64(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullFloat64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:22 +// @description: test null float64 type String() +func TestNullFloat64_String(t *testing.T) { + type fields struct { + Inner float64 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullFloat64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:22 +// @description: test null float64 type Value() +func TestNullFloat64_Value(t *testing.T) { + type fields struct { + Inner float64 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: float64(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullFloat64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:22 +// @description: test null int16 type Scan() +func TestNullInt16_Scan(t *testing.T) { + type fields struct { + Inner int16 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: int16(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullInt16{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:23 +// @description: test null int16 type String() +func TestNullInt16_String(t *testing.T) { + type fields struct { + Inner int16 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullInt16{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:23 +// @description: test null int16 type Value() +func TestNullInt16_Value(t *testing.T) { + type fields struct { + Inner int16 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: int16(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullInt16{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:23 +// @description: test null int32 type Scan() +func TestNullInt32_Scan(t *testing.T) { + type fields struct { + Inner int32 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: int32(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullInt32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:25 +// @description: test null int32 type String() +func TestNullInt32_String(t *testing.T) { + type fields struct { + Inner int32 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullInt32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:04 +// @description: test null int32 type Value() +func TestNullInt32_Value(t *testing.T) { + type fields struct { + Inner int32 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: int32(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullInt32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:23 +// @description: test null int64 type Scan() +func TestNullInt64_Scan(t *testing.T) { + type fields struct { + Inner int64 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: int64(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullInt64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:47 +// @description: test null int64 type String() +func TestNullInt64_String(t *testing.T) { + type fields struct { + Inner int64 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullInt64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:19 +// @description: test null int64 type Value() +func TestNullInt64_Value(t *testing.T) { + type fields struct { + Inner int64 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: int64(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullInt64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:23 +// @description: test null int8 type Scan() +func TestNullInt8_Scan(t *testing.T) { + type fields struct { + Inner int8 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: int8(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullInt8{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:48 +// @description: test null int8 type String() +func TestNullInt8_String(t *testing.T) { + type fields struct { + Inner int8 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullInt8{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:19 +// @description: test null int8 type Value() +func TestNullInt8_Value(t *testing.T) { + type fields struct { + Inner int8 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: int8(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullInt8{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:24 +// @description: test null json type Scan() +func TestNullJson_Scan(t *testing.T) { + type fields struct { + Inner RawMessage + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{}, + args: args{ + value: []byte{'1', '2', '3'}, + }, + wantErr: false, + }, + { + name: "error", + fields: fields{}, + args: args{ + value: 123, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{}, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullJson{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:19 +// @description: test null json type Value() +func TestNullJson_Value(t *testing.T) { + type fields struct { + Inner RawMessage + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: RawMessage("123"), + Valid: true, + }, + want: RawMessage("123"), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: nil, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullJson{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:24 +// @description: test null string type Scan() +func TestNullString_Scan(t *testing.T) { + type fields struct { + Inner string + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{}, + args: args{ + value: "123", + }, + wantErr: false, + }, + { + name: "error", + fields: fields{}, + args: args{ + value: 123, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{}, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullString{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:19 +// @description: test null string type Value() +func TestNullString_Value(t *testing.T) { + type fields struct { + Inner string + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: "123", + Valid: true, + }, + want: "123", + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: "", + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullString{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:24 +// @description: test null time type Scan() +func TestNullTime_Scan(t *testing.T) { + type fields struct { + Time time.Time + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "time", + fields: fields{}, + args: args{ + value: time.Now(), + }, + wantErr: false, + }, + { + name: "bytes", + fields: fields{}, + args: args{ + value: []byte("2022-01-27T15:34:52.9368423+08:00"), + }, + wantErr: false, + }, + { + name: "string", + fields: fields{}, + args: args{ + value: "2022-01-27T15:34:52.9368423+08:00", + }, + wantErr: false, + }, + { + name: "error", + fields: fields{}, + args: args{ + value: 123, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{}, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nt := &NullTime{ + Time: tt.fields.Time, + Valid: tt.fields.Valid, + } + if err := nt.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:20 +// @description: test null time type Value() +func TestNullTime_Value(t *testing.T) { + now := time.Now() + type fields struct { + Time time.Time + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Time: now, + Valid: true, + }, + want: now, + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Time: now, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nt := NullTime{ + Time: tt.fields.Time, + Valid: tt.fields.Valid, + } + got, err := nt.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:24 +// @description: test null uint16 type Scan() +func TestNullUInt16_Scan(t *testing.T) { + type fields struct { + Inner uint16 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: uint16(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullUInt16{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:48 +// @description: test null uint16 type String() +func TestNullUInt16_String(t *testing.T) { + type fields struct { + Inner uint16 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullUInt16{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:20 +// @description: test null uint16 type Value() +func TestNullUInt16_Value(t *testing.T) { + type fields struct { + Inner uint16 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: uint16(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullUInt16{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:24 +// @description: test null uint32 type Scan() +func TestNullUInt32_Scan(t *testing.T) { + type fields struct { + Inner uint32 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: uint32(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullUInt32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:48 +// @description: test null uint32 type String() +func TestNullUInt32_String(t *testing.T) { + type fields struct { + Inner uint32 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullUInt32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:20 +// @description: test null uint32 type Value() +func TestNullUInt32_Value(t *testing.T) { + type fields struct { + Inner uint32 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: uint32(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullUInt32{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:24 +// @description: test null uint64 type Scan() +func TestNullUInt64_Scan(t *testing.T) { + type fields struct { + Inner uint64 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: uint64(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullUInt64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:48 +// @description: test null uint64 type String() +func TestNullUInt64_String(t *testing.T) { + type fields struct { + Inner uint64 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullUInt64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:20 +// @description: test null uint64 type Value() +func TestNullUInt64_Value(t *testing.T) { + type fields struct { + Inner uint64 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: uint64(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullUInt64{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:25 +// @description: test null uint8 type Scan() +func TestNullUInt8_Scan(t *testing.T) { + type fields struct { + Inner uint8 + Valid bool + } + type args struct { + value interface{} + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: uint8(1), + }, + wantErr: false, + }, + { + name: "error", + fields: fields{ + Inner: 1, + Valid: true, + }, + args: args{ + value: 1, + }, + wantErr: true, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + args: args{ + value: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := &NullUInt8{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if err := n.Scan(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 16:48 +// @description: test null uint8 type String() +func TestNullUInt8_String(t *testing.T) { + type fields struct { + Inner uint8 + Valid bool + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: "123", + }, + { + name: "null", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: "NULL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullUInt8{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + if got := n.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:20 +// @description: test null uint8 type Value() +func TestNullUInt8_Value(t *testing.T) { + type fields struct { + Inner uint8 + Valid bool + } + tests := []struct { + name string + fields fields + want driver.Value + wantErr bool + }{ + { + name: "common", + fields: fields{ + Inner: 123, + Valid: true, + }, + want: uint8(123), + wantErr: false, + }, + { + name: "nil", + fields: fields{ + Inner: 0, + Valid: false, + }, + want: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := NullUInt8{ + Inner: tt.fields.Inner, + Valid: tt.fields.Valid, + } + got, err := n.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Value() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:20 +// @description: test raw message type MarshalJson() interface +func TestRawMessage_MarshalJSON(t *testing.T) { + tests := []struct { + name string + m RawMessage + want []byte + wantErr bool + }{ + { + name: "common", + m: RawMessage(`{"a":"b"}`), + want: []byte(`{"a":"b"}`), + wantErr: false, + }, + { + name: "nil", + m: nil, + want: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.m.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:21 +// @description: test raw message type UnmarshalJson() interface +func TestRawMessage_UnmarshalJSON(t *testing.T) { + common := RawMessage(`{"a":"b"}`) + type args struct { + data []byte + } + tests := []struct { + name string + m *RawMessage + args args + wantErr bool + }{ + { + name: "common", + m: &common, + args: args{ + data: []byte(`{"a":"b"}`), + }, + wantErr: false, + }, + { + name: "error", + m: nil, + args: args{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.m.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/driver/wrapper/asynccb.go b/driver/wrapper/asynccb.go new file mode 100644 index 00000000..ad9d58a1 --- /dev/null +++ b/driver/wrapper/asynccb.go @@ -0,0 +1,38 @@ +package wrapper + +/* +#include +#include +#include +#include + +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +type Caller interface { + QueryCall(res unsafe.Pointer, code int) + FetchCall(res unsafe.Pointer, numOfRows int) +} + +//export QueryCallback +func QueryCallback(p unsafe.Pointer, res *C.TAOS_RES, code C.int) { + caller := (*(*cgo.Handle)(p)).Value().(Caller) + caller.QueryCall(unsafe.Pointer(res), int(code)) +} + +//export FetchRowsCallback +func FetchRowsCallback(p unsafe.Pointer, res *C.TAOS_RES, numOfRows C.int) { + caller := (*(*cgo.Handle)(p)).Value().(Caller) + caller.FetchCall(unsafe.Pointer(res), int(numOfRows)) +} + +//export FetchRawBlockCallback +func FetchRawBlockCallback(p unsafe.Pointer, res *C.TAOS_RES, numOfRows C.int) { + caller := (*(*cgo.Handle)(p)).Value().(Caller) + caller.FetchCall(unsafe.Pointer(res), int(numOfRows)) +} diff --git a/driver/wrapper/block.go b/driver/wrapper/block.go new file mode 100644 index 00000000..eacf894f --- /dev/null +++ b/driver/wrapper/block.go @@ -0,0 +1,49 @@ +package wrapper + +/* +#include +#include +#include +#include +*/ +import "C" +import ( + "unsafe" +) + +// TaosFetchRawBlock int taos_fetch_raw_block(TAOS_RES *res, int* numOfRows, void** pData); +func TaosFetchRawBlock(result unsafe.Pointer) (int, int, unsafe.Pointer) { + var cSize int + size := unsafe.Pointer(&cSize) + var block unsafe.Pointer + errCode := int(C.taos_fetch_raw_block(result, (*C.int)(size), &block)) + return cSize, errCode, block +} + +// TaosWriteRawBlock DLL_EXPORT int taos_write_raw_block(TAOS *taos, int numOfRows, char *pData, const char* tbname); +func TaosWriteRawBlock(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, tableName string) int { + cStr := C.CString(tableName) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_write_raw_block(conn, (C.int)(numOfRows), (*C.char)(pData), cStr)) +} + +// TaosWriteRawBlockWithFields DLL_EXPORT int taos_write_raw_block_with_fields(TAOS* taos, int rows, char* pData, const char* tbname, TAOS_FIELD *fields, int numFields); +func TaosWriteRawBlockWithFields(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, tableName string, fields unsafe.Pointer, numFields int) int { + cStr := C.CString(tableName) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_write_raw_block_with_fields(conn, (C.int)(numOfRows), (*C.char)(pData), cStr, (*C.struct_taosField)(fields), (C.int)(numFields))) +} + +// DLL_EXPORT int taos_write_raw_block_with_reqid(TAOS *taos, int numOfRows, char *pData, const char *tbname, int64_t reqid); +func TaosWriteRawBlockWithReqID(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, tableName string, reqID int64) int { + cStr := C.CString(tableName) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_write_raw_block_with_reqid(conn, (C.int)(numOfRows), (*C.char)(pData), cStr, (C.int64_t)(reqID))) +} + +// DLL_EXPORT int taos_write_raw_block_with_fields_with_reqid(TAOS *taos, int rows, char *pData, const char *tbname,TAOS_FIELD *fields, int numFields, int64_t reqid); +func TaosWriteRawBlockWithFieldsWithReqID(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, tableName string, fields unsafe.Pointer, numFields int, reqID int64) int { + cStr := C.CString(tableName) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_write_raw_block_with_fields_with_reqid(conn, (C.int)(numOfRows), (*C.char)(pData), cStr, (*C.struct_taosField)(fields), (C.int)(numFields), (C.int64_t)(reqID))) +} diff --git a/driver/wrapper/block_test.go b/driver/wrapper/block_test.go new file mode 100644 index 00000000..85523d78 --- /dev/null +++ b/driver/wrapper/block_test.go @@ -0,0 +1,924 @@ +package wrapper + +import ( + "database/sql/driver" + "fmt" + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/errors" +) + +// @author: xftan +// @date: 2023/10/13 11:27 +// @description: test read block +func TestReadBlock(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + res := TaosQuery(conn, "drop database if exists test_block_raw") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res = TaosQuery(conn, "drop database if exists test_block_raw") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "create database test_block_raw") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "create table if not exists test_block_raw.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + after2s := now.Add(2 * time.Second) + sql := fmt.Sprintf("insert into test_block_raw.t0 using test_block_raw.all_type tags('{\"a\":1}') values"+ + "('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')"+ + "('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)"+ + "('%s',true,%d,%d,%d,%d,%d,%d,%d,%v,%f,%f,'b','n')", + now.Format(time.RFC3339Nano), + after1s.Format(time.RFC3339Nano), + after2s.Format(time.RFC3339Nano), + math.MaxInt8, + math.MaxInt16, + math.MaxInt32, + math.MaxInt64, + math.MaxUint8, + math.MaxUint16, + math.MaxUint32, + uint64(math.MaxUint64), + math.MaxFloat32, + math.MaxFloat64, + ) + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "select * from test_block_raw.all_type" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) + } + TaosFreeResult(res) + assert.Equal(t, 3, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) + row3 := data[2] + assert.Equal(t, after2s.UnixNano()/1e6, row3[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[1].(bool)) + assert.Equal(t, int8(math.MaxInt8), row3[2].(int8)) + assert.Equal(t, int16(math.MaxInt16), row3[3].(int16)) + assert.Equal(t, int32(math.MaxInt32), row3[4].(int32)) + assert.Equal(t, int64(math.MaxInt64), row3[5].(int64)) + assert.Equal(t, uint8(math.MaxUint8), row3[6].(uint8)) + assert.Equal(t, uint16(math.MaxUint16), row3[7].(uint16)) + assert.Equal(t, uint32(math.MaxUint32), row3[8].(uint32)) + assert.Equal(t, uint64(math.MaxUint64), row3[9].(uint64)) + assert.Equal(t, float32(math.MaxFloat32), row3[10].(float32)) + assert.Equal(t, float64(math.MaxFloat64), row3[11].(float64)) + assert.Equal(t, "b", row3[12].(string)) + assert.Equal(t, "n", row3[13].(string)) + assert.Equal(t, []byte(`{"a":1}`), row3[14].([]byte)) +} + +// @author: xftan +// @date: 2023/10/13 11:27 +// @description: test write raw block +func TestTaosWriteRawBlock(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + res := TaosQuery(conn, "drop database if exists test_write_block_raw") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res = TaosQuery(conn, "drop database if exists test_write_block_raw") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "create database test_write_block_raw") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "create table if not exists test_write_block_raw.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_write_block_raw.t0 using test_write_block_raw.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "create table test_write_block_raw.t1 using test_write_block_raw.all_type tags('{\"a\":2}')" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "use test_write_block_raw" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw.t0" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + + errCode = TaosWriteRawBlock(conn, blockSize, block, "t1") + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(nil) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw.t1" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) + } + TaosFreeResult(res) + + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } +} + +// @author: xftan +// @date: 2023/10/13 11:28 +// @description: test write raw block with fields +func TestTaosWriteRawBlockWithFields(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + res := TaosQuery(conn, "drop database if exists test_write_block_raw_fields") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res = TaosQuery(conn, "drop database if exists test_write_block_raw_fields") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "create database test_write_block_raw_fields") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "create table if not exists test_write_block_raw_fields.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_write_block_raw_fields.t0 using test_write_block_raw_fields.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "create table test_write_block_raw_fields.t1 using test_write_block_raw_fields.all_type tags('{\"a\":2}')" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "use test_write_block_raw_fields" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "select ts,c1 from test_write_block_raw_fields.t0" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + fieldsCount := TaosNumFields(res) + fields := TaosFetchFields(res) + + errCode = TaosWriteRawBlockWithFields(conn, blockSize, block, "t1", fields, fieldsCount) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(nil) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw_fields.t1" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) + } + TaosFreeResult(res) + + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + for i := 2; i < 14; i++ { + assert.Nil(t, row1[i]) + } + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } +} + +// @author: xftan +// @date: 2023/11/17 9:39 +// @description: test write raw block with reqid +func TestTaosWriteRawBlockWithReqID(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + res := TaosQuery(conn, "drop database if exists test_write_block_raw_with_reqid") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res = TaosQuery(conn, "drop database if exists test_write_block_raw_with_reqid") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "create database test_write_block_raw_with_reqid") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "create table if not exists test_write_block_raw_with_reqid.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_write_block_raw_with_reqid.t0 using test_write_block_raw_with_reqid.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "create table test_write_block_raw_with_reqid.t1 using test_write_block_raw_with_reqid.all_type tags('{\"a\":2}')" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "use test_write_block_raw_with_reqid" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw_with_reqid.t0" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + + errCode = TaosWriteRawBlockWithReqID(conn, blockSize, block, "t1", 1) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(nil) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw_with_reqid.t1" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) + } + TaosFreeResult(res) + + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } +} + +// @author: xftan +// @date: 2023/11/17 9:37 +// @description: test write raw block with fields and reqid +func TestTaosWriteRawBlockWithFieldsWithReqID(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + res := TaosQuery(conn, "drop database if exists test_write_block_raw_fields_with_reqid") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res = TaosQuery(conn, "drop database if exists test_write_block_raw_fields_with_reqid") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "create database test_write_block_raw_fields_with_reqid") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "create table if not exists test_write_block_raw_fields_with_reqid.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_write_block_raw_fields_with_reqid.t0 using test_write_block_raw_fields_with_reqid.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "create table test_write_block_raw_fields_with_reqid.t1 using test_write_block_raw_fields_with_reqid.all_type tags('{\"a\":2}')" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "use test_write_block_raw_fields_with_reqid" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "select ts,c1 from test_write_block_raw_fields_with_reqid.t0" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + fieldsCount := TaosNumFields(res) + fields := TaosFetchFields(res) + + errCode = TaosWriteRawBlockWithFieldsWithReqID(conn, blockSize, block, "t1", fields, fieldsCount, 1) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(nil) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw_fields_with_reqid.t1" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) + } + TaosFreeResult(res) + + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + for i := 2; i < 14; i++ { + assert.Nil(t, row1[i]) + } + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } +} diff --git a/driver/wrapper/cgo/README.md b/driver/wrapper/cgo/README.md new file mode 100644 index 00000000..329a9baa --- /dev/null +++ b/driver/wrapper/cgo/README.md @@ -0,0 +1 @@ +Copy from go 1.17.2 runtime/cgo. In order to be compatible with lower versions \ No newline at end of file diff --git a/driver/wrapper/cgo/handle.go b/driver/wrapper/cgo/handle.go new file mode 100644 index 00000000..586adcf3 --- /dev/null +++ b/driver/wrapper/cgo/handle.go @@ -0,0 +1,81 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cgo + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +// Handle provides a way to pass values that contain Go pointers +// (pointers to memory allocated by Go) between Go and C without +// breaking the cgo pointer passing rules. A Handle is an integer +// value that can represent any Go value. A Handle can be passed +// through C and back to Go, and Go code can use the Handle to +// retrieve the original Go value. +// +// The underlying type of Handle is guaranteed to fit in an integer type +// that is large enough to hold the bit pattern of any pointer. The zero +// value of a Handle is not valid, and thus is safe to use as a sentinel +// in C APIs. + +type Handle uintptr + +// NewHandle returns a handle for a given value. +// +// The handle is valid until the program calls Delete on it. The handle +// uses resources, and this package assumes that C code may hold on to +// the handle, so a program must explicitly call Delete when the handle +// is no longer needed. +// +// The intended use is to pass the returned handle to C code, which +// passes it back to Go, which calls Value. +func NewHandle(v interface{}) Handle { + h := atomic.AddUintptr(&handleIdx, 1) + if h == 0 { + panic("runtime/cgo: ran out of handle space") + } + + handles.Store(h, v) + handle := Handle(h) + handlePointers.Store(h, &handle) + return handle +} + +// Value returns the associated Go value for a valid handle. +// +// The method panics if the handle is invalid. +func (h Handle) Value() interface{} { + v, ok := handles.Load(uintptr(h)) + if !ok { + panic("runtime/cgo: misuse of an invalid Handle") + } + return v +} + +func (h Handle) Pointer() unsafe.Pointer { + p, ok := handlePointers.Load(uintptr(h)) + if !ok { + panic("runtime/cgo: misuse of an invalid Handle") + } + return unsafe.Pointer(p.(*Handle)) +} + +// Delete invalidates a handle. This method should only be called once +// the program no longer needs to pass the handle to C and the C code +// no longer has a copy of the handle value. +// +// The method panics if the handle is invalid. +func (h Handle) Delete() { + handles.Delete(uintptr(h)) + handlePointers.Delete(uintptr(h)) +} + +var ( + handles = sync.Map{} // map[Handle]interface{} + handlePointers = sync.Map{} // map[Handle]*Handle + handleIdx uintptr // atomic +) diff --git a/driver/wrapper/cgo/handle_test.go b/driver/wrapper/cgo/handle_test.go new file mode 100644 index 00000000..46cdc52d --- /dev/null +++ b/driver/wrapper/cgo/handle_test.go @@ -0,0 +1,107 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cgo + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +// @author: xftan +// @date: 2022/1/27 17:21 +// @description: test cgo handler +func TestHandle(t *testing.T) { + v := 42 + + tests := []struct { + v1 interface{} + v2 interface{} + }{ + {v1: v, v2: v}, + {v1: &v, v2: &v}, + {v1: nil, v2: nil}, + } + + for _, tt := range tests { + h1 := NewHandle(tt.v1) + h2 := NewHandle(tt.v2) + + if uintptr(h1) == 0 || uintptr(h2) == 0 { + t.Fatalf("NewHandle returns zero") + } + + if uintptr(h1) == uintptr(h2) { + t.Fatalf("Duplicated Go values should have different handles, but got equal") + } + + h1v := h1.Value() + h2v := h2.Value() + if !reflect.DeepEqual(h1v, h2v) || !reflect.DeepEqual(h1v, tt.v1) { + t.Fatalf("Value of a Handle got wrong, got %+v %+v, want %+v", h1v, h2v, tt.v1) + } + + h1.Delete() + h2.Delete() + } + + siz := 0 + handles.Range(func(k, v interface{}) bool { + siz++ + return true + }) + if siz != 0 { + t.Fatalf("handles are not cleared, got %d, want %d", siz, 0) + } +} + +func TestPointer(t *testing.T) { + v := 42 + h := NewHandle(&v) + p := h.Pointer() + assert.Equal(t, *(*Handle)(p), h) + h.Delete() + defer func() { + if r := recover(); r != nil { + return + } + t.Fatalf("Pointer should panic") + }() + h.Pointer() +} + +func TestInvalidValue(t *testing.T) { + v := 42 + h := NewHandle(&v) + h.Delete() + defer func() { + if r := recover(); r != nil { + return + } + t.Fatalf("Value should panic") + }() + h.Value() +} + +func BenchmarkHandle(b *testing.B) { + b.Run("non-concurrent", func(b *testing.B) { + for i := 0; i < b.N; i++ { + h := NewHandle(i) + _ = h.Value() + h.Delete() + } + }) + b.Run("concurrent", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + var v int + for pb.Next() { + h := NewHandle(v) + _ = h.Value() + h.Delete() + } + }) + }) +} diff --git a/driver/wrapper/field.go b/driver/wrapper/field.go new file mode 100644 index 00000000..feb3492e --- /dev/null +++ b/driver/wrapper/field.go @@ -0,0 +1,67 @@ +package wrapper + +/* +#include +*/ +import "C" +import ( + "bytes" + "reflect" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/errors" +) + +type RowsHeader struct { + ColNames []string + ColTypes []uint8 + ColLength []int64 +} + +func ReadColumn(result unsafe.Pointer, count int) (*RowsHeader, error) { + if result == nil { + return nil, &errors.TaosError{Code: 0xffff, ErrStr: "invalid result"} + } + rowsHeader := &RowsHeader{ + ColNames: make([]string, count), + ColTypes: make([]uint8, count), + ColLength: make([]int64, count), + } + pFields := TaosFetchFields(result) + for i := 0; i < count; i++ { + field := *(*C.struct_taosField)(unsafe.Pointer(uintptr(pFields) + uintptr(C.sizeof_struct_taosField*C.int(i)))) + buf := bytes.NewBufferString("") + for _, c := range field.name { + if c == 0 { + break + } + buf.WriteByte(byte(c)) + } + rowsHeader.ColNames[i] = buf.String() + rowsHeader.ColTypes[i] = (uint8)(field._type) + rowsHeader.ColLength[i] = int64((uint32)(field.bytes)) + } + return rowsHeader, nil +} + +func (rh *RowsHeader) TypeDatabaseName(i int) string { + return common.TypeNameMap[int(rh.ColTypes[i])] +} + +func (rh *RowsHeader) ScanType(i int) reflect.Type { + t, exist := common.ColumnTypeMap[int(rh.ColTypes[i])] + if !exist { + return common.UnknownType + } + return t +} + +func FetchLengths(res unsafe.Pointer, count int) []int { + lengths := TaosFetchLengths(res) + result := make([]int, count) + for i := 0; i < count; i++ { + result[i] = int(*(*C.int)(unsafe.Pointer(uintptr(lengths) + uintptr(C.sizeof_int*C.int(i))))) + } + return result +} diff --git a/driver/wrapper/field_test.go b/driver/wrapper/field_test.go new file mode 100644 index 00000000..47210c13 --- /dev/null +++ b/driver/wrapper/field_test.go @@ -0,0 +1,483 @@ +package wrapper + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/errors" +) + +// @author: xftan +// @date: 2022/1/27 17:22 +// @description: test taos_fetch_lengths +func TestFetchLengths(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + res := TaosQuery(conn, "drop database if exists test_fetch_lengths") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res := TaosQuery(conn, "create database if not exists test_fetch_lengths") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res := TaosQuery(conn, "drop database if exists test_fetch_lengths") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "drop table if exists test_fetch_lengths.test") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "create table if not exists test_fetch_lengths.test (ts timestamp, c1 int,c2 binary(10),c3 nchar(10))") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "insert into test_fetch_lengths.test values(now,1,'123','456')") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "select * from test_fetch_lengths.test") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + count := TaosNumFields(res) + assert.Equal(t, 4, count) + _, rows := TaosFetchBlock(res) + _ = rows + lengthList := FetchLengths(res, count) + TaosFreeResult(res) + assert.Equal(t, []int{8, 4, 12, 42}, lengthList) +} + +// @author: xftan +// @date: 2022/1/27 17:23 +// @description: test result column database name +func TestRowsHeader_TypeDatabaseName(t *testing.T) { + type fields struct { + ColNames []string + ColTypes []uint8 + ColLength []int64 + } + type args struct { + i int + } + tests := []struct { + name string + fields fields + args args + want string + }{ + { + name: "NULL", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 0, + }, + want: "NULL", + }, + { + name: "BOOL", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 1, + }, + want: "BOOL", + }, + { + name: "TINYINT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 2, + }, + want: "TINYINT", + }, + { + name: "SMALLINT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 3, + }, + want: "SMALLINT", + }, + { + name: "INT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 4, + }, + want: "INT", + }, + { + name: "BIGINT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 5, + }, + want: "BIGINT", + }, + { + name: "FLOAT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 6, + }, + want: "FLOAT", + }, + { + name: "DOUBLE", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 7, + }, + want: "DOUBLE", + }, + { + name: "BINARY", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 8, + }, + want: "VARCHAR", + }, + { + name: "TIMESTAMP", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 9, + }, + want: "TIMESTAMP", + }, + { + name: "NCHAR", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 10, + }, + want: "NCHAR", + }, + { + name: "TINYINT UNSIGNED", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 11, + }, + want: "TINYINT UNSIGNED", + }, + { + name: "SMALLINT UNSIGNED", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 12, + }, + want: "SMALLINT UNSIGNED", + }, + { + name: "INT UNSIGNED", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 13, + }, + want: "INT UNSIGNED", + }, + { + name: "BIGINT UNSIGNED", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 14, + }, + want: "BIGINT UNSIGNED", + }, + { + name: "JSON", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 15, + }, + want: "JSON", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rh := &RowsHeader{ + ColNames: tt.fields.ColNames, + ColTypes: tt.fields.ColTypes, + ColLength: tt.fields.ColLength, + } + if got := rh.TypeDatabaseName(tt.args.i); got != tt.want { + t.Errorf("TypeDatabaseName() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:23 +// @description: test scan result column type +func TestRowsHeader_ScanType(t *testing.T) { + type fields struct { + ColNames []string + ColTypes []uint8 + ColLength []int64 + } + type args struct { + i int + } + tests := []struct { + name string + fields fields + args args + want reflect.Type + }{ + { + name: "unknown", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 0, + }, + want: common.UnknownType, + }, + { + name: "BOOL", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 1, + }, + want: common.NullBool, + }, + { + name: "TINYINT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 2, + }, + want: common.NullInt8, + }, + { + name: "SMALLINT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 3, + }, + want: common.NullInt16, + }, { + name: "INT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 4, + }, + want: common.NullInt32, + }, + { + name: "BIGINT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 5, + }, + want: common.NullInt64, + }, + { + name: "FLOAT", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 6, + }, + want: common.NullFloat32, + }, + { + name: "DOUBLE", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 7, + }, + want: common.NullFloat64, + }, + { + name: "BINARY", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 8, + }, + want: common.NullString, + }, + { + name: "TIMESTAMP", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 9, + }, + want: common.NullTime, + }, + { + name: "NCHAR", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 10, + }, + want: common.NullString, + }, + { + name: "TINYINT UNSIGNED", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 11, + }, + want: common.NullUInt8, + }, + { + name: "SMALLINT UNSIGNED", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 12, + }, + want: common.NullUInt16, + }, + { + name: "INT UNSIGNED", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 13, + }, + want: common.NullUInt32, + }, + { + name: "BIGINT UNSIGNEDD", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 14, + }, + want: common.NullUInt64, + }, + { + name: "JSON", + fields: fields{ + ColTypes: []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + i: 15, + }, + want: common.NullJson, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rh := &RowsHeader{ + ColNames: tt.fields.ColNames, + ColTypes: tt.fields.ColTypes, + ColLength: tt.fields.ColLength, + } + if got := rh.ScanType(tt.args.i); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ScanType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/driver/wrapper/notify.go b/driver/wrapper/notify.go new file mode 100644 index 00000000..21aa9ef1 --- /dev/null +++ b/driver/wrapper/notify.go @@ -0,0 +1,22 @@ +package wrapper + +/* +#include +#include +#include +#include +extern void NotifyCallback(void *param, void *ext, int type); +int taos_set_notify_cb_wrapper(TAOS *taos, void *param, int type){ + return taos_set_notify_cb(taos,NotifyCallback,param,type); +}; +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +func TaosSetNotifyCB(taosConnect unsafe.Pointer, caller cgo.Handle, notifyType int) int32 { + return int32(C.taos_set_notify_cb_wrapper(taosConnect, caller.Pointer(), (C.int)(notifyType))) +} diff --git a/driver/wrapper/notify_test.go b/driver/wrapper/notify_test.go new file mode 100644 index 00000000..8db0cb37 --- /dev/null +++ b/driver/wrapper/notify_test.go @@ -0,0 +1,100 @@ +package wrapper + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +// @author: xftan +// @date: 2023/10/13 11:28 +// @description: test notify callback +func TestNotify(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + defer func() { + _ = exec(conn, "drop user t_notify") + }() + _ = exec(conn, "drop user t_notify") + err = exec(conn, "create user t_notify pass 'notify'") + assert.NoError(t, err) + + conn2, err := TaosConnect("", "t_notify", "notify", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn2) + notify := make(chan int32, 1) + handler := cgo.NewHandle(notify) + errCode := TaosSetNotifyCB(conn2, handler, common.TAOS_NOTIFY_PASSVER) + if errCode != 0 { + errStr := TaosErrorStr(nil) + t.Error(errCode, errStr) + } + notifyWhitelist := make(chan int64, 1) + handlerWhiteList := cgo.NewHandle(notifyWhitelist) + errCode = TaosSetNotifyCB(conn2, handlerWhiteList, common.TAOS_NOTIFY_WHITELIST_VER) + if errCode != 0 { + errStr := TaosErrorStr(nil) + t.Error(errCode, errStr) + } + + notifyDropUser := make(chan struct{}, 1) + handlerDropUser := cgo.NewHandle(notifyDropUser) + errCode = TaosSetNotifyCB(conn2, handlerDropUser, common.TAOS_NOTIFY_USER_DROPPED) + if errCode != 0 { + errStr := TaosErrorStr(nil) + t.Error(errCode, errStr) + } + + err = exec(conn, "alter user t_notify pass 'test'") + assert.NoError(t, err) + timeout, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + now := time.Now() + select { + case version := <-notify: + t.Log(time.Since(now)) + t.Log("password changed", version) + case <-timeout.Done(): + t.Error("wait for notify callback timeout") + } + + err = exec(conn, "ALTER USER t_notify ADD HOST '192.168.1.98/0','192.168.1.98/32'") + assert.NoError(t, err) + timeoutWhiteList, cancelWhitelist := context.WithTimeout(context.Background(), time.Second*5) + defer cancelWhitelist() + now = time.Now() + select { + case version := <-notifyWhitelist: + t.Log(time.Since(now)) + t.Log("whitelist changed", version) + case <-timeoutWhiteList.Done(): + t.Error("wait for notifyWhitelist callback timeout") + } + + err = exec(conn, "drop USER t_notify") + assert.NoError(t, err) + timeoutDropUser, cancelDropUser := context.WithTimeout(context.Background(), time.Second*5) + defer cancelDropUser() + now = time.Now() + select { + case <-notifyDropUser: + t.Log(time.Since(now)) + t.Log("user dropped") + case <-timeoutDropUser.Done(): + t.Error("wait for notifyDropUser callback timeoutDropUser") + } + +} diff --git a/driver/wrapper/notifycb.go b/driver/wrapper/notifycb.go new file mode 100644 index 00000000..8eb5c522 --- /dev/null +++ b/driver/wrapper/notifycb.go @@ -0,0 +1,36 @@ +package wrapper + +/* +#include +#include +#include +#include +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +//export NotifyCallback +func NotifyCallback(p unsafe.Pointer, ext unsafe.Pointer, notifyType C.int) { + defer func() { + // channel may be closed + _ = recover() + }() + switch int(notifyType) { + case common.TAOS_NOTIFY_PASSVER: + version := int32(*(*C.int32_t)(ext)) + c := (*(*cgo.Handle)(p)).Value().(chan int32) + c <- version + case common.TAOS_NOTIFY_WHITELIST_VER: + version := int64(*(*C.int64_t)(ext)) + c := (*(*cgo.Handle)(p)).Value().(chan int64) + c <- version + case common.TAOS_NOTIFY_USER_DROPPED: + c := (*(*cgo.Handle)(p)).Value().(chan struct{}) + c <- struct{}{} + } +} diff --git a/driver/wrapper/row.go b/driver/wrapper/row.go new file mode 100644 index 00000000..ebe4d8cd --- /dev/null +++ b/driver/wrapper/row.go @@ -0,0 +1,77 @@ +package wrapper + +/* +#include +*/ +import "C" +import ( + "database/sql/driver" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/tools" +) + +const ( + PointerSize = unsafe.Sizeof(uintptr(1)) +) + +type FormatTimeFunc func(ts int64, precision int) driver.Value + +func FetchRow(row unsafe.Pointer, offset int, colType uint8, length int, arg ...interface{}) driver.Value { + base := *(**C.void)(tools.AddPointer(row, uintptr(offset)*PointerSize)) + p := unsafe.Pointer(base) + if p == nil { + return nil + } + switch colType { + case C.TSDB_DATA_TYPE_BOOL: + if v := *((*byte)(p)); v != 0 { + return true + } else { + return false + } + case C.TSDB_DATA_TYPE_TINYINT: + return *((*int8)(p)) + case C.TSDB_DATA_TYPE_SMALLINT: + return *((*int16)(p)) + case C.TSDB_DATA_TYPE_INT: + return *((*int32)(p)) + case C.TSDB_DATA_TYPE_BIGINT: + return *((*int64)(p)) + case C.TSDB_DATA_TYPE_UTINYINT: + return *((*uint8)(p)) + case C.TSDB_DATA_TYPE_USMALLINT: + return *((*uint16)(p)) + case C.TSDB_DATA_TYPE_UINT: + return *((*uint32)(p)) + case C.TSDB_DATA_TYPE_UBIGINT: + return *((*uint64)(p)) + case C.TSDB_DATA_TYPE_FLOAT: + return *((*float32)(p)) + case C.TSDB_DATA_TYPE_DOUBLE: + return *((*float64)(p)) + case C.TSDB_DATA_TYPE_BINARY, C.TSDB_DATA_TYPE_NCHAR: + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = *((*byte)(tools.AddPointer(p, uintptr(i)))) + } + return string(data) + case C.TSDB_DATA_TYPE_TIMESTAMP: + if len(arg) == 1 { + return common.TimestampConvertToTime(*((*int64)(p)), arg[0].(int)) + } else if len(arg) == 2 { + return arg[1].(FormatTimeFunc)(*((*int64)(p)), arg[0].(int)) + } else { + panic("convertTime error") + } + case C.TSDB_DATA_TYPE_JSON, C.TSDB_DATA_TYPE_VARBINARY, C.TSDB_DATA_TYPE_GEOMETRY: + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = *((*byte)(tools.AddPointer(p, uintptr(i)))) + } + return data + default: + return nil + } +} diff --git a/driver/wrapper/row_test.go b/driver/wrapper/row_test.go new file mode 100644 index 00000000..d58e25a8 --- /dev/null +++ b/driver/wrapper/row_test.go @@ -0,0 +1,628 @@ +package wrapper + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/errors" +) + +// @author: xftan +// @date: 2022/1/27 17:24 +// @description: test fetch json result +func TestFetchRowJSON(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + defer func() { + res := TaosQuery(conn, "drop database if exists test_json_wrapper") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res := TaosQuery(conn, "create database if not exists test_json_wrapper") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res := TaosQuery(conn, "drop database if exists test_json_wrapper") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(&errors.TaosError{ + Code: int32(code) & 0xffff, + ErrStr: errStr, + }) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "drop table if exists test_json_wrapper.tjsonr") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(&errors.TaosError{ + Code: int32(code) & 0xffff, + ErrStr: errStr, + }) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, "create stable if not exists test_json_wrapper.tjsonr(ts timestamp,v int )tags(t json)") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(&errors.TaosError{ + Code: int32(code) & 0xffff, + ErrStr: errStr, + }) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, `insert into test_json_wrapper.tjr_1 using test_json_wrapper.tjsonr tags('{"a":1,"b":"b"}')values (now,1)`) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(&errors.TaosError{ + Code: int32(code) & 0xffff, + ErrStr: errStr, + }) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, `insert into test_json_wrapper.tjr_2 using test_json_wrapper.tjsonr tags('{"a":1,"c":"c"}')values (now+1s,1)`) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(&errors.TaosError{ + Code: int32(code) & 0xffff, + ErrStr: errStr, + }) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, `insert into test_json_wrapper.tjr_3 using test_json_wrapper.tjsonr tags('null')values (now+2s,1)`) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(&errors.TaosError{ + Code: int32(code) & 0xffff, + ErrStr: errStr, + }) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, `select * from test_json_wrapper.tjsonr order by ts`) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(&errors.TaosError{ + Code: int32(code) & 0xffff, + ErrStr: errStr, + }) + return + } + numFields := TaosFieldCount(res) + precision := TaosResultPrecision(res) + assert.Equal(t, 3, numFields) + headers, err := ReadColumn(res, numFields) + assert.NoError(t, err) + var data [][]driver.Value + for i := 0; i < 3; i++ { + var d []driver.Value + row := TaosFetchRow(res) + lengths := FetchLengths(res, numFields) + for j := range headers.ColTypes { + d = append(d, FetchRow(row, j, headers.ColTypes[j], lengths[j], precision)) + } + data = append(data, d) + } + TaosFreeResult(res) + t.Log(data) + assert.Equal(t, `{"a":1,"b":"b"}`, string(data[0][2].([]byte))) + assert.Equal(t, `{"a":1,"c":"c"}`, string(data[1][2].([]byte))) + assert.Nil(t, data[2][2]) +} + +// @author: xftan +// @date: 2022/1/27 17:24 +// @description: test TS-781 error +func TestFetchRow(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + db := "test_ts_781" + //create stable stb1 (ts timestamp, name binary(10)) tags(n int); + //insert into tb1 using stb1 tags(1) values(now, 'log'); + //insert into tb2 using stb1 tags(2) values(now, 'test'); + //insert into tb3 using stb1 tags(3) values(now, 'db02'); + //insert into tb4 using stb1 tags(4) values(now, 'db3'); + defer func() { + res := TaosQuery(conn, "drop database if exists "+db) + code := TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + }() + res := TaosQuery(conn, "create database if not exists "+db) + code := TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("create stable if not exists %s.stb1 (ts timestamp, name binary(10)) tags(n int);", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb1 using %s.stb1 tags(1)", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb1 values(now, 'log');", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb2 using %s.stb1 tags(2)", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb2 values(now, 'test');", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb3 using %s.stb1 tags(3)", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb3 values(now, 'db02')", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb4 using %s.stb1 tags(4)", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb4 values(now, 'db3');", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("select distinct(name) from %s.stb1;", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + numFields := TaosFieldCount(res) + header, err := ReadColumn(res, numFields) + if err != nil { + TaosFreeResult(res) + t.Error(err) + return + } + names := map[string]struct{}{ + "log": {}, + "test": {}, + "db02": {}, + "db3": {}, + } + for { + rr := TaosFetchRow(res) + lengths := FetchLengths(res, numFields) + if rr == nil { + break + } + d := FetchRow(rr, 0, header.ColTypes[0], lengths[0]) + delete(names, d.(string)) + } + TaosFreeResult(res) + + assert.Empty(t, names) +} + +// @author: xftan +// @date: 2022/1/27 17:24 +// @description: test TS-781 nchar type error +func TestFetchRowNchar(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + db := "test_ts_781_nchar" + //create stable stb1 (ts timestamp, name nchar(10)) tags(n int); + //insert into tb1 using stb1 tags(1) values(now, 'log'); + //insert into tb2 using stb1 tags(2) values(now, 'test'); + //insert into tb3 using stb1 tags(3) values(now, 'db02'); + //insert into tb4 using stb1 tags(4) values(now, 'db3'); + defer func() { + res := TaosQuery(conn, "drop database if exists "+db) + code := TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + }() + res := TaosQuery(conn, "create database if not exists "+db) + code := TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("create stable if not exists %s.stb1 (ts timestamp, name nchar(10)) tags(n int);", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb1 using %s.stb1 tags(1)", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb2 using %s.stb1 tags(2)", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb3 using %s.stb1 tags(3)", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb4 using %s.stb1 tags(4)", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb1 values(now, 'log');", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb2 values(now, 'test');", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb3 values(now, 'db02')", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb4 values(now, 'db3');", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf("select distinct(name) from %s.stb1;", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + numFields := TaosFieldCount(res) + header, err := ReadColumn(res, numFields) + if err != nil { + TaosFreeResult(res) + t.Error(err) + return + } + names := map[string]struct{}{ + "log": {}, + "test": {}, + "db02": {}, + "db3": {}, + } + for { + rr := TaosFetchRow(res) + lengths := FetchLengths(res, numFields) + if rr == nil { + break + } + d := FetchRow(rr, 0, header.ColTypes[0], lengths[0]) + delete(names, d.(string)) + } + TaosFreeResult(res) + assert.Empty(t, names) +} + +// @author: xftan +// @date: 2023/10/13 11:28 +// @description: test fetch row all type +func TestFetchRowAllType(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + db := "test_fetch_row_all" + + res := TaosQuery(conn, "drop database if exists "+db) + code := TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + defer func() { + res := TaosQuery(conn, "drop database if exists "+db) + code := TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "create database if not exists "+db) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, fmt.Sprintf( + "create stable if not exists %s.stb1 (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20),"+ + "c14 varbinary(20),"+ + "c15 geometry(100)"+ + ")"+ + "tags(t json)", db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, fmt.Sprintf("create table if not exists %s.tb1 using %s.stb1 tags('{\"a\":1}')", db, db)) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + now := time.Now() + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb1 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar','varbinary','POINT(100 100)');", db, now.Format(time.RFC3339Nano))) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, fmt.Sprintf("select * from %s.stb1 where ts = '%s';", db, now.Format(time.RFC3339Nano))) + code = TaosError(res) + if code != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + numFields := TaosFieldCount(res) + header, err := ReadColumn(res, numFields) + if err != nil { + TaosFreeResult(res) + t.Error(err) + return + } + precision := TaosResultPrecision(res) + count := 0 + result := make([]driver.Value, numFields) + for { + rr := TaosFetchRow(res) + if rr == nil { + break + } + count += 1 + lengths := FetchLengths(res, numFields) + + for i := range header.ColTypes { + result[i] = FetchRow(rr, i, header.ColTypes[i], lengths[i], precision) + } + } + assert.Equal(t, 1, count) + assert.Equal(t, now.UnixNano()/1e6, result[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, result[1].(bool)) + assert.Equal(t, int8(2), result[2].(int8)) + assert.Equal(t, int16(3), result[3].(int16)) + assert.Equal(t, int32(4), result[4].(int32)) + assert.Equal(t, int64(5), result[5].(int64)) + assert.Equal(t, uint8(6), result[6].(uint8)) + assert.Equal(t, uint16(7), result[7].(uint16)) + assert.Equal(t, uint32(8), result[8].(uint32)) + assert.Equal(t, uint64(9), result[9].(uint64)) + assert.Equal(t, float32(10), result[10].(float32)) + assert.Equal(t, float64(11), result[11].(float64)) + assert.Equal(t, "binary", result[12].(string)) + assert.Equal(t, "nchar", result[13].(string)) + assert.Equal(t, []byte("varbinary"), result[14].([]byte)) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, result[15].([]byte)) + assert.Equal(t, []byte(`{"a":1}`), result[16].([]byte)) +} diff --git a/driver/wrapper/schemaless.go b/driver/wrapper/schemaless.go new file mode 100644 index 00000000..4c81d618 --- /dev/null +++ b/driver/wrapper/schemaless.go @@ -0,0 +1,233 @@ +package wrapper + +/* +#include +#include +#include +#include +*/ +import "C" +import "unsafe" + +//revive:disable +const ( + InfluxDBLineProtocol = 1 + OpenTSDBTelnetLineProtocol = 2 + OpenTSDBJsonFormatProtocol = 3 +) +const ( + TSDB_SML_TIMESTAMP_NOT_CONFIGURED = iota + TSDB_SML_TIMESTAMP_HOURS + TSDB_SML_TIMESTAMP_MINUTES + TSDB_SML_TIMESTAMP_SECONDS + TSDB_SML_TIMESTAMP_MILLI_SECONDS + TSDB_SML_TIMESTAMP_MICRO_SECONDS + TSDB_SML_TIMESTAMP_NANO_SECONDS +) + +//revive:enable + +// TaosSchemalessInsert TAOS_RES *taos_schemaless_insert(TAOS* taos, char* lines[], int numLines, int protocol, int precision); +// Deprecated +func TaosSchemalessInsert(taosConnect unsafe.Pointer, lines []string, protocol int, precision string) unsafe.Pointer { + numLines, cLines, needFree := taosSchemalessInsertParams(lines) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + return unsafe.Pointer(C.taos_schemaless_insert( + taosConnect, + (**C.char)(&cLines[0]), + (C.int)(numLines), + (C.int)(protocol), + (C.int)(exchange(precision)), + )) +} + +// TaosSchemalessInsertTTL TAOS_RES *taos_schemaless_insert_ttl(TAOS *taos, char *lines[], int numLines, int protocol, int precision, int32_t ttl) +// Deprecated +func TaosSchemalessInsertTTL(taosConnect unsafe.Pointer, lines []string, protocol int, precision string, ttl int) unsafe.Pointer { + numLines, cLines, needFree := taosSchemalessInsertParams(lines) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + return unsafe.Pointer(C.taos_schemaless_insert_ttl( + taosConnect, + (**C.char)(&cLines[0]), + (C.int)(numLines), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + )) +} + +// TaosSchemalessInsertWithReqID TAOS_RES *taos_schemaless_insert_with_reqid(TAOS *taos, char *lines[], int numLines, int protocol, int precision, int64_t reqid); +// Deprecated +func TaosSchemalessInsertWithReqID(taosConnect unsafe.Pointer, lines []string, protocol int, precision string, reqID int64) unsafe.Pointer { + numLines, cLines, needFree := taosSchemalessInsertParams(lines) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + return unsafe.Pointer(C.taos_schemaless_insert_with_reqid( + taosConnect, + (**C.char)(&cLines[0]), + (C.int)(numLines), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int64_t)(reqID), + )) +} + +// TaosSchemalessInsertTTLWithReqID TAOS_RES *taos_schemaless_insert_ttl_with_reqid(TAOS *taos, char *lines[], int numLines, int protocol, int precision, int32_t ttl, int64_t reqid) +// Deprecated +func TaosSchemalessInsertTTLWithReqID(taosConnect unsafe.Pointer, lines []string, protocol int, precision string, ttl int, reqID int64) unsafe.Pointer { + numLines, cLines, needFree := taosSchemalessInsertParams(lines) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + return unsafe.Pointer(C.taos_schemaless_insert_ttl_with_reqid( + taosConnect, + (**C.char)(&cLines[0]), + (C.int)(numLines), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + (C.int64_t)(reqID), + )) +} + +func taosSchemalessInsertParams(lines []string) (numLines int, cLines []*C.char, needFree []unsafe.Pointer) { + numLines = len(lines) + cLines = make([]*C.char, numLines) + needFree = make([]unsafe.Pointer, numLines) + for i, line := range lines { + cLine := C.CString(line) + needFree[i] = unsafe.Pointer(cLine) + cLines[i] = cLine + } + return +} + +// TaosSchemalessInsertRaw TAOS_RES *taos_schemaless_insert_raw(TAOS* taos, char* lines, int len, int32_t *totalRows, int protocol, int precision); +func TaosSchemalessInsertRaw(taosConnect unsafe.Pointer, lines string, protocol int, precision string) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := unsafe.Pointer(C.taos_schemaless_insert_raw( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + )) + return rows, result +} + +// TaosSchemalessInsertRawTTL TAOS_RES *taos_schemaless_insert_raw_ttl(TAOS *taos, char *lines, int len, int32_t *totalRows, int protocol, int precision, int32_t ttl); +func TaosSchemalessInsertRawTTL(taosConnect unsafe.Pointer, lines string, protocol int, precision string, ttl int) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := unsafe.Pointer(C.taos_schemaless_insert_raw_ttl( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + )) + return rows, result +} + +// TaosSchemalessInsertRawWithReqID TAOS_RES *taos_schemaless_insert_raw_with_reqid(TAOS *taos, char *lines, int len, int32_t *totalRows, int protocol, int precision, int64_t reqid); +func TaosSchemalessInsertRawWithReqID(taosConnect unsafe.Pointer, lines string, protocol int, precision string, reqID int64) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := unsafe.Pointer(C.taos_schemaless_insert_raw_with_reqid( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int64_t)(reqID), + )) + return rows, result +} + +// TaosSchemalessInsertRawTTLWithReqID TAOS_RES *taos_schemaless_insert_raw_ttl_with_reqid(TAOS *taos, char *lines, int len, int32_t *totalRows, int protocol, int precision, int32_t ttl, int64_t reqid) +func TaosSchemalessInsertRawTTLWithReqID(taosConnect unsafe.Pointer, lines string, protocol int, precision string, ttl int, reqID int64) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := C.taos_schemaless_insert_raw_ttl_with_reqid( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + (C.int64_t)(reqID), + ) + return rows, result +} + +// TaosSchemalessInsertRawTTLWithReqIDTBNameKey TAOS_RES *taos_schemaless_insert_raw_ttl_with_reqid_tbname_key(TAOS *taos, char *lines, int len, int32_t *totalRows, int protocol, int precision, int32_t ttl, int64_t reqid, char *tbnameKey); +func TaosSchemalessInsertRawTTLWithReqIDTBNameKey(taosConnect unsafe.Pointer, lines string, protocol int, precision string, ttl int, reqID int64, tbNameKey string) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + cTBNameKey := (*C.char)(nil) + if tbNameKey != "" { + cTBNameKey = C.CString(tbNameKey) + defer C.free(unsafe.Pointer(cTBNameKey)) + } + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := C.taos_schemaless_insert_raw_ttl_with_reqid_tbname_key( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + (C.int64_t)(reqID), + cTBNameKey, + ) + return rows, result +} + +func exchange(ts string) int { + switch ts { + case "": + return TSDB_SML_TIMESTAMP_NOT_CONFIGURED + case "h": + return TSDB_SML_TIMESTAMP_HOURS + case "m": + return TSDB_SML_TIMESTAMP_MINUTES + case "s": + return TSDB_SML_TIMESTAMP_SECONDS + case "ms": + return TSDB_SML_TIMESTAMP_MILLI_SECONDS + case "u", "μ": + return TSDB_SML_TIMESTAMP_MICRO_SECONDS + case "ns": + return TSDB_SML_TIMESTAMP_NANO_SECONDS + } + return TSDB_SML_TIMESTAMP_NOT_CONFIGURED +} diff --git a/driver/wrapper/schemaless_test.go b/driver/wrapper/schemaless_test.go new file mode 100644 index 00000000..b47ba8f4 --- /dev/null +++ b/driver/wrapper/schemaless_test.go @@ -0,0 +1,744 @@ +package wrapper_test + +import ( + "strings" + "testing" + "time" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" +) + +func prepareEnv() unsafe.Pointer { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + panic(err) + } + res := wrapper.TaosQuery(conn, "create database if not exists test_schemaless_common") + if wrapper.TaosError(res) != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + panic(errStr) + } + wrapper.TaosFreeResult(res) + code := wrapper.TaosSelectDB(conn, "test_schemaless_common") + if code != 0 { + panic("use db test_schemaless_common fail") + } + return conn +} + +func cleanEnv(conn unsafe.Pointer) { + res := wrapper.TaosQuery(conn, "drop database if exists test_schemaless_common") + if wrapper.TaosError(res) != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + panic(errStr) + } + wrapper.TaosFreeResult(res) +} + +func BenchmarkTelnetSchemaless(b *testing.B) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + for i := 0; i < b.N; i++ { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "sys_if_bytes_out 1636626444 1.3E3 host=web01 interface=eth0", + }, wrapper.OpenTSDBTelnetLineProtocol, "") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + b.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } +} + +// @author: xftan +// @date: 2022/1/27 17:26 +// @description: test schemaless opentsdb telnet +func TestSchemalessTelnet(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + result := wrapper.TaosSchemalessInsert(conn, []string{ + "sys_if_bytes_out 1636626444 1.3E3 host=web01 interface=eth0", + }, wrapper.OpenTSDBTelnetLineProtocol, "") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + s := time.Now() + result = wrapper.TaosSchemalessInsert(conn, []string{ + "sys_if_bytes_out 1636626444 1.3E3 host=web01 interface=eth0", + }, wrapper.OpenTSDBTelnetLineProtocol, "") + code = wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + t.Log("finish ", time.Since(s)) +} + +// @author: xftan +// @date: 2022/1/27 17:26 +// @description: test schemaless influxDB +func TestSchemalessInfluxDB(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577836800000000000", + }, wrapper.InfluxDBLineProtocol, "") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } + { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577836800000000000", + }, wrapper.InfluxDBLineProtocol, "ns") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } + { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577836800000000", + }, wrapper.InfluxDBLineProtocol, "u") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } + { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577836800000000", + }, wrapper.InfluxDBLineProtocol, "μ") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } + { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577836800000", + }, wrapper.InfluxDBLineProtocol, "ms") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } + { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577836800", + }, wrapper.InfluxDBLineProtocol, "s") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } + { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "measurement,host=host1 field1=2i,field2=2.0 26297280", + }, wrapper.InfluxDBLineProtocol, "m") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } + { + result := wrapper.TaosSchemalessInsert(conn, []string{ + "measurement,host=host1 field1=2i,field2=2.0 438288", + }, wrapper.InfluxDBLineProtocol, "h") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(result) + } +} + +// @author: xftan +// @date: 2023/10/13 11:28 +// @description: test schemaless insert with opentsdb telnet line protocol +func TestSchemalessRawTelnet(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + type in struct { + rows []string + } + data := []in{ + { + rows: []string{"sys_if_bytes_out 1636626444 1.3E3 host=web01 interface=eth0"}, + }, + { + rows: []string{"sys_if_bytes_out 1636626444 1.3E3 host=web01 interface=eth0"}, + }, + } + for _, d := range data { + row := strings.Join(d.rows, "\n") + totalRows, result := wrapper.TaosSchemalessInsertRaw(conn, row, wrapper.OpenTSDBTelnetLineProtocol, "") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Log(row) + t.Error(errors.NewError(code, errStr)) + return + } + if int(totalRows) != len(d.rows) { + t.Log(row) + t.Errorf("expect rows %d got %d", len(d.rows), totalRows) + } + affected := wrapper.TaosAffectedRows(result) + if affected != len(d.rows) { + t.Log(row) + t.Errorf("expect affected %d got %d", len(d.rows), affected) + } + wrapper.TaosFreeResult(result) + } +} + +// @author: xftan +// @date: 2023/10/13 11:29 +// @description: test schemaless insert with opentsdb telnet line protocol +func TestSchemalessRawInfluxDB(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + type in struct { + rows []string + precision string + } + data := []in{ + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "ns", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000"}, + precision: "u", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000"}, + precision: "μ", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000"}, + precision: "ms", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800"}, + precision: "s", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 26297280"}, + precision: "m", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 438288"}, + precision: "h", + }, + { + rows: []string{"cpu_value,host=xyzzy,instance=0,type=cpu,type_instance=user value=63843347 1665212955372077566\n"}, + precision: "ns", + }, + } + for _, d := range data { + row := strings.Join(d.rows, "\n") + totalRows, result := wrapper.TaosSchemalessInsertRaw(conn, row, wrapper.InfluxDBLineProtocol, d.precision) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Log(row) + t.Error(errors.NewError(code, errStr)) + return + } + if int(totalRows) != len(d.rows) { + t.Log(row) + t.Errorf("expect rows %d got %d", len(d.rows), totalRows) + } + affected := wrapper.TaosAffectedRows(result) + if affected != len(d.rows) { + t.Log(row) + t.Errorf("expect affected %d got %d", len(d.rows), affected) + } + wrapper.TaosFreeResult(result) + } +} + +// @author: xftan +// @date: 2023/10/13 11:29 +// @description: test schemaless insert raw with reqid +func TestTaosSchemalessInsertRawWithReqID(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + row string + rows int32 + precision string + reqID int64 + }{ + { + name: "1", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836800000000000", + rows: 1, + precision: "", + reqID: 1, + }, + { + name: "2", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836900000000000", + rows: 1, + precision: "ns", + reqID: 2, + }, + { + name: "3", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837000000000", + rows: 1, + precision: "u", + reqID: 3, + }, + { + name: "4", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837100000000", + rows: 1, + precision: "μ", + reqID: 4, + }, + { + name: "5", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837200000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + rows: 2, + precision: "ms", + reqID: 5, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rows, result := wrapper.TaosSchemalessInsertRawWithReqID(conn, c.row, wrapper.InfluxDBLineProtocol, c.precision, c.reqID) + if rows != c.rows { + t.Fatal("rows miss") + } + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +// @author: xftan +// @date: 2023/10/13 11:29 +// @description: test schemaless insert with reqid +func TestTaosSchemalessInsertWithReqID(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + rows []string + precision string + reqID int64 + }{ + { + name: "1", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "", + reqID: 1, + }, + { + name: "2", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836900000000000"}, + precision: "ns", + reqID: 2, + }, + { + name: "3", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577837000000000"}, + precision: "u", + reqID: 3, + }, + { + name: "4", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577837100000000"}, + precision: "μ", + reqID: 4, + }, + { + name: "5", + rows: []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577837200000", + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + }, + precision: "ms", + reqID: 5, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := wrapper.TaosSchemalessInsertWithReqID(conn, c.rows, wrapper.InfluxDBLineProtocol, c.precision, c.reqID) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +// @author: xftan +// @date: 2023/10/13 11:29 +// @description: test schemaless insert with ttl +func TestTaosSchemalessInsertTTL(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + rows []string + precision string + ttl int + }{ + { + name: "1", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "", + ttl: 1000, + }, + { + name: "2", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836900000000000"}, + precision: "ns", + ttl: 1200, + }, + { + name: "3", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577837100000000"}, + precision: "μ", + ttl: 1400, + }, + { + name: "4", + rows: []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577837200000", + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + }, + precision: "ms", + ttl: 1600, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := wrapper.TaosSchemalessInsertTTL(conn, c.rows, wrapper.InfluxDBLineProtocol, c.precision, c.ttl) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test schemaless insert with ttl and reqid +func TestTaosSchemalessInsertTTLWithReqID(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + rows []string + precision string + ttl int + reqId int64 + }{ + { + name: "1", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "", + ttl: 1000, + reqId: 1, + }, + { + name: "2", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836900000000000"}, + precision: "ns", + ttl: 1200, + reqId: 2, + }, + { + name: "3", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577837100000000"}, + precision: "μ", + ttl: 1400, + reqId: 3, + }, + { + name: "4", + rows: []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577837200000", + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + }, + precision: "ms", + ttl: 1600, + reqId: 4, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := wrapper.TaosSchemalessInsertTTLWithReqID(conn, c.rows, wrapper.InfluxDBLineProtocol, c.precision, c.ttl, c.reqId) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test schemaless insert raw with ttl +func TestTaosSchemalessInsertRawTTL(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + row string + rows int32 + precision string + ttl int + }{ + { + name: "1", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836800000000000", + rows: 1, + precision: "", + ttl: 1000, + }, + { + name: "2", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836900000000000", + rows: 1, + precision: "ns", + ttl: 1200, + }, + { + name: "3", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837200000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + rows: 2, + precision: "ms", + ttl: 1400, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rows, result := wrapper.TaosSchemalessInsertRawTTL(conn, c.row, wrapper.InfluxDBLineProtocol, c.precision, c.ttl) + if rows != c.rows { + t.Fatal("rows miss") + } + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test schemaless insert raw with ttl and reqid +func TestTaosSchemalessInsertRawTTLWithReqID(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + row string + rows int32 + precision string + ttl int + reqID int64 + }{ + { + name: "1", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836800000000000", + rows: 1, + precision: "", + ttl: 1000, + reqID: 1, + }, + { + name: "2", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836900000000000", + rows: 1, + precision: "ns", + ttl: 1200, + reqID: 2, + }, + { + name: "3", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837200000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + rows: 2, + precision: "ms", + ttl: 1400, + reqID: 3, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rows, result := wrapper.TaosSchemalessInsertRawTTLWithReqID(conn, c.row, wrapper.InfluxDBLineProtocol, c.precision, c.ttl, c.reqID) + if rows != c.rows { + t.Fatal("rows miss") + } + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +func TestTaosSchemalessInsertRawTTLWithReqIDTBNameKey(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + //defer cleanEnv(conn) + cases := []struct { + name string + row string + rows int32 + precision string + ttl int + reqID int64 + tbNameKey string + }{ + { + name: "1", + row: "measurement,host=host1 field1=2i,field2=1.0 1577836800000000000", + rows: 1, + precision: "", + ttl: 1000, + reqID: 1, + tbNameKey: "host", + }, + { + name: "2", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836900000000000", + rows: 1, + precision: "ns", + ttl: 1200, + reqID: 2, + tbNameKey: "host", + }, + { + name: "3", + row: "measurement,host=host1 field1=2i,field2=3.0 1577837200000\n" + + "measurement,host=host1 field1=2i,field2=4.0 1577837300000", + rows: 2, + precision: "ms", + ttl: 1400, + reqID: 3, + tbNameKey: "host", + }, + { + name: "no table name key", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836900000000000", + rows: 1, + precision: "ns", + ttl: 1200, + reqID: 2, + tbNameKey: "", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rows, result := wrapper.TaosSchemalessInsertRawTTLWithReqIDTBNameKey(conn, c.row, wrapper.InfluxDBLineProtocol, c.precision, c.ttl, c.reqID, c.tbNameKey) + if rows != c.rows { + t.Fatal("rows miss") + } + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} diff --git a/driver/wrapper/setconfig.go b/driver/wrapper/setconfig.go new file mode 100644 index 00000000..93db562a --- /dev/null +++ b/driver/wrapper/setconfig.go @@ -0,0 +1,42 @@ +package wrapper + +/* +#include +#include +#include +#include +*/ +import "C" +import ( + "strings" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/errors" +) + +// TaosSetConfig int taos_set_config(const char *config); +func TaosSetConfig(params map[string]string) error { + if len(params) == 0 { + return nil + } + buf := &strings.Builder{} + for k, v := range params { + buf.WriteString(k) + buf.WriteString(" ") + buf.WriteString(v) + } + cConfig := C.CString(buf.String()) + defer C.free(unsafe.Pointer(cConfig)) + result := (C.struct_setConfRet)(C.taos_set_config(cConfig)) + if int(result.retCode) == -5 || int(result.retCode) == 0 { + return nil + } + buf.Reset() + for _, c := range result.retMsg { + if c == 0 { + break + } + buf.WriteByte(byte(c)) + } + return &errors.TaosError{Code: int32(result.retCode) & 0xffff, ErrStr: buf.String()} +} diff --git a/driver/wrapper/setconfig_test.go b/driver/wrapper/setconfig_test.go new file mode 100644 index 00000000..5589e0dc --- /dev/null +++ b/driver/wrapper/setconfig_test.go @@ -0,0 +1,41 @@ +package wrapper + +import ( + "testing" +) + +// @author: xftan +// @date: 2022/1/27 17:27 +// @description: test taos_set_config +func TestSetConfig(t *testing.T) { + source := map[string]string{ + "numOfThreadsPerCore": "1.000000", + "rpcTimer": "300", + "rpcForceTcp": "0", + "rpcMaxTime": "600", + "compressMsgSize": "-1", + "maxSQLLength": "1048576", + "maxWildCardsLength": "100", + "maxNumOfOrderedRes": "100000", + "keepColumnName": "0", + "timezone": "Asia/Shanghai (CST, +0800)", + "locale": "C.UTF-8", + "charset": "UTF-8", + "numOfLogLines": "10000000", + "asyncLog": "1", + "debugFlag": "135", + "rpcDebugFlag": "131", + "tmrDebugFlag": "131", + "cDebugFlag": "131", + "jniDebugFlag": "131", + "odbcDebugFlag": "131", + "uDebugFlag": "131", + "qDebugFlag": "131", + "maxBinaryDisplayWidth": "30", + "tempDir": "/tmp/", + } + err := TaosSetConfig(source) + if err != nil { + t.Error(err) + } +} diff --git a/driver/wrapper/stmt.go b/driver/wrapper/stmt.go new file mode 100644 index 00000000..734cc33f --- /dev/null +++ b/driver/wrapper/stmt.go @@ -0,0 +1,756 @@ +package wrapper + +/* +#include +#include +#include +#include +*/ +import "C" +import ( + "bytes" + "database/sql/driver" + "errors" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/stmt" + taosError "github.com/taosdata/taosadapter/v3/driver/errors" + taosTypes "github.com/taosdata/taosadapter/v3/driver/types" +) + +// TaosStmtInit TAOS_STMT *taos_stmt_init(TAOS *taos); +func TaosStmtInit(taosConnect unsafe.Pointer) unsafe.Pointer { + return C.taos_stmt_init(taosConnect) +} + +// TaosStmtInitWithReqID TAOS_STMT *taos_stmt_init_with_reqid(TAOS *taos, int64_t reqid); +func TaosStmtInitWithReqID(taosConn unsafe.Pointer, reqID int64) unsafe.Pointer { + return C.taos_stmt_init_with_reqid(taosConn, (C.int64_t)(reqID)) +} + +// TaosStmtPrepare int taos_stmt_prepare(TAOS_STMT *stmt, const char *sql, unsigned long length); +func TaosStmtPrepare(stmt unsafe.Pointer, sql string) int { + cSql := C.CString(sql) + cLen := C.ulong(len(sql)) + defer C.free(unsafe.Pointer(cSql)) + return int(C.taos_stmt_prepare(stmt, cSql, cLen)) +} + +//typedef struct TAOS_MULTI_BIND { +//int buffer_type; +//void *buffer; +//int32_t buffer_length; +//int32_t *length; +//char *is_null; +//int num; +//} TAOS_MULTI_BIND; + +// TaosStmtSetTags int taos_stmt_set_tags(TAOS_STMT *stmt, TAOS_MULTI_BIND *tags); +func TaosStmtSetTags(stmt unsafe.Pointer, tags []driver.Value) int { + if len(tags) == 0 { + return int(C.taos_stmt_set_tags(stmt, nil)) + } + binds, needFreePointer, err := generateTaosBindList(tags) + defer func() { + for _, pointer := range needFreePointer { + C.free(pointer) + } + }() + if err != nil { + return -1 + } + result := int(C.taos_stmt_set_tags(stmt, (*C.TAOS_MULTI_BIND)(&binds[0]))) + return result +} + +// TaosStmtSetTBNameTags int taos_stmt_set_tbname_tags(TAOS_STMT* stmt, const char* name, TAOS_MULTI_BIND* tags); +func TaosStmtSetTBNameTags(stmt unsafe.Pointer, name string, tags []driver.Value) int { + cStr := C.CString(name) + defer C.free(unsafe.Pointer(cStr)) + if len(tags) == 0 { + return int(C.taos_stmt_set_tbname_tags(stmt, cStr, nil)) + } + binds, needFreePointer, err := generateTaosBindList(tags) + defer func() { + for _, pointer := range needFreePointer { + C.free(pointer) + } + }() + if err != nil { + return -1 + } + result := int(C.taos_stmt_set_tbname_tags(stmt, cStr, (*C.TAOS_MULTI_BIND)(&binds[0]))) + return result +} + +// TaosStmtSetTBName int taos_stmt_set_tbname(TAOS_STMT* stmt, const char* name); +func TaosStmtSetTBName(stmt unsafe.Pointer, name string) int { + cStr := C.CString(name) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_stmt_set_tbname(stmt, cStr)) +} + +// TaosStmtIsInsert int taos_stmt_is_insert(TAOS_STMT *stmt, int *insert); +func TaosStmtIsInsert(stmt unsafe.Pointer) (is bool, errorCode int) { + p := C.malloc(C.size_t(4)) + isInsert := (*C.int)(p) + defer C.free(p) + errorCode = int(C.taos_stmt_is_insert(stmt, isInsert)) + return int(*isInsert) == 1, errorCode +} + +// TaosStmtNumParams int taos_stmt_num_params(TAOS_STMT *stmt, int *nums); +func TaosStmtNumParams(stmt unsafe.Pointer) (count int, errorCode int) { + p := C.malloc(C.size_t(4)) + num := (*C.int)(p) + defer C.free(p) + errorCode = int(C.taos_stmt_num_params(stmt, num)) + return int(*num), errorCode +} + +// TaosStmtBindParam int taos_stmt_bind_param(TAOS_STMT *stmt, TAOS_MULTI_BIND *bind); +func TaosStmtBindParam(stmt unsafe.Pointer, params []driver.Value) int { + if len(params) == 0 { + return int(C.taos_stmt_bind_param(stmt, nil)) + } + binds, needFreePointer, err := generateTaosBindList(params) + defer func() { + for _, pointer := range needFreePointer { + if pointer != nil { + C.free(pointer) + } + } + }() + if err != nil { + return -1 + } + result := int(C.taos_stmt_bind_param(stmt, (*C.TAOS_MULTI_BIND)(unsafe.Pointer(&binds[0])))) + return result +} + +func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe.Pointer, error) { + binds := make([]C.TAOS_MULTI_BIND, len(params)) + var needFreePointer []unsafe.Pointer + for i, param := range params { + bind := C.TAOS_MULTI_BIND{} + bind.num = C.int(1) + if param == nil { + bind.buffer_type = C.TSDB_DATA_TYPE_BOOL + p := C.malloc(1) + *(*C.char)(p) = C.char(1) + needFreePointer = append(needFreePointer, p) + bind.is_null = (*C.char)(p) + } else { + switch value := param.(type) { + case taosTypes.TaosBool: + bind.buffer_type = C.TSDB_DATA_TYPE_BOOL + p := C.malloc(1) + if value { + *(*C.int8_t)(p) = C.int8_t(1) + } else { + *(*C.int8_t)(p) = C.int8_t(0) + } + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(1) + case taosTypes.TaosTinyint: + bind.buffer_type = C.TSDB_DATA_TYPE_TINYINT + p := C.malloc(1) + *(*C.int8_t)(p) = C.int8_t(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(1) + case taosTypes.TaosSmallint: + bind.buffer_type = C.TSDB_DATA_TYPE_SMALLINT + p := C.malloc(2) + *(*C.int16_t)(p) = C.int16_t(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(2) + case taosTypes.TaosInt: + bind.buffer_type = C.TSDB_DATA_TYPE_INT + p := C.malloc(4) + *(*C.int32_t)(p) = C.int32_t(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(4) + case taosTypes.TaosBigint: + bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT + p := C.malloc(8) + *(*C.int64_t)(p) = C.int64_t(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(8) + case taosTypes.TaosUTinyint: + bind.buffer_type = C.TSDB_DATA_TYPE_UTINYINT + cbuf := C.malloc(1) + *(*C.uint8_t)(cbuf) = C.uint8_t(value) + needFreePointer = append(needFreePointer, cbuf) + bind.buffer = cbuf + bind.buffer_length = C.uintptr_t(1) + case taosTypes.TaosUSmallint: + bind.buffer_type = C.TSDB_DATA_TYPE_USMALLINT + p := C.malloc(2) + *(*C.uint16_t)(p) = C.uint16_t(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(2) + case taosTypes.TaosUInt: + bind.buffer_type = C.TSDB_DATA_TYPE_UINT + p := C.malloc(4) + *(*C.uint32_t)(p) = C.uint32_t(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(4) + case taosTypes.TaosUBigint: + bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT + p := C.malloc(8) + *(*C.uint64_t)(p) = C.uint64_t(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(8) + case taosTypes.TaosFloat: + bind.buffer_type = C.TSDB_DATA_TYPE_FLOAT + p := C.malloc(4) + *(*C.float)(p) = C.float(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(4) + case taosTypes.TaosDouble: + bind.buffer_type = C.TSDB_DATA_TYPE_DOUBLE + p := C.malloc(8) + *(*C.double)(p) = C.double(value) + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.buffer_length = C.uintptr_t(8) + case taosTypes.TaosBinary: + bind.buffer_type = C.TSDB_DATA_TYPE_BINARY + cbuf := C.CString(string(value)) + needFreePointer = append(needFreePointer, unsafe.Pointer(cbuf)) + bind.buffer = unsafe.Pointer(cbuf) + clen := int32(len(value)) + p := C.malloc(C.size_t(unsafe.Sizeof(clen))) + bind.length = (*C.int32_t)(p) + *(bind.length) = C.int32_t(clen) + needFreePointer = append(needFreePointer, p) + bind.buffer_length = C.uintptr_t(clen) + case taosTypes.TaosVarBinary: + bind.buffer_type = C.TSDB_DATA_TYPE_VARBINARY + cbuf := C.CString(string(value)) + needFreePointer = append(needFreePointer, unsafe.Pointer(cbuf)) + bind.buffer = unsafe.Pointer(cbuf) + clen := int32(len(value)) + p := C.malloc(C.size_t(unsafe.Sizeof(clen))) + bind.length = (*C.int32_t)(p) + *(bind.length) = C.int32_t(clen) + needFreePointer = append(needFreePointer, p) + bind.buffer_length = C.uintptr_t(clen) + case taosTypes.TaosGeometry: + bind.buffer_type = C.TSDB_DATA_TYPE_GEOMETRY + cbuf := C.CString(string(value)) + needFreePointer = append(needFreePointer, unsafe.Pointer(cbuf)) + bind.buffer = unsafe.Pointer(cbuf) + clen := int32(len(value)) + p := C.malloc(C.size_t(unsafe.Sizeof(clen))) + bind.length = (*C.int32_t)(p) + *(bind.length) = C.int32_t(clen) + needFreePointer = append(needFreePointer, p) + bind.buffer_length = C.uintptr_t(clen) + case taosTypes.TaosNchar: + bind.buffer_type = C.TSDB_DATA_TYPE_NCHAR + p := unsafe.Pointer(C.CString(string(value))) + needFreePointer = append(needFreePointer, p) + bind.buffer = unsafe.Pointer(p) + clen := int32(len(value)) + bind.length = (*C.int32_t)(C.malloc(C.size_t(unsafe.Sizeof(clen)))) + *(bind.length) = C.int32_t(clen) + needFreePointer = append(needFreePointer, unsafe.Pointer(bind.length)) + bind.buffer_length = C.uintptr_t(clen) + case taosTypes.TaosTimestamp: + bind.buffer_type = C.TSDB_DATA_TYPE_TIMESTAMP + ts := common.TimeToTimestamp(value.T, value.Precision) + p := C.malloc(8) + needFreePointer = append(needFreePointer, p) + *(*C.int64_t)(p) = C.int64_t(ts) + bind.buffer = p + bind.buffer_length = C.uintptr_t(8) + case taosTypes.TaosJson: + bind.buffer_type = C.TSDB_DATA_TYPE_JSON + cbuf := C.CString(string(value)) + needFreePointer = append(needFreePointer, unsafe.Pointer(cbuf)) + bind.buffer = unsafe.Pointer(cbuf) + clen := int32(len(value)) + p := C.malloc(C.size_t(unsafe.Sizeof(clen))) + bind.length = (*C.int32_t)(p) + *(bind.length) = C.int32_t(clen) + needFreePointer = append(needFreePointer, p) + bind.buffer_length = C.uintptr_t(clen) + default: + return nil, nil, errors.New("unsupported type") + } + } + binds[i] = bind + } + return binds, needFreePointer, nil +} + +// TaosStmtAddBatch int taos_stmt_add_batch(TAOS_STMT *stmt); +func TaosStmtAddBatch(stmt unsafe.Pointer) int { + return int(C.taos_stmt_add_batch(stmt)) +} + +// TaosStmtExecute int taos_stmt_execute(TAOS_STMT *stmt); +func TaosStmtExecute(stmt unsafe.Pointer) int { + return int(C.taos_stmt_execute(stmt)) +} + +// TaosStmtUseResult TAOS_RES * taos_stmt_use_result(TAOS_STMT *stmt); +func TaosStmtUseResult(stmt unsafe.Pointer) unsafe.Pointer { + return C.taos_stmt_use_result(stmt) +} + +// TaosStmtClose int taos_stmt_close(TAOS_STMT *stmt); +func TaosStmtClose(stmt unsafe.Pointer) int { + return int(C.taos_stmt_close(stmt)) +} + +// TaosStmtSetSubTBName int taos_stmt_set_sub_tbname(TAOS_STMT* stmt, const char* name); +func TaosStmtSetSubTBName(stmt unsafe.Pointer, name string) int { + cStr := C.CString(name) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_stmt_set_tbname(stmt, cStr)) +} + +// TaosStmtBindParamBatch int taos_stmt_bind_param_batch(TAOS_STMT* stmt, TAOS_MULTI_BIND* bind); +func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bindType []*taosTypes.ColumnType) int { + var binds = make([]C.TAOS_MULTI_BIND, len(multiBind)) + var needFreePointer []unsafe.Pointer + defer func() { + for _, pointer := range needFreePointer { + C.free(pointer) + } + }() + for columnIndex, columnData := range multiBind { + bind := C.TAOS_MULTI_BIND{} + //malloc + rowLen := len(multiBind[0]) + bind.num = C.int(rowLen) + nullList := unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + needFreePointer = append(needFreePointer, nullList) + lengthList := unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen * 4)))) + needFreePointer = append(needFreePointer, lengthList) + var p unsafe.Pointer + columnType := bindType[columnIndex] + switch columnType.Type { + case taosTypes.TaosBoolType: + //1 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_BOOL + bind.buffer_length = C.uintptr_t(1) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosBool) + current := unsafe.Pointer(uintptr(p) + uintptr(i)) + if value { + *(*C.int8_t)(current) = C.int8_t(1) + } else { + *(*C.int8_t)(current) = C.int8_t(0) + } + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) + } + } + case taosTypes.TaosTinyintType: + //1 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_TINYINT + bind.buffer_length = C.uintptr_t(1) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosTinyint) + current := unsafe.Pointer(uintptr(p) + uintptr(i)) + *(*C.int8_t)(current) = C.int8_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) + } + } + case taosTypes.TaosSmallintType: + //2 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_SMALLINT + bind.buffer_length = C.uintptr_t(2) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosSmallint) + current := unsafe.Pointer(uintptr(p) + uintptr(2*i)) + *(*C.int16_t)(current) = C.int16_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(2) + } + } + case taosTypes.TaosIntType: + //4 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_INT + bind.buffer_length = C.uintptr_t(4) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosInt) + current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) + *(*C.int32_t)(current) = C.int32_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) + } + } + case taosTypes.TaosBigintType: + //8 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT + bind.buffer_length = C.uintptr_t(8) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosBigint) + current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) + *(*C.int64_t)(current) = C.int64_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) + } + } + case taosTypes.TaosUTinyintType: + //1 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_UTINYINT + bind.buffer_length = C.uintptr_t(1) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosUTinyint) + current := unsafe.Pointer(uintptr(p) + uintptr(i)) + *(*C.uint8_t)(current) = C.uint8_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) + } + } + case taosTypes.TaosUSmallintType: + //2 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_USMALLINT + bind.buffer_length = C.uintptr_t(2) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosUSmallint) + current := unsafe.Pointer(uintptr(p) + uintptr(2*i)) + *(*C.uint16_t)(current) = C.uint16_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(2) + } + } + case taosTypes.TaosUIntType: + //4 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_UINT + bind.buffer_length = C.uintptr_t(4) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosUInt) + current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) + *(*C.uint32_t)(current) = C.uint32_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) + } + } + case taosTypes.TaosUBigintType: + //8 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT + bind.buffer_length = C.uintptr_t(8) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosUBigint) + current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) + *(*C.uint64_t)(current) = C.uint64_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) + } + } + case taosTypes.TaosFloatType: + //4 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_FLOAT + bind.buffer_length = C.uintptr_t(4) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosFloat) + current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) + *(*C.float)(current) = C.float(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) + } + } + case taosTypes.TaosDoubleType: + //8 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_DOUBLE + bind.buffer_length = C.uintptr_t(8) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosDouble) + current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) + *(*C.double)(current) = C.double(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) + } + } + case taosTypes.TaosBinaryType: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(columnType.MaxLen * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_BINARY + bind.buffer_length = C.uintptr_t(columnType.MaxLen) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosBinary) + for j := 0; j < len(value); j++ { + *(*C.char)(unsafe.Pointer(uintptr(p) + uintptr(columnType.MaxLen*i+j))) = (C.char)(value[j]) + } + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(len(value)) + } + } + case taosTypes.TaosVarBinaryType: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(columnType.MaxLen * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_VARBINARY + bind.buffer_length = C.uintptr_t(columnType.MaxLen) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosVarBinary) + for j := 0; j < len(value); j++ { + *(*C.char)(unsafe.Pointer(uintptr(p) + uintptr(columnType.MaxLen*i+j))) = (C.char)(value[j]) + } + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(len(value)) + } + } + case taosTypes.TaosGeometryType: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(columnType.MaxLen * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_GEOMETRY + bind.buffer_length = C.uintptr_t(columnType.MaxLen) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosGeometry) + for j := 0; j < len(value); j++ { + *(*C.char)(unsafe.Pointer(uintptr(p) + uintptr(columnType.MaxLen*i+j))) = (C.char)(value[j]) + } + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(len(value)) + } + } + case taosTypes.TaosNcharType: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(columnType.MaxLen * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_NCHAR + bind.buffer_length = C.uintptr_t(columnType.MaxLen) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosNchar) + for j := 0; j < len(value); j++ { + *(*C.char)(unsafe.Pointer(uintptr(p) + uintptr(columnType.MaxLen*i+j))) = (C.char)(value[j]) + } + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(len(value)) + } + } + case taosTypes.TaosTimestampType: + //8 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_TIMESTAMP + bind.buffer_length = C.uintptr_t(8) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosTimestamp) + ts := common.TimeToTimestamp(value.T, value.Precision) + current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) + *(*C.int64_t)(current) = C.int64_t(ts) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) + } + } + } + needFreePointer = append(needFreePointer, p) + bind.buffer = p + bind.length = (*C.int32_t)(lengthList) + bind.is_null = (*C.char)(nullList) + binds[columnIndex] = bind + } + return int(C.taos_stmt_bind_param_batch(stmt, (*C.TAOS_MULTI_BIND)(&binds[0]))) +} + +// TaosStmtErrStr char *taos_stmt_errstr(TAOS_STMT *stmt); +func TaosStmtErrStr(stmt unsafe.Pointer) string { + return C.GoString(C.taos_stmt_errstr(stmt)) +} + +// TaosStmtAffectedRows int taos_stmt_affected_rows(TAOS_STMT *stmt); +func TaosStmtAffectedRows(stmt unsafe.Pointer) int { + return int(C.taos_stmt_affected_rows(stmt)) +} + +// TaosStmtAffectedRowsOnce int taos_stmt_affected_rows_once(TAOS_STMT *stmt); +func TaosStmtAffectedRowsOnce(stmt unsafe.Pointer) int { + return int(C.taos_stmt_affected_rows_once(stmt)) +} + +//typedef struct TAOS_FIELD_E { +//char name[65]; +//int8_t type; +//uint8_t precision; +//uint8_t scale; +//int32_t bytes; +//} TAOS_FIELD_E; + +// TaosStmtGetTagFields DLL_EXPORT int taos_stmt_get_tag_fields(TAOS_STMT *stmt, int* fieldNum, TAOS_FIELD_E** fields); +func TaosStmtGetTagFields(stmt unsafe.Pointer) (code, num int, fields unsafe.Pointer) { + cNum := unsafe.Pointer(&num) + var cField *C.TAOS_FIELD_E + code = int(C.taos_stmt_get_tag_fields(stmt, (*C.int)(cNum), (**C.TAOS_FIELD_E)(unsafe.Pointer(&cField)))) + if code != 0 { + return code, num, nil + } + if num == 0 { + return code, num, nil + } + return code, num, unsafe.Pointer(cField) +} + +// TaosStmtGetColFields DLL_EXPORT int taos_stmt_get_col_fields(TAOS_STMT *stmt, int* fieldNum, TAOS_FIELD_E** fields); +func TaosStmtGetColFields(stmt unsafe.Pointer) (code, num int, fields unsafe.Pointer) { + cNum := unsafe.Pointer(&num) + var cField *C.TAOS_FIELD_E + code = int(C.taos_stmt_get_col_fields(stmt, (*C.int)(cNum), (**C.TAOS_FIELD_E)(unsafe.Pointer(&cField)))) + if code != 0 { + return code, num, nil + } + if num == 0 { + return code, num, nil + } + return code, num, unsafe.Pointer(cField) +} + +func StmtParseFields(num int, fields unsafe.Pointer) []*stmt.StmtField { + if num == 0 { + return nil + } + if fields == nil { + return nil + } + result := make([]*stmt.StmtField, num) + buf := bytes.NewBufferString("") + for i := 0; i < num; i++ { + r := &stmt.StmtField{} + field := *(*C.TAOS_FIELD_E)(unsafe.Pointer(uintptr(fields) + uintptr(C.sizeof_struct_TAOS_FIELD_E*C.int(i)))) + for _, c := range field.name { + if c == 0 { + break + } + buf.WriteByte(byte(c)) + } + r.Name = buf.String() + buf.Reset() + r.FieldType = int8(field._type) + r.Precision = uint8(field.precision) + r.Scale = uint8(field.scale) + r.Bytes = int32(field.bytes) + result[i] = r + } + return result +} + +// TaosStmtReclaimFields DLL_EXPORT void taos_stmt_reclaim_fields(TAOS_STMT *stmt, TAOS_FIELD_E *fields); +func TaosStmtReclaimFields(stmt unsafe.Pointer, fields unsafe.Pointer) { + C.taos_stmt_reclaim_fields(stmt, (*C.TAOS_FIELD_E)(fields)) +} + +// TaosStmtGetParam DLL_EXPORT int taos_stmt_get_param(TAOS_STMT *stmt, int idx, int *type, int *bytes) +func TaosStmtGetParam(stmt unsafe.Pointer, idx int) (dataType int, dataLength int, err error) { + code := C.taos_stmt_get_param(stmt, C.int(idx), (*C.int)(unsafe.Pointer(&dataType)), (*C.int)(unsafe.Pointer(&dataLength))) + if code != 0 { + err = &taosError.TaosError{ + Code: int32(code), + ErrStr: TaosStmtErrStr(stmt), + } + } + return +} diff --git a/driver/wrapper/stmt2.go b/driver/wrapper/stmt2.go new file mode 100644 index 00000000..09eb1cbb --- /dev/null +++ b/driver/wrapper/stmt2.go @@ -0,0 +1,857 @@ +package wrapper + +/* +#include +#include +#include +#include + +extern void Stmt2ExecCallback(void *param,TAOS_RES *,int code); +//TAOS_STMT2 *taos_stmt2_init(TAOS *taos, TAOS_STMT2_OPTION *option); +TAOS_STMT2 * taos_stmt2_init_wrapper(TAOS *taos, int64_t reqid, bool singleStbInsert,bool singleTableBindOnce, void *param){ + TAOS_STMT2_OPTION option = {reqid, singleStbInsert, singleTableBindOnce, Stmt2ExecCallback , param}; + return taos_stmt2_init(taos,&option); +}; +*/ +import "C" +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "time" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/stmt" + taosError "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/tools" +) + +// TaosStmt2Init TAOS_STMT2 *taos_stmt2_init(TAOS *taos, TAOS_STMT2_OPTION *option); +func TaosStmt2Init(taosConnect unsafe.Pointer, reqID int64, singleStbInsert bool, singleTableBindOnce bool, handler cgo.Handle) unsafe.Pointer { + return C.taos_stmt2_init_wrapper(taosConnect, C.int64_t(reqID), C.bool(singleStbInsert), C.bool(singleTableBindOnce), handler.Pointer()) +} + +// TaosStmt2Prepare int taos_stmt2_prepare(TAOS_STMT2 *stmt, const char *sql, unsigned long length); +func TaosStmt2Prepare(stmt unsafe.Pointer, sql string) int { + cSql := C.CString(sql) + cLen := C.ulong(len(sql)) + defer C.free(unsafe.Pointer(cSql)) + return int(C.taos_stmt2_prepare(stmt, cSql, cLen)) +} + +// TaosStmt2BindParam int taos_stmt2_bind_param(TAOS_STMT2 *stmt, TAOS_STMT2_BINDV *bindv, int32_t col_idx); +func TaosStmt2BindParam(stmt unsafe.Pointer, isInsert bool, params []*stmt.TaosStmt2BindData, colTypes, tagTypes []*stmt.StmtField, colIdx int32) error { + count := len(params) + if count == 0 { + return taosError.NewError(0xffff, "params is empty") + } + cBindv := C.TAOS_STMT2_BINDV{} + cBindv.count = C.int(count) + tbNames := unsafe.Pointer(C.malloc(C.size_t(count) * C.size_t(PointerSize))) + needFreePointer := []unsafe.Pointer{tbNames} + defer func() { + for i := len(needFreePointer) - 1; i >= 0; i-- { + if needFreePointer[i] != nil { + C.free(needFreePointer[i]) + } + } + }() + tagList := C.malloc(C.size_t(count) * C.size_t(PointerSize)) + needFreePointer = append(needFreePointer, unsafe.Pointer(tagList)) + colList := C.malloc(C.size_t(count) * C.size_t(PointerSize)) + needFreePointer = append(needFreePointer, unsafe.Pointer(colList)) + var currentTbNameP unsafe.Pointer + var currentTagP unsafe.Pointer + var currentColP unsafe.Pointer + for i, param := range params { + //parse table name + currentTbNameP = tools.AddPointer(tbNames, uintptr(i)*PointerSize) + if param.TableName != "" { + if !isInsert { + return taosError.NewError(0xffff, "table name is not allowed in query statement") + } + tbName := C.CString(param.TableName) + needFreePointer = append(needFreePointer, unsafe.Pointer(tbName)) + *(**C.char)(currentTbNameP) = tbName + } else { + *(**C.char)(currentTbNameP) = nil + } + //parse tags + currentTagP = tools.AddPointer(tagList, uintptr(i)*PointerSize) + if len(param.Tags) > 0 { + if !isInsert { + return taosError.NewError(0xffff, "tag is not allowed in query statement") + } + //transpose + columnFormatTags := make([][]driver.Value, len(param.Tags)) + for j := 0; j < len(param.Tags); j++ { + columnFormatTags[j] = []driver.Value{param.Tags[j]} + } + tags, freePointer, err := generateTaosStmt2BindsInsert(columnFormatTags, tagTypes) + needFreePointer = append(needFreePointer, freePointer...) + if err != nil { + return taosError.NewError(0xffff, fmt.Sprintf("generate tags Bindv struct error: %s", err.Error())) + } + *(**C.TAOS_STMT2_BIND)(currentTagP) = (*C.TAOS_STMT2_BIND)(tags) + } else { + *(**C.TAOS_STMT2_BIND)(currentTagP) = nil + } + // parse cols + currentColP = tools.AddPointer(colList, uintptr(i)*PointerSize) + if len(param.Cols) > 0 { + var err error + var cols unsafe.Pointer + var freePointer []unsafe.Pointer + if isInsert { + cols, freePointer, err = generateTaosStmt2BindsInsert(param.Cols, colTypes) + } else { + cols, freePointer, err = generateTaosStmt2BindsQuery(param.Cols) + } + needFreePointer = append(needFreePointer, freePointer...) + if err != nil { + return taosError.NewError(0xffff, fmt.Sprintf("generate cols Bindv struct error: %s", err.Error())) + } + *(**C.TAOS_STMT2_BIND)(currentColP) = (*C.TAOS_STMT2_BIND)(cols) + } else { + *(**C.TAOS_STMT2_BIND)(currentColP) = nil + } + } + cBindv.bind_cols = (**C.TAOS_STMT2_BIND)(unsafe.Pointer(colList)) + cBindv.tags = (**C.TAOS_STMT2_BIND)(unsafe.Pointer(tagList)) + cBindv.tbnames = (**C.char)(tbNames) + code := int(C.taos_stmt2_bind_param(stmt, &cBindv, C.int32_t(colIdx))) + if code != 0 { + errStr := TaosStmt2Error(stmt) + return taosError.NewError(code, errStr) + } + return nil +} + +func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt.StmtField) (unsafe.Pointer, []unsafe.Pointer, error) { + var needFreePointer []unsafe.Pointer + if len(multiBind) != len(fieldTypes) { + return nil, needFreePointer, fmt.Errorf("data and type length not match, data length: %d, type length: %d", len(multiBind), len(fieldTypes)) + } + binds := unsafe.Pointer(C.malloc(C.size_t(C.size_t(len(multiBind)) * C.size_t(unsafe.Sizeof(C.TAOS_STMT2_BIND{}))))) + needFreePointer = append(needFreePointer, binds) + rowLen := len(multiBind[0]) + for columnIndex, columnData := range multiBind { + if len(multiBind[columnIndex]) != rowLen { + return nil, needFreePointer, fmt.Errorf("data length not match, column %d data length: %d, expect: %d", columnIndex, len(multiBind[columnIndex]), rowLen) + } + bind := (*C.TAOS_STMT2_BIND)(unsafe.Pointer(uintptr(binds) + uintptr(columnIndex)*unsafe.Sizeof(C.TAOS_STMT2_BIND{}))) + bind.num = C.int(rowLen) + nullList := unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + needFreePointer = append(needFreePointer, nullList) + lengthList := unsafe.Pointer(C.calloc(C.size_t(C.uint(rowLen)), C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, lengthList) + var p unsafe.Pointer + columnType := fieldTypes[columnIndex].FieldType + precision := int(fieldTypes[columnIndex].Precision) + switch columnType { + case common.TSDB_DATA_TYPE_BOOL: + //1 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_BOOL + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(bool) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect bool, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(i)) + if value { + *(*C.int8_t)(current) = C.int8_t(1) + } else { + *(*C.int8_t)(current) = C.int8_t(0) + } + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) + } + } + case common.TSDB_DATA_TYPE_TINYINT: + //1 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_TINYINT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(int8) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect int8, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(i)) + *(*C.int8_t)(current) = C.int8_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) + } + } + case common.TSDB_DATA_TYPE_SMALLINT: + //2 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_SMALLINT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(int16) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect int16, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(2*i)) + *(*C.int16_t)(current) = C.int16_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(2) + } + } + case common.TSDB_DATA_TYPE_INT: + //4 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_INT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(int32) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect int32, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) + *(*C.int32_t)(current) = C.int32_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) + } + } + case common.TSDB_DATA_TYPE_BIGINT: + //8 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(int64) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect int64, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) + *(*C.int64_t)(current) = C.int64_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) + } + } + case common.TSDB_DATA_TYPE_UTINYINT: + //1 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_UTINYINT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(uint8) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect uint8, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(i)) + *(*C.uint8_t)(current) = C.uint8_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) + } + } + case common.TSDB_DATA_TYPE_USMALLINT: + //2 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_USMALLINT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(uint16) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect uint16, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(2*i)) + *(*C.uint16_t)(current) = C.uint16_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(2) + } + } + case common.TSDB_DATA_TYPE_UINT: + //4 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_UINT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(uint32) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect uint32, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) + *(*C.uint32_t)(current) = C.uint32_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) + } + } + case common.TSDB_DATA_TYPE_UBIGINT: + //8 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(uint64) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect uint64, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) + *(*C.uint64_t)(current) = C.uint64_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) + } + } + case common.TSDB_DATA_TYPE_FLOAT: + //4 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_FLOAT + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(float32) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect float32, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) + *(*C.float)(current) = C.float(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) + } + } + case common.TSDB_DATA_TYPE_DOUBLE: + //8 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_DOUBLE + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(0) + } else { + *(*C.char)(currentNull) = C.char(0) + value, ok := rowData.(float64) + if !ok { + return nil, needFreePointer, fmt.Errorf("data type error, expect float64, but got %T, value: %v", rowData, value) + } + current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) + *(*C.double)(current) = C.double(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) + } + } + case common.TSDB_DATA_TYPE_BINARY, common.TSDB_DATA_TYPE_VARBINARY, common.TSDB_DATA_TYPE_JSON, common.TSDB_DATA_TYPE_GEOMETRY, common.TSDB_DATA_TYPE_NCHAR: + bind.buffer_type = C.int(columnType) + colOffset := make([]int, rowLen) + totalLen := 0 + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + + *(*C.int32_t)(l) = C.int32_t(0) + } else { + colOffset[i] = totalLen + switch value := rowData.(type) { + case string: + totalLen += len(value) + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(len(value)) + case []byte: + totalLen += len(value) + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(len(value)) + default: + return nil, needFreePointer, fmt.Errorf("data type error, expect string or []byte, but got %T, value: %v", rowData, value) + } + *(*C.char)(currentNull) = C.char(0) + } + } + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(totalLen)))) + needFreePointer = append(needFreePointer, p) + for i, rowData := range columnData { + if rowData != nil { + switch value := rowData.(type) { + case string: + x := *(*[]byte)(unsafe.Pointer(&value)) + C.memcpy(unsafe.Pointer(uintptr(p)+uintptr(colOffset[i])), unsafe.Pointer(&x[0]), C.size_t(len(value))) + case []byte: + C.memcpy(unsafe.Pointer(uintptr(p)+uintptr(colOffset[i])), unsafe.Pointer(&value[0]), C.size_t(len(value))) + default: + return nil, needFreePointer, fmt.Errorf("data type error, expect string or []byte, but got %T, value: %v", rowData, value) + } + } + } + case common.TSDB_DATA_TYPE_TIMESTAMP: + //8 + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8 * rowLen)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_TIMESTAMP + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(0) + } else { + *(*C.char)(currentNull) = C.char(0) + var ts int64 + switch value := rowData.(type) { + case time.Time: + ts = common.TimeToTimestamp(value, precision) + case int64: + ts = value + default: + return nil, needFreePointer, fmt.Errorf("data type error, expect time.Time or int64, but got %T, value: %v", rowData, rowData) + } + current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) + *(*C.int64_t)(current) = C.int64_t(ts) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) + } + } + } + bind.buffer = p + bind.length = (*C.int32_t)(lengthList) + bind.is_null = (*C.char)(nullList) + } + + return binds, needFreePointer, nil + +} + +func generateTaosStmt2BindsQuery(multiBind [][]driver.Value) (unsafe.Pointer, []unsafe.Pointer, error) { + var needFreePointer []unsafe.Pointer + binds := unsafe.Pointer(C.malloc(C.size_t(C.size_t(len(multiBind)) * C.size_t(unsafe.Sizeof(C.TAOS_STMT2_BIND{}))))) + needFreePointer = append(needFreePointer, binds) + for columnIndex, columnData := range multiBind { + if len(columnData) != 1 { + return nil, needFreePointer, fmt.Errorf("bind query data length must be 1, but column %d got %d", columnIndex, len(columnData)) + } + bind := (*C.TAOS_STMT2_BIND)(unsafe.Pointer(uintptr(binds) + uintptr(columnIndex)*unsafe.Sizeof(C.TAOS_STMT2_BIND{}))) + data := columnData[0] + bind.num = C.int(1) + nullList := unsafe.Pointer(C.malloc(C.size_t(C.uint(1)))) + needFreePointer = append(needFreePointer, nullList) + var lengthList unsafe.Pointer + var p unsafe.Pointer + if data == nil { + return nil, needFreePointer, fmt.Errorf("bind query data can not be nil") + } + *(*C.char)(nullList) = C.char(0) + + switch rowData := data.(type) { + case bool: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(1)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_BOOL + if rowData { + *(*C.int8_t)(p) = C.int8_t(1) + } else { + *(*C.int8_t)(p) = C.int8_t(0) + } + + case int8: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(1)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_TINYINT + *(*C.int8_t)(p) = C.int8_t(rowData) + + case int16: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_SMALLINT + *(*C.int16_t)(p) = C.int16_t(rowData) + + case int32: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_INT + *(*C.int32_t)(p) = C.int32_t(rowData) + + case int64: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT + *(*C.int64_t)(p) = C.int64_t(rowData) + + case int: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT + *(*C.int64_t)(p) = C.int64_t(int64(rowData)) + + case uint8: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(1)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_UTINYINT + *(*C.uint8_t)(p) = C.uint8_t(rowData) + + case uint16: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(2)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_USMALLINT + *(*C.uint16_t)(p) = C.uint16_t(rowData) + + case uint32: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_UINT + *(*C.uint32_t)(p) = C.uint32_t(rowData) + + case uint64: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT + *(*C.uint64_t)(p) = C.uint64_t(rowData) + + case uint: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT + *(*C.uint64_t)(p) = C.uint64_t(uint64(rowData)) + + case float32: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_FLOAT + *(*C.float)(p) = C.float(rowData) + + case float64: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(8)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_DOUBLE + *(*C.double)(p) = C.double(rowData) + + case []byte: + valueLength := len(rowData) + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(valueLength)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_BINARY + C.memcpy(p, unsafe.Pointer(&rowData[0]), C.size_t(valueLength)) + lengthList = unsafe.Pointer(C.calloc(C.size_t(C.uint(1)), C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, lengthList) + *(*C.int32_t)(lengthList) = C.int32_t(valueLength) + case string: + valueLength := len(rowData) + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(valueLength)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_BINARY + x := *(*[]byte)(unsafe.Pointer(&rowData)) + C.memcpy(p, unsafe.Pointer(&x[0]), C.size_t(valueLength)) + lengthList = unsafe.Pointer(C.calloc(C.size_t(C.uint(1)), C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, lengthList) + *(*C.int32_t)(lengthList) = C.int32_t(valueLength) + case time.Time: + buffer := make([]byte, 0, 35) + value := rowData.AppendFormat(buffer, time.RFC3339Nano) + valueLength := len(value) + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(valueLength)))) + needFreePointer = append(needFreePointer, p) + bind.buffer_type = C.TSDB_DATA_TYPE_BINARY + x := *(*[]byte)(unsafe.Pointer(&value)) + C.memcpy(p, unsafe.Pointer(&x[0]), C.size_t(valueLength)) + lengthList = unsafe.Pointer(C.calloc(C.size_t(C.uint(1)), C.size_t(C.uint(4)))) + needFreePointer = append(needFreePointer, lengthList) + *(*C.int32_t)(lengthList) = C.int32_t(valueLength) + default: + return nil, needFreePointer, fmt.Errorf("data type error, expect bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, []byte, string, time.Time, but got %T, value: %v", data, data) + } + bind.buffer = p + bind.length = (*C.int32_t)(lengthList) + bind.is_null = (*C.char)(nullList) + } + return binds, needFreePointer, nil +} + +// TaosStmt2Exec int taos_stmt2_exec(TAOS_STMT2 *stmt, int *affected_rows); +func TaosStmt2Exec(stmt unsafe.Pointer) int { + return int(C.taos_stmt2_exec(stmt, nil)) +} + +// TaosStmt2Close int taos_stmt2_close(TAOS_STMT2 *stmt); +func TaosStmt2Close(stmt unsafe.Pointer) int { + return int(C.taos_stmt2_close(stmt)) +} + +// TaosStmt2IsInsert int taos_stmt2_is_insert(TAOS_STMT2 *stmt, int *insert); +func TaosStmt2IsInsert(stmt unsafe.Pointer) (is bool, errorCode int) { + p := C.malloc(C.size_t(4)) + isInsert := (*C.int)(p) + defer C.free(p) + errorCode = int(C.taos_stmt2_is_insert(stmt, isInsert)) + return int(*isInsert) == 1, errorCode +} + +// TaosStmt2GetFields int taos_stmt2_get_fields(TAOS_STMT2 *stmt, TAOS_FIELD_T field_type, int *count, TAOS_FIELD_E **fields); +func TaosStmt2GetFields(stmt unsafe.Pointer, fieldType int) (code, count int, fields unsafe.Pointer) { + code = int(C.taos_stmt2_get_fields(stmt, C.TAOS_FIELD_T(fieldType), (*C.int)(unsafe.Pointer(&count)), (**C.TAOS_FIELD_E)(unsafe.Pointer(&fields)))) + return +} + +// TaosStmt2FreeFields void taos_stmt2_free_fields(TAOS_STMT2 *stmt, TAOS_FIELD_E *fields); +func TaosStmt2FreeFields(stmt unsafe.Pointer, fields unsafe.Pointer) { + if fields == nil { + return + } + C.taos_stmt2_free_fields(stmt, (*C.TAOS_FIELD_E)(fields)) +} + +// TaosStmt2Error char *taos_stmt2_error(TAOS_STMT2 *stmt) +func TaosStmt2Error(stmt unsafe.Pointer) string { + return C.GoString(C.taos_stmt2_error(stmt)) +} + +func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error { + totalLength := binary.LittleEndian.Uint32(data[stmt.TotalLengthPosition:]) + if totalLength != uint32(len(data)) { + return fmt.Errorf("total length not match, expect %d, but get %d", len(data), totalLength) + } + var freePointer []unsafe.Pointer + defer func() { + for i := len(freePointer) - 1; i >= 0; i-- { + if freePointer[i] != nil { + C.free(freePointer[i]) + } + } + }() + dataP := unsafe.Pointer(C.CBytes(data)) + freePointer = append(freePointer, dataP) + count := binary.LittleEndian.Uint32(data[stmt.CountPosition:]) + tagCount := binary.LittleEndian.Uint32(data[stmt.TagCountPosition:]) + colCount := binary.LittleEndian.Uint32(data[stmt.ColCountPosition:]) + tableNamesOffset := binary.LittleEndian.Uint32(data[stmt.TableNamesOffsetPosition:]) + tagsOffset := binary.LittleEndian.Uint32(data[stmt.TagsOffsetPosition:]) + colsOffset := binary.LittleEndian.Uint32(data[stmt.ColsOffsetPosition:]) + // check table names + if tableNamesOffset > 0 { + tableNameEnd := tableNamesOffset + count*2 + // table name lengths out of range + if tableNameEnd > totalLength { + return fmt.Errorf("table name lengths out of range, total length: %d, tableNamesLengthEnd: %d", totalLength, tableNameEnd) + } + for i := uint32(0); i < count; i++ { + tableNameLength := binary.LittleEndian.Uint16(data[tableNamesOffset+i*2:]) + tableNameEnd += uint32(tableNameLength) + } + if tableNameEnd > totalLength { + return fmt.Errorf("table names out of range, total length: %d, tableNameTotalLength: %d", totalLength, tableNameEnd) + } + } + // check tags + if tagsOffset > 0 { + if tagCount == 0 { + return fmt.Errorf("tag count is zero, but tags offset is not zero") + } + tagsEnd := tagsOffset + count*4 + if tagsEnd > totalLength { + return fmt.Errorf("tags lengths out of range, total length: %d, tagsLengthEnd: %d", totalLength, tagsEnd) + } + for i := uint32(0); i < count; i++ { + tagLength := binary.LittleEndian.Uint32(data[tagsOffset+i*4:]) + if tagLength == 0 { + return fmt.Errorf("tag length is zero, data index: %d", i) + } + tagsEnd += tagLength + } + if tagsEnd > totalLength { + return fmt.Errorf("tags out of range, total length: %d, tagsTotalLength: %d", totalLength, tagsEnd) + } + } + // check cols + if colsOffset > 0 { + if colCount == 0 { + return fmt.Errorf("col count is zero, but cols offset is not zero") + } + colsEnd := colsOffset + count*4 + if colsEnd > totalLength { + return fmt.Errorf("cols lengths out of range, total length: %d, colsLengthEnd: %d", totalLength, colsEnd) + } + for i := uint32(0); i < count; i++ { + colLength := binary.LittleEndian.Uint32(data[colsOffset+i*4:]) + if colLength == 0 { + return fmt.Errorf("col length is zero, data: %d", i) + } + colsEnd += colLength + } + if colsEnd > totalLength { + return fmt.Errorf("cols out of range, total length: %d, colsTotalLength: %d", totalLength, colsEnd) + } + } + cBindv := C.TAOS_STMT2_BINDV{} + cBindv.count = C.int(count) + if tableNamesOffset > 0 { + tableNameLengthP := tools.AddPointer(dataP, uintptr(tableNamesOffset)) + cTableNames := C.malloc(C.size_t(uintptr(count) * PointerSize)) + freePointer = append(freePointer, cTableNames) + tableDataP := tools.AddPointer(tableNameLengthP, uintptr(count)*2) + var tableNamesArrayP unsafe.Pointer + for i := uint32(0); i < count; i++ { + tableNamesArrayP = tools.AddPointer(cTableNames, uintptr(i)*PointerSize) + *(**C.char)(tableNamesArrayP) = (*C.char)(tableDataP) + tableNameLength := *(*uint16)(tools.AddPointer(tableNameLengthP, uintptr(i*2))) + if tableNameLength == 0 { + return fmt.Errorf("table name length is zero, data index: %d", i) + } + tableDataP = tools.AddPointer(tableDataP, uintptr(tableNameLength)) + } + cBindv.tbnames = (**C.char)(cTableNames) + } else { + cBindv.tbnames = nil + } + if tagsOffset > 0 { + tags, needFreePointer, err := generateStmt2Binds(count, tagCount, dataP, tagsOffset) + freePointer = append(freePointer, needFreePointer...) + if err != nil { + return fmt.Errorf("generate tags error: %s", err.Error()) + } + cBindv.tags = (**C.TAOS_STMT2_BIND)(tags) + } else { + cBindv.tags = nil + } + if colsOffset > 0 { + cols, needFreePointer, err := generateStmt2Binds(count, colCount, dataP, colsOffset) + freePointer = append(freePointer, needFreePointer...) + if err != nil { + return fmt.Errorf("generate cols error: %s", err.Error()) + } + cBindv.bind_cols = (**C.TAOS_STMT2_BIND)(cols) + } else { + cBindv.bind_cols = nil + } + code := int(C.taos_stmt2_bind_param(stmt2, &cBindv, C.int32_t(colIdx))) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + return taosError.NewError(code, errStr) + } + return nil +} + +func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, fieldsOffset uint32) (unsafe.Pointer, []unsafe.Pointer, error) { + var freePointer []unsafe.Pointer + bindsCList := unsafe.Pointer(C.malloc(C.size_t(uintptr(count) * PointerSize))) + freePointer = append(freePointer, bindsCList) + // dataLength [count]uint32 + // length have checked in TaosStmt2BindBinary + baseLengthPointer := tools.AddPointer(dataP, uintptr(fieldsOffset)) + // dataBuffer + dataPointer := tools.AddPointer(baseLengthPointer, uintptr(count)*4) + var bindsPointer unsafe.Pointer + for tableIndex := uint32(0); tableIndex < count; tableIndex++ { + bindsPointer = tools.AddPointer(bindsCList, uintptr(tableIndex)*PointerSize) + binds := unsafe.Pointer(C.malloc(C.size_t(C.size_t(fieldCount) * C.size_t(unsafe.Sizeof(C.TAOS_STMT2_BIND{}))))) + freePointer = append(freePointer, binds) + var bindDataP unsafe.Pointer + var bindDataTotalLength uint32 + var num int32 + var haveLength byte + var bufferLength uint32 + for fieldIndex := uint32(0); fieldIndex < fieldCount; fieldIndex++ { + // field data + bindDataP = dataPointer + // totalLength + bindDataTotalLength = *(*uint32)(bindDataP) + bindDataP = tools.AddPointer(bindDataP, common.UInt32Size) + bind := (*C.TAOS_STMT2_BIND)(unsafe.Pointer(uintptr(binds) + uintptr(fieldIndex)*unsafe.Sizeof(C.TAOS_STMT2_BIND{}))) + // buffer_type + bind.buffer_type = *(*C.int)(bindDataP) + bindDataP = tools.AddPointer(bindDataP, common.Int32Size) + // num + num = *(*int32)(bindDataP) + bind.num = C.int(num) + bindDataP = tools.AddPointer(bindDataP, common.Int32Size) + // is_null + bind.is_null = (*C.char)(bindDataP) + bindDataP = tools.AddPointer(bindDataP, uintptr(num)) + // haveLength + haveLength = *(*byte)(bindDataP) + bindDataP = tools.AddPointer(bindDataP, common.Int8Size) + if haveLength == 0 { + bind.length = nil + } else { + // length [num]int32 + bind.length = (*C.int32_t)(bindDataP) + bindDataP = tools.AddPointer(bindDataP, common.Int32Size*uintptr(num)) + } + // bufferLength + bufferLength = *(*uint32)(bindDataP) + bindDataP = tools.AddPointer(bindDataP, common.UInt32Size) + // buffer + if bufferLength == 0 { + bind.buffer = nil + } else { + bind.buffer = bindDataP + } + bindDataP = tools.AddPointer(bindDataP, uintptr(bufferLength)) + // check bind data length + bindDataLen := uintptr(bindDataP) - uintptr(dataPointer) + if bindDataLen != uintptr(bindDataTotalLength) { + return nil, freePointer, fmt.Errorf("bind data length not match, expect %d, but get %d, tableIndex:%d", bindDataTotalLength, bindDataLen, tableIndex) + } + dataPointer = bindDataP + } + *(**C.TAOS_STMT2_BIND)(bindsPointer) = (*C.TAOS_STMT2_BIND)(binds) + } + return bindsCList, freePointer, nil +} diff --git a/driver/wrapper/stmt2_test.go b/driver/wrapper/stmt2_test.go new file mode 100644 index 00000000..f92b148f --- /dev/null +++ b/driver/wrapper/stmt2_test.go @@ -0,0 +1,5076 @@ +package wrapper + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/common/stmt" + taosError "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +type stmt2Result struct { + res unsafe.Pointer + affected int + n int +} +type StmtCallBackTest struct { + ExecResult chan *stmt2Result +} + +func (s *StmtCallBackTest) ExecCall(res unsafe.Pointer, affected int, code int) { + s.ExecResult <- &stmt2Result{ + res: res, + affected: affected, + n: code, + } +} + +func NewStmtCallBackTest() *StmtCallBackTest { + return &StmtCallBackTest{ + ExecResult: make(chan *stmt2Result, 1), + } +} + +func TestStmt2BindData(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2 precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2") + if err != nil { + t.Error(err) + return + } + now := time.Now().Round(time.Millisecond) + next1S := now.Add(time.Second) + next2S := now.Add(2 * time.Second) + + tests := []struct { + name string + tbType string + pos string + params []*stmt.TaosStmt2BindData + expectValue [][]driver.Value + }{ + { + name: "int", + tbType: "ts timestamp, v int", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + {now}, + {int32(1)}, + }, + }}, + expectValue: [][]driver.Value{ + {now, int32(1)}, + }, + }, + { + name: "int null", + tbType: "ts timestamp, v int", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + {now}, + {nil}, + }, + }}, + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "int null 3 cols", + tbType: "ts timestamp, v int", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + int32(1), + nil, + int32(2), + }, + }, + }}, + expectValue: [][]driver.Value{ + {now, int32(1)}, + {next1S, nil}, + {next2S, int32(2)}, + }, + }, + { + name: "bool", + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{{now}, {true}}, + }}, + + expectValue: [][]driver.Value{{now, true}}, + }, + { + name: "bool false", + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{{now}, {false}}, + }}, + + expectValue: [][]driver.Value{{now, false}}, + }, + { + name: "bool null", + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{{now}, {nil}}, + }}, + + expectValue: [][]driver.Value{{now, nil}}, + }, + { + name: "bool null 3 cols", + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + bool(true), + nil, + bool(false), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, true}, + {next1S, nil}, + {next2S, false}, + }, + }, + { + name: "tinyint", + tbType: "ts timestamp, v tinyint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{{now}, {int8(1)}}, + }}, + + expectValue: [][]driver.Value{ + {now, int8(1)}, + }, + }, + { + name: "tinyint null", + tbType: "ts timestamp, v tinyint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, { + nil, + }}, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "tinyint null 3 cols", + tbType: "ts timestamp, v tinyint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, { + int8(1), + nil, + int8(2), + }}, + }}, + + expectValue: [][]driver.Value{ + {now, int8(1)}, + {next1S, nil}, + {next2S, int8(2)}, + }, + }, + { + name: "smallint", + tbType: "ts timestamp, v smallint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + int16(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, int16(1)}, + }, + }, + { + name: "smallint null", + tbType: "ts timestamp, v smallint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "smallint null 3 cols", + tbType: "ts timestamp, v smallint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + int16(1), + nil, + int16(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, int16(1)}, + {next1S, nil}, + {next2S, int16(2)}, + }, + }, + { + name: "bigint", + tbType: "ts timestamp, v bigint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + int64(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, int64(1)}, + }, + }, + { + name: "bigint null", + tbType: "ts timestamp, v bigint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "bigint null 3 cols", + tbType: "ts timestamp, v bigint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + int64(1), + nil, + int64(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, int64(1)}, + {next1S, nil}, + {next2S, int64(2)}, + }, + }, + + { + name: "tinyint unsigned", + tbType: "ts timestamp, v tinyint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + uint8(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint8(1)}, + }, + }, + { + name: "tinyint unsigned null", + tbType: "ts timestamp, v tinyint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "tinyint unsigned null 3 cols", + tbType: "ts timestamp, v tinyint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint8(1), + nil, + uint8(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint8(1)}, + {next1S, nil}, + {next2S, uint8(2)}, + }, + }, + + { + name: "smallint unsigned", + tbType: "ts timestamp, v smallint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + uint16(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint16(1)}, + }, + }, + { + name: "smallint unsigned null", + tbType: "ts timestamp, v smallint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "smallint unsigned null 3 cols", + tbType: "ts timestamp, v smallint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint16(1), + nil, + uint16(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint16(1)}, + {next1S, nil}, + {next2S, uint16(2)}, + }, + }, + + { + name: "int unsigned", + tbType: "ts timestamp, v int unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + uint32(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint32(1)}, + }, + }, + { + name: "int unsigned null", + tbType: "ts timestamp, v int unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "int unsigned null 3 cols", + tbType: "ts timestamp, v int unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint32(1), + nil, + uint32(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint32(1)}, + {next1S, nil}, + {next2S, uint32(2)}, + }, + }, + + { + name: "bigint unsigned", + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + uint64(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint64(1)}, + }, + }, + { + name: "bigint unsigned null", + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "bigint unsigned null 3 cols", + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint64(1), + nil, + uint64(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint64(1)}, + {next1S, nil}, + {next2S, uint64(2)}, + }, + }, + + { + name: "float", + tbType: "ts timestamp, v float", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + float32(1.2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, float32(1.2)}, + }, + }, + { + name: "float null", + tbType: "ts timestamp, v float", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "float null 3 cols", + tbType: "ts timestamp, v float", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + float32(1.2), + nil, + float32(2.2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, float32(1.2)}, + {next1S, nil}, + {next2S, float32(2.2)}, + }, + }, + + { + name: "double", + tbType: "ts timestamp, v double", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + float64(1.2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, float64(1.2)}, + }, + }, + { + name: "double null", + tbType: "ts timestamp, v double", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "double null 3 cols", + tbType: "ts timestamp, v double", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + float64(1.2), + nil, + float64(2.2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, float64(1.2)}, + {next1S, nil}, + {next2S, float64(2.2)}, + }, + }, + + { + name: "binary", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + []byte("yes"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + }, + }, + { + name: "binary null", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "binary null 3 cols", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + []byte("yes"), + nil, + []byte("中文"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + {next1S, nil}, + {next2S, "中文"}, + }, + }, + + { + name: "varbinary", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + []byte("yes"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte("yes")}, + }, + }, + { + name: "varbinary null", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + + { + name: "varbinary null 3 cols", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + []byte("yes"), + nil, + []byte("中文"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte("yes")}, + {next1S, nil}, + {next2S, []byte("中文")}, + }, + }, + + { + name: "geometry", + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + }, + }, + { + name: "geometry null", + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "geometry null 3 cols", + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + {next1S, nil}, + {next2S, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + }, + }, + + { + name: "nchar", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + []byte("yes"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + }, + }, + { + name: "nchar null", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "nchar null 3 cols", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + []byte("yes"), + nil, + []byte("中文"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + {next1S, nil}, + {next2S, "中文"}, + }, + }, + + { + name: "nchar bind string", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + "yes", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + }, + }, + + { + name: "nchar bind string null 3 cols", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + "yes", + nil, + "中文", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + {next1S, nil}, + {next2S, "中文"}, + }, + }, + + { + name: "binary bind string", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + "yes", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + }, + }, + + { + name: "binary bind string null 3 cols", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + "yes", + nil, + "中文", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + {next1S, nil}, + {next2S, "中文"}, + }, + }, + + { + name: "varbinary bind string", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + "yes", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte("yes")}, + }, + }, + + { + name: "varbinary bind string null 3 cols", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + "yes", + nil, + "中文", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte("yes")}, + {next1S, nil}, + {next2S, []byte("中文")}, + }, + }, + + { + name: "timestamp", + tbType: "ts timestamp, v timestamp", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + now, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, now}, + }, + }, + } + for i, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tbType := tc.tbType + tbName := fmt.Sprintf("test_fast_insert_%02d", i) + drop := fmt.Sprintf("drop table if exists %s", tbName) + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + pos := tc.pos + sql := fmt.Sprintf("insert into %s values(%s)", tbName, pos) + var err error + if err = exec(conn, drop); err != nil { + t.Error(err) + return + } + if err = exec(conn, create); err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + insertStmt := TaosStmt2Init(conn, 0xcc123, false, false, handler) + code := TaosStmt2Prepare(insertStmt, sql) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code := TaosStmt2IsInsert(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + code, count, cfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmt2FreeFields(insertStmt, cfields) + assert.Equal(t, 2, count) + fields := StmtParseFields(count, cfields) + err = TaosStmt2BindParam(insertStmt, true, tc.params, fields, nil, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + t.Log(r.affected) + //time.Sleep(time.Second) + code = TaosStmt2Close(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + result, err := query(conn, fmt.Sprintf("select * from %s order by ts asc", tbName)) + if err != nil { + t.Error(err) + return + } + assert.Equal(t, tc.expectValue, result) + }) + } + +} + +func TestStmt2BindBinary(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_binary") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_binary precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_binary") + if err != nil { + t.Error(err) + return + } + now := time.Now().Round(time.Millisecond) + next1S := now.Add(time.Second) + next2S := now.Add(2 * time.Second) + + tests := []struct { + name string + tbType string + pos string + params []*stmt.TaosStmt2BindData + expectValue [][]driver.Value + }{ + { + name: "int", + tbType: "ts timestamp, v int", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + {now}, + {int32(1)}, + }, + }}, + expectValue: [][]driver.Value{ + {now, int32(1)}, + }, + }, + { + name: "int null", + tbType: "ts timestamp, v int", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + {now}, + {nil}, + }, + }}, + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "int null 3 cols", + tbType: "ts timestamp, v int", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + int32(1), + nil, + int32(2), + }, + }, + }}, + expectValue: [][]driver.Value{ + {now, int32(1)}, + {next1S, nil}, + {next2S, int32(2)}, + }, + }, + { + name: "bool", + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{{now}, {bool(true)}}, + }}, + + expectValue: [][]driver.Value{{now, true}}, + }, + { + name: "bool null", + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{{now}, {nil}}, + }}, + + expectValue: [][]driver.Value{{now, nil}}, + }, + { + name: "bool null 3 cols", + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + bool(true), + nil, + bool(false), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, true}, + {next1S, nil}, + {next2S, false}, + }, + }, + { + name: "tinyint", + tbType: "ts timestamp, v tinyint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{{now}, {int8(1)}}, + }}, + + expectValue: [][]driver.Value{ + {now, int8(1)}, + }, + }, + { + name: "tinyint null", + tbType: "ts timestamp, v tinyint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, { + nil, + }}, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "tinyint null 3 cols", + tbType: "ts timestamp, v tinyint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, { + int8(1), + nil, + int8(2), + }}, + }}, + + expectValue: [][]driver.Value{ + {now, int8(1)}, + {next1S, nil}, + {next2S, int8(2)}, + }, + }, + { + name: "smallint", + tbType: "ts timestamp, v smallint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + int16(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, int16(1)}, + }, + }, + { + name: "smallint null", + tbType: "ts timestamp, v smallint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "smallint null 3 cols", + tbType: "ts timestamp, v smallint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + int16(1), + nil, + int16(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, int16(1)}, + {next1S, nil}, + {next2S, int16(2)}, + }, + }, + { + name: "bigint", + tbType: "ts timestamp, v bigint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + int64(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, int64(1)}, + }, + }, + { + name: "bigint null", + tbType: "ts timestamp, v bigint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "bigint null 3 cols", + tbType: "ts timestamp, v bigint", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + int64(1), + nil, + int64(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, int64(1)}, + {next1S, nil}, + {next2S, int64(2)}, + }, + }, + + { + name: "tinyint unsigned", + tbType: "ts timestamp, v tinyint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + uint8(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint8(1)}, + }, + }, + { + name: "tinyint unsigned null", + tbType: "ts timestamp, v tinyint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "tinyint unsigned null 3 cols", + tbType: "ts timestamp, v tinyint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint8(1), + nil, + uint8(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint8(1)}, + {next1S, nil}, + {next2S, uint8(2)}, + }, + }, + + { + name: "smallint unsigned", + tbType: "ts timestamp, v smallint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + uint16(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint16(1)}, + }, + }, + { + name: "smallint unsigned null", + tbType: "ts timestamp, v smallint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "smallint unsigned null 3 cols", + tbType: "ts timestamp, v smallint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint16(1), + nil, + uint16(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint16(1)}, + {next1S, nil}, + {next2S, uint16(2)}, + }, + }, + + { + name: "int unsigned", + tbType: "ts timestamp, v int unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + uint32(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint32(1)}, + }, + }, + { + name: "int unsigned null", + tbType: "ts timestamp, v int unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "int unsigned null 3 cols", + tbType: "ts timestamp, v int unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint32(1), + nil, + uint32(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint32(1)}, + {next1S, nil}, + {next2S, uint32(2)}, + }, + }, + + { + name: "bigint unsigned", + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + uint64(1), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint64(1)}, + }, + }, + { + name: "bigint unsigned null", + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "bigint unsigned null 3 cols", + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + uint64(1), + nil, + uint64(2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, uint64(1)}, + {next1S, nil}, + {next2S, uint64(2)}, + }, + }, + + { + name: "float", + tbType: "ts timestamp, v float", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + float32(1.2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, float32(1.2)}, + }, + }, + { + name: "float null", + tbType: "ts timestamp, v float", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "float null 3 cols", + tbType: "ts timestamp, v float", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + float32(1.2), + nil, + float32(2.2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, float32(1.2)}, + {next1S, nil}, + {next2S, float32(2.2)}, + }, + }, + + { + name: "double", + tbType: "ts timestamp, v double", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + float64(1.2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, float64(1.2)}, + }, + }, + { + name: "double null", + tbType: "ts timestamp, v double", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "double null 3 cols", + tbType: "ts timestamp, v double", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + float64(1.2), + nil, + float64(2.2), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, float64(1.2)}, + {next1S, nil}, + {next2S, float64(2.2)}, + }, + }, + + { + name: "binary", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + []byte("yes"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + }, + }, + { + name: "binary null", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "binary null 3 cols", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + []byte("yes"), + nil, + []byte("中文"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + {next1S, nil}, + {next2S, "中文"}, + }, + }, + + { + name: "varbinary", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + []byte("yes"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte("yes")}, + }, + }, + { + name: "varbinary null", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + + { + name: "varbinary null 3 cols", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + []byte("yes"), + nil, + []byte("中文"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte("yes")}, + {next1S, nil}, + {next2S, []byte("中文")}, + }, + }, + + { + name: "geometry", + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + }, + }, + { + name: "geometry null", + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "geometry null 3 cols", + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + {next1S, nil}, + {next2S, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + }, + }, + + { + name: "nchar", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + []byte("yes"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + }, + }, + { + name: "nchar null", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + nil, + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, nil}, + }, + }, + { + name: "nchar null 3 cols", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + []byte("yes"), + nil, + []byte("中文"), + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + {next1S, nil}, + {next2S, "中文"}, + }, + }, + + { + name: "nchar bind string", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + "yes", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + }, + }, + + { + name: "nchar bind string null 3 cols", + tbType: "ts timestamp, v nchar(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + "yes", + nil, + "中文", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + {next1S, nil}, + {next2S, "中文"}, + }, + }, + + { + name: "binary bind string", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + "yes", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + }, + }, + + { + name: "binary bind string null 3 cols", + tbType: "ts timestamp, v binary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + "yes", + nil, + "中文", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, "yes"}, + {next1S, nil}, + {next2S, "中文"}, + }, + }, + + { + name: "varbinary bind string", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + }, + { + "yes", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte("yes")}, + }, + }, + + { + name: "varbinary bind string null 3 cols", + tbType: "ts timestamp, v varbinary(20)", + pos: "?, ?", + params: []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + { + now, + next1S, + next2S, + }, + { + "yes", + nil, + "中文", + }, + }, + }}, + + expectValue: [][]driver.Value{ + {now, []byte("yes")}, + {next1S, nil}, + {next2S, []byte("中文")}, + }, + }, + } + for i, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tbType := tc.tbType + tbName := fmt.Sprintf("test_fast_insert_%02d", i) + drop := fmt.Sprintf("drop table if exists %s", tbName) + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + pos := tc.pos + sql := fmt.Sprintf("insert into %s values(%s)", tbName, pos) + var err error + if err = exec(conn, drop); err != nil { + t.Error(err) + return + } + if err = exec(conn, create); err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + insertStmt := TaosStmt2Init(conn, 0xcc123, false, false, handler) + code := TaosStmt2Prepare(insertStmt, sql) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code := TaosStmt2IsInsert(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + code, count, cfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmt2FreeFields(insertStmt, cfields) + assert.Equal(t, 2, count) + fields := StmtParseFields(count, cfields) + bs, err := stmt.MarshalStmt2Binary(tc.params, true, fields, nil) + if err != nil { + t.Error("marshal binary error:", err) + return + } + err = TaosStmt2BindBinary(insertStmt, bs, -1) + if !assert.NoError(t, err, bs) { + return + } + //return + code = TaosStmt2Exec(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + t.Log(r.affected) + code = TaosStmt2Close(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + result, err := query(conn, fmt.Sprintf("select * from %s order by ts asc", tbName)) + if err != nil { + t.Error(err) + return + } + assert.Equal(t, tc.expectValue, result) + }) + } + +} + +func TestStmt2AllType(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_all") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_all precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_all") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists all_stb("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))"+ + "tags("+ + "tts timestamp, "+ + "tv1 bool, "+ + "tv2 tinyint, "+ + "tv3 smallint, "+ + "tv4 int, "+ + "tv5 bigint, "+ + "tv6 tinyint unsigned, "+ + "tv7 smallint unsigned, "+ + "tv8 int unsigned, "+ + "tv9 bigint unsigned, "+ + "tv10 float, "+ + "tv11 double, "+ + "tv12 binary(20), "+ + "tv13 varbinary(20), "+ + "tv14 geometry(100), "+ + "tv15 nchar(20))") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + insertStmt := TaosStmt2Init(conn, 0xcc123, false, false, handler) + prepareInsertSql := "insert into ? using all_stb tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" + code := TaosStmt2Prepare(insertStmt, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + params := []*stmt.TaosStmt2BindData{{ + TableName: "ctb1", + }} + err = TaosStmt2BindParam(insertStmt, true, params, nil, nil, -1) + if err != nil { + t.Error(err) + return + } + + code, count, cTablefields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TBNAME) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.Equal(t, 1, count) + assert.Equal(t, unsafe.Pointer(nil), cTablefields) + + isInsert, code := TaosStmt2IsInsert(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + code, count, cColFields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmt2FreeFields(insertStmt, cColFields) + assert.Equal(t, 16, count) + colFields := StmtParseFields(count, cColFields) + t.Log(colFields) + code, count, cTagfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TAG) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmt2FreeFields(insertStmt, cTagfields) + assert.Equal(t, 16, count) + tagFields := StmtParseFields(count, cTagfields) + t.Log(tagFields) + now := time.Now() + //colTypes := []int8{ + // common.TSDB_DATA_TYPE_TIMESTAMP, + // common.TSDB_DATA_TYPE_BOOL, + // common.TSDB_DATA_TYPE_TINYINT, + // common.TSDB_DATA_TYPE_SMALLINT, + // common.TSDB_DATA_TYPE_INT, + // common.TSDB_DATA_TYPE_BIGINT, + // common.TSDB_DATA_TYPE_UTINYINT, + // common.TSDB_DATA_TYPE_USMALLINT, + // common.TSDB_DATA_TYPE_UINT, + // common.TSDB_DATA_TYPE_UBIGINT, + // common.TSDB_DATA_TYPE_FLOAT, + // common.TSDB_DATA_TYPE_DOUBLE, + // common.TSDB_DATA_TYPE_BINARY, + // common.TSDB_DATA_TYPE_VARBINARY, + // common.TSDB_DATA_TYPE_GEOMETRY, + // common.TSDB_DATA_TYPE_NCHAR, + //} + params2 := []*stmt.TaosStmt2BindData{{ + TableName: "ctb1", + Tags: []driver.Value{ + // TIMESTAMP + now, + // BOOL + true, + // TINYINT + int8(1), + // SMALLINT + int16(1), + // INT + int32(1), + // BIGINT + int64(1), + // UTINYINT + uint8(1), + // USMALLINT + uint16(1), + // UINT + uint32(1), + // UBIGINT + uint64(1), + // FLOAT + float32(1.2), + // DOUBLE + float64(1.2), + // BINARY + []byte("binary"), + // VARBINARY + []byte("varbinary"), + // GEOMETRY + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + // NCHAR + "nchar", + }, + Cols: [][]driver.Value{ + { + now, + now.Add(time.Second), + now.Add(time.Second * 2), + }, + { + true, + nil, + false, + }, + { + int8(11), + nil, + int8(12), + }, + { + int16(11), + nil, + int16(12), + }, + { + int32(11), + nil, + int32(12), + }, + { + int64(11), + nil, + int64(12), + }, + { + uint8(11), + nil, + uint8(12), + }, + { + uint16(11), + nil, + uint16(12), + }, + { + uint32(11), + nil, + uint32(12), + }, + { + uint64(11), + nil, + uint64(12), + }, + { + float32(11.2), + nil, + float32(12.2), + }, + { + float64(11.2), + nil, + float64(12.2), + }, + { + []byte("binary1"), + nil, + []byte("binary2"), + }, + { + []byte("varbinary1"), + nil, + []byte("varbinary2"), + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + { + "nchar1", + nil, + "nchar2", + }, + }, + }} + + err = TaosStmt2BindParam(insertStmt, true, params2, colFields, tagFields, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + t.Log(r.affected) + + code = TaosStmt2Close(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } +} + +func TestStmt2AllTypeBytes(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_all") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_all_bytes precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_all_bytes") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists all_stb("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))"+ + "tags("+ + "tts timestamp, "+ + "tv1 bool, "+ + "tv2 tinyint, "+ + "tv3 smallint, "+ + "tv4 int, "+ + "tv5 bigint, "+ + "tv6 tinyint unsigned, "+ + "tv7 smallint unsigned, "+ + "tv8 int unsigned, "+ + "tv9 bigint unsigned, "+ + "tv10 float, "+ + "tv11 double, "+ + "tv12 binary(20), "+ + "tv13 varbinary(20), "+ + "tv14 geometry(100), "+ + "tv15 nchar(20))") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + insertStmt := TaosStmt2Init(conn, 0xcc123, false, false, handler) + prepareInsertSql := "insert into ? using all_stb tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" + code := TaosStmt2Prepare(insertStmt, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + params := []*stmt.TaosStmt2BindData{{ + TableName: "ctb1", + }} + bs, err := stmt.MarshalStmt2Binary(params, true, nil, nil) + if err != nil { + t.Error(err) + return + } + err = TaosStmt2BindBinary(insertStmt, bs, -1) + if err != nil { + t.Error(err) + return + } + + code, count, cTablefields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TBNAME) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.Equal(t, 1, count) + assert.Equal(t, unsafe.Pointer(nil), cTablefields) + + isInsert, code := TaosStmt2IsInsert(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + code, count, cColFields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmt2FreeFields(insertStmt, cColFields) + assert.Equal(t, 16, count) + colFields := StmtParseFields(count, cColFields) + t.Log(colFields) + code, count, cTagfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TAG) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmt2FreeFields(insertStmt, cTagfields) + assert.Equal(t, 16, count) + tagFields := StmtParseFields(count, cTagfields) + t.Log(tagFields) + now := time.Now() + //colTypes := []int8{ + // common.TSDB_DATA_TYPE_TIMESTAMP, + // common.TSDB_DATA_TYPE_BOOL, + // common.TSDB_DATA_TYPE_TINYINT, + // common.TSDB_DATA_TYPE_SMALLINT, + // common.TSDB_DATA_TYPE_INT, + // common.TSDB_DATA_TYPE_BIGINT, + // common.TSDB_DATA_TYPE_UTINYINT, + // common.TSDB_DATA_TYPE_USMALLINT, + // common.TSDB_DATA_TYPE_UINT, + // common.TSDB_DATA_TYPE_UBIGINT, + // common.TSDB_DATA_TYPE_FLOAT, + // common.TSDB_DATA_TYPE_DOUBLE, + // common.TSDB_DATA_TYPE_BINARY, + // common.TSDB_DATA_TYPE_VARBINARY, + // common.TSDB_DATA_TYPE_GEOMETRY, + // common.TSDB_DATA_TYPE_NCHAR, + //} + params2 := []*stmt.TaosStmt2BindData{{ + TableName: "ctb1", + Tags: []driver.Value{ + // TIMESTAMP + now, + // BOOL + true, + // TINYINT + int8(1), + // SMALLINT + int16(1), + // INT + int32(1), + // BIGINT + int64(1), + // UTINYINT + uint8(1), + // USMALLINT + uint16(1), + // UINT + uint32(1), + // UBIGINT + uint64(1), + // FLOAT + float32(1.2), + // DOUBLE + float64(1.2), + // BINARY + []byte("binary"), + // VARBINARY + []byte("varbinary"), + // GEOMETRY + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + // NCHAR + "nchar", + }, + Cols: [][]driver.Value{ + { + now, + now.Add(time.Second), + now.Add(time.Second * 2), + }, + { + true, + nil, + false, + }, + { + int8(11), + nil, + int8(12), + }, + { + int16(11), + nil, + int16(12), + }, + { + int32(11), + nil, + int32(12), + }, + { + int64(11), + nil, + int64(12), + }, + { + uint8(11), + nil, + uint8(12), + }, + { + uint16(11), + nil, + uint16(12), + }, + { + uint32(11), + nil, + uint32(12), + }, + { + uint64(11), + nil, + uint64(12), + }, + { + float32(11.2), + nil, + float32(12.2), + }, + { + float64(11.2), + nil, + float64(12.2), + }, + { + []byte("binary1"), + nil, + []byte("binary2"), + }, + { + []byte("varbinary1"), + nil, + []byte("varbinary2"), + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + { + "nchar1", + nil, + "nchar2", + }, + }, + }} + bs, err = stmt.MarshalStmt2Binary(params2, true, colFields, tagFields) + if err != nil { + t.Error(err) + return + } + err = TaosStmt2BindBinary(insertStmt, bs, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + t.Log(r.affected) + + code = TaosStmt2Close(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } +} + +func TestStmt2Query(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_query") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_query precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_query") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists t(ts timestamp,v int)") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) + prepareInsertSql := "insert into t values (?,?)" + code := TaosStmt2Prepare(stmt2, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code := TaosStmt2IsInsert(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + now := time.Now().Round(time.Millisecond) + colTypes := []*stmt.StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionMilliSecond, + }, + { + FieldType: common.TSDB_DATA_TYPE_INT, + }, + } + params := []*stmt.TaosStmt2BindData{ + { + TableName: "t", + Cols: [][]driver.Value{ + { + now, + now.Add(time.Second), + }, + { + int32(1), + int32(2), + }, + }, + }, + { + TableName: "t", + Cols: [][]driver.Value{ + { + now.Add(time.Second * 2), + now.Add(time.Second * 3), + }, + { + int32(3), + int32(4), + }, + }, + }, + } + err = TaosStmt2BindParam(stmt2, true, params, colTypes, nil, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + assert.Equal(t, 4, r.affected) + code = TaosStmt2Prepare(stmt2, "select * from t where ts >= ? and ts <= ?") + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code = TaosStmt2IsInsert(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.False(t, isInsert) + params = []*stmt.TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + now, + }, + { + now.Add(time.Second * 3), + }, + }, + }, + } + + err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r = <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + res := r.res + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := TaosFetchRawBlock(res) + if errCode != 0 { + errStr := TaosErrorStr(res) + err = taosError.NewError(errCode, errStr) + t.Error(err) + return + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + assert.Equal(t, 4, len(result)) + assert.Equal(t, now, result[0][0]) + assert.Equal(t, now.Add(time.Second), result[1][0]) + assert.Equal(t, now.Add(time.Second*2), result[2][0]) + assert.Equal(t, now.Add(time.Second*3), result[3][0]) + assert.Equal(t, int32(1), result[0][1]) + assert.Equal(t, int32(2), result[1][1]) + assert.Equal(t, int32(3), result[2][1]) + assert.Equal(t, int32(4), result[3][1]) + code = TaosStmt2Close(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } +} + +func TestStmt2QueryBytes(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_query_bytes") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_query_bytes precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_query_bytes") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists t(ts timestamp,v int)") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) + prepareInsertSql := "insert into t values (?,?)" + code := TaosStmt2Prepare(stmt2, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code := TaosStmt2IsInsert(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + now := time.Now().Round(time.Millisecond) + colTypes := []*stmt.StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionMilliSecond, + }, + { + FieldType: common.TSDB_DATA_TYPE_INT, + }, + } + params := []*stmt.TaosStmt2BindData{ + { + TableName: "t", + Cols: [][]driver.Value{ + { + now, + now.Add(time.Second), + }, + { + int32(1), + int32(2), + }, + }, + }, + { + TableName: "t", + Cols: [][]driver.Value{ + { + now.Add(time.Second * 2), + now.Add(time.Second * 3), + }, + { + int32(3), + int32(4), + }, + }, + }, + } + bs, err := stmt.MarshalStmt2Binary(params, true, colTypes, nil) + if err != nil { + t.Error(err) + return + } + err = TaosStmt2BindBinary(stmt2, bs, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + assert.Equal(t, 4, r.affected) + code = TaosStmt2Prepare(stmt2, "select * from t where ts >= ? and ts <= ?") + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code = TaosStmt2IsInsert(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.False(t, isInsert) + params = []*stmt.TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + { + now, + }, + { + now.Add(time.Second * 3), + }, + }, + }, + } + bs, err = stmt.MarshalStmt2Binary(params, false, nil, nil) + if err != nil { + t.Error(err) + return + } + err = TaosStmt2BindBinary(stmt2, bs, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r = <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + res := r.res + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := TaosFetchRawBlock(res) + if errCode != 0 { + errStr := TaosErrorStr(res) + err = taosError.NewError(errCode, errStr) + t.Error(err) + return + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + assert.Equal(t, 4, len(result)) + assert.Equal(t, now, result[0][0]) + assert.Equal(t, now.Add(time.Second), result[1][0]) + assert.Equal(t, now.Add(time.Second*2), result[2][0]) + assert.Equal(t, now.Add(time.Second*3), result[3][0]) + assert.Equal(t, int32(1), result[0][1]) + assert.Equal(t, int32(2), result[1][1]) + assert.Equal(t, int32(3), result[2][1]) + assert.Equal(t, int32(4), result[3][1]) + code = TaosStmt2Close(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } +} + +func TestStmt2QueryAllType(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_query_all") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_query_all precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_query_all") + if err != nil { + t.Error(err) + return + } + + err = exec(conn, "create table if not exists t("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) + prepareInsertSql := "insert into t values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" + colTypes := []*stmt.StmtField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, + {FieldType: common.TSDB_DATA_TYPE_BOOL}, + {FieldType: common.TSDB_DATA_TYPE_TINYINT}, + {FieldType: common.TSDB_DATA_TYPE_SMALLINT}, + {FieldType: common.TSDB_DATA_TYPE_INT}, + {FieldType: common.TSDB_DATA_TYPE_BIGINT}, + {FieldType: common.TSDB_DATA_TYPE_UTINYINT}, + {FieldType: common.TSDB_DATA_TYPE_USMALLINT}, + {FieldType: common.TSDB_DATA_TYPE_UINT}, + {FieldType: common.TSDB_DATA_TYPE_UBIGINT}, + {FieldType: common.TSDB_DATA_TYPE_FLOAT}, + {FieldType: common.TSDB_DATA_TYPE_DOUBLE}, + {FieldType: common.TSDB_DATA_TYPE_BINARY}, + {FieldType: common.TSDB_DATA_TYPE_VARBINARY}, + {FieldType: common.TSDB_DATA_TYPE_GEOMETRY}, + {FieldType: common.TSDB_DATA_TYPE_NCHAR}, + } + + now := time.Now() + params2 := []*stmt.TaosStmt2BindData{{ + TableName: "t", + Cols: [][]driver.Value{ + { + now, + now.Add(time.Second), + now.Add(time.Second * 2), + }, + { + true, + nil, + false, + }, + { + int8(11), + nil, + int8(12), + }, + { + int16(11), + nil, + int16(12), + }, + { + int32(11), + nil, + int32(12), + }, + { + int64(11), + nil, + int64(12), + }, + { + uint8(11), + nil, + uint8(12), + }, + { + uint16(11), + nil, + uint16(12), + }, + { + uint32(11), + nil, + uint32(12), + }, + { + uint64(11), + nil, + uint64(12), + }, + { + float32(11.2), + nil, + float32(12.2), + }, + { + float64(11.2), + nil, + float64(12.2), + }, + { + []byte("binary1"), + nil, + []byte("binary2"), + }, + { + []byte("varbinary1"), + nil, + []byte("varbinary2"), + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + { + "nchar1", + nil, + "nchar2", + }, + }, + }} + code := TaosStmt2Prepare(stmt2, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code := TaosStmt2IsInsert(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + err = TaosStmt2BindParam(stmt2, true, params2, colTypes, nil, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + t.Log(r.affected) + assert.Equal(t, 3, r.affected) + code = TaosStmt2Prepare(stmt2, "select * from t where ts =? and v1 = ? and v2 = ? and v3 = ? and v4 = ? and v5 = ? and v6 = ? and v7 = ? and v8 = ? and v9 = ? and v10 = ? and v11 = ? and v12 = ? and v13 = ? and v14 = ? and v15 = ? ") + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code = TaosStmt2IsInsert(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.False(t, isInsert) + params := []*stmt.TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + {now}, + {true}, + {int8(11)}, + {int16(11)}, + {int32(11)}, + {int64(11)}, + {uint8(11)}, + {uint16(11)}, + {uint32(11)}, + {uint64(11)}, + {float32(11.2)}, + {float64(11.2)}, + {[]byte("binary1")}, + {[]byte("varbinary1")}, + {"point(100 100)"}, + {"nchar1"}, + }, + }, + } + err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r = <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + res := r.res + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := TaosFetchRawBlock(res) + if errCode != 0 { + errStr := TaosErrorStr(res) + err = taosError.NewError(errCode, errStr) + t.Error(err) + return + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + t.Log(result) + assert.Len(t, result, 1) +} + +func TestStmt2QueryAllTypeBytes(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_query_all_bytes") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_query_all_bytes precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_query_all_bytes") + if err != nil { + t.Error(err) + return + } + + err = exec(conn, "create table if not exists t("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) + prepareInsertSql := "insert into t values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" + colTypes := []*stmt.StmtField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, + {FieldType: common.TSDB_DATA_TYPE_BOOL}, + {FieldType: common.TSDB_DATA_TYPE_TINYINT}, + {FieldType: common.TSDB_DATA_TYPE_SMALLINT}, + {FieldType: common.TSDB_DATA_TYPE_INT}, + {FieldType: common.TSDB_DATA_TYPE_BIGINT}, + {FieldType: common.TSDB_DATA_TYPE_UTINYINT}, + {FieldType: common.TSDB_DATA_TYPE_USMALLINT}, + {FieldType: common.TSDB_DATA_TYPE_UINT}, + {FieldType: common.TSDB_DATA_TYPE_UBIGINT}, + {FieldType: common.TSDB_DATA_TYPE_FLOAT}, + {FieldType: common.TSDB_DATA_TYPE_DOUBLE}, + {FieldType: common.TSDB_DATA_TYPE_BINARY}, + {FieldType: common.TSDB_DATA_TYPE_VARBINARY}, + {FieldType: common.TSDB_DATA_TYPE_GEOMETRY}, + {FieldType: common.TSDB_DATA_TYPE_NCHAR}, + } + + now := time.Now() + params2 := []*stmt.TaosStmt2BindData{{ + TableName: "t", + Cols: [][]driver.Value{ + { + now, + now.Add(time.Second), + now.Add(time.Second * 2), + }, + { + true, + nil, + false, + }, + { + int8(11), + nil, + int8(12), + }, + { + int16(11), + nil, + int16(12), + }, + { + int32(11), + nil, + int32(12), + }, + { + int64(11), + nil, + int64(12), + }, + { + uint8(11), + nil, + uint8(12), + }, + { + uint16(11), + nil, + uint16(12), + }, + { + uint32(11), + nil, + uint32(12), + }, + { + uint64(11), + nil, + uint64(12), + }, + { + float32(11.2), + nil, + float32(12.2), + }, + { + float64(11.2), + nil, + float64(12.2), + }, + { + []byte("binary1"), + nil, + []byte("binary2"), + }, + { + []byte("varbinary1"), + nil, + []byte("varbinary2"), + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + { + "nchar1", + nil, + "nchar2", + }, + }, + }} + code := TaosStmt2Prepare(stmt2, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code := TaosStmt2IsInsert(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + bs, err := stmt.MarshalStmt2Binary(params2, true, colTypes, nil) + if err != nil { + t.Error(err) + return + } + err = TaosStmt2BindBinary(stmt2, bs, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + t.Log(r.affected) + assert.Equal(t, 3, r.affected) + code = TaosStmt2Prepare(stmt2, "select * from t where ts =? and v1 = ? and v2 = ? and v3 = ? and v4 = ? and v5 = ? and v6 = ? and v7 = ? and v8 = ? and v9 = ? and v10 = ? and v11 = ? and v12 = ? and v13 = ? and v14 = ? and v15 = ? ") + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code = TaosStmt2IsInsert(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.False(t, isInsert) + params := []*stmt.TaosStmt2BindData{ + { + Cols: [][]driver.Value{ + {now}, + {true}, + {int8(11)}, + {int16(11)}, + {int32(11)}, + {int64(11)}, + {uint8(11)}, + {uint16(11)}, + {uint32(11)}, + {uint64(11)}, + {float32(11.2)}, + {float64(11.2)}, + {[]byte("binary1")}, + {[]byte("varbinary1")}, + {"point(100 100)"}, + {"nchar1"}, + }, + }, + } + bs, err = stmt.MarshalStmt2Binary(params, false, nil, nil) + if err != nil { + t.Error(err) + return + } + err = TaosStmt2BindBinary(stmt2, bs, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r = <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + res := r.res + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := TaosFetchRawBlock(res) + if errCode != 0 { + errStr := TaosErrorStr(res) + err = taosError.NewError(errCode, errStr) + t.Error(err) + return + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + t.Log(result) + assert.Len(t, result, 1) +} + +func TestStmt2Json(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_json") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_json precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_json") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists test_json_stb(ts timestamp, v int) tags (t json)") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) + defer func() { + code := TaosStmt2Close(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + }() + prepareInsertSql := "insert into ? using test_json_stb tags(?) values (?,?)" + code := TaosStmt2Prepare(stmt2, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + now := time.Now().Round(time.Millisecond) + params := []*stmt.TaosStmt2BindData{{ + TableName: "ctb1", + Tags: []driver.Value{[]byte(`{"a":1,"b":"xx"}`)}, + Cols: [][]driver.Value{ + {now}, + {int32(1)}, + }, + }} + colTypes := []*stmt.StmtField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, + {FieldType: common.TSDB_DATA_TYPE_INT}, + } + tagTypes := []*stmt.StmtField{ + {FieldType: common.TSDB_DATA_TYPE_JSON}, + } + err = TaosStmt2BindParam(stmt2, true, params, colTypes, tagTypes, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + assert.Equal(t, 1, r.affected) + + TaosStmt2Prepare(stmt2, "select * from test_json_stb where t->'a' = ?") + params = []*stmt.TaosStmt2BindData{{ + Cols: [][]driver.Value{ + {int32(1)}, + }, + }} + err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r = <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(stmt2) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + res := r.res + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := TaosFetchRawBlock(res) + if errCode != 0 { + errStr := TaosErrorStr(res) + err = taosError.NewError(errCode, errStr) + t.Error(err) + return + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + t.Log(result) + assert.Equal(t, 1, len(result)) +} + +func TestStmt2BindMultiTables(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_multi") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_multi precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_multi") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists stb(ts timestamp, v bigint) tags(tv int)") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + insertStmt := TaosStmt2Init(conn, 0xcc123, false, false, handler) + prepareInsertSql := "insert into ? using stb tags(?) values (?,?)" + code := TaosStmt2Prepare(insertStmt, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + binds := []*stmt.TaosStmt2BindData{ + { + TableName: "table1", + Cols: [][]driver.Value{ + { + // ts 1726803356466 + time.Unix(1726803356, 466000000), + }, + { + int64(1), + }, + }, + Tags: []driver.Value{int32(1)}, + }, + { + TableName: "table2", + Cols: [][]driver.Value{ + { + // ts 1726803356466 + time.Unix(1726803356, 466000000), + }, + { + int64(2), + }, + }, + Tags: []driver.Value{int32(2)}, + }, + { + TableName: "table3", + Cols: [][]driver.Value{ + { + // ts 1726803356466 + time.Unix(1726803356, 466000000), + }, + { + int64(3), + }, + }, + Tags: []driver.Value{int32(3)}, + }, + } + colType := []*stmt.StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionMilliSecond, + }, + { + FieldType: common.TSDB_DATA_TYPE_BIGINT, + }, + } + tagType := []*stmt.StmtField{ + { + FieldType: common.TSDB_DATA_TYPE_INT, + }, + } + + isInsert, code := TaosStmt2IsInsert(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + + err = TaosStmt2BindParam(insertStmt, true, binds, colType, tagType, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + t.Log(r.affected) + + code = TaosStmt2Close(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } +} + +func TestTaosStmt2BindBinaryParse(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_binary_parse") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_binary_parse precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_binary_parse") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table test1 (ts timestamp, v int)") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table stb (ts timestamp, v int) tags(tv int)") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table test2 (ts timestamp, v binary(100))") + if err != nil { + t.Error(err) + return + } + type args struct { + sql string + data []byte + colIdx int32 + } + tests := []struct { + name string + args args + wantErr assert.ErrorAssertionFunc + }{ + { + name: "normal table name", + args: args{ + sql: "insert into ? values (?,?)", + data: []byte{ + // total Length + 0x24, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + // TableNameLength + 0x06, 0x00, + // test1 + 0x74, 0x65, 0x73, 0x74, 0x31, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.NoError, + }, + { + name: "empty table name", + args: args{ + sql: "insert into ? values (?,?)", + data: []byte{ + // total Length + 0x1e, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + // TableNameLength + 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong total length", + args: args{ + sql: "insert into ? values (?,?)", + data: []byte{ + // total Length + 0x24, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + // TableNameLength + 0x06, 0x00, + // test1 + 0x74, 0x65, 0x73, 0x74, 0x31, 0x00, + // + 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong table name offset", + args: args{ + sql: "insert into ? values (?,?)", + data: []byte{ + // total Length + 0x24, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x24, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + // TableNameLength + 0x06, 0x00, + // test1 + 0x74, 0x65, 0x73, 0x74, 0x31, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong table name length", + args: args{ + sql: "insert into ? values (?,?)", + data: []byte{ + // total Length + 0x24, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + // TableNameLength + 0x07, 0x00, + // test1 + 0x74, 0x65, 0x73, 0x74, 0x31, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "normal col", + args: args{ + sql: "insert into test1 values (?,?)", + data: []byte{ + // total Length + 0x50, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x02, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + 0x30, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, + + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x7b, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.NoError, + }, + { + name: "col zero length", + args: args{ + sql: "insert into test1 values (?,?)", + data: []byte{ + // total Length + 0x50, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x02, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + 0x00, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, + + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x7b, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong col offset", + args: args{ + sql: "insert into test1 values (?,?)", + data: []byte{ + // total Length + 0x50, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x02, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x50, 0x00, 0x00, 0x00, + // cols + 0x30, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, + + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x7b, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong col length", + args: args{ + sql: "insert into test1 values (?,?)", + data: []byte{ + // total Length + 0x50, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x02, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + 0x50, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, + + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x7b, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong col bind length", + args: args{ + sql: "insert into test1 values (?,?)", + data: []byte{ + // total Length + 0x50, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x02, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + 0x30, 0x00, 0x00, 0x00, + + 0x1b, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, + + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x7b, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "normal col count", + args: args{ + sql: "insert into test1 values (?,?)", + data: []byte{ + // total Length + 0x50, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + 0x30, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, + + 0x16, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x7b, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "normal tag", + args: args{ + sql: "insert into ? using stb tags(?) values (?,?)", + data: []byte{ + // total Length + 0x40, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x01, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x22, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + 0x04, 0x00, 0x63, 0x74, 0x62, 0x00, + // tags + 0x1a, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xc8, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.NoError, + }, + { + name: "tag zero length", + args: args{ + sql: "insert into ? using stb tags(?) values (?,?)", + data: []byte{ + // total Length + 0x40, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x01, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x22, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + 0x04, 0x00, 0x63, 0x74, 0x62, 0x00, + // tags + 0x00, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xc8, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong tag offset", + args: args{ + sql: "insert into ? using stb tags(?) values (?,?)", + data: []byte{ + // total Length + 0x40, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x01, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x40, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + 0x04, 0x00, 0x63, 0x74, 0x62, 0x00, + // tags + 0x1a, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xc8, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong tag length", + args: args{ + sql: "insert into ? using stb tags(?) values (?,?)", + data: []byte{ + // total Length + 0x40, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x01, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x22, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + 0x04, 0x00, 0x63, 0x74, 0x62, 0x00, + // tags + 0x40, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xc8, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong tag bind length", + args: args{ + sql: "insert into ? using stb tags(?) values (?,?)", + data: []byte{ + // total Length + 0x40, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x01, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x22, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + 0x04, 0x00, 0x63, 0x74, 0x62, 0x00, + // tags + 0x1a, 0x00, 0x00, 0x00, + + 0x40, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xc8, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong tag count", + args: args{ + sql: "insert into ? using stb tags(?) values (?,?)", + data: []byte{ + // total Length + 0x40, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x00, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x1c, 0x00, 0x00, 0x00, + // TagsOffset + 0x22, 0x00, 0x00, 0x00, + // ColOffset + 0x00, 0x00, 0x00, 0x00, + // table names + 0x04, 0x00, 0x63, 0x74, 0x62, 0x00, + // tags + 0x1a, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xc8, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "wrong param count", + args: args{ + sql: "insert into test1 values (?,?)", + data: []byte{ + // total Length + 0x3A, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x01, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + 0x1a, 0x00, 0x00, 0x00, + + 0x1a, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x08, 0x00, 0x00, 0x00, + 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, + { + name: "bind binary", + args: args{ + sql: "insert into test2 values (?,?)", + data: []byte{ + // total Length + 0x78, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x02, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + // col length + 0x58, 0x00, 0x00, 0x00, + //table 0 cols + //col 0 + //total length + 0x2c, 0x00, 0x00, 0x00, + //type + 0x09, 0x00, 0x00, 0x00, + //num + 0x03, 0x00, 0x00, 0x00, + //is null + 0x00, + 0x00, + 0x00, + // haveLength + 0x00, + // buffer length + 0x18, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, 0x1a, 0x2f, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, 0x02, 0x33, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + //col 1 + //total length + 0x2c, 0x00, 0x00, 0x00, + //type + 0x08, 0x00, 0x00, 0x00, + //num + 0x03, 0x00, 0x00, 0x00, + //is null + 0x00, + 0x01, + 0x00, + // haveLength + 0x01, + // length + 0x06, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + // buffer length + 0x0c, 0x00, 0x00, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + }, + colIdx: -1, + }, + wantErr: assert.NoError, + }, + { + name: "empty buffer", + args: args{ + sql: "insert into test2 values (?,?)", + data: []byte{ + // total Length + 0x4c, 0x00, 0x00, 0x00, + // tableCount + 0x01, 0x00, 0x00, 0x00, + // TagCount + 0x00, 0x00, 0x00, 0x00, + // ColCount + 0x02, 0x00, 0x00, 0x00, + // TableNamesOffset + 0x00, 0x00, 0x00, 0x00, + // TagsOffset + 0x00, 0x00, 0x00, 0x00, + // ColOffset + 0x1c, 0x00, 0x00, 0x00, + // cols + // col length + 0x2c, 0x00, 0x00, 0x00, + //table 0 cols + //col 0 + //total length + 0x1a, 0x00, 0x00, 0x00, + //type + 0x09, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x00, + // haveLength + 0x00, + // buffer length + 0x08, 0x00, 0x00, 0x00, + 0x32, 0x2b, 0x80, 0x0d, 0x92, 0x01, 0x00, 0x00, + + //col 1 + //total length + 0x12, 0x00, 0x00, 0x00, + //type + 0x04, 0x00, 0x00, 0x00, + //num + 0x01, 0x00, 0x00, 0x00, + //is null + 0x01, + // haveLength + 0x00, + // buffer length + 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xdd123, false, false, handler) + defer TaosStmt2Close(stmt2) + code := TaosStmt2Prepare(stmt2, tt.args.sql) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + tt.wantErr(t, TaosStmt2BindBinary(stmt2, tt.args.data, tt.args.colIdx), fmt.Sprintf("TaosStmt2BindBinary(%v, %v, %v)", stmt2, tt.args.data, tt.args.colIdx)) + }) + } +} diff --git a/driver/wrapper/stmt2async.go b/driver/wrapper/stmt2async.go new file mode 100644 index 00000000..cc3babb6 --- /dev/null +++ b/driver/wrapper/stmt2async.go @@ -0,0 +1,26 @@ +package wrapper + +/* +#include +#include +#include +#include + +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +type TaosStmt2CallbackCaller interface { + ExecCall(res unsafe.Pointer, affected int, code int) +} + +//export Stmt2ExecCallback +func Stmt2ExecCallback(p unsafe.Pointer, res *C.TAOS_RES, code C.int) { + caller := (*(*cgo.Handle)(p)).Value().(TaosStmt2CallbackCaller) + affectedRows := int(C.taos_affected_rows(unsafe.Pointer(res))) + caller.ExecCall(unsafe.Pointer(res), affectedRows, int(code)) +} diff --git a/driver/wrapper/stmt_test.go b/driver/wrapper/stmt_test.go new file mode 100644 index 00000000..9d63b7a9 --- /dev/null +++ b/driver/wrapper/stmt_test.go @@ -0,0 +1,1367 @@ +package wrapper + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/param" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" + taosError "github.com/taosdata/taosadapter/v3/driver/errors" + taosTypes "github.com/taosdata/taosadapter/v3/driver/types" +) + +// @author: xftan +// @date: 2022/1/27 17:27 +// @description: test stmt with taos_stmt_bind_param_batch +func TestStmt(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_wrapper") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_wrapper precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_wrapper") + if err != nil { + t.Error(err) + return + } + now := time.Now() + for i, tc := range []struct { + tbType string + pos string + params [][]driver.Value + bindType []*taosTypes.ColumnType + expectValue interface{} + }{ + { + tbType: "ts timestamp, v int", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosInt(1)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosIntType}}, + expectValue: int32(1), + }, + { + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosBool(true)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosBoolType}}, + expectValue: true, + }, + { + tbType: "ts timestamp, v tinyint", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosTinyint(1)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosTinyintType}}, + expectValue: int8(1), + }, + { + tbType: "ts timestamp, v smallint", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosSmallint(1)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosSmallintType}}, + expectValue: int16(1), + }, + { + tbType: "ts timestamp, v bigint", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosBigint(1)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosBigintType}}, + expectValue: int64(1), + }, + { + tbType: "ts timestamp, v tinyint unsigned", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosUTinyint(1)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosUTinyintType}}, + expectValue: uint8(1), + }, + { + tbType: "ts timestamp, v smallint unsigned", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosUSmallint(1)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosUSmallintType}}, + expectValue: uint16(1), + }, + { + tbType: "ts timestamp, v int unsigned", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosUInt(1)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosUIntType}}, + expectValue: uint32(1), + }, + { + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosUBigint(1)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosUBigintType}}, + expectValue: uint64(1), + }, + { + tbType: "ts timestamp, v float", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosFloat(1.2)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosFloatType}}, + expectValue: float32(1.2), + }, + { + tbType: "ts timestamp, v double", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosDouble(1.2)}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, {Type: taosTypes.TaosDoubleType}}, + expectValue: 1.2, + }, + { + tbType: "ts timestamp, v binary(8)", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosBinary("yes")}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, { + Type: taosTypes.TaosBinaryType, + MaxLen: 3, + }}, + expectValue: "yes", + }, //3 + { + tbType: "ts timestamp, v varbinary(8)", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosVarBinary("yes")}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, { + Type: taosTypes.TaosVarBinaryType, + MaxLen: 3, + }}, + expectValue: []byte("yes"), + }, //3 + { + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosGeometry{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, { + Type: taosTypes.TaosGeometryType, + MaxLen: 100, + }}, + expectValue: []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, //3 + { + tbType: "ts timestamp, v nchar(8)", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosNchar("yes")}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, { + Type: taosTypes.TaosNcharType, + MaxLen: 3, + }}, + expectValue: "yes", + }, //3 + { + tbType: "ts timestamp, v nchar(8)", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {nil}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, { + Type: taosTypes.TaosNcharType, + MaxLen: 1, + }}, + expectValue: nil, + }, //1 + } { + tbName := fmt.Sprintf("test_fast_insert_%02d", i) + tbType := tc.tbType + drop := fmt.Sprintf("drop table if exists %s", tbName) + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + name := fmt.Sprintf("%02d-%s", i, tbType) + pos := tc.pos + sql := fmt.Sprintf("insert into %s values(%s)", tbName, pos) + var err error + t.Run(name, func(t *testing.T) { + if err = exec(conn, drop); err != nil { + t.Error(err) + return + } + if err = exec(conn, create); err != nil { + t.Error(err) + return + } + insertStmt := TaosStmtInit(conn) + code := TaosStmtPrepare(insertStmt, sql) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + isInsert, code := TaosStmtIsInsert(insertStmt) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + if !isInsert { + t.Errorf("expect insert stmt") + return + } + code = TaosStmtBindParamBatch(insertStmt, tc.params, tc.bindType) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtAddBatch(insertStmt) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtExecute(insertStmt) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtClose(insertStmt) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + result, err := query(conn, fmt.Sprintf("select v from %s", tbName)) + if err != nil { + t.Error(err) + return + } + if len(result) != 1 { + t.Errorf("expect %d got %d", 1, len(result)) + return + } + assert.Equal(t, tc.expectValue, result[0][0]) + }) + } + +} + +// @author: xftan +// @date: 2022/1/27 17:27 +// @description: test stmt insert with taos_stmt_bind_param +func TestStmtExec(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_wrapper") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_wrapper precision 'us' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_wrapper") + if err != nil { + t.Error(err) + return + } + now := time.Now() + for i, tc := range []struct { + tbType string + pos string + params []driver.Value + expectValue interface{} + }{ + { + tbType: "ts timestamp, v int", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosInt(1)}, + expectValue: int32(1), + }, + { + tbType: "ts timestamp, v bool", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosBool(true)}, + expectValue: true, + }, + { + tbType: "ts timestamp, v tinyint", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosTinyint(1)}, + expectValue: int8(1), + }, + { + tbType: "ts timestamp, v smallint", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosSmallint(1)}, + expectValue: int16(1), + }, + { + tbType: "ts timestamp, v bigint", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosBigint(1)}, + expectValue: int64(1), + }, + { + tbType: "ts timestamp, v tinyint unsigned", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosUTinyint(1)}, + expectValue: uint8(1), + }, + { + tbType: "ts timestamp, v smallint unsigned", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosUSmallint(1)}, + expectValue: uint16(1), + }, + { + tbType: "ts timestamp, v int unsigned", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosUInt(1)}, + expectValue: uint32(1), + }, + { + tbType: "ts timestamp, v bigint unsigned", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosUBigint(1)}, + expectValue: uint64(1), + }, + { + tbType: "ts timestamp, v float", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosFloat(1.2)}, + expectValue: float32(1.2), + }, + { + tbType: "ts timestamp, v double", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosDouble(1.2)}, + expectValue: 1.2, + }, + { + tbType: "ts timestamp, v binary(8)", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosBinary("yes")}, + expectValue: "yes", + }, //3 + { + tbType: "ts timestamp, v varbinary(8)", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosVarBinary("yes")}, + expectValue: []byte("yes"), + }, //3 + { + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosGeometry{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + expectValue: []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, //3 + { + tbType: "ts timestamp, v nchar(8)", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosNchar("yes")}, + expectValue: "yes", + }, //3 + { + tbType: "ts timestamp, v nchar(8)", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, nil}, + expectValue: nil, + }, //1 + } { + tbName := fmt.Sprintf("test_fast_insert_2_%02d", i) + tbType := tc.tbType + drop := fmt.Sprintf("drop table if exists %s", tbName) + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + name := fmt.Sprintf("%02d-%s", i, tbType) + pos := tc.pos + sql := fmt.Sprintf("insert into %s values(%s)", tbName, pos) + var err error + t.Run(name, func(t *testing.T) { + if err = exec(conn, drop); err != nil { + t.Error(err) + return + } + if err = exec(conn, create); err != nil { + t.Error(err) + return + } + insertStmt := TaosStmtInit(conn) + code := TaosStmtPrepare(insertStmt, sql) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtBindParam(insertStmt, tc.params) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtAddBatch(insertStmt) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtExecute(insertStmt) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + affectedRows := TaosStmtAffectedRowsOnce(insertStmt) + if affectedRows != 1 { + t.Errorf("expect 1 got %d", affectedRows) + return + } + code = TaosStmtClose(insertStmt) + if code != 0 { + errStr := TaosStmtErrStr(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + result, err := query(conn, fmt.Sprintf("select v from %s", tbName)) + if err != nil { + t.Error(err) + return + } + if len(result) != 1 { + t.Errorf("expect %d got %d", 1, len(result)) + return + } + assert.Equal(t, tc.expectValue, result[0][0]) + }) + } +} + +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test stmt query +func TestStmtQuery(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + err = exec(conn, "create database if not exists test_wrapper precision 'us' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_wrapper") + if err != nil { + t.Error(err) + return + } + for i, tc := range []struct { + tbType string + data string + clause string + params *param.Param + skip bool + }{ + { + tbType: "ts timestamp, v int", + data: "0, 1", + clause: "v = ?", + params: param.NewParam(1).AddInt(1), + }, + { + tbType: "ts timestamp, v bool", + data: "now, true", + clause: "v = ?", + params: param.NewParam(1).AddBool(true), + }, + { + tbType: "ts timestamp, v tinyint", + data: "now, 3", + clause: "v = ?", + params: param.NewParam(1).AddTinyint(3), + }, + { + tbType: "ts timestamp, v smallint", + data: "now, 5", + clause: "v = ?", + params: param.NewParam(1).AddSmallint(5), + }, + { + tbType: "ts timestamp, v int", + data: "now, 6", + clause: "v = ?", + params: param.NewParam(1).AddInt(6), + }, + { + tbType: "ts timestamp, v bigint", + data: "now, 7", + clause: "v = ?", + params: param.NewParam(1).AddBigint(7), + }, + { + tbType: "ts timestamp, v tinyint unsigned", + data: "now, 1", + clause: "v = ?", + params: param.NewParam(1).AddUTinyint(1), + }, + { + tbType: "ts timestamp, v smallint unsigned", + data: "now, 2", + clause: "v = ?", + params: param.NewParam(1).AddUSmallint(2), + }, + { + tbType: "ts timestamp, v int unsigned", + data: "now, 3", + clause: "v = ?", + params: param.NewParam(1).AddUInt(3), + }, + { + tbType: "ts timestamp, v bigint unsigned", + data: "now, 4", + clause: "v = ?", + params: param.NewParam(1).AddUBigint(4), + }, + { + tbType: "ts timestamp, v tinyint unsigned", + data: "now, 1", + clause: "v = ?", + params: param.NewParam(1).AddUTinyint(1), + }, + { + tbType: "ts timestamp, v smallint unsigned", + data: "now, 2", + clause: "v = ?", + params: param.NewParam(1).AddUSmallint(2), + }, + { + tbType: "ts timestamp, v int unsigned", + data: "now, 3", + clause: "v = ?", + params: param.NewParam(1).AddUInt(3), + }, + { + tbType: "ts timestamp, v bigint unsigned", + data: "now, 4", + clause: "v = ?", + params: param.NewParam(1).AddUBigint(4), + }, + { + tbType: "ts timestamp, v float", + data: "now, 1.2", + clause: "v = ?", + params: param.NewParam(1).AddFloat(1.2), + }, + { + tbType: "ts timestamp, v double", + data: "now, 1.3", + clause: "v = ?", + params: param.NewParam(1).AddDouble(1.3), + }, + { + tbType: "ts timestamp, v double", + data: "now, 1.4", + clause: "v = ?", + params: param.NewParam(1).AddDouble(1.4), + }, + { + tbType: "ts timestamp, v binary(8)", + data: "now, 'yes'", + clause: "v = ?", + params: param.NewParam(1).AddBinary([]byte("yes")), + }, + { + tbType: "ts timestamp, v nchar(8)", + data: "now, 'OK'", + clause: "v = ?", + params: param.NewParam(1).AddNchar("OK"), + }, + { + tbType: "ts timestamp, v nchar(8)", + data: "1622282105000000, 'NOW'", + clause: "ts = ? and v = ?", + params: param.NewParam(2).AddTimestamp(time.Unix(1622282105, 0), common.PrecisionMicroSecond).AddBinary([]byte("NOW")), + }, + { + tbType: "ts timestamp, v nchar(8)", + data: "1622282105000000, 'NOW'", + clause: "ts = ? and v = ?", + params: param.NewParam(2).AddBigint(1622282105000000).AddBinary([]byte("NOW")), + }, + } { + tbName := fmt.Sprintf("test_stmt_query%02d", i) + tbType := tc.tbType + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + insert := fmt.Sprintf("insert into %s values(%s)", tbName, tc.data) + params := tc.params + sql := fmt.Sprintf("select * from %s where %s", tbName, tc.clause) + name := fmt.Sprintf("%02d-%s", i, tbType) + var err error + t.Run(name, func(t *testing.T) { + if tc.skip { + t.Skip("Skip, not support yet") + } + if err = exec(conn, create); err != nil { + t.Error(err) + return + } + if err = exec(conn, insert); err != nil { + t.Error(err) + return + } + + rows, err := StmtQuery(t, conn, sql, params) + if err != nil { + t.Error(err) + return + } + t.Log(rows) + }) + } +} + +func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { + res := TaosQuery(conn, sql) + defer TaosFreeResult(res) + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + return nil, taosError.NewError(code, errStr) + } + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + return nil, err + } + precision := TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := TaosFetchRawBlock(res) + if errCode != 0 { + errStr := TaosErrorStr(res) + return nil, taosError.NewError(errCode, errStr) + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + return result, nil +} + +func StmtQuery(t *testing.T, conn unsafe.Pointer, sql string, params *param.Param) (rows [][]driver.Value, err error) { + stmt := TaosStmtInit(conn) + if stmt == nil { + err = taosError.NewError(0xffff, "failed to init stmt") + return + } + defer TaosStmtClose(stmt) + code := TaosStmtPrepare(stmt, sql) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + return nil, taosError.NewError(code, errStr) + } + value := params.GetValues() + code = TaosStmtBindParam(stmt, value) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + return nil, taosError.NewError(code, errStr) + } + code = TaosStmtExecute(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + return nil, taosError.NewError(code, errStr) + } + res := TaosStmtUseResult(stmt) + numFields := TaosFieldCount(res) + rowsHeader, err := ReadColumn(res, numFields) + t.Log(rowsHeader) + if err != nil { + return nil, err + } + precision := TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(taosError.SUCCESS) { + errStr := TaosErrorStr(res) + err := taosError.NewError(code, errStr) + return nil, err + } + if blockSize == 0 { + break + } + d := parser.ReadBlock(block, blockSize, rowsHeader.ColTypes, precision) + data = append(data, d...) + } + return data, nil +} + +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test get field +func TestGetFields(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + stmt := TaosStmtInit(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt_field") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt_field") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists test_stmt_field.all_type(ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ")"+ + "tags(tts timestamp,"+ + "tc1 bool,"+ + "tc2 tinyint,"+ + "tc3 smallint,"+ + "tc4 int,"+ + "tc5 bigint,"+ + "tc6 tinyint unsigned,"+ + "tc7 smallint unsigned,"+ + "tc8 int unsigned,"+ + "tc9 bigint unsigned,"+ + "tc10 float,"+ + "tc11 double,"+ + "tc12 binary(20),"+ + "tc13 nchar(20)"+ + ")") + if err != nil { + t.Error(err) + return + } + code := TaosStmtPrepare(stmt, "insert into ? using test_stmt_field.all_type tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtSetTBName(stmt, "test_stmt_field.ct2") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code, tagCount, tagsP := TaosStmtGetTagFields(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmtReclaimFields(stmt, tagsP) + code, columnCount, columnsP := TaosStmtGetColFields(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmtReclaimFields(stmt, columnsP) + columns := StmtParseFields(columnCount, columnsP) + tags := StmtParseFields(tagCount, tagsP) + assert.Equal(t, []*stmtCommon.StmtField{ + {Name: "ts", FieldType: 9, Bytes: 8}, + {Name: "c1", FieldType: 1, Bytes: 1}, + {Name: "c2", FieldType: 2, Bytes: 1}, + {Name: "c3", FieldType: 3, Bytes: 2}, + {Name: "c4", FieldType: 4, Bytes: 4}, + {Name: "c5", FieldType: 5, Bytes: 8}, + {Name: "c6", FieldType: 11, Bytes: 1}, + {Name: "c7", FieldType: 12, Bytes: 2}, + {Name: "c8", FieldType: 13, Bytes: 4}, + {Name: "c9", FieldType: 14, Bytes: 8}, + {Name: "c10", FieldType: 6, Bytes: 4}, + {Name: "c11", FieldType: 7, Bytes: 8}, + {Name: "c12", FieldType: 8, Bytes: 22}, + {Name: "c13", FieldType: 10, Bytes: 82}, + }, columns) + assert.Equal(t, []*stmtCommon.StmtField{ + {Name: "tts", FieldType: 9, Bytes: 8}, + {Name: "tc1", FieldType: 1, Bytes: 1}, + {Name: "tc2", FieldType: 2, Bytes: 1}, + {Name: "tc3", FieldType: 3, Bytes: 2}, + {Name: "tc4", FieldType: 4, Bytes: 4}, + {Name: "tc5", FieldType: 5, Bytes: 8}, + {Name: "tc6", FieldType: 11, Bytes: 1}, + {Name: "tc7", FieldType: 12, Bytes: 2}, + {Name: "tc8", FieldType: 13, Bytes: 4}, + {Name: "tc9", FieldType: 14, Bytes: 8}, + {Name: "tc10", FieldType: 6, Bytes: 4}, + {Name: "tc11", FieldType: 7, Bytes: 8}, + {Name: "tc12", FieldType: 8, Bytes: 22}, + {Name: "tc13", FieldType: 10, Bytes: 82}, + }, tags) +} + +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test get fields with common table +func TestGetFieldsCommonTable(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + stmt := TaosStmtInit(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt_field") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt_field") + if err != nil { + t.Error(err) + return + } + TaosSelectDB(conn, "test_stmt_field") + err = exec(conn, "create table if not exists ct(ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ")") + if err != nil { + t.Error(err) + return + } + code := TaosStmtPrepare(stmt, "insert into ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code, num, _ := TaosStmtGetTagFields(stmt) + assert.Equal(t, 0, code) + assert.Equal(t, 0, num) + code, columnCount, columnsP := TaosStmtGetColFields(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmtReclaimFields(stmt, columnsP) + columns := StmtParseFields(columnCount, columnsP) + assert.Equal(t, []*stmtCommon.StmtField{ + {Name: "ts", FieldType: 9, Bytes: 8}, + {Name: "c1", FieldType: 1, Bytes: 1}, + {Name: "c2", FieldType: 2, Bytes: 1}, + {Name: "c3", FieldType: 3, Bytes: 2}, + {Name: "c4", FieldType: 4, Bytes: 4}, + {Name: "c5", FieldType: 5, Bytes: 8}, + {Name: "c6", FieldType: 11, Bytes: 1}, + {Name: "c7", FieldType: 12, Bytes: 2}, + {Name: "c8", FieldType: 13, Bytes: 4}, + {Name: "c9", FieldType: 14, Bytes: 8}, + {Name: "c10", FieldType: 6, Bytes: 4}, + {Name: "c11", FieldType: 7, Bytes: 8}, + {Name: "c12", FieldType: 8, Bytes: 22}, + {Name: "c13", FieldType: 10, Bytes: 82}, + }, columns) +} + +func exec(conn unsafe.Pointer, sql string) error { + res := TaosQuery(conn, sql) + defer TaosFreeResult(res) + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + return taosError.NewError(code, errStr) + } + return nil +} + +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test stmt set tags +func TestTaosStmtSetTags(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + err = exec(conn, "drop database if exists test_wrapper") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create database if not exists test_wrapper precision 'us' keep 36500") + if err != nil { + t.Error(err) + return + } + defer func() { + _ = exec(conn, "drop database if exists test_wrapper") + }() + err = exec(conn, "create table if not exists test_wrapper.tgs(ts timestamp,v int) tags (tts timestamp,"+ + "t1 bool,"+ + "t2 tinyint,"+ + "t3 smallint,"+ + "t4 int,"+ + "t5 bigint,"+ + "t6 tinyint unsigned,"+ + "t7 smallint unsigned,"+ + "t8 int unsigned,"+ + "t9 bigint unsigned,"+ + "t10 float,"+ + "t11 double,"+ + "t12 binary(20),"+ + "t13 nchar(20))") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists test_wrapper.json_tag (ts timestamp,v int) tags (info json)") + if err != nil { + t.Error(err) + return + } + stmt := TaosStmtInit(conn) + if stmt == nil { + err = taosError.NewError(0xffff, "failed to init stmt") + t.Error(err) + return + } + //defer TaosStmtClose(stmt) + code := TaosStmtPrepare(stmt, "insert into ? using test_wrapper.tgs tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?) values (?,?)") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + + code = TaosStmtSetTBName(stmt, "test_wrapper.t0") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + now := time.Now() + code = TaosStmtSetTags(stmt, param.NewParam(14). + AddTimestamp(now, common.PrecisionMicroSecond). + AddBool(true). + AddTinyint(2). + AddSmallint(3). + AddInt(4). + AddBigint(5). + AddUTinyint(6). + AddUSmallint(7). + AddUInt(8). + AddUBigint(9). + AddFloat(10). + AddDouble(11). + AddBinary([]byte("binary")). + AddNchar("nchar"). + GetValues()) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtBindParam(stmt, param.NewParam(2).AddTimestamp(now, common.PrecisionMicroSecond).AddInt(100).GetValues()) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtAddBatch(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtExecute(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtSetSubTBName(stmt, "test_wrapper.t1") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtSetTags(stmt, param.NewParam(14). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + AddNull(). + GetValues()) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtBindParam(stmt, param.NewParam(2).AddTimestamp(now, common.PrecisionMicroSecond).AddInt(101).GetValues()) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtAddBatch(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtExecute(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + + code = TaosStmtPrepare(stmt, "insert into ? using test_wrapper.json_tag tags(?) values (?,?)") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtSetTBName(stmt, "test_wrapper.t2") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtSetTags(stmt, param.NewParam(1).AddJson([]byte(`{"a":"b"}`)).GetValues()) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtBindParam(stmt, param.NewParam(2).AddTimestamp(now, common.PrecisionMicroSecond).AddInt(102).GetValues()) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtAddBatch(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtExecute(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + code = TaosStmtClose(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + t.Error(taosError.NewError(code, errStr)) + return + } + data, err := query(conn, "select tbname,tgs.* from test_wrapper.tgs where v >= 100") + if err != nil { + t.Error(err) + return + } + + assert.Equal(t, 2, len(data)) + for i := 0; i < 2; i++ { + + switch data[i][0] { + case "t0": + assert.Equal(t, now.UTC().UnixNano()/1e3, data[i][1].(time.Time).UTC().UnixNano()/1e3) + assert.Equal(t, int32(100), data[i][2].(int32)) + assert.Equal(t, now.UTC().UnixNano()/1e3, data[i][3].(time.Time).UTC().UnixNano()/1e3) + assert.Equal(t, true, data[i][4].(bool)) + assert.Equal(t, int8(2), data[i][5].(int8)) + assert.Equal(t, int16(3), data[i][6].(int16)) + assert.Equal(t, int32(4), data[i][7].(int32)) + assert.Equal(t, int64(5), data[i][8].(int64)) + assert.Equal(t, uint8(6), data[i][9].(uint8)) + assert.Equal(t, uint16(7), data[i][10].(uint16)) + assert.Equal(t, uint32(8), data[i][11].(uint32)) + assert.Equal(t, uint64(9), data[i][12].(uint64)) + assert.Equal(t, float32(10), data[i][13].(float32)) + assert.Equal(t, float64(11), data[i][14].(float64)) + assert.Equal(t, "binary", data[i][15].(string)) + assert.Equal(t, "nchar", data[i][16].(string)) + case "t1": + assert.Equal(t, now.UTC().UnixNano()/1e3, data[i][1].(time.Time).UTC().UnixNano()/1e3) + assert.Equal(t, int32(101), data[i][2].(int32)) + for j := 0; j < 14; j++ { + assert.Nil(t, data[i][3+j]) + } + } + } + + data, err = query(conn, "select tbname,json_tag.* from test_wrapper.json_tag where v >= 100") + if err != nil { + t.Error(err) + return + } + assert.Equal(t, 1, len(data)) + assert.Equal(t, "t2", data[0][0].(string)) + assert.Equal(t, now.UTC().UnixNano()/1e3, data[0][1].(time.Time).UTC().UnixNano()/1e3) + assert.Equal(t, int32(102), data[0][2].(int32)) + assert.Equal(t, []byte(`{"a":"b"}`), data[0][3].([]byte)) +} + +func TestTaosStmtGetParam(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + + err = exec(conn, "drop database if exists test_stmt_get_param") + assert.NoError(t, err) + err = exec(conn, "create database if not exists test_stmt_get_param") + assert.NoError(t, err) + defer func() { + err = exec(conn, "drop database if exists test_stmt_get_param") + assert.NoError(t, err) + }() + + err = exec(conn, + "create table if not exists test_stmt_get_param.stb(ts TIMESTAMP,current float,voltage int,phase float) TAGS (groupid int,location varchar(24))") + assert.NoError(t, err) + + stmt := TaosStmtInit(conn) + assert.NotNilf(t, stmt, "failed to init stmt") + defer TaosStmtClose(stmt) + + code := TaosStmtPrepare(stmt, "insert into test_stmt_get_param.tb_0 using test_stmt_get_param.stb tags(?,?) values (?,?,?,?)") + assert.Equal(t, 0, code, TaosStmtErrStr(stmt)) + + dt, dl, err := TaosStmtGetParam(stmt, 0) // ts + assert.NoError(t, err) + assert.Equal(t, 9, dt) + assert.Equal(t, 8, dl) + + dt, dl, err = TaosStmtGetParam(stmt, 1) // current + assert.NoError(t, err) + assert.Equal(t, 6, dt) + assert.Equal(t, 4, dl) + + dt, dl, err = TaosStmtGetParam(stmt, 2) // voltage + assert.NoError(t, err) + assert.Equal(t, 4, dt) + assert.Equal(t, 4, dl) + + dt, dl, err = TaosStmtGetParam(stmt, 3) // phase + assert.NoError(t, err) + assert.Equal(t, 6, dt) + assert.Equal(t, 4, dl) + + _, _, err = TaosStmtGetParam(stmt, 4) // invalid index + assert.Error(t, err) +} + +func TestStmtJson(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt_json") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt_json precision 'ms' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt_json") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table test_json_stb(ts timestamp, v int) tags (t json)") + if err != nil { + t.Error(err) + return + } + stmt := TaosStmtInitWithReqID(conn, 0xbb123) + defer func() { + code := TaosStmtClose(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + }() + prepareInsertSql := "insert into ? using test_json_stb tags(?) values (?,?)" + code := TaosStmtPrepare(stmt, prepareInsertSql) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtSetTBNameTags(stmt, "ctb1", param.NewParam(1).AddJson([]byte(`{"a":1,"b":"xx"}`)).GetValues()) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + now := time.Now().Round(time.Millisecond) + args := param.NewParam(2).AddTimestamp(now, common.PrecisionMilliSecond).AddInt(1).GetValues() + code = TaosStmtBindParam(stmt, args) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + + code = TaosStmtAddBatch(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtExecute(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + affected := TaosStmtAffectedRowsOnce(stmt) + assert.Equal(t, 1, affected) + + code = TaosStmtPrepare(stmt, "select * from test_json_stb where t->'a' = ?") + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + count, code := TaosStmtNumParams(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.Equal(t, 1, count) + code = TaosStmtBindParam(stmt, param.NewParam(1).AddBigint(1).GetValues()) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + code = TaosStmtExecute(stmt) + if code != 0 { + errStr := TaosStmtErrStr(stmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + res := TaosStmtUseResult(stmt) + + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := TaosFetchRawBlock(res) + if errCode != 0 { + errStr := TaosErrorStr(res) + err = taosError.NewError(errCode, errStr) + t.Error(err) + return + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + t.Log(result) +} diff --git a/driver/wrapper/taosc.go b/driver/wrapper/taosc.go new file mode 100644 index 00000000..e4ccb12b --- /dev/null +++ b/driver/wrapper/taosc.go @@ -0,0 +1,289 @@ +package wrapper + +/* +#cgo CFLAGS: -IC:/TDengine/include -I/usr/include +#cgo linux LDFLAGS: -L/usr/lib -ltaos +#cgo windows LDFLAGS: -LC:/TDengine/driver -ltaos +#cgo darwin LDFLAGS: -L/usr/local/lib -ltaos +#include +#include +#include +#include +extern void QueryCallback(void *param,TAOS_RES *,int code); +extern void FetchRowsCallback(void *param,TAOS_RES *,int numOfRows); +extern void FetchRawBlockCallback(void *param,TAOS_RES *,int numOfRows); +int taos_options_wrapper(TSDB_OPTION option, char *arg) { + return taos_options(option,arg); +}; +void taos_fetch_rows_a_wrapper(TAOS_RES *res, void *param){ + return taos_fetch_rows_a(res,FetchRowsCallback,param); +}; +void taos_query_a_wrapper(TAOS *taos,const char *sql, void *param){ + return taos_query_a(taos,sql,QueryCallback,param); +}; +void taos_query_a_with_req_id_wrapper(TAOS *taos,const char *sql, void *param, int64_t reqID){ + return taos_query_a_with_reqid(taos, sql, QueryCallback, param, reqID); +}; +void taos_fetch_raw_block_a_wrapper(TAOS_RES *res, void *param){ + return taos_fetch_raw_block_a(res,FetchRawBlockCallback,param); +}; +*/ +import "C" +import ( + "strings" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/tools" +) + +// TaosFreeResult void taos_free_result(TAOS_RES *res); +func TaosFreeResult(res unsafe.Pointer) { + C.taos_free_result(res) +} + +// TaosConnect TAOS *taos_connect(const char *ip, const char *user, const char *pass, const char *db, uint16_t port); +func TaosConnect(host, user, pass, db string, port int) (taos unsafe.Pointer, err error) { + cUser := C.CString(user) + defer C.free(unsafe.Pointer(cUser)) + cPass := C.CString(pass) + defer C.free(unsafe.Pointer(cPass)) + cdb := (*C.char)(nil) + if len(db) > 0 { + cdb = C.CString(db) + defer C.free(unsafe.Pointer(cdb)) + } + var taosObj unsafe.Pointer + if len(host) == 0 { + taosObj = C.taos_connect(nil, cUser, cPass, cdb, (C.ushort)(0)) + } else { + cHost := C.CString(host) + defer C.free(unsafe.Pointer(cHost)) + taosObj = C.taos_connect(cHost, cUser, cPass, cdb, (C.ushort)(port)) + } + + if taosObj == nil { + errCode := TaosError(nil) + return nil, errors.NewError(errCode, TaosErrorStr(nil)) + } + return taosObj, nil +} + +// TaosClose void taos_close(TAOS *taos); +func TaosClose(taosConnect unsafe.Pointer) { + C.taos_close(taosConnect) +} + +// TaosQuery TAOS_RES *taos_query(TAOS *taos, const char *sql); +func TaosQuery(taosConnect unsafe.Pointer, sql string) unsafe.Pointer { + cSql := C.CString(sql) + defer C.free(unsafe.Pointer(cSql)) + return unsafe.Pointer(C.taos_query(taosConnect, cSql)) +} + +// TaosQueryWithReqID TAOS_RES *taos_query_with_reqid(TAOS *taos, const char *sql, int64_t reqID); +func TaosQueryWithReqID(taosConn unsafe.Pointer, sql string, reqID int64) unsafe.Pointer { + cSql := C.CString(sql) + defer C.free(unsafe.Pointer(cSql)) + return unsafe.Pointer(C.taos_query_with_reqid(taosConn, cSql, (C.int64_t)(reqID))) +} + +// TaosError int taos_errno(TAOS_RES *tres); +func TaosError(result unsafe.Pointer) int { + return int(C.taos_errno(result)) +} + +// TaosErrorStr char *taos_errstr(TAOS_RES *tres); +func TaosErrorStr(result unsafe.Pointer) string { + return C.GoString(C.taos_errstr(result)) +} + +// TaosFieldCount int taos_field_count(TAOS_RES *res); +func TaosFieldCount(result unsafe.Pointer) int { + return int(C.taos_field_count(result)) +} + +// TaosAffectedRows int taos_affected_rows(TAOS_RES *res); +func TaosAffectedRows(result unsafe.Pointer) int { + return int(C.taos_affected_rows(result)) +} + +// TaosFetchFields TAOS_FIELD *taos_fetch_fields(TAOS_RES *res); +func TaosFetchFields(result unsafe.Pointer) unsafe.Pointer { + return unsafe.Pointer(C.taos_fetch_fields(result)) +} + +// TaosFetchBlock int taos_fetch_block(TAOS_RES *res, TAOS_ROW *rows); +func TaosFetchBlock(result unsafe.Pointer) (int, unsafe.Pointer) { + var block C.TAOS_ROW + b := unsafe.Pointer(&block) + blockSize := int(C.taos_fetch_block(result, (*C.TAOS_ROW)(b))) + return blockSize, b +} + +// TaosResultPrecision int taos_result_precision(TAOS_RES *res); +func TaosResultPrecision(result unsafe.Pointer) int { + return int(C.taos_result_precision(result)) +} + +// TaosNumFields int taos_num_fields(TAOS_RES *res); +func TaosNumFields(result unsafe.Pointer) int { + return int(C.taos_num_fields(result)) +} + +// TaosFetchRow TAOS_ROW taos_fetch_row(TAOS_RES *res); +func TaosFetchRow(result unsafe.Pointer) unsafe.Pointer { + return unsafe.Pointer(C.taos_fetch_row(result)) +} + +// TaosSelectDB int taos_select_db(TAOS *taos, const char *db); +func TaosSelectDB(taosConnect unsafe.Pointer, db string) int { + cDB := C.CString(db) + defer C.free(unsafe.Pointer(cDB)) + return int(C.taos_select_db(taosConnect, cDB)) +} + +// TaosOptions int taos_options(TSDB_OPTION option, const void *arg, ...); +func TaosOptions(option int, value string) int { + cValue := C.CString(value) + defer C.free(unsafe.Pointer(cValue)) + return int(C.taos_options_wrapper((C.TSDB_OPTION)(option), cValue)) +} + +// TaosQueryA void taos_query_a(TAOS *taos, const char *sql, void (*fp)(void *param, TAOS_RES *, int code), void *param); +func TaosQueryA(taosConnect unsafe.Pointer, sql string, caller cgo.Handle) { + cSql := C.CString(sql) + defer C.free(unsafe.Pointer(cSql)) + C.taos_query_a_wrapper(taosConnect, cSql, caller.Pointer()) +} + +// TaosQueryAWithReqID void taos_query_a_with_reqid(TAOS *taos, const char *sql, __taos_async_fn_t fp, void *param, int64_t reqid); +func TaosQueryAWithReqID(taosConn unsafe.Pointer, sql string, caller cgo.Handle, reqID int64) { + cSql := C.CString(sql) + defer C.free(unsafe.Pointer(cSql)) + C.taos_query_a_with_req_id_wrapper(taosConn, cSql, caller.Pointer(), (C.int64_t)(reqID)) +} + +// TaosFetchRowsA void taos_fetch_rows_a(TAOS_RES *res, void (*fp)(void *param, TAOS_RES *, int numOfRows), void *param); +func TaosFetchRowsA(res unsafe.Pointer, caller cgo.Handle) { + C.taos_fetch_rows_a_wrapper(res, caller.Pointer()) +} + +// TaosResetCurrentDB void taos_reset_current_db(TAOS *taos); +func TaosResetCurrentDB(taosConnect unsafe.Pointer) { + C.taos_reset_current_db(taosConnect) +} + +// TaosValidateSql int taos_validate_sql(TAOS *taos, const char *sql); +func TaosValidateSql(taosConnect unsafe.Pointer, sql string) int { + cSql := C.CString(sql) + defer C.free(unsafe.Pointer(cSql)) + return int(C.taos_validate_sql(taosConnect, cSql)) +} + +// TaosIsUpdateQuery bool taos_is_update_query(TAOS_RES *res); +func TaosIsUpdateQuery(res unsafe.Pointer) bool { + return bool(C.taos_is_update_query(res)) +} + +// TaosFetchLengths int* taos_fetch_lengths(TAOS_RES *res); +func TaosFetchLengths(res unsafe.Pointer) unsafe.Pointer { + return unsafe.Pointer(C.taos_fetch_lengths(res)) +} + +// TaosFetchRawBlockA void taos_fetch_raw_block_a(TAOS_RES* res, __taos_async_fn_t fp, void* param); +func TaosFetchRawBlockA(res unsafe.Pointer, caller cgo.Handle) { + C.taos_fetch_raw_block_a_wrapper(res, caller.Pointer()) +} + +// TaosGetRawBlock const void *taos_get_raw_block(TAOS_RES* res); +func TaosGetRawBlock(result unsafe.Pointer) unsafe.Pointer { + return unsafe.Pointer(C.taos_get_raw_block(result)) +} + +// TaosGetClientInfo const char *taos_get_client_info(); +func TaosGetClientInfo() string { + return C.GoString(C.taos_get_client_info()) +} + +// TaosLoadTableInfo taos_load_table_info(TAOS *taos, const char* tableNameList); +func TaosLoadTableInfo(taosConnect unsafe.Pointer, tableNameList []string) int { + s := strings.Join(tableNameList, ",") + buf := C.CString(s) + defer C.free(unsafe.Pointer(buf)) + return int(C.taos_load_table_info(taosConnect, buf)) +} + +// TaosGetTableVgID +// DLL_EXPORT int taos_get_table_vgId(TAOS *taos, const char *db, const char *table, int *vgId) +func TaosGetTableVgID(conn unsafe.Pointer, db, table string) (vgID int, code int) { + cDB := C.CString(db) + defer C.free(unsafe.Pointer(cDB)) + cTable := C.CString(table) + defer C.free(unsafe.Pointer(cTable)) + + code = int(C.taos_get_table_vgId(conn, cDB, cTable, (*C.int)(unsafe.Pointer(&vgID)))) + return +} + +// TaosGetTablesVgID DLL_EXPORT int taos_get_tables_vgId(TAOS *taos, const char *db, const char *table[], int tableNum, int *vgId) +func TaosGetTablesVgID(conn unsafe.Pointer, db string, tables []string) (vgIDs []int, code int) { + cDB := C.CString(db) + defer C.free(unsafe.Pointer(cDB)) + numTables := len(tables) + cTables := make([]*C.char, numTables) + needFree := make([]unsafe.Pointer, numTables) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + for i, table := range tables { + cTable := C.CString(table) + needFree[i] = unsafe.Pointer(cTable) + cTables[i] = cTable + } + p := C.malloc(C.sizeof_int * C.size_t(numTables)) + defer C.free(p) + code = int(C.taos_get_tables_vgId(conn, cDB, (**C.char)(&cTables[0]), (C.int)(numTables), (*C.int)(p))) + if code != 0 { + return nil, code + } + vgIDs = make([]int, numTables) + for i := 0; i < numTables; i++ { + vgIDs[i] = int(*(*C.int)(tools.AddPointer(p, uintptr(C.sizeof_int*C.int(i))))) + } + return +} + +//typedef enum { +//TAOS_CONN_MODE_BI = 0, +//} TAOS_CONN_MODE; +// +//DLL_EXPORT int taos_set_conn_mode(TAOS* taos, int mode, int value); + +func TaosSetConnMode(conn unsafe.Pointer, mode int, value int) int { + return int(C.taos_set_conn_mode(conn, C.int(mode), C.int(value))) +} + +// TaosGetCurrentDB DLL_EXPORT int taos_get_current_db(TAOS *taos, char *database, int len, int *required) +func TaosGetCurrentDB(conn unsafe.Pointer) (db string, err error) { + cDb := (*C.char)(C.malloc(195)) + defer C.free(unsafe.Pointer(cDb)) + var required int + + code := C.taos_get_current_db(conn, cDb, C.int(195), (*C.int)(unsafe.Pointer(&required))) + if code != 0 { + err = errors.NewError(int(code), TaosErrorStr(nil)) + } + db = C.GoString(cDb) + + return +} + +// TaosGetServerInfo DLL_EXPORT const char *taos_get_server_info(TAOS *taos) +func TaosGetServerInfo(conn unsafe.Pointer) string { + info := C.taos_get_server_info(conn) + return C.GoString(info) +} diff --git a/driver/wrapper/taosc_test.go b/driver/wrapper/taosc_test.go new file mode 100644 index 00000000..7a2c268e --- /dev/null +++ b/driver/wrapper/taosc_test.go @@ -0,0 +1,607 @@ +package wrapper + +import ( + "database/sql/driver" + "fmt" + "io" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +// @author: xftan +// @date: 2022/1/27 17:29 +// @description: test taos_options +func TestTaosOptions(t *testing.T) { + type args struct { + option int + value string + } + tests := []struct { + name string + args args + want int + }{ + { + name: "test_options", + args: args{ + option: common.TSDB_OPTION_CONFIGDIR, + value: "/etc/taos", + }, + want: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TaosOptions(tt.args.option, tt.args.value); got != tt.want { + t.Errorf("TaosOptions() = %v, want %v", got, tt.want) + } + }) + } +} + +type result struct { + res unsafe.Pointer + n int +} + +type TestCaller struct { + QueryResult chan *result + FetchResult chan *result +} + +func NewTestCaller() *TestCaller { + return &TestCaller{ + QueryResult: make(chan *result), + FetchResult: make(chan *result), + } +} + +func (t *TestCaller) QueryCall(res unsafe.Pointer, code int) { + t.QueryResult <- &result{ + res: res, + n: code, + } +} + +func (t *TestCaller) FetchCall(res unsafe.Pointer, numOfRows int) { + t.FetchResult <- &result{ + res: res, + n: numOfRows, + } +} + +// @author: xftan +// @date: 2022/1/27 17:29 +// @description: test taos_query_a +func TestTaosQueryA(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + var caller = NewTestCaller() + type args struct { + taosConnect unsafe.Pointer + sql string + caller *TestCaller + } + tests := []struct { + name string + args args + }{ + { + name: "test", + args: args{ + taosConnect: conn, + sql: "show databases", + caller: caller, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := cgo.NewHandle(tt.args.caller) + go TaosQueryA(tt.args.taosConnect, tt.args.sql, p) + r := <-tt.args.caller.QueryResult + t.Log("query finish") + count := TaosNumFields(r.res) + rowsHeader, err := ReadColumn(r.res, count) + precision := TaosResultPrecision(r.res) + if err != nil { + t.Error(err) + return + } + t.Logf("%#v", rowsHeader) + if r.n != 0 { + t.Error("query result", r.n) + return + } + res := r.res + for { + go TaosFetchRowsA(res, p) + r = <-tt.args.caller.FetchResult + if r.n == 0 { + t.Log("success") + TaosFreeResult(r.res) + break + } else { + res = r.res + for i := 0; i < r.n; i++ { + values := make([]driver.Value, len(rowsHeader.ColNames)) + row := TaosFetchRow(res) + lengths := FetchLengths(res, len(rowsHeader.ColNames)) + for j := range rowsHeader.ColTypes { + if row == nil { + t.Error(io.EOF) + return + } + values[j] = FetchRow(row, j, rowsHeader.ColTypes[j], lengths[j], precision) + } + } + t.Log("fetch rows a", r.n) + } + } + }) + } +} + +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test taos error +func TestError(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + res := TaosQuery(conn, "asd") + code := TaosError(res) + assert.NotEqual(t, code, 0) + errStr := TaosErrorStr(res) + assert.NotEmpty(t, errStr) +} + +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test affected rows +func TestAffectedRows(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + res := TaosQuery(conn, "drop database if exists affected_rows_test") + code := TaosError(res) + if code != 0 { + t.Error(errors.NewError(code, TaosErrorStr(res))) + return + } + TaosFreeResult(res) + }() + res := TaosQuery(conn, "create database if not exists affected_rows_test") + code := TaosError(res) + if code != 0 { + t.Error(errors.NewError(code, TaosErrorStr(res))) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, "create table if not exists affected_rows_test.t0(ts timestamp,v int)") + code = TaosError(res) + if code != 0 { + t.Error(errors.NewError(code, TaosErrorStr(res))) + return + } + TaosFreeResult(res) + res = TaosQuery(conn, "insert into affected_rows_test.t0 values(now,1)") + code = TaosError(res) + if code != 0 { + t.Error(errors.NewError(code, TaosErrorStr(res))) + return + } + affected := TaosAffectedRows(res) + assert.Equal(t, 1, affected) +} + +// @author: xftan +// @date: 2022/1/27 17:29 +// @description: test taos_reset_current_db +func TestTaosResetCurrentDB(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + type args struct { + taosConnect unsafe.Pointer + } + tests := []struct { + name string + args args + }{ + { + name: "test", + args: args{ + taosConnect: conn, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = exec(tt.args.taosConnect, "create database if not exists log") + if err != nil { + t.Error(err) + return + } + TaosSelectDB(tt.args.taosConnect, "log") + result := TaosQuery(tt.args.taosConnect, "select database()") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + row := TaosFetchRow(result) + lengths := FetchLengths(result, 1) + currentDB := FetchRow(row, 0, 10, lengths[0]) + assert.Equal(t, "log", currentDB) + TaosFreeResult(result) + TaosResetCurrentDB(tt.args.taosConnect) + result = TaosQuery(tt.args.taosConnect, "select database()") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + row = TaosFetchRow(result) + lengths = FetchLengths(result, 1) + currentDB = FetchRow(row, 0, 10, lengths[0]) + assert.Nil(t, currentDB) + TaosFreeResult(result) + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:30 +// @description: test taos_validate_sql +func TestTaosValidateSql(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + type args struct { + taosConnect unsafe.Pointer + sql string + } + tests := []struct { + name string + args args + want int + }{ + { + name: "valid", + args: args{ + taosConnect: conn, + sql: "show grants", + }, + want: 0, + }, + { + name: "TSC_SQL_SYNTAX_ERROR", + args: args{ + taosConnect: conn, + sql: "slect 1", + }, + want: 9728, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TaosValidateSql(tt.args.taosConnect, tt.args.sql); got&0xffff != tt.want { + t.Errorf("TaosValidateSql() = %v, want %v", got&0xffff, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:30 +// @description: test taos_is_update_query +func TestTaosIsUpdateQuery(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + tests := []struct { + name string + want bool + }{ + { + name: "create database if not exists is_update", + want: true, + }, + { + name: "drop database if exists is_update", + want: true, + }, + { + name: "show log.stables", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := TaosQuery(conn, tt.name) + defer TaosFreeResult(result) + if got := TaosIsUpdateQuery(result); got != tt.want { + t.Errorf("TaosIsUpdateQuery() = %v, want %v", got, tt.want) + } + }) + } +} + +// @author: xftan +// @date: 2022/1/27 17:30 +// @description: taos async raw block +func TestTaosResultBlock(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + var caller = NewTestCaller() + type args struct { + taosConnect unsafe.Pointer + sql string + caller *TestCaller + } + tests := []struct { + name string + args args + }{ + { + name: "test", + args: args{ + taosConnect: conn, + sql: "show users", + caller: caller, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := cgo.NewHandle(tt.args.caller) + go TaosQueryA(tt.args.taosConnect, tt.args.sql, p) + r := <-tt.args.caller.QueryResult + t.Log("query finish") + count := TaosNumFields(r.res) + rowsHeader, err := ReadColumn(r.res, count) + if err != nil { + t.Error(err) + return + } + //t.Logf("%#v", rowsHeader) + if r.n != 0 { + t.Error("query result", r.n) + return + } + res := r.res + precision := TaosResultPrecision(res) + for { + go TaosFetchRawBlockA(res, p) + r = <-tt.args.caller.FetchResult + if r.n == 0 { + t.Log("success") + TaosFreeResult(r.res) + break + } else { + res = r.res + block := TaosGetRawBlock(res) + assert.NotNil(t, block) + values := parser.ReadBlock(block, r.n, rowsHeader.ColTypes, precision) + _ = values + t.Log(values) + } + } + }) + } +} + +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test taos_get_client_info +func TestTaosGetClientInfo(t *testing.T) { + s := TaosGetClientInfo() + assert.NotEmpty(t, s) +} + +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test taos_load_table_info +func TestTaosLoadTableInfo(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + err = exec(conn, "drop database if exists info1") + if err != nil { + t.Error(err) + return + } + defer func() { + err = exec(conn, "drop database if exists info1") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database info1") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table info1.t(ts timestamp,v int)") + if err != nil { + t.Error(err) + return + } + code := TaosLoadTableInfo(conn, []string{"info1.t"}) + if code != 0 { + errStr := TaosErrorStr(nil) + t.Error(errors.NewError(code, errStr)) + return + } + +} + +// @author: xftan +// @date: 2023/10/13 11:32 +// @description: test taos_get_table_vgId +func TestTaosGetTableVgID(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + } + defer TaosClose(conn) + dbName := "table_vg_id_test" + + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + defer func() { + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + }() + if err = exec(conn, fmt.Sprintf("create database %s", dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create stable %s.meters (ts timestamp, current float, voltage int, phase float) "+ + "tags (location binary(64), groupId int)", dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create table %s.d0 using %s.meters tags ('California.SanFrancisco', 1)", dbName, dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create table %s.d1 using %s.meters tags ('California.LosAngles', 2)", dbName, dbName)); err != nil { + t.Fatal(err) + } + + vg1, code := TaosGetTableVgID(conn, dbName, "d0") + if code != 0 { + t.Fatal("fail") + } + vg2, code := TaosGetTableVgID(conn, dbName, "d0") + if code != 0 { + t.Fatal("fail") + } + if vg1 != vg2 { + t.Fatal("fail") + } + _, code = TaosGetTableVgID(conn, dbName, "d1") + if code != 0 { + t.Fatal("fail") + } + _, code = TaosGetTableVgID(conn, dbName, "d2") + if code != 0 { + t.Fatal("fail") + } +} + +// @author: xftan +// @date: 2023/10/13 11:32 +// @description: test taos_get_tables_vgId +func TestTaosGetTablesVgID(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + } + defer TaosClose(conn) + dbName := "tables_vg_id_test" + + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + defer func() { + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + }() + if err = exec(conn, fmt.Sprintf("create database %s", dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create stable %s.meters (ts timestamp, current float, voltage int, phase float) "+ + "tags (location binary(64), groupId int)", dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create table %s.d0 using %s.meters tags ('California.SanFrancisco', 1)", dbName, dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create table %s.d1 using %s.meters tags ('California.LosAngles', 2)", dbName, dbName)); err != nil { + t.Fatal(err) + } + var vgs1 []int + var vgs2 []int + var code int + now := time.Now() + vgs1, code = TaosGetTablesVgID(conn, dbName, []string{"d0", "d1"}) + t.Log(time.Since(now)) + if code != 0 { + t.Fatal("fail") + } + assert.Equal(t, 2, len(vgs1)) + vgs2, code = TaosGetTablesVgID(conn, dbName, []string{"d0", "d1"}) + if code != 0 { + t.Fatal("fail") + } + assert.Equal(t, 2, len(vgs2)) + assert.Equal(t, vgs2, vgs1) +} + +func TestTaosSetConnMode(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + code := TaosSetConnMode(conn, 0, 1) + if code != 0 { + t.Errorf("TaosSetConnMode() error code= %d, msg: %s", code, TaosErrorStr(nil)) + } +} + +func TestTaosGetCurrentDB(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + dbName := "current_db_test" + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + err = exec(conn, fmt.Sprintf("create database %s", dbName)) + assert.NoError(t, err) + defer func() { + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + }() + _ = exec(conn, fmt.Sprintf("use %s", dbName)) + db, err := TaosGetCurrentDB(conn) + assert.NoError(t, err) + assert.Equal(t, dbName, db) +} + +func TestTaosGetServerInfo(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + info := TaosGetServerInfo(conn) + assert.NotEmpty(t, info) +} diff --git a/driver/wrapper/tmq.go b/driver/wrapper/tmq.go new file mode 100644 index 00000000..823dc426 --- /dev/null +++ b/driver/wrapper/tmq.go @@ -0,0 +1,334 @@ +package wrapper + +/* +#include +#include +#include +#include +extern void TMQCommitCB(tmq_t *, int32_t, void *param); +extern void TMQAutoCommitCB(tmq_t *, int32_t, void *param); +extern void TMQCommitOffsetCB(tmq_t *, int32_t, void *param); +*/ +import "C" +import ( + "sync" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/common/tmq" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" + "github.com/taosdata/taosadapter/v3/tools" +) + +var tmqCommitCallbackResultPool = sync.Pool{} + +type TMQCommitCallbackResult struct { + ErrCode int32 + Consumer unsafe.Pointer +} + +func (t *TMQCommitCallbackResult) GetError() error { + if t.ErrCode == 0 { + return nil + } + errStr := TMQErr2Str(t.ErrCode) + return errors.NewError(int(t.ErrCode), errStr) +} + +func GetTMQCommitCallbackResult(errCode int32, consumer unsafe.Pointer) *TMQCommitCallbackResult { + t, ok := tmqCommitCallbackResultPool.Get().(*TMQCommitCallbackResult) + if ok { + t.ErrCode = errCode + t.Consumer = consumer + return t + } + return &TMQCommitCallbackResult{ErrCode: errCode, Consumer: consumer} +} + +func PutTMQCommitCallbackResult(result *TMQCommitCallbackResult) { + tmqCommitCallbackResultPool.Put(result) +} + +// TMQConfNew tmq_conf_t *tmq_conf_new(); +func TMQConfNew() unsafe.Pointer { + return unsafe.Pointer(C.tmq_conf_new()) +} + +// TMQConfSet tmq_conf_res_t tmq_conf_set(tmq_conf_t *conf, const char *key, const char *value); +func TMQConfSet(conf unsafe.Pointer, key string, value string) int32 { + cKey := C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + cValue := C.CString(value) + defer C.free(unsafe.Pointer(cValue)) + return int32(C.tmq_conf_set((*C.struct_tmq_conf_t)(conf), cKey, cValue)) +} + +// TMQConfDestroy void tmq_conf_destroy(tmq_conf_t *conf); +func TMQConfDestroy(conf unsafe.Pointer) { + C.tmq_conf_destroy((*C.struct_tmq_conf_t)(conf)) +} + +// TMQConfSetAutoCommitCB DLL_EXPORT void tmq_conf_set_auto_commit_cb(tmq_conf_t *conf, tmq_commit_cb *cb, void *param); +func TMQConfSetAutoCommitCB(conf unsafe.Pointer, h cgo.Handle) { + C.tmq_conf_set_auto_commit_cb((*C.struct_tmq_conf_t)(conf), (*C.tmq_commit_cb)(C.TMQAutoCommitCB), h.Pointer()) +} + +// TMQCommitAsync DLL_EXPORT void tmq_commit_async(tmq_t *tmq, const TAOS_RES *msg, tmq_commit_cb *cb, void *param); +func TMQCommitAsync(consumer unsafe.Pointer, message unsafe.Pointer, h cgo.Handle) { + C.tmq_commit_async((*C.tmq_t)(consumer), message, (*C.tmq_commit_cb)(C.TMQCommitCB), h.Pointer()) +} + +// TMQCommitSync DLL_EXPORT int32_t tmq_commit_sync(tmq_t *tmq, const TAOS_RES *msg); +func TMQCommitSync(consumer unsafe.Pointer, message unsafe.Pointer) int32 { + return int32(C.tmq_commit_sync((*C.tmq_t)(consumer), message)) +} + +// TMQListNew tmq_list_t *tmq_list_new(); +func TMQListNew() unsafe.Pointer { + return unsafe.Pointer(C.tmq_list_new()) +} + +// TMQListAppend int32_t tmq_list_append(tmq_list_t *, const char *); +func TMQListAppend(list unsafe.Pointer, str string) int32 { + cStr := C.CString(str) + defer C.free(unsafe.Pointer(cStr)) + return int32(C.tmq_list_append((*C.tmq_list_t)(list), cStr)) +} + +// TMQListDestroy void tmq_list_destroy(tmq_list_t *); +func TMQListDestroy(list unsafe.Pointer) { + C.tmq_list_destroy((*C.tmq_list_t)(list)) +} + +// TMQListGetSize int32_t tmq_list_get_size(const tmq_list_t *); +func TMQListGetSize(list unsafe.Pointer) int32 { + return int32(C.tmq_list_get_size((*C.tmq_list_t)(list))) +} + +// TMQListToCArray char **tmq_list_to_c_array(const tmq_list_t *); +func TMQListToCArray(list unsafe.Pointer, size int) []string { + head := unsafe.Pointer(C.tmq_list_to_c_array((*C.tmq_list_t)(list))) + result := make([]string, size) + for i := 0; i < size; i++ { + result[i] = C.GoString(*(**C.char)(tools.AddPointer(head, PointerSize*uintptr(i)))) + } + return result +} + +// TMQConsumerNew tmq_t *tmq_consumer_new1(tmq_conf_t *conf, char *errstr, int32_t errstrLen); +func TMQConsumerNew(conf unsafe.Pointer) (unsafe.Pointer, error) { + p := (*C.char)(C.calloc(C.size_t(C.uint(1024)), C.size_t(C.uint(1024)))) + defer C.free(unsafe.Pointer(p)) + tmq := unsafe.Pointer(C.tmq_consumer_new((*C.struct_tmq_conf_t)(conf), p, C.int32_t(1024))) + errStr := C.GoString(p) + if len(errStr) > 0 { + return nil, errors.NewError(-1, errStr) + } + if tmq == nil { + return nil, errors.NewError(-1, "new consumer return nil") + } + return tmq, nil +} + +// TMQErr2Str const char *tmq_err2str(int32_t); +func TMQErr2Str(code int32) string { + return C.GoString(C.tmq_err2str((C.int32_t)(code))) +} + +// TMQSubscribe tmq_resp_err_t tmq_subscribe(tmq_t *tmq, tmq_list_t *topic_list); +func TMQSubscribe(consumer unsafe.Pointer, topicList unsafe.Pointer) int32 { + return int32(C.tmq_subscribe((*C.tmq_t)(consumer), (*C.tmq_list_t)(topicList))) +} + +// TMQUnsubscribe tmq_resp_err_t tmq_unsubscribe(tmq_t *tmq); +func TMQUnsubscribe(consumer unsafe.Pointer) int32 { + return int32(C.tmq_unsubscribe((*C.tmq_t)(consumer))) +} + +// TMQSubscription tmq_resp_err_t tmq_subscription(tmq_t *tmq, tmq_list_t **topics); +func TMQSubscription(consumer unsafe.Pointer) (int32, unsafe.Pointer) { + list := C.tmq_list_new() + code := int32(C.tmq_subscription( + (*C.tmq_t)(consumer), + (**C.tmq_list_t)(&list), + )) + return code, unsafe.Pointer(list) +} + +// TMQConsumerPoll TAOS_RES *tmq_consumer_poll(tmq_t *tmq, int64_t blocking_time); +func TMQConsumerPoll(consumer unsafe.Pointer, blockingTime int64) unsafe.Pointer { + return unsafe.Pointer(C.tmq_consumer_poll((*C.tmq_t)(consumer), (C.int64_t)(blockingTime))) +} + +// TMQConsumerClose tmq_resp_err_t tmq_consumer_close(tmq_t *tmq); +func TMQConsumerClose(consumer unsafe.Pointer) int32 { + return int32(C.tmq_consumer_close((*C.tmq_t)(consumer))) +} + +// TMQGetTopicName char *tmq_get_topic_name(tmq_message_t *message); +func TMQGetTopicName(message unsafe.Pointer) string { + return C.GoString(C.tmq_get_topic_name(message)) +} + +// TMQGetVgroupID int32_t tmq_get_vgroup_id(tmq_message_t *message); +func TMQGetVgroupID(message unsafe.Pointer) int32 { + return int32(C.tmq_get_vgroup_id(message)) +} + +// TMQGetTableName DLL_EXPORT const char *tmq_get_table_name(TAOS_RES *res); +func TMQGetTableName(message unsafe.Pointer) string { + return C.GoString(C.tmq_get_table_name(message)) +} + +// TMQGetDBName const char *tmq_get_db_name(TAOS_RES *res); +func TMQGetDBName(message unsafe.Pointer) string { + return C.GoString(C.tmq_get_db_name(message)) +} + +// TMQGetResType DLL_EXPORT tmq_res_t tmq_get_res_type(TAOS_RES *res); +func TMQGetResType(message unsafe.Pointer) int32 { + return int32(C.tmq_get_res_type(message)) +} + +// TMQGetRaw DLL_EXPORT int32_t tmq_get_raw(TAOS_RES *res, tmq_raw_data *raw); +func TMQGetRaw(message unsafe.Pointer) (int32, unsafe.Pointer) { + var cRawMeta C.TAOS_FIELD_E + m := unsafe.Pointer(&cRawMeta) + code := int32(C.tmq_get_raw(message, (*C.tmq_raw_data)(m))) + return code, m +} + +// TMQWriteRaw DLL_EXPORT int32_t tmq_write_raw(TAOS *taos, tmq_raw_data raw); +func TMQWriteRaw(conn unsafe.Pointer, raw unsafe.Pointer) int32 { + return int32(C.tmq_write_raw(conn, (C.struct_tmq_raw_data)(*(*C.struct_tmq_raw_data)(raw)))) +} + +// TMQFreeRaw DLL_EXPORT void tmq_free_raw(tmq_raw_data raw); +func TMQFreeRaw(raw unsafe.Pointer) { + C.tmq_free_raw((C.struct_tmq_raw_data)(*(*C.struct_tmq_raw_data)(raw))) +} + +// TMQGetJsonMeta DLL_EXPORT char *tmq_get_json_meta(TAOS_RES *res); // Returning null means error. Returned result need to be freed by tmq_free_json_meta +func TMQGetJsonMeta(message unsafe.Pointer) unsafe.Pointer { + p := unsafe.Pointer(C.tmq_get_json_meta(message)) + return p +} + +// TMQFreeJsonMeta DLL_EXPORT void tmq_free_json_meta(char* jsonMeta); +func TMQFreeJsonMeta(jsonMeta unsafe.Pointer) { + C.tmq_free_json_meta((*C.char)(jsonMeta)) +} + +func ParseRawMeta(rawMeta unsafe.Pointer) (length uint32, metaType uint16, data unsafe.Pointer) { + meta := *(*C.tmq_raw_data)(rawMeta) + length = uint32(meta.raw_len) + metaType = uint16(meta.raw_type) + data = meta.raw + return +} + +func ParseJsonMeta(jsonMeta unsafe.Pointer) []byte { + var binaryVal []byte + if jsonMeta != nil { + i := 0 + c := byte(0) + for { + c = *((*byte)(unsafe.Pointer(uintptr(jsonMeta) + uintptr(i)))) + if c != 0 { + binaryVal = append(binaryVal, c) + i += 1 + } else { + break + } + } + } + return binaryVal +} + +func BuildRawMeta(length uint32, metaType uint16, data unsafe.Pointer) unsafe.Pointer { + meta := C.struct_tmq_raw_data{} + meta.raw = data + meta.raw_len = (C.uint32_t)(length) + meta.raw_type = (C.uint16_t)(metaType) + return unsafe.Pointer(&meta) +} + +// TMQGetTopicAssignment DLL_EXPORT int32_t tmq_get_topic_assignment(tmq_t *tmq, const char* pTopicName, tmq_topic_assignment **assignment, int32_t *numOfAssignment) +func TMQGetTopicAssignment(consumer unsafe.Pointer, topic string) (int32, []*tmq.Assignment) { + var assignment *C.tmq_topic_assignment + var numOfAssignment int32 + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + code := int32(C.tmq_get_topic_assignment((*C.tmq_t)(consumer), topicName, (**C.tmq_topic_assignment)(unsafe.Pointer(&assignment)), (*C.int32_t)(&numOfAssignment))) + if code != 0 { + return code, nil + } + if assignment == nil { + return 0, nil + } + defer TMQFreeAssignment(unsafe.Pointer(assignment)) + result := make([]*tmq.Assignment, numOfAssignment) + for i := 0; i < int(numOfAssignment); i++ { + item := *(*C.tmq_topic_assignment)(unsafe.Pointer(uintptr(unsafe.Pointer(assignment)) + uintptr(C.sizeof_struct_tmq_topic_assignment*C.int(i)))) + result[i] = &tmq.Assignment{ + VGroupID: int32(item.vgId), + Offset: int64(item.currentOffset), + Begin: int64(item.begin), + End: int64(item.end), + } + } + return 0, result +} + +// TMQOffsetSeek DLL_EXPORT int32_t tmq_offset_seek(tmq_t* tmq, const char* pTopicName, int32_t vgroupHandle, int64_t offset); +func TMQOffsetSeek(consumer unsafe.Pointer, topic string, vGroupID int32, offset int64) int32 { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + return int32(C.tmq_offset_seek((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID), (C.int64_t)(offset))) +} + +// TMQGetVgroupOffset DLL_EXPORT int64_t tmq_get_vgroup_offset(TAOS_RES* res, int32_t vgroupId); +func TMQGetVgroupOffset(message unsafe.Pointer) int64 { + return int64(C.tmq_get_vgroup_offset(message)) +} + +// TMQFreeAssignment DLL_EXPORT void tmq_free_assignment(tmq_topic_assignment* pAssignment); +func TMQFreeAssignment(assignment unsafe.Pointer) { + if assignment == nil { + return + } + C.tmq_free_assignment((*C.tmq_topic_assignment)(assignment)) +} + +// TMQPosition DLL_EXPORT int64_t tmq_position(tmq_t *tmq, const char *pTopicName, int32_t vgId); +func TMQPosition(consumer unsafe.Pointer, topic string, vGroupID int32) int64 { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + return int64(C.tmq_position((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID))) +} + +// TMQCommitted DLL_EXPORT int64_t tmq_committed(tmq_t *tmq, const char *pTopicName, int32_t vgId); +func TMQCommitted(consumer unsafe.Pointer, topic string, vGroupID int32) int64 { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + return int64(C.tmq_committed((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID))) +} + +// TMQCommitOffsetSync DLL_EXPORT int32_t tmq_commit_offset_sync(tmq_t *tmq, const char *pTopicName, int32_t vgId, int64_t offset); +func TMQCommitOffsetSync(consumer unsafe.Pointer, topic string, vGroupID int32, offset int64) int32 { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + return int32(C.tmq_commit_offset_sync((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID), (C.int64_t)(offset))) +} + +// TMQCommitOffsetAsync DLL_EXPORT void tmq_commit_offset_async(tmq_t *tmq, const char *pTopicName, int32_t vgId, int64_t offset, tmq_commit_cb *cb, void *param); +func TMQCommitOffsetAsync(consumer unsafe.Pointer, topic string, vGroupID int32, offset int64, h cgo.Handle) { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + C.tmq_commit_offset_async((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID), (C.int64_t)(offset), (*C.tmq_commit_cb)(C.TMQCommitOffsetCB), h.Pointer()) +} + +// TMQGetConnect TAOS *tmq_get_connect(tmq_t *tmq) +func TMQGetConnect(consumer unsafe.Pointer) unsafe.Pointer { + return unsafe.Pointer(C.tmq_get_connect((*C.tmq_t)(consumer))) +} diff --git a/driver/wrapper/tmq_test.go b/driver/wrapper/tmq_test.go new file mode 100644 index 00000000..4f3879dc --- /dev/null +++ b/driver/wrapper/tmq_test.go @@ -0,0 +1,2012 @@ +package wrapper + +import ( + "database/sql/driver" + "testing" + "time" + "unsafe" + + jsoniter "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + tmqcommon "github.com/taosdata/taosadapter/v3/driver/common/tmq" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +// @author: xftan +// @date: 2023/10/13 11:32 +// @description: test tmq +func TestTMQ(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + result := TaosQuery(conn, "drop database if exists abc1") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result := TaosQuery(conn, "create database if not exists abc1 vgroups 2 WAL_RETENTION_PERIOD 86400") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "use abc1") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create stable if not exists st1 (ts timestamp, c1 int, c2 float, c3 binary(10)) tags(t1 int)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct0 using st1 tags(1000)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct1 using st1 tags(2000)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct3 using st1 tags(3000)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + //create topic + defer func() { + result = TaosQuery(conn, "drop topic if exists topic_ctb_column") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result = TaosQuery(conn, "create topic if not exists topic_ctb_column as select ts, c1 from ct1") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + defer func() { + result = TaosQuery(conn, "drop topic if exists topic_ctb_column") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + go func() { + for i := 0; i < 5; i++ { + result = TaosQuery(conn, "insert into ct1 values(now,1,2,'1')") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + time.Sleep(time.Millisecond) + } + }() + //build consumer + conf := TMQConfNew() + TMQConfSet(conf, "msg.with.table.name", "true") + // auto commit default is true then the commitCallback function will be called after 5 seconds + TMQConfSet(conf, "enable.auto.commit", "true") + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "auto.offset.reset", "earliest") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Error(err) + } + TMQConfDestroy(conf) + //build_topic_list + topicList := TMQListNew() + TMQListAppend(topicList, "topic_ctb_column") + + //sync_consume_loop + s := time.Now() + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + t.Log("sub", time.Since(s)) + errCode, list := TMQSubscription(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + size := TMQListGetSize(list) + r := TMQListToCArray(list, int(size)) + assert.Equal(t, []string{"topic_ctb_column"}, r) + c2 := make(chan *TMQCommitCallbackResult, 1) + h2 := cgo.NewHandle(c2) + for i := 0; i < 5; i++ { + + message := TMQConsumerPoll(tmq, 500) + if message != nil { + t.Log(message) + topic := TMQGetTopicName(message) + assert.Equal(t, "topic_ctb_column", topic) + vgroupID := TMQGetVgroupID(message) + t.Log("vgroupID", vgroupID) + + for { + blockSize, errCode, block := TaosFetchRawBlock(message) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(message) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(message) + return + } + if blockSize == 0 { + break + } + filedCount := TaosNumFields(message) + rh, err := ReadColumn(message, filedCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(message) + //tableName := TMQGetTableName(message) + //assert.Equal(t, "ct1", tableName) + dbName := TMQGetDBName(message) + assert.Equal(t, "abc1", dbName) + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + t.Log(data) + } + TaosFreeResult(message) + TMQCommitAsync(tmq, nil, h2) + timer := time.NewTimer(time.Minute) + select { + case d := <-c2: + assert.Equal(t, int32(0), d.ErrCode) + PutTMQCommitCallbackResult(d) + timer.Stop() + break + case <-timer.C: + timer.Stop() + t.Error("wait tmq commit callback timeout") + return + } + } + } + + errCode = TMQConsumerClose(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +// @author: xftan +// @date: 2023/10/13 11:33 +// @description: test TMQList +func TestTMQList(t *testing.T) { + list := TMQListNew() + TMQListAppend(list, "1") + TMQListAppend(list, "2") + TMQListAppend(list, "3") + size := TMQListGetSize(list) + r := TMQListToCArray(list, int(size)) + assert.Equal(t, []string{"1", "2", "3"}, r) + TMQListDestroy(list) +} + +// @author: xftan +// @date: 2023/10/13 11:33 +// @description: test tmq subscribe db +func TestTMQDB(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + result := TaosQuery(conn, "drop database if exists tmq_test_db") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result := TaosQuery(conn, "create database if not exists tmq_test_db vgroups 2 WAL_RETENTION_PERIOD 86400") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "use tmq_test_db") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create stable if not exists st1 (ts timestamp, c1 int, c2 float, c3 binary(10)) tags(t1 int)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct0 using st1 tags(1000)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct1 using st1 tags(2000)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct3 using st1 tags(3000)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + //create topic + result = TaosQuery(conn, "create topic if not exists test_tmq_db_topic as DATABASE tmq_test_db") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + defer func() { + result = TaosQuery(conn, "drop topic if exists test_tmq_db_topic") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + go func() { + for i := 0; i < 5; i++ { + result = TaosQuery(conn, "insert into ct1 values(now,1,2,'1')") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + time.Sleep(time.Millisecond) + } + }() + //build consumer + conf := TMQConfNew() + // auto commit default is true then the commitCallback function will be called after 5 seconds + TMQConfSet(conf, "enable.auto.commit", "true") + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Error(err) + } + TMQConfDestroy(conf) + //build_topic_list + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_db_topic") + + //sync_consume_loop + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + errCode, list := TMQSubscription(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + size := TMQListGetSize(list) + r := TMQListToCArray(list, int(size)) + assert.Equal(t, []string{"test_tmq_db_topic"}, r) + totalCount := 0 + c2 := make(chan *TMQCommitCallbackResult, 1) + h2 := cgo.NewHandle(c2) + for i := 0; i < 5; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + t.Log(message) + topic := TMQGetTopicName(message) + assert.Equal(t, "test_tmq_db_topic", topic) + for { + blockSize, errCode, block := TaosFetchRawBlock(message) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(message) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(message) + return + } + if blockSize == 0 { + break + } + tableName := TMQGetTableName(message) + assert.Equal(t, "ct1", tableName) + filedCount := TaosNumFields(message) + rh, err := ReadColumn(message, filedCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(message) + totalCount += blockSize + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + t.Log(data) + } + TaosFreeResult(message) + + TMQCommitAsync(tmq, nil, h2) + timer := time.NewTimer(time.Minute) + select { + case d := <-c2: + assert.Nil(t, d.GetError()) + assert.Equal(t, int32(0), d.ErrCode) + PutTMQCommitCallbackResult(d) + timer.Stop() + break + case <-timer.C: + timer.Stop() + t.Error("wait tmq commit callback timeout") + return + } + } + } + + errCode = TMQConsumerClose(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + assert.GreaterOrEqual(t, totalCount, 5) +} + +// @author: xftan +// @date: 2023/10/13 11:33 +// @description: test tmq subscribe multi tables +func TestTMQDBMultiTable(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + result := TaosQuery(conn, "drop database if exists tmq_test_db_multi") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result := TaosQuery(conn, "create database if not exists tmq_test_db_multi vgroups 2 WAL_RETENTION_PERIOD 86400") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "use tmq_test_db_multi") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct0 (ts timestamp, c1 int)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct1 (ts timestamp, c1 int, c2 float)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct2 (ts timestamp, c1 int, c2 float, c3 binary(10))") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + //create topic + result = TaosQuery(conn, "create topic if not exists test_tmq_db_multi_topic as DATABASE tmq_test_db_multi") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + defer func() { + result = TaosQuery(conn, "drop topic if exists test_tmq_db_multi_topic") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + { + result = TaosQuery(conn, "insert into ct0 values(now,1)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + } + { + result = TaosQuery(conn, "insert into ct1 values(now,1,2)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + } + { + result = TaosQuery(conn, "insert into ct2 values(now,1,2,'3')") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + } + //build consumer + conf := TMQConfNew() + // auto commit default is true then the commitCallback function will be called after 5 seconds + TMQConfSet(conf, "enable.auto.commit", "true") + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Error(err) + } + TMQConfDestroy(conf) + //build_topic_list + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_db_multi_topic") + + //sync_consume_loop + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + errCode, list := TMQSubscription(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + size := TMQListGetSize(list) + r := TMQListToCArray(list, int(size)) + assert.Equal(t, []string{"test_tmq_db_multi_topic"}, r) + totalCount := 0 + tables := map[string]struct{}{ + "ct0": {}, + "ct1": {}, + "ct2": {}, + } + c2 := make(chan *TMQCommitCallbackResult, 1) + h2 := cgo.NewHandle(c2) + for i := 0; i < 5; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + t.Log(message) + topic := TMQGetTopicName(message) + assert.Equal(t, "test_tmq_db_multi_topic", topic) + for { + blockSize, errCode, block := TaosFetchRawBlock(message) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(message) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(message) + return + } + if blockSize == 0 { + break + } + tableName := TMQGetTableName(message) + delete(tables, tableName) + filedCount := TaosNumFields(message) + rh, err := ReadColumn(message, filedCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(message) + totalCount += blockSize + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + t.Log(data) + } + TaosFreeResult(message) + + TMQCommitAsync(tmq, nil, h2) + timer := time.NewTimer(time.Minute) + select { + case d := <-c2: + assert.Equal(t, int32(0), d.ErrCode) + PutTMQCommitCallbackResult(d) + timer.Stop() + break + case <-timer.C: + timer.Stop() + t.Error("wait tmq commit callback timeout") + return + } + } + } + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + errCode = TMQConsumerClose(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + assert.GreaterOrEqual(t, totalCount, 3) + assert.Emptyf(t, tables, "tables name not empty", tables) +} + +// @author: xftan +// @date: 2023/10/13 11:33 +// @description: test tmq subscribe db with multi table insert +func TestTMQDBMultiInsert(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + result := TaosQuery(conn, "drop database if exists tmq_test_db_multi_insert") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result := TaosQuery(conn, "create database if not exists tmq_test_db_multi_insert vgroups 2 wal_retention_period 3600") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "use tmq_test_db_multi_insert") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct0 (ts timestamp, c1 int)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct1 (ts timestamp, c1 int, c2 float)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create table if not exists ct2 (ts timestamp, c1 int, c2 float, c3 binary(10))") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + //create topic + result = TaosQuery(conn, "create topic if not exists tmq_test_db_multi_insert_topic as DATABASE tmq_test_db_multi_insert") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + defer func() { + result = TaosQuery(conn, "drop topic if exists tmq_test_db_multi_insert_topic") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + { + result = TaosQuery(conn, "insert into ct0 values(now,1) ct1 values(now,1,2) ct2 values(now,1,2,'3')") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + } + //build consumer + conf := TMQConfNew() + // auto commit default is true then the commitCallback function will be called after 5 seconds + TMQConfSet(conf, "enable.auto.commit", "true") + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Error(err) + } + TMQConfDestroy(conf) + //build_topic_list + topicList := TMQListNew() + TMQListAppend(topicList, "tmq_test_db_multi_insert_topic") + + //sync_consume_loop + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + errCode, list := TMQSubscription(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + size := TMQListGetSize(list) + r := TMQListToCArray(list, int(size)) + assert.Equal(t, []string{"tmq_test_db_multi_insert_topic"}, r) + totalCount := 0 + var tables [][]string + c2 := make(chan *TMQCommitCallbackResult, 1) + h2 := cgo.NewHandle(c2) + for i := 0; i < 5; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + t.Log(message) + topic := TMQGetTopicName(message) + assert.Equal(t, "tmq_test_db_multi_insert_topic", topic) + var table []string + for { + blockSize, errCode, block := TaosFetchRawBlock(message) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(message) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(message) + return + } + if blockSize == 0 { + break + } + tableName := TMQGetTableName(message) + table = append(table, tableName) + filedCount := TaosNumFields(message) + rh, err := ReadColumn(message, filedCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(message) + totalCount += blockSize + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + t.Log(data) + } + TaosFreeResult(message) + + TMQCommitAsync(tmq, nil, h2) + timer := time.NewTimer(time.Minute) + select { + case d := <-c2: + assert.Equal(t, int32(0), d.ErrCode) + PutTMQCommitCallbackResult(d) + timer.Stop() + break + case <-timer.C: + timer.Stop() + t.Error("wait tmq commit callback timeout") + return + } + tables = append(tables, table) + } + } + + errCode = TMQConsumerClose(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + assert.GreaterOrEqual(t, totalCount, 3) + t.Log(tables) +} + +// @author: xftan +// @date: 2023/10/13 11:34 +// @description: tmq test modify meta +func TestTMQModify(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + result := TaosQuery(conn, "drop database if exists tmq_test_db_modify") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + result = TaosQuery(conn, "drop database if exists tmq_test_db_modify_target") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + + result := TaosQuery(conn, "drop database if exists tmq_test_db_modify_target") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "drop database if exists tmq_test_db_modify") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + result = TaosQuery(conn, "create database if not exists tmq_test_db_modify_target vgroups 2 WAL_RETENTION_PERIOD 86400") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create database if not exists tmq_test_db_modify vgroups 5 WAL_RETENTION_PERIOD 86400") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "use tmq_test_db_modify") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + //create topic + result = TaosQuery(conn, "create topic if not exists tmq_test_db_modify_topic with meta as DATABASE tmq_test_db_modify") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + defer func() { + result = TaosQuery(conn, "drop topic if exists tmq_test_db_modify_topic") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + //build consumer + conf := TMQConfNew() + // auto commit default is true then the commitCallback function will be called after 5 seconds + TMQConfSet(conf, "enable.auto.commit", "true") + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Error(err) + } + TMQConfDestroy(conf) + //build_topic_list + topicList := TMQListNew() + TMQListAppend(topicList, "tmq_test_db_modify_topic") + + //sync_consume_loop + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + errCode, list := TMQSubscription(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + size := TMQListGetSize(list) + r := TMQListToCArray(list, int(size)) + assert.Equal(t, []string{"tmq_test_db_modify_topic"}, r) + c2 := make(chan *TMQCommitCallbackResult, 1) + h2 := cgo.NewHandle(c2) + targetConn, err := TaosConnect("", "root", "taosdata", "tmq_test_db_modify_target", 0) + assert.NoError(t, err) + defer TaosClose(targetConn) + result = TaosQuery(conn, "create table stb (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ")"+ + "tags(tts timestamp,"+ + "tc1 bool,"+ + "tc2 tinyint,"+ + "tc3 smallint,"+ + "tc4 int,"+ + "tc5 bigint,"+ + "tc6 tinyint unsigned,"+ + "tc7 smallint unsigned,"+ + "tc8 int unsigned,"+ + "tc9 bigint unsigned,"+ + "tc10 float,"+ + "tc11 double,"+ + "tc12 binary(20),"+ + "tc13 nchar(20)"+ + ")") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + pool := func(cb func(*tmqcommon.Meta, unsafe.Pointer)) { + message := TMQConsumerPoll(tmq, 500) + assert.NotNil(t, message) + topic := TMQGetTopicName(message) + assert.Equal(t, "tmq_test_db_modify_topic", topic) + messageType := TMQGetResType(message) + assert.Equal(t, int32(common.TMQ_RES_TABLE_META), messageType) + pointer := TMQGetJsonMeta(message) + assert.NotNil(t, pointer) + data := ParseJsonMeta(pointer) + var meta tmqcommon.Meta + err = jsoniter.Unmarshal(data, &meta) + assert.NoError(t, err) + + defer TaosFreeResult(message) + + TMQCommitAsync(tmq, nil, h2) + timer := time.NewTimer(time.Minute) + select { + case d := <-c2: + assert.Equal(t, int32(0), d.ErrCode) + PutTMQCommitCallbackResult(d) + timer.Stop() + break + case <-timer.C: + timer.Stop() + t.Error("wait tmq commit callback timeout") + cb(nil, nil) + return + } + errCode, rawMeta := TMQGetRaw(message) + if errCode != errors.SUCCESS { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + cb(&meta, rawMeta) + TMQFreeRaw(rawMeta) + } + + pool(func(meta *tmqcommon.Meta, rawMeta unsafe.Pointer) { + assert.Equal(t, "create", meta.Type) + assert.Equal(t, "stb", meta.TableName) + assert.Equal(t, "super", meta.TableType) + assert.NoError(t, err) + length, metaType, data := ParseRawMeta(rawMeta) + r2 := BuildRawMeta(length, metaType, data) + errCode = TMQWriteRaw(targetConn, r2) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + d, err := query(targetConn, "describe stb") + assert.NoError(t, err) + expect := [][]driver.Value{ + {"ts", "TIMESTAMP", int32(8), ""}, + {"c1", "BOOL", int32(1), ""}, + {"c2", "TINYINT", int32(1), ""}, + {"c3", "SMALLINT", int32(2), ""}, + {"c4", "INT", int32(4), ""}, + {"c5", "BIGINT", int32(8), ""}, + {"c6", "TINYINT UNSIGNED", int32(1), ""}, + {"c7", "SMALLINT UNSIGNED", int32(2), ""}, + {"c8", "INT UNSIGNED", int32(4), ""}, + {"c9", "BIGINT UNSIGNED", int32(8), ""}, + {"c10", "FLOAT", int32(4), ""}, + {"c11", "DOUBLE", int32(8), ""}, + {"c12", "VARCHAR", int32(20), ""}, + {"c13", "NCHAR", int32(20), ""}, + {"tts", "TIMESTAMP", int32(8), "TAG"}, + {"tc1", "BOOL", int32(1), "TAG"}, + {"tc2", "TINYINT", int32(1), "TAG"}, + {"tc3", "SMALLINT", int32(2), "TAG"}, + {"tc4", "INT", int32(4), "TAG"}, + {"tc5", "BIGINT", int32(8), "TAG"}, + {"tc6", "TINYINT UNSIGNED", int32(1), "TAG"}, + {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG"}, + {"tc8", "INT UNSIGNED", int32(4), "TAG"}, + {"tc9", "BIGINT UNSIGNED", int32(8), "TAG"}, + {"tc10", "FLOAT", int32(4), "TAG"}, + {"tc11", "DOUBLE", int32(8), "TAG"}, + {"tc12", "VARCHAR", int32(20), "TAG"}, + {"tc13", "NCHAR", int32(20), "TAG"}, + } + for rowIndex, values := range d { + for i := 0; i < 4; i++ { + assert.Equal(t, expect[rowIndex][i], values[i]) + } + } + }) + + TMQUnsubscribe(tmq) + errCode = TMQConsumerClose(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +// @author: xftan +// @date: 2023/10/13 11:34 +// @description: test tmq subscribe with auto create table +func TestTMQAutoCreateTable(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + result := TaosQuery(conn, "drop database if exists tmq_test_auto_create") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result := TaosQuery(conn, "create database if not exists tmq_test_auto_create vgroups 2 WAL_RETENTION_PERIOD 86400") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "use tmq_test_auto_create") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create stable if not exists st1 (ts timestamp, c1 int, c2 float, c3 binary(10)) tags(t1 int)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + //create topic + result = TaosQuery(conn, "create topic if not exists test_tmq_auto_topic with meta as DATABASE tmq_test_auto_create") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + defer func() { + result = TaosQuery(conn, "drop topic if exists test_tmq_auto_topic") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result = TaosQuery(conn, "insert into ct1 using st1 tags(2000) values(now,1,2,'1')") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + //build consumer + conf := TMQConfNew() + // auto commit default is true then the commitCallback function will be called after 5 seconds + TMQConfSet(conf, "enable.auto.commit", "true") + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Error(err) + } + TMQConfDestroy(conf) + //build_topic_list + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_auto_topic") + + //sync_consume_loop + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + errCode, list := TMQSubscription(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + size := TMQListGetSize(list) + r := TMQListToCArray(list, int(size)) + assert.Equal(t, []string{"test_tmq_auto_topic"}, r) + totalCount := 0 + c2 := make(chan *TMQCommitCallbackResult, 1) + h2 := cgo.NewHandle(c2) + for i := 0; i < 5; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + t.Log(message) + topic := TMQGetTopicName(message) + assert.Equal(t, "test_tmq_auto_topic", topic) + messageType := TMQGetResType(message) + if messageType != common.TMQ_RES_METADATA { + continue + } + pointer := TMQGetJsonMeta(message) + data := ParseJsonMeta(pointer) + t.Log(string(data)) + var meta tmqcommon.Meta + err = jsoniter.Unmarshal(data, &meta) + assert.NoError(t, err) + assert.Equal(t, "create", meta.Type) + for { + blockSize, errCode, block := TaosFetchRawBlock(message) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(message) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(message) + return + } + if blockSize == 0 { + break + } + tableName := TMQGetTableName(message) + assert.Equal(t, "ct1", tableName) + filedCount := TaosNumFields(message) + rh, err := ReadColumn(message, filedCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(message) + totalCount += blockSize + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + t.Log(data) + } + TaosFreeResult(message) + + TMQCommitAsync(tmq, nil, h2) + timer := time.NewTimer(time.Minute) + select { + case d := <-c2: + assert.Nil(t, d.GetError()) + assert.Equal(t, int32(0), d.ErrCode) + PutTMQCommitCallbackResult(d) + timer.Stop() + break + case <-timer.C: + timer.Stop() + t.Error("wait tmq commit callback timeout") + return + } + } + } + + errCode = TMQConsumerClose(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + assert.GreaterOrEqual(t, totalCount, 1) +} + +// @author: xftan +// @date: 2023/10/13 11:35 +// @description: test tmq get assignment +func TestTMQGetTopicAssignment(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists test_tmq_get_topic_assignment"); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists test_tmq_get_topic_assignment vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use test_tmq_get_topic_assignment"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + if err = taosOperation(conn, "create topic if not exists test_tmq_assignment as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists test_tmq_assignment"); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "auto.offset.reset", "earliest") + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_assignment") + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + + code, assignment := TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.Equal(t, int64(0), assignment[0].Offset) + assert.GreaterOrEqual(t, assignment[0].End, assignment[0].Offset) + end := assignment[0].End + vgID, vgCode := TaosGetTableVgID(conn, "test_tmq_get_topic_assignment", "t") + if vgCode != 0 { + t.Fatal(errors.NewError(int(vgCode), TMQErr2Str(code))) + } + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + _ = taosOperation(conn, "insert into t values(now,1)") + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + TMQCommitSync(tmq, message) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + //seek + code = TMQOffsetSeek(tmq, "test_tmq_assignment", int32(vgID), 0) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.Equal(t, int64(0), assignment[0].Offset) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + haveMessage = false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + TMQCommitSync(tmq, message) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + // seek twice + code = TMQOffsetSeek(tmq, "test_tmq_assignment", int32(vgID), 1) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + haveMessage = false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + offset := TMQGetVgroupOffset(message) + assert.Greater(t, offset, int64(0)) + TMQCommitSync(tmq, message) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.GreaterOrEqual(t, assignment[0].End, end) + assert.Equal(t, int32(vgID), assignment[0].VGroupID) +} + +func TestTMQPosition(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists test_tmq_position"); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists test_tmq_position vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use test_tmq_position"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + if err = taosOperation(conn, "create topic if not exists test_tmq_position_topic as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists test_tmq_position_topic"); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "position") + TMQConfSet(conf, "auto.offset.reset", "earliest") + + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_position_topic") + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + _ = taosOperation(conn, "insert into t values(now,1)") + code, assignment := TMQGetTopicAssignment(tmq, "test_tmq_position_topic") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + vgID := assignment[0].VGroupID + position := TMQPosition(tmq, "test_tmq_position_topic", vgID) + assert.Equal(t, position, int64(0)) + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + position := TMQPosition(tmq, "test_tmq_position_topic", vgID) + assert.Greater(t, position, int64(0)) + committed := TMQCommitted(tmq, "test_tmq_position_topic", vgID) + assert.Less(t, committed, int64(0)) + TMQCommitSync(tmq, message) + position = TMQPosition(tmq, "test_tmq_position_topic", vgID) + committed = TMQCommitted(tmq, "test_tmq_position_topic", vgID) + assert.Equal(t, position, committed) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +func TestTMQCommitOffset(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists test_tmq_commit_offset"); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists test_tmq_commit_offset vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use test_tmq_commit_offset"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + if err = taosOperation(conn, "create topic if not exists test_tmq_commit_offset_topic as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists test_tmq_commit_offset_topic"); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "commit") + TMQConfSet(conf, "auto.offset.reset", "earliest") + + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_commit_offset_topic") + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + _ = taosOperation(conn, "insert into t values(now,1)") + code, assignment := TMQGetTopicAssignment(tmq, "test_tmq_commit_offset_topic") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + vgID := assignment[0].VGroupID + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + position := TMQPosition(tmq, "test_tmq_commit_offset_topic", vgID) + assert.Greater(t, position, int64(0)) + committed := TMQCommitted(tmq, "test_tmq_commit_offset_topic", vgID) + assert.Less(t, committed, int64(0)) + offset := TMQGetVgroupOffset(message) + code = TMQCommitOffsetSync(tmq, "test_tmq_commit_offset_topic", vgID, offset) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + committed = TMQCommitted(tmq, "test_tmq_commit_offset_topic", vgID) + assert.Equal(t, int64(offset), committed) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +func TestTMQCommitOffsetAsync(t *testing.T) { + topic := "test_tmq_commit_offset_a_topic" + tableName := "test_tmq_commit_offset_a" + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists "+tableName); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists "+tableName+" vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use "+tableName); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + + if err = taosOperation(conn, "create topic if not exists "+topic+" as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists "+topic); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "commit_a") + TMQConfSet(conf, "auto.offset.reset", "earliest") + + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, topic) + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + _ = taosOperation(conn, "insert into t values(now,1)") + code, assignment := TMQGetTopicAssignment(tmq, topic) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + vgID := assignment[0].VGroupID + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + position := TMQPosition(tmq, topic, vgID) + assert.Greater(t, position, int64(0)) + committed := TMQCommitted(tmq, topic, vgID) + assert.Less(t, committed, int64(0)) + offset := TMQGetVgroupOffset(message) + c := make(chan *TMQCommitCallbackResult, 1) + handler := cgo.NewHandle(c) + TMQCommitOffsetAsync(tmq, topic, vgID, offset, handler) + timer := time.NewTimer(time.Second * 5) + select { + case r := <-c: + code = r.ErrCode + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + timer.Stop() + case <-timer.C: + t.Fatal("commit async timeout") + timer.Stop() + } + committed = TMQCommitted(tmq, topic, vgID) + assert.Equal(t, int64(offset), committed) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +func TestTMQCommitAsyncCallback(t *testing.T) { + topic := "test_tmq_commit_a_cb_topic" + tableName := "test_tmq_commit_a_cb" + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists "+tableName); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists "+tableName+" vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use "+tableName); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + + if err = taosOperation(conn, "create topic if not exists "+topic+" as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists "+topic); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "commit_a") + TMQConfSet(conf, "enable.auto.commit", "false") + TMQConfSet(conf, "auto.offset.reset", "earliest") + TMQConfSet(conf, "auto.commit.interval.ms", "100") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, topic) + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + _ = taosOperation(conn, "insert into t values(now,1)") + code, assignment := TMQGetTopicAssignment(tmq, topic) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + vgID := assignment[0].VGroupID + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + position := TMQPosition(tmq, topic, vgID) + assert.Greater(t, position, int64(0)) + committed := TMQCommitted(tmq, topic, vgID) + assert.Less(t, committed, int64(0)) + offset := TMQGetVgroupOffset(message) + TMQCommitOffsetSync(tmq, topic, vgID, offset) + committed = TMQCommitted(tmq, topic, vgID) + assert.Equal(t, offset, committed) + TaosFreeResult(message) + } + } + assert.True(t, haveMessage, "expect have message") + committed := TMQCommitted(tmq, topic, vgID) + t.Log(committed) + code, assignment = TMQGetTopicAssignment(tmq, topic) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + t.Log(assignment[0].Offset) + TMQCommitOffsetSync(tmq, topic, vgID, 1) + committed = TMQCommitted(tmq, topic, vgID) + assert.Equal(t, int64(1), committed) + code, assignment = TMQGetTopicAssignment(tmq, topic) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + t.Log(assignment[0].Offset) + position := TMQPosition(tmq, topic, vgID) + t.Log(position) + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +func taosOperation(conn unsafe.Pointer, sql string) (err error) { + res := TaosQuery(conn, sql) + defer TaosFreeResult(res) + code := TaosError(res) + if code != 0 { + err = errors.NewError(code, TaosErrorStr(res)) + } + return +} diff --git a/driver/wrapper/tmqcb.go b/driver/wrapper/tmqcb.go new file mode 100644 index 00000000..94ca7c37 --- /dev/null +++ b/driver/wrapper/tmqcb.go @@ -0,0 +1,49 @@ +package wrapper + +/* +#include +#include +#include +#include +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +//typedef void(tmq_commit_cb(tmq_t *, int32_t code, void *param)); + +//export TMQCommitCB +func TMQCommitCB(consumer unsafe.Pointer, resp C.int32_t, param unsafe.Pointer) { + c := (*(*cgo.Handle)(param)).Value().(chan *TMQCommitCallbackResult) + r := GetTMQCommitCallbackResult(int32(resp), consumer) + defer func() { + // Avoid panic due to channel closed + _ = recover() + }() + c <- r +} + +//export TMQAutoCommitCB +func TMQAutoCommitCB(consumer unsafe.Pointer, resp C.int32_t, param unsafe.Pointer) { + c := (*(*cgo.Handle)(param)).Value().(chan *TMQCommitCallbackResult) + r := GetTMQCommitCallbackResult(int32(resp), consumer) + defer func() { + // Avoid panic due to channel closed + _ = recover() + }() + c <- r +} + +//export TMQCommitOffsetCB +func TMQCommitOffsetCB(consumer unsafe.Pointer, resp C.int32_t, param unsafe.Pointer) { + c := (*(*cgo.Handle)(param)).Value().(chan *TMQCommitCallbackResult) + r := GetTMQCommitCallbackResult(int32(resp), consumer) + defer func() { + // Avoid panic due to channel closed + _ = recover() + }() + c <- r +} diff --git a/driver/wrapper/whitelist.go b/driver/wrapper/whitelist.go new file mode 100644 index 00000000..65788e1f --- /dev/null +++ b/driver/wrapper/whitelist.go @@ -0,0 +1,29 @@ +package wrapper + +/* +#cgo CFLAGS: -IC:/TDengine/include -I/usr/include +#cgo linux LDFLAGS: -L/usr/lib -ltaos +#cgo windows LDFLAGS: -LC:/TDengine/driver -ltaos +#cgo darwin LDFLAGS: -L/usr/local/lib -ltaos +#include +#include +#include +#include +extern void WhitelistCallback(void *param, int code, TAOS *taos, int numOfWhiteLists, uint64_t* pWhiteLists); +void taos_fetch_whitelist_a_wrapper(TAOS *taos, void *param){ + return taos_fetch_whitelist_a(taos, WhitelistCallback, param); +}; +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +// typedef void (*__taos_async_whitelist_fn_t)(void *param, int code, TAOS *taos, int numOfWhiteLists, uint64_t* pWhiteLists); + +// TaosFetchWhitelistA DLL_EXPORT void taos_fetch_whitelist_a(TAOS *taos, __taos_async_whitelist_fn_t fp, void *param); +func TaosFetchWhitelistA(taosConnect unsafe.Pointer, caller cgo.Handle) { + C.taos_fetch_whitelist_a_wrapper(taosConnect, caller.Pointer()) +} diff --git a/driver/wrapper/whitelist_test.go b/driver/wrapper/whitelist_test.go new file mode 100644 index 00000000..d985ff06 --- /dev/null +++ b/driver/wrapper/whitelist_test.go @@ -0,0 +1,21 @@ +package wrapper + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +func TestGetWhiteList(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + c := make(chan *WhitelistResult, 1) + handler := cgo.NewHandle(c) + TaosFetchWhitelistA(conn, handler) + data := <-c + assert.Equal(t, int32(0), data.ErrCode) + assert.Equal(t, 1, len(data.IPNets)) + assert.Equal(t, "0.0.0.0/0", data.IPNets[0].String()) +} diff --git a/driver/wrapper/whitelistcb.go b/driver/wrapper/whitelistcb.go new file mode 100644 index 00000000..aab5471f --- /dev/null +++ b/driver/wrapper/whitelistcb.go @@ -0,0 +1,35 @@ +package wrapper + +import "C" +import ( + "net" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +type WhitelistResult struct { + ErrCode int32 + IPNets []*net.IPNet +} + +//export WhitelistCallback +func WhitelistCallback(param unsafe.Pointer, code int, taosConnect unsafe.Pointer, numOfWhiteLists int, pWhiteLists unsafe.Pointer) { + c := (*(*cgo.Handle)(param)).Value().(chan *WhitelistResult) + if code != 0 { + c <- &WhitelistResult{ErrCode: int32(code)} + return + } + ips := make([]*net.IPNet, 0, numOfWhiteLists) + for i := 0; i < numOfWhiteLists; i++ { + ipNet := make([]byte, 8) + for j := 0; j < 8; j++ { + ipNet[j] = *(*byte)(unsafe.Pointer(uintptr(pWhiteLists) + uintptr(i*8) + uintptr(j))) + } + ip := net.IP{ipNet[0], ipNet[1], ipNet[2], ipNet[3]} + ones := int(ipNet[4]) + ipMask := net.CIDRMask(ones, 32) + ips = append(ips, &net.IPNet{IP: ip, Mask: ipMask}) + } + c <- &WhitelistResult{IPNets: ips} +} diff --git a/driver/wrapper/whitelistcb_test.go b/driver/wrapper/whitelistcb_test.go new file mode 100644 index 00000000..9cc62a4f --- /dev/null +++ b/driver/wrapper/whitelistcb_test.go @@ -0,0 +1,57 @@ +package wrapper + +import ( + "net" + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" +) + +func TestWhitelistCallback_ErrorCode(t *testing.T) { + // Create a channel to receive the result + resultChan := make(chan *WhitelistResult, 1) + handle := cgo.NewHandle(resultChan) + // Simulate an error (code != 0) + go WhitelistCallback(handle.Pointer(), 1, nil, 0, nil) + + // Expect the result to have an error code + result := <-resultChan + assert.Equal(t, int32(1), result.ErrCode) + assert.Nil(t, result.IPNets) // No IPs should be returned +} + +func TestWhitelistCallback_Success(t *testing.T) { + // Prepare the test data: a list of byte slices representing IPs and masks + ipList := []byte{ + 192, 168, 1, 1, 24, // 192.168.1.1/24 + 0, 0, 0, + 10, 0, 0, 1, 16, // 10.0.0.1/16 + } + + // Create a channel to receive the result + resultChan := make(chan *WhitelistResult, 1) + + // Cast the byte slice to an unsafe pointer + pWhiteLists := unsafe.Pointer(&ipList[0]) + handle := cgo.NewHandle(resultChan) + // Simulate a successful callback (code == 0) + go WhitelistCallback(handle.Pointer(), 0, nil, 2, pWhiteLists) + + // Expect the result to have two IPNets + result := <-resultChan + assert.Equal(t, int32(0), result.ErrCode) + assert.Len(t, result.IPNets, 2) + + // Validate the first IPNet (192.168.1.1/24) + assert.Equal(t, net.IPv4(192, 168, 1, 1).To4(), result.IPNets[0].IP) + + ones, _ := result.IPNets[0].Mask.Size() + assert.Equal(t, 24, ones) + + // Validate the second IPNet (10.0.0.1/16) + assert.Equal(t, net.IPv4(10, 0, 0, 1).To4(), result.IPNets[1].IP) + ones, _ = result.IPNets[1].Mask.Size() + assert.Equal(t, 16, ones) +} diff --git a/go.mod b/go.mod index 1cecbbc9..07d23e92 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,6 @@ require ( github.com/spf13/viper v1.14.0 github.com/stretchr/testify v1.9.0 github.com/swaggo/swag v1.8.8 - github.com/taosdata/driver-go/v3 v3.5.1-0.20241101015534-8fb37f82db51 github.com/taosdata/file-rotatelogs/v2 v2.5.2 go.uber.org/automaxprocs v1.5.1 golang.org/x/sync v0.1.0 diff --git a/go.sum b/go.sum index ba7dcb3b..5e9b5f10 100644 --- a/go.sum +++ b/go.sum @@ -2609,8 +2609,6 @@ github.com/swaggo/swag v1.8.8/go.mod h1:ezQVUUhly8dludpVk+/PuwJWvLLanB13ygV5Pr9e github.com/syndtr/gocapability v0.0.0-20170704070218-db04d3cc01c8/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= -github.com/taosdata/driver-go/v3 v3.5.1-0.20241101015534-8fb37f82db51 h1:diWG8X6vBAERTPk9DnsMfcFH/gN5s9v/WZ2C+gJMcs4= -github.com/taosdata/driver-go/v3 v3.5.1-0.20241101015534-8fb37f82db51/go.mod h1:H2vo/At+rOPY1aMzUV9P49SVX7NlXb3LAbKw+MCLrmU= github.com/taosdata/file-rotatelogs/v2 v2.5.2 h1:6ryjwDdKqQtWrkVq9OKj4gvMING/f+fDluMAAe2DIXQ= github.com/taosdata/file-rotatelogs/v2 v2.5.2/go.mod h1:Qm99Lh0iMZouGgyy++JgTqKvP5FQw1ruR5jkWF7e1n0= github.com/tbrandon/mbserver v0.0.0-20170611213546-993e1772cc62/go.mod h1:qUzPVlSj2UgxJkVbH0ZwuuiR46U8RBMDT5KLY78Ifpw= diff --git a/plugin/collectd/config.go b/plugin/collectd/config.go index 2f96ab49..388727af 100644 --- a/plugin/collectd/config.go +++ b/plugin/collectd/config.go @@ -3,7 +3,7 @@ package collectd import ( "github.com/spf13/pflag" "github.com/spf13/viper" - "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/taosadapter/v3/driver/common" ) type Config struct { diff --git a/plugin/collectd/plugin_test.go b/plugin/collectd/plugin_test.go index ec64395d..e69501f8 100644 --- a/plugin/collectd/plugin_test.go +++ b/plugin/collectd/plugin_test.go @@ -7,16 +7,17 @@ import ( "net" "testing" "time" + "unsafe" "collectd.org/api" "collectd.org/network" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/af" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" ) // @author: xftan @@ -28,15 +29,12 @@ func TestCollectd(t *testing.T) { t.Error(err) return } - afC, err := af.NewConnector(conn) - assert.NoError(t, err) defer func() { - err = afC.Close() - assert.NoError(t, err) + wrapper.TaosClose(conn) }() - _, err = afC.Exec("drop database if exists collectd") + err = exec(conn, "drop database if exists collectd") assert.NoError(t, err) - _, err = afC.Exec("create database if not exists collectd") + err = exec(conn, "create database if not exists collectd") assert.NoError(t, err) //nolint:staticcheck rand.Seed(time.Now().UnixNano()) @@ -99,36 +97,60 @@ func TestCollectd(t *testing.T) { } wrapper.TaosFreeResult(r) }() - r, err := afC.Query("select last(`value`) from collectd.`cpu_value`") - if err != nil { - t.Error(err) - return - } - defer func() { - err = r.Close() - assert.NoError(t, err) - }() - values := make([]driver.Value, 1) - err = r.Next(values) + values, err := query(conn, "select last(`value`) from collectd.`cpu_value`") assert.NoError(t, err) - if int32(values[0].(float64)) != number { + if int32(values[0][0].(float64)) != number { t.Errorf("got %f expect %d", values[0], number) } - r, err = afC.Query("select `ttl` from information_schema.ins_tables " + + values, err = query(conn, "select `ttl` from information_schema.ins_tables "+ " where db_name='collectd' and stable_name='cpu_value'") if err != nil { t.Error(err) return } - defer func() { - err = r.Close() - assert.NoError(t, err) - }() - values = make([]driver.Value, 1) - err = r.Next(values) - assert.NoError(t, err) - if values[0].(int32) != 1000 { + if values[0][0].(int32) != 1000 { t.Fatal("ttl miss") } } + +func exec(conn unsafe.Pointer, sql string) error { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return errors.NewError(code, errStr) + } + return nil +} + +func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(code, errStr) + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + return nil, err + } + precision := wrapper.TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(errCode, errStr) + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + return result, nil +} diff --git a/plugin/influxdb/plugin.go b/plugin/influxdb/plugin.go index 642f58ce..3c1696ac 100644 --- a/plugin/influxdb/plugin.go +++ b/plugin/influxdb/plugin.go @@ -7,9 +7,9 @@ import ( "strings" "github.com/gin-gonic/gin" - tErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/commonpool" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/plugin" diff --git a/plugin/influxdb/plugin_test.go b/plugin/influxdb/plugin_test.go index 2c19cddc..552f2468 100644 --- a/plugin/influxdb/plugin_test.go +++ b/plugin/influxdb/plugin_test.go @@ -9,15 +9,16 @@ import ( "strings" "testing" "time" + "unsafe" "github.com/gin-gonic/gin" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/af" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) @@ -40,22 +41,14 @@ func TestInfluxdb(t *testing.T) { t.Error(err) return } - afC, err := af.NewConnector(conn) - assert.NoError(t, err) defer func() { - err = afC.Close() - assert.NoError(t, err) + wrapper.TaosClose(conn) }() - _, err = afC.Exec("create database if not exists test_plugin_influxdb") + err = exec(conn, "drop database if exists test_plugin_influxdb") assert.NoError(t, err) defer func() { - r := wrapper.TaosQuery(conn, "drop database if exists test_plugin_influxdb") - code := wrapper.TaosError(r) - if code != 0 { - errStr := wrapper.TaosErrorStr(r) - t.Error(errors.NewError(code, errStr)) - } - wrapper.TaosFreeResult(r) + err = exec(conn, "drop database if exists test_plugin_influxdb") + assert.NoError(t, err) }() err = p.Init(router) assert.NoError(t, err) @@ -85,29 +78,14 @@ func TestInfluxdb(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) time.Sleep(time.Second) - r, err := afC.Query("select last(*) from test_plugin_influxdb.`measurement`") - if err != nil { - t.Error(err) - return - } - defer func() { - err = r.Close() - assert.NoError(t, err) - }() - fieldCount := len(r.Columns()) - values := make([]driver.Value, fieldCount) - err = r.Next(values) + values, err := query(conn, "select * from test_plugin_influxdb.`measurement`") assert.NoError(t, err) - keyMap := map[string]int{} - for i, s := range r.Columns() { - keyMap[s] = i - } - if values[3].(string) != "Launch 🚀" { - t.Errorf("got %s expect %s", values[3], "Launch 🚀") + if values[0][3].(string) != "Launch 🚀" { + t.Errorf("got %s expect %s", values[0][3], "Launch 🚀") return } - if int32(values[1].(int64)) != number { - t.Errorf("got %d expect %d", values[1].(int64), number) + if int32(values[0][1].(int64)) != number { + t.Errorf("got %d expect %d", values[0][1].(int64), number) return } @@ -117,21 +95,51 @@ func TestInfluxdb(t *testing.T) { req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) time.Sleep(time.Second) - - r, err = afC.Query("select `ttl` from information_schema.ins_tables " + + values, err = query(conn, "select `ttl` from information_schema.ins_tables "+ " where db_name='test_plugin_influxdb_ttl' and stable_name='measurement_ttl'") - if err != nil { - t.Error(err) - return - } - defer func() { - err = r.Close() - assert.NoError(t, err) - }() - values = make([]driver.Value, 1) - err = r.Next(values) assert.NoError(t, err) - if values[0].(int32) != 1000 { + if values[0][0].(int32) != 1000 { t.Fatal("ttl miss") } } + +func exec(conn unsafe.Pointer, sql string) error { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return errors.NewError(code, errStr) + } + return nil +} + +func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(code, errStr) + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + return nil, err + } + precision := wrapper.TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(errCode, errStr) + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + return result, nil +} diff --git a/plugin/nodeexporter/config.go b/plugin/nodeexporter/config.go index 340db76c..684a3637 100644 --- a/plugin/nodeexporter/config.go +++ b/plugin/nodeexporter/config.go @@ -5,7 +5,7 @@ import ( "github.com/spf13/pflag" "github.com/spf13/viper" - "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/taosadapter/v3/driver/common" ) type Config struct { diff --git a/plugin/nodeexporter/plugin.go b/plugin/nodeexporter/plugin.go index 61e39cf0..7b0e0658 100644 --- a/plugin/nodeexporter/plugin.go +++ b/plugin/nodeexporter/plugin.go @@ -17,8 +17,8 @@ import ( "github.com/gin-gonic/gin" tmetric "github.com/influxdata/telegraf/metric" "github.com/influxdata/telegraf/plugins/serializers/influx" - "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/taosadapter/v3/db/commonpool" + "github.com/taosdata/taosadapter/v3/driver/common" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/plugin" diff --git a/plugin/nodeexporter/plugin_test.go b/plugin/nodeexporter/plugin_test.go index c6340d87..852af10b 100644 --- a/plugin/nodeexporter/plugin_test.go +++ b/plugin/nodeexporter/plugin_test.go @@ -6,12 +6,15 @@ import ( "net/http/httptest" "testing" "time" + "unsafe" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/af" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) @@ -52,17 +55,14 @@ func TestNodeExporter_Gather(t *testing.T) { viper.Set("node_exporter.urls", []string{api}) viper.Set("node_exporter.gatherDuration", time.Second) viper.Set("node_exporter.ttl", 1000) - conn, err := af.Open("", "", "", "", 0) + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) assert.NoError(t, err) defer func() { - err = conn.Close() - if err != nil { - t.Error(err) - } + wrapper.TaosClose(conn) }() - _, err = conn.Exec("create database if not exists node_exporter precision 'ns'") + err = exec(conn, "drop database if exists node_exporter") assert.NoError(t, err) - err = conn.SelectDB("node_exporter") + err = exec(conn, "use node_exporter") assert.NoError(t, err) n := NodeExporter{} err = n.Init(nil) @@ -70,41 +70,58 @@ func TestNodeExporter_Gather(t *testing.T) { err = n.Start() assert.NoError(t, err) time.Sleep(time.Second * 2) - rows, err := conn.Query("select last(`value`) as `value` from node_exporter.test_metric;") - assert.NoError(t, err) - defer func() { - err = rows.Close() - if err != nil { - t.Error(err) - } - }() - assert.Equal(t, 1, len(rows.Columns())) - d := make([]driver.Value, 1) - err = rows.Next(d) + values, err := query(conn, "select last(`value`) as `value` from node_exporter.go_gc_duration_seconds;") assert.NoError(t, err) - assert.Equal(t, float64(1), d[0]) + assert.Equal(t, float64(1), values[0][0]) err = n.Stop() assert.NoError(t, err) - - rows, err = conn.Query("select `ttl` from information_schema.ins_tables " + + values, err = query(conn, "select `ttl` from information_schema.ins_tables "+ " where db_name='node_exporter' and stable_name='test_metric'") - if err != nil { - t.Error(err) - return - } - defer func() { - err = rows.Close() - if err != nil { - t.Error(err) - } - }() - values := make([]driver.Value, 1) - err = rows.Next(values) assert.NoError(t, err) - if values[0].(int32) != 1000 { + if values[0][0].(int32) != 1000 { t.Fatal("ttl miss") } - - _, err = conn.Exec("drop database if exists node_exporter") + err = exec(conn, "drop database if exists node_exporter") assert.NoError(t, err) } + +func exec(conn unsafe.Pointer, sql string) error { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return errors.NewError(code, errStr) + } + return nil +} + +func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(code, errStr) + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + return nil, err + } + precision := wrapper.TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(errCode, errStr) + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + return result, nil +} diff --git a/plugin/opentsdb/plugin.go b/plugin/opentsdb/plugin.go index 0ceb0477..f22cfc35 100644 --- a/plugin/opentsdb/plugin.go +++ b/plugin/opentsdb/plugin.go @@ -9,9 +9,9 @@ import ( "strconv" "github.com/gin-gonic/gin" - tErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/commonpool" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/plugin" diff --git a/plugin/opentsdb/plugin_test.go b/plugin/opentsdb/plugin_test.go index 22451901..a2101fa0 100644 --- a/plugin/opentsdb/plugin_test.go +++ b/plugin/opentsdb/plugin_test.go @@ -9,15 +9,16 @@ import ( "strings" "testing" "time" + "unsafe" "github.com/gin-gonic/gin" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/af" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) @@ -41,13 +42,10 @@ func TestOpentsdb(t *testing.T) { t.Error(err) return } - afC, err := af.NewConnector(conn) - assert.NoError(t, err) defer func() { - err = afC.Close() - assert.NoError(t, err) + wrapper.TaosClose(conn) }() - _, err = afC.Exec("create database if not exists test_plugin_opentsdb_http_json") + err = exec(conn, "drop database if exists test_plugin_opentsdb_http_telnet") assert.NoError(t, err) err = p.Init(router) assert.NoError(t, err) @@ -82,87 +80,74 @@ func TestOpentsdb(t *testing.T) { assert.Equal(t, 204, w.Code) defer func() { - r := wrapper.TaosQuery(conn, "drop database if exists test_plugin_opentsdb_http_json") - code := wrapper.TaosError(r) - if code != 0 { - errStr := wrapper.TaosErrorStr(r) - t.Error(errors.NewError(code, errStr)) - } - wrapper.TaosFreeResult(r) - }() - defer func() { - r := wrapper.TaosQuery(conn, "drop database if exists test_plugin_opentsdb_http_telnet") - code := wrapper.TaosError(r) - if code != 0 { - errStr := wrapper.TaosErrorStr(r) - t.Error(errors.NewError(code, errStr)) - } - wrapper.TaosFreeResult(r) + err = exec(conn, "drop database if exists test_plugin_opentsdb_http_json") + assert.NoError(t, err) }() - - r, err := afC.Query("select last(_value) from test_plugin_opentsdb_http_json.`sys_cpu_nice`") - if err != nil { - t.Error(err) - return - } defer func() { - err = r.Close() + err = exec(conn, "drop database if exists test_plugin_opentsdb_http_telnet") assert.NoError(t, err) }() - values := make([]driver.Value, 1) - err = r.Next(values) + values, err := query(conn, "select last(_value) from test_plugin_opentsdb_http_json.`sys_cpu_nice`") assert.NoError(t, err) - if int32(values[0].(float64)) != number { + if int32(values[0][0].(float64)) != number { t.Errorf("got %f expect %d", values[0], number) } - - r2, err := afC.Query("select last(_value) from test_plugin_opentsdb_http_telnet.`metric`") - if err != nil { - t.Error(err) - return - } - defer func() { - err = r2.Close() - assert.NoError(t, err) - }() - values = make([]driver.Value, 1) - err = r2.Next(values) + values, err = query(conn, "select last(_value) from test_plugin_opentsdb_http_telnet.`metric`") assert.NoError(t, err) - if int32(values[0].(float64)) != number { + if int32(values[0][0].(float64)) != number { t.Errorf("got %f expect %d", values[0], number) } - - rows, err := afC.Query("select `ttl` from information_schema.ins_tables " + + values, err = query(conn, "select `ttl` from information_schema.ins_tables "+ " where db_name='test_plugin_opentsdb_http_json' and stable_name='sys_cpu_nice'") - if err != nil { - t.Error(err) - return + assert.NoError(t, err) + if values[0][0].(int32) != 1000 { + t.Fatal("ttl miss") } - defer func() { - err = rows.Close() - assert.NoError(t, err) - }() - values = make([]driver.Value, 1) - err = rows.Next(values) + values, err = query(conn, "select `ttl` from information_schema.ins_tables "+ + " where db_name='test_plugin_opentsdb_http_telnet' and stable_name='metric'") assert.NoError(t, err) - if values[0].(int32) != 1000 { + if values[0][0].(int32) != 1000 { t.Fatal("ttl miss") } +} - rows, err = afC.Query("select `ttl` from information_schema.ins_tables " + - " where db_name='test_plugin_opentsdb_http_telnet' and stable_name='metric'") +func exec(conn unsafe.Pointer, sql string) error { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return errors.NewError(code, errStr) + } + return nil +} + +func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(code, errStr) + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) if err != nil { - t.Error(err) - return + return nil, err } - defer func() { - err = rows.Close() - assert.NoError(t, err) - }() - values = make([]driver.Value, 1) - err = rows.Next(values) - assert.NoError(t, err) - if values[0].(int32) != 1000 { - t.Fatal("ttl miss") + precision := wrapper.TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(errCode, errStr) + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) } + return result, nil } diff --git a/plugin/opentsdbtelnet/config.go b/plugin/opentsdbtelnet/config.go index 99aadb5b..9b79858b 100644 --- a/plugin/opentsdbtelnet/config.go +++ b/plugin/opentsdbtelnet/config.go @@ -5,7 +5,7 @@ import ( "github.com/spf13/pflag" "github.com/spf13/viper" - "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/taosadapter/v3/driver/common" ) type Config struct { diff --git a/plugin/opentsdbtelnet/plugin_test.go b/plugin/opentsdbtelnet/plugin_test.go index d2c8b3f4..a8b937ee 100644 --- a/plugin/opentsdbtelnet/plugin_test.go +++ b/plugin/opentsdbtelnet/plugin_test.go @@ -7,14 +7,15 @@ import ( "net" "testing" "time" + "unsafe" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/af" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/plugin/opentsdbtelnet" ) @@ -37,13 +38,10 @@ func TestPlugin(t *testing.T) { t.Error(err) return } - afC, err := af.NewConnector(conn) - assert.NoError(t, err) defer func() { - err = afC.Close() - assert.NoError(t, err) + wrapper.TaosClose(conn) }() - _, err = afC.Exec("create database if not exists opentsdb_telnet") + err = exec(conn, "drop database if exists opentsdb_telnet") assert.NoError(t, err) err = p.Init(nil) assert.NoError(t, err) @@ -73,37 +71,56 @@ func TestPlugin(t *testing.T) { } wrapper.TaosFreeResult(r) }() - - r, err := afC.Query("select last(_value) from opentsdb_telnet.`sys_if_bytes_out`") - if err != nil { - t.Error(err) - return - } - defer func() { - err = r.Close() - assert.NoError(t, err) - }() - values := make([]driver.Value, 1) - err = r.Next(values) + values, err := query(conn, "select last(_value) from opentsdb_telnet.`sys_if_bytes_out`") assert.NoError(t, err) - if int32(values[0].(float64)) != number { + if int32(values[0][0].(float64)) != number { t.Errorf("got %f expect %d", values[0], number) } - - rows, err := afC.Query("select `ttl` from information_schema.ins_tables " + + values, err = query(conn, "select `ttl` from information_schema.ins_tables "+ " where db_name='opentsdb_telnet' and stable_name='sys_if_bytes_out'") - if err != nil { - t.Error(err) - return - } - defer func() { - err = rows.Close() - assert.NoError(t, err) - }() - values = make([]driver.Value, 1) - err = rows.Next(values) assert.NoError(t, err) - if values[0].(int32) != 1000 { + if values[0][0].(int32) != 1000 { t.Fatal("ttl miss") } } + +func exec(conn unsafe.Pointer, sql string) error { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return errors.NewError(code, errStr) + } + return nil +} + +func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(code, errStr) + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + return nil, err + } + precision := wrapper.TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(errCode, errStr) + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + return result, nil +} diff --git a/plugin/prometheus/plugin.go b/plugin/prometheus/plugin.go index 5b462a1f..13744925 100644 --- a/plugin/prometheus/plugin.go +++ b/plugin/prometheus/plugin.go @@ -10,8 +10,8 @@ import ( "github.com/gogo/protobuf/proto" "github.com/golang/snappy" "github.com/prometheus/prometheus/prompb" - tErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/taosadapter/v3/db/commonpool" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/plugin" diff --git a/plugin/prometheus/plugin_test.go b/plugin/prometheus/plugin_test.go index 76409f49..dfcca1ea 100644 --- a/plugin/prometheus/plugin_test.go +++ b/plugin/prometheus/plugin_test.go @@ -14,9 +14,9 @@ import ( "github.com/prometheus/prometheus/prompb" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) diff --git a/plugin/prometheus/process.go b/plugin/prometheus/process.go index f1880fb7..10b9973b 100644 --- a/plugin/prometheus/process.go +++ b/plugin/prometheus/process.go @@ -16,13 +16,13 @@ import ( jsoniter "github.com/json-iterator/go" "github.com/prometheus/prometheus/prompb" - "github.com/taosdata/driver-go/v3/common" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/driver/common" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" prompbWrite "github.com/taosdata/taosadapter/v3/plugin/prometheus/proto/write" diff --git a/plugin/statsd/config.go b/plugin/statsd/config.go index 66e49f6f..98139851 100644 --- a/plugin/statsd/config.go +++ b/plugin/statsd/config.go @@ -5,7 +5,7 @@ import ( "github.com/spf13/pflag" "github.com/spf13/viper" - "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/taosadapter/v3/driver/common" ) type Config struct { diff --git a/plugin/statsd/plugin_test.go b/plugin/statsd/plugin_test.go index 14b3aee5..2bcb2bd6 100644 --- a/plugin/statsd/plugin_test.go +++ b/plugin/statsd/plugin_test.go @@ -7,14 +7,15 @@ import ( "net" "testing" "time" + "unsafe" "github.com/spf13/viper" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/af" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/common/parser" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" ) @@ -36,13 +37,10 @@ func TestStatsd(t *testing.T) { t.Error(err) return } - afC, err := af.NewConnector(conn) - assert.NoError(t, err) defer func() { - err = afC.Close() - assert.NoError(t, err) + wrapper.TaosClose(conn) }() - _, err = afC.Exec("create database if not exists statsd") + err = exec(conn, "drop database if exists statsd") assert.NoError(t, err) err = p.Init(nil) assert.NoError(t, err) @@ -68,37 +66,56 @@ func TestStatsd(t *testing.T) { } wrapper.TaosFreeResult(r) }() - - r, err := afC.Query("select last(`value`) from statsd.`foo`") - if err != nil { - t.Error(err) - return - } - defer func() { - err = r.Close() - assert.NoError(t, err) - }() - values := make([]driver.Value, 1) - err = r.Next(values) + values, err := query(conn, "select last(`value`) from statsd.`foo`") assert.NoError(t, err) - if int32(values[0].(int64)) != number { + if int32(values[0][0].(int64)) != number { t.Errorf("got %f expect %d", values[0], number) } - - rows, err := afC.Query("select `ttl` from information_schema.ins_tables " + + values, err = query(conn, "select `ttl` from information_schema.ins_tables "+ " where db_name='statsd' and stable_name='foo'") - if err != nil { - t.Error(err) - return - } - defer func() { - err = rows.Close() - assert.NoError(t, err) - }() - values = make([]driver.Value, 1) - err = rows.Next(values) assert.NoError(t, err) - if values[0].(int32) != 1000 { + if values[0][0].(int32) != 1000 { t.Fatal("ttl miss") } } + +func exec(conn unsafe.Pointer, sql string) error { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return errors.NewError(code, errStr) + } + return nil +} + +func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { + res := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(res) + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(code, errStr) + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + return nil, err + } + precision := wrapper.TaosResultPrecision(res) + var result [][]driver.Value + for { + columns, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != 0 { + errStr := wrapper.TaosErrorStr(res) + return nil, errors.NewError(errCode, errStr) + } + if columns == 0 { + break + } + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) + result = append(result, r...) + } + return result, nil +} diff --git a/schemaless/capi/influxdb.go b/schemaless/capi/influxdb.go index 2cc4b1ed..3e90064c 100644 --- a/schemaless/capi/influxdb.go +++ b/schemaless/capi/influxdb.go @@ -5,10 +5,10 @@ import ( "unsafe" "github.com/sirupsen/logrus" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/generator" ) diff --git a/schemaless/capi/influxdb_test.go b/schemaless/capi/influxdb_test.go index 45527e6f..755ca8f4 100644 --- a/schemaless/capi/influxdb_test.go +++ b/schemaless/capi/influxdb_test.go @@ -4,8 +4,8 @@ import ( "testing" "unsafe" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/schemaless/capi" ) diff --git a/schemaless/capi/opentsdb.go b/schemaless/capi/opentsdb.go index 67725e03..20aea201 100644 --- a/schemaless/capi/opentsdb.go +++ b/schemaless/capi/opentsdb.go @@ -5,10 +5,10 @@ import ( "unsafe" "github.com/sirupsen/logrus" - tErrors "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/tools/generator" ) diff --git a/schemaless/capi/opentsdb_test.go b/schemaless/capi/opentsdb_test.go index c6e52e92..8207471b 100644 --- a/schemaless/capi/opentsdb_test.go +++ b/schemaless/capi/opentsdb_test.go @@ -7,10 +7,10 @@ import ( "unsafe" "github.com/spf13/viper" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db" + "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/schemaless/capi" ) diff --git a/tools/ctools/block.go b/tools/ctools/block.go index 0f73900e..72f0f08f 100644 --- a/tools/ctools/block.go +++ b/tools/ctools/block.go @@ -5,8 +5,8 @@ import ( "strconv" "unsafe" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" ) diff --git a/tools/ctools/block_test.go b/tools/ctools/block_test.go index 79e7e1ca..03390fba 100644 --- a/tools/ctools/block_test.go +++ b/tools/ctools/block_test.go @@ -7,8 +7,8 @@ import ( "unsafe" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/taosadapter/v3/driver/common" + "github.com/taosdata/taosadapter/v3/driver/common/parser" "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" ) diff --git a/tools/parseblock/parse.go b/tools/parseblock/parse.go index a7abbc7e..cf98fe45 100644 --- a/tools/parseblock/parse.go +++ b/tools/parseblock/parse.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" "unsafe" - "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/taosadapter/v3/driver/common/parser" "github.com/taosdata/taosadapter/v3/tools" ) diff --git a/tools/parseblock/parse_test.go b/tools/parseblock/parse_test.go index 611b1404..877353af 100644 --- a/tools/parseblock/parse_test.go +++ b/tools/parseblock/parse_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/taosadapter/v3/driver/common" ) func TestParseBlock(t *testing.T) { diff --git a/version/version.go b/version/version.go index 235bc0d8..4baa673d 100644 --- a/version/version.go +++ b/version/version.go @@ -1,6 +1,6 @@ package version -import "github.com/taosdata/driver-go/v3/wrapper" +import "github.com/taosdata/taosadapter/v3/driver/wrapper" var Version = "0.1.0" From 3b36a591b5160d1a1db78fe8d5af74dc568681e7 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 2 Dec 2024 14:26:44 +0800 Subject: [PATCH 22/48] fix: unit test using wrong sql --- plugin/nodeexporter/plugin_test.go | 4 ++-- plugin/opentsdbtelnet/plugin_test.go | 2 +- plugin/statsd/plugin_test.go | 11 +++-------- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/plugin/nodeexporter/plugin_test.go b/plugin/nodeexporter/plugin_test.go index 852af10b..5928863f 100644 --- a/plugin/nodeexporter/plugin_test.go +++ b/plugin/nodeexporter/plugin_test.go @@ -60,7 +60,7 @@ func TestNodeExporter_Gather(t *testing.T) { defer func() { wrapper.TaosClose(conn) }() - err = exec(conn, "drop database if exists node_exporter") + err = exec(conn, "create database if not exists node_exporter precision 'ns'") assert.NoError(t, err) err = exec(conn, "use node_exporter") assert.NoError(t, err) @@ -70,7 +70,7 @@ func TestNodeExporter_Gather(t *testing.T) { err = n.Start() assert.NoError(t, err) time.Sleep(time.Second * 2) - values, err := query(conn, "select last(`value`) as `value` from node_exporter.go_gc_duration_seconds;") + values, err := query(conn, "select last(`value`) as `value` from node_exporter.test_metric;") assert.NoError(t, err) assert.Equal(t, float64(1), values[0][0]) err = n.Stop() diff --git a/plugin/opentsdbtelnet/plugin_test.go b/plugin/opentsdbtelnet/plugin_test.go index a8b937ee..97171270 100644 --- a/plugin/opentsdbtelnet/plugin_test.go +++ b/plugin/opentsdbtelnet/plugin_test.go @@ -41,7 +41,7 @@ func TestPlugin(t *testing.T) { defer func() { wrapper.TaosClose(conn) }() - err = exec(conn, "drop database if exists opentsdb_telnet") + err = exec(conn, "create database if not exists opentsdb_telnet") assert.NoError(t, err) err = p.Init(nil) assert.NoError(t, err) diff --git a/plugin/statsd/plugin_test.go b/plugin/statsd/plugin_test.go index 2bcb2bd6..e2e28d0e 100644 --- a/plugin/statsd/plugin_test.go +++ b/plugin/statsd/plugin_test.go @@ -40,7 +40,7 @@ func TestStatsd(t *testing.T) { defer func() { wrapper.TaosClose(conn) }() - err = exec(conn, "drop database if exists statsd") + err = exec(conn, "create database if not exists statsd") assert.NoError(t, err) err = p.Init(nil) assert.NoError(t, err) @@ -58,13 +58,8 @@ func TestStatsd(t *testing.T) { assert.NoError(t, err) time.Sleep(time.Second) defer func() { - r := wrapper.TaosQuery(conn, "drop database if exists statsd") - code := wrapper.TaosError(r) - if code != 0 { - errStr := wrapper.TaosErrorStr(r) - t.Error(errors.NewError(code, errStr)) - } - wrapper.TaosFreeResult(r) + err = exec(conn, "drop database if exists statsd") + assert.NoError(t, err) }() values, err := query(conn, "select last(`value`) from statsd.`foo`") assert.NoError(t, err) From 4fc6cb0a2ea5e51a3cdb403a2c678c0056efd215 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 3 Dec 2024 10:17:46 +0800 Subject: [PATCH 23/48] ci: remove go 1.17 test on macOS --- .github/workflows/macos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 8d3bf220..71438328 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -183,7 +183,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.17', 'stable' ] + go: [ 'stable' ] name: test taosAdapter ${{ matrix.go }} steps: - name: get cache server by pr From b8bf0cb8fa774da78ff3fc6f8254472efdc45620 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 3 Dec 2024 12:37:36 +0800 Subject: [PATCH 24/48] test: remove stmt2 bind with wrong column count --- driver/wrapper/stmt2_test.go | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/driver/wrapper/stmt2_test.go b/driver/wrapper/stmt2_test.go index f92b148f..eb025d0d 100644 --- a/driver/wrapper/stmt2_test.go +++ b/driver/wrapper/stmt2_test.go @@ -4901,40 +4901,6 @@ func TestTaosStmt2BindBinaryParse(t *testing.T) { }, wantErr: assert.Error, }, - { - name: "wrong param count", - args: args{ - sql: "insert into test1 values (?,?)", - data: []byte{ - // total Length - 0x3A, 0x00, 0x00, 0x00, - // tableCount - 0x01, 0x00, 0x00, 0x00, - // TagCount - 0x00, 0x00, 0x00, 0x00, - // ColCount - 0x01, 0x00, 0x00, 0x00, - // TableNamesOffset - 0x00, 0x00, 0x00, 0x00, - // TagsOffset - 0x00, 0x00, 0x00, 0x00, - // ColOffset - 0x1c, 0x00, 0x00, 0x00, - // cols - 0x1a, 0x00, 0x00, 0x00, - - 0x1a, 0x00, 0x00, 0x00, - 0x09, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - 0x00, - 0x08, 0x00, 0x00, 0x00, - 0xba, 0x08, 0x32, 0x27, 0x92, 0x01, 0x00, 0x00, - }, - colIdx: -1, - }, - wantErr: assert.Error, - }, { name: "bind binary", args: args{ From 86c9837c61bc9f58c36dde7c10f8b75756f24172 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 3 Dec 2024 12:52:19 +0800 Subject: [PATCH 25/48] ci: add go 1.17 test on macOS --- .github/workflows/macos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 71438328..8d3bf220 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -183,7 +183,7 @@ jobs: needs: build strategy: matrix: - go: [ 'stable' ] + go: [ '1.17', 'stable' ] name: test taosAdapter ${{ matrix.go }} steps: - name: get cache server by pr From 126b611a908d0fdea2dc8e31be9bf4c2305158c0 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 3 Dec 2024 13:07:55 +0800 Subject: [PATCH 26/48] ci: add macos-13 runner --- .github/workflows/macos.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 8d3bf220..13782204 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -179,12 +179,13 @@ jobs: run: cd ./taosadapter && go build test_go_test: - runs-on: macos-latest + runs-on: ${{ matrix.os }} needs: build strategy: matrix: + os: [ 'macos-latest','macos-13' ] go: [ '1.17', 'stable' ] - name: test taosAdapter ${{ matrix.go }} + name: test taosAdapter ${{ matrix.os }} ${{ matrix.go }} steps: - name: get cache server by pr if: github.event_name == 'pull_request' From 2b811e99572588e7bfc8fb3225a0b33c011dc7d6 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 3 Dec 2024 13:17:11 +0800 Subject: [PATCH 27/48] ci: add macos-13 runner --- .github/workflows/macos.yml | 42 ++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 13782204..c01a2f10 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -22,8 +22,11 @@ on: jobs: build: - runs-on: macos-latest - name: Build + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ 'macos-latest','macos-13' ] + name: Build-${{ matrix.os }} outputs: commit_id: ${{ steps.get_commit_id.outputs.commit_id }} steps: @@ -63,7 +66,7 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ steps.get_commit_id.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ matrix.os }}-${{ steps.get_commit_id.outputs.commit_id }} - name: Cache server by push if: github.event_name == 'push' @@ -71,7 +74,7 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ steps.get_commit_id.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ matrix.os }}-${{ steps.get_commit_id.outputs.commit_id }} - name: Cache server manually if: github.event_name == 'workflow_dispatch' @@ -79,7 +82,7 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ steps.get_commit_id.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ matrix.os }}-${{ steps.get_commit_id.outputs.commit_id }} - name: install TDengine @@ -118,10 +121,11 @@ jobs: tar -zcvf server.tar.gz ./release test_build: - runs-on: macos-latest + runs-on: ${{ matrix.os }} needs: build strategy: matrix: + os: [ 'macos-latest','macos-13' ] go: [ '1.17', 'stable' ] name: Build taosAdapter ${{ matrix.go }} steps: @@ -131,9 +135,9 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ needs.build.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} restore-keys: | - ${{ runner.os }}-build-${{ github.base_ref }}- + ${{ runner.os }}-build-${{ github.base_ref }}-${{ matrix.os }}- - name: get cache server by push if: github.event_name == 'push' @@ -141,9 +145,9 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ needs.build.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} restore-keys: | - ${{ runner.os }}-build-${{ github.ref_name }}- + ${{ runner.os }}-build-${{ github.ref_name }}-${{ matrix.os }}- - name: get cache server manually if: github.event_name == 'workflow_dispatch' @@ -151,9 +155,9 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ needs.build.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} restore-keys: | - ${{ runner.os }}-build-${{ inputs.tbBranch }}- + ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ matrix.os }}- - name: prepare install run: | @@ -193,9 +197,9 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ needs.build.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} restore-keys: | - ${{ runner.os }}-build-${{ github.base_ref }}- + ${{ runner.os }}-build-${{ github.base_ref }}-${{ matrix.os }}- - name: get cache server by push if: github.event_name == 'push' @@ -203,9 +207,9 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ needs.build.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} restore-keys: | - ${{ runner.os }}-build-${{ github.ref_name }}- + ${{ runner.os }}-build-${{ github.ref_name }}-${{ matrix.os }}- - name: get cache server manually if: github.event_name == 'workflow_dispatch' @@ -213,9 +217,9 @@ jobs: uses: actions/cache@v4 with: path: server.tar.gz - key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ needs.build.outputs.commit_id }} + key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} restore-keys: | - ${{ runner.os }}-build-${{ inputs.tbBranch }}- + ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ matrix.os }}- - name: prepare install run: | @@ -260,5 +264,5 @@ jobs: - uses: actions/upload-artifact@v4 if: always() && (steps.test.outcome == 'failure' || steps.test.outcome == 'cancelled') with: - name: ${{ runner.os }}-${{ matrix.go }}-log + name: ${{ runner.os }}-${{ matrix.os }}-${{ matrix.go }}-log path: /var/log/taos/ \ No newline at end of file From a0259c9fe45033aaa98e7a4cf6536afa1071943c Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 3 Dec 2024 14:23:27 +0800 Subject: [PATCH 28/48] ci: remove test on macOS --- .github/workflows/macos.yml | 172 ++++++++++++++++++------------------ 1 file changed, 86 insertions(+), 86 deletions(-) diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index c01a2f10..4d362431 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -25,7 +25,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ 'macos-latest','macos-13' ] + os: [ 'macos-latest' ] name: Build-${{ matrix.os }} outputs: commit_id: ${{ steps.get_commit_id.outputs.commit_id }} @@ -125,7 +125,7 @@ jobs: needs: build strategy: matrix: - os: [ 'macos-latest','macos-13' ] + os: [ 'macos-latest' ] go: [ '1.17', 'stable' ] name: Build taosAdapter ${{ matrix.go }} steps: @@ -182,87 +182,87 @@ jobs: - name: build taosAdapter run: cd ./taosadapter && go build - test_go_test: - runs-on: ${{ matrix.os }} - needs: build - strategy: - matrix: - os: [ 'macos-latest','macos-13' ] - go: [ '1.17', 'stable' ] - name: test taosAdapter ${{ matrix.os }} ${{ matrix.go }} - steps: - - name: get cache server by pr - if: github.event_name == 'pull_request' - id: get-cache-server-pr - uses: actions/cache@v4 - with: - path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} - restore-keys: | - ${{ runner.os }}-build-${{ github.base_ref }}-${{ matrix.os }}- - - - name: get cache server by push - if: github.event_name == 'push' - id: get-cache-server-push - uses: actions/cache@v4 - with: - path: server.tar.gz - key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} - restore-keys: | - ${{ runner.os }}-build-${{ github.ref_name }}-${{ matrix.os }}- - - - name: get cache server manually - if: github.event_name == 'workflow_dispatch' - id: get-cache-server-manually - uses: actions/cache@v4 - with: - path: server.tar.gz - key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ matrix.os }}-${{ needs.build.outputs.commit_id }} - restore-keys: | - ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ matrix.os }}- - - - name: prepare install - run: | - sudo mkdir -p /usr/local/lib - sudo mkdir -p /usr/local/include - - - name: install - run: | - tar -zxvf server.tar.gz - cd release && sudo sh install.sh - - - name: checkout - uses: actions/checkout@v4 - with: - path: 'taosadapter' - - - name: copy taos cfg - run: | - sudo mkdir -p /etc/taos - sudo cp ./taosadapter/.github/workflows/taos.cfg /etc/taos/taos.cfg - - - uses: actions/setup-go@v4 - with: - go-version: ${{ matrix.go }} - cache-dependency-path: taosadapter/go.sum - - - name: start shell - run: | - cat >start.sh<start.sh< Date: Tue, 3 Dec 2024 18:57:37 +0800 Subject: [PATCH 29/48] enh: support taos_stmt2_get_stb_fields --- controller/ws/ws/stmt2.go | 93 +++----- controller/ws/ws/stmt2_test.go | 282 ++++++++++++++++++++++++ db/syncinterface/wrapper.go | 12 + driver/wrapper/stmt2.go | 60 +++++ driver/wrapper/stmt2_test.go | 385 +++++++++++++++++++++++++++++++++ 5 files changed, 766 insertions(+), 66 deletions(-) diff --git a/controller/ws/ws/stmt2.go b/controller/ws/ws/stmt2.go index f5648547..8fa84034 100644 --- a/controller/ws/ws/stmt2.go +++ b/controller/ws/ws/stmt2.go @@ -83,21 +83,16 @@ type stmt2PrepareRequest struct { GetFields bool `json:"get_fields"` } -type prepareFields struct { - stmtCommon.StmtField - BindType int8 `json:"bind_type"` -} - type stmt2PrepareResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Action string `json:"action"` - ReqID uint64 `json:"req_id"` - Timing int64 `json:"timing"` - StmtID uint64 `json:"stmt_id"` - IsInsert bool `json:"is_insert"` - Fields []*prepareFields `json:"fields"` - FieldsCount int `json:"fields_count"` + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + IsInsert bool `json:"is_insert"` + Fields []*wrapper.StmtStbField `json:"fields"` + FieldsCount int `json:"fields_count"` } func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Session, action string, req stmt2PrepareRequest, logger *logrus.Entry, isDebug bool) { @@ -127,58 +122,18 @@ func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Sessi stmtItem.isInsert = isInsert prepareResp := &stmt2PrepareResponse{StmtID: req.StmtID, IsInsert: isInsert} if req.GetFields { - if isInsert { - var fields []*prepareFields - // get table field - _, count, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_TBNAME, logger, isDebug) - if code != 0 { - logger.Errorf("get table names fields error, code:%d, err:%s", code, errStr) - stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get table names fields error, %s", errStr), req.StmtID) - return - } - if count == 1 { - tableNameFields := &prepareFields{ - StmtField: stmtCommon.StmtField{}, - BindType: stmtCommon.TAOS_FIELD_TBNAME, - } - fields = append(fields, tableNameFields) - } - // get tags field - tagFields, _, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_TAG, logger, isDebug) - if code != 0 { - logger.Errorf("get tag fields error, code:%d, err:%s", code, errStr) - stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get tag fields error, %s", errStr), req.StmtID) - return - } - for i := 0; i < len(tagFields); i++ { - fields = append(fields, &prepareFields{ - StmtField: *tagFields[i], - BindType: stmtCommon.TAOS_FIELD_TAG, - }) - } - // get cols field - colFields, _, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_COL, logger, isDebug) - if code != 0 { - logger.Errorf("get col fields error, code:%d, err:%s", code, errStr) - stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get col fields error, %s", errStr), req.StmtID) - return - } - for i := 0; i < len(colFields); i++ { - fields = append(fields, &prepareFields{ - StmtField: *colFields[i], - BindType: stmtCommon.TAOS_FIELD_COL, - }) - } - prepareResp.Fields = fields - } else { - _, count, code, errStr := getFields(stmt2, stmtCommon.TAOS_FIELD_QUERY, logger, isDebug) - if code != 0 { - logger.Errorf("get query fields error, code:%d, err:%s", code, errStr) - stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get query fields error, %s", errStr), req.StmtID) - return - } - prepareResp.FieldsCount = count + code, count, fields := syncinterface.TaosStmt2GetStbFields(stmt2, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosStmt2Error(stmt2) + logger.Errorf("stmt2 get fields error, code:%d, err:%s", code, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) + return } + defer wrapper.TaosStmt2FreeStbFields(stmt2, fields) + stbFields := wrapper.ParseStmt2StbFields(count, fields) + prepareResp.Fields = stbFields + prepareResp.FieldsCount = count + } prepareResp.ReqID = req.ReqID prepareResp.Action = action @@ -298,7 +253,7 @@ func (h *messageHandler) stmt2Exec(ctx context.Context, session *melody.Session, code := syncinterface.TaosStmt2Exec(stmtItem.stmt, logger, isDebug) if code != 0 { errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) - logger.Errorf("stmt2 execute error, err:%s", errStr) + logger.Errorf("stmt2 execute error,code:%d, err:%s", code, errStr) stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) return } @@ -306,6 +261,12 @@ func (h *messageHandler) stmt2Exec(ctx context.Context, session *melody.Session, logger.Tracef("stmt2 execute wait callback, stmt_id:%d", req.StmtID) result := <-stmtItem.caller.ExecResult logger.Debugf("stmt2 execute wait callback finish, affected:%d, res:%p, n:%d, cost:%s", result.Affected, result.Res, result.N, log.GetLogDuration(isDebug, s)) + if result.N < 0 { + errStr := wrapper.TaosStmtErrStr(stmtItem.stmt) + logger.Errorf("stmt2 execute callback error, code:%d, err:%s", result.N, errStr) + stmtErrorResponse(ctx, session, logger, action, req.ReqID, result.N, errStr, req.StmtID) + return + } stmtItem.result = result.Res resp := &stmt2ExecResponse{ Action: action, diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go index cd83fde6..623d259b 100644 --- a/controller/ws/ws/stmt2_test.go +++ b/controller/ws/ws/stmt2_test.go @@ -740,3 +740,285 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.NoError(t, err) assert.Equal(t, 0, closeResp.Code, closeResp.Message) } + +func TestStmt2BindWithStbFields(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + code, message := doRestful("drop database if exists test_ws_stmt2_getstbfields_ws", "") + assert.Equal(t, 0, code, message) + code, message = doRestful("create database if not exists test_ws_stmt2_getstbfields_ws precision 'ns'", "") + assert.Equal(t, 0, code, message) + + //defer doRestful("drop database if exists test_ws_stmt2_getstbfields_ws", "") + + code, message = doRestful( + "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", + "test_ws_stmt2_getstbfields_ws") + assert.Equal(t, 0, code, message) + + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt2_getstbfields_ws"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // init + initReq := stmt2InitRequest{ + ReqID: 0x123, + SingleStbInsert: false, + SingleTableBindOnce: false, + } + resp, err = doWebSocket(ws, STMT2Init, &initReq) + assert.NoError(t, err) + var initResp stmt2InitResponse + err = json.Unmarshal(resp, &initResp) + assert.NoError(t, err) + assert.Equal(t, uint64(0x123), initResp.ReqID) + assert.Equal(t, 0, initResp.Code, initResp.Message) + + // prepare + prepareReq := stmt2PrepareRequest{ + ReqID: 3, + StmtID: initResp.StmtID, + SQL: "insert into ? using test_ws_stmt2_getstbfields_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + GetFields: true, + } + resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) + assert.NoError(t, err) + var prepareResp stmt2PrepareResponse + err = json.Unmarshal(resp, &prepareResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), prepareResp.ReqID) + assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) + assert.Equal(t, true, prepareResp.IsInsert) + expectFieldsName := [18]string{"tbname", "info", "ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"} + expectFieldsType := [18]int8{ + common.TSDB_DATA_TYPE_BINARY, + common.TSDB_DATA_TYPE_JSON, + common.TSDB_DATA_TYPE_TIMESTAMP, + common.TSDB_DATA_TYPE_BOOL, + common.TSDB_DATA_TYPE_TINYINT, + common.TSDB_DATA_TYPE_SMALLINT, + common.TSDB_DATA_TYPE_INT, + common.TSDB_DATA_TYPE_BIGINT, + common.TSDB_DATA_TYPE_UTINYINT, + common.TSDB_DATA_TYPE_USMALLINT, + common.TSDB_DATA_TYPE_UINT, + common.TSDB_DATA_TYPE_UBIGINT, + common.TSDB_DATA_TYPE_FLOAT, + common.TSDB_DATA_TYPE_DOUBLE, + common.TSDB_DATA_TYPE_BINARY, + common.TSDB_DATA_TYPE_NCHAR, + common.TSDB_DATA_TYPE_VARBINARY, + common.TSDB_DATA_TYPE_GEOMETRY, + } + expectBindType := [18]int8{ + stmtCommon.TAOS_FIELD_TBNAME, + stmtCommon.TAOS_FIELD_TAG, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + stmtCommon.TAOS_FIELD_COL, + } + for i := 0; i < 18; i++ { + assert.Equal(t, expectFieldsName[i], prepareResp.Fields[i].Name) + assert.Equal(t, expectFieldsType[i], prepareResp.Fields[i].FieldType) + assert.Equal(t, expectBindType[i], prepareResp.Fields[i].BindType) + if prepareResp.Fields[i].FieldType == common.TSDB_DATA_TYPE_TIMESTAMP { + assert.Equal(t, uint8(common.PrecisionNanoSecond), prepareResp.Fields[i].Precision) + } + } + // bind + now := time.Now() + cols := [][]driver.Value{ + // ts + {now, now.Add(time.Second), now.Add(time.Second * 2)}, + // bool + {true, false, nil}, + // tinyint + {int8(2), int8(22), nil}, + // smallint + {int16(3), int16(33), nil}, + // int + {int32(4), int32(44), nil}, + // bigint + {int64(5), int64(55), nil}, + // tinyint unsigned + {uint8(6), uint8(66), nil}, + // smallint unsigned + {uint16(7), uint16(77), nil}, + // int unsigned + {uint32(8), uint32(88), nil}, + // bigint unsigned + {uint64(9), uint64(99), nil}, + // float + {float32(10), float32(1010), nil}, + // double + {float64(11), float64(1111), nil}, + // binary + {"binary", "binary2", nil}, + // nchar + {"nchar", "nchar2", nil}, + // varbinary + {[]byte{0xaa, 0xbb, 0xcc}, []byte{0xaa, 0xbb, 0xcc}, nil}, + // geometry + {[]byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, nil}, + } + tbName := "test_ws_stmt2_getstbfields_ws.ct1" + tag := []driver.Value{"{\"a\":\"b\"}"} + binds := &stmtCommon.TaosStmt2BindData{ + TableName: tbName, + Tags: tag, + Cols: cols, + } + var colFields []*stmtCommon.StmtField + var tagFields []*stmtCommon.StmtField + for i := 0; i < 18; i++ { + field := &stmtCommon.StmtField{ + FieldType: prepareResp.Fields[i].FieldType, + Precision: prepareResp.Fields[i].Precision, + } + switch prepareResp.Fields[i].BindType { + case stmtCommon.TAOS_FIELD_COL: + colFields = append(colFields, field) + case stmtCommon.TAOS_FIELD_TAG: + tagFields = append(tagFields, field) + } + } + bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, colFields, tagFields) + assert.NoError(t, err) + bindReq := make([]byte, len(bs)+30) + // req_id + binary.LittleEndian.PutUint64(bindReq, 0xee12345) + // stmt_id + binary.LittleEndian.PutUint64(bindReq[8:], prepareResp.StmtID) + // action + binary.LittleEndian.PutUint64(bindReq[16:], Stmt2BindMessage) + // version + binary.LittleEndian.PutUint16(bindReq[24:], Stmt2BindProtocolVersion1) + // col_idx + idx := int32(-1) + binary.LittleEndian.PutUint32(bindReq[26:], uint32(idx)) + // data + copy(bindReq[30:], bs) + err = ws.WriteMessage(websocket.BinaryMessage, bindReq) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + var bindResp stmt2BindResponse + err = json.Unmarshal(resp, &bindResp) + assert.NoError(t, err) + assert.Equal(t, 0, bindResp.Code, bindResp.Message) + + //exec + execReq := stmt2ExecRequest{ReqID: 10, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMT2Exec, &execReq) + assert.NoError(t, err) + var execResp stmt2ExecResponse + err = json.Unmarshal(resp, &execResp) + assert.NoError(t, err) + assert.Equal(t, uint64(10), execResp.ReqID) + assert.Equal(t, 0, execResp.Code, execResp.Message) + assert.Equal(t, 3, execResp.Affected) + + // close + closeReq := stmt2CloseRequest{ReqID: 11, StmtID: prepareResp.StmtID} + resp, err = doWebSocket(ws, STMT2Close, &closeReq) + assert.NoError(t, err) + var closeResp stmt2CloseResponse + err = json.Unmarshal(resp, &closeResp) + assert.NoError(t, err) + assert.Equal(t, uint64(11), closeResp.ReqID) + assert.Equal(t, 0, closeResp.Code, closeResp.Message) + + // query + queryReq := queryRequest{Sql: "select * from test_ws_stmt2_getstbfields_ws.stb"} + resp, err = doWebSocket(ws, WSQuery, &queryReq) + assert.NoError(t, err) + var queryResp queryResponse + err = json.Unmarshal(resp, &queryResp) + assert.NoError(t, err) + assert.Equal(t, 0, queryResp.Code, queryResp.Message) + + // fetch + fetchReq := fetchRequest{ID: queryResp.ID} + resp, err = doWebSocket(ws, WSFetch, &fetchReq) + assert.NoError(t, err) + var fetchResp fetchResponse + err = json.Unmarshal(resp, &fetchResp) + assert.NoError(t, err) + assert.Equal(t, 0, fetchResp.Code, fetchResp.Message) + + // fetch block + fetchBlockReq := fetchBlockRequest{ID: queryResp.ID} + fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) + assert.NoError(t, err) + _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], queryResp.FieldsTypes, fetchResp.Rows, queryResp.Precision) + assert.Equal(t, 3, len(blockResult)) + assert.Equal(t, now.UnixNano(), blockResult[0][0].(time.Time).UnixNano()) + + assert.Equal(t, true, blockResult[0][1]) + assert.Equal(t, int8(2), blockResult[0][2]) + assert.Equal(t, int16(3), blockResult[0][3]) + assert.Equal(t, int32(4), blockResult[0][4]) + assert.Equal(t, int64(5), blockResult[0][5]) + assert.Equal(t, uint8(6), blockResult[0][6]) + assert.Equal(t, uint16(7), blockResult[0][7]) + assert.Equal(t, uint32(8), blockResult[0][8]) + assert.Equal(t, uint64(9), blockResult[0][9]) + assert.Equal(t, float32(10), blockResult[0][10]) + assert.Equal(t, float64(11), blockResult[0][11]) + assert.Equal(t, "binary", blockResult[0][12]) + assert.Equal(t, "nchar", blockResult[0][13]) + assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[0][15]) + + assert.Equal(t, now.Add(time.Second).UnixNano(), blockResult[1][0].(time.Time).UnixNano()) + assert.Equal(t, false, blockResult[1][1]) + assert.Equal(t, int8(22), blockResult[1][2]) + assert.Equal(t, int16(33), blockResult[1][3]) + assert.Equal(t, int32(44), blockResult[1][4]) + assert.Equal(t, int64(55), blockResult[1][5]) + assert.Equal(t, uint8(66), blockResult[1][6]) + assert.Equal(t, uint16(77), blockResult[1][7]) + assert.Equal(t, uint32(88), blockResult[1][8]) + assert.Equal(t, uint64(99), blockResult[1][9]) + assert.Equal(t, float32(1010), blockResult[1][10]) + assert.Equal(t, float64(1111), blockResult[1][11]) + assert.Equal(t, "binary2", blockResult[1][12]) + assert.Equal(t, "nchar2", blockResult[1][13]) + assert.Equal(t, []byte{0xaa, 0xbb, 0xcc}, blockResult[1][14]) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, blockResult[1][15]) + + assert.Equal(t, now.Add(time.Second*2).UnixNano(), blockResult[2][0].(time.Time).UnixNano()) + for i := 1; i < 16; i++ { + assert.Nil(t, blockResult[2][i]) + } + +} diff --git a/db/syncinterface/wrapper.go b/db/syncinterface/wrapper.go index 00b508e1..dd9c6222 100644 --- a/db/syncinterface/wrapper.go +++ b/db/syncinterface/wrapper.go @@ -409,3 +409,15 @@ func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32, logger thread.SyncLocker.Unlock() return err } + +func TaosStmt2GetStbFields(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) (code, count int, fields unsafe.Pointer) { + logger.Tracef("call taos_stmt2_get_stb_fields, stmt2:%p", stmt2) + s := log.GetLogNow(isDebug) + thread.SyncLocker.Lock() + logger.Debugf("get thread lock for taos_stmt2_get_stb_fields cost:%s", log.GetLogDuration(isDebug, s)) + s = log.GetLogNow(isDebug) + code, count, fields = wrapper.TaosStmt2GetStbFields(stmt2) + logger.Debugf("taos_stmt2_get_stb_fields finish, code:%d, count:%d, fields:%p, cost:%s", code, count, fields, log.GetLogDuration(isDebug, s)) + thread.SyncLocker.Unlock() + return code, count, fields +} diff --git a/driver/wrapper/stmt2.go b/driver/wrapper/stmt2.go index 09eb1cbb..306325c3 100644 --- a/driver/wrapper/stmt2.go +++ b/driver/wrapper/stmt2.go @@ -15,6 +15,7 @@ TAOS_STMT2 * taos_stmt2_init_wrapper(TAOS *taos, int64_t reqid, bool singleStbIn */ import "C" import ( + "bytes" "database/sql/driver" "encoding/binary" "fmt" @@ -855,3 +856,62 @@ func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, f } return bindsCList, freePointer, nil } + +// TaosStmt2GetStbFields int taos_stmt2_get_stb_fields(TAOS_STMT2 *stmt, int *count, TAOS_FIELD_STB **fields); +func TaosStmt2GetStbFields(stmt unsafe.Pointer) (code, count int, fields unsafe.Pointer) { + code = int(C.taos_stmt2_get_stb_fields(stmt, (*C.int)(unsafe.Pointer(&count)), (**C.TAOS_FIELD_STB)(unsafe.Pointer(&fields)))) + return +} + +// TaosStmt2FreeStbFields void taos_stmt2_free_stb_fields(TAOS_STMT2 *stmt, TAOS_FIELD_STB *fields); +func TaosStmt2FreeStbFields(stmt unsafe.Pointer, fields unsafe.Pointer) { + C.taos_stmt2_free_stb_fields(stmt, (*C.TAOS_FIELD_STB)(fields)) +} + +//typedef struct TAOS_FIELD_STB { +//char name[65]; +//int8_t type; +//uint8_t precision; +//uint8_t scale; +//int32_t bytes; +//TAOS_FIELD_T field_type; +//} TAOS_FIELD_STB; + +type StmtStbField struct { + Name string `json:"name"` + FieldType int8 `json:"field_type"` + Precision uint8 `json:"precision"` + Scale uint8 `json:"scale"` + Bytes int32 `json:"bytes"` + BindType int8 `json:"bind_type"` +} + +func ParseStmt2StbFields(num int, fields unsafe.Pointer) []*StmtStbField { + if num <= 0 { + return nil + } + if fields == nil { + return nil + } + result := make([]*StmtStbField, num) + buf := bytes.NewBufferString("") + for i := 0; i < num; i++ { + r := &StmtStbField{} + field := *(*C.TAOS_FIELD_STB)(unsafe.Pointer(uintptr(fields) + uintptr(C.sizeof_struct_TAOS_FIELD_STB*C.int(i)))) + for _, c := range field.name { + if c == 0 { + break + } + buf.WriteByte(byte(c)) + } + r.Name = buf.String() + buf.Reset() + r.FieldType = int8(field._type) + r.Precision = uint8(field.precision) + r.Scale = uint8(field.scale) + r.Bytes = int32(field.bytes) + r.BindType = int8(field.field_type) + result[i] = r + } + return result +} diff --git a/driver/wrapper/stmt2_test.go b/driver/wrapper/stmt2_test.go index eb025d0d..9a0dd9ee 100644 --- a/driver/wrapper/stmt2_test.go +++ b/driver/wrapper/stmt2_test.go @@ -5040,3 +5040,388 @@ func TestTaosStmt2BindBinaryParse(t *testing.T) { }) } } + +func TestTaosStmt2GetStbFields(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_stb_fields") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database test_stmt2_stb_fields precision 'ns'") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_stb_fields") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists all_stb("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))"+ + "tags("+ + "tts timestamp, "+ + "tv1 bool, "+ + "tv2 tinyint, "+ + "tv3 smallint, "+ + "tv4 int, "+ + "tv5 bigint, "+ + "tv6 tinyint unsigned, "+ + "tv7 smallint unsigned, "+ + "tv8 int unsigned, "+ + "tv9 bigint unsigned, "+ + "tv10 float, "+ + "tv11 double, "+ + "tv12 binary(20), "+ + "tv13 varbinary(20), "+ + "tv14 geometry(100), "+ + "tv15 nchar(20))") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists commontb("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))") + if err != nil { + t.Error(err) + return + } + expectMap := map[string]*StmtStbField{ + "tts": &StmtStbField{ + Name: "tts", + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Scale: 0, + Precision: common.PrecisionNanoSecond, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv1": &StmtStbField{ + Name: "tv1", + FieldType: common.TSDB_DATA_TYPE_BOOL, + Scale: 0, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv2": &StmtStbField{ + Name: "tv2", + FieldType: common.TSDB_DATA_TYPE_TINYINT, + Scale: 0, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv3": &StmtStbField{ + Name: "tv3", + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + Scale: 0, + Bytes: 2, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv4": &StmtStbField{ + Name: "tv4", + FieldType: common.TSDB_DATA_TYPE_INT, + Scale: 0, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv5": &StmtStbField{ + Name: "tv5", + FieldType: common.TSDB_DATA_TYPE_BIGINT, + Scale: 0, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv6": &StmtStbField{ + Name: "tv6", + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + Scale: 0, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv7": &StmtStbField{ + Name: "tv7", + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + Scale: 0, + Bytes: 2, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv8": &StmtStbField{ + Name: "tv8", + FieldType: common.TSDB_DATA_TYPE_UINT, + Scale: 0, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv9": &StmtStbField{ + Name: "tv9", + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + Scale: 0, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv10": &StmtStbField{ + Name: "tv10", + FieldType: common.TSDB_DATA_TYPE_FLOAT, + Scale: 0, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv11": &StmtStbField{ + Name: "tv11", + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + Scale: 0, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv12": &StmtStbField{ + Name: "tv12", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Scale: 0, + Bytes: 22, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv13": &StmtStbField{ + Name: "tv13", + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + Scale: 0, + Bytes: 22, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv14": &StmtStbField{ + Name: "tv14", + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + Scale: 0, + Bytes: 102, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv15": &StmtStbField{ + Name: "tv15", + FieldType: common.TSDB_DATA_TYPE_NCHAR, + Scale: 0, + Bytes: 82, + BindType: stmt.TAOS_FIELD_TAG, + }, + "ts": &StmtStbField{ + Name: "ts", + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionNanoSecond, + Scale: 0, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v1": &StmtStbField{ + Name: "v1", + FieldType: common.TSDB_DATA_TYPE_BOOL, + Scale: 0, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v2": &StmtStbField{ + Name: "v2", + FieldType: common.TSDB_DATA_TYPE_TINYINT, + Scale: 0, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v3": &StmtStbField{ + Name: "v3", + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + Scale: 0, + Bytes: 2, + BindType: stmt.TAOS_FIELD_COL, + }, + "v4": &StmtStbField{ + Name: "v4", + FieldType: common.TSDB_DATA_TYPE_INT, + Scale: 0, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v5": &StmtStbField{ + Name: "v5", + FieldType: common.TSDB_DATA_TYPE_BIGINT, + Scale: 0, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v6": &StmtStbField{ + Name: "v6", + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + Scale: 0, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v7": &StmtStbField{ + Name: "v7", + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + Scale: 0, + Bytes: 2, + BindType: stmt.TAOS_FIELD_COL, + }, + "v8": &StmtStbField{ + Name: "v8", + FieldType: common.TSDB_DATA_TYPE_UINT, + Scale: 0, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v9": &StmtStbField{ + Name: "v9", + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + Scale: 0, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v10": &StmtStbField{ + Name: "v10", + FieldType: common.TSDB_DATA_TYPE_FLOAT, + Scale: 0, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v11": &StmtStbField{ + Name: "v11", + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + Scale: 0, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v12": &StmtStbField{ + Name: "v12", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Scale: 0, + Bytes: 22, + BindType: stmt.TAOS_FIELD_COL, + }, + "v13": &StmtStbField{ + Name: "v13", + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + Scale: 0, + Bytes: 22, + BindType: stmt.TAOS_FIELD_COL, + }, + "v14": &StmtStbField{ + Name: "v14", + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + Scale: 0, + Bytes: 102, + BindType: stmt.TAOS_FIELD_COL, + }, + "v15": &StmtStbField{ + Name: "v15", + FieldType: common.TSDB_DATA_TYPE_NCHAR, + Scale: 0, + Bytes: 82, + BindType: stmt.TAOS_FIELD_COL, + }, + "tbname": &StmtStbField{ + Name: "tbname", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Scale: 0, + Bytes: 271, + BindType: stmt.TAOS_FIELD_TBNAME, + }, + } + tests := []struct { + name string + sql string + expect []string + }{ + { + name: "with subTableName", + sql: "insert into tb1 using all_stb tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"tts", "tv1", "tv2", "tv3", "tv4", "tv5", "tv6", "tv7", "tv8", "tv9", "tv10", "tv11", "tv12", "tv13", "tv14", "tv15", "ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + { + name: "using stb", + sql: "insert into ? using all_stb tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"tbname", "tts", "tv1", "tv2", "tv3", "tv4", "tv5", "tv6", "tv7", "tv8", "tv9", "tv10", "tv11", "tv12", "tv13", "tv14", "tv15", "ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + { + name: "tbname as value", + sql: "insert into all_stb (tbname,tts,tv1,tv2,tv3,tv4,tv5,tv6,tv7,tv8,tv9,tv10,tv11,tv12,tv13,tv14,tv15,ts,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"tbname", "tts", "tv1", "tv2", "tv3", "tv4", "tv5", "tv6", "tv7", "tv8", "tv9", "tv10", "tv11", "tv12", "tv13", "tv14", "tv15", "ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + { + name: "tbname as value random", + sql: "insert into all_stb (ts,v1,v2,v3,v4,v5,v6,tts,tv1,tv2,tv3,tv4,tv5,tv6,tv7,tv8,tv9,tv10,tv11,tv12,tv13,tv14,tbname,tv15,v7,v8,v9,v10,v11,v12,v13,v14,v15) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"ts", "v1", "v2", "v3", "v4", "v5", "v6", "tts", "tv1", "tv2", "tv3", "tv4", "tv5", "tv6", "tv7", "tv8", "tv9", "tv10", "tv11", "tv12", "tv13", "tv14", "tbname", "tv15", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + { + name: "common table", + sql: "insert into commontb values(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + expect: []string{"ts", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"}, + }, + } + for _, tt := range tests { + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xed123, false, false, handler) + defer TaosStmt2Close(stmt2) + code := TaosStmt2Prepare(stmt2, tt.sql) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + code, count, fields := TaosStmt2GetStbFields(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + fs := ParseStmt2StbFields(count, fields) + TaosStmt2FreeStbFields(stmt2, fields) + expect := make([]*StmtStbField, len(tt.expect)) + for i := 0; i < len(tt.expect); i++ { + expect[i] = expectMap[tt.expect[i]] + } + assert.Equal(t, expect, fs) + } +} + +func TestWrongParseStmt2StbFields(t *testing.T) { + fs := ParseStmt2StbFields(0, nil) + assert.Nil(t, fs) + fs = ParseStmt2StbFields(2, nil) + assert.Nil(t, fs) +} From f164041e147d73e6cef6848f5dd523aa486d4a46 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 4 Dec 2024 12:54:49 +0800 Subject: [PATCH 30/48] test: add stmt2 query test --- controller/ws/ws/stmt2_test.go | 3 +- db/syncinterface/wrapper_test.go | 24 ++++++ driver/wrapper/stmt2_test.go | 128 +++++++++++++++---------------- 3 files changed, 87 insertions(+), 68 deletions(-) diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go index 623d259b..88281fbd 100644 --- a/controller/ws/ws/stmt2_test.go +++ b/controller/ws/ws/stmt2_test.go @@ -643,7 +643,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { ReqID: 3, StmtID: initResp.StmtID, SQL: fmt.Sprintf("select * from %s.meters where group_id=? and location=?", db), - GetFields: false, + GetFields: true, } resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) assert.NoError(t, err) @@ -653,6 +653,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.Equal(t, uint64(3), prepareResp.ReqID) assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) assert.False(t, prepareResp.IsInsert) + assert.Equal(t, 2, prepareResp.FieldsCount) // bind var block bytes.Buffer diff --git a/db/syncinterface/wrapper_test.go b/db/syncinterface/wrapper_test.go index 25852e27..9917b7e8 100644 --- a/db/syncinterface/wrapper_test.go +++ b/db/syncinterface/wrapper_test.go @@ -509,6 +509,30 @@ func TestTaosStmt2(t *testing.T) { return } assert.True(t, isInsert) + code, count, fiels := TaosStmt2GetStbFields(stmt, logger, isDebug) + if !assert.Equal(t, 0, code, wrapper.TaosStmtErrStr(stmt)) { + return + } + assert.Equal(t, 4, count) + assert.NotNil(t, fiels) + defer func() { + wrapper.TaosStmt2FreeFields(stmt, fiels) + }() + fs := wrapper.ParseStmt2StbFields(count, fiels) + assert.Equal(t, 4, len(fs)) + assert.Equal(t, "tbname", fs[0].Name) + assert.Equal(t, int8(common.TSDB_DATA_TYPE_BINARY), fs[0].FieldType) + assert.Equal(t, int8(stmtCommon.TAOS_FIELD_TBNAME), fs[0].BindType) + assert.Equal(t, "id", fs[1].Name) + assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), fs[1].FieldType) + assert.Equal(t, int8(stmtCommon.TAOS_FIELD_TAG), fs[1].BindType) + assert.Equal(t, "ts", fs[2].Name) + assert.Equal(t, int8(common.TSDB_DATA_TYPE_TIMESTAMP), fs[2].FieldType) + assert.Equal(t, int8(stmtCommon.TAOS_FIELD_COL), fs[2].BindType) + assert.Equal(t, uint8(common.PrecisionMilliSecond), fs[2].Precision) + assert.Equal(t, "v", fs[3].Name) + assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), fs[3].FieldType) + assert.Equal(t, int8(stmtCommon.TAOS_FIELD_COL), fs[3].BindType) tableName := "tb1" binds := &stmtCommon.TaosStmt2BindData{ TableName: tableName, diff --git a/driver/wrapper/stmt2_test.go b/driver/wrapper/stmt2_test.go index 9a0dd9ee..54576e39 100644 --- a/driver/wrapper/stmt2_test.go +++ b/driver/wrapper/stmt2_test.go @@ -5125,236 +5125,203 @@ func TestTaosStmt2GetStbFields(t *testing.T) { return } expectMap := map[string]*StmtStbField{ - "tts": &StmtStbField{ + "tts": { Name: "tts", FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, - Scale: 0, Precision: common.PrecisionNanoSecond, Bytes: 8, BindType: stmt.TAOS_FIELD_TAG, }, - "tv1": &StmtStbField{ + "tv1": { Name: "tv1", FieldType: common.TSDB_DATA_TYPE_BOOL, - Scale: 0, Bytes: 1, BindType: stmt.TAOS_FIELD_TAG, }, - "tv2": &StmtStbField{ + "tv2": { Name: "tv2", FieldType: common.TSDB_DATA_TYPE_TINYINT, - Scale: 0, Bytes: 1, BindType: stmt.TAOS_FIELD_TAG, }, - "tv3": &StmtStbField{ + "tv3": { Name: "tv3", FieldType: common.TSDB_DATA_TYPE_SMALLINT, - Scale: 0, Bytes: 2, BindType: stmt.TAOS_FIELD_TAG, }, - "tv4": &StmtStbField{ + "tv4": { Name: "tv4", FieldType: common.TSDB_DATA_TYPE_INT, - Scale: 0, Bytes: 4, BindType: stmt.TAOS_FIELD_TAG, }, - "tv5": &StmtStbField{ + "tv5": { Name: "tv5", FieldType: common.TSDB_DATA_TYPE_BIGINT, - Scale: 0, Bytes: 8, BindType: stmt.TAOS_FIELD_TAG, }, - "tv6": &StmtStbField{ + "tv6": { Name: "tv6", FieldType: common.TSDB_DATA_TYPE_UTINYINT, - Scale: 0, Bytes: 1, BindType: stmt.TAOS_FIELD_TAG, }, - "tv7": &StmtStbField{ + "tv7": { Name: "tv7", FieldType: common.TSDB_DATA_TYPE_USMALLINT, - Scale: 0, Bytes: 2, BindType: stmt.TAOS_FIELD_TAG, }, - "tv8": &StmtStbField{ + "tv8": { Name: "tv8", FieldType: common.TSDB_DATA_TYPE_UINT, - Scale: 0, Bytes: 4, BindType: stmt.TAOS_FIELD_TAG, }, - "tv9": &StmtStbField{ + "tv9": { Name: "tv9", FieldType: common.TSDB_DATA_TYPE_UBIGINT, - Scale: 0, Bytes: 8, BindType: stmt.TAOS_FIELD_TAG, }, - "tv10": &StmtStbField{ + "tv10": { Name: "tv10", FieldType: common.TSDB_DATA_TYPE_FLOAT, - Scale: 0, Bytes: 4, BindType: stmt.TAOS_FIELD_TAG, }, - "tv11": &StmtStbField{ + "tv11": { Name: "tv11", FieldType: common.TSDB_DATA_TYPE_DOUBLE, - Scale: 0, Bytes: 8, BindType: stmt.TAOS_FIELD_TAG, }, - "tv12": &StmtStbField{ + "tv12": { Name: "tv12", FieldType: common.TSDB_DATA_TYPE_BINARY, - Scale: 0, Bytes: 22, BindType: stmt.TAOS_FIELD_TAG, }, - "tv13": &StmtStbField{ + "tv13": { Name: "tv13", FieldType: common.TSDB_DATA_TYPE_VARBINARY, - Scale: 0, Bytes: 22, BindType: stmt.TAOS_FIELD_TAG, }, - "tv14": &StmtStbField{ + "tv14": { Name: "tv14", FieldType: common.TSDB_DATA_TYPE_GEOMETRY, - Scale: 0, Bytes: 102, BindType: stmt.TAOS_FIELD_TAG, }, - "tv15": &StmtStbField{ + "tv15": { Name: "tv15", FieldType: common.TSDB_DATA_TYPE_NCHAR, - Scale: 0, Bytes: 82, BindType: stmt.TAOS_FIELD_TAG, }, - "ts": &StmtStbField{ + "ts": { Name: "ts", FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionNanoSecond, - Scale: 0, Bytes: 8, BindType: stmt.TAOS_FIELD_COL, }, - "v1": &StmtStbField{ + "v1": { Name: "v1", FieldType: common.TSDB_DATA_TYPE_BOOL, - Scale: 0, Bytes: 1, BindType: stmt.TAOS_FIELD_COL, }, - "v2": &StmtStbField{ + "v2": { Name: "v2", FieldType: common.TSDB_DATA_TYPE_TINYINT, - Scale: 0, Bytes: 1, BindType: stmt.TAOS_FIELD_COL, }, - "v3": &StmtStbField{ + "v3": { Name: "v3", FieldType: common.TSDB_DATA_TYPE_SMALLINT, - Scale: 0, Bytes: 2, BindType: stmt.TAOS_FIELD_COL, }, - "v4": &StmtStbField{ + "v4": { Name: "v4", FieldType: common.TSDB_DATA_TYPE_INT, - Scale: 0, Bytes: 4, BindType: stmt.TAOS_FIELD_COL, }, - "v5": &StmtStbField{ + "v5": { Name: "v5", FieldType: common.TSDB_DATA_TYPE_BIGINT, - Scale: 0, Bytes: 8, BindType: stmt.TAOS_FIELD_COL, }, - "v6": &StmtStbField{ + "v6": { Name: "v6", FieldType: common.TSDB_DATA_TYPE_UTINYINT, - Scale: 0, Bytes: 1, BindType: stmt.TAOS_FIELD_COL, }, - "v7": &StmtStbField{ + "v7": { Name: "v7", FieldType: common.TSDB_DATA_TYPE_USMALLINT, - Scale: 0, Bytes: 2, BindType: stmt.TAOS_FIELD_COL, }, - "v8": &StmtStbField{ + "v8": { Name: "v8", FieldType: common.TSDB_DATA_TYPE_UINT, - Scale: 0, Bytes: 4, BindType: stmt.TAOS_FIELD_COL, }, - "v9": &StmtStbField{ + "v9": { Name: "v9", FieldType: common.TSDB_DATA_TYPE_UBIGINT, - Scale: 0, Bytes: 8, BindType: stmt.TAOS_FIELD_COL, }, - "v10": &StmtStbField{ + "v10": { Name: "v10", FieldType: common.TSDB_DATA_TYPE_FLOAT, - Scale: 0, Bytes: 4, BindType: stmt.TAOS_FIELD_COL, }, - "v11": &StmtStbField{ + "v11": { Name: "v11", FieldType: common.TSDB_DATA_TYPE_DOUBLE, - Scale: 0, Bytes: 8, BindType: stmt.TAOS_FIELD_COL, }, - "v12": &StmtStbField{ + "v12": { Name: "v12", FieldType: common.TSDB_DATA_TYPE_BINARY, - Scale: 0, Bytes: 22, BindType: stmt.TAOS_FIELD_COL, }, - "v13": &StmtStbField{ + "v13": { Name: "v13", FieldType: common.TSDB_DATA_TYPE_VARBINARY, - Scale: 0, Bytes: 22, BindType: stmt.TAOS_FIELD_COL, }, - "v14": &StmtStbField{ + "v14": { Name: "v14", FieldType: common.TSDB_DATA_TYPE_GEOMETRY, - Scale: 0, Bytes: 102, BindType: stmt.TAOS_FIELD_COL, }, - "v15": &StmtStbField{ + "v15": { Name: "v15", FieldType: common.TSDB_DATA_TYPE_NCHAR, - Scale: 0, Bytes: 82, BindType: stmt.TAOS_FIELD_COL, }, - "tbname": &StmtStbField{ + "tbname": { Name: "tbname", FieldType: common.TSDB_DATA_TYPE_BINARY, - Scale: 0, Bytes: 271, BindType: stmt.TAOS_FIELD_TBNAME, }, @@ -5413,10 +5380,37 @@ func TestTaosStmt2GetStbFields(t *testing.T) { TaosStmt2FreeStbFields(stmt2, fields) expect := make([]*StmtStbField, len(tt.expect)) for i := 0; i < len(tt.expect); i++ { + assert.Equal(t, expectMap[tt.expect[i]].Name, fs[i].Name) + assert.Equal(t, expectMap[tt.expect[i]].FieldType, fs[i].FieldType) + assert.Equal(t, expectMap[tt.expect[i]].Bytes, fs[i].Bytes) + assert.Equal(t, expectMap[tt.expect[i]].BindType, fs[i].BindType) + if expectMap[tt.expect[i]].FieldType == common.TSDB_DATA_TYPE_TIMESTAMP { + assert.Equal(t, expectMap[tt.expect[i]].Precision, fs[i].Precision) + } expect[i] = expectMap[tt.expect[i]] } - assert.Equal(t, expect, fs) } + + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + stmt2 := TaosStmt2Init(conn, 0xfd123, false, false, handler) + defer TaosStmt2Close(stmt2) + code := TaosStmt2Prepare(stmt2, "select * from commontb where ts = ? and v1 = ?") + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + code, count, fields := TaosStmt2GetStbFields(stmt2) + if code != 0 { + errStr := TaosStmt2Error(stmt2) + err := taosError.NewError(code, errStr) + t.Error(err) + return + } + TaosStmt2FreeStbFields(stmt2, fields) + assert.Equal(t, 2, count) } func TestWrongParseStmt2StbFields(t *testing.T) { From 7f78719af0900e7eaedd3d09dca985474ab6939b Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 5 Dec 2024 09:57:22 +0800 Subject: [PATCH 31/48] enh: remove stmt2_get_fields --- controller/ws/ws/const.go | 11 +- controller/ws/ws/handler.go | 14 -- controller/ws/ws/handler_test.go | 6 - controller/ws/ws/stmt2.go | 90 ------------ controller/ws/ws/stmt2_test.go | 240 +++---------------------------- 5 files changed, 27 insertions(+), 334 deletions(-) diff --git a/controller/ws/ws/const.go b/controller/ws/ws/const.go index 59e95839..a1951a7b 100644 --- a/controller/ws/ws/const.go +++ b/controller/ws/ws/const.go @@ -37,12 +37,11 @@ const ( STMTGetParam = "stmt_get_param" // stmt2 - STMT2Init = "stmt2_init" - STMT2Prepare = "stmt2_prepare" - STMT2GetFields = "stmt2_get_fields" - STMT2Exec = "stmt2_exec" - STMT2Result = "stmt2_result" - STMT2Close = "stmt2_close" + STMT2Init = "stmt2_init" + STMT2Prepare = "stmt2_prepare" + STMT2Exec = "stmt2_exec" + STMT2Result = "stmt2_result" + STMT2Close = "stmt2_close" ) const ( diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 4e758b08..a1e02dda 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -526,20 +526,6 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { config.ReqIDKey: req.ReqID, }) h.stmt2Prepare(ctx, session, action, req, logger, log.IsDebug()) - case STMT2GetFields: - action = STMT2GetFields - var req stmt2GetFieldsRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - h.logger.Errorf("unmarshal stmt2 get fields request error, request:%s, err:%s", request.Args, err) - reqID := getReqID(request.Args) - commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal stmt2 get fields request error") - return - } - logger := h.logger.WithFields(logrus.Fields{ - actionKey: action, - config.ReqIDKey: req.ReqID, - }) - h.stmt2GetFields(ctx, session, action, req, logger, log.IsDebug()) case STMT2Exec: action = STMT2Exec var req stmt2ExecRequest diff --git a/controller/ws/ws/handler_test.go b/controller/ws/ws/handler_test.go index 01d75474..dbd232fe 100644 --- a/controller/ws/ws/handler_test.go +++ b/controller/ws/ws/handler_test.go @@ -234,12 +234,6 @@ func Test_WrongJsonProtocol(t *testing.T) { args: "wrong", errorPrefix: "unmarshal stmt2 prepare request error", }, - { - name: "stmt2 get fields with wrong args", - action: STMT2GetFields, - args: "wrong", - errorPrefix: "unmarshal stmt2 get fields request error", - }, { name: "stmt2 exec with wrong args", action: STMT2Exec, diff --git a/controller/ws/ws/stmt2.go b/controller/ws/ws/stmt2.go index 8fa84034..e2ff2d57 100644 --- a/controller/ws/ws/stmt2.go +++ b/controller/ws/ws/stmt2.go @@ -4,14 +4,11 @@ import ( "context" "encoding/binary" "errors" - "fmt" - "unsafe" "github.com/sirupsen/logrus" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/syncinterface" - stmtCommon "github.com/taosdata/taosadapter/v3/driver/common/stmt" errors2 "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" @@ -141,93 +138,6 @@ func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Sessi wstool.WSWriteJson(session, logger, prepareResp) } -func getFields(stmt2 unsafe.Pointer, fieldType int8, logger *logrus.Entry, isDebug bool) (fields []*stmtCommon.StmtField, count int, code int, errSt string) { - var cFields unsafe.Pointer - code, count, cFields = syncinterface.TaosStmt2GetFields(stmt2, int(fieldType), logger, isDebug) - if code != 0 { - errStr := wrapper.TaosStmt2Error(stmt2) - logger.Errorf("stmt2 get fields error, field_type:%d, err:%s", fieldType, errStr) - return nil, count, code, errStr - } - defer wrapper.TaosStmt2FreeFields(stmt2, cFields) - if count > 0 && cFields != nil { - s := log.GetLogNow(isDebug) - fields = wrapper.StmtParseFields(count, cFields) - logger.Debugf("stmt2 parse fields cost:%s", log.GetLogDuration(isDebug, s)) - return fields, count, 0, "" - } - return nil, count, 0, "" -} - -type stmt2GetFieldsRequest struct { - ReqID uint64 `json:"req_id"` - StmtID uint64 `json:"stmt_id"` - FieldTypes []int8 `json:"field_types"` -} - -type stmt2GetFieldsResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Action string `json:"action"` - ReqID uint64 `json:"req_id"` - Timing int64 `json:"timing"` - StmtID uint64 `json:"stmt_id"` - TableCount int32 `json:"table_count"` - QueryCount int32 `json:"query_count"` - ColFields []*stmtCommon.StmtField `json:"col_fields"` - TagFields []*stmtCommon.StmtField `json:"tag_fields"` -} - -func (h *messageHandler) stmt2GetFields(ctx context.Context, session *melody.Session, action string, req stmt2GetFieldsRequest, logger *logrus.Entry, isDebug bool) { - logger.Tracef("stmt2 get col fields, stmt_id:%d", req.StmtID) - stmtItem, locked := h.stmt2ValidateAndLock(ctx, session, action, req.ReqID, req.StmtID, logger, isDebug) - if !locked { - return - } - defer stmtItem.Unlock() - stmt2GetFieldsResp := &stmt2GetFieldsResponse{StmtID: req.StmtID} - for i := 0; i < len(req.FieldTypes); i++ { - switch req.FieldTypes[i] { - case stmtCommon.TAOS_FIELD_COL: - colFields, _, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_COL, logger, isDebug) - if code != 0 { - logger.Errorf("get col fields error, code:%d, err:%s", code, errStr) - stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get col fields error, %s", errStr), req.StmtID) - return - } - stmt2GetFieldsResp.ColFields = colFields - case stmtCommon.TAOS_FIELD_TAG: - tagFields, _, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_TAG, logger, isDebug) - if code != 0 { - logger.Errorf("get tag fields error, code:%d, err:%s", code, errStr) - stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get tag fields error, %s", errStr), req.StmtID) - return - } - stmt2GetFieldsResp.TagFields = tagFields - case stmtCommon.TAOS_FIELD_TBNAME: - _, count, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_TBNAME, logger, isDebug) - if code != 0 { - logger.Errorf("get table names fields error, code:%d, err:%s", code, errStr) - stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get table names fields error, %s", errStr), req.StmtID) - return - } - stmt2GetFieldsResp.TableCount = int32(count) - case stmtCommon.TAOS_FIELD_QUERY: - _, count, code, errStr := getFields(stmtItem.stmt, stmtCommon.TAOS_FIELD_QUERY, logger, isDebug) - if code != 0 { - logger.Errorf("get query fields error, code:%d, err:%s", code, errStr) - stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, fmt.Sprintf("get query fields error, %s", errStr), req.StmtID) - return - } - stmt2GetFieldsResp.QueryCount = int32(count) - } - } - stmt2GetFieldsResp.ReqID = req.ReqID - stmt2GetFieldsResp.Action = action - stmt2GetFieldsResp.Timing = wstool.GetDuration(ctx) - wstool.WSWriteJson(session, logger, stmt2GetFieldsResp) -} - type stmt2ExecRequest struct { ReqID uint64 `json:"req_id"` StmtID uint64 `json:"stmt_id"` diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go index 88281fbd..7e186ed1 100644 --- a/controller/ws/ws/stmt2_test.go +++ b/controller/ws/ws/stmt2_test.go @@ -69,7 +69,12 @@ func TestWsStmt2(t *testing.T) { assert.Equal(t, 0, initResp.Code, initResp.Message) // prepare - prepareReq := stmt2PrepareRequest{ReqID: 3, StmtID: initResp.StmtID, SQL: "insert into ct1 using test_ws_stmt2_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"} + prepareReq := stmt2PrepareRequest{ + ReqID: 3, + StmtID: initResp.StmtID, + SQL: "insert into ct1 using test_ws_stmt2_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + GetFields: true, + } resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) assert.NoError(t, err) var prepareResp stmt2PrepareResponse @@ -78,27 +83,21 @@ func TestWsStmt2(t *testing.T) { assert.Equal(t, uint64(3), prepareResp.ReqID) assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) assert.True(t, prepareResp.IsInsert) - - // get tag fields - getTagFieldsReq := stmt2GetFieldsRequest{ReqID: 5, StmtID: prepareResp.StmtID, FieldTypes: []int8{stmtCommon.TAOS_FIELD_TAG}} - resp, err = doWebSocket(ws, STMT2GetFields, &getTagFieldsReq) - assert.NoError(t, err) - var getTagFieldsResp stmt2GetFieldsResponse - err = json.Unmarshal(resp, &getTagFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(5), getTagFieldsResp.ReqID) - assert.Equal(t, 0, getTagFieldsResp.Code, getTagFieldsResp.Message) - - // get col fields - getColFieldsReq := stmt2GetFieldsRequest{ReqID: 6, StmtID: prepareResp.StmtID, FieldTypes: []int8{stmtCommon.TAOS_FIELD_COL}} - resp, err = doWebSocket(ws, STMT2GetFields, &getColFieldsReq) - assert.NoError(t, err) - var getColFieldsResp stmt2GetFieldsResponse - err = json.Unmarshal(resp, &getColFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(6), getColFieldsResp.ReqID) - assert.Equal(t, 0, getColFieldsResp.Code, getColFieldsResp.Message) - + assert.Equal(t, 17, len(prepareResp.Fields)) + var colFields []*stmtCommon.StmtField + var tagFields []*stmtCommon.StmtField + for i := 0; i < 17; i++ { + field := &stmtCommon.StmtField{ + FieldType: prepareResp.Fields[i].FieldType, + Precision: prepareResp.Fields[i].Precision, + } + switch prepareResp.Fields[i].BindType { + case stmtCommon.TAOS_FIELD_COL: + colFields = append(colFields, field) + case stmtCommon.TAOS_FIELD_TAG: + tagFields = append(tagFields, field) + } + } // bind now := time.Now() cols := [][]driver.Value{ @@ -142,7 +141,7 @@ func TestWsStmt2(t *testing.T) { Tags: tag, Cols: cols, } - bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, getColFieldsResp.ColFields, getTagFieldsResp.TagFields) + bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, colFields, tagFields) assert.NoError(t, err) bindReq := make([]byte, len(bs)+30) // req_id @@ -380,201 +379,6 @@ func TestStmt2Prepare(t *testing.T) { assert.Equal(t, 2, prepareResp.FieldsCount) } -func TestStmt2GetFields(t *testing.T) { - s := httptest.NewServer(router) - defer s.Close() - code, message := doRestful("drop database if exists test_ws_stmt2_getfields_ws", "") - assert.Equal(t, 0, code, message) - code, message = doRestful("create database if not exists test_ws_stmt2_getfields_ws precision 'ns'", "") - assert.Equal(t, 0, code, message) - - defer doRestful("drop database if exists test_ws_stmt2_getfields_ws", "") - - code, message = doRestful( - "create table if not exists stb (ts timestamp,v1 bool,v2 tinyint,v3 smallint,v4 int,v5 bigint,v6 tinyint unsigned,v7 smallint unsigned,v8 int unsigned,v9 bigint unsigned,v10 float,v11 double,v12 binary(20),v13 nchar(20),v14 varbinary(20),v15 geometry(100)) tags (info json)", - "test_ws_stmt2_getfields_ws") - assert.Equal(t, 0, code, message) - - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - // connect - connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", DB: "test_ws_stmt2_getfields_ws"} - resp, err := doWebSocket(ws, Connect, &connReq) - assert.NoError(t, err) - var connResp commonResp - err = json.Unmarshal(resp, &connResp) - assert.NoError(t, err) - assert.Equal(t, uint64(1), connResp.ReqID) - assert.Equal(t, 0, connResp.Code, connResp.Message) - - // init - initReq := stmt2InitRequest{ - ReqID: 0x123, - SingleStbInsert: false, - SingleTableBindOnce: false, - } - resp, err = doWebSocket(ws, STMT2Init, &initReq) - assert.NoError(t, err) - var initResp stmt2InitResponse - err = json.Unmarshal(resp, &initResp) - assert.NoError(t, err) - assert.Equal(t, uint64(0x123), initResp.ReqID) - assert.Equal(t, 0, initResp.Code, initResp.Message) - - // prepare - prepareReq := stmt2PrepareRequest{ - ReqID: 3, - StmtID: initResp.StmtID, - SQL: "insert into ctb using test_ws_stmt2_getfields_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - GetFields: false, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - var prepareResp stmt2PrepareResponse - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(3), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.Equal(t, true, prepareResp.IsInsert) - - // get fields - getFieldsReq := stmt2GetFieldsRequest{ - ReqID: 4, - StmtID: prepareResp.StmtID, - FieldTypes: []int8{ - stmtCommon.TAOS_FIELD_TAG, - stmtCommon.TAOS_FIELD_COL, - }, - } - resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) - assert.NoError(t, err) - var getFieldsResp stmt2GetFieldsResponse - err = json.Unmarshal(resp, &getFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(4), getFieldsResp.ReqID) - assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) - names := [16]string{ - "ts", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - } - fieldTypes := [16]int8{ - common.TSDB_DATA_TYPE_TIMESTAMP, - common.TSDB_DATA_TYPE_BOOL, - common.TSDB_DATA_TYPE_TINYINT, - common.TSDB_DATA_TYPE_SMALLINT, - common.TSDB_DATA_TYPE_INT, - common.TSDB_DATA_TYPE_BIGINT, - common.TSDB_DATA_TYPE_UTINYINT, - common.TSDB_DATA_TYPE_USMALLINT, - common.TSDB_DATA_TYPE_UINT, - common.TSDB_DATA_TYPE_UBIGINT, - common.TSDB_DATA_TYPE_FLOAT, - common.TSDB_DATA_TYPE_DOUBLE, - common.TSDB_DATA_TYPE_BINARY, - common.TSDB_DATA_TYPE_NCHAR, - common.TSDB_DATA_TYPE_VARBINARY, - common.TSDB_DATA_TYPE_GEOMETRY, - } - assert.Equal(t, 16, len(getFieldsResp.ColFields)) - assert.Equal(t, 1, len(getFieldsResp.TagFields)) - for i := 0; i < 16; i++ { - assert.Equal(t, names[i], getFieldsResp.ColFields[i].Name) - assert.Equal(t, fieldTypes[i], getFieldsResp.ColFields[i].FieldType) - } - assert.Equal(t, "info", getFieldsResp.TagFields[0].Name) - assert.Equal(t, int8(common.TSDB_DATA_TYPE_JSON), getFieldsResp.TagFields[0].FieldType) - - // prepare get tablename - prepareReq = stmt2PrepareRequest{ - ReqID: 5, - StmtID: initResp.StmtID, - SQL: "insert into ? using test_ws_stmt2_getfields_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", - GetFields: false, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(5), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.Equal(t, true, prepareResp.IsInsert) - // get fields - getFieldsReq = stmt2GetFieldsRequest{ - ReqID: 6, - StmtID: prepareResp.StmtID, - FieldTypes: []int8{ - stmtCommon.TAOS_FIELD_TBNAME, - }, - } - resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &getFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(6), getFieldsResp.ReqID) - assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) - - assert.Nil(t, getFieldsResp.ColFields) - assert.Nil(t, getFieldsResp.TagFields) - assert.Equal(t, int32(1), getFieldsResp.TableCount) - - // prepare query - prepareReq = stmt2PrepareRequest{ - ReqID: 7, - StmtID: initResp.StmtID, - SQL: "select * from test_ws_stmt2_getfields_ws.stb where ts = ? and v1 = ?", - GetFields: false, - } - resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &prepareResp) - assert.NoError(t, err) - assert.Equal(t, uint64(7), prepareResp.ReqID) - assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) - assert.Equal(t, false, prepareResp.IsInsert) - // get fields - getFieldsReq = stmt2GetFieldsRequest{ - ReqID: 8, - StmtID: prepareResp.StmtID, - FieldTypes: []int8{ - stmtCommon.TAOS_FIELD_QUERY, - }, - } - resp, err = doWebSocket(ws, STMT2GetFields, &getFieldsReq) - assert.NoError(t, err) - err = json.Unmarshal(resp, &getFieldsResp) - assert.NoError(t, err) - assert.Equal(t, uint64(8), getFieldsResp.ReqID) - assert.Equal(t, 0, getFieldsResp.Code, getFieldsResp.Message) - - assert.Nil(t, getFieldsResp.ColFields) - assert.Nil(t, getFieldsResp.TagFields) - assert.Equal(t, int32(2), getFieldsResp.QueryCount) - -} - func TestStmt2Query(t *testing.T) { //for stable prepareDataSql := []string{ From 808b8597e878890df2406f3aacd90a9f878f0a2d Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 5 Dec 2024 16:03:36 +0800 Subject: [PATCH 32/48] enh: fix race in go test --- controller/ws/ws/handler.go | 19 ++++++++++++++----- controller/ws/ws/query.go | 2 +- controller/ws/ws/ws.go | 8 ++++---- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index a1e02dda..9f8ff5a6 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" "unsafe" @@ -25,7 +26,7 @@ import ( type messageHandler struct { conn unsafe.Pointer logger *logrus.Entry - closed bool + closed uint32 once sync.Once wait sync.WaitGroup dropUserChan chan struct{} @@ -75,7 +76,7 @@ func (h *messageHandler) waitSignal(logger *logrus.Entry) { logger.Info("get drop user signal") isDebug := log.IsDebug() h.lock(logger, isDebug) - if h.closed { + if h.isClosed() { logger.Trace("server closed") h.Unlock() return @@ -87,7 +88,7 @@ func (h *messageHandler) waitSignal(logger *logrus.Entry) { logger.Info("get whitelist change signal") isDebug := log.IsDebug() h.lock(logger, isDebug) - if h.closed { + if h.isClosed() { logger.Trace("server closed") h.Unlock() return @@ -113,6 +114,14 @@ func (h *messageHandler) waitSignal(logger *logrus.Entry) { } } +func (h *messageHandler) isClosed() bool { + return atomic.LoadUint32(&h.closed) == 1 +} + +func (h *messageHandler) setClosed() { + atomic.StoreUint32(&h.closed, 1) +} + func (h *messageHandler) signalExit(logger *logrus.Entry, isDebug bool) { logger.Trace("close session") s := log.GetLogNow(isDebug) @@ -136,11 +145,11 @@ func (h *messageHandler) Close() { h.Lock() defer h.Unlock() - if h.closed { + if h.isClosed() { h.logger.Trace("server closed") return } - h.closed = true + h.setClosed() h.stop() close(h.exit) } diff --git a/controller/ws/ws/query.go b/controller/ws/ws/query.go index e0679d74..22162c81 100644 --- a/controller/ws/ws/query.go +++ b/controller/ws/ws/query.go @@ -32,7 +32,7 @@ type connRequest struct { func (h *messageHandler) connect(ctx context.Context, session *melody.Session, action string, req connRequest, logger *logrus.Entry, isDebug bool) { h.lock(logger, isDebug) defer h.Unlock() - if h.closed { + if h.isClosed() { logger.Trace("server closed") return } diff --git a/controller/ws/ws/ws.go b/controller/ws/ws/ws.go index 66898439..435f4ad6 100644 --- a/controller/ws/ws/ws.go +++ b/controller/ws/ws/ws.go @@ -42,13 +42,13 @@ func initController() *webSocketCtl { }) m.HandleMessage(func(session *melody.Session, data []byte) { h := session.MustGet(TaosKey).(*messageHandler) - if h.closed { + if h.isClosed() { return } h.wait.Add(1) go func() { defer h.wait.Done() - if h.closed { + if h.isClosed() { return } h.handleMessage(session, data) @@ -56,13 +56,13 @@ func initController() *webSocketCtl { }) m.HandleMessageBinary(func(session *melody.Session, data []byte) { h := session.MustGet(TaosKey).(*messageHandler) - if h.closed { + if h.isClosed() { return } h.wait.Add(1) go func() { defer h.wait.Done() - if h.closed { + if h.isClosed() { return } h.handleMessageBinary(session, data) From 977cf5d64a00256e8475bb5b53f4da2844563e7b Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 5 Dec 2024 16:09:47 +0800 Subject: [PATCH 33/48] test: add set tbname test --- controller/ws/ws/stmt2_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go index 7e186ed1..55f0fc16 100644 --- a/controller/ws/ws/stmt2_test.go +++ b/controller/ws/ws/stmt2_test.go @@ -72,7 +72,7 @@ func TestWsStmt2(t *testing.T) { prepareReq := stmt2PrepareRequest{ ReqID: 3, StmtID: initResp.StmtID, - SQL: "insert into ct1 using test_ws_stmt2_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + SQL: "insert into ? using test_ws_stmt2_ws.stb tags (?) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", GetFields: true, } resp, err = doWebSocket(ws, STMT2Prepare, &prepareReq) @@ -83,10 +83,10 @@ func TestWsStmt2(t *testing.T) { assert.Equal(t, uint64(3), prepareResp.ReqID) assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) assert.True(t, prepareResp.IsInsert) - assert.Equal(t, 17, len(prepareResp.Fields)) + assert.Equal(t, 18, len(prepareResp.Fields)) var colFields []*stmtCommon.StmtField var tagFields []*stmtCommon.StmtField - for i := 0; i < 17; i++ { + for i := 0; i < 18; i++ { field := &stmtCommon.StmtField{ FieldType: prepareResp.Fields[i].FieldType, Precision: prepareResp.Fields[i].Precision, From 11b27699b0c8d0e241733696550a468cd64dfe0e Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 9 Dec 2024 12:43:05 +0800 Subject: [PATCH 34/48] fix: stmt2 bind crash --- .github/workflows/macos.yml | 2 +- driver/wrapper/stmt2.go | 202 +------------------------- driver/wrapper/stmt2binary.go | 265 ++++++++++++++++++++++++++++++++++ 3 files changed, 270 insertions(+), 199 deletions(-) create mode 100644 driver/wrapper/stmt2binary.go diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 4d362431..f7622d84 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -126,7 +126,7 @@ jobs: strategy: matrix: os: [ 'macos-latest' ] - go: [ '1.17', 'stable' ] + go: [ '1.18', 'stable' ] name: Build taosAdapter ${{ matrix.go }} steps: - name: get cache server by pr diff --git a/driver/wrapper/stmt2.go b/driver/wrapper/stmt2.go index 306325c3..6db59f81 100644 --- a/driver/wrapper/stmt2.go +++ b/driver/wrapper/stmt2.go @@ -17,7 +17,6 @@ import "C" import ( "bytes" "database/sql/driver" - "encoding/binary" "fmt" "time" "unsafe" @@ -27,6 +26,7 @@ import ( taosError "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" "github.com/taosdata/taosadapter/v3/tools" + "github.com/taosdata/taosadapter/v3/tools/bytesutil" ) // TaosStmt2Init TAOS_STMT2 *taos_stmt2_init(TAOS *taos, TAOS_STMT2_OPTION *option); @@ -434,7 +434,7 @@ func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt if rowData != nil { switch value := rowData.(type) { case string: - x := *(*[]byte)(unsafe.Pointer(&value)) + x := bytesutil.ToUnsafeBytes(value) C.memcpy(unsafe.Pointer(uintptr(p)+uintptr(colOffset[i])), unsafe.Pointer(&x[0]), C.size_t(len(value))) case []byte: C.memcpy(unsafe.Pointer(uintptr(p)+uintptr(colOffset[i])), unsafe.Pointer(&value[0]), C.size_t(len(value))) @@ -599,7 +599,7 @@ func generateTaosStmt2BindsQuery(multiBind [][]driver.Value) (unsafe.Pointer, [] p = unsafe.Pointer(C.malloc(C.size_t(C.uint(valueLength)))) needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BINARY - x := *(*[]byte)(unsafe.Pointer(&rowData)) + x := bytesutil.ToUnsafeBytes(rowData) C.memcpy(p, unsafe.Pointer(&x[0]), C.size_t(valueLength)) lengthList = unsafe.Pointer(C.calloc(C.size_t(C.uint(1)), C.size_t(C.uint(4)))) needFreePointer = append(needFreePointer, lengthList) @@ -611,8 +611,7 @@ func generateTaosStmt2BindsQuery(multiBind [][]driver.Value) (unsafe.Pointer, [] p = unsafe.Pointer(C.malloc(C.size_t(C.uint(valueLength)))) needFreePointer = append(needFreePointer, p) bind.buffer_type = C.TSDB_DATA_TYPE_BINARY - x := *(*[]byte)(unsafe.Pointer(&value)) - C.memcpy(p, unsafe.Pointer(&x[0]), C.size_t(valueLength)) + C.memcpy(p, unsafe.Pointer(&value[0]), C.size_t(valueLength)) lengthList = unsafe.Pointer(C.calloc(C.size_t(C.uint(1)), C.size_t(C.uint(4)))) needFreePointer = append(needFreePointer, lengthList) *(*C.int32_t)(lengthList) = C.int32_t(valueLength) @@ -664,199 +663,6 @@ func TaosStmt2Error(stmt unsafe.Pointer) string { return C.GoString(C.taos_stmt2_error(stmt)) } -func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error { - totalLength := binary.LittleEndian.Uint32(data[stmt.TotalLengthPosition:]) - if totalLength != uint32(len(data)) { - return fmt.Errorf("total length not match, expect %d, but get %d", len(data), totalLength) - } - var freePointer []unsafe.Pointer - defer func() { - for i := len(freePointer) - 1; i >= 0; i-- { - if freePointer[i] != nil { - C.free(freePointer[i]) - } - } - }() - dataP := unsafe.Pointer(C.CBytes(data)) - freePointer = append(freePointer, dataP) - count := binary.LittleEndian.Uint32(data[stmt.CountPosition:]) - tagCount := binary.LittleEndian.Uint32(data[stmt.TagCountPosition:]) - colCount := binary.LittleEndian.Uint32(data[stmt.ColCountPosition:]) - tableNamesOffset := binary.LittleEndian.Uint32(data[stmt.TableNamesOffsetPosition:]) - tagsOffset := binary.LittleEndian.Uint32(data[stmt.TagsOffsetPosition:]) - colsOffset := binary.LittleEndian.Uint32(data[stmt.ColsOffsetPosition:]) - // check table names - if tableNamesOffset > 0 { - tableNameEnd := tableNamesOffset + count*2 - // table name lengths out of range - if tableNameEnd > totalLength { - return fmt.Errorf("table name lengths out of range, total length: %d, tableNamesLengthEnd: %d", totalLength, tableNameEnd) - } - for i := uint32(0); i < count; i++ { - tableNameLength := binary.LittleEndian.Uint16(data[tableNamesOffset+i*2:]) - tableNameEnd += uint32(tableNameLength) - } - if tableNameEnd > totalLength { - return fmt.Errorf("table names out of range, total length: %d, tableNameTotalLength: %d", totalLength, tableNameEnd) - } - } - // check tags - if tagsOffset > 0 { - if tagCount == 0 { - return fmt.Errorf("tag count is zero, but tags offset is not zero") - } - tagsEnd := tagsOffset + count*4 - if tagsEnd > totalLength { - return fmt.Errorf("tags lengths out of range, total length: %d, tagsLengthEnd: %d", totalLength, tagsEnd) - } - for i := uint32(0); i < count; i++ { - tagLength := binary.LittleEndian.Uint32(data[tagsOffset+i*4:]) - if tagLength == 0 { - return fmt.Errorf("tag length is zero, data index: %d", i) - } - tagsEnd += tagLength - } - if tagsEnd > totalLength { - return fmt.Errorf("tags out of range, total length: %d, tagsTotalLength: %d", totalLength, tagsEnd) - } - } - // check cols - if colsOffset > 0 { - if colCount == 0 { - return fmt.Errorf("col count is zero, but cols offset is not zero") - } - colsEnd := colsOffset + count*4 - if colsEnd > totalLength { - return fmt.Errorf("cols lengths out of range, total length: %d, colsLengthEnd: %d", totalLength, colsEnd) - } - for i := uint32(0); i < count; i++ { - colLength := binary.LittleEndian.Uint32(data[colsOffset+i*4:]) - if colLength == 0 { - return fmt.Errorf("col length is zero, data: %d", i) - } - colsEnd += colLength - } - if colsEnd > totalLength { - return fmt.Errorf("cols out of range, total length: %d, colsTotalLength: %d", totalLength, colsEnd) - } - } - cBindv := C.TAOS_STMT2_BINDV{} - cBindv.count = C.int(count) - if tableNamesOffset > 0 { - tableNameLengthP := tools.AddPointer(dataP, uintptr(tableNamesOffset)) - cTableNames := C.malloc(C.size_t(uintptr(count) * PointerSize)) - freePointer = append(freePointer, cTableNames) - tableDataP := tools.AddPointer(tableNameLengthP, uintptr(count)*2) - var tableNamesArrayP unsafe.Pointer - for i := uint32(0); i < count; i++ { - tableNamesArrayP = tools.AddPointer(cTableNames, uintptr(i)*PointerSize) - *(**C.char)(tableNamesArrayP) = (*C.char)(tableDataP) - tableNameLength := *(*uint16)(tools.AddPointer(tableNameLengthP, uintptr(i*2))) - if tableNameLength == 0 { - return fmt.Errorf("table name length is zero, data index: %d", i) - } - tableDataP = tools.AddPointer(tableDataP, uintptr(tableNameLength)) - } - cBindv.tbnames = (**C.char)(cTableNames) - } else { - cBindv.tbnames = nil - } - if tagsOffset > 0 { - tags, needFreePointer, err := generateStmt2Binds(count, tagCount, dataP, tagsOffset) - freePointer = append(freePointer, needFreePointer...) - if err != nil { - return fmt.Errorf("generate tags error: %s", err.Error()) - } - cBindv.tags = (**C.TAOS_STMT2_BIND)(tags) - } else { - cBindv.tags = nil - } - if colsOffset > 0 { - cols, needFreePointer, err := generateStmt2Binds(count, colCount, dataP, colsOffset) - freePointer = append(freePointer, needFreePointer...) - if err != nil { - return fmt.Errorf("generate cols error: %s", err.Error()) - } - cBindv.bind_cols = (**C.TAOS_STMT2_BIND)(cols) - } else { - cBindv.bind_cols = nil - } - code := int(C.taos_stmt2_bind_param(stmt2, &cBindv, C.int32_t(colIdx))) - if code != 0 { - errStr := TaosStmt2Error(stmt2) - return taosError.NewError(code, errStr) - } - return nil -} - -func generateStmt2Binds(count uint32, fieldCount uint32, dataP unsafe.Pointer, fieldsOffset uint32) (unsafe.Pointer, []unsafe.Pointer, error) { - var freePointer []unsafe.Pointer - bindsCList := unsafe.Pointer(C.malloc(C.size_t(uintptr(count) * PointerSize))) - freePointer = append(freePointer, bindsCList) - // dataLength [count]uint32 - // length have checked in TaosStmt2BindBinary - baseLengthPointer := tools.AddPointer(dataP, uintptr(fieldsOffset)) - // dataBuffer - dataPointer := tools.AddPointer(baseLengthPointer, uintptr(count)*4) - var bindsPointer unsafe.Pointer - for tableIndex := uint32(0); tableIndex < count; tableIndex++ { - bindsPointer = tools.AddPointer(bindsCList, uintptr(tableIndex)*PointerSize) - binds := unsafe.Pointer(C.malloc(C.size_t(C.size_t(fieldCount) * C.size_t(unsafe.Sizeof(C.TAOS_STMT2_BIND{}))))) - freePointer = append(freePointer, binds) - var bindDataP unsafe.Pointer - var bindDataTotalLength uint32 - var num int32 - var haveLength byte - var bufferLength uint32 - for fieldIndex := uint32(0); fieldIndex < fieldCount; fieldIndex++ { - // field data - bindDataP = dataPointer - // totalLength - bindDataTotalLength = *(*uint32)(bindDataP) - bindDataP = tools.AddPointer(bindDataP, common.UInt32Size) - bind := (*C.TAOS_STMT2_BIND)(unsafe.Pointer(uintptr(binds) + uintptr(fieldIndex)*unsafe.Sizeof(C.TAOS_STMT2_BIND{}))) - // buffer_type - bind.buffer_type = *(*C.int)(bindDataP) - bindDataP = tools.AddPointer(bindDataP, common.Int32Size) - // num - num = *(*int32)(bindDataP) - bind.num = C.int(num) - bindDataP = tools.AddPointer(bindDataP, common.Int32Size) - // is_null - bind.is_null = (*C.char)(bindDataP) - bindDataP = tools.AddPointer(bindDataP, uintptr(num)) - // haveLength - haveLength = *(*byte)(bindDataP) - bindDataP = tools.AddPointer(bindDataP, common.Int8Size) - if haveLength == 0 { - bind.length = nil - } else { - // length [num]int32 - bind.length = (*C.int32_t)(bindDataP) - bindDataP = tools.AddPointer(bindDataP, common.Int32Size*uintptr(num)) - } - // bufferLength - bufferLength = *(*uint32)(bindDataP) - bindDataP = tools.AddPointer(bindDataP, common.UInt32Size) - // buffer - if bufferLength == 0 { - bind.buffer = nil - } else { - bind.buffer = bindDataP - } - bindDataP = tools.AddPointer(bindDataP, uintptr(bufferLength)) - // check bind data length - bindDataLen := uintptr(bindDataP) - uintptr(dataPointer) - if bindDataLen != uintptr(bindDataTotalLength) { - return nil, freePointer, fmt.Errorf("bind data length not match, expect %d, but get %d, tableIndex:%d", bindDataTotalLength, bindDataLen, tableIndex) - } - dataPointer = bindDataP - } - *(**C.TAOS_STMT2_BIND)(bindsPointer) = (*C.TAOS_STMT2_BIND)(binds) - } - return bindsCList, freePointer, nil -} - // TaosStmt2GetStbFields int taos_stmt2_get_stb_fields(TAOS_STMT2 *stmt, int *count, TAOS_FIELD_STB **fields); func TaosStmt2GetStbFields(stmt unsafe.Pointer) (code, count int, fields unsafe.Pointer) { code = int(C.taos_stmt2_get_stb_fields(stmt, (*C.int)(unsafe.Pointer(&count)), (**C.TAOS_FIELD_STB)(unsafe.Pointer(&fields)))) diff --git a/driver/wrapper/stmt2binary.go b/driver/wrapper/stmt2binary.go new file mode 100644 index 00000000..e6b3b0ff --- /dev/null +++ b/driver/wrapper/stmt2binary.go @@ -0,0 +1,265 @@ +package wrapper + +/* +#include +#include +#include +#include + +int +generate_stmt2_binds(char *data, uint32_t count, uint32_t field_count, uint32_t field_offset, + TAOS_STMT2_BIND *bind_struct, + TAOS_STMT2_BIND **bind_ptr, char *err_msg) { + uint32_t *base_length = (uint32_t *) (data + field_offset); + char *data_ptr = (char *) (base_length + count); + for (int table_index = 0; table_index < count; table_index++) { + bind_ptr[table_index] = bind_struct + table_index * field_count; + char *bind_data_ptr; + for (uint32_t field_index = 0; field_index < field_count; field_index++) { + bind_data_ptr = data_ptr; + TAOS_STMT2_BIND *bind = bind_ptr[table_index] + field_index; + // total length + uint32_t bind_data_totalLength = *(uint32_t *) bind_data_ptr; + bind_data_ptr += 4; + // buffer_type + bind->buffer_type = *(int *) bind_data_ptr; + bind_data_ptr += 4; + // num + bind->num = *(int *) bind_data_ptr; + bind_data_ptr += 4; + // is_null + bind->is_null = (char *) bind_data_ptr; + bind_data_ptr += bind->num; + // have_length + char have_length = *(char *) bind_data_ptr; + bind_data_ptr += 1; + if (have_length == 0) { + bind->length = NULL; + } else { + bind->length = (int32_t *) bind_data_ptr; + bind_data_ptr += bind->num * 4; + } + // buffer_length + int32_t buffer_length = *(int32_t *) bind_data_ptr; + bind_data_ptr += 4; + // buffer + if (buffer_length > 0) { + bind->buffer = (void *) bind_data_ptr; + bind_data_ptr += buffer_length; + } else { + bind->buffer = NULL; + } + // check bind data length + if (bind_data_ptr - data_ptr != bind_data_totalLength) { + snprintf(err_msg, 128, "bind data length error, tableIndex: %d, fieldIndex: %d", table_index, field_index); + return -1; + } + data_ptr = bind_data_ptr; + } + } + return 0; +} + + +int taos_stmt2_bind_binary(TAOS_STMT2 *stmt, char *data, int32_t col_idx, char *err_msg) { + uint32_t *header = (uint32_t *) data; + uint32_t total_length = header[0]; + uint32_t count = header[1]; + uint32_t tag_count = header[2]; + uint32_t col_count = header[3]; + uint32_t table_names_offset = header[4]; + uint32_t tags_offset = header[5]; + uint32_t cols_offset = header[6]; + // check table names + if (table_names_offset > 0) { + uint32_t table_name_end = table_names_offset + count * 2; + if (table_name_end > total_length) { + snprintf(err_msg, 128, "table name lengths out of range, total length: %d, tableNamesLengthEnd: %d", total_length, + table_name_end); + return -1; + } + uint16_t *table_name_length_ptr = (uint16_t *) (data + table_names_offset); + for (int32_t i = 0; i < count; ++i) { + table_name_end += (uint32_t) table_name_length_ptr[i]; + } + if (table_name_end > total_length) { + snprintf(err_msg, 128, "table names out of range, total length: %d, tableNameTotalLength: %d", total_length, + table_name_end); + return -1; + } + } + // check tags + if (tags_offset > 0) { + if (tag_count == 0) { + snprintf(err_msg, 128, "tag count is 0, but tags offset is not 0"); + return -1; + } + uint32_t tag_end = tags_offset + count * 4; + if (tag_end > total_length) { + snprintf(err_msg, 128, "tags out of range, total length: %d, tagEnd: %d", total_length, tag_end); + return -1; + } + uint32_t *tab_length_ptr = (uint32_t *) (data + tags_offset); + for (int32_t i = 0; i < count; ++i) { + if (tab_length_ptr[i] == 0) { + snprintf(err_msg, 128, "tag length is 0, tableIndex: %d", i); + return -1; + } + tag_end += tab_length_ptr[i]; + } + if (tag_end > total_length) { + snprintf(err_msg, 128, "tags out of range, total length: %d, tagsTotalLength: %d", total_length, tag_end); + return -1; + } + } + // check cols + if (cols_offset > 0) { + if (col_count == 0) { + snprintf(err_msg, 128, "col count is 0, but cols offset is not 0"); + return -1; + } + uint32_t colEnd = cols_offset + count * 4; + if (colEnd > total_length) { + snprintf(err_msg, 128, "cols out of range, total length: %d, colEnd: %d", total_length, colEnd); + return -1; + } + uint32_t *col_length_ptr = (uint32_t *) (data + cols_offset); + for (int32_t i = 0; i < count; ++i) { + if (col_length_ptr[i] == 0) { + snprintf(err_msg, 128, "col length is 0, tableIndex: %d", i); + return -1; + } + colEnd += col_length_ptr[i]; + } + if (colEnd > total_length) { + snprintf(err_msg, 128, "cols out of range, total length: %d, colsTotalLength: %d", total_length, colEnd); + return -1; + } + } + // generate bindv struct + TAOS_STMT2_BINDV bind_v; + bind_v.count = (int) count; + if (table_names_offset > 0) { + uint16_t *table_name_length_ptr = (uint16_t *) (data + table_names_offset); + char *table_name_data_ptr = (char *) (table_name_length_ptr) + 2 * count; + char **table_name = (char **) malloc(sizeof(char *) * count); + if (table_name == NULL) { + snprintf(err_msg, 128, "malloc tableName error"); + return -1; + } + for (int i = 0; i < count; i++) { + table_name[i] = table_name_data_ptr; + table_name_data_ptr += table_name_length_ptr[i]; + } + bind_v.tbnames = table_name; + } else { + bind_v.tbnames = NULL; + } + uint32_t bind_struct_count = 0; + uint32_t bind_ptr_count = 0; + if (tags_offset == 0) { + bind_v.tags = NULL; + } else { + bind_struct_count += count * tag_count; + bind_ptr_count += count; + } + if (cols_offset == 0) { + bind_v.bind_cols = NULL; + } else { + bind_struct_count += count * col_count; + bind_ptr_count += count; + } + TAOS_STMT2_BIND *bind_struct = NULL; + TAOS_STMT2_BIND **bind_ptr = NULL; + if (bind_struct_count == 0) { + bind_v.tags = NULL; + bind_v.bind_cols = NULL; + } else { + // []TAOS_STMT2_BIND bindStruct + bind_struct = (TAOS_STMT2_BIND *) malloc(sizeof(TAOS_STMT2_BIND) * bind_struct_count); + if (bind_struct == NULL) { + snprintf(err_msg, 128, "malloc bind struct error"); + free(bind_v.tbnames); + return -1; + } + // []TAOS_STMT2_BIND *bindPtr + bind_ptr = (TAOS_STMT2_BIND **) malloc(sizeof(TAOS_STMT2_BIND *) * bind_ptr_count); + if (bind_ptr == NULL) { + snprintf(err_msg, 128, "malloc bind pointer error"); + free(bind_struct); + free(bind_v.tbnames); + return -1; + } + uint32_t struct_index = 0; + uint32_t ptr_index = 0; + if (tags_offset > 0) { + int code = generate_stmt2_binds(data, count, tag_count, tags_offset, bind_struct, bind_ptr, err_msg); + if (code != 0) { + free(bind_struct); + free(bind_ptr); + free(bind_v.tbnames); + return code; + } + bind_v.tags = bind_ptr; + struct_index += count * tag_count; + ptr_index += count; + } + if (cols_offset > 0) { + TAOS_STMT2_BIND *col_bind_struct = bind_struct + struct_index; + TAOS_STMT2_BIND **col_bind_ptr = bind_ptr + ptr_index; + int code = generate_stmt2_binds(data, count, col_count, cols_offset, col_bind_struct, col_bind_ptr, + err_msg); + if (code != 0) { + free(bind_struct); + free(bind_ptr); + free(bind_v.tbnames); + return code; + } + bind_v.bind_cols = col_bind_ptr; + } + } + int code = taos_stmt2_bind_param(stmt, &bind_v, col_idx); + if (code != 0) { + char *msg = taos_stmt2_error(stmt); + snprintf(err_msg, 128, "%s", msg); + } + if (bind_v.tbnames != NULL) { + free(bind_v.tbnames); + } + if (bind_struct != NULL) { + free(bind_struct); + } + if (bind_ptr != NULL) { + free(bind_ptr); + } + return code; +} +*/ +import "C" +import ( + "encoding/binary" + "fmt" + "unsafe" + + "github.com/taosdata/taosadapter/v3/driver/common/stmt" + taosError "github.com/taosdata/taosadapter/v3/driver/errors" +) + +// TaosStmt2BindBinary bind binary data to stmt2 +func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error { + totalLength := binary.LittleEndian.Uint32(data[stmt.TotalLengthPosition:]) + if totalLength != uint32(len(data)) { + return fmt.Errorf("total length not match, expect %d, but get %d", len(data), totalLength) + } + dataP := C.CBytes(data) + defer C.free(dataP) + errMsg := (*C.char)(C.malloc(128)) + defer C.free(unsafe.Pointer(errMsg)) + + code := C.taos_stmt2_bind_binary(stmt2, (*C.char)(dataP), C.int32_t(colIdx), errMsg) + if code != 0 { + msg := C.GoString(errMsg) + return taosError.NewError(int(code), msg) + } + return nil +} From 064359b6f8a26dc38fa732d6d697458fa497baa2 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 9 Dec 2024 17:26:02 +0800 Subject: [PATCH 35/48] test: add stmt2 bind table name as value test --- driver/wrapper/stmt2_test.go | 463 +++++++++++++++++++++++++++++++++++ 1 file changed, 463 insertions(+) diff --git a/driver/wrapper/stmt2_test.go b/driver/wrapper/stmt2_test.go index 54576e39..d09382a1 100644 --- a/driver/wrapper/stmt2_test.go +++ b/driver/wrapper/stmt2_test.go @@ -5419,3 +5419,466 @@ func TestWrongParseStmt2StbFields(t *testing.T) { fs = ParseStmt2StbFields(2, nil) assert.Nil(t, fs) } + +func TestStmt2BindTbnameAsValue(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + err = exec(conn, "drop database if exists test_stmt2_bind_tbname_as_value") + if err != nil { + t.Error(err) + return + } + }() + err = exec(conn, "create database if not exists test_stmt2_bind_tbname_as_value precision 'ns' keep 36500") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "use test_stmt2_bind_tbname_as_value") + if err != nil { + t.Error(err) + return + } + err = exec(conn, "create table if not exists all_stb("+ + "ts timestamp, "+ + "v1 bool, "+ + "v2 tinyint, "+ + "v3 smallint, "+ + "v4 int, "+ + "v5 bigint, "+ + "v6 tinyint unsigned, "+ + "v7 smallint unsigned, "+ + "v8 int unsigned, "+ + "v9 bigint unsigned, "+ + "v10 float, "+ + "v11 double, "+ + "v12 binary(20), "+ + "v13 varbinary(20), "+ + "v14 geometry(100), "+ + "v15 nchar(20))"+ + "tags("+ + "tts timestamp, "+ + "tv1 bool, "+ + "tv2 tinyint, "+ + "tv3 smallint, "+ + "tv4 int, "+ + "tv5 bigint, "+ + "tv6 tinyint unsigned, "+ + "tv7 smallint unsigned, "+ + "tv8 int unsigned, "+ + "tv9 bigint unsigned, "+ + "tv10 float, "+ + "tv11 double, "+ + "tv12 binary(20), "+ + "tv13 varbinary(20), "+ + "tv14 geometry(100), "+ + "tv15 nchar(20))") + if err != nil { + t.Error(err) + return + } + caller := NewStmtCallBackTest() + handler := cgo.NewHandle(caller) + insertStmt := TaosStmt2Init(conn, 0xff1234, false, false, handler) + prepareInsertSql := "insert into all_stb (ts ,v1 ,v2 ,v3 ,v4 ,v5 ,v6 ,v7 ,v8 ,v9 ,v10,v11,v12,v13,v14,v15,tbname,tts,tv1 ,tv2 ,tv3 ,tv4 ,tv5 ,tv6 ,tv7 ,tv8 ,tv9 ,tv10,tv11,tv12,tv13,tv14,tv15) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" + code := TaosStmt2Prepare(insertStmt, prepareInsertSql) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + + isInsert, code := TaosStmt2IsInsert(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + assert.True(t, isInsert) + + code, count, cFields := TaosStmt2GetStbFields(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + defer TaosStmt2FreeStbFields(insertStmt, cFields) + assert.Equal(t, 33, count) + fields := ParseStmt2StbFields(count, cFields) + assert.Equal(t, 33, len(fields)) + expectMap := map[string]*StmtStbField{ + "tts": { + Name: "tts", + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionNanoSecond, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv1": { + Name: "tv1", + FieldType: common.TSDB_DATA_TYPE_BOOL, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv2": { + Name: "tv2", + FieldType: common.TSDB_DATA_TYPE_TINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv3": { + Name: "tv3", + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv4": { + Name: "tv4", + FieldType: common.TSDB_DATA_TYPE_INT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv5": { + Name: "tv5", + FieldType: common.TSDB_DATA_TYPE_BIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv6": { + Name: "tv6", + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv7": { + Name: "tv7", + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv8": { + Name: "tv8", + FieldType: common.TSDB_DATA_TYPE_UINT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv9": { + Name: "tv9", + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv10": { + Name: "tv10", + FieldType: common.TSDB_DATA_TYPE_FLOAT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv11": { + Name: "tv11", + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + Bytes: 8, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv12": { + Name: "tv12", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv13": { + Name: "tv13", + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv14": { + Name: "tv14", + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + Bytes: 102, + BindType: stmt.TAOS_FIELD_TAG, + }, + "tv15": { + Name: "tv15", + FieldType: common.TSDB_DATA_TYPE_NCHAR, + Bytes: 82, + BindType: stmt.TAOS_FIELD_TAG, + }, + "ts": { + Name: "ts", + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + Precision: common.PrecisionNanoSecond, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v1": { + Name: "v1", + FieldType: common.TSDB_DATA_TYPE_BOOL, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v2": { + Name: "v2", + FieldType: common.TSDB_DATA_TYPE_TINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v3": { + Name: "v3", + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_COL, + }, + "v4": { + Name: "v4", + FieldType: common.TSDB_DATA_TYPE_INT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v5": { + Name: "v5", + FieldType: common.TSDB_DATA_TYPE_BIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v6": { + Name: "v6", + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + Bytes: 1, + BindType: stmt.TAOS_FIELD_COL, + }, + "v7": { + Name: "v7", + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + Bytes: 2, + BindType: stmt.TAOS_FIELD_COL, + }, + "v8": { + Name: "v8", + FieldType: common.TSDB_DATA_TYPE_UINT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v9": { + Name: "v9", + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v10": { + Name: "v10", + FieldType: common.TSDB_DATA_TYPE_FLOAT, + Bytes: 4, + BindType: stmt.TAOS_FIELD_COL, + }, + "v11": { + Name: "v11", + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + Bytes: 8, + BindType: stmt.TAOS_FIELD_COL, + }, + "v12": { + Name: "v12", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_COL, + }, + "v13": { + Name: "v13", + FieldType: common.TSDB_DATA_TYPE_VARBINARY, + Bytes: 22, + BindType: stmt.TAOS_FIELD_COL, + }, + "v14": { + Name: "v14", + FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + Bytes: 102, + BindType: stmt.TAOS_FIELD_COL, + }, + "v15": { + Name: "v15", + FieldType: common.TSDB_DATA_TYPE_NCHAR, + Bytes: 82, + BindType: stmt.TAOS_FIELD_COL, + }, + "tbname": { + Name: "tbname", + FieldType: common.TSDB_DATA_TYPE_BINARY, + Bytes: 271, + BindType: stmt.TAOS_FIELD_TBNAME, + }, + } + var colFields, tagFields []*stmt.StmtField + for i := 0; i < 33; i++ { + expect := expectMap[fields[i].Name] + assert.Equal(t, expect, fields[i]) + field := &stmt.StmtField{ + FieldType: fields[i].FieldType, + Precision: fields[i].Precision, + } + if fields[i].BindType == stmt.TAOS_FIELD_COL { + colFields = append(colFields, field) + } else if fields[i].BindType == stmt.TAOS_FIELD_TAG { + tagFields = append(tagFields, field) + } + } + + now := time.Now() + params2 := []*stmt.TaosStmt2BindData{{ + TableName: "ctb1", + Tags: []driver.Value{ + // TIMESTAMP + now, + // BOOL + true, + // TINYINT + int8(1), + // SMALLINT + int16(1), + // INT + int32(1), + // BIGINT + int64(1), + // UTINYINT + uint8(1), + // USMALLINT + uint16(1), + // UINT + uint32(1), + // UBIGINT + uint64(1), + // FLOAT + float32(1.2), + // DOUBLE + float64(1.2), + // BINARY + []byte("binary"), + // VARBINARY + []byte("varbinary"), + // GEOMETRY + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + // NCHAR + "nchar", + }, + Cols: [][]driver.Value{ + { + now, + now.Add(time.Second), + now.Add(time.Second * 2), + }, + { + true, + nil, + false, + }, + { + int8(11), + nil, + int8(12), + }, + { + int16(11), + nil, + int16(12), + }, + { + int32(11), + nil, + int32(12), + }, + { + int64(11), + nil, + int64(12), + }, + { + uint8(11), + nil, + uint8(12), + }, + { + uint16(11), + nil, + uint16(12), + }, + { + uint32(11), + nil, + uint32(12), + }, + { + uint64(11), + nil, + uint64(12), + }, + { + float32(11.2), + nil, + float32(12.2), + }, + { + float64(11.2), + nil, + float64(12.2), + }, + { + []byte("binary1"), + nil, + []byte("binary2"), + }, + { + []byte("varbinary1"), + nil, + []byte("varbinary2"), + }, + { + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + nil, + []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, + { + "nchar1", + nil, + "nchar2", + }, + }, + }} + bs, err := stmt.MarshalStmt2Binary(params2, true, colFields, tagFields) + assert.NoError(t, err) + err = TaosStmt2BindBinary(insertStmt, bs, -1) + if err != nil { + t.Error(err) + return + } + code = TaosStmt2Exec(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } + r := <-caller.ExecResult + if r.n != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(r.n, errStr) + t.Error(err) + return + } + assert.Equal(t, 3, r.affected) + + code = TaosStmt2Close(insertStmt) + if code != 0 { + errStr := TaosStmt2Error(insertStmt) + err = taosError.NewError(code, errStr) + t.Error(err) + return + } +} From c90f5a21d58bba16dbe80b2e516517d2171433b4 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 10 Dec 2024 14:59:31 +0800 Subject: [PATCH 36/48] feat: support set connection options --- controller/rest/restful.go | 73 +++++++++++++-------- controller/rest/restful_test.go | 61 +++++++++++++++++ controller/ws/tmq/tmq.go | 81 +++++++++++++++++------ controller/ws/tmq/tmq_test.go | 88 +++++++++++++++++++++++++ controller/ws/ws/const.go | 3 + controller/ws/ws/handler.go | 14 ++++ controller/ws/ws/handler_test.go | 6 ++ controller/ws/ws/misc.go | 28 ++++++++ controller/ws/ws/misc_test.go | 94 +++++++++++++++++++++++++++ controller/ws/ws/query.go | 91 +++++++++++++++++--------- controller/ws/ws/query_test.go | 45 +++++++++---- db/commonpool/pool.go | 5 ++ db/syncinterface/wrapper.go | 12 ++++ db/syncinterface/wrapper_test.go | 23 +++++++ driver/common/const.go | 8 +++ driver/wrapper/taosc.go | 13 ++++ driver/wrapper/taosc_test.go | 108 +++++++++++++++++++++++++++++++ plugin/influxdb/plugin.go | 10 +++ plugin/influxdb/plugin_test.go | 8 +-- plugin/opentsdb/plugin.go | 17 +++++ plugin/opentsdb/plugin_test.go | 4 +- tools/ctools/block.go | 27 ++++++-- tools/ctools/block_test.go | 19 +----- 23 files changed, 721 insertions(+), 117 deletions(-) diff --git a/controller/rest/restful.go b/controller/rest/restful.go index 8b56bd32..36bae648 100644 --- a/controller/rest/restful.go +++ b/controller/rest/restful.go @@ -17,6 +17,7 @@ import ( "github.com/taosdata/taosadapter/v3/controller" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/commonpool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/driver/common" "github.com/taosdata/taosadapter/v3/driver/common/parser" tErrors "github.com/taosdata/taosadapter/v3/driver/errors" @@ -31,7 +32,6 @@ import ( "github.com/taosdata/taosadapter/v3/tools/generator" "github.com/taosdata/taosadapter/v3/tools/iptool" "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" - "github.com/taosdata/taosadapter/v3/tools/layout" "github.com/taosdata/taosadapter/v3/tools/pool" "github.com/taosdata/taosadapter/v3/tools/sqltype" ) @@ -112,19 +112,30 @@ type TDEngineRestfulRespDoc struct { // @Router /rest/sql [post] func (ctl *Restful) sql(c *gin.Context) { db := c.Param("db") - timeZone := c.Query("tz") - location := time.UTC var err error logger := c.MustGet(LoggerKey).(*logrus.Entry) reqID := c.MustGet(config.ReqIDKey).(int64) - if len(timeZone) != 0 { - location, err = time.LoadLocation(timeZone) + connTimezone, exists := c.GetQuery("conn_tz") + location := time.UTC + if exists { + location, err = time.LoadLocation(connTimezone) if err != nil { - logger.Errorf("load location:%s error:%s", timeZone, err) + logger.Errorf("load conn_tz location:%s fail, error:%s", connTimezone, err) BadRequestResponseWithMsg(c, logger, 0xffff, err.Error()) return } + } else { + timezone, exists := c.GetQuery("tz") + if exists { + location, err = time.LoadLocation(timezone) + if err != nil { + logger.Errorf("load tz location:%s fail, error:%s", timezone, err) + BadRequestResponseWithMsg(c, logger, 0xffff, err.Error()) + return + } + } } + var returnObj bool if returnObjStr := c.Query("row_with_meta"); len(returnObjStr) != 0 { if returnObj, err = strconv.ParseBool(returnObjStr); err != nil { @@ -134,21 +145,7 @@ func (ctl *Restful) sql(c *gin.Context) { } } - timeBuffer := make([]byte, 0, 30) - DoQuery(c, db, func(builder *jsonbuilder.Stream, ts int64, precision int) { - timeBuffer = timeBuffer[:0] - switch precision { - case common.PrecisionMilliSecond: // milli-second - timeBuffer = time.Unix(ts/1e3, (ts%1e3)*1e6).In(location).AppendFormat(timeBuffer, layout.LayoutMillSecond) - case common.PrecisionMicroSecond: // micro-second - timeBuffer = time.Unix(ts/1e6, (ts%1e6)*1e3).In(location).AppendFormat(timeBuffer, layout.LayoutMicroSecond) - case common.PrecisionNanoSecond: // nano-second - timeBuffer = time.Unix(0, ts).In(location).AppendFormat(timeBuffer, layout.LayoutNanoSecond) - default: - logger.Errorf("unknown precision:%d", precision) - } - builder.WriteString(string(timeBuffer)) - }, reqID, returnObj, logger) + DoQuery(c, db, location, reqID, returnObj, logger) } type TDEngineRestfulResp struct { @@ -159,7 +156,7 @@ type TDEngineRestfulResp struct { Rows int `json:"rows,omitempty"` } -func DoQuery(c *gin.Context, db string, timeFunc ctools.FormatTimeFunc, reqID int64, returnObj bool, logger *logrus.Entry) { +func DoQuery(c *gin.Context, db string, location *time.Location, reqID int64, returnObj bool, logger *logrus.Entry) { var s time.Time isDebug := log.IsDebug() b, err := c.GetRawData() @@ -218,14 +215,37 @@ func DoQuery(c *gin.Context, db string, timeFunc ctools.FormatTimeFunc, reqID in } logger.Debugf("put connection finish, cost:%s", log.GetLogDuration(isDebug, s)) }() - + // set connection options + success := trySetConnectionOptions(c, taosConnect.TaosConnection, logger, isDebug) + if !success { + monitor.RestRecordResult(sqlType, false) + return + } if len(db) > 0 { // Attempt to select the database does not return even if there is an error // To avoid error reporting in the `create database` statement logger.Tracef("select db %s", db) _ = async.GlobalAsync.TaosExecWithoutResult(taosConnect.TaosConnection, logger, isDebug, fmt.Sprintf("use `%s`", db), reqID) } - execute(c, logger, isDebug, taosConnect.TaosConnection, sql, timeFunc, reqID, sqlType, returnObj) + execute(c, logger, isDebug, taosConnect.TaosConnection, sql, reqID, sqlType, returnObj, location) +} + +func trySetConnectionOptions(c *gin.Context, conn unsafe.Pointer, logger *logrus.Entry, isDebug bool) bool { + keys := [3]string{"conn_tz", "app", "ip"} + options := [3]int{common.TSDB_OPTION_CONNECTION_TIMEZONE, common.TSDB_OPTION_CONNECTION_USER_APP, common.TSDB_OPTION_CONNECTION_USER_IP} + for i := 0; i < 3; i++ { + val := c.Query(keys[i]) + if val != "" { + code := syncinterface.TaosOptionsConnection(conn, options[i], &val, logger, isDebug) + if code != httperror.SUCCESS { + errStr := wrapper.TaosErrorStr(conn) + logger.Errorf("set connection options error, option:%d, val:%s, code:%d, message:%s", options[i], val, code, errStr) + TaosErrorResponse(c, logger, code, errStr) + return false + } + } + } + return true } var ( @@ -241,7 +261,7 @@ var ( Timing = []byte(`,"timing":`) ) -func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect unsafe.Pointer, sql string, timeFormat ctools.FormatTimeFunc, reqID int64, sqlType sqltype.SqlType, returnObj bool) { +func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect unsafe.Pointer, sql string, reqID int64, sqlType sqltype.SqlType, returnObj bool, location *time.Location) { _, calculateTiming := c.Get(RequireTiming) st := c.MustGet(StartTimeKey) flushTiming := int64(0) @@ -364,6 +384,7 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns fetched := false pHeaderList := make([]unsafe.Pointer, fieldsCount) pStartList := make([]unsafe.Pointer, fieldsCount) + timeBuffer := make([]byte, 0, 30) for { if config.Conf.RestfulRowLimit > -1 && total == config.Conf.RestfulRowLimit { break @@ -412,7 +433,7 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns if returnObj { builder.WriteObjectField(rowsHeader.ColNames[column]) } - ctools.JsonWriteRawBlock(builder, rowsHeader.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, timeFormat) + ctools.JsonWriteRawBlock(builder, rowsHeader.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, location, timeBuffer, logger) if column != fieldsCount-1 { builder.WriteMore() } diff --git a/controller/rest/restful_test.go b/controller/rest/restful_test.go index 39a54d50..8212a6d3 100644 --- a/controller/rest/restful_test.go +++ b/controller/rest/restful_test.go @@ -21,6 +21,7 @@ import ( "github.com/taosdata/taosadapter/v3/db" "github.com/taosdata/taosadapter/v3/httperror" "github.com/taosdata/taosadapter/v3/log" + "github.com/taosdata/taosadapter/v3/tools/layout" ) var router *gin.Engine @@ -674,3 +675,63 @@ func TestInternalError(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) } + +func TestSetConnectionOptions(t *testing.T) { + config.Conf.RestfulRowLimit = -1 + w := httptest.NewRecorder() + body := strings.NewReader("create database if not exists rest_test_options") + url := "/rest/sql?app=rest_test_options&ip=192.168.100.1&conn_tz=Europe/Moscow&tz=Asia/Shanghai" + req, _ := http.NewRequest(http.MethodPost, url, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + router.ServeHTTP(w, req) + checkResp(t, w) + + defer func() { + w = httptest.NewRecorder() + body = strings.NewReader("drop database if exists rest_test_options") + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + checkResp(t, w) + }() + + w = httptest.NewRecorder() + body = strings.NewReader("create table if not exists rest_test_options.t1(ts timestamp,v1 bool)") + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + checkResp(t, w) + + w = httptest.NewRecorder() + ts := "2024-12-04 12:34:56.789" + body = strings.NewReader(fmt.Sprintf(`insert into rest_test_options.t1 values ('%s',true)`, ts)) + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + checkResp(t, w) + + w = httptest.NewRecorder() + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + var result TDEngineRestfulRespDoc + err := json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.Equal(t, 0, result.Code) + assert.Equal(t, 1, len(result.Data)) + + location, err := time.LoadLocation("Europe/Moscow") + assert.NoError(t, err) + expectTime, err := time.ParseInLocation("2006-01-02 15:04:05.000", ts, location) + assert.NoError(t, err) + expectTimeStr := expectTime.Format(layout.LayoutMillSecond) + assert.Equal(t, expectTimeStr, result.Data[0][0]) + t.Log(expectTimeStr, result.Data[0][0]) +} + +func checkResp(t *testing.T, w *httptest.ResponseRecorder) { + assert.Equal(t, 200, w.Code) + var result TDEngineRestfulRespDoc + err := json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.Equal(t, 0, result.Code) +} diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index cdbcb18e..e3ae3c2d 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -4,6 +4,8 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" + "fmt" "net" "strconv" "sync" @@ -395,6 +397,9 @@ type TMQSubscribeReq struct { MsgConsumeExcluded string `json:"msg_consume_excluded"` SessionTimeoutMS string `json:"session_timeout_ms"` MaxPollIntervalMS string `json:"max_poll_interval_ms"` + TZ string `json:"tz"` + App string `json:"app"` + IP string `json:"ip"` } type TMQSubscribeResp struct { @@ -539,55 +544,77 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu logger.Tracef("tmq append topic:%s", topic) errCode = wrapper.TMQListAppend(topicList, topic) if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - logger.Errorf("tmq list append error, tpic:%s, code:%d, msg:%s", topic, errCode, errStr) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wsTMQErrorMsg(ctx, session, logger, int(errCode), errStr, action, req.ReqID, nil) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(int(errCode), wrapper.TMQErr2Str(errCode)), fmt.Sprintf("tmq list append error, tpic:%s", topic)) return } } errCode = t.wrapperSubscribe(logger, isDebug, cPointer, topicList) if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - logger.Errorf("tmq subscribe error:%d %s", errCode, errStr) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wsTMQErrorMsg(ctx, session, logger, int(errCode), errStr, action, req.ReqID, nil) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(int(errCode), wrapper.TMQErr2Str(errCode)), "tmq subscribe error") return } conn := wrapper.TMQGetConnect(cPointer) logger.Trace("get whitelist") whitelist, err := tool.GetWhitelist(conn) if err != nil { - logger.Errorf("get whitelist error:%s", err.Error()) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wstool.WSError(ctx, session, logger, err, action, req.ReqID) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, err, "get whitelist error") return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, t.ip) if !valid { - logger.Errorf("whitelist prohibits current IP access, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wstool.WSErrorMsg(ctx, session, logger, 0xffff, "whitelist prohibits current IP access", action, req.ReqID) + errorExt := fmt.Sprintf("whitelist prohibits current IP access, ip:%s, whitelist:%s", t.ipStr, tool.IpNetSliceToString(whitelist)) + err = errors.New("whitelist prohibits current IP access") + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, err, errorExt) return } logger.Trace("register change whitelist") err = tool.RegisterChangeWhitelist(conn, t.whitelistChangeHandle) if err != nil { - logger.Errorf("register change whitelist error:%s", err) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wstool.WSError(ctx, session, logger, err, action, req.ReqID) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, err, "register change whitelist error") return } logger.Trace("register drop user") err = tool.RegisterDropUser(conn, t.dropUserHandle) if err != nil { - logger.Errorf("register drop user error:%s", err) - t.closeConsumerWithErrLog(logger, isDebug, cPointer) - wstool.WSError(ctx, session, logger, err, action, req.ReqID) + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, err, "register drop user error") + return + } + + // set connection ip + clientIP := t.ipStr + if req.IP != "" { + clientIP = req.IP + } + logger.Tracef("set connection ip, ip:%s", clientIP) + code := syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, &clientIP, logger, isDebug) + logger.Trace("set connection ip done") + if code != 0 { + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set connection ip error") return } + // set timezone + if req.TZ != "" { + logger.Tracef("set timezone, tz:%s", req.TZ) + code = syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_TIMEZONE, &req.TZ, logger, isDebug) + logger.Trace("set timezone done") + if code != 0 { + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set timezone error") + return + } + } + // set connection app + if req.App != "" { + logger.Tracef("set app, app:%s", req.App) + code = syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &req.App, logger, isDebug) + logger.Trace("set app done") + if code != 0 { + t.closeConsumerWithErrLog(ctx, cPointer, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set app error") + return + } + } + t.conn = conn t.consumer = cPointer logger.Trace("start to wait signal") @@ -599,12 +626,24 @@ func (t *TMQ) subscribe(ctx context.Context, session *melody.Session, req *TMQSu }) } -func (t *TMQ) closeConsumerWithErrLog(logger *logrus.Entry, isDebug bool, consumer unsafe.Pointer) { +func (t *TMQ) closeConsumerWithErrLog( + ctx context.Context, + consumer unsafe.Pointer, + session *melody.Session, + logger *logrus.Entry, + isDebug bool, + action string, + reqID uint64, + err error, + errorExt string, +) { + logger.Errorf("%s, err: %s", errorExt, err) errCode := t.wrapperCloseConsumer(logger, isDebug, consumer) if errCode != 0 { errMsg := wrapper.TMQErr2Str(errCode) logger.Errorf("tmq close consumer error, consumer:%p, code:%d, msg:%s", t.consumer, errCode, errMsg) } + wstool.WSError(ctx, session, logger, err, action, reqID) } type TMQCommitReq struct { diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index 69fabf70..ca4c2d7e 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -3206,3 +3206,91 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { // resp, err := doWebSocket(ws, wstool.ClientVersion, nil) // assert.Error(t, err, string(resp)) //} + +type httpQueryResp struct { + Code int `json:"code,omitempty"` + Desc string `json:"desc,omitempty"` + ColumnMeta [][]driver.Value `json:"column_meta,omitempty"` + Data [][]driver.Value `json:"data,omitempty"` + Rows int `json:"rows,omitempty"` +} + +func restQuery(sql string, db string) *httpQueryResp { + w := httptest.NewRecorder() + body := strings.NewReader(sql) + url := "/rest/sql" + if db != "" { + url = fmt.Sprintf("/rest/sql/%s", db) + } + req, _ := http.NewRequest(http.MethodPost, url, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + return &httpQueryResp{ + Code: w.Code, + Desc: w.Body.String(), + } + } + b, _ := io.ReadAll(w.Body) + var res httpQueryResp + _ = json.Unmarshal(b, &res) + return &res +} + +func TestConnectionOptions(t *testing.T) { + dbName := "test_ws_tmq_conn_options" + topic := "test_ws_tmq_conn_options_topic" + + before(t, dbName, topic) + + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() + + // subscribe + b, _ := json.Marshal(TMQSubscribeReq{ + User: "root", + Password: "taosdata", + DB: dbName, + GroupID: "test", + Topics: []string{topic}, + AutoCommit: "false", + OffsetReset: "earliest", + SessionTimeoutMS: "100000", + App: "tmq_test_conn_protocol", + IP: "192.168.55.55", + TZ: "Asia/Shanghai", + }) + msg, err := doWebSocket(ws, TMQSubscribe, b) + assert.NoError(t, err) + var subscribeResp TMQSubscribeResp + err = json.Unmarshal(msg, &subscribeResp) + assert.NoError(t, err) + assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) + + // check connection options + got := false + for i := 0; i < 10; i++ { + queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'tmq_test_conn_protocol' and user_ip = '192.168.55.55'", "") + if queryResp.Code == 0 && len(queryResp.Data) > 0 { + got = true + break + } + time.Sleep(time.Second) + } + assert.True(t, got) +} diff --git a/controller/ws/ws/const.go b/controller/ws/ws/const.go index a1951a7b..d3f2254a 100644 --- a/controller/ws/ws/const.go +++ b/controller/ws/ws/const.go @@ -42,6 +42,9 @@ const ( STMT2Exec = "stmt2_exec" STMT2Result = "stmt2_result" STMT2Close = "stmt2_close" + + // options + OptionsConnection = "options_connection" ) const ( diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 9f8ff5a6..d9f3c301 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -606,6 +606,20 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { config.ReqIDKey: req.ReqID, }) h.getServerInfo(ctx, session, action, req, logger, log.IsDebug()) + case OptionsConnection: + action = OptionsConnection + var req optionsConnectionRequest + if err := json.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal options connection request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal options connection request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.optionsConnection(ctx, session, action, req, logger, log.IsDebug()) default: h.logger.Errorf("unknown action %s", action) reqID := getReqID(request.Args) diff --git a/controller/ws/ws/handler_test.go b/controller/ws/ws/handler_test.go index dbd232fe..00ac96b4 100644 --- a/controller/ws/ws/handler_test.go +++ b/controller/ws/ws/handler_test.go @@ -264,6 +264,12 @@ func Test_WrongJsonProtocol(t *testing.T) { args: "wrong", errorPrefix: "unmarshal get server info request error", }, + { + name: "options connection with wrong args", + action: OptionsConnection, + args: "wrong", + errorPrefix: "unmarshal options connection request error", + }, { name: "unknown action", action: "unknown", diff --git a/controller/ws/ws/misc.go b/controller/ws/ws/misc.go index da4e6342..141d873f 100644 --- a/controller/ws/ws/misc.go +++ b/controller/ws/ws/misc.go @@ -7,6 +7,7 @@ import ( "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/syncinterface" errors2 "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/tools/melody" ) @@ -65,3 +66,30 @@ func (h *messageHandler) getServerInfo(ctx context.Context, session *melody.Sess } wstool.WSWriteJson(session, logger, resp) } + +type optionsConnectionRequest struct { + ReqID uint64 `json:"req_id"` + Options []*option `json:"options"` +} +type option struct { + Option int `json:"option"` + Value *string `json:"value"` +} + +func (h *messageHandler) optionsConnection(ctx context.Context, session *melody.Session, action string, req optionsConnectionRequest, logger *logrus.Entry, isDebug bool) { + logger.Trace("options connection") + for i := 0; i < len(req.Options); i++ { + code := syncinterface.TaosOptionsConnection(h.conn, req.Options[i].Option, req.Options[i].Value, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosErrorStr(nil) + val := "" + if req.Options[i].Value != nil { + val = *req.Options[i].Value + } + logger.Errorf("options connection error, option:%d, value:%s, code:%d, err:%s", req.Options[i].Option, val, code, errStr) + commonErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr) + return + } + } + commonSuccessResponse(ctx, session, logger, action, req.ReqID) +} diff --git a/controller/ws/ws/misc_test.go b/controller/ws/ws/misc_test.go index a2bafdf4..d605326b 100644 --- a/controller/ws/ws/misc_test.go +++ b/controller/ws/ws/misc_test.go @@ -6,9 +6,11 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" ) func TestGetCurrentDB(t *testing.T) { @@ -88,3 +90,95 @@ func TestGetServerInfo(t *testing.T) { assert.Equal(t, 0, serverInfoResp.Code, serverInfoResp.Message) t.Log(serverInfoResp.Info) } + +func TestOptionsConnection(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + // connect + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + + // set app name + app := "ws_test_options" + optionsConnectionReq := optionsConnectionRequest{ + ReqID: 2, + Options: []*option{ + {Option: common.TSDB_OPTION_CONNECTION_USER_APP, Value: &app}, + }, + } + resp, err = doWebSocket(ws, OptionsConnection, &optionsConnectionReq) + assert.NoError(t, err) + var optionsConnectionResp commonResp + err = json.Unmarshal(resp, &optionsConnectionResp) + assert.NoError(t, err) + assert.Equal(t, uint64(2), optionsConnectionResp.ReqID) + assert.Equal(t, 0, optionsConnectionResp.Code, optionsConnectionResp.Message) + + // get app name + got := false + for i := 0; i < 10; i++ { + queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'ws_test_options'", "") + if queryResp.Code == 0 && len(queryResp.Data) > 0 { + got = true + break + } + time.Sleep(time.Second) + } + assert.True(t, got) + // clear app name + optionsConnectionReq = optionsConnectionRequest{ + ReqID: 3, + Options: []*option{ + {Option: common.TSDB_OPTION_CONNECTION_USER_APP, Value: nil}, + }, + } + resp, err = doWebSocket(ws, OptionsConnection, &optionsConnectionReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &optionsConnectionResp) + assert.NoError(t, err) + assert.Equal(t, uint64(3), optionsConnectionResp.ReqID) + assert.Equal(t, 0, optionsConnectionResp.Code, optionsConnectionResp.Message) + + // wrong option with nil value + optionsConnectionReq = optionsConnectionRequest{ + ReqID: 4, + Options: []*option{ + {Option: -10000, Value: nil}, + }, + } + resp, err = doWebSocket(ws, OptionsConnection, &optionsConnectionReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &optionsConnectionResp) + assert.NoError(t, err) + assert.Equal(t, uint64(4), optionsConnectionResp.ReqID) + assert.NotEqual(t, 0, optionsConnectionResp.Code) + // wrong option with non-nil value + optionsConnectionReq = optionsConnectionRequest{ + ReqID: 5, + Options: []*option{ + {Option: -10000, Value: &app}, + }, + } + resp, err = doWebSocket(ws, OptionsConnection, &optionsConnectionReq) + assert.NoError(t, err) + err = json.Unmarshal(resp, &optionsConnectionResp) + assert.NoError(t, err) + assert.Equal(t, uint64(5), optionsConnectionResp.ReqID) + assert.NotEqual(t, 0, optionsConnectionResp.Code) +} diff --git a/controller/ws/ws/query.go b/controller/ws/ws/query.go index 22162c81..7c71fb94 100644 --- a/controller/ws/ws/query.go +++ b/controller/ws/ws/query.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "unsafe" "github.com/sirupsen/logrus" "github.com/taosdata/taosadapter/v3/controller/ws/wstool" @@ -12,7 +13,7 @@ import ( "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" "github.com/taosdata/taosadapter/v3/driver/common" - errors2 "github.com/taosdata/taosadapter/v3/driver/errors" + taoserrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" @@ -27,6 +28,9 @@ type connRequest struct { Password string `json:"password"` DB string `json:"db"` Mode *int `json:"mode"` + TZ string `json:"tz"` + App string `json:"app"` + IP string `json:"ip"` } func (h *messageHandler) connect(ctx context.Context, session *melody.Session, action string, req connRequest, logger *logrus.Entry, isDebug bool) { @@ -45,10 +49,7 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug) if err != nil { - logger.Errorf("connect to TDengine error, err:%s", err) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "connect to TDengine error") return } logger.Trace("get whitelist") @@ -56,19 +57,14 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a whitelist, err := tool.GetWhitelist(conn) logger.Debugf("get whitelist cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { - logger.Errorf("get whitelist error, err:%s", err) - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "get whitelist error") return } logger.Tracef("check whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) valid := tool.CheckWhitelist(whitelist, h.ip) if !valid { - logger.Errorf("ip not in whitelist, ip:%s, whitelist:%s", h.ipStr, tool.IpNetSliceToString(whitelist)) - syncinterface.TaosClose(conn, logger, isDebug) - commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, "whitelist prohibits current IP access") + err = errors.New("ip not in whitelist") + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "ip not in whitelist") return } s = log.GetLogNow(isDebug) @@ -76,11 +72,7 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a err = tool.RegisterChangeWhitelist(conn, h.whitelistChangeHandle) logger.Debugf("register whitelist change cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { - logger.Errorf("register whitelist change error, err:%s", err) - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "register whitelist change error") return } s = log.GetLogNow(isDebug) @@ -88,11 +80,7 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a err = tool.RegisterDropUser(conn, h.dropUserHandle) logger.Debugf("register drop user cost:%s", log.GetLogDuration(isDebug, s)) if err != nil { - logger.Errorf("register drop user error, err:%s", err) - syncinterface.TaosClose(conn, logger, isDebug) - var taosErr *errors2.TaosError - errors.As(err, &taosErr) - commonErrorResponse(ctx, session, logger, action, req.ReqID, int(taosErr.Code), taosErr.ErrStr) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, "register drop user error") return } if req.Mode != nil { @@ -103,15 +91,44 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a code := wrapper.TaosSetConnMode(conn, common.TAOS_CONN_MODE_BI, 1) logger.Trace("set connection mode to BI done") if code != 0 { - logger.Errorf("set connection mode to BI error, err:%s", wrapper.TaosErrorStr(nil)) - syncinterface.TaosClose(conn, logger, isDebug) - commonErrorResponse(ctx, session, logger, action, req.ReqID, code, wrapper.TaosErrorStr(nil)) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set connection mode to BI error") return } default: - syncinterface.TaosClose(conn, logger, isDebug) - logger.Tracef("unexpected mode:%d", *req.Mode) - commonErrorResponse(ctx, session, logger, action, req.ReqID, 0xffff, fmt.Sprintf("unexpected mode:%d", *req.Mode)) + err = fmt.Errorf("unexpected mode:%d", *req.Mode) + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, err, err.Error()) + return + } + } + // set connection ip + clientIP := h.ipStr + if req.IP != "" { + clientIP = req.IP + } + logger.Tracef("set connection ip, ip:%s", clientIP) + code := syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, &clientIP, logger, isDebug) + logger.Trace("set connection ip done") + if code != 0 { + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set connection ip error") + return + } + // set timezone + if req.TZ != "" { + logger.Tracef("set timezone, tz:%s", req.TZ) + code = syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_TIMEZONE, &req.TZ, logger, isDebug) + logger.Trace("set timezone done") + if code != 0 { + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set timezone error") + return + } + } + // set connection app + if req.App != "" { + logger.Tracef("set app, app:%s", req.App) + code = syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &req.App, logger, isDebug) + logger.Trace("set app done") + if code != 0 { + handleConnectError(ctx, conn, session, logger, isDebug, action, req.ReqID, taoserrors.NewError(code, wrapper.TaosErrorStr(nil)), "set app error") return } } @@ -121,6 +138,22 @@ func (h *messageHandler) connect(ctx context.Context, session *melody.Session, a commonSuccessResponse(ctx, session, logger, action, req.ReqID) } +func handleConnectError(ctx context.Context, conn unsafe.Pointer, session *melody.Session, logger *logrus.Entry, isDebug bool, action string, reqID uint64, err error, errorExt string) { + var code int + var errStr string + taosError, ok := err.(*taoserrors.TaosError) + if ok { + code = int(taosError.Code) + errStr = taosError.ErrStr + } else { + code = 0xffff + errStr = err.Error() + } + logger.Errorf("%s, code:%d, message:%s", errorExt, code, errStr) + syncinterface.TaosClose(conn, logger, isDebug) + commonErrorResponse(ctx, session, logger, action, reqID, code, errStr) +} + type queryRequest struct { ReqID uint64 `json:"req_id"` Sql string `json:"sql"` diff --git a/controller/ws/ws/query_test.go b/controller/ws/ws/query_test.go index 93c58bb4..a9a9009a 100644 --- a/controller/ws/ws/query_test.go +++ b/controller/ws/ws/query_test.go @@ -62,14 +62,6 @@ func TestWSConnect(t *testing.T) { assert.Equal(t, "duplicate connections", connResp.Message) } -type TestConnRequest struct { - ReqID uint64 `json:"req_id"` - User string `json:"user"` - Password string `json:"password"` - DB string `json:"db"` - Mode int `json:"mode"` -} - func TestMode(t *testing.T) { s := httptest.NewServer(router) defer s.Close() @@ -84,7 +76,7 @@ func TestMode(t *testing.T) { }() wrongMode := 999 - connReq := TestConnRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: wrongMode} + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: &wrongMode} resp, err := doWebSocket(ws, Connect, &connReq) assert.NoError(t, err) var connResp commonResp @@ -96,7 +88,7 @@ func TestMode(t *testing.T) { //bi biMode := 0 - connReq = TestConnRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: biMode} + connReq = connRequest{ReqID: 1, User: "root", Password: "taosdata", Mode: &biMode} resp, err = doWebSocket(ws, Connect, &connReq) assert.NoError(t, err) err = json.Unmarshal(resp, &connResp) @@ -106,9 +98,38 @@ func TestMode(t *testing.T) { } -func TestWrongConnect(t *testing.T) { - // mock tool.GetWhitelist return error +func TestConnectionOptions(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/ws", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + connReq := connRequest{ReqID: 1, User: "root", Password: "taosdata", IP: "192.168.44.55", App: "ws_test_conn_protocol", TZ: "Asia/Shanghai"} + resp, err := doWebSocket(ws, Connect, &connReq) + assert.NoError(t, err) + var connResp commonResp + err = json.Unmarshal(resp, &connResp) + assert.NoError(t, err) + assert.Equal(t, uint64(1), connResp.ReqID) + assert.Equal(t, 0, connResp.Code, connResp.Message) + // check connection options + got := false + for i := 0; i < 10; i++ { + queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'ws_test_conn_protocol' and user_ip = '192.168.44.55'", "") + if queryResp.Code == 0 && len(queryResp.Data) > 0 { + got = true + break + } + time.Sleep(time.Second) + } + assert.True(t, got) } func TestWsQuery(t *testing.T) { diff --git a/db/commonpool/pool.go b/db/commonpool/pool.go index 9af5d545..775c91c1 100644 --- a/db/commonpool/pool.go +++ b/db/commonpool/pool.go @@ -14,6 +14,7 @@ import ( "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/syncinterface" "github.com/taosdata/taosadapter/v3/db/tool" + "github.com/taosdata/taosadapter/v3/driver/common" tErrors "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/driver/wrapper/cgo" @@ -196,6 +197,7 @@ func (cp *ConnectorPool) Get() (unsafe.Pointer, error) { func (cp *ConnectorPool) Put(c unsafe.Pointer) error { wrapper.TaosResetCurrentDB(c) + wrapper.TaosOptionsConnection(c, common.TSDB_OPTION_CONNECTION_CLEAR, nil) return cp.pool.Put(c) } @@ -308,6 +310,9 @@ func getConnectDirect(connectionPool *ConnectorPool, clientIP net.IP) (*Conn, er if err != nil { return nil, err } + ipStr := clientIP.String() + // ignore error, because we have checked the ip + wrapper.TaosOptionsConnection(c, common.TSDB_OPTION_CONNECTION_USER_IP, &ipStr) return &Conn{ TaosConnection: c, pool: connectionPool, diff --git a/db/syncinterface/wrapper.go b/db/syncinterface/wrapper.go index dd9c6222..ae189f17 100644 --- a/db/syncinterface/wrapper.go +++ b/db/syncinterface/wrapper.go @@ -421,3 +421,15 @@ func TaosStmt2GetStbFields(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug b thread.SyncLocker.Unlock() return code, count, fields } + +func TaosOptionsConnection(conn unsafe.Pointer, option int, value *string, logger *logrus.Entry, isDebug bool) int { + if value == nil { + logger.Tracef("call taos_options_connection, conn:%p, option:%d, value:", conn, option) + } else { + logger.Tracef("call taos_options_connection, conn:%p, option:%d, value:%s", conn, option, *value) + } + s := log.GetLogNow(isDebug) + code := wrapper.TaosOptionsConnection(conn, option, value) + logger.Debugf("taos_options_connection finish, code:%d, cost:%s", code, log.GetLogDuration(isDebug, s)) + return code +} diff --git a/db/syncinterface/wrapper_test.go b/db/syncinterface/wrapper_test.go index 9917b7e8..ddd35808 100644 --- a/db/syncinterface/wrapper_test.go +++ b/db/syncinterface/wrapper_test.go @@ -641,3 +641,26 @@ func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { } return result, nil } + +func TestTaosOptionsConnection(t *testing.T) { + reqID := generator.GetReqID() + var logger = logger.WithField("test", "TestTaosOptionsConnection").WithField(config.ReqIDKey, reqID) + conn, err := TaosConnect("", "root", "taosdata", "", 0, logger, isDebug) + if !assert.NoError(t, err) { + return + } + defer TaosClose(conn, logger, isDebug) + app := "test_sync_interface" + code := TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosErrorStr(nil) + t.Error(t, taoserrors.NewError(code, errStr)) + return + } + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, nil, logger, isDebug) + if code != 0 { + errStr := wrapper.TaosErrorStr(nil) + t.Error(t, taoserrors.NewError(code, errStr)) + return + } +} diff --git a/driver/common/const.go b/driver/common/const.go index 49efee48..ae97d949 100644 --- a/driver/common/const.go +++ b/driver/common/const.go @@ -24,6 +24,14 @@ const ( TSDB_OPTION_USE_ADAPTER ) +const ( + TSDB_OPTION_CONNECTION_CLEAR = iota - 1 + TSDB_OPTION_CONNECTION_CHARSET + TSDB_OPTION_CONNECTION_TIMEZONE + TSDB_OPTION_CONNECTION_USER_IP + TSDB_OPTION_CONNECTION_USER_APP +) + const ( TMQ_RES_INVALID = -1 TMQ_RES_DATA = 1 diff --git a/driver/wrapper/taosc.go b/driver/wrapper/taosc.go index e4ccb12b..de3f62ea 100644 --- a/driver/wrapper/taosc.go +++ b/driver/wrapper/taosc.go @@ -27,6 +27,9 @@ void taos_query_a_with_req_id_wrapper(TAOS *taos,const char *sql, void *param, i void taos_fetch_raw_block_a_wrapper(TAOS_RES *res, void *param){ return taos_fetch_raw_block_a(res,FetchRawBlockCallback,param); }; +int taos_options_connection_wrapper(TAOS *taos, TSDB_OPTION_CONNECTION option, void *arg) { + return taos_options_connection(taos,option,arg); +}; */ import "C" import ( @@ -287,3 +290,13 @@ func TaosGetServerInfo(conn unsafe.Pointer) string { info := C.taos_get_server_info(conn) return C.GoString(info) } + +// TaosOptionsConnection int taos_options_connection(TAOS *taos, TSDB_OPTION_CONNECTION option, const void *arg, ...) +func TaosOptionsConnection(conn unsafe.Pointer, option int, value *string) int { + cValue := unsafe.Pointer(nil) + if value != nil { + cValue = unsafe.Pointer(C.CString(*value)) + defer C.free(cValue) + } + return int(C.taos_options_connection_wrapper(conn, (C.TSDB_OPTION_CONNECTION)(option), cValue)) +} diff --git a/driver/wrapper/taosc_test.go b/driver/wrapper/taosc_test.go index 7a2c268e..ec533323 100644 --- a/driver/wrapper/taosc_test.go +++ b/driver/wrapper/taosc_test.go @@ -605,3 +605,111 @@ func TestTaosGetServerInfo(t *testing.T) { info := TaosGetServerInfo(conn) assert.NotEmpty(t, info) } + +func TestTaosOptionsConnection(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + ip := "192.168.9.9" + app := "test_options_connection" + // set ip + code := TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, &ip) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + // set app + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + var values [][]driver.Value + for i := 0; i < 10; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_connection'") + assert.NoError(t, err) + if len(values) == 1 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 1, len(values)) + connID := values[0][0].(uint32) + + // clean app + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, nil) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 10; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_connection'") + assert.NoError(t, err) + if len(values) == 0 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 0, len(values)) + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9'") + assert.NoError(t, err) + assert.Equal(t, 1, len(values)) + assert.Equal(t, connID, values[0][0].(uint32)) + + // set app + app = "test_options_2" + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 20; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_2'") + assert.NoError(t, err) + if len(values) == 1 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 1, len(values)) + assert.Equal(t, connID, values[0][0].(uint32)) + + // clear ip + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, nil) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 10; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_2'") + assert.NoError(t, err) + if len(values) == 0 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 0, len(values)) + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_app = 'test_options_2'") + assert.NoError(t, err) + assert.Equal(t, 1, len(values)) + assert.Equal(t, connID, values[0][0].(uint32)) + + // clean all + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_CLEAR, nil) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 10; i++ { + values, err = query(conn, fmt.Sprintf("select user_app,user_ip from performance_schema.perf_connections where conn_id = %d", connID)) + assert.NoError(t, err) + if len(values) == 1 && values[0][0].(string) == "" && values[0][1].(string) == "" { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 1, len(values)) + assert.Equal(t, "", values[0][0].(string)) + assert.Equal(t, "", values[0][1].(string)) + + ip = "192.168.9.9" + app = "test_options_connection" + // set ip + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_IP, &ip) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + // set app + code = TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app) + assert.Equal(t, 0, code, TaosErrorStr(nil)) + for i := 0; i < 10; i++ { + values, err = query(conn, "select conn_id from performance_schema.perf_connections where user_ip = '192.168.9.9' and user_app = 'test_options_connection'") + assert.NoError(t, err) + if len(values) == 1 { + break + } + time.Sleep(time.Second) + } + assert.Equal(t, 1, len(values)) + assert.Equal(t, connID, values[0][0].(uint32)) +} diff --git a/plugin/influxdb/plugin.go b/plugin/influxdb/plugin.go index 3c1696ac..cfcd62ee 100644 --- a/plugin/influxdb/plugin.go +++ b/plugin/influxdb/plugin.go @@ -9,7 +9,10 @@ import ( "github.com/gin-gonic/gin" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/commonpool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/common" tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/plugin" @@ -188,6 +191,13 @@ func (p *Influxdb) write(c *gin.Context) { } }() conn := taosConn.TaosConnection + app := c.Query("app") + if app != "" { + errCode := syncinterface.TaosOptionsConnection(conn, common.TSDB_OPTION_CONNECTION_USER_APP, &app, logger, isDebug) + if errCode != 0 { + logger.Errorf("set app error, app:%s, code:%d, msg:%s", app, errCode, wrapper.TaosErrorStr(nil)) + } + } s = log.GetLogNow(isDebug) logger.Tracef("start insert influxdb, data:%s", data) err = inserter.InsertInfluxdb(conn, data, db, precision, ttl, reqID, tableNameKey, logger) diff --git a/plugin/influxdb/plugin_test.go b/plugin/influxdb/plugin_test.go index 552f2468..1340b71b 100644 --- a/plugin/influxdb/plugin_test.go +++ b/plugin/influxdb/plugin_test.go @@ -61,19 +61,19 @@ func TestInfluxdb(t *testing.T) { }() w := httptest.NewRecorder() reader := strings.NewReader(fmt.Sprintf("measurement,host=host1 field1=%di,field2=2.0,fieldKey=\"Launch 🚀\" %d", number, time.Now().UnixNano())) - req, _ := http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb", reader) + req, _ := http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb&app=test_influxdb", reader) req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) assert.Equal(t, 204, w.Code) w = httptest.NewRecorder() reader = strings.NewReader("measurement,host=host1 field1=a1") - req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb", reader) + req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb&app=test_influxdb", reader) req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) assert.Equal(t, 500, w.Code) w = httptest.NewRecorder() reader = strings.NewReader(fmt.Sprintf("measurement,host=host1 field1=%di,field2=2.0,fieldKey=\"Launch 🚀\" %d", number, time.Now().UnixNano())) - req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata", reader) + req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&app=test_influxdb", reader) req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) @@ -91,7 +91,7 @@ func TestInfluxdb(t *testing.T) { w = httptest.NewRecorder() reader = strings.NewReader(fmt.Sprintf("measurement_ttl,host=host1 field1=%di,field2=2.0,fieldKey=\"Launch 🚀\" %d", number, time.Now().UnixNano())) - req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb_ttl&ttl=1000", reader) + req, _ = http.NewRequest("POST", "/write?u=root&p=taosdata&db=test_plugin_influxdb_ttl&ttl=1000&app=test_influxdb", reader) req.RemoteAddr = "127.0.0.1:33333" router.ServeHTTP(w, req) time.Sleep(time.Second) diff --git a/plugin/opentsdb/plugin.go b/plugin/opentsdb/plugin.go index f22cfc35..a5229fac 100644 --- a/plugin/opentsdb/plugin.go +++ b/plugin/opentsdb/plugin.go @@ -11,7 +11,10 @@ import ( "github.com/gin-gonic/gin" "github.com/taosdata/taosadapter/v3/config" "github.com/taosdata/taosadapter/v3/db/commonpool" + "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/common" tErrors "github.com/taosdata/taosadapter/v3/driver/errors" + "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" "github.com/taosdata/taosadapter/v3/monitor" "github.com/taosdata/taosadapter/v3/plugin" @@ -152,6 +155,13 @@ func (p *Plugin) insertJson(c *gin.Context) { logger.WithError(putErr).Errorln("connect pool put error") } }() + app := c.Query("app") + if app != "" { + errCode := syncinterface.TaosOptionsConnection(taosConn.TaosConnection, common.TSDB_OPTION_CONNECTION_USER_APP, &app, logger, isDebug) + if errCode != 0 { + logger.Errorf("set app error, app:%s, code:%d, msg:%s", app, errCode, wrapper.TaosErrorStr(nil)) + } + } s = log.GetLogNow(isDebug) logger.Debugf("insert json payload, data:%s, db:%s, ttl:%d, table_name_key:%s", data, db, ttl, tableNameKey) err = inserter.InsertOpentsdbJson(taosConn.TaosConnection, data, db, ttl, reqID, tableNameKey, logger) @@ -271,6 +281,13 @@ func (p *Plugin) insertTelnet(c *gin.Context) { logger.WithError(putErr).Errorln("connect pool put error") } }() + app := c.Query("app") + if app != "" { + errCode := syncinterface.TaosOptionsConnection(taosConn.TaosConnection, common.TSDB_OPTION_CONNECTION_USER_APP, &app, logger, isDebug) + if errCode != 0 { + logger.Errorf("set app error, app:%s, code:%d, msg:%s", app, errCode, wrapper.TaosErrorStr(nil)) + } + } s = log.GetLogNow(isDebug) logger.Debugf("insert telnet payload, lines:%v, db:%s, ttl:%d, table_name_key: %s", lines, db, ttl, tableNameKey) err = inserter.InsertOpentsdbTelnetBatch(taosConn.TaosConnection, lines, db, ttl, reqID, tableNameKey, logger) diff --git a/plugin/opentsdb/plugin_test.go b/plugin/opentsdb/plugin_test.go index a2101fa0..4317e885 100644 --- a/plugin/opentsdb/plugin_test.go +++ b/plugin/opentsdb/plugin_test.go @@ -58,7 +58,7 @@ func TestOpentsdb(t *testing.T) { }() w := httptest.NewRecorder() reader := strings.NewReader(fmt.Sprintf("put metric %d %d host=web01 interface=eth0 ", time.Now().Unix(), number)) - req, _ := http.NewRequest("POST", "/put/telnet/test_plugin_opentsdb_http_telnet?ttl=1000", reader) + req, _ := http.NewRequest("POST", "/put/telnet/test_plugin_opentsdb_http_telnet?ttl=1000&app=test_telnet_http", reader) req.RemoteAddr = "127.0.0.1:33333" req.SetBasicAuth("root", "taosdata") router.ServeHTTP(w, req) @@ -73,7 +73,7 @@ func TestOpentsdb(t *testing.T) { "dc": "lga" } }`, time.Now().Unix(), number)) - req, _ = http.NewRequest("POST", "/put/json/test_plugin_opentsdb_http_json?ttl=1000", reader) + req, _ = http.NewRequest("POST", "/put/json/test_plugin_opentsdb_http_json?ttl=1000&app=test_json_http", reader) req.RemoteAddr = "127.0.0.1:33333" req.SetBasicAuth("root", "taosdata") router.ServeHTTP(w, req) diff --git a/tools/ctools/block.go b/tools/ctools/block.go index 72f0f08f..d7abdc76 100644 --- a/tools/ctools/block.go +++ b/tools/ctools/block.go @@ -3,16 +3,18 @@ package ctools import ( "math" "strconv" + "time" "unsafe" + "github.com/sirupsen/logrus" "github.com/taosdata/taosadapter/v3/driver/common" "github.com/taosdata/taosadapter/v3/driver/common/parser" "github.com/taosdata/taosadapter/v3/tools" + "github.com/taosdata/taosadapter/v3/tools/bytesutil" "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" + "github.com/taosdata/taosadapter/v3/tools/layout" ) -type FormatTimeFunc func(builder *jsonbuilder.Stream, ts int64, precision int) - func IsVarDataType(colType uint8) bool { return colType == common.TSDB_DATA_TYPE_BINARY || colType == common.TSDB_DATA_TYPE_NCHAR || colType == common.TSDB_DATA_TYPE_JSON || colType == common.TSDB_DATA_TYPE_VARBINARY || colType == common.TSDB_DATA_TYPE_GEOMETRY } @@ -87,9 +89,20 @@ func WriteRawJsonDouble(builder *jsonbuilder.Stream, pStart unsafe.Pointer, row builder.WriteFloat64(math.Float64frombits(*((*uint64)(tools.AddPointer(pStart, uintptr(row)*parser.Float64Size))))) } -func WriteRawJsonTime(builder *jsonbuilder.Stream, pStart unsafe.Pointer, row int, precision int, timeFormat FormatTimeFunc) { - value := *((*int64)(tools.AddPointer(pStart, uintptr(row)*parser.Int64Size))) - timeFormat(builder, value, precision) +func WriteRawJsonTime(builder *jsonbuilder.Stream, pStart unsafe.Pointer, row int, precision int, location *time.Location, timeBuffer []byte, logger *logrus.Entry) { + ts := *((*int64)(tools.AddPointer(pStart, uintptr(row)*parser.Int64Size))) + timeBuffer = timeBuffer[:0] + switch precision { + case common.PrecisionMilliSecond: // milli-second + timeBuffer = time.Unix(ts/1e3, (ts%1e3)*1e6).In(location).AppendFormat(timeBuffer, layout.LayoutMillSecond) + case common.PrecisionMicroSecond: // micro-second + timeBuffer = time.Unix(ts/1e6, (ts%1e6)*1e3).In(location).AppendFormat(timeBuffer, layout.LayoutMicroSecond) + case common.PrecisionNanoSecond: // nano-second + timeBuffer = time.Unix(0, ts).In(location).AppendFormat(timeBuffer, layout.LayoutNanoSecond) + default: + logger.Errorf("unknown precision:%d", precision) + } + builder.WriteString(bytesutil.ToUnsafeString(timeBuffer)) } func WriteRawJsonBinary(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Pointer, row int) { @@ -167,7 +180,7 @@ func WriteRawJsonJson(builder *jsonbuilder.Stream, pHeader, pStart unsafe.Pointe } } -func JsonWriteRawBlock(builder *jsonbuilder.Stream, colType uint8, pHeader, pStart unsafe.Pointer, row int, precision int, timeFormat FormatTimeFunc) { +func JsonWriteRawBlock(builder *jsonbuilder.Stream, colType uint8, pHeader, pStart unsafe.Pointer, row int, precision int, location *time.Location, timeBuffer []byte, logger *logrus.Entry) { if IsVarDataType(colType) { switch colType { case uint8(common.TSDB_DATA_TYPE_BINARY): @@ -209,7 +222,7 @@ func JsonWriteRawBlock(builder *jsonbuilder.Stream, colType uint8, pHeader, pSta case uint8(common.TSDB_DATA_TYPE_DOUBLE): WriteRawJsonDouble(builder, pStart, row) case uint8(common.TSDB_DATA_TYPE_TIMESTAMP): - WriteRawJsonTime(builder, pStart, row, precision, timeFormat) + WriteRawJsonTime(builder, pStart, row, precision, location, timeBuffer, logger) } } } diff --git a/tools/ctools/block_test.go b/tools/ctools/block_test.go index 03390fba..f811d208 100644 --- a/tools/ctools/block_test.go +++ b/tools/ctools/block_test.go @@ -6,8 +6,8 @@ import ( "time" "unsafe" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "github.com/taosdata/taosadapter/v3/driver/common" "github.com/taosdata/taosadapter/v3/driver/common/parser" "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" @@ -163,20 +163,7 @@ func TestJsonWriteRawBlock(t *testing.T) { for row := 0; row < blockSize; row++ { builder.WriteArrayStart() for column := 0; column < fieldsCount; column++ { - JsonWriteRawBlock(builder, fieldTypes[column], pHeaderList[column], pStartList[column], row, precision, func(builder *jsonbuilder.Stream, ts int64, precision int) { - timeBuffer = timeBuffer[:0] - switch precision { - case common.PrecisionMilliSecond: // milli-second - timeBuffer = time.Unix(0, ts*1e6).UTC().AppendFormat(timeBuffer, time.RFC3339Nano) - case common.PrecisionMicroSecond: // micro-second - timeBuffer = time.Unix(0, ts*1e3).UTC().AppendFormat(timeBuffer, time.RFC3339Nano) - case common.PrecisionNanoSecond: // nano-second - timeBuffer = time.Unix(0, ts).UTC().AppendFormat(timeBuffer, time.RFC3339Nano) - default: - panic("unknown precision") - } - builder.WriteString(string(timeBuffer)) - }) + JsonWriteRawBlock(builder, fieldTypes[column], pHeaderList[column], pStartList[column], row, precision, time.UTC, timeBuffer, logrus.New().WithField("test", "test")) if column != fieldsCount-1 { builder.WriteMore() err := builder.Flush() @@ -191,5 +178,5 @@ func TestJsonWriteRawBlock(t *testing.T) { builder.WriteObjectEnd() err := builder.Flush() assert.NoError(t, err) - assert.Equal(t, `{["2022-08-10T07:02:40.5Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["2022-08-10T07:02:41.5Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, w.String()) + assert.Equal(t, `{["2022-08-10T07:02:40.500Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["2022-08-10T07:02:41.500Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, w.String()) } From d701c43021262d4c6c8e8ecf4e405815d1a24a6d Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 10 Dec 2024 15:41:28 +0800 Subject: [PATCH 37/48] enh: stmt2 use result response change result_id to id and fix version action --- controller/ws/ws/handler.go | 14 +++++++++++++- controller/ws/ws/handler_test.go | 6 ++++++ controller/ws/ws/misc.go | 25 +++++++++++++++++++++++++ controller/ws/ws/stmt2.go | 4 ++-- controller/ws/ws/stmt2_test.go | 6 +++--- controller/ws/ws/ws_test.go | 7 +------ 6 files changed, 50 insertions(+), 12 deletions(-) diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 9f8ff5a6..388820c5 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -208,7 +208,19 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { // no need connection actions switch request.Action { case wstool.ClientVersion: - wstool.WSWriteVersion(session, h.logger) + action = wstool.ClientVersion + var req versionRequest + if err := json.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal version request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal version request error") + return + } + logger := h.logger.WithFields(logrus.Fields{ + actionKey: action, + config.ReqIDKey: req.ReqID, + }) + h.version(ctx, session, action, req, logger, log.IsDebug()) return case Connect: action = Connect diff --git a/controller/ws/ws/handler_test.go b/controller/ws/ws/handler_test.go index dbd232fe..f279c34b 100644 --- a/controller/ws/ws/handler_test.go +++ b/controller/ws/ws/handler_test.go @@ -102,6 +102,12 @@ func Test_WrongJsonProtocol(t *testing.T) { args: nil, errorPrefix: "request no action", }, + { + name: "version with wrong args", + action: wstool.ClientVersion, + args: "wrong", + errorPrefix: "unmarshal version request error", + }, { name: "connect with wrong args", action: Connect, diff --git a/controller/ws/ws/misc.go b/controller/ws/ws/misc.go index da4e6342..573f210f 100644 --- a/controller/ws/ws/misc.go +++ b/controller/ws/ws/misc.go @@ -8,6 +8,7 @@ import ( "github.com/taosdata/taosadapter/v3/db/syncinterface" errors2 "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/tools/melody" + "github.com/taosdata/taosadapter/v3/version" ) type getCurrentDBRequest struct { @@ -65,3 +66,27 @@ func (h *messageHandler) getServerInfo(ctx context.Context, session *melody.Sess } wstool.WSWriteJson(session, logger, resp) } + +type versionRequest struct { + ReqID uint64 `json:"req_id"` +} + +type versionResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + Version string `json:"version"` +} + +func (h *messageHandler) version(ctx context.Context, session *melody.Session, action string, req versionRequest, logger *logrus.Entry, isDebug bool) { + logger.Trace("get version") + resp := &versionResp{ + Action: action, + ReqID: req.ReqID, + Timing: wstool.GetDuration(ctx), + Version: version.TaosClientVersion, + } + wstool.WSWriteJson(session, logger, resp) +} diff --git a/controller/ws/ws/stmt2.go b/controller/ws/ws/stmt2.go index e2ff2d57..f7ccb5b4 100644 --- a/controller/ws/ws/stmt2.go +++ b/controller/ws/ws/stmt2.go @@ -200,7 +200,7 @@ type stmt2UseResultResponse struct { ReqID uint64 `json:"req_id"` Timing int64 `json:"timing"` StmtID uint64 `json:"stmt_id"` - ResultID uint64 `json:"result_id"` + ID uint64 `json:"id"` FieldsCount int `json:"fields_count"` FieldsNames []string `json:"fields_names"` FieldsTypes jsontype.JsonUint8 `json:"fields_types"` @@ -227,7 +227,7 @@ func (h *messageHandler) stmt2UseResult(ctx context.Context, session *melody.Ses ReqID: req.ReqID, Timing: wstool.GetDuration(ctx), StmtID: req.StmtID, - ResultID: idx, + ID: idx, FieldsCount: fieldsCount, FieldsNames: rowsHeader.ColNames, FieldsTypes: rowsHeader.ColTypes, diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go index 55f0fc16..8486ec40 100644 --- a/controller/ws/ws/stmt2_test.go +++ b/controller/ws/ws/stmt2_test.go @@ -510,7 +510,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.Equal(t, 0, useResultResp.Code, useResultResp.Message) // fetch - fetchReq := fetchRequest{ReqID: 8, ID: useResultResp.ResultID} + fetchReq := fetchRequest{ReqID: 8, ID: useResultResp.ID} resp, err = doWebSocket(ws, WSFetch, &fetchReq) assert.NoError(t, err) var fetchResp fetchResponse @@ -521,7 +521,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.Equal(t, 1, fetchResp.Rows) // fetch block - fetchBlockReq := fetchBlockRequest{ReqID: 9, ID: useResultResp.ResultID} + fetchBlockReq := fetchBlockRequest{ReqID: 9, ID: useResultResp.ID} fetchBlockResp, err := doWebSocket(ws, WSFetchBlock, &fetchBlockReq) assert.NoError(t, err) _, blockResult := parseblock.ParseBlock(fetchBlockResp[8:], useResultResp.FieldsTypes, fetchResp.Rows, useResultResp.Precision) @@ -531,7 +531,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { assert.Equal(t, float32(0.31), blockResult[0][3]) // free result - freeResultReq, _ := json.Marshal(freeResultRequest{ReqID: 10, ID: useResultResp.ResultID}) + freeResultReq, _ := json.Marshal(freeResultRequest{ReqID: 10, ID: useResultResp.ID}) a, _ := json.Marshal(Request{Action: WSFreeResult, Args: freeResultReq}) err = ws.WriteMessage(websocket.TextMessage, a) assert.NoError(t, err) diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index cf67435e..bc9278ba 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -131,11 +131,6 @@ func doWebSocketWithoutResp(ws *websocket.Conn, action string, arg interface{}) return nil } -type versionResponse struct { - commonResp - Version string -} - func TestVersion(t *testing.T) { s := httptest.NewServer(router) defer s.Close() @@ -150,7 +145,7 @@ func TestVersion(t *testing.T) { }() resp, err := doWebSocket(ws, wstool.ClientVersion, nil) assert.NoError(t, err) - var versionResp versionResponse + var versionResp versionResp err = json.Unmarshal(resp, &versionResp) assert.NoError(t, err) assert.Equal(t, 0, versionResp.Code, versionResp.Message) From 3bd2d747c3b1cbdf9676051feffdff5114228b58 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 10 Dec 2024 16:50:43 +0800 Subject: [PATCH 38/48] enh: stmt2 bind binary check minimum length --- driver/wrapper/stmt2_test.go | 12 ++++++++++++ driver/wrapper/stmt2binary.go | 7 +++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/driver/wrapper/stmt2_test.go b/driver/wrapper/stmt2_test.go index d09382a1..d9652b12 100644 --- a/driver/wrapper/stmt2_test.go +++ b/driver/wrapper/stmt2_test.go @@ -4288,6 +4288,18 @@ func TestTaosStmt2BindBinaryParse(t *testing.T) { args args wantErr assert.ErrorAssertionFunc }{ + { + name: "wrong data length", + args: args{ + sql: "insert into ? values (?,?)", + data: []byte{ + // total Length + 0x00, 0x00, 0x00, 0x00, + }, + colIdx: -1, + }, + wantErr: assert.Error, + }, { name: "normal table name", args: args{ diff --git a/driver/wrapper/stmt2binary.go b/driver/wrapper/stmt2binary.go index e6b3b0ff..919ff7bc 100644 --- a/driver/wrapper/stmt2binary.go +++ b/driver/wrapper/stmt2binary.go @@ -61,7 +61,7 @@ generate_stmt2_binds(char *data, uint32_t count, uint32_t field_count, uint32_t } -int taos_stmt2_bind_binary(TAOS_STMT2 *stmt, char *data, int32_t col_idx, char *err_msg) { +int go_stmt2_bind_binary(TAOS_STMT2 *stmt, char *data, int32_t col_idx, char *err_msg) { uint32_t *header = (uint32_t *) data; uint32_t total_length = header[0]; uint32_t count = header[1]; @@ -247,6 +247,9 @@ import ( // TaosStmt2BindBinary bind binary data to stmt2 func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error { + if len(data) < stmt.DataPosition { + return fmt.Errorf("data length is less than 28") + } totalLength := binary.LittleEndian.Uint32(data[stmt.TotalLengthPosition:]) if totalLength != uint32(len(data)) { return fmt.Errorf("total length not match, expect %d, but get %d", len(data), totalLength) @@ -256,7 +259,7 @@ func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32) error errMsg := (*C.char)(C.malloc(128)) defer C.free(unsafe.Pointer(errMsg)) - code := C.taos_stmt2_bind_binary(stmt2, (*C.char)(dataP), C.int32_t(colIdx), errMsg) + code := C.go_stmt2_bind_binary(stmt2, (*C.char)(dataP), C.int32_t(colIdx), errMsg) if code != 0 { msg := C.GoString(errMsg) return taosError.NewError(int(code), msg) From ba9edd5a91e331d6d6f3249055cca8802239c7df Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 10 Dec 2024 17:25:27 +0800 Subject: [PATCH 39/48] fix: stmt2 bind binary add table name length check and add asan test --- .github/workflows/linux.yml | 72 +++++++++++++++++++++++++++++- driver/wrapper/stmt2binary.go | 4 ++ driver/wrapper/whitelistcb_test.go | 1 + 3 files changed, 76 insertions(+), 1 deletion(-) diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 7b7a05ed..daf9f9d5 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -309,4 +309,74 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.61.0 \ No newline at end of file + version: v1.61.0 + + test_go_asan: + runs-on: ubuntu-latest + needs: build + strategy: + matrix: + go: [ 'stable' ] + name: test taosAdapter with asan ${{ matrix.go }} + steps: + - name: get cache server by pr + if: github.event_name == 'pull_request' + id: get-cache-server-pr + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ github.base_ref }}- + + - name: get cache server by push + if: github.event_name == 'push' + id: get-cache-server-push + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ github.ref_name }}- + + - name: get cache server manually + if: github.event_name == 'workflow_dispatch' + id: get-cache-server-manually + uses: actions/cache@v4 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ needs.build.outputs.commit_id }} + restore-keys: | + ${{ runner.os }}-build-${{ inputs.tbBranch }}- + + + - name: install + run: | + tar -zxvf server.tar.gz + cd release && sudo sh install.sh + + - name: checkout + uses: actions/checkout@v4 + + - name: copy taos cfg + run: | + sudo mkdir -p /etc/taos + sudo cp ./.github/workflows/taos.cfg /etc/taos/taos.cfg + + - uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + cache-dependency-path: go.sum + + - name: start shell + run: | + cat >start.sh< total_length) { diff --git a/driver/wrapper/whitelistcb_test.go b/driver/wrapper/whitelistcb_test.go index 9cc62a4f..40c81fb5 100644 --- a/driver/wrapper/whitelistcb_test.go +++ b/driver/wrapper/whitelistcb_test.go @@ -28,6 +28,7 @@ func TestWhitelistCallback_Success(t *testing.T) { 192, 168, 1, 1, 24, // 192.168.1.1/24 0, 0, 0, 10, 0, 0, 1, 16, // 10.0.0.1/16 + 0, 0, 0, } // Create a channel to receive the result From ef58e3b949965032dfc18c0bd04685a0d6ba6d81 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 10 Dec 2024 17:51:16 +0800 Subject: [PATCH 40/48] enh: rename generate_stmt2_binds to go_generate_stmt2_binds --- driver/wrapper/stmt2binary.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/driver/wrapper/stmt2binary.go b/driver/wrapper/stmt2binary.go index 6070d524..b38b3a51 100644 --- a/driver/wrapper/stmt2binary.go +++ b/driver/wrapper/stmt2binary.go @@ -7,7 +7,7 @@ package wrapper #include int -generate_stmt2_binds(char *data, uint32_t count, uint32_t field_count, uint32_t field_offset, +go_generate_stmt2_binds(char *data, uint32_t count, uint32_t field_count, uint32_t field_offset, TAOS_STMT2_BIND *bind_struct, TAOS_STMT2_BIND **bind_ptr, char *err_msg) { uint32_t *base_length = (uint32_t *) (data + field_offset); @@ -197,7 +197,7 @@ int go_stmt2_bind_binary(TAOS_STMT2 *stmt, char *data, int32_t col_idx, char *er uint32_t struct_index = 0; uint32_t ptr_index = 0; if (tags_offset > 0) { - int code = generate_stmt2_binds(data, count, tag_count, tags_offset, bind_struct, bind_ptr, err_msg); + int code = go_generate_stmt2_binds(data, count, tag_count, tags_offset, bind_struct, bind_ptr, err_msg); if (code != 0) { free(bind_struct); free(bind_ptr); @@ -211,7 +211,7 @@ int go_stmt2_bind_binary(TAOS_STMT2 *stmt, char *data, int32_t col_idx, char *er if (cols_offset > 0) { TAOS_STMT2_BIND *col_bind_struct = bind_struct + struct_index; TAOS_STMT2_BIND **col_bind_ptr = bind_ptr + ptr_index; - int code = generate_stmt2_binds(data, count, col_count, cols_offset, col_bind_struct, col_bind_ptr, + int code = go_generate_stmt2_binds(data, count, col_count, cols_offset, col_bind_struct, col_bind_ptr, err_msg); if (code != 0) { free(bind_struct); From 51a69867972afbf21d5065cb09565898090343dc Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Tue, 10 Dec 2024 19:27:15 +0800 Subject: [PATCH 41/48] fix: Remove tmq set options test as it is not supported yet --- controller/ws/tmq/tmq_test.go | 171 +++++++++++++++++----------------- 1 file changed, 86 insertions(+), 85 deletions(-) diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index ca4c2d7e..9b8121be 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -3207,90 +3207,91 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { // assert.Error(t, err, string(resp)) //} -type httpQueryResp struct { - Code int `json:"code,omitempty"` - Desc string `json:"desc,omitempty"` - ColumnMeta [][]driver.Value `json:"column_meta,omitempty"` - Data [][]driver.Value `json:"data,omitempty"` - Rows int `json:"rows,omitempty"` -} - -func restQuery(sql string, db string) *httpQueryResp { - w := httptest.NewRecorder() - body := strings.NewReader(sql) - url := "/rest/sql" - if db != "" { - url = fmt.Sprintf("/rest/sql/%s", db) - } - req, _ := http.NewRequest(http.MethodPost, url, body) - req.RemoteAddr = "127.0.0.1:33333" - req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") - router.ServeHTTP(w, req) - if w.Code != http.StatusOK { - return &httpQueryResp{ - Code: w.Code, - Desc: w.Body.String(), - } - } - b, _ := io.ReadAll(w.Body) - var res httpQueryResp - _ = json.Unmarshal(b, &res) - return &res -} - -func TestConnectionOptions(t *testing.T) { - dbName := "test_ws_tmq_conn_options" - topic := "test_ws_tmq_conn_options_topic" - - before(t, dbName, topic) - - s := httptest.NewServer(router) - defer s.Close() - ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) - if err != nil { - t.Error(err) - return - } - defer func() { - err = ws.Close() - assert.NoError(t, err) - }() - - defer func() { - err = after(ws, dbName, topic) - assert.NoError(t, err) - }() +//type httpQueryResp struct { +// Code int `json:"code,omitempty"` +// Desc string `json:"desc,omitempty"` +// ColumnMeta [][]driver.Value `json:"column_meta,omitempty"` +// Data [][]driver.Value `json:"data,omitempty"` +// Rows int `json:"rows,omitempty"` +//} - // subscribe - b, _ := json.Marshal(TMQSubscribeReq{ - User: "root", - Password: "taosdata", - DB: dbName, - GroupID: "test", - Topics: []string{topic}, - AutoCommit: "false", - OffsetReset: "earliest", - SessionTimeoutMS: "100000", - App: "tmq_test_conn_protocol", - IP: "192.168.55.55", - TZ: "Asia/Shanghai", - }) - msg, err := doWebSocket(ws, TMQSubscribe, b) - assert.NoError(t, err) - var subscribeResp TMQSubscribeResp - err = json.Unmarshal(msg, &subscribeResp) - assert.NoError(t, err) - assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) +//func restQuery(sql string, db string) *httpQueryResp { +// w := httptest.NewRecorder() +// body := strings.NewReader(sql) +// url := "/rest/sql" +// if db != "" { +// url = fmt.Sprintf("/rest/sql/%s", db) +// } +// req, _ := http.NewRequest(http.MethodPost, url, body) +// req.RemoteAddr = "127.0.0.1:33333" +// req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") +// router.ServeHTTP(w, req) +// if w.Code != http.StatusOK { +// return &httpQueryResp{ +// Code: w.Code, +// Desc: w.Body.String(), +// } +// } +// b, _ := io.ReadAll(w.Body) +// var res httpQueryResp +// _ = json.Unmarshal(b, &res) +// return &res +//} - // check connection options - got := false - for i := 0; i < 10; i++ { - queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'tmq_test_conn_protocol' and user_ip = '192.168.55.55'", "") - if queryResp.Code == 0 && len(queryResp.Data) > 0 { - got = true - break - } - time.Sleep(time.Second) - } - assert.True(t, got) -} +// not supported yet +//func TestConnectionOptions(t *testing.T) { +// dbName := "test_ws_tmq_conn_options" +// topic := "test_ws_tmq_conn_options_topic" +// +// before(t, dbName, topic) +// +// s := httptest.NewServer(router) +// defer s.Close() +// ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) +// if err != nil { +// t.Error(err) +// return +// } +// defer func() { +// err = ws.Close() +// assert.NoError(t, err) +// }() +// +// defer func() { +// err = after(ws, dbName, topic) +// assert.NoError(t, err) +// }() +// +// // subscribe +// b, _ := json.Marshal(TMQSubscribeReq{ +// User: "root", +// Password: "taosdata", +// DB: dbName, +// GroupID: "test", +// Topics: []string{topic}, +// AutoCommit: "false", +// OffsetReset: "earliest", +// SessionTimeoutMS: "100000", +// App: "tmq_test_conn_protocol", +// IP: "192.168.55.55", +// TZ: "Asia/Shanghai", +// }) +// msg, err := doWebSocket(ws, TMQSubscribe, b) +// assert.NoError(t, err) +// var subscribeResp TMQSubscribeResp +// err = json.Unmarshal(msg, &subscribeResp) +// assert.NoError(t, err) +// assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) +// +// // check connection options +// got := false +// for i := 0; i < 10; i++ { +// queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'tmq_test_conn_protocol' and user_ip = '192.168.55.55'", "") +// if queryResp.Code == 0 && len(queryResp.Data) > 0 { +// got = true +// break +// } +// time.Sleep(time.Second) +// } +// assert.True(t, got) +//} From 07eeb21a0cf5cfad28ace4148e181e81ea686c7c Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 11 Dec 2024 15:12:22 +0800 Subject: [PATCH 42/48] fix: get client version without request id --- controller/ws/ws/handler.go | 17 ++++++++++++----- controller/ws/ws/ws_test.go | 13 +++++++++++++ driver/wrapper/stmt_test.go | 2 +- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/controller/ws/ws/handler.go b/controller/ws/ws/handler.go index 388820c5..88e47673 100644 --- a/controller/ws/ws/handler.go +++ b/controller/ws/ws/handler.go @@ -210,11 +210,18 @@ func (h *messageHandler) handleMessage(session *melody.Session, data []byte) { case wstool.ClientVersion: action = wstool.ClientVersion var req versionRequest - if err := json.Unmarshal(request.Args, &req); err != nil { - h.logger.Errorf("unmarshal version request error, request:%s, err:%s", request.Args, err) - reqID := getReqID(request.Args) - commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal version request error") - return + var reqID uint64 + if request.Args != nil { + if err := json.Unmarshal(request.Args, &req); err != nil { + h.logger.Errorf("unmarshal version request error, request:%s, err:%s", request.Args, err) + reqID := getReqID(request.Args) + commonErrorResponse(ctx, session, h.logger, action, reqID, 0xffff, "unmarshal version request error") + return + } + reqID = req.ReqID + } + req = versionRequest{ + ReqID: reqID, } logger := h.logger.WithFields(logrus.Fields{ actionKey: action, diff --git a/controller/ws/ws/ws_test.go b/controller/ws/ws/ws_test.go index bc9278ba..6433d9ef 100644 --- a/controller/ws/ws/ws_test.go +++ b/controller/ws/ws/ws_test.go @@ -151,4 +151,17 @@ func TestVersion(t *testing.T) { assert.Equal(t, 0, versionResp.Code, versionResp.Message) assert.Equal(t, version.TaosClientVersion, versionResp.Version) assert.Equal(t, wstool.ClientVersion, versionResp.Action) + + req := "{\"action\":\"version\"}" + err = ws.WriteMessage(websocket.TextMessage, []byte(req)) + assert.NoError(t, err) + _, resp, err = ws.ReadMessage() + assert.NoError(t, err) + + assert.NoError(t, err) + err = json.Unmarshal(resp, &versionResp) + assert.NoError(t, err) + assert.Equal(t, 0, versionResp.Code, versionResp.Message) + assert.Equal(t, version.TaosClientVersion, versionResp.Version) + assert.Equal(t, wstool.ClientVersion, versionResp.Action) } diff --git a/driver/wrapper/stmt_test.go b/driver/wrapper/stmt_test.go index 9d63b7a9..de58e228 100644 --- a/driver/wrapper/stmt_test.go +++ b/driver/wrapper/stmt_test.go @@ -898,7 +898,7 @@ func TestGetFieldsCommonTable(t *testing.T) { return } code, num, _ := TaosStmtGetTagFields(stmt) - assert.Equal(t, 0, code) + assert.NotEqual(t, 0, code) assert.Equal(t, 0, num) code, columnCount, columnsP := TaosStmtGetColFields(stmt) if code != 0 { From a2a2c8eea3d97dbeaf4f4a6ec49642fdaf959f85 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 11 Dec 2024 17:12:25 +0800 Subject: [PATCH 43/48] test: add block test --- tools/ctools/block_test.go | 124 ++++++++++++++++++++++++++++++------- 1 file changed, 101 insertions(+), 23 deletions(-) diff --git a/tools/ctools/block_test.go b/tools/ctools/block_test.go index f811d208..cb4550ff 100644 --- a/tools/ctools/block_test.go +++ b/tools/ctools/block_test.go @@ -8,11 +8,27 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/taosdata/taosadapter/v3/driver/common" "github.com/taosdata/taosadapter/v3/driver/common/parser" "github.com/taosdata/taosadapter/v3/tools" "github.com/taosdata/taosadapter/v3/tools/jsonbuilder" ) +type testHook struct { + hasError bool +} + +func (t *testHook) Levels() []logrus.Level { + return []logrus.Level{logrus.ErrorLevel} +} + +func (t *testHook) Fire(entry *logrus.Entry) error { + if entry.Level == logrus.ErrorLevel { + t.hasError = true + } + return nil +} + func TestJsonWriteRawBlock(t *testing.T) { raw := []byte{ 0x01, 0x00, 0x00, 0x00, @@ -134,13 +150,9 @@ func TestJsonWriteRawBlock(t *testing.T) { 0x07, 0x00, 0x7b, 0x22, 0x61, 0x22, 0x3a, 0x31, 0x7d, } - w := &strings.Builder{} - builder := jsonbuilder.BorrowStream(w) - defer jsonbuilder.ReturnStream(builder) fieldsCount := 17 fieldTypes := []uint8{9, 1, 2, 3, 4, 5, 11, 12, 13, 14, 6, 7, 8, 10, 16, 20, 15} blockSize := 2 - precision := 0 pHeaderList := make([]unsafe.Pointer, fieldsCount) pStartList := make([]unsafe.Pointer, fieldsCount) nullBitMapOffset := uintptr(BitmapLen(blockSize)) @@ -158,25 +170,91 @@ func TestJsonWriteRawBlock(t *testing.T) { } tmpPHeader = tools.AddPointer(pStartList[column], uintptr(colLength)) } - timeBuffer := make([]byte, 0, 30) - builder.WriteObjectStart() - for row := 0; row < blockSize; row++ { - builder.WriteArrayStart() - for column := 0; column < fieldsCount; column++ { - JsonWriteRawBlock(builder, fieldTypes[column], pHeaderList[column], pStartList[column], row, precision, time.UTC, timeBuffer, logrus.New().WithField("test", "test")) - if column != fieldsCount-1 { - builder.WriteMore() - err := builder.Flush() - assert.NoError(t, err) + cnLocation, _ := time.LoadLocation("Asia/Shanghai") + tests := []struct { + name string + precision int + timeZone *time.Location + expect string + expectError bool + }{ + { + name: "ms", + precision: common.PrecisionMilliSecond, + timeZone: time.UTC, + expect: `{["2022-08-10T07:02:40.500Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["2022-08-10T07:02:41.500Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "ns", + precision: common.PrecisionNanoSecond, + timeZone: time.UTC, + expect: `{["1970-01-01T00:27:40.114960500Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["1970-01-01T00:27:40.114961500Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "us", + precision: common.PrecisionMicroSecond, + timeZone: time.UTC, + expect: `{["1970-01-20T05:08:34.960500Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["1970-01-20T05:08:34.961500Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "ms_cn", + precision: common.PrecisionMilliSecond, + timeZone: cnLocation, + expect: `{["2022-08-10T15:02:40.500+08:00",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["2022-08-10T15:02:41.500+08:00",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "ns_cn", + precision: common.PrecisionNanoSecond, + timeZone: cnLocation, + expect: `{["1970-01-01T08:27:40.114960500+08:00",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["1970-01-01T08:27:40.114961500+08:00",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "us_cn", + precision: common.PrecisionMicroSecond, + timeZone: cnLocation, + expect: `{["1970-01-20T13:08:34.960500+08:00",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["1970-01-20T13:08:34.961500+08:00",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, + }, + { + name: "invalid precision", + precision: -1, + timeZone: cnLocation, + expectError: true, + }, + } + for _, tt := range tests { + hook := &testHook{} + logger := logrus.New() + logger.AddHook(hook) + t.Run(tt.name, func(t *testing.T) { + w := &strings.Builder{} + builder := jsonbuilder.BorrowStream(w) + defer jsonbuilder.ReturnStream(builder) + builder.WriteObjectStart() + timeBuffer := make([]byte, 0, 35) + for row := 0; row < blockSize; row++ { + builder.WriteArrayStart() + for column := 0; column < fieldsCount; column++ { + JsonWriteRawBlock(builder, fieldTypes[column], pHeaderList[column], pStartList[column], row, tt.precision, tt.timeZone, timeBuffer, logger.WithField("test", "test")) + if column != fieldsCount-1 { + builder.WriteMore() + err := builder.Flush() + assert.NoError(t, err) + } + } + builder.WriteArrayEnd() + if row != blockSize-1 { + builder.WriteMore() + } } - } - builder.WriteArrayEnd() - if row != blockSize-1 { - builder.WriteMore() - } + builder.WriteObjectEnd() + err := builder.Flush() + assert.NoError(t, err) + if tt.expectError { + assert.True(t, hook.hasError) + } else { + assert.False(t, hook.hasError) + assert.Equal(t, tt.expect, w.String()) + } + }) } - builder.WriteObjectEnd() - err := builder.Flush() - assert.NoError(t, err) - assert.Equal(t, `{["2022-08-10T07:02:40.500Z",true,2,3,4,5,6,7,8,9,10,11,"binary","nchar","746573745f76617262696e617279","010100000000000000000059400000000000005940",{"a":1}],["2022-08-10T07:02:41.500Z",null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,{"a":1}]}`, w.String()) } From 0b45ad9e4c98ec4271be0c8b94eda05b625d04d4 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 12 Dec 2024 13:03:40 +0800 Subject: [PATCH 44/48] fix: user password restrictions --- controller/rest/restful_test.go | 26 ++++++++++++++ controller/ws/query/ws_test.go | 4 +-- controller/ws/schemaless/schemaless_test.go | 4 +-- controller/ws/stmt/stmt_test.go | 4 +-- controller/ws/tmq/tmq_test.go | 4 +-- controller/ws/ws/handler_test.go | 4 +-- db/commonpool/pool_test.go | 38 ++++++++++----------- db/tool/notify_test.go | 10 +++--- driver/wrapper/notify_test.go | 6 ++-- 9 files changed, 63 insertions(+), 37 deletions(-) diff --git a/controller/rest/restful_test.go b/controller/rest/restful_test.go index 8212a6d3..005b3039 100644 --- a/controller/rest/restful_test.go +++ b/controller/rest/restful_test.go @@ -688,6 +688,9 @@ func TestSetConnectionOptions(t *testing.T) { checkResp(t, w) defer func() { + req, _ := http.NewRequest(http.MethodPost, url, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") w = httptest.NewRecorder() body = strings.NewReader("drop database if exists rest_test_options") req.Body = io.NopCloser(body) @@ -726,6 +729,29 @@ func TestSetConnectionOptions(t *testing.T) { expectTimeStr := expectTime.Format(layout.LayoutMillSecond) assert.Equal(t, expectTimeStr, result.Data[0][0]) t.Log(expectTimeStr, result.Data[0][0]) + + // wrong timezone + wrongTZUrl := "/rest/sql?app=rest_test_options&ip=192.168.100.1&tz=xxx" + req, _ = http.NewRequest(http.MethodPost, wrongTZUrl, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + w = httptest.NewRecorder() + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + assert.Equal(t, 400, w.Code) + + // wrong conn_tz + wrongConnTZUrl := "/rest/sql?app=rest_test_options&ip=192.168.100.1&conn_tz=xxx" + req, _ = http.NewRequest(http.MethodPost, wrongConnTZUrl, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + w = httptest.NewRecorder() + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + assert.Equal(t, 400, w.Code) + // wrong } func checkResp(t *testing.T, w *httptest.ResponseRecorder) { diff --git a/controller/ws/query/ws_test.go b/controller/ws/query/ws_test.go index baa99ef6..eb7476c3 100644 --- a/controller/ws/query/ws_test.go +++ b/controller/ws/query/ws_test.go @@ -1608,10 +1608,10 @@ func TestDropUser(t *testing.T) { assert.NoError(t, err) }() defer doRestful("drop user test_ws_query_drop_user", "") - code, message := doRestful("create user test_ws_query_drop_user pass 'pass'", "") + code, message := doRestful("create user test_ws_query_drop_user pass 'pass_123'", "") assert.Equal(t, 0, code, message) // connect - connReq := &WSConnectReq{ReqID: 1, User: "test_ws_query_drop_user", Password: "pass"} + connReq := &WSConnectReq{ReqID: 1, User: "test_ws_query_drop_user", Password: "pass_123"} resp, err := doWebSocket(ws, WSConnect, &connReq) assert.NoError(t, err) var connResp WSConnectResp diff --git a/controller/ws/schemaless/schemaless_test.go b/controller/ws/schemaless/schemaless_test.go index 1e54c2ae..934dded7 100644 --- a/controller/ws/schemaless/schemaless_test.go +++ b/controller/ws/schemaless/schemaless_test.go @@ -262,10 +262,10 @@ func TestDropUser(t *testing.T) { assert.NoError(t, err) }() defer doRestful("drop user test_ws_sml_drop_user", "") - code, message := doRestful("create user test_ws_sml_drop_user pass 'pass'", "") + code, message := doRestful("create user test_ws_sml_drop_user pass 'pass_123'", "") assert.Equal(t, 0, code, message) // connect - connReq := &schemalessConnReq{ReqID: 1, User: "test_ws_sml_drop_user", Password: "pass"} + connReq := &schemalessConnReq{ReqID: 1, User: "test_ws_sml_drop_user", Password: "pass_123"} resp, err := doWebSocket(ws, SchemalessConn, &connReq) assert.NoError(t, err) var connResp schemalessConnResp diff --git a/controller/ws/stmt/stmt_test.go b/controller/ws/stmt/stmt_test.go index 445a2fb8..ee86e668 100644 --- a/controller/ws/stmt/stmt_test.go +++ b/controller/ws/stmt/stmt_test.go @@ -1101,10 +1101,10 @@ func TestDropUser(t *testing.T) { assert.NoError(t, err) }() defer doRestful("drop user test_ws_stmt_drop_user", "") - code, message := doRestful("create user test_ws_stmt_drop_user pass 'pass'", "") + code, message := doRestful("create user test_ws_stmt_drop_user pass 'pass_123'", "") assert.Equal(t, 0, code, message) // connect - connReq := &StmtConnectReq{ReqID: 1, User: "test_ws_stmt_drop_user", Password: "pass"} + connReq := &StmtConnectReq{ReqID: 1, User: "test_ws_stmt_drop_user", Password: "pass_123"} resp, err := doWebSocket(ws, STMTConnect, &connReq) assert.NoError(t, err) var connResp StmtConnectResp diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index 9b8121be..47e26deb 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -3158,7 +3158,7 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { //func TestDropUser(t *testing.T) { // defer doHttpSql("drop user test_tmq_drop_user") -// code, message := doHttpSql("create user test_tmq_drop_user pass 'pass'") +// code, message := doHttpSql("create user test_tmq_drop_user pass 'pass_123'") // assert.Equal(t, 0, code, message) // // dbName := "test_ws_tmq_drop_user" @@ -3186,7 +3186,7 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { // // subscribe // b, _ := json.Marshal(TMQSubscribeReq{ // User: "test_tmq_drop_user", -// Password: "pass", +// Password: "pass_123", // DB: dbName, // GroupID: "test", // Topics: []string{topic}, diff --git a/controller/ws/ws/handler_test.go b/controller/ws/ws/handler_test.go index 8cffeabc..c91a318f 100644 --- a/controller/ws/ws/handler_test.go +++ b/controller/ws/ws/handler_test.go @@ -26,10 +26,10 @@ func TestDropUser(t *testing.T) { assert.NoError(t, err) }() defer doRestful("drop user test_ws_drop_user", "") - code, message := doRestful("create user test_ws_drop_user pass 'pass'", "") + code, message := doRestful("create user test_ws_drop_user pass 'pass_123'", "") assert.Equal(t, 0, code, message) // connect - connReq := connRequest{ReqID: 1, User: "test_ws_drop_user", Password: "pass"} + connReq := connRequest{ReqID: 1, User: "test_ws_drop_user", Password: "pass_123"} resp, err := doWebSocket(ws, Connect, &connReq) assert.NoError(t, err) var connResp commonResp diff --git a/db/commonpool/pool_test.go b/db/commonpool/pool_test.go index 838415c4..e5c0b719 100644 --- a/db/commonpool/pool_test.go +++ b/db/commonpool/pool_test.go @@ -111,28 +111,28 @@ func TestChangePassword(t *testing.T) { c, err := GetConnection("root", "taosdata", net.ParseIP("127.0.0.1")) assert.NoError(t, err) - result := wrapper.TaosQuery(c.TaosConnection, "drop user test") + result := wrapper.TaosQuery(c.TaosConnection, "drop user test_change_pass_pool") assert.NotNil(t, result) wrapper.TaosFreeResult(result) - result = wrapper.TaosQuery(c.TaosConnection, "create user test pass 'test'") + result = wrapper.TaosQuery(c.TaosConnection, "create user test_change_pass_pool pass 'test_123'") assert.NotNil(t, result) errNo := wrapper.TaosError(result) assert.Equal(t, 0, errNo) wrapper.TaosFreeResult(result) defer func() { - r := wrapper.TaosQuery(c.TaosConnection, "drop user test") + r := wrapper.TaosQuery(c.TaosConnection, "drop user test_change_pass_pool") wrapper.TaosFreeResult(r) }() - conn, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) + conn, err := GetConnection("test_change_pass_pool", "test_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) - result = wrapper.TaosQuery(c.TaosConnection, "alter user test pass 'test2'") + result = wrapper.TaosQuery(c.TaosConnection, "alter user test_change_pass_pool pass 'test2_123'") assert.NotNil(t, result) errNo = wrapper.TaosError(result) - assert.Equal(t, 0, errNo) + assert.Equal(t, 0, errNo, wrapper.TaosErrorStr(result)) wrapper.TaosFreeResult(result) result = wrapper.TaosQuery(conn.TaosConnection, "show databases") @@ -149,11 +149,11 @@ func TestChangePassword(t *testing.T) { err = conn.Put() assert.NoError(t, err) - conn2, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) + conn2, err := GetConnection("test_change_pass_pool", "test_123", net.ParseIP("127.0.0.1")) assert.Error(t, err) assert.Nil(t, conn2) - conn3, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn3, err := GetConnection("test_change_pass_pool", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn3) result2 := wrapper.TaosQuery(conn3.TaosConnection, "show databases") @@ -163,7 +163,7 @@ func TestChangePassword(t *testing.T) { err = conn3.Put() assert.NoError(t, err) - conn4, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn4, err := GetConnection("test_change_pass_pool", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn4) result3 := wrapper.TaosQuery(conn4.TaosConnection, "show databases") @@ -178,27 +178,27 @@ func TestChangePasswordConcurrent(t *testing.T) { c, err := GetConnection("root", "taosdata", net.ParseIP("127.0.0.1")) assert.NoError(t, err) - result := wrapper.TaosQuery(c.TaosConnection, "drop user test") + result := wrapper.TaosQuery(c.TaosConnection, "drop user test_change_pass_con") assert.NotNil(t, result) wrapper.TaosFreeResult(result) - result = wrapper.TaosQuery(c.TaosConnection, "create user test pass 'test'") + result = wrapper.TaosQuery(c.TaosConnection, "create user test_change_pass_con pass 'test_123'") assert.NotNil(t, result) errNo := wrapper.TaosError(result) assert.Equal(t, 0, errNo) wrapper.TaosFreeResult(result) defer func() { - r := wrapper.TaosQuery(c.TaosConnection, "drop user test") + r := wrapper.TaosQuery(c.TaosConnection, "drop user test_change_pass_con") wrapper.TaosFreeResult(r) }() - conn, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) + conn, err := GetConnection("test_change_pass_con", "test_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) - result = wrapper.TaosQuery(c.TaosConnection, "alter user test pass 'test2'") + result = wrapper.TaosQuery(c.TaosConnection, "alter user test_change_pass_con pass 'test2_123'") assert.NotNil(t, result) errNo = wrapper.TaosError(result) - assert.Equal(t, 0, errNo) + assert.Equal(t, 0, errNo, wrapper.TaosErrorStr(result)) wrapper.TaosFreeResult(result) result = wrapper.TaosQuery(conn.TaosConnection, "show databases") @@ -215,7 +215,7 @@ func TestChangePasswordConcurrent(t *testing.T) { for i := 0; i < 5; i++ { go func() { defer wg.Done() - conn2, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn2, err := GetConnection("test_change_pass_con", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn2) err = conn2.Put() @@ -223,11 +223,11 @@ func TestChangePasswordConcurrent(t *testing.T) { }() } wg.Wait() - conn2, err := GetConnection("test", "test", net.ParseIP("127.0.0.1")) + conn2, err := GetConnection("test_change_pass_con", "test_123", net.ParseIP("127.0.0.1")) assert.Error(t, err) assert.Nil(t, conn2) - conn3, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn3, err := GetConnection("test_change_pass_con", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn3) result2 := wrapper.TaosQuery(conn3.TaosConnection, "show databases") @@ -237,7 +237,7 @@ func TestChangePasswordConcurrent(t *testing.T) { err = conn3.Put() assert.NoError(t, err) - conn4, err := GetConnection("test", "test2", net.ParseIP("127.0.0.1")) + conn4, err := GetConnection("test_change_pass_con", "test2_123", net.ParseIP("127.0.0.1")) assert.NoError(t, err) assert.NotNil(t, conn4) result3 := wrapper.TaosQuery(conn4.TaosConnection, "show databases") diff --git a/db/tool/notify_test.go b/db/tool/notify_test.go index c1df3fbc..f5edc97c 100644 --- a/db/tool/notify_test.go +++ b/db/tool/notify_test.go @@ -266,13 +266,13 @@ func TestRegisterChangePass(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) assert.NoError(t, err) defer wrapper.TaosClose(conn) - err = exec(conn, "create user test_notify pass 'notify'") + err = exec(conn, "create user test_notify pass 'notify_123'") assert.NoError(t, err) defer func() { // ignore error _ = exec(conn, "drop user test_notify") }() - conn2, err := wrapper.TaosConnect("", "test_notify", "notify", "", 0) + conn2, err := wrapper.TaosConnect("", "test_notify", "notify_123", "", 0) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() @@ -285,7 +285,7 @@ func TestRegisterChangePass(t *testing.T) { } ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second*5) defer cancel2() - err = exec(conn, "alter user test_notify pass 'test'") + err = exec(conn, "alter user test_notify pass 'test_123'") assert.NoError(t, err) select { case data := <-c: @@ -301,13 +301,13 @@ func TestRegisterDropUser(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) assert.NoError(t, err) defer wrapper.TaosClose(conn) - err = exec(conn, "create user test_drop_user pass 'notify'") + err = exec(conn, "create user test_drop_user pass 'notify_123'") assert.NoError(t, err) defer func() { // ignore error _ = exec(conn, "drop user test_drop_user") }() - conn2, err := wrapper.TaosConnect("", "test_drop_user", "notify", "", 0) + conn2, err := wrapper.TaosConnect("", "test_drop_user", "notify_123", "", 0) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() diff --git a/driver/wrapper/notify_test.go b/driver/wrapper/notify_test.go index 8db0cb37..de102270 100644 --- a/driver/wrapper/notify_test.go +++ b/driver/wrapper/notify_test.go @@ -25,10 +25,10 @@ func TestNotify(t *testing.T) { _ = exec(conn, "drop user t_notify") }() _ = exec(conn, "drop user t_notify") - err = exec(conn, "create user t_notify pass 'notify'") + err = exec(conn, "create user t_notify pass 'notify_123'") assert.NoError(t, err) - conn2, err := TaosConnect("", "t_notify", "notify", "", 0) + conn2, err := TaosConnect("", "t_notify", "notify_123", "", 0) if err != nil { t.Error(err) return @@ -58,7 +58,7 @@ func TestNotify(t *testing.T) { t.Error(errCode, errStr) } - err = exec(conn, "alter user t_notify pass 'test'") + err = exec(conn, "alter user t_notify pass 'test_123'") assert.NoError(t, err) timeout, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() From e979d6a49548301980b938d2df47b776dfc5b183 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 12 Dec 2024 13:39:33 +0800 Subject: [PATCH 45/48] fix: get taos error string in http --- controller/rest/restful.go | 2 +- controller/rest/restful_test.go | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/controller/rest/restful.go b/controller/rest/restful.go index 36bae648..1c9b4b40 100644 --- a/controller/rest/restful.go +++ b/controller/rest/restful.go @@ -238,7 +238,7 @@ func trySetConnectionOptions(c *gin.Context, conn unsafe.Pointer, logger *logrus if val != "" { code := syncinterface.TaosOptionsConnection(conn, options[i], &val, logger, isDebug) if code != httperror.SUCCESS { - errStr := wrapper.TaosErrorStr(conn) + errStr := wrapper.TaosErrorStr(nil) logger.Errorf("set connection options error, option:%d, val:%s, code:%d, message:%s", options[i], val, code, errStr) TaosErrorResponse(c, logger, code, errStr) return false diff --git a/controller/rest/restful_test.go b/controller/rest/restful_test.go index 005b3039..f7638861 100644 --- a/controller/rest/restful_test.go +++ b/controller/rest/restful_test.go @@ -751,7 +751,19 @@ func TestSetConnectionOptions(t *testing.T) { req.Body = io.NopCloser(body) router.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) - // wrong + // wrong ip + wrongIPUrl := "/rest/sql?app=rest_test_options&ip=xxx.xxx.xxx.xxx&conn_tz=Europe/Moscow&tz=Asia/Shanghai" + req, _ = http.NewRequest(http.MethodPost, wrongIPUrl, body) + req.RemoteAddr = "127.0.0.1:33333" + req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") + w = httptest.NewRecorder() + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) + req.Body = io.NopCloser(body) + router.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + err = json.Unmarshal(w.Body.Bytes(), &result) + assert.NoError(t, err) + assert.NotEqual(t, 0, result.Code) } func checkResp(t *testing.T, w *httptest.ResponseRecorder) { From d5916c2a121f4cb75d0d6bc40e81c5d78cfea467 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 12 Dec 2024 13:52:24 +0800 Subject: [PATCH 46/48] test: add tmq test --- controller/ws/tmq/tmq_test.go | 149 +++++++++++++++++++++------------- 1 file changed, 91 insertions(+), 58 deletions(-) diff --git a/controller/ws/tmq/tmq_test.go b/controller/ws/tmq/tmq_test.go index 47e26deb..ebaf5c4b 100644 --- a/controller/ws/tmq/tmq_test.go +++ b/controller/ws/tmq/tmq_test.go @@ -3156,6 +3156,7 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) } +// todo: not implemented //func TestDropUser(t *testing.T) { // defer doHttpSql("drop user test_tmq_drop_user") // code, message := doHttpSql("create user test_tmq_drop_user pass 'pass_123'") @@ -3214,7 +3215,7 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { // Data [][]driver.Value `json:"data,omitempty"` // Rows int `json:"rows,omitempty"` //} - +// //func restQuery(sql string, db string) *httpQueryResp { // w := httptest.NewRecorder() // body := strings.NewReader(sql) @@ -3238,60 +3239,92 @@ func TestTMQ_SetMsgConsumeExcluded(t *testing.T) { // return &res //} -// not supported yet -//func TestConnectionOptions(t *testing.T) { -// dbName := "test_ws_tmq_conn_options" -// topic := "test_ws_tmq_conn_options_topic" -// -// before(t, dbName, topic) -// -// s := httptest.NewServer(router) -// defer s.Close() -// ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) -// if err != nil { -// t.Error(err) -// return -// } -// defer func() { -// err = ws.Close() -// assert.NoError(t, err) -// }() -// -// defer func() { -// err = after(ws, dbName, topic) -// assert.NoError(t, err) -// }() -// -// // subscribe -// b, _ := json.Marshal(TMQSubscribeReq{ -// User: "root", -// Password: "taosdata", -// DB: dbName, -// GroupID: "test", -// Topics: []string{topic}, -// AutoCommit: "false", -// OffsetReset: "earliest", -// SessionTimeoutMS: "100000", -// App: "tmq_test_conn_protocol", -// IP: "192.168.55.55", -// TZ: "Asia/Shanghai", -// }) -// msg, err := doWebSocket(ws, TMQSubscribe, b) -// assert.NoError(t, err) -// var subscribeResp TMQSubscribeResp -// err = json.Unmarshal(msg, &subscribeResp) -// assert.NoError(t, err) -// assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) -// -// // check connection options -// got := false -// for i := 0; i < 10; i++ { -// queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'tmq_test_conn_protocol' and user_ip = '192.168.55.55'", "") -// if queryResp.Code == 0 && len(queryResp.Data) > 0 { -// got = true -// break -// } -// time.Sleep(time.Second) -// } -// assert.True(t, got) -//} +func TestConnectionOptions(t *testing.T) { + dbName := "test_ws_tmq_conn_options" + topic := "test_ws_tmq_conn_options_topic" + + before(t, dbName, topic) + + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + + defer func() { + err = after(ws, dbName, topic) + assert.NoError(t, err) + }() + + // subscribe + b, _ := json.Marshal(TMQSubscribeReq{ + User: "root", + Password: "taosdata", + DB: dbName, + GroupID: "test", + Topics: []string{topic}, + AutoCommit: "false", + OffsetReset: "earliest", + SessionTimeoutMS: "100000", + App: "tmq_test_conn_protocol", + IP: "192.168.55.55", + TZ: "Asia/Shanghai", + }) + msg, err := doWebSocket(ws, TMQSubscribe, b) + assert.NoError(t, err) + var subscribeResp TMQSubscribeResp + err = json.Unmarshal(msg, &subscribeResp) + assert.NoError(t, err) + assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message) + + // todo: check connection options, C not implemented + //got := false + //for i := 0; i < 10; i++ { + // queryResp := restQuery("select conn_id from performance_schema.perf_connections where user_app = 'tmq_test_conn_protocol' and user_ip = '192.168.55.55'", "") + // if queryResp.Code == 0 && len(queryResp.Data) > 0 { + // got = true + // break + // } + // time.Sleep(time.Second) + //} + //assert.True(t, got) +} + +func TestWrongPass(t *testing.T) { + s := httptest.NewServer(router) + defer s.Close() + ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil) + if err != nil { + t.Error(err) + return + } + defer func() { + err = ws.Close() + assert.NoError(t, err) + }() + // subscribe + b, _ := json.Marshal(TMQSubscribeReq{ + User: "root", + Password: "wrong_pass", + GroupID: "test", + Topics: []string{"test"}, + AutoCommit: "false", + OffsetReset: "earliest", + SessionTimeoutMS: "100000", + App: "tmq_test_conn_protocol", + IP: "192.168.55.55", + TZ: "Asia/Shanghai", + }) + msg, err := doWebSocket(ws, TMQSubscribe, b) + assert.NoError(t, err) + var subscribeResp TMQSubscribeResp + err = json.Unmarshal(msg, &subscribeResp) + assert.NoError(t, err) + assert.NotEqual(t, 0, subscribeResp.Code, subscribeResp.Message) +} From 640f3e4ee177127eae34e6bcd0bcd5abb6e6ee74 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 12 Dec 2024 16:59:26 +0800 Subject: [PATCH 47/48] fix: tmq new consumer return error --- controller/rest/restful_test.go | 9 +++------ controller/ws/tmq/tmq.go | 3 +-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/controller/rest/restful_test.go b/controller/rest/restful_test.go index f7638861..f940cb87 100644 --- a/controller/rest/restful_test.go +++ b/controller/rest/restful_test.go @@ -688,12 +688,11 @@ func TestSetConnectionOptions(t *testing.T) { checkResp(t, w) defer func() { + body := strings.NewReader("drop database if exists rest_test_options") req, _ := http.NewRequest(http.MethodPost, url, body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") w = httptest.NewRecorder() - body = strings.NewReader("drop database if exists rest_test_options") - req.Body = io.NopCloser(body) router.ServeHTTP(w, req) checkResp(t, w) }() @@ -732,23 +731,21 @@ func TestSetConnectionOptions(t *testing.T) { // wrong timezone wrongTZUrl := "/rest/sql?app=rest_test_options&ip=192.168.100.1&tz=xxx" + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) req, _ = http.NewRequest(http.MethodPost, wrongTZUrl, body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") w = httptest.NewRecorder() - body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) - req.Body = io.NopCloser(body) router.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) // wrong conn_tz wrongConnTZUrl := "/rest/sql?app=rest_test_options&ip=192.168.100.1&conn_tz=xxx" + body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) req, _ = http.NewRequest(http.MethodPost, wrongConnTZUrl, body) req.RemoteAddr = "127.0.0.1:33333" req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==") w = httptest.NewRecorder() - body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`) - req.Body = io.NopCloser(body) router.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) // wrong ip diff --git a/controller/ws/tmq/tmq.go b/controller/ws/tmq/tmq.go index e3ae3c2d..1cef7054 100644 --- a/controller/ws/tmq/tmq.go +++ b/controller/ws/tmq/tmq.go @@ -1616,8 +1616,7 @@ func (t *TMQ) wrapperConsumerNew(logger *logrus.Entry, isDebug bool, tmqConfig u logger.Tracef("new consumer result %x", uintptr(result.Consumer)) if len(result.ErrStr) > 0 { err = taoserrors.NewError(-1, result.ErrStr) - } - if result.Consumer == nil { + } else if result.Consumer == nil { err = taoserrors.NewError(-1, "new consumer return nil") } t.asyncLocker.Unlock() From dfa4db2115a19fd014c37bfe99068acb2de67d02 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Mon, 16 Dec 2024 11:05:40 +0800 Subject: [PATCH 48/48] feat: rename taos_stmt2_get_stb_fields to taos_stmt2_get_fields --- controller/ws/ws/stmt2.go | 25 +-- controller/ws/ws/stmt2_test.go | 34 +--- db/syncinterface/wrapper.go | 18 +- db/syncinterface/wrapper_test.go | 52 +++--- driver/common/stmt/stmt2.go | 22 ++- driver/common/stmt/stmt2_test.go | 293 +++++++++++++++++++++---------- driver/wrapper/stmt2.go | 81 ++++----- driver/wrapper/stmt2_test.go | 239 ++++++++++--------------- 8 files changed, 388 insertions(+), 376 deletions(-) diff --git a/controller/ws/ws/stmt2.go b/controller/ws/ws/stmt2.go index f7ccb5b4..2ff6565b 100644 --- a/controller/ws/ws/stmt2.go +++ b/controller/ws/ws/stmt2.go @@ -9,6 +9,7 @@ import ( "github.com/taosdata/taosadapter/v3/controller/ws/wstool" "github.com/taosdata/taosadapter/v3/db/async" "github.com/taosdata/taosadapter/v3/db/syncinterface" + "github.com/taosdata/taosadapter/v3/driver/common/stmt" errors2 "github.com/taosdata/taosadapter/v3/driver/errors" "github.com/taosdata/taosadapter/v3/driver/wrapper" "github.com/taosdata/taosadapter/v3/log" @@ -81,15 +82,15 @@ type stmt2PrepareRequest struct { } type stmt2PrepareResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Action string `json:"action"` - ReqID uint64 `json:"req_id"` - Timing int64 `json:"timing"` - StmtID uint64 `json:"stmt_id"` - IsInsert bool `json:"is_insert"` - Fields []*wrapper.StmtStbField `json:"fields"` - FieldsCount int `json:"fields_count"` + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + IsInsert bool `json:"is_insert"` + Fields []*stmt.Stmt2AllField `json:"fields"` + FieldsCount int `json:"fields_count"` } func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Session, action string, req stmt2PrepareRequest, logger *logrus.Entry, isDebug bool) { @@ -119,15 +120,15 @@ func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Sessi stmtItem.isInsert = isInsert prepareResp := &stmt2PrepareResponse{StmtID: req.StmtID, IsInsert: isInsert} if req.GetFields { - code, count, fields := syncinterface.TaosStmt2GetStbFields(stmt2, logger, isDebug) + code, count, fields := syncinterface.TaosStmt2GetFields(stmt2, logger, isDebug) if code != 0 { errStr := wrapper.TaosStmt2Error(stmt2) logger.Errorf("stmt2 get fields error, code:%d, err:%s", code, errStr) stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID) return } - defer wrapper.TaosStmt2FreeStbFields(stmt2, fields) - stbFields := wrapper.ParseStmt2StbFields(count, fields) + defer wrapper.TaosStmt2FreeFields(stmt2, fields) + stbFields := wrapper.Stmt2ParseAllFields(count, fields) prepareResp.Fields = stbFields prepareResp.FieldsCount = count diff --git a/controller/ws/ws/stmt2_test.go b/controller/ws/ws/stmt2_test.go index 8486ec40..d5764171 100644 --- a/controller/ws/ws/stmt2_test.go +++ b/controller/ws/ws/stmt2_test.go @@ -84,20 +84,6 @@ func TestWsStmt2(t *testing.T) { assert.Equal(t, 0, prepareResp.Code, prepareResp.Message) assert.True(t, prepareResp.IsInsert) assert.Equal(t, 18, len(prepareResp.Fields)) - var colFields []*stmtCommon.StmtField - var tagFields []*stmtCommon.StmtField - for i := 0; i < 18; i++ { - field := &stmtCommon.StmtField{ - FieldType: prepareResp.Fields[i].FieldType, - Precision: prepareResp.Fields[i].Precision, - } - switch prepareResp.Fields[i].BindType { - case stmtCommon.TAOS_FIELD_COL: - colFields = append(colFields, field) - case stmtCommon.TAOS_FIELD_TAG: - tagFields = append(tagFields, field) - } - } // bind now := time.Now() cols := [][]driver.Value{ @@ -141,7 +127,7 @@ func TestWsStmt2(t *testing.T) { Tags: tag, Cols: cols, } - bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, colFields, tagFields) + bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, prepareResp.Fields) assert.NoError(t, err) bindReq := make([]byte, len(bs)+30) // req_id @@ -475,7 +461,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) { }, }, } - b, err := stmtCommon.MarshalStmt2Binary(params, false, nil, nil) + b, err := stmtCommon.MarshalStmt2Binary(params, false, nil) assert.NoError(t, err) block.Write(b) @@ -702,21 +688,7 @@ func TestStmt2BindWithStbFields(t *testing.T) { Tags: tag, Cols: cols, } - var colFields []*stmtCommon.StmtField - var tagFields []*stmtCommon.StmtField - for i := 0; i < 18; i++ { - field := &stmtCommon.StmtField{ - FieldType: prepareResp.Fields[i].FieldType, - Precision: prepareResp.Fields[i].Precision, - } - switch prepareResp.Fields[i].BindType { - case stmtCommon.TAOS_FIELD_COL: - colFields = append(colFields, field) - case stmtCommon.TAOS_FIELD_TAG: - tagFields = append(tagFields, field) - } - } - bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, colFields, tagFields) + bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, prepareResp.Fields) assert.NoError(t, err) bindReq := make([]byte, len(bs)+30) // req_id diff --git a/db/syncinterface/wrapper.go b/db/syncinterface/wrapper.go index ae189f17..90c70dc3 100644 --- a/db/syncinterface/wrapper.go +++ b/db/syncinterface/wrapper.go @@ -362,13 +362,13 @@ func TaosStmt2IsInsert(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) return isInsert, code } -func TaosStmt2GetFields(stmt2 unsafe.Pointer, fieldType int, logger *logrus.Entry, isDebug bool) (code, count int, fields unsafe.Pointer) { - logger.Tracef("call taos_stmt2_get_fields, stmt2:%p, fieldType:%d", stmt2, fieldType) +func TaosStmt2GetFields(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) (code, count int, fields unsafe.Pointer) { + logger.Tracef("call taos_stmt2_get_fields, stmt2:%p", stmt2) s := log.GetLogNow(isDebug) thread.SyncLocker.Lock() logger.Debugf("get thread lock for taos_stmt2_get_fields cost:%s", log.GetLogDuration(isDebug, s)) s = log.GetLogNow(isDebug) - code, count, fields = wrapper.TaosStmt2GetFields(stmt2, fieldType) + code, count, fields = wrapper.TaosStmt2GetFields(stmt2) logger.Debugf("taos_stmt2_get_fields finish, code:%d, count:%d, fields:%p, cost:%s", code, count, fields, log.GetLogDuration(isDebug, s)) thread.SyncLocker.Unlock() return code, count, fields @@ -410,18 +410,6 @@ func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32, logger return err } -func TaosStmt2GetStbFields(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) (code, count int, fields unsafe.Pointer) { - logger.Tracef("call taos_stmt2_get_stb_fields, stmt2:%p", stmt2) - s := log.GetLogNow(isDebug) - thread.SyncLocker.Lock() - logger.Debugf("get thread lock for taos_stmt2_get_stb_fields cost:%s", log.GetLogDuration(isDebug, s)) - s = log.GetLogNow(isDebug) - code, count, fields = wrapper.TaosStmt2GetStbFields(stmt2) - logger.Debugf("taos_stmt2_get_stb_fields finish, code:%d, count:%d, fields:%p, cost:%s", code, count, fields, log.GetLogDuration(isDebug, s)) - thread.SyncLocker.Unlock() - return code, count, fields -} - func TaosOptionsConnection(conn unsafe.Pointer, option int, value *string, logger *logrus.Entry, isDebug bool) int { if value == nil { logger.Tracef("call taos_options_connection, conn:%p, option:%d, value:", conn, option) diff --git a/db/syncinterface/wrapper_test.go b/db/syncinterface/wrapper_test.go index ddd35808..b24b3c09 100644 --- a/db/syncinterface/wrapper_test.go +++ b/db/syncinterface/wrapper_test.go @@ -509,16 +509,16 @@ func TestTaosStmt2(t *testing.T) { return } assert.True(t, isInsert) - code, count, fiels := TaosStmt2GetStbFields(stmt, logger, isDebug) + code, count, fields := TaosStmt2GetFields(stmt, logger, isDebug) if !assert.Equal(t, 0, code, wrapper.TaosStmtErrStr(stmt)) { return } assert.Equal(t, 4, count) - assert.NotNil(t, fiels) + assert.NotNil(t, fields) defer func() { - wrapper.TaosStmt2FreeFields(stmt, fiels) + wrapper.TaosStmt2FreeFields(stmt, fields) }() - fs := wrapper.ParseStmt2StbFields(count, fiels) + fs := wrapper.Stmt2ParseAllFields(count, fields) assert.Equal(t, 4, len(fs)) assert.Equal(t, "tbname", fs[0].Name) assert.Equal(t, int8(common.TSDB_DATA_TYPE_BINARY), fs[0].FieldType) @@ -537,45 +537,37 @@ func TestTaosStmt2(t *testing.T) { binds := &stmtCommon.TaosStmt2BindData{ TableName: tableName, } - bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, nil, nil) + bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, nil) assert.NoError(t, err) err = TaosStmt2BindBinary(stmt, bs, -1, logger, isDebug) assert.NoError(t, err) - code, num, fields := TaosStmt2GetFields(stmt, stmtCommon.TAOS_FIELD_COL, logger, isDebug) + code, num, fields2 := TaosStmt2GetFields(stmt, logger, isDebug) if !assert.Equal(t, 0, code, wrapper.TaosStmtErrStr(stmt)) { return } - assert.Equal(t, 2, num) + assert.Equal(t, 3, num) assert.NotNil(t, fields) defer func() { - wrapper.TaosStmt2FreeFields(stmt, fields) - }() - colFields := wrapper.StmtParseFields(num, fields) - assert.Equal(t, 2, len(colFields)) - assert.Equal(t, "ts", colFields[0].Name) - assert.Equal(t, int8(common.TSDB_DATA_TYPE_TIMESTAMP), colFields[0].FieldType) - assert.Equal(t, "v", colFields[1].Name) - assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), colFields[1].FieldType) - code, num, tags := TaosStmt2GetFields(stmt, stmtCommon.TAOS_FIELD_TAG, logger, isDebug) - if !assert.Equal(t, 0, code, wrapper.TaosStmtErrStr(stmt)) { - return - } - assert.Equal(t, 1, num) - assert.NotNil(t, tags) - defer func() { - wrapper.TaosStmt2FreeFields(stmt, tags) + wrapper.TaosStmt2FreeFields(stmt, fields2) }() - tagFields := wrapper.StmtParseFields(num, tags) - assert.Equal(t, 1, len(tagFields)) - assert.Equal(t, "id", tagFields[0].Name) - assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), tagFields[0].FieldType) - + fsAfterBindTableName := wrapper.Stmt2ParseAllFields(num, fields2) + assert.Equal(t, 3, len(fsAfterBindTableName)) + assert.Equal(t, "id", fsAfterBindTableName[0].Name) + assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), fsAfterBindTableName[0].FieldType) + assert.Equal(t, int8(stmtCommon.TAOS_FIELD_TAG), fsAfterBindTableName[0].BindType) + assert.Equal(t, "ts", fsAfterBindTableName[1].Name) + assert.Equal(t, int8(common.TSDB_DATA_TYPE_TIMESTAMP), fsAfterBindTableName[1].FieldType) + assert.Equal(t, int8(stmtCommon.TAOS_FIELD_COL), fsAfterBindTableName[1].BindType) + assert.Equal(t, uint8(common.PrecisionMilliSecond), fsAfterBindTableName[1].Precision) + assert.Equal(t, "v", fsAfterBindTableName[2].Name) + assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), fsAfterBindTableName[2].FieldType) + assert.Equal(t, int8(stmtCommon.TAOS_FIELD_COL), fsAfterBindTableName[2].BindType) binds = &stmtCommon.TaosStmt2BindData{ Tags: []driver.Value{int32(1)}, } - bs, err = stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, nil, tagFields) + bs, err = stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, fsAfterBindTableName[0:1]) assert.NoError(t, err) err = TaosStmt2BindBinary(stmt, bs, -1, logger, isDebug) assert.NoError(t, err) @@ -587,7 +579,7 @@ func TestTaosStmt2(t *testing.T) { {int32(100), int32(101)}, }, } - bs, err = stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, colFields, nil) + bs, err = stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, fsAfterBindTableName[1:]) assert.NoError(t, err) err = TaosStmt2BindBinary(stmt, bs, -1, logger, isDebug) assert.NoError(t, err) diff --git a/driver/common/stmt/stmt2.go b/driver/common/stmt/stmt2.go index d1a97696..fbd7086a 100644 --- a/driver/common/stmt/stmt2.go +++ b/driver/common/stmt/stmt2.go @@ -29,7 +29,16 @@ const ( BindDataIsNullOffset = BindDataNumOffset + 4 ) -func MarshalStmt2Binary(bindData []*TaosStmt2BindData, isInsert bool, colType, tagType []*StmtField) ([]byte, error) { +func MarshalStmt2Binary(bindData []*TaosStmt2BindData, isInsert bool, fields []*Stmt2AllField) ([]byte, error) { + var colType []*Stmt2AllField + var tagType []*Stmt2AllField + for i := 0; i < len(fields); i++ { + if fields[i].BindType == TAOS_FIELD_COL { + colType = append(colType, fields[i]) + } else if fields[i].BindType == TAOS_FIELD_TAG { + tagType = append(tagType, fields[i]) + } + } // count count := len(bindData) if count == 0 { @@ -215,7 +224,7 @@ func getBindDataHeaderLength(num int, needLength bool) int { return length } -func generateBindColData(data []driver.Value, colType *StmtField, tmpBuffer *bytes.Buffer) ([]byte, error) { +func generateBindColData(data []driver.Value, colType *Stmt2AllField, tmpBuffer *bytes.Buffer) ([]byte, error) { num := len(data) tmpBuffer.Reset() needLength := needLength(colType.FieldType) @@ -578,3 +587,12 @@ func needLength(colType int8) bool { } return false } + +type Stmt2AllField struct { + Name string `json:"name"` + FieldType int8 `json:"field_type"` + Precision uint8 `json:"precision"` + Scale uint8 `json:"scale"` + Bytes int32 `json:"bytes"` + BindType int8 `json:"bind_type"` +} diff --git a/driver/common/stmt/stmt2_test.go b/driver/common/stmt/stmt2_test.go index 957d592f..e26fa6f6 100644 --- a/driver/common/stmt/stmt2_test.go +++ b/driver/common/stmt/stmt2_test.go @@ -18,10 +18,11 @@ func TestMarshalBinary(t *testing.T) { largeTableName += "a" } type args struct { - t []*TaosStmt2BindData - isInsert bool - tagType []*StmtField - colType []*StmtField + t []*TaosStmt2BindData + isInsert bool + fieldType []*Stmt2AllField + //tagType []*StmtField + //colType []*StmtField } tests := []struct { name string @@ -43,9 +44,8 @@ func TestMarshalBinary(t *testing.T) { TableName: "test2", }, }, - isInsert: true, - tagType: nil, - colType: nil, + isInsert: true, + fieldType: nil, }, want: []byte{ // total Length @@ -84,9 +84,8 @@ func TestMarshalBinary(t *testing.T) { TableName: largeTableName, }, }, - isInsert: true, - tagType: nil, - colType: nil, + isInsert: true, + fieldType: nil, }, want: nil, wantErr: true, @@ -208,58 +207,73 @@ func TestMarshalBinary(t *testing.T) { }, }, isInsert: true, - tagType: []*StmtField{ + fieldType: []*Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_BOOL, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_TINYINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_SMALLINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_BIGINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_FLOAT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_DOUBLE, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_UTINYINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_USMALLINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_UINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_UBIGINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_BINARY, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_NCHAR, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_VARBINARY, + BindType: TAOS_FIELD_TAG, }, }, - colType: nil, }, want: []byte{ // total Length @@ -898,106 +912,137 @@ func TestMarshalBinary(t *testing.T) { }, }, isInsert: true, - tagType: []*StmtField{ + fieldType: []*Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_BOOL, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_TINYINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_SMALLINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_BIGINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_FLOAT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_DOUBLE, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_UTINYINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_USMALLINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_UINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_UBIGINT, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_BINARY, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_NCHAR, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + BindType: TAOS_FIELD_TAG, }, { FieldType: common.TSDB_DATA_TYPE_VARBINARY, + BindType: TAOS_FIELD_TAG, }, - }, - colType: []*StmtField{ + { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_BOOL, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_TINYINT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_SMALLINT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_BIGINT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_FLOAT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_DOUBLE, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_UTINYINT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_USMALLINT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_UINT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_UBIGINT, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_BINARY, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_NCHAR, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_GEOMETRY, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_VARBINARY, + BindType: TAOS_FIELD_COL, }, }, }, @@ -1757,18 +1802,19 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{int32(3)}, }, }, - colType: []*StmtField{ + fieldType: []*Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_BIGINT, + BindType: TAOS_FIELD_COL, }, - }, - tagType: []*StmtField{ { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: TAOS_FIELD_TAG, }, }, isInsert: true, @@ -1884,10 +1930,9 @@ func TestMarshalBinary(t *testing.T) { { name: "empty", args: args{ - t: nil, - isInsert: false, - tagType: nil, - colType: nil, + t: nil, + isInsert: false, + fieldType: nil, }, want: nil, wantErr: true, @@ -1900,9 +1945,8 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{int32(1)}, }, }, - isInsert: true, - tagType: nil, - colType: nil, + isInsert: true, + fieldType: nil, }, want: nil, wantErr: true, @@ -1919,9 +1963,8 @@ func TestMarshalBinary(t *testing.T) { }, }, }, - isInsert: true, - tagType: nil, - colType: nil, + isInsert: true, + fieldType: nil, }, want: nil, wantErr: true, @@ -1939,10 +1982,12 @@ func TestMarshalBinary(t *testing.T) { }, }, isInsert: false, - tagType: []*StmtField{{ - FieldType: common.TSDB_DATA_TYPE_INT, - }}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_INT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -1960,10 +2005,12 @@ func TestMarshalBinary(t *testing.T) { }, }, isInsert: false, - tagType: nil, - colType: []*StmtField{{ - FieldType: common.TSDB_DATA_TYPE_INT, - }}, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_INT, + BindType: TAOS_FIELD_COL, + }, + }, }, want: nil, wantErr: true, @@ -1987,9 +2034,8 @@ func TestMarshalBinary(t *testing.T) { }, }, }, - isInsert: false, - tagType: nil, - colType: nil, + isInsert: false, + fieldType: nil, }, want: nil, wantErr: true, @@ -2002,9 +2048,8 @@ func TestMarshalBinary(t *testing.T) { TableName: "table1", }, }, - isInsert: false, - tagType: nil, - colType: nil, + isInsert: false, + fieldType: nil, }, want: nil, wantErr: true, @@ -2017,9 +2062,8 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{int32(1)}, }, }, - isInsert: false, - tagType: nil, - colType: nil, + isInsert: false, + fieldType: nil, }, want: nil, wantErr: true, @@ -2030,9 +2074,8 @@ func TestMarshalBinary(t *testing.T) { t: []*TaosStmt2BindData{ {}, }, - isInsert: false, - tagType: nil, - colType: nil, + isInsert: false, + fieldType: nil, }, want: nil, wantErr: true, @@ -2050,9 +2093,8 @@ func TestMarshalBinary(t *testing.T) { }, }, }, - isInsert: false, - tagType: nil, - colType: nil, + isInsert: false, + fieldType: nil, }, want: nil, wantErr: true, @@ -2064,8 +2106,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{int32(1)}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_BOOL}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_BOOL, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2077,8 +2123,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TINYINT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_TINYINT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2090,8 +2140,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_SMALLINT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_SMALLINT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2103,8 +2157,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_INT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_INT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2116,8 +2174,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_BIGINT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_BIGINT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2129,8 +2191,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_UTINYINT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_UTINYINT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2142,8 +2208,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_USMALLINT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_USMALLINT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2155,8 +2225,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_UINT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_UINT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2168,8 +2242,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_UBIGINT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_UBIGINT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2181,8 +2259,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_FLOAT}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_FLOAT, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2194,8 +2276,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_DOUBLE}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_DOUBLE, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2207,8 +2293,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_BINARY}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_BINARY, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2220,8 +2310,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{true}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TIMESTAMP}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: nil, wantErr: true, @@ -2240,8 +2334,12 @@ func TestMarshalBinary(t *testing.T) { }, }, isInsert: true, - tagType: nil, - colType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TIMESTAMP}}, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + BindType: TAOS_FIELD_COL, + }, + }, }, want: []byte{ // total Length @@ -2288,9 +2386,8 @@ func TestMarshalBinary(t *testing.T) { {false}, }, }}, - isInsert: false, - tagType: nil, - colType: nil, + isInsert: false, + fieldType: nil, }, want: []byte{ // total Length @@ -2336,9 +2433,8 @@ func TestMarshalBinary(t *testing.T) { {customInt(1)}, }, }}, - isInsert: false, - tagType: nil, - colType: nil, + isInsert: false, + fieldType: nil, }, want: nil, wantErr: true, @@ -2352,8 +2448,12 @@ func TestMarshalBinary(t *testing.T) { }, }}, isInsert: true, - tagType: nil, - colType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_NULL}}, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_NULL, + BindType: TAOS_FIELD_COL, + }, + }, }, want: nil, wantErr: true, @@ -2366,9 +2466,8 @@ func TestMarshalBinary(t *testing.T) { Cols: nil, }, }, - isInsert: true, - tagType: nil, - colType: []*StmtField{}, + isInsert: true, + fieldType: nil, }, want: nil, wantErr: true, @@ -2380,8 +2479,12 @@ func TestMarshalBinary(t *testing.T) { Tags: []driver.Value{int64(1726803356466)}, }}, isInsert: true, - tagType: []*StmtField{{FieldType: common.TSDB_DATA_TYPE_TIMESTAMP}}, - colType: nil, + fieldType: []*Stmt2AllField{ + { + FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + BindType: TAOS_FIELD_TAG, + }, + }, }, want: []byte{ // total Length @@ -2422,7 +2525,7 @@ func TestMarshalBinary(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := MarshalStmt2Binary(tt.args.t, tt.args.isInsert, tt.args.colType, tt.args.tagType) + got, err := MarshalStmt2Binary(tt.args.t, tt.args.isInsert, tt.args.fieldType) if (err != nil) != tt.wantErr { t.Errorf("MarshalStmt2Binary() error = %v, wantErr %v", err, tt.wantErr) return @@ -2431,7 +2534,3 @@ func TestMarshalBinary(t *testing.T) { }) } } - -func TestT(t *testing.T) { - -} diff --git a/driver/wrapper/stmt2.go b/driver/wrapper/stmt2.go index 6db59f81..26871dab 100644 --- a/driver/wrapper/stmt2.go +++ b/driver/wrapper/stmt2.go @@ -35,15 +35,24 @@ func TaosStmt2Init(taosConnect unsafe.Pointer, reqID int64, singleStbInsert bool } // TaosStmt2Prepare int taos_stmt2_prepare(TAOS_STMT2 *stmt, const char *sql, unsigned long length); -func TaosStmt2Prepare(stmt unsafe.Pointer, sql string) int { +func TaosStmt2Prepare(stmt2 unsafe.Pointer, sql string) int { cSql := C.CString(sql) cLen := C.ulong(len(sql)) defer C.free(unsafe.Pointer(cSql)) - return int(C.taos_stmt2_prepare(stmt, cSql, cLen)) + return int(C.taos_stmt2_prepare(stmt2, cSql, cLen)) } // TaosStmt2BindParam int taos_stmt2_bind_param(TAOS_STMT2 *stmt, TAOS_STMT2_BINDV *bindv, int32_t col_idx); -func TaosStmt2BindParam(stmt unsafe.Pointer, isInsert bool, params []*stmt.TaosStmt2BindData, colTypes, tagTypes []*stmt.StmtField, colIdx int32) error { +func TaosStmt2BindParam(stmt2 unsafe.Pointer, isInsert bool, params []*stmt.TaosStmt2BindData, fields []*stmt.Stmt2AllField, colIdx int32) error { + var colTypes []*stmt.Stmt2AllField + var tagTypes []*stmt.Stmt2AllField + for i := 0; i < len(fields); i++ { + if fields[i].BindType == stmt.TAOS_FIELD_COL { + colTypes = append(colTypes, fields[i]) + } else if fields[i].BindType == stmt.TAOS_FIELD_TAG { + tagTypes = append(tagTypes, fields[i]) + } + } count := len(params) if count == 0 { return taosError.NewError(0xffff, "params is empty") @@ -122,15 +131,15 @@ func TaosStmt2BindParam(stmt unsafe.Pointer, isInsert bool, params []*stmt.TaosS cBindv.bind_cols = (**C.TAOS_STMT2_BIND)(unsafe.Pointer(colList)) cBindv.tags = (**C.TAOS_STMT2_BIND)(unsafe.Pointer(tagList)) cBindv.tbnames = (**C.char)(tbNames) - code := int(C.taos_stmt2_bind_param(stmt, &cBindv, C.int32_t(colIdx))) + code := int(C.taos_stmt2_bind_param(stmt2, &cBindv, C.int32_t(colIdx))) if code != 0 { - errStr := TaosStmt2Error(stmt) + errStr := TaosStmt2Error(stmt2) return taosError.NewError(code, errStr) } return nil } -func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt.StmtField) (unsafe.Pointer, []unsafe.Pointer, error) { +func generateTaosStmt2BindsInsert(multiBind [][]driver.Value, fieldTypes []*stmt.Stmt2AllField) (unsafe.Pointer, []unsafe.Pointer, error) { var needFreePointer []unsafe.Pointer if len(multiBind) != len(fieldTypes) { return nil, needFreePointer, fmt.Errorf("data and type length not match, data length: %d, type length: %d", len(multiBind), len(fieldTypes)) @@ -626,84 +635,64 @@ func generateTaosStmt2BindsQuery(multiBind [][]driver.Value) (unsafe.Pointer, [] } // TaosStmt2Exec int taos_stmt2_exec(TAOS_STMT2 *stmt, int *affected_rows); -func TaosStmt2Exec(stmt unsafe.Pointer) int { - return int(C.taos_stmt2_exec(stmt, nil)) +func TaosStmt2Exec(stmt2 unsafe.Pointer) int { + return int(C.taos_stmt2_exec(stmt2, nil)) } // TaosStmt2Close int taos_stmt2_close(TAOS_STMT2 *stmt); -func TaosStmt2Close(stmt unsafe.Pointer) int { - return int(C.taos_stmt2_close(stmt)) +func TaosStmt2Close(stmt2 unsafe.Pointer) int { + return int(C.taos_stmt2_close(stmt2)) } // TaosStmt2IsInsert int taos_stmt2_is_insert(TAOS_STMT2 *stmt, int *insert); -func TaosStmt2IsInsert(stmt unsafe.Pointer) (is bool, errorCode int) { +func TaosStmt2IsInsert(stmt2 unsafe.Pointer) (is bool, errorCode int) { p := C.malloc(C.size_t(4)) isInsert := (*C.int)(p) defer C.free(p) - errorCode = int(C.taos_stmt2_is_insert(stmt, isInsert)) + errorCode = int(C.taos_stmt2_is_insert(stmt2, isInsert)) return int(*isInsert) == 1, errorCode } -// TaosStmt2GetFields int taos_stmt2_get_fields(TAOS_STMT2 *stmt, TAOS_FIELD_T field_type, int *count, TAOS_FIELD_E **fields); -func TaosStmt2GetFields(stmt unsafe.Pointer, fieldType int) (code, count int, fields unsafe.Pointer) { - code = int(C.taos_stmt2_get_fields(stmt, C.TAOS_FIELD_T(fieldType), (*C.int)(unsafe.Pointer(&count)), (**C.TAOS_FIELD_E)(unsafe.Pointer(&fields)))) - return -} - -// TaosStmt2FreeFields void taos_stmt2_free_fields(TAOS_STMT2 *stmt, TAOS_FIELD_E *fields); -func TaosStmt2FreeFields(stmt unsafe.Pointer, fields unsafe.Pointer) { +// TaosStmt2FreeFields void taos_stmt2_free_fields(TAOS_STMT2 *stmt, TAOS_FIELD_ALL *fields); +func TaosStmt2FreeFields(stmt2 unsafe.Pointer, fields unsafe.Pointer) { if fields == nil { return } - C.taos_stmt2_free_fields(stmt, (*C.TAOS_FIELD_E)(fields)) + C.taos_stmt2_free_fields(stmt2, (*C.TAOS_FIELD_ALL)(fields)) } // TaosStmt2Error char *taos_stmt2_error(TAOS_STMT2 *stmt) -func TaosStmt2Error(stmt unsafe.Pointer) string { - return C.GoString(C.taos_stmt2_error(stmt)) +func TaosStmt2Error(stmt2 unsafe.Pointer) string { + return C.GoString(C.taos_stmt2_error(stmt2)) } -// TaosStmt2GetStbFields int taos_stmt2_get_stb_fields(TAOS_STMT2 *stmt, int *count, TAOS_FIELD_STB **fields); -func TaosStmt2GetStbFields(stmt unsafe.Pointer) (code, count int, fields unsafe.Pointer) { - code = int(C.taos_stmt2_get_stb_fields(stmt, (*C.int)(unsafe.Pointer(&count)), (**C.TAOS_FIELD_STB)(unsafe.Pointer(&fields)))) +// TaosStmt2GetFields int taos_stmt2_get_fields(TAOS_STMT2 *stmt, int *count, TAOS_FIELD_ALL **fields); +func TaosStmt2GetFields(stmt2 unsafe.Pointer) (code, count int, fields unsafe.Pointer) { + code = int(C.taos_stmt2_get_fields(stmt2, (*C.int)(unsafe.Pointer(&count)), (**C.TAOS_FIELD_ALL)(unsafe.Pointer(&fields)))) return } -// TaosStmt2FreeStbFields void taos_stmt2_free_stb_fields(TAOS_STMT2 *stmt, TAOS_FIELD_STB *fields); -func TaosStmt2FreeStbFields(stmt unsafe.Pointer, fields unsafe.Pointer) { - C.taos_stmt2_free_stb_fields(stmt, (*C.TAOS_FIELD_STB)(fields)) -} - -//typedef struct TAOS_FIELD_STB { +//typedef struct TAOS_FIELD_ALL { //char name[65]; //int8_t type; //uint8_t precision; //uint8_t scale; //int32_t bytes; //TAOS_FIELD_T field_type; -//} TAOS_FIELD_STB; - -type StmtStbField struct { - Name string `json:"name"` - FieldType int8 `json:"field_type"` - Precision uint8 `json:"precision"` - Scale uint8 `json:"scale"` - Bytes int32 `json:"bytes"` - BindType int8 `json:"bind_type"` -} +//} TAOS_FIELD_ALL; -func ParseStmt2StbFields(num int, fields unsafe.Pointer) []*StmtStbField { +func Stmt2ParseAllFields(num int, fields unsafe.Pointer) []*stmt.Stmt2AllField { if num <= 0 { return nil } if fields == nil { return nil } - result := make([]*StmtStbField, num) + result := make([]*stmt.Stmt2AllField, num) buf := bytes.NewBufferString("") for i := 0; i < num; i++ { - r := &StmtStbField{} - field := *(*C.TAOS_FIELD_STB)(unsafe.Pointer(uintptr(fields) + uintptr(C.sizeof_struct_TAOS_FIELD_STB*C.int(i)))) + r := &stmt.Stmt2AllField{} + field := *(*C.TAOS_FIELD_ALL)(unsafe.Pointer(uintptr(fields) + uintptr(C.sizeof_struct_TAOS_FIELD_ALL*C.int(i)))) for _, c := range field.name { if c == 0 { break diff --git a/driver/wrapper/stmt2_test.go b/driver/wrapper/stmt2_test.go index d9652b12..474d5539 100644 --- a/driver/wrapper/stmt2_test.go +++ b/driver/wrapper/stmt2_test.go @@ -1193,17 +1193,18 @@ func TestStmt2BindData(t *testing.T) { return } assert.True(t, isInsert) - code, count, cfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + code, count, cfields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) t.Error(err) return } + defer TaosStmt2FreeFields(insertStmt, cfields) assert.Equal(t, 2, count) - fields := StmtParseFields(count, cfields) - err = TaosStmt2BindParam(insertStmt, true, tc.params, fields, nil, -1) + fields := Stmt2ParseAllFields(count, cfields) + err = TaosStmt2BindParam(insertStmt, true, tc.params, fields, -1) if err != nil { t.Error(err) return @@ -2367,7 +2368,7 @@ func TestStmt2BindBinary(t *testing.T) { return } assert.True(t, isInsert) - code, count, cfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + code, count, cfields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) @@ -2376,8 +2377,8 @@ func TestStmt2BindBinary(t *testing.T) { } defer TaosStmt2FreeFields(insertStmt, cfields) assert.Equal(t, 2, count) - fields := StmtParseFields(count, cfields) - bs, err := stmt.MarshalStmt2Binary(tc.params, true, fields, nil) + fields := Stmt2ParseAllFields(count, cfields) + bs, err := stmt.MarshalStmt2Binary(tc.params, true, fields) if err != nil { t.Error("marshal binary error:", err) return @@ -2496,22 +2497,12 @@ func TestStmt2AllType(t *testing.T) { params := []*stmt.TaosStmt2BindData{{ TableName: "ctb1", }} - err = TaosStmt2BindParam(insertStmt, true, params, nil, nil, -1) + err = TaosStmt2BindParam(insertStmt, true, params, nil, -1) if err != nil { t.Error(err) return } - code, count, cTablefields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TBNAME) - if code != 0 { - errStr := TaosStmt2Error(insertStmt) - err = taosError.NewError(code, errStr) - t.Error(err) - return - } - assert.Equal(t, 1, count) - assert.Equal(t, unsafe.Pointer(nil), cTablefields) - isInsert, code := TaosStmt2IsInsert(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) @@ -2520,28 +2511,17 @@ func TestStmt2AllType(t *testing.T) { return } assert.True(t, isInsert) - code, count, cColFields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) + code, count, cFields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) t.Error(err) return } - defer TaosStmt2FreeFields(insertStmt, cColFields) - assert.Equal(t, 16, count) - colFields := StmtParseFields(count, cColFields) - t.Log(colFields) - code, count, cTagfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TAG) - if code != 0 { - errStr := TaosStmt2Error(insertStmt) - err = taosError.NewError(code, errStr) - t.Error(err) - return - } - defer TaosStmt2FreeFields(insertStmt, cTagfields) - assert.Equal(t, 16, count) - tagFields := StmtParseFields(count, cTagfields) - t.Log(tagFields) + defer TaosStmt2FreeFields(insertStmt, cFields) + assert.Equal(t, 32, count) + fields := Stmt2ParseAllFields(count, cFields) + t.Log(fields) now := time.Now() //colTypes := []int8{ // common.TSDB_DATA_TYPE_TIMESTAMP, @@ -2681,7 +2661,7 @@ func TestStmt2AllType(t *testing.T) { }, }} - err = TaosStmt2BindParam(insertStmt, true, params2, colFields, tagFields, -1) + err = TaosStmt2BindParam(insertStmt, true, params2, fields, -1) if err != nil { t.Error(err) return @@ -2787,7 +2767,7 @@ func TestStmt2AllTypeBytes(t *testing.T) { params := []*stmt.TaosStmt2BindData{{ TableName: "ctb1", }} - bs, err := stmt.MarshalStmt2Binary(params, true, nil, nil) + bs, err := stmt.MarshalStmt2Binary(params, true, nil) if err != nil { t.Error(err) return @@ -2798,16 +2778,6 @@ func TestStmt2AllTypeBytes(t *testing.T) { return } - code, count, cTablefields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TBNAME) - if code != 0 { - errStr := TaosStmt2Error(insertStmt) - err = taosError.NewError(code, errStr) - t.Error(err) - return - } - assert.Equal(t, 1, count) - assert.Equal(t, unsafe.Pointer(nil), cTablefields) - isInsert, code := TaosStmt2IsInsert(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) @@ -2816,28 +2786,18 @@ func TestStmt2AllTypeBytes(t *testing.T) { return } assert.True(t, isInsert) - code, count, cColFields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_COL) - if code != 0 { - errStr := TaosStmt2Error(insertStmt) - err = taosError.NewError(code, errStr) - t.Error(err) - return - } - defer TaosStmt2FreeFields(insertStmt, cColFields) - assert.Equal(t, 16, count) - colFields := StmtParseFields(count, cColFields) - t.Log(colFields) - code, count, cTagfields := TaosStmt2GetFields(insertStmt, stmt.TAOS_FIELD_TAG) + + code, count, cFields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) t.Error(err) return } - defer TaosStmt2FreeFields(insertStmt, cTagfields) - assert.Equal(t, 16, count) - tagFields := StmtParseFields(count, cTagfields) - t.Log(tagFields) + defer TaosStmt2FreeFields(insertStmt, cFields) + assert.Equal(t, 32, count) + fields := Stmt2ParseAllFields(count, cFields) + t.Log(fields) now := time.Now() //colTypes := []int8{ // common.TSDB_DATA_TYPE_TIMESTAMP, @@ -2976,7 +2936,7 @@ func TestStmt2AllTypeBytes(t *testing.T) { }, }, }} - bs, err = stmt.MarshalStmt2Binary(params2, true, colFields, tagFields) + bs, err = stmt.MarshalStmt2Binary(params2, true, fields) if err != nil { t.Error(err) return @@ -3060,13 +3020,15 @@ func TestStmt2Query(t *testing.T) { } assert.True(t, isInsert) now := time.Now().Round(time.Millisecond) - colTypes := []*stmt.StmtField{ + colTypes := []*stmt.Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: stmt.TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: stmt.TAOS_FIELD_COL, }, } params := []*stmt.TaosStmt2BindData{ @@ -3097,7 +3059,7 @@ func TestStmt2Query(t *testing.T) { }, }, } - err = TaosStmt2BindParam(stmt2, true, params, colTypes, nil, -1) + err = TaosStmt2BindParam(stmt2, true, params, colTypes, -1) if err != nil { t.Error(err) return @@ -3145,7 +3107,7 @@ func TestStmt2Query(t *testing.T) { }, } - err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + err = TaosStmt2BindParam(stmt2, false, params, nil, -1) if err != nil { t.Error(err) return @@ -3254,13 +3216,15 @@ func TestStmt2QueryBytes(t *testing.T) { } assert.True(t, isInsert) now := time.Now().Round(time.Millisecond) - colTypes := []*stmt.StmtField{ + colTypes := []*stmt.Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: stmt.TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: stmt.TAOS_FIELD_COL, }, } params := []*stmt.TaosStmt2BindData{ @@ -3291,7 +3255,7 @@ func TestStmt2QueryBytes(t *testing.T) { }, }, } - bs, err := stmt.MarshalStmt2Binary(params, true, colTypes, nil) + bs, err := stmt.MarshalStmt2Binary(params, true, colTypes) if err != nil { t.Error(err) return @@ -3343,7 +3307,7 @@ func TestStmt2QueryBytes(t *testing.T) { }, }, } - bs, err = stmt.MarshalStmt2Binary(params, false, nil, nil) + bs, err = stmt.MarshalStmt2Binary(params, false, nil) if err != nil { t.Error(err) return @@ -3458,23 +3422,23 @@ func TestStmt2QueryAllType(t *testing.T) { handler := cgo.NewHandle(caller) stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) prepareInsertSql := "insert into t values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" - colTypes := []*stmt.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, - {FieldType: common.TSDB_DATA_TYPE_BOOL}, - {FieldType: common.TSDB_DATA_TYPE_TINYINT}, - {FieldType: common.TSDB_DATA_TYPE_SMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_INT}, - {FieldType: common.TSDB_DATA_TYPE_BIGINT}, - {FieldType: common.TSDB_DATA_TYPE_UTINYINT}, - {FieldType: common.TSDB_DATA_TYPE_USMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_UINT}, - {FieldType: common.TSDB_DATA_TYPE_UBIGINT}, - {FieldType: common.TSDB_DATA_TYPE_FLOAT}, - {FieldType: common.TSDB_DATA_TYPE_DOUBLE}, - {FieldType: common.TSDB_DATA_TYPE_BINARY}, - {FieldType: common.TSDB_DATA_TYPE_VARBINARY}, - {FieldType: common.TSDB_DATA_TYPE_GEOMETRY}, - {FieldType: common.TSDB_DATA_TYPE_NCHAR}, + colTypes := []*stmt.Stmt2AllField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BOOL, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_TINYINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_SMALLINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_INT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BIGINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UTINYINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_USMALLINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UBIGINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_FLOAT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_DOUBLE, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BINARY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_VARBINARY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_GEOMETRY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_NCHAR, BindType: stmt.TAOS_FIELD_COL}, } now := time.Now() @@ -3578,7 +3542,7 @@ func TestStmt2QueryAllType(t *testing.T) { return } assert.True(t, isInsert) - err = TaosStmt2BindParam(stmt2, true, params2, colTypes, nil, -1) + err = TaosStmt2BindParam(stmt2, true, params2, colTypes, -1) if err != nil { t.Error(err) return @@ -3636,7 +3600,7 @@ func TestStmt2QueryAllType(t *testing.T) { }, }, } - err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + err = TaosStmt2BindParam(stmt2, false, params, nil, -1) if err != nil { t.Error(err) return @@ -3732,23 +3696,23 @@ func TestStmt2QueryAllTypeBytes(t *testing.T) { handler := cgo.NewHandle(caller) stmt2 := TaosStmt2Init(conn, 0xcc123, false, false, handler) prepareInsertSql := "insert into t values (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" - colTypes := []*stmt.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, - {FieldType: common.TSDB_DATA_TYPE_BOOL}, - {FieldType: common.TSDB_DATA_TYPE_TINYINT}, - {FieldType: common.TSDB_DATA_TYPE_SMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_INT}, - {FieldType: common.TSDB_DATA_TYPE_BIGINT}, - {FieldType: common.TSDB_DATA_TYPE_UTINYINT}, - {FieldType: common.TSDB_DATA_TYPE_USMALLINT}, - {FieldType: common.TSDB_DATA_TYPE_UINT}, - {FieldType: common.TSDB_DATA_TYPE_UBIGINT}, - {FieldType: common.TSDB_DATA_TYPE_FLOAT}, - {FieldType: common.TSDB_DATA_TYPE_DOUBLE}, - {FieldType: common.TSDB_DATA_TYPE_BINARY}, - {FieldType: common.TSDB_DATA_TYPE_VARBINARY}, - {FieldType: common.TSDB_DATA_TYPE_GEOMETRY}, - {FieldType: common.TSDB_DATA_TYPE_NCHAR}, + colTypes := []*stmt.Stmt2AllField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BOOL, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_TINYINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_SMALLINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_INT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BIGINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UTINYINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_USMALLINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_UBIGINT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_FLOAT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_DOUBLE, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_BINARY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_VARBINARY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_GEOMETRY, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_NCHAR, BindType: stmt.TAOS_FIELD_COL}, } now := time.Now() @@ -3852,7 +3816,7 @@ func TestStmt2QueryAllTypeBytes(t *testing.T) { return } assert.True(t, isInsert) - bs, err := stmt.MarshalStmt2Binary(params2, true, colTypes, nil) + bs, err := stmt.MarshalStmt2Binary(params2, true, colTypes) if err != nil { t.Error(err) return @@ -3915,7 +3879,7 @@ func TestStmt2QueryAllTypeBytes(t *testing.T) { }, }, } - bs, err = stmt.MarshalStmt2Binary(params, false, nil, nil) + bs, err = stmt.MarshalStmt2Binary(params, false, nil) if err != nil { t.Error(err) return @@ -4024,14 +3988,12 @@ func TestStmt2Json(t *testing.T) { {int32(1)}, }, }} - colTypes := []*stmt.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond}, - {FieldType: common.TSDB_DATA_TYPE_INT}, - } - tagTypes := []*stmt.StmtField{ - {FieldType: common.TSDB_DATA_TYPE_JSON}, + types := []*stmt.Stmt2AllField{ + {FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_INT, BindType: stmt.TAOS_FIELD_COL}, + {FieldType: common.TSDB_DATA_TYPE_JSON, BindType: stmt.TAOS_FIELD_TAG}, } - err = TaosStmt2BindParam(stmt2, true, params, colTypes, tagTypes, -1) + err = TaosStmt2BindParam(stmt2, true, params, types, -1) if err != nil { t.Error(err) return @@ -4058,7 +4020,7 @@ func TestStmt2Json(t *testing.T) { {int32(1)}, }, }} - err = TaosStmt2BindParam(stmt2, false, params, nil, nil, -1) + err = TaosStmt2BindParam(stmt2, false, params, nil, -1) if err != nil { t.Error(err) return @@ -4185,21 +4147,21 @@ func TestStmt2BindMultiTables(t *testing.T) { Tags: []driver.Value{int32(3)}, }, } - colType := []*stmt.StmtField{ + fields := []*stmt.Stmt2AllField{ { FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, Precision: common.PrecisionMilliSecond, + BindType: stmt.TAOS_FIELD_COL, }, { FieldType: common.TSDB_DATA_TYPE_BIGINT, + BindType: stmt.TAOS_FIELD_COL, }, - } - tagType := []*stmt.StmtField{ { FieldType: common.TSDB_DATA_TYPE_INT, + BindType: stmt.TAOS_FIELD_TAG, }, } - isInsert, code := TaosStmt2IsInsert(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) @@ -4209,7 +4171,7 @@ func TestStmt2BindMultiTables(t *testing.T) { } assert.True(t, isInsert) - err = TaosStmt2BindParam(insertStmt, true, binds, colType, tagType, -1) + err = TaosStmt2BindParam(insertStmt, true, binds, fields, -1) if err != nil { t.Error(err) return @@ -5136,7 +5098,7 @@ func TestTaosStmt2GetStbFields(t *testing.T) { t.Error(err) return } - expectMap := map[string]*StmtStbField{ + expectMap := map[string]*stmt.Stmt2AllField{ "tts": { Name: "tts", FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, @@ -5381,16 +5343,16 @@ func TestTaosStmt2GetStbFields(t *testing.T) { t.Error(err) return } - code, count, fields := TaosStmt2GetStbFields(stmt2) + code, count, fields := TaosStmt2GetFields(stmt2) if code != 0 { errStr := TaosStmt2Error(stmt2) err := taosError.NewError(code, errStr) t.Error(err) return } - fs := ParseStmt2StbFields(count, fields) - TaosStmt2FreeStbFields(stmt2, fields) - expect := make([]*StmtStbField, len(tt.expect)) + fs := Stmt2ParseAllFields(count, fields) + TaosStmt2FreeFields(stmt2, fields) + expect := make([]*stmt.Stmt2AllField, len(tt.expect)) for i := 0; i < len(tt.expect); i++ { assert.Equal(t, expectMap[tt.expect[i]].Name, fs[i].Name) assert.Equal(t, expectMap[tt.expect[i]].FieldType, fs[i].FieldType) @@ -5414,21 +5376,21 @@ func TestTaosStmt2GetStbFields(t *testing.T) { t.Error(err) return } - code, count, fields := TaosStmt2GetStbFields(stmt2) + code, count, fields := TaosStmt2GetFields(stmt2) if code != 0 { errStr := TaosStmt2Error(stmt2) err := taosError.NewError(code, errStr) t.Error(err) return } - TaosStmt2FreeStbFields(stmt2, fields) + TaosStmt2FreeFields(stmt2, fields) assert.Equal(t, 2, count) } func TestWrongParseStmt2StbFields(t *testing.T) { - fs := ParseStmt2StbFields(0, nil) + fs := Stmt2ParseAllFields(0, nil) assert.Nil(t, fs) - fs = ParseStmt2StbFields(2, nil) + fs = Stmt2ParseAllFields(2, nil) assert.Nil(t, fs) } @@ -5515,18 +5477,18 @@ func TestStmt2BindTbnameAsValue(t *testing.T) { } assert.True(t, isInsert) - code, count, cFields := TaosStmt2GetStbFields(insertStmt) + code, count, cFields := TaosStmt2GetFields(insertStmt) if code != 0 { errStr := TaosStmt2Error(insertStmt) err = taosError.NewError(code, errStr) t.Error(err) return } - defer TaosStmt2FreeStbFields(insertStmt, cFields) + defer TaosStmt2FreeFields(insertStmt, cFields) assert.Equal(t, 33, count) - fields := ParseStmt2StbFields(count, cFields) + fields := Stmt2ParseAllFields(count, cFields) assert.Equal(t, 33, len(fields)) - expectMap := map[string]*StmtStbField{ + expectMap := map[string]*stmt.Stmt2AllField{ "tts": { Name: "tts", FieldType: common.TSDB_DATA_TYPE_TIMESTAMP, @@ -5728,19 +5690,10 @@ func TestStmt2BindTbnameAsValue(t *testing.T) { BindType: stmt.TAOS_FIELD_TBNAME, }, } - var colFields, tagFields []*stmt.StmtField + for i := 0; i < 33; i++ { expect := expectMap[fields[i].Name] assert.Equal(t, expect, fields[i]) - field := &stmt.StmtField{ - FieldType: fields[i].FieldType, - Precision: fields[i].Precision, - } - if fields[i].BindType == stmt.TAOS_FIELD_COL { - colFields = append(colFields, field) - } else if fields[i].BindType == stmt.TAOS_FIELD_TAG { - tagFields = append(tagFields, field) - } } now := time.Now() @@ -5863,7 +5816,7 @@ func TestStmt2BindTbnameAsValue(t *testing.T) { }, }, }} - bs, err := stmt.MarshalStmt2Binary(params2, true, colFields, tagFields) + bs, err := stmt.MarshalStmt2Binary(params2, true, fields) assert.NoError(t, err) err = TaosStmt2BindBinary(insertStmt, bs, -1) if err != nil {