Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh: refactor websocket #360

Merged
merged 11 commits into from
Nov 29, 2024
79 changes: 34 additions & 45 deletions controller/ws/query/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/taosdata/taosadapter/v3/tools/generator"

"github.com/gin-gonic/gin"
"github.com/huskar-t/melody"
"github.com/sirupsen/logrus"
"github.com/taosdata/driver-go/v3/common/parser"
"github.com/taosdata/driver-go/v3/wrapper"
Expand All @@ -32,6 +31,7 @@ import (
"github.com/taosdata/taosadapter/v3/tools"
"github.com/taosdata/taosadapter/v3/tools/iptool"
"github.com/taosdata/taosadapter/v3/tools/jsontype"
"github.com/taosdata/taosadapter/v3/tools/melody"
)

type QueryController struct {
Expand All @@ -40,7 +40,7 @@ type QueryController struct {

func NewQueryController() *QueryController {
queryM := melody.New()
queryM.UpGrader.EnableCompression = true
queryM.Upgrader.EnableCompression = true
queryM.Config.MaxMessageSize = 0

queryM.HandleConnect(func(session *melody.Session) {
Expand All @@ -51,16 +51,16 @@ func NewQueryController() *QueryController {
})

queryM.HandleMessage(func(session *melody.Session, data []byte) {
if queryM.IsClosed() {
return
}
t := session.MustGet(TaosSessionKey).(*Taos)
if t.closed {
return
}
t.wg.Add(1)
go func() {
defer t.wg.Done()
if t.closed {
return
}
ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano())
logger := wstool.GetLogger(session)
logger.Debugf("get ws message data: %s", data)
Expand All @@ -72,7 +72,7 @@ func NewQueryController() *QueryController {
}
switch action.Action {
case wstool.ClientVersion:
_ = session.Write(wstool.VersionResp)
wstool.WSWriteVersion(session, logger)
case WSConnect:
var wsConnect WSConnectReq
err = json.Unmarshal(action.Args, &wsConnect)
Expand Down Expand Up @@ -122,16 +122,16 @@ func NewQueryController() *QueryController {
})

queryM.HandleMessageBinary(func(session *melody.Session, data []byte) {
if queryM.IsClosed() {
return
}
t := session.MustGet(TaosSessionKey).(*Taos)
if t.closed {
return
}
t.wg.Add(1)
go func() {
defer t.wg.Done()
if t.closed {
return
}
ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano())
logger := wstool.GetLogger(session)
logger.Tracef("get ws block message data:%+v", data)
Expand Down Expand Up @@ -415,7 +415,7 @@ func (t *Taos) connect(ctx context.Context, session *melody.Session, req *WSConn
}
if t.conn != nil {
logger.Trace("duplicate connections")
wsErrorMsg(ctx, session, 0xffff, "duplicate connections", WSConnect, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "duplicate connections", WSConnect, req.ReqID)
return
}
conn, err := syncinterface.TaosConnect("", req.User, req.Password, req.DB, 0, logger, isDebug)
Expand Down Expand Up @@ -499,7 +499,7 @@ func (t *Taos) query(ctx context.Context, session *melody.Session, req *WSQueryR
)
if t.conn == nil {
logger.Trace("server not connected")
wsErrorMsg(ctx, session, 0xffff, "server not connected", WSQuery, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSQuery, req.ReqID)
return
}
logger.Tracef("req_id: 0x%x,query sql: %s", req.ReqID, req.SQL)
Expand All @@ -521,7 +521,7 @@ func (t *Taos) query(ctx context.Context, session *melody.Session, req *WSQueryR
logger.Errorf("query error, code: %d, message: %s", code, errStr)
logger.Trace("get thread lock for free result")
syncinterface.FreeResult(result.Res, logger, isDebug)
wsErrorMsg(ctx, session, code, errStr, WSQuery, req.ReqID)
wsErrorMsg(ctx, session, logger, code, errStr, WSQuery, req.ReqID)
return
}
monitor.WSRecordResult(sqlType, true)
Expand Down Expand Up @@ -593,7 +593,7 @@ func (t *Taos) writeRaw(ctx context.Context, session *melody.Session, reqID, mes
}
if t.conn == nil {
logger.Error("server not connected")
wsTMQErrorMsg(ctx, session, 0xffff, "server not connected", WSWriteRaw, reqID, &messageID)
wsTMQErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSWriteRaw, reqID, &messageID)
return
}
meta := wrapper.BuildRawMeta(length, metaType, data)
Expand All @@ -609,7 +609,7 @@ func (t *Taos) writeRaw(ctx context.Context, session *melody.Session, reqID, mes
if errCode != 0 {
errStr := wrapper.TMQErr2Str(errCode)
logger.Errorf("write raw meta error, code: %d, message: %s", errCode, errStr)
wsErrorMsg(ctx, session, int(errCode)&0xffff, errStr, WSWriteRaw, reqID)
wsErrorMsg(ctx, session, logger, int(errCode)&0xffff, errStr, WSWriteRaw, reqID)
return
}
resp := &WSWriteMetaResp{Action: WSWriteRaw, ReqID: reqID, MessageID: messageID, Timing: wstool.GetDuration(ctx)}
Expand All @@ -636,7 +636,7 @@ func (t *Taos) writeRawBlock(ctx context.Context, session *melody.Session, reqID
return
}
if t.conn == nil {
wsErrorMsg(ctx, session, 0xffff, "server not connected", WSWriteRawBlock, reqID)
wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSWriteRawBlock, reqID)
return
}
logger.Trace("get thread lock for write raw block")
Expand All @@ -651,7 +651,7 @@ func (t *Taos) writeRawBlock(ctx context.Context, session *melody.Session, reqID
if errCode != 0 {
errStr := wrapper.TMQErr2Str(int32(errCode))
logger.Errorf("write raw block error, code: %d, message: %s", errCode, errStr)
wsErrorMsg(ctx, session, errCode&0xffff, errStr, WSWriteRawBlock, reqID)
wsErrorMsg(ctx, session, logger, errCode&0xffff, errStr, WSWriteRawBlock, reqID)
return
}
resp := &WSWriteRawBlockResp{Action: WSWriteRawBlock, ReqID: reqID, Timing: wstool.GetDuration(ctx)}
Expand Down Expand Up @@ -679,7 +679,7 @@ func (t *Taos) writeRawBlockWithFields(ctx context.Context, session *melody.Sess
}
if t.conn == nil {
logger.Errorf("server not connected")
wsErrorMsg(ctx, session, 0xffff, "server not connected", WSWriteRawBlockWithFields, reqID)
wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSWriteRawBlockWithFields, reqID)
return
}
logger.Trace("get thread lock for write raw block with fields")
Expand All @@ -694,7 +694,7 @@ func (t *Taos) writeRawBlockWithFields(ctx context.Context, session *melody.Sess
if errCode != 0 {
errStr := wrapper.TMQErr2Str(int32(errCode))
logger.Errorf("write raw block with fields error, code: %d, message: %s", errCode, errStr)
wsErrorMsg(ctx, session, errCode&0xffff, errStr, WSWriteRawBlockWithFields, reqID)
wsErrorMsg(ctx, session, logger, errCode&0xffff, errStr, WSWriteRawBlockWithFields, reqID)
return
}
resp := &WSWriteRawBlockWithFieldsResp{Action: WSWriteRawBlockWithFields, ReqID: reqID, Timing: wstool.GetDuration(ctx)}
Expand Down Expand Up @@ -724,22 +724,22 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR
)
if t.conn == nil {
logger.Errorf("server not connected")
wsErrorMsg(ctx, session, 0xffff, "server not connected", WSFetch, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSFetch, req.ReqID)
return
}
isDebug := log.IsDebug()
resultItem := t.getResult(req.ID)
if resultItem == nil {
logger.Errorf("result is nil")
wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetch, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "result is nil", WSFetch, req.ReqID)
return
}
resultS := resultItem.Value.(*Result)
resultS.Lock()
if resultS.TaosResult == nil {
resultS.Unlock()
logger.Errorf("result is nil")
wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetch, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "result is nil", WSFetch, req.ReqID)
return
}
s := log.GetLogNow(isDebug)
Expand Down Expand Up @@ -768,7 +768,7 @@ func (t *Taos) fetch(ctx context.Context, session *melody.Session, req *WSFetchR
logger.Errorf("fetch raw block error, code: %d, message: %s", result.N, errStr)
resultS.Unlock()
t.FreeResult(resultItem, logger)
wsErrorMsg(ctx, session, result.N&0xffff, errStr, WSFetch, req.ReqID)
wsErrorMsg(ctx, session, logger, result.N&0xffff, errStr, WSFetch, req.ReqID)
return
}
s = log.GetLogNow(isDebug)
Expand Down Expand Up @@ -802,26 +802,26 @@ func (t *Taos) fetchBlock(ctx context.Context, session *melody.Session, req *WSF
)
if t.conn == nil {
logger.Error("server not connected")
wsErrorMsg(ctx, session, 0xffff, "server not connected", WSFetchBlock, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "server not connected", WSFetchBlock, req.ReqID)
return
}
isDebug := log.IsDebug()
s := log.GetLogNow(isDebug)
resultItem := t.getResult(req.ID)
if resultItem == nil {
wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetchBlock, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "result is nil", WSFetchBlock, req.ReqID)
return
}
resultS := resultItem.Value.(*Result)
resultS.Lock()
if resultS.TaosResult == nil {
resultS.Unlock()
wsErrorMsg(ctx, session, 0xffff, "result is nil", WSFetchBlock, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "result is nil", WSFetchBlock, req.ReqID)
return
}
if resultS.Block == nil {
resultS.Unlock()
wsErrorMsg(ctx, session, 0xffff, "block is nil", WSFetchBlock, req.ReqID)
wsErrorMsg(ctx, session, logger, 0xffff, "block is nil", WSFetchBlock, req.ReqID)
return
}
blockLength := int(parser.RawBlockGetLength(resultS.Block))
Expand Down Expand Up @@ -867,15 +867,6 @@ func (t *Taos) freeResult(req *WSFreeResultReq) {
}
}

type Writer struct {
session *melody.Session
}

func (w *Writer) Write(p []byte) (int, error) {
err := w.session.Write(p)
return 0, err
}

func (t *Taos) FreeResult(element *list.Element, logger *logrus.Entry) {
if element == nil {
return
Expand Down Expand Up @@ -968,16 +959,15 @@ type WSErrorResp struct {
Timing int64 `json:"timing"`
}

func wsErrorMsg(ctx context.Context, session *melody.Session, code int, message string, action string, reqID uint64) {
b, _ := json.Marshal(&WSErrorResp{
func wsErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, action string, reqID uint64) {
data := &WSErrorResp{
Code: code & 0xffff,
Message: message,
Action: action,
ReqID: reqID,
Timing: wstool.GetDuration(ctx),
})
wstool.GetLogger(session).Tracef("write error message: %s", b)
_ = session.Write(b)
}
wstool.WSWriteJson(session, logger, data)
}

type WSTMQErrorResp struct {
Expand All @@ -989,17 +979,16 @@ type WSTMQErrorResp struct {
MessageID *uint64 `json:"message_id,omitempty"`
}

func wsTMQErrorMsg(ctx context.Context, session *melody.Session, code int, message string, action string, reqID uint64, messageID *uint64) {
b, _ := json.Marshal(&WSTMQErrorResp{
func wsTMQErrorMsg(ctx context.Context, session *melody.Session, logger *logrus.Entry, code int, message string, action string, reqID uint64, messageID *uint64) {
data := &WSTMQErrorResp{
Code: code & 0xffff,
Message: message,
Action: action,
ReqID: reqID,
Timing: wstool.GetDuration(ctx),
MessageID: messageID,
})
wstool.GetLogger(session).Tracef("write error message: %s", b)
_ = session.Write(b)
}
wstool.WSWriteJson(session, logger, data)
}

func init() {
Expand Down
12 changes: 6 additions & 6 deletions controller/ws/schemaless/schemaless.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"unsafe"

"github.com/gin-gonic/gin"
"github.com/huskar-t/melody"
"github.com/sirupsen/logrus"
tErrors "github.com/taosdata/driver-go/v3/errors"
"github.com/taosdata/driver-go/v3/wrapper"
Expand All @@ -22,6 +21,7 @@ import (
"github.com/taosdata/taosadapter/v3/log"
"github.com/taosdata/taosadapter/v3/tools/generator"
"github.com/taosdata/taosadapter/v3/tools/iptool"
"github.com/taosdata/taosadapter/v3/tools/melody"
)

type SchemalessController struct {
Expand All @@ -30,7 +30,7 @@ type SchemalessController struct {

func NewSchemalessController() *SchemalessController {
schemaless := melody.New()
schemaless.UpGrader.EnableCompression = true
schemaless.Upgrader.EnableCompression = true
schemaless.Config.MaxMessageSize = 0

schemaless.HandleConnect(func(session *melody.Session) {
Expand All @@ -40,16 +40,16 @@ func NewSchemalessController() *SchemalessController {
})

schemaless.HandleMessage(func(session *melody.Session, bytes []byte) {
if schemaless.IsClosed() {
return
}
t := session.MustGet(taosSchemalessKey).(*TaosSchemaless)
if t.closed {
return
}
t.wg.Add(1)
go func() {
defer t.wg.Done()
if t.closed {
return
}
ctx := context.WithValue(context.Background(), wstool.StartTimeKey, time.Now().UnixNano())
logger := wstool.GetLogger(session)
logger.Debugf("get ws message data:%s", bytes)
Expand All @@ -62,7 +62,7 @@ func NewSchemalessController() *SchemalessController {
}
switch action.Action {
case wstool.ClientVersion:
_ = session.Write(wstool.VersionResp)
wstool.WSWriteVersion(session, logger)
case SchemalessConn:
var req schemalessConnReq
if err = json.Unmarshal(action.Args, &req); err != nil {
Expand Down
Loading
Loading