Skip to content

Commit

Permalink
feat: rename taos_stmt2_get_stb_fields to taos_stmt2_get_fields
Browse files Browse the repository at this point in the history
  • Loading branch information
huskar-t committed Dec 16, 2024
1 parent c0dd734 commit dfa4db2
Show file tree
Hide file tree
Showing 8 changed files with 388 additions and 376 deletions.
25 changes: 13 additions & 12 deletions controller/ws/ws/stmt2.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/taosdata/taosadapter/v3/controller/ws/wstool"
"github.com/taosdata/taosadapter/v3/db/async"
"github.com/taosdata/taosadapter/v3/db/syncinterface"
"github.com/taosdata/taosadapter/v3/driver/common/stmt"
errors2 "github.com/taosdata/taosadapter/v3/driver/errors"
"github.com/taosdata/taosadapter/v3/driver/wrapper"
"github.com/taosdata/taosadapter/v3/log"
Expand Down Expand Up @@ -81,15 +82,15 @@ type stmt2PrepareRequest struct {
}

type stmt2PrepareResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Action string `json:"action"`
ReqID uint64 `json:"req_id"`
Timing int64 `json:"timing"`
StmtID uint64 `json:"stmt_id"`
IsInsert bool `json:"is_insert"`
Fields []*wrapper.StmtStbField `json:"fields"`
FieldsCount int `json:"fields_count"`
Code int `json:"code"`
Message string `json:"message"`
Action string `json:"action"`
ReqID uint64 `json:"req_id"`
Timing int64 `json:"timing"`
StmtID uint64 `json:"stmt_id"`
IsInsert bool `json:"is_insert"`
Fields []*stmt.Stmt2AllField `json:"fields"`
FieldsCount int `json:"fields_count"`
}

func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Session, action string, req stmt2PrepareRequest, logger *logrus.Entry, isDebug bool) {
Expand Down Expand Up @@ -119,15 +120,15 @@ func (h *messageHandler) stmt2Prepare(ctx context.Context, session *melody.Sessi
stmtItem.isInsert = isInsert
prepareResp := &stmt2PrepareResponse{StmtID: req.StmtID, IsInsert: isInsert}
if req.GetFields {
code, count, fields := syncinterface.TaosStmt2GetStbFields(stmt2, logger, isDebug)
code, count, fields := syncinterface.TaosStmt2GetFields(stmt2, logger, isDebug)
if code != 0 {
errStr := wrapper.TaosStmt2Error(stmt2)
logger.Errorf("stmt2 get fields error, code:%d, err:%s", code, errStr)
stmtErrorResponse(ctx, session, logger, action, req.ReqID, code, errStr, req.StmtID)
return
}
defer wrapper.TaosStmt2FreeStbFields(stmt2, fields)
stbFields := wrapper.ParseStmt2StbFields(count, fields)
defer wrapper.TaosStmt2FreeFields(stmt2, fields)
stbFields := wrapper.Stmt2ParseAllFields(count, fields)
prepareResp.Fields = stbFields
prepareResp.FieldsCount = count

Expand Down
34 changes: 3 additions & 31 deletions controller/ws/ws/stmt2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,6 @@ func TestWsStmt2(t *testing.T) {
assert.Equal(t, 0, prepareResp.Code, prepareResp.Message)
assert.True(t, prepareResp.IsInsert)
assert.Equal(t, 18, len(prepareResp.Fields))
var colFields []*stmtCommon.StmtField
var tagFields []*stmtCommon.StmtField
for i := 0; i < 18; i++ {
field := &stmtCommon.StmtField{
FieldType: prepareResp.Fields[i].FieldType,
Precision: prepareResp.Fields[i].Precision,
}
switch prepareResp.Fields[i].BindType {
case stmtCommon.TAOS_FIELD_COL:
colFields = append(colFields, field)
case stmtCommon.TAOS_FIELD_TAG:
tagFields = append(tagFields, field)
}
}
// bind
now := time.Now()
cols := [][]driver.Value{
Expand Down Expand Up @@ -141,7 +127,7 @@ func TestWsStmt2(t *testing.T) {
Tags: tag,
Cols: cols,
}
bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, colFields, tagFields)
bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, prepareResp.Fields)
assert.NoError(t, err)
bindReq := make([]byte, len(bs)+30)
// req_id
Expand Down Expand Up @@ -475,7 +461,7 @@ func Stmt2Query(t *testing.T, db string, prepareDataSql []string) {
},
},
}
b, err := stmtCommon.MarshalStmt2Binary(params, false, nil, nil)
b, err := stmtCommon.MarshalStmt2Binary(params, false, nil)
assert.NoError(t, err)
block.Write(b)

Expand Down Expand Up @@ -702,21 +688,7 @@ func TestStmt2BindWithStbFields(t *testing.T) {
Tags: tag,
Cols: cols,
}
var colFields []*stmtCommon.StmtField
var tagFields []*stmtCommon.StmtField
for i := 0; i < 18; i++ {
field := &stmtCommon.StmtField{
FieldType: prepareResp.Fields[i].FieldType,
Precision: prepareResp.Fields[i].Precision,
}
switch prepareResp.Fields[i].BindType {
case stmtCommon.TAOS_FIELD_COL:
colFields = append(colFields, field)
case stmtCommon.TAOS_FIELD_TAG:
tagFields = append(tagFields, field)
}
}
bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, colFields, tagFields)
bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, prepareResp.Fields)
assert.NoError(t, err)
bindReq := make([]byte, len(bs)+30)
// req_id
Expand Down
18 changes: 3 additions & 15 deletions db/syncinterface/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,13 @@ func TaosStmt2IsInsert(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool)
return isInsert, code
}

func TaosStmt2GetFields(stmt2 unsafe.Pointer, fieldType int, logger *logrus.Entry, isDebug bool) (code, count int, fields unsafe.Pointer) {
logger.Tracef("call taos_stmt2_get_fields, stmt2:%p, fieldType:%d", stmt2, fieldType)
func TaosStmt2GetFields(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) (code, count int, fields unsafe.Pointer) {
logger.Tracef("call taos_stmt2_get_fields, stmt2:%p", stmt2)
s := log.GetLogNow(isDebug)
thread.SyncLocker.Lock()
logger.Debugf("get thread lock for taos_stmt2_get_fields cost:%s", log.GetLogDuration(isDebug, s))
s = log.GetLogNow(isDebug)
code, count, fields = wrapper.TaosStmt2GetFields(stmt2, fieldType)
code, count, fields = wrapper.TaosStmt2GetFields(stmt2)
logger.Debugf("taos_stmt2_get_fields finish, code:%d, count:%d, fields:%p, cost:%s", code, count, fields, log.GetLogDuration(isDebug, s))
thread.SyncLocker.Unlock()
return code, count, fields
Expand Down Expand Up @@ -410,18 +410,6 @@ func TaosStmt2BindBinary(stmt2 unsafe.Pointer, data []byte, colIdx int32, logger
return err
}

func TaosStmt2GetStbFields(stmt2 unsafe.Pointer, logger *logrus.Entry, isDebug bool) (code, count int, fields unsafe.Pointer) {
logger.Tracef("call taos_stmt2_get_stb_fields, stmt2:%p", stmt2)
s := log.GetLogNow(isDebug)
thread.SyncLocker.Lock()
logger.Debugf("get thread lock for taos_stmt2_get_stb_fields cost:%s", log.GetLogDuration(isDebug, s))
s = log.GetLogNow(isDebug)
code, count, fields = wrapper.TaosStmt2GetStbFields(stmt2)
logger.Debugf("taos_stmt2_get_stb_fields finish, code:%d, count:%d, fields:%p, cost:%s", code, count, fields, log.GetLogDuration(isDebug, s))
thread.SyncLocker.Unlock()
return code, count, fields
}

func TaosOptionsConnection(conn unsafe.Pointer, option int, value *string, logger *logrus.Entry, isDebug bool) int {
if value == nil {
logger.Tracef("call taos_options_connection, conn:%p, option:%d, value:<nil>", conn, option)
Expand Down
52 changes: 22 additions & 30 deletions db/syncinterface/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,16 +509,16 @@ func TestTaosStmt2(t *testing.T) {
return
}
assert.True(t, isInsert)
code, count, fiels := TaosStmt2GetStbFields(stmt, logger, isDebug)
code, count, fields := TaosStmt2GetFields(stmt, logger, isDebug)
if !assert.Equal(t, 0, code, wrapper.TaosStmtErrStr(stmt)) {
return
}
assert.Equal(t, 4, count)
assert.NotNil(t, fiels)
assert.NotNil(t, fields)
defer func() {
wrapper.TaosStmt2FreeFields(stmt, fiels)
wrapper.TaosStmt2FreeFields(stmt, fields)
}()
fs := wrapper.ParseStmt2StbFields(count, fiels)
fs := wrapper.Stmt2ParseAllFields(count, fields)
assert.Equal(t, 4, len(fs))
assert.Equal(t, "tbname", fs[0].Name)
assert.Equal(t, int8(common.TSDB_DATA_TYPE_BINARY), fs[0].FieldType)
Expand All @@ -537,45 +537,37 @@ func TestTaosStmt2(t *testing.T) {
binds := &stmtCommon.TaosStmt2BindData{
TableName: tableName,
}
bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, nil, nil)
bs, err := stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, nil)
assert.NoError(t, err)
err = TaosStmt2BindBinary(stmt, bs, -1, logger, isDebug)
assert.NoError(t, err)

code, num, fields := TaosStmt2GetFields(stmt, stmtCommon.TAOS_FIELD_COL, logger, isDebug)
code, num, fields2 := TaosStmt2GetFields(stmt, logger, isDebug)
if !assert.Equal(t, 0, code, wrapper.TaosStmtErrStr(stmt)) {
return
}
assert.Equal(t, 2, num)
assert.Equal(t, 3, num)
assert.NotNil(t, fields)
defer func() {
wrapper.TaosStmt2FreeFields(stmt, fields)
}()
colFields := wrapper.StmtParseFields(num, fields)
assert.Equal(t, 2, len(colFields))
assert.Equal(t, "ts", colFields[0].Name)
assert.Equal(t, int8(common.TSDB_DATA_TYPE_TIMESTAMP), colFields[0].FieldType)
assert.Equal(t, "v", colFields[1].Name)
assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), colFields[1].FieldType)
code, num, tags := TaosStmt2GetFields(stmt, stmtCommon.TAOS_FIELD_TAG, logger, isDebug)
if !assert.Equal(t, 0, code, wrapper.TaosStmtErrStr(stmt)) {
return
}
assert.Equal(t, 1, num)
assert.NotNil(t, tags)
defer func() {
wrapper.TaosStmt2FreeFields(stmt, tags)
wrapper.TaosStmt2FreeFields(stmt, fields2)
}()
tagFields := wrapper.StmtParseFields(num, tags)
assert.Equal(t, 1, len(tagFields))
assert.Equal(t, "id", tagFields[0].Name)
assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), tagFields[0].FieldType)

fsAfterBindTableName := wrapper.Stmt2ParseAllFields(num, fields2)
assert.Equal(t, 3, len(fsAfterBindTableName))
assert.Equal(t, "id", fsAfterBindTableName[0].Name)
assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), fsAfterBindTableName[0].FieldType)
assert.Equal(t, int8(stmtCommon.TAOS_FIELD_TAG), fsAfterBindTableName[0].BindType)
assert.Equal(t, "ts", fsAfterBindTableName[1].Name)
assert.Equal(t, int8(common.TSDB_DATA_TYPE_TIMESTAMP), fsAfterBindTableName[1].FieldType)
assert.Equal(t, int8(stmtCommon.TAOS_FIELD_COL), fsAfterBindTableName[1].BindType)
assert.Equal(t, uint8(common.PrecisionMilliSecond), fsAfterBindTableName[1].Precision)
assert.Equal(t, "v", fsAfterBindTableName[2].Name)
assert.Equal(t, int8(common.TSDB_DATA_TYPE_INT), fsAfterBindTableName[2].FieldType)
assert.Equal(t, int8(stmtCommon.TAOS_FIELD_COL), fsAfterBindTableName[2].BindType)
binds = &stmtCommon.TaosStmt2BindData{
Tags: []driver.Value{int32(1)},
}

bs, err = stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, nil, tagFields)
bs, err = stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, fsAfterBindTableName[0:1])
assert.NoError(t, err)
err = TaosStmt2BindBinary(stmt, bs, -1, logger, isDebug)
assert.NoError(t, err)
Expand All @@ -587,7 +579,7 @@ func TestTaosStmt2(t *testing.T) {
{int32(100), int32(101)},
},
}
bs, err = stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, colFields, nil)
bs, err = stmtCommon.MarshalStmt2Binary([]*stmtCommon.TaosStmt2BindData{binds}, true, fsAfterBindTableName[1:])
assert.NoError(t, err)
err = TaosStmt2BindBinary(stmt, bs, -1, logger, isDebug)
assert.NoError(t, err)
Expand Down
22 changes: 20 additions & 2 deletions driver/common/stmt/stmt2.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ const (
BindDataIsNullOffset = BindDataNumOffset + 4
)

func MarshalStmt2Binary(bindData []*TaosStmt2BindData, isInsert bool, colType, tagType []*StmtField) ([]byte, error) {
func MarshalStmt2Binary(bindData []*TaosStmt2BindData, isInsert bool, fields []*Stmt2AllField) ([]byte, error) {
var colType []*Stmt2AllField
var tagType []*Stmt2AllField
for i := 0; i < len(fields); i++ {
if fields[i].BindType == TAOS_FIELD_COL {
colType = append(colType, fields[i])
} else if fields[i].BindType == TAOS_FIELD_TAG {
tagType = append(tagType, fields[i])
}
}
// count
count := len(bindData)
if count == 0 {
Expand Down Expand Up @@ -215,7 +224,7 @@ func getBindDataHeaderLength(num int, needLength bool) int {
return length
}

func generateBindColData(data []driver.Value, colType *StmtField, tmpBuffer *bytes.Buffer) ([]byte, error) {
func generateBindColData(data []driver.Value, colType *Stmt2AllField, tmpBuffer *bytes.Buffer) ([]byte, error) {
num := len(data)
tmpBuffer.Reset()
needLength := needLength(colType.FieldType)
Expand Down Expand Up @@ -578,3 +587,12 @@ func needLength(colType int8) bool {
}
return false
}

type Stmt2AllField struct {
Name string `json:"name"`
FieldType int8 `json:"field_type"`
Precision uint8 `json:"precision"`
Scale uint8 `json:"scale"`
Bytes int32 `json:"bytes"`
BindType int8 `json:"bind_type"`
}
Loading

0 comments on commit dfa4db2

Please sign in to comment.