Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async dialer #423

Merged
merged 21 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions conn_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ var (
func (c *Conn) newToWriteBuf(buf []byte) {
c.left += len(buf)

allocator := c.p.g.BodyAllocator
appendBuffer := func() {
t := poolToWrite.New().(*toWrite)
b := c.p.g.BodyAllocator.Malloc(len(buf))
b := allocator.Malloc(len(buf))
copy(b, buf)
t.buf = b
c.writeList = append(c.writeList, t)
Expand All @@ -55,12 +56,12 @@ func (c *Conn) newToWriteBuf(buf []byte) {
appendBuffer()
} else {
if cap(tail.buf) < tailLen+l {
b := c.p.g.BodyAllocator.Malloc(tailLen + l)[:tailLen]
b := allocator.Malloc(tailLen + l)[:tailLen]
copy(b, tail.buf)
c.p.g.BodyAllocator.Free(tail.buf)
allocator.Free(tail.buf)
tail.buf = b
}
tail.buf = append(tail.buf, buf...)
tail.buf = allocator.Append(tail.buf, buf...)
}
}
}
Expand Down
26 changes: 25 additions & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ const (
DefaultUDPReadTimeout = 120 * time.Second
)

const (
NETWORK_TCP = "tcp"
NETWORK_TCP4 = "tcp4"
NETWORK_TCP6 = "tcp6"
NETWORK_UDP = "udp"
NETWORK_UDP4 = "udp4"
NETWORK_UDP6 = "udp6"
NETWORK_UNIX = "unix"
NETWORK_UNIXGRAM = "unixgram"
NETWORK_UNIXPACKET = "unixpacket"
)

var (
// MaxOpenFiles .
MaxOpenFiles = 1024 * 1024 * 2
Expand Down Expand Up @@ -249,7 +261,19 @@ func (g *Engine) AddConn(conn net.Conn) (*Conn, error) {
}

p := g.pollers[c.Hash()%len(g.pollers)]
p.addConn(c)
err = p.addConn(c)
if err != nil {
return nil, err
}
return c, nil
}

func (g *Engine) addDialer(c *Conn) (*Conn, error) {
p := g.pollers[c.Hash()%len(g.pollers)]
err := p.addDialer(c)
if err != nil {
return nil, err
}
return c, nil
}

Expand Down
41 changes: 39 additions & 2 deletions engine_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net"
"runtime"
"strings"
"time"

"github.com/lesismal/nbio/logging"
"github.com/lesismal/nbio/mempool"
Expand All @@ -22,7 +23,7 @@ func (g *Engine) Start() error {
// Create listener pollers.
udpListeners := make([]*net.UDPConn, len(g.Addrs))[0:0]
switch g.Network {
case "tcp", "tcp4", "tcp6":
case NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6:
for i := range g.Addrs {
ln, err := newPoller(g, true, i)
if err != nil {
Expand All @@ -34,7 +35,7 @@ func (g *Engine) Start() error {
g.Addrs[i] = ln.listener.Addr().String()
g.listeners = append(g.listeners, ln)
}
case "udp", "udp4", "udp6":
case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6:
for i, addrStr := range g.Addrs {
addr, err := net.ResolveUDPAddr(g.Network, addrStr)
if err != nil {
Expand Down Expand Up @@ -165,3 +166,39 @@ func NewEngine(conf Config) *Engine {

return g
}

// DialAsync connects asynchrony to the address on the named network.
func (engine *Engine) DialAsync(network, addr string, onConnected func(*Conn, error)) error {
return engine.DialAsyncTimeout(network, addr, 0, onConnected)
}

// DialAsync connects asynchrony to the address on the named network with timeout.
func (engine *Engine) DialAsyncTimeout(network, addr string, timeout time.Duration, onConnected func(*Conn, error)) error {
go func() {
var err error
var conn net.Conn
if timeout > 0 {
conn, err = net.DialTimeout(network, addr, timeout)
} else {
conn, err = net.Dial(network, addr)
}
if err != nil {
onConnected(nil, err)
return
}
nbc, err := NBConn(conn)
if err != nil {
onConnected(nil, err)
return
}
engine.wgConn.Add(1)
nbc, err = engine.addDialer(nbc)
if err == nil {
nbc.SetWriteDeadline(time.Time{})
} else {
engine.wgConn.Done()
}
onConnected(nbc, err)
}()
return nil
}
118 changes: 116 additions & 2 deletions engine_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
package nbio

import (
"errors"
"net"
"runtime"
"strings"
"syscall"
"time"

"github.com/lesismal/nbio/logging"
"github.com/lesismal/nbio/mempool"
Expand All @@ -28,7 +31,7 @@ func (g *Engine) Start() error {
udpListeners := make([]*net.UDPConn, len(g.Addrs))[0:0]

switch g.Network {
case "unix", "tcp", "tcp4", "tcp6":
case NETWORK_UNIX, NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6:
for i := range g.Addrs {
ln, err := newPoller(g, true, i)
if err != nil {
Expand All @@ -40,7 +43,7 @@ func (g *Engine) Start() error {
g.Addrs[i] = ln.listener.Addr().String()
g.listeners = append(g.listeners, ln)
}
case "udp", "udp4", "udp6":
case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6:
for i, addrStr := range g.Addrs {
addr, err := net.ResolveUDPAddr(g.Network, addrStr)
if err != nil {
Expand Down Expand Up @@ -139,6 +142,117 @@ func (g *Engine) Start() error {
return nil
}

// DialAsync connects asynchrony to the address on the named network.
func (engine *Engine) DialAsync(network, addr string, onConnected func(*Conn, error)) error {
return engine.DialAsyncTimeout(network, addr, 0, onConnected)
}

// DialAsync connects asynchrony to the address on the named network with timeout.
func (engine *Engine) DialAsyncTimeout(network, addr string, timeout time.Duration, onConnected func(*Conn, error)) error {
h := func(c *Conn, err error) {
if err == nil {
c.SetWriteDeadline(time.Time{})
}
onConnected(c, err)
}
domain, typ, dialaddr, raddr, connType, err := parseDomainAndType(network, addr)
if err != nil {
return err
}
fd, err := syscall.Socket(domain, typ, 0)
if err != nil {
return err
}
err = syscall.SetNonblock(fd, true)
if err != nil {
syscall.Close(fd)
return err
}
err = syscall.Connect(fd, dialaddr)
inprogress := false
if err != nil {
if errors.Is(err, syscall.EINPROGRESS) {
inprogress = true
} else {
syscall.Close(fd)
return err
}
}
sa, _ := syscall.Getsockname(fd)
c := &Conn{
fd: fd,
rAddr: raddr,
typ: connType,
}
if inprogress {
c.onConnected = h
}
switch vt := sa.(type) {
case *syscall.SockaddrInet4:
switch connType {
case ConnTypeTCP:
c.lAddr = &net.TCPAddr{
IP: []byte{vt.Addr[0], vt.Addr[1], vt.Addr[2], vt.Addr[3]},
Port: vt.Port,
}
case ConnTypeUDPClientFromDial:
c.lAddr = &net.TCPAddr{
IP: []byte{vt.Addr[0], vt.Addr[1], vt.Addr[2], vt.Addr[3]},
Port: vt.Port,
}
c.connUDP = &udpConn{
parent: c,
}
}
case *syscall.SockaddrInet6:
var iface *net.Interface
iface, err = net.InterfaceByIndex(int(vt.ZoneId))
if err != nil {
syscall.Close(fd)
return err
}
switch connType {
case ConnTypeTCP:
c.lAddr = &net.TCPAddr{
IP: make([]byte, len(vt.Addr)),
Port: vt.Port,
Zone: iface.Name,
}
case ConnTypeUDPClientFromDial:
c.lAddr = &net.UDPAddr{
IP: make([]byte, len(vt.Addr)),
Port: vt.Port,
Zone: iface.Name,
}
c.connUDP = &udpConn{
parent: c,
}
}
case *syscall.SockaddrUnix:
c.lAddr = &net.UnixAddr{
Net: network,
Name: vt.Name,
}
}

engine.wgConn.Add(1)
_, err = engine.addDialer(c)
if err != nil {
engine.wgConn.Done()
return err
}

if !inprogress {
engine.Async(func() {
h(c, nil)
})
} else if timeout > 0 {
c.setDeadline(&c.wTimer, ErrDialTimeout, time.Now().Add(timeout))
}

return nil
}

// NewEngine creates an Engine and init default configurations.
func NewEngine(conf Config) *Engine {
if conf.Name == "" {
Expand Down
2 changes: 2 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ var (
ErrOverflow = errors.New("write overflow")
errOverflow = ErrOverflow

ErrDialTimeout = errors.New("dial timeout")

ErrUnsupported = errors.New("unsupported operation")
)
Loading
Loading