Skip to content

Commit

Permalink
use go-bytes-pool
Browse files Browse the repository at this point in the history
  • Loading branch information
IrineSistiana committed Sep 17, 2023
1 parent 98701bd commit f02d377
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 228 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/IrineSistiana/mosdns/v5
go 1.19

require (
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686
github.com/go-chi/chi/v5 v5.0.10
github.com/google/nftables v0.1.0
github.com/kardianos/service v1.2.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686 h1:5R32cCep3VUDTKf3aurFKfgbvg+RScuBmZsw/DyyXco=
github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686/go.mod h1:pQ/FSsWSNYmNdgIKmulKlmVC/R2PEpq2vIEi3J9IijI=
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a h1:GQdh/h0q0ni3L//CXusyk+7QdhBL289vdNaes1WKkHI=
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
Expand Down
45 changes: 22 additions & 23 deletions pkg/dnsutils/net_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import (
"encoding/binary"
"errors"
"fmt"
"io"

"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
"io"
)

var (
Expand All @@ -35,45 +36,43 @@ var (
// ReadRawMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
// with a two byte length field).
// n represents how many bytes are read from c.
// The returned the []byte should be released by pool.ReleaseBuf.
func ReadRawMsgFromTCP(c io.Reader) ([]byte, int, error) {
n := 0
// The returned the *[]byte should be released by pool.ReleaseBuf.
func ReadRawMsgFromTCP(c io.Reader) (*[]byte, error) {
h := pool.GetBuf(2)
defer pool.ReleaseBuf(h)
nh, err := io.ReadFull(c, h)
n += nh
_, err := io.ReadFull(c, *h)

if err != nil {
return nil, n, err
return nil, err
}

// dns length
length := binary.BigEndian.Uint16(h)
length := binary.BigEndian.Uint16(*h)
if length == 0 {
return nil, 0, errZeroLenMsg
return nil, errZeroLenMsg
}

buf := pool.GetBuf(int(length))
nm, err := io.ReadFull(c, buf)
n += nm
b := pool.GetBuf(int(length))
_, err = io.ReadFull(c, *b)
if err != nil {
pool.ReleaseBuf(buf)
return nil, n, err
pool.ReleaseBuf(b)
return nil, err
}
return buf, n, nil
return b, nil
}

// ReadMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
// with a two byte length field).
// n represents how many bytes are read from c.
func ReadMsgFromTCP(c io.Reader) (*dns.Msg, int, error) {
b, n, err := ReadRawMsgFromTCP(c)
b, err := ReadRawMsgFromTCP(c)
if err != nil {
return nil, 0, err
}
defer pool.ReleaseBuf(b)

m, err := unpackMsgWithDetailedErr(b)
return m, n, err
m, err := unpackMsgWithDetailedErr(*b)
return m, len(*b) + 2, err
}

// WriteMsgToTCP packs and writes m to c in RFC 1035 format.
Expand All @@ -96,9 +95,9 @@ func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) {
buf := pool.GetBuf(len(b) + 2)
defer pool.ReleaseBuf(buf)

binary.BigEndian.PutUint16(buf[:2], uint16(len(b)))
copy(buf[2:], b)
return c.Write(buf)
binary.BigEndian.PutUint16((*buf)[:2], uint16(len(b)))
copy((*buf)[2:], b)
return c.Write((*buf))
}

func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) {
Expand All @@ -118,12 +117,12 @@ func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) {

b := pool.GetBuf(bufSize)
defer pool.ReleaseBuf(b)
n, err := c.Read(b)
n, err := c.Read(*b)
if err != nil {
return nil, n, err
}

m, err := unpackMsgWithDetailedErr(b[:n])
m, err := unpackMsgWithDetailedErr((*b)[:n])
return m, n, err
}

Expand Down
78 changes: 5 additions & 73 deletions pkg/pool/allocator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,79 +20,11 @@
package pool

import (
"fmt"
"math"
"math/bits"
"sync"
bytesPool "github.com/IrineSistiana/go-bytes-pool"
)

// defaultBufPool is an Allocator that has a maximum capacity.
var defaultBufPool = NewAllocator()

// GetBuf returns a []byte from pool with most appropriate cap.
// It panics if size < 0.
func GetBuf(size int) []byte {
return defaultBufPool.Get(size)
}

// ReleaseBuf puts the buf to the pool.
func ReleaseBuf(b []byte) {
defaultBufPool.Release(b)
}

type Allocator struct {
buffers []sync.Pool
}

// NewAllocator initiates a []byte Allocator.
// The waste(memory fragmentation) of space allocation is guaranteed to be
// no more than 50%.
func NewAllocator() *Allocator {
alloc := &Allocator{
buffers: make([]sync.Pool, bits.UintSize+1),
}

for i := range alloc.buffers {
var bufSize uint
if i == bits.UintSize {
bufSize = math.MaxUint
} else {
bufSize = 1 << i
}
alloc.buffers[i].New = func() any {
b := make([]byte, bufSize)
return &b
}
}
return alloc
}

// Get returns a []byte from pool with most appropriate cap
func (alloc *Allocator) Get(size int) []byte {
if size < 0 {
panic(fmt.Sprintf("invalid slice size %d", size))
}

i := shard(size)
v := alloc.buffers[i].Get()
buf := v.(*[]byte)
return (*buf)[0:size]
}

// Release releases the buf to the allocatorL.
func (alloc *Allocator) Release(buf []byte) {
c := cap(buf)
i := shard(c)
if c == 0 || c != 1<<i {
panic("unexpected cap size")
}
alloc.buffers[i].Put(&buf)
}

// shard returns the shard index that is suitable for the size.
func shard(size int) int {
if size <= 1 {
return 0
}
return bits.Len64(uint64(size - 1))
}
var (
GetBuf = bytesPool.Get
ReleaseBuf = bytesPool.Release
)
119 changes: 0 additions & 119 deletions pkg/pool/allocator_test.go

This file was deleted.

4 changes: 2 additions & 2 deletions pkg/pool/msg_buf.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ const packBufSize = 4096
// PackBuffer packs the dns msg m to wire format.
// Callers should release the buf by calling ReleaseBuf after they have done
// with the wire []byte.
func PackBuffer(m *dns.Msg) (wire, buf []byte, err error) {
func PackBuffer(m *dns.Msg) (wire []byte, buf *[]byte, err error) {
buf = GetBuf(packBufSize)
wire, err = m.PackBuffer(buf)
wire, err = m.PackBuffer(*buf)
if err != nil {
ReleaseBuf(buf)
return nil, nil, err
Expand Down
4 changes: 2 additions & 2 deletions pkg/pool/msg_buf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestPackBuffer_No_Allocation(t *testing.T) {
t.Fatal(err)
}

if cap(wire) != cap(buf) {
t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(buf))
if cap(wire) != cap(*buf) {
t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(*buf))
}
}
6 changes: 3 additions & 3 deletions pkg/server/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ func (s *UDPServer) ServeUDP(c *net.UDPConn) error {
}

for {
n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(rb, ob)
n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob)
if err != nil {
return fmt.Errorf("unexpected read err: %w", err)
}
clientAddr := remoteAddr.Addr()

q := new(dns.Msg)
if err := q.Unpack(rb[:n]); err != nil {
s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", rb[:n]), zap.Stringer("from", remoteAddr))
if err := q.Unpack((*rb)[:n]); err != nil {
s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr))
continue
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/upstream/transport/dns_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ var (
c1, c2 := net.Pipe()
go func() {
for {
m, _, readErr := dnsutils.ReadRawMsgFromTCP(c2)
m, readErr := dnsutils.ReadRawMsgFromTCP(c2)
if m != nil {
go func() {
defer pool.ReleaseBuf(m)
latency := time.Millisecond * time.Duration(rand.Intn(20))
time.Sleep(latency)
_, _ = dnsutils.WriteRawMsgToTCP(c2, m)
_, _ = dnsutils.WriteRawMsgToTCP(c2, *m)
}()
}
if readErr != nil {
Expand Down
Loading

0 comments on commit f02d377

Please sign in to comment.