Skip to content

Commit

Permalink
Working websocket connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Kelwing committed Feb 26, 2019
1 parent 8f34864 commit 12ec65b
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 12 deletions.
4 changes: 4 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
FROM alpine:latest as certs
RUN apk --update add ca-certificates

FROM golang:latest as builder
WORKDIR /go/src/gitea.auttaja.io/kubecord/ws
RUN go get github.com/gorilla/websocket \
Expand All @@ -9,5 +12,6 @@ COPY . .
RUN CGO_ENABLED=0 go build -installsuffix 'static' -o /app .

FROM scratch as final
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=builder /app /app
ENTRYPOINT ["/app"]
2 changes: 2 additions & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ const (
)

const FailedHeartbeatAcks = 5 * time.Millisecond

const APIBase = "https://discordapp.com/api/v6"
50 changes: 50 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package main

import (
"encoding/json"
"github.com/labstack/gommon/log"
"io/ioutil"
"net/http"
"strings"
"time"
)

func (s *Shard) GatewayBot() (st *GatewayBotResponse, err error) {

client := &http.Client{Timeout: time.Second * 10}

req, err := http.NewRequest("GET", APIBase+"/gateway/bot", nil)

req.Header.Set("Authorization", "Bot "+s.Token)

response, err := client.Do(req)
if err != nil {
log.Fatal("Error getting Gateway data: ", err)
return
}

defer func() {
err := response.Body.Close()
if err != nil {
return
}
}()

body, err := ioutil.ReadAll(response.Body)
if err != nil {
log.Fatal("error reading gateway response body ", err)
}

err = json.Unmarshal(body, &st)
if err != nil {
return
}

// Ensure the gateway always has a trailing slash.
// MacOS will fail to connect if we add query params without a trailing slash on the base domain.
if !strings.HasSuffix(st.URL, "/") {
st.URL += "/"
}

return
}
35 changes: 33 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
package main

import "fmt"
import (
"fmt"
"github.com/labstack/gommon/log"
"os"
"os/signal"
"syscall"
)

func main() {
fmt.Println("Testing")
log.SetLevel(log.INFO)
token := os.Getenv("TOKEN")
client := Shard{Token: token}
GatewayData, err := client.GatewayBot()
if err != nil {
log.Fatal("Unable to get GatewayBot data")
}
shards := make([]Shard, GatewayData.Shards)
initSequence := int64(0)
for sid, shard := range shards {
shard.Sequence = &initSequence
shard.SessionID = ""
shard.Token = token
shard.ShardCount = GatewayData.Shards
shard.ShardId = sid
_ = shard.Open(GatewayData.URL)
}

// Wait here until CTRL-C or other term signal is received.
fmt.Println("Bot is now running. Press CTRL-C to exit.")
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill)
<-sc

// Cleanly close down the Discord session.
_ = client.Close()
}
5 changes: 5 additions & 0 deletions structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ type ResumePayload struct {
Data ResumeData `json:"d"`
}

type GatewayBotResponse struct {
URL string `json:"url"`
Shards int `json:"shards"`
}

/* Gateway objects */

type User struct {
Expand Down
26 changes: 16 additions & 10 deletions ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ func (s *Shard) Open(gateway string) error {
return ErrWSAlreadyOpen
}
s.Gateway = gateway
gateway = gateway + "?v=6&encoding=json&compress=zlib-stream"
gateway = gateway + "?v=6&encoding=json"

log.Info("Shard %d connecting to gateway", s.ShardId)
log.Infof("Shard %d connecting to gateway", s.ShardId)
header := http.Header{}
header.Add("accept-encoding", "zlib")

s.Conn, _, err = websocket.DefaultDialer.Dial(gateway, header)
if err != nil {
log.Warn("error connecting to gateway on shard %d", err)
log.Warnf("error connecting to gateway on shard %d", err)
s.Conn = nil
return err
}
Expand Down Expand Up @@ -176,6 +176,7 @@ func (s *Shard) onPayload(wsConn *websocket.Conn, listening <-chan interface{})
s.RUnlock()

if sameConnection {
log.Warnf("error reading from gateway on shard %d: %s", s.ShardId, err)
err := s.Close()
if err != nil {
log.Warn("error closing connection, %s", err)
Expand Down Expand Up @@ -203,30 +204,34 @@ func (s *Shard) Dispatch(messageType int, message []byte) (*GatewayPayload, erro
var err error
var e *GatewayPayload
buffer = bytes.NewBuffer(message)

log.Debugf("Got event on shard %d", s.ShardId)
if messageType == websocket.BinaryMessage {
decompressor, zerr := zlib.NewReader(buffer)
if zerr != nil {
log.Error("error decompressing message: %s", zerr)
log.Errorf("error decompressing message: %s", zerr)
return nil, zerr
}

defer func() {
zerr := decompressor.Close()
if zerr != nil {
log.Warn("error closing zlib: %s", zerr)
log.Warnf("error closing zlib: %s", zerr)
return
}
}()

buffer = decompressor
}

log.Debugf("Decompressed message on shard %d", s.ShardId)

decoder := json.NewDecoder(buffer)
if err = decoder.Decode(&e); err != nil {
log.Error("error decoding message: %s", err)
return e, err
}

log.Debug("Op: %d, Seq: %d, Type: %s, Data: %s\n\n", e.Op, e.Sequence, e.Event, string(e.Data))
log.Debugf("Op: %d, Seq: %d, Type: %s, Data: %s\n\n", e.Op, e.Sequence, e.Event, string(e.Data))

switch e.Op {
case OP_HEARTBEAT:
Expand All @@ -246,7 +251,7 @@ func (s *Shard) Dispatch(messageType int, message []byte) (*GatewayPayload, erro
case OP_INVALID_SESSION:
err = s.Identify()
if err != nil {
log.Error("error identifying with gateway: %s", err)
log.Errorf("error identifying with gateway: %s", err)
return e, err
}
return e, nil
Expand All @@ -256,14 +261,15 @@ func (s *Shard) Dispatch(messageType int, message []byte) (*GatewayPayload, erro
s.Lock()
s.LastHeartbeatAck = time.Now().UTC()
s.Unlock()
log.Debug("got heartbeat ACK")
log.Debugf("got heartbeat ACK")
return e, nil
case OP_DISPATCH:
// Dispatch the message
atomic.StoreInt64(s.Sequence, e.Sequence)
log.Debugf("Got event %s on shard %d", e.Event, s.ShardId)
return e, nil
default:
log.Warn("Unknown Op: %d", e.Op)
log.Warnf("Unknown Op: %d", e.Op)
return e, nil
}
}
Expand Down

0 comments on commit 12ec65b

Please sign in to comment.