Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support set connection options #366

Merged
merged 9 commits into from
Dec 16, 2024
73 changes: 47 additions & 26 deletions controller/rest/restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/taosdata/taosadapter/v3/controller"
"github.com/taosdata/taosadapter/v3/db/async"
"github.com/taosdata/taosadapter/v3/db/commonpool"
"github.com/taosdata/taosadapter/v3/db/syncinterface"
"github.com/taosdata/taosadapter/v3/driver/common"
"github.com/taosdata/taosadapter/v3/driver/common/parser"
tErrors "github.com/taosdata/taosadapter/v3/driver/errors"
Expand All @@ -31,7 +32,6 @@ import (
"github.com/taosdata/taosadapter/v3/tools/generator"
"github.com/taosdata/taosadapter/v3/tools/iptool"
"github.com/taosdata/taosadapter/v3/tools/jsonbuilder"
"github.com/taosdata/taosadapter/v3/tools/layout"
"github.com/taosdata/taosadapter/v3/tools/pool"
"github.com/taosdata/taosadapter/v3/tools/sqltype"
)
Expand Down Expand Up @@ -112,19 +112,30 @@ type TDEngineRestfulRespDoc struct {
// @Router /rest/sql [post]
func (ctl *Restful) sql(c *gin.Context) {
db := c.Param("db")
timeZone := c.Query("tz")
location := time.UTC
var err error
logger := c.MustGet(LoggerKey).(*logrus.Entry)
reqID := c.MustGet(config.ReqIDKey).(int64)
if len(timeZone) != 0 {
location, err = time.LoadLocation(timeZone)
connTimezone, exists := c.GetQuery("conn_tz")
location := time.UTC
if exists {
location, err = time.LoadLocation(connTimezone)
if err != nil {
logger.Errorf("load location:%s error:%s", timeZone, err)
logger.Errorf("load conn_tz location:%s fail, error:%s", connTimezone, err)
BadRequestResponseWithMsg(c, logger, 0xffff, err.Error())
return
}
} else {
timezone, exists := c.GetQuery("tz")
if exists {
location, err = time.LoadLocation(timezone)
if err != nil {
logger.Errorf("load tz location:%s fail, error:%s", timezone, err)
BadRequestResponseWithMsg(c, logger, 0xffff, err.Error())
return
}
}
}

var returnObj bool
if returnObjStr := c.Query("row_with_meta"); len(returnObjStr) != 0 {
if returnObj, err = strconv.ParseBool(returnObjStr); err != nil {
Expand All @@ -134,21 +145,7 @@ func (ctl *Restful) sql(c *gin.Context) {
}
}

timeBuffer := make([]byte, 0, 30)
DoQuery(c, db, func(builder *jsonbuilder.Stream, ts int64, precision int) {
timeBuffer = timeBuffer[:0]
switch precision {
case common.PrecisionMilliSecond: // milli-second
timeBuffer = time.Unix(ts/1e3, (ts%1e3)*1e6).In(location).AppendFormat(timeBuffer, layout.LayoutMillSecond)
case common.PrecisionMicroSecond: // micro-second
timeBuffer = time.Unix(ts/1e6, (ts%1e6)*1e3).In(location).AppendFormat(timeBuffer, layout.LayoutMicroSecond)
case common.PrecisionNanoSecond: // nano-second
timeBuffer = time.Unix(0, ts).In(location).AppendFormat(timeBuffer, layout.LayoutNanoSecond)
default:
logger.Errorf("unknown precision:%d", precision)
}
builder.WriteString(string(timeBuffer))
}, reqID, returnObj, logger)
DoQuery(c, db, location, reqID, returnObj, logger)
}

type TDEngineRestfulResp struct {
Expand All @@ -159,7 +156,7 @@ type TDEngineRestfulResp struct {
Rows int `json:"rows,omitempty"`
}

func DoQuery(c *gin.Context, db string, timeFunc ctools.FormatTimeFunc, reqID int64, returnObj bool, logger *logrus.Entry) {
func DoQuery(c *gin.Context, db string, location *time.Location, reqID int64, returnObj bool, logger *logrus.Entry) {
var s time.Time
isDebug := log.IsDebug()
b, err := c.GetRawData()
Expand Down Expand Up @@ -218,14 +215,37 @@ func DoQuery(c *gin.Context, db string, timeFunc ctools.FormatTimeFunc, reqID in
}
logger.Debugf("put connection finish, cost:%s", log.GetLogDuration(isDebug, s))
}()

// set connection options
success := trySetConnectionOptions(c, taosConnect.TaosConnection, logger, isDebug)
if !success {
monitor.RestRecordResult(sqlType, false)
return
}
if len(db) > 0 {
// Attempt to select the database does not return even if there is an error
// To avoid error reporting in the `create database` statement
logger.Tracef("select db %s", db)
_ = async.GlobalAsync.TaosExecWithoutResult(taosConnect.TaosConnection, logger, isDebug, fmt.Sprintf("use `%s`", db), reqID)
}
execute(c, logger, isDebug, taosConnect.TaosConnection, sql, timeFunc, reqID, sqlType, returnObj)
execute(c, logger, isDebug, taosConnect.TaosConnection, sql, reqID, sqlType, returnObj, location)
}

func trySetConnectionOptions(c *gin.Context, conn unsafe.Pointer, logger *logrus.Entry, isDebug bool) bool {
keys := [3]string{"conn_tz", "app", "ip"}
options := [3]int{common.TSDB_OPTION_CONNECTION_TIMEZONE, common.TSDB_OPTION_CONNECTION_USER_APP, common.TSDB_OPTION_CONNECTION_USER_IP}
for i := 0; i < 3; i++ {
val := c.Query(keys[i])
if val != "" {
code := syncinterface.TaosOptionsConnection(conn, options[i], &val, logger, isDebug)
if code != httperror.SUCCESS {
errStr := wrapper.TaosErrorStr(nil)
logger.Errorf("set connection options error, option:%d, val:%s, code:%d, message:%s", options[i], val, code, errStr)
TaosErrorResponse(c, logger, code, errStr)
return false
}
}
}
return true
}

var (
Expand All @@ -241,7 +261,7 @@ var (
Timing = []byte(`,"timing":`)
)

func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect unsafe.Pointer, sql string, timeFormat ctools.FormatTimeFunc, reqID int64, sqlType sqltype.SqlType, returnObj bool) {
func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect unsafe.Pointer, sql string, reqID int64, sqlType sqltype.SqlType, returnObj bool, location *time.Location) {
_, calculateTiming := c.Get(RequireTiming)
st := c.MustGet(StartTimeKey)
flushTiming := int64(0)
Expand Down Expand Up @@ -364,6 +384,7 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns
fetched := false
pHeaderList := make([]unsafe.Pointer, fieldsCount)
pStartList := make([]unsafe.Pointer, fieldsCount)
timeBuffer := make([]byte, 0, 30)
for {
if config.Conf.RestfulRowLimit > -1 && total == config.Conf.RestfulRowLimit {
break
Expand Down Expand Up @@ -412,7 +433,7 @@ func execute(c *gin.Context, logger *logrus.Entry, isDebug bool, taosConnect uns
if returnObj {
builder.WriteObjectField(rowsHeader.ColNames[column])
}
ctools.JsonWriteRawBlock(builder, rowsHeader.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, timeFormat)
ctools.JsonWriteRawBlock(builder, rowsHeader.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, location, timeBuffer, logger)
if column != fieldsCount-1 {
builder.WriteMore()
}
Expand Down
96 changes: 96 additions & 0 deletions controller/rest/restful_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/taosdata/taosadapter/v3/db"
"github.com/taosdata/taosadapter/v3/httperror"
"github.com/taosdata/taosadapter/v3/log"
"github.com/taosdata/taosadapter/v3/tools/layout"
)

var router *gin.Engine
Expand Down Expand Up @@ -674,3 +675,98 @@ func TestInternalError(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}

func TestSetConnectionOptions(t *testing.T) {
config.Conf.RestfulRowLimit = -1
w := httptest.NewRecorder()
body := strings.NewReader("create database if not exists rest_test_options")
url := "/rest/sql?app=rest_test_options&ip=192.168.100.1&conn_tz=Europe/Moscow&tz=Asia/Shanghai"
req, _ := http.NewRequest(http.MethodPost, url, body)
req.RemoteAddr = "127.0.0.1:33333"
req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==")
router.ServeHTTP(w, req)
checkResp(t, w)

defer func() {
body := strings.NewReader("drop database if exists rest_test_options")
req, _ := http.NewRequest(http.MethodPost, url, body)
req.RemoteAddr = "127.0.0.1:33333"
req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==")
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
checkResp(t, w)
}()

w = httptest.NewRecorder()
body = strings.NewReader("create table if not exists rest_test_options.t1(ts timestamp,v1 bool)")
req.Body = io.NopCloser(body)
router.ServeHTTP(w, req)
checkResp(t, w)

w = httptest.NewRecorder()
ts := "2024-12-04 12:34:56.789"
body = strings.NewReader(fmt.Sprintf(`insert into rest_test_options.t1 values ('%s',true)`, ts))
req.Body = io.NopCloser(body)
router.ServeHTTP(w, req)
checkResp(t, w)

w = httptest.NewRecorder()
body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`)
req.Body = io.NopCloser(body)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var result TDEngineRestfulRespDoc
err := json.Unmarshal(w.Body.Bytes(), &result)
assert.NoError(t, err)
assert.Equal(t, 0, result.Code)
assert.Equal(t, 1, len(result.Data))

location, err := time.LoadLocation("Europe/Moscow")
assert.NoError(t, err)
expectTime, err := time.ParseInLocation("2006-01-02 15:04:05.000", ts, location)
assert.NoError(t, err)
expectTimeStr := expectTime.Format(layout.LayoutMillSecond)
assert.Equal(t, expectTimeStr, result.Data[0][0])
t.Log(expectTimeStr, result.Data[0][0])

// wrong timezone
wrongTZUrl := "/rest/sql?app=rest_test_options&ip=192.168.100.1&tz=xxx"
body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`)
req, _ = http.NewRequest(http.MethodPost, wrongTZUrl, body)
req.RemoteAddr = "127.0.0.1:33333"
req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==")
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, 400, w.Code)

// wrong conn_tz
wrongConnTZUrl := "/rest/sql?app=rest_test_options&ip=192.168.100.1&conn_tz=xxx"
body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`)
req, _ = http.NewRequest(http.MethodPost, wrongConnTZUrl, body)
req.RemoteAddr = "127.0.0.1:33333"
req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==")
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, 400, w.Code)
// wrong ip
wrongIPUrl := "/rest/sql?app=rest_test_options&ip=xxx.xxx.xxx.xxx&conn_tz=Europe/Moscow&tz=Asia/Shanghai"
req, _ = http.NewRequest(http.MethodPost, wrongIPUrl, body)
req.RemoteAddr = "127.0.0.1:33333"
req.Header.Set("Authorization", "Basic:cm9vdDp0YW9zZGF0YQ==")
w = httptest.NewRecorder()
body = strings.NewReader(`select * from rest_test_options.t1 where ts = '2024-12-04 12:34:56.789'`)
req.Body = io.NopCloser(body)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
err = json.Unmarshal(w.Body.Bytes(), &result)
assert.NoError(t, err)
assert.NotEqual(t, 0, result.Code)
}

func checkResp(t *testing.T, w *httptest.ResponseRecorder) {
assert.Equal(t, 200, w.Code)
var result TDEngineRestfulRespDoc
err := json.Unmarshal(w.Body.Bytes(), &result)
assert.NoError(t, err)
assert.Equal(t, 0, result.Code)
}
4 changes: 2 additions & 2 deletions controller/ws/query/ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1608,10 +1608,10 @@ func TestDropUser(t *testing.T) {
assert.NoError(t, err)
}()
defer doRestful("drop user test_ws_query_drop_user", "")
code, message := doRestful("create user test_ws_query_drop_user pass 'pass'", "")
code, message := doRestful("create user test_ws_query_drop_user pass 'pass_123'", "")
assert.Equal(t, 0, code, message)
// connect
connReq := &WSConnectReq{ReqID: 1, User: "test_ws_query_drop_user", Password: "pass"}
connReq := &WSConnectReq{ReqID: 1, User: "test_ws_query_drop_user", Password: "pass_123"}
resp, err := doWebSocket(ws, WSConnect, &connReq)
assert.NoError(t, err)
var connResp WSConnectResp
Expand Down
4 changes: 2 additions & 2 deletions controller/ws/schemaless/schemaless_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,10 @@ func TestDropUser(t *testing.T) {
assert.NoError(t, err)
}()
defer doRestful("drop user test_ws_sml_drop_user", "")
code, message := doRestful("create user test_ws_sml_drop_user pass 'pass'", "")
code, message := doRestful("create user test_ws_sml_drop_user pass 'pass_123'", "")
assert.Equal(t, 0, code, message)
// connect
connReq := &schemalessConnReq{ReqID: 1, User: "test_ws_sml_drop_user", Password: "pass"}
connReq := &schemalessConnReq{ReqID: 1, User: "test_ws_sml_drop_user", Password: "pass_123"}
resp, err := doWebSocket(ws, SchemalessConn, &connReq)
assert.NoError(t, err)
var connResp schemalessConnResp
Expand Down
4 changes: 2 additions & 2 deletions controller/ws/stmt/stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1101,10 +1101,10 @@ func TestDropUser(t *testing.T) {
assert.NoError(t, err)
}()
defer doRestful("drop user test_ws_stmt_drop_user", "")
code, message := doRestful("create user test_ws_stmt_drop_user pass 'pass'", "")
code, message := doRestful("create user test_ws_stmt_drop_user pass 'pass_123'", "")
assert.Equal(t, 0, code, message)
// connect
connReq := &StmtConnectReq{ReqID: 1, User: "test_ws_stmt_drop_user", Password: "pass"}
connReq := &StmtConnectReq{ReqID: 1, User: "test_ws_stmt_drop_user", Password: "pass_123"}
resp, err := doWebSocket(ws, STMTConnect, &connReq)
assert.NoError(t, err)
var connResp StmtConnectResp
Expand Down
Loading
Loading