diff --git a/conn_unix.go b/conn_unix.go index 25b3e07c..cd0990b2 100644 --- a/conn_unix.go +++ b/conn_unix.go @@ -32,9 +32,10 @@ var ( func (c *Conn) newToWriteBuf(buf []byte) { c.left += len(buf) + allocator := c.p.g.BodyAllocator appendBuffer := func() { t := poolToWrite.New().(*toWrite) - b := c.p.g.BodyAllocator.Malloc(len(buf)) + b := allocator.Malloc(len(buf)) copy(b, buf) t.buf = b c.writeList = append(c.writeList, t) @@ -55,12 +56,12 @@ func (c *Conn) newToWriteBuf(buf []byte) { appendBuffer() } else { if cap(tail.buf) < tailLen+l { - b := c.p.g.BodyAllocator.Malloc(tailLen + l)[:tailLen] + b := allocator.Malloc(tailLen + l)[:tailLen] copy(b, tail.buf) - c.p.g.BodyAllocator.Free(tail.buf) + allocator.Free(tail.buf) tail.buf = b } - tail.buf = append(tail.buf, buf...) + tail.buf = allocator.Append(tail.buf, buf...) } } } diff --git a/engine.go b/engine.go index a5ba99e9..bec86a5c 100644 --- a/engine.go +++ b/engine.go @@ -32,6 +32,18 @@ const ( DefaultUDPReadTimeout = 120 * time.Second ) +const ( + NETWORK_TCP = "tcp" + NETWORK_TCP4 = "tcp4" + NETWORK_TCP6 = "tcp6" + NETWORK_UDP = "udp" + NETWORK_UDP4 = "udp4" + NETWORK_UDP6 = "udp6" + NETWORK_UNIX = "unix" + NETWORK_UNIXGRAM = "unixgram" + NETWORK_UNIXPACKET = "unixpacket" +) + var ( // MaxOpenFiles . MaxOpenFiles = 1024 * 1024 * 2 @@ -249,7 +261,19 @@ func (g *Engine) AddConn(conn net.Conn) (*Conn, error) { } p := g.pollers[c.Hash()%len(g.pollers)] - p.addConn(c) + err = p.addConn(c) + if err != nil { + return nil, err + } + return c, nil +} + +func (g *Engine) addDialer(c *Conn) (*Conn, error) { + p := g.pollers[c.Hash()%len(g.pollers)] + err := p.addDialer(c) + if err != nil { + return nil, err + } return c, nil } diff --git a/engine_std.go b/engine_std.go index 3d3966bc..e7c08c50 100644 --- a/engine_std.go +++ b/engine_std.go @@ -11,6 +11,7 @@ import ( "net" "runtime" "strings" + "time" "github.com/lesismal/nbio/logging" "github.com/lesismal/nbio/mempool" @@ -22,7 +23,7 @@ func (g *Engine) Start() error { // Create listener pollers. udpListeners := make([]*net.UDPConn, len(g.Addrs))[0:0] switch g.Network { - case "tcp", "tcp4", "tcp6": + case NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6: for i := range g.Addrs { ln, err := newPoller(g, true, i) if err != nil { @@ -34,7 +35,7 @@ func (g *Engine) Start() error { g.Addrs[i] = ln.listener.Addr().String() g.listeners = append(g.listeners, ln) } - case "udp", "udp4", "udp6": + case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6: for i, addrStr := range g.Addrs { addr, err := net.ResolveUDPAddr(g.Network, addrStr) if err != nil { @@ -165,3 +166,39 @@ func NewEngine(conf Config) *Engine { return g } + +// DialAsync connects asynchrony to the address on the named network. +func (engine *Engine) DialAsync(network, addr string, onConnected func(*Conn, error)) error { + return engine.DialAsyncTimeout(network, addr, 0, onConnected) +} + +// DialAsync connects asynchrony to the address on the named network with timeout. +func (engine *Engine) DialAsyncTimeout(network, addr string, timeout time.Duration, onConnected func(*Conn, error)) error { + go func() { + var err error + var conn net.Conn + if timeout > 0 { + conn, err = net.DialTimeout(network, addr, timeout) + } else { + conn, err = net.Dial(network, addr) + } + if err != nil { + onConnected(nil, err) + return + } + nbc, err := NBConn(conn) + if err != nil { + onConnected(nil, err) + return + } + engine.wgConn.Add(1) + nbc, err = engine.addDialer(nbc) + if err == nil { + nbc.SetWriteDeadline(time.Time{}) + } else { + engine.wgConn.Done() + } + onConnected(nbc, err) + }() + return nil +} diff --git a/engine_unix.go b/engine_unix.go index 6bfd0a4a..7a8ba600 100644 --- a/engine_unix.go +++ b/engine_unix.go @@ -8,9 +8,12 @@ package nbio import ( + "errors" "net" "runtime" "strings" + "syscall" + "time" "github.com/lesismal/nbio/logging" "github.com/lesismal/nbio/mempool" @@ -28,7 +31,7 @@ func (g *Engine) Start() error { udpListeners := make([]*net.UDPConn, len(g.Addrs))[0:0] switch g.Network { - case "unix", "tcp", "tcp4", "tcp6": + case NETWORK_UNIX, NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6: for i := range g.Addrs { ln, err := newPoller(g, true, i) if err != nil { @@ -40,7 +43,7 @@ func (g *Engine) Start() error { g.Addrs[i] = ln.listener.Addr().String() g.listeners = append(g.listeners, ln) } - case "udp", "udp4", "udp6": + case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6: for i, addrStr := range g.Addrs { addr, err := net.ResolveUDPAddr(g.Network, addrStr) if err != nil { @@ -139,6 +142,117 @@ func (g *Engine) Start() error { return nil } +// DialAsync connects asynchrony to the address on the named network. +func (engine *Engine) DialAsync(network, addr string, onConnected func(*Conn, error)) error { + return engine.DialAsyncTimeout(network, addr, 0, onConnected) +} + +// DialAsync connects asynchrony to the address on the named network with timeout. +func (engine *Engine) DialAsyncTimeout(network, addr string, timeout time.Duration, onConnected func(*Conn, error)) error { + h := func(c *Conn, err error) { + if err == nil { + c.SetWriteDeadline(time.Time{}) + } + onConnected(c, err) + } + domain, typ, dialaddr, raddr, connType, err := parseDomainAndType(network, addr) + if err != nil { + return err + } + fd, err := syscall.Socket(domain, typ, 0) + if err != nil { + return err + } + err = syscall.SetNonblock(fd, true) + if err != nil { + syscall.Close(fd) + return err + } + err = syscall.Connect(fd, dialaddr) + inprogress := false + if err != nil { + if errors.Is(err, syscall.EINPROGRESS) { + inprogress = true + } else { + syscall.Close(fd) + return err + } + } + sa, _ := syscall.Getsockname(fd) + c := &Conn{ + fd: fd, + rAddr: raddr, + typ: connType, + } + if inprogress { + c.onConnected = h + } + switch vt := sa.(type) { + case *syscall.SockaddrInet4: + switch connType { + case ConnTypeTCP: + c.lAddr = &net.TCPAddr{ + IP: []byte{vt.Addr[0], vt.Addr[1], vt.Addr[2], vt.Addr[3]}, + Port: vt.Port, + } + case ConnTypeUDPClientFromDial: + c.lAddr = &net.TCPAddr{ + IP: []byte{vt.Addr[0], vt.Addr[1], vt.Addr[2], vt.Addr[3]}, + Port: vt.Port, + } + c.connUDP = &udpConn{ + parent: c, + } + } + case *syscall.SockaddrInet6: + var iface *net.Interface + iface, err = net.InterfaceByIndex(int(vt.ZoneId)) + if err != nil { + syscall.Close(fd) + return err + } + switch connType { + case ConnTypeTCP: + c.lAddr = &net.TCPAddr{ + IP: make([]byte, len(vt.Addr)), + Port: vt.Port, + Zone: iface.Name, + } + case ConnTypeUDPClientFromDial: + c.lAddr = &net.UDPAddr{ + IP: make([]byte, len(vt.Addr)), + Port: vt.Port, + Zone: iface.Name, + } + c.connUDP = &udpConn{ + parent: c, + } + } + case *syscall.SockaddrUnix: + c.lAddr = &net.UnixAddr{ + Net: network, + Name: vt.Name, + } + } + + engine.wgConn.Add(1) + _, err = engine.addDialer(c) + if err != nil { + engine.wgConn.Done() + return err + } + + if !inprogress { + engine.Async(func() { + h(c, nil) + }) + } else if timeout > 0 { + c.setDeadline(&c.wTimer, ErrDialTimeout, time.Now().Add(timeout)) + } + + return nil +} + // NewEngine creates an Engine and init default configurations. func NewEngine(conf Config) *Engine { if conf.Name == "" { diff --git a/error.go b/error.go index 6465b50d..cb20e8d1 100644 --- a/error.go +++ b/error.go @@ -18,5 +18,7 @@ var ( ErrOverflow = errors.New("write overflow") errOverflow = ErrOverflow + ErrDialTimeout = errors.New("dial timeout") + ErrUnsupported = errors.New("unsupported operation") ) diff --git a/mempool/trace_debugger.go b/mempool/trace_debugger.go new file mode 100644 index 00000000..57db6af2 --- /dev/null +++ b/mempool/trace_debugger.go @@ -0,0 +1,118 @@ +package mempool + +import ( + "fmt" + "runtime/debug" + "sync" + "unsafe" +) + +type TraceDebugger struct { + mux sync.Mutex + pAlloced map[uintptr]struct{} + allocator Allocator +} + +func NewTraceDebuger(allocator Allocator) *TraceDebugger { + return &TraceDebugger{ + allocator: allocator, + pAlloced: map[uintptr]struct{}{}, + } +} + +// Malloc . +func (td *TraceDebugger) Malloc(size int) []byte { + buf := td.allocator.Malloc(size) + td.setBufferPointer(buf) + return buf +} + +// Realloc . +func (td *TraceDebugger) Realloc(buf []byte, size int) []byte { + newBuf := td.allocator.Realloc(buf, size) + pold := td.pointer(buf) + pnew := td.pointer(newBuf) + if pnew != pold { + td.deleteBufferPointer(buf) + td.setBufferPointer(newBuf) + } + return newBuf +} + +// Append . +func (td *TraceDebugger) Append(buf []byte, more ...byte) []byte { + newBuf := td.allocator.Append(buf, more...) + pold := td.pointer(buf) + pnew := td.pointer(newBuf) + if pnew != pold { + td.deleteBufferPointer(buf) + td.setBufferPointer(newBuf) + } + return newBuf +} + +// AppendString . +func (td *TraceDebugger) AppendString(buf []byte, more string) []byte { + newBuf := td.allocator.AppendString(buf, more) + pold := td.pointer(buf) + pnew := td.pointer(newBuf) + if pnew != pold { + td.deleteBufferPointer(buf) + td.setBufferPointer(newBuf) + } + return newBuf +} + +// Free . +func (td *TraceDebugger) Free(buf []byte) { + // if cap(buf) == 0 { + // td.printStack("invalid buf with cap 0") + // return + // } + td.deleteBufferPointer(buf) + td.allocator.Free(buf) +} + +func (td *TraceDebugger) setBufferPointer(buf []byte) { + if cap(buf) == 0 { + td.printStack("invalid buf with cap 0") + return + } + + td.mux.Lock() + defer td.mux.Unlock() + p := td.pointer(buf) + if _, ok := td.pAlloced[p]; ok { + td.printStack("re-alloc the same buf before free:") + return + } + td.pAlloced[p] = struct{}{} +} + +func (td *TraceDebugger) deleteBufferPointer(buf []byte) { + if cap(buf) == 0 { + td.printStack("invalid buf with cap 0") + return + } + + td.mux.Lock() + defer td.mux.Unlock() + p := td.pointer(buf) + if _, ok := td.pAlloced[p]; !ok { + td.printStack("free un-allocated buf:") + return + } + delete(td.pAlloced, p) +} + +func (td *TraceDebugger) pointer(buf []byte) uintptr { + p := *((*uintptr)(unsafe.Pointer(&(buf[:1][0])))) + return p +} + +func (td *TraceDebugger) printStack(info string) { + fmt.Println("-----------------------------") + fmt.Println("[mempool trace] " + info + "\n") + debug.PrintStack() + fmt.Println("-----------------------------") +} diff --git a/nbhttp/body_test.go b/nbhttp/body_test.go index 4a06ccbb..7d4ae59c 100644 --- a/nbhttp/body_test.go +++ b/nbhttp/body_test.go @@ -9,6 +9,23 @@ import ( "github.com/lesismal/nbio/mempool" ) +func TestBodyReaderPool(t *testing.T) { + br := bodyReaderPool.Get().(*BodyReader) + br.buffers = append(br.buffers, make([]byte, 10)) + *br = emptyBodyReader + bodyReaderPool.Put(br) + + for i := 0; i < 1000; i++ { + br2 := bodyReaderPool.Get().(*BodyReader) + if br2.buffers != nil { + t.Fatal("len>0") + } + br2.buffers = append(br.buffers, make([]byte, 10)) + *br2 = emptyBodyReader + bodyReaderPool.Put(br) + } +} + func TestBodyReader(t *testing.T) { engine := NewEngine(Config{ BodyAllocator: mempool.NewAligned(), diff --git a/nbhttp/websocket/conn.go b/nbhttp/websocket/conn.go index 46e540b0..3021839b 100644 --- a/nbhttp/websocket/conn.go +++ b/nbhttp/websocket/conn.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/binary" + "errors" "fmt" "io" "math/rand" @@ -295,7 +296,7 @@ func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool, bodyLen = int64(payloadLen) } - if c.isMessageTooLarge(int(bodyLen)) { + if c.isMessageTooLarge(len(c.message) + int(bodyLen)) { return data, 0, nil, false, false, false, ErrMessageTooLarge } @@ -407,6 +408,9 @@ func (c *Conn) Parse(data []byte) error { }() if err != nil { + if errors.Is(err, ErrMessageTooLarge) || errors.Is(err, ErrControlMessageTooBig) { + c.WriteClose(1009, err.Error()) + } return err } @@ -913,8 +917,8 @@ func (c *Conn) HandleRead(bufSize int) { // return false if length is ok. func (c *Conn) isMessageTooLarge(len int) bool { - if c.MessageLengthLimit == 0 { - // 0 means unlimitted size + // <=0 means unlimitted size + if c.MessageLengthLimit <= 0 { return false } return len > c.MessageLengthLimit @@ -941,6 +945,9 @@ func (c *Conn) validFrame(opcode MessageType, fin, res1, res2, res3, expectingFr func (c *Conn) readAll(r io.Reader, size int) ([]byte, error) { const maxAppendSize = 1024 * 1024 * 4 + if c.MessageLengthLimit > 0 && size > c.MessageLengthLimit { + size = c.MessageLengthLimit + } buf := c.Engine.BodyAllocator.Malloc(size)[0:0] for { n, err := r.Read(buf[len(buf):cap(buf)]) @@ -955,10 +962,18 @@ func (c *Conn) readAll(r io.Reader, size int) ([]byte, error) { } if len(buf) == cap(buf) { l := len(buf) + // can not extend more bytes. + if c.isMessageTooLarge(l + 1) { + return nil, ErrMessageTooLarge + } al := l if al > maxAppendSize { al = maxAppendSize } + // extend to the limit size at most. + if (c.MessageLengthLimit > 0) && (l+al > c.MessageLengthLimit) { + al = c.MessageLengthLimit - l + } buf = c.Engine.BodyAllocator.Append(buf, make([]byte, al)...)[:l] } } diff --git a/nbhttp/websocket/dialer.go b/nbhttp/websocket/dialer.go index 3aa4d3cd..4becd2d6 100644 --- a/nbhttp/websocket/dialer.go +++ b/nbhttp/websocket/dialer.go @@ -137,7 +137,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } } - if d.EnableCompression { + if options.enableCompression { req.Header[secWebsocketExtHeaderField] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} } @@ -251,7 +251,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h wsConn = NewClientConn(options, conn, resp.Header.Get(secWebsocketProtoHeaderField), remoteCompressionEnabled, false) wsConn.Engine = parser.Engine wsConn.Execute = parser.Execute - parser.ParserCloser = wsConn + nbc.SetSession(wsConn) if wsConn.openHandler != nil { wsConn.openHandler(wsConn) diff --git a/nbhttp/websocket/upgrader.go b/nbhttp/websocket/upgrader.go index 47471e54..456c214a 100644 --- a/nbhttp/websocket/upgrader.go +++ b/nbhttp/websocket/upgrader.go @@ -41,6 +41,10 @@ var ( // DefaultBlockingModSendQueueMaxSize . DefaultBlockingModSendQueueMaxSize uint16 = 0 + // DefaultMessageLengthLimit . + DefaultMessageLengthLimit = 1024 * 1024 * 4 + + // DefaultBlockingModAsyncCloseDelay . DefaultBlockingModAsyncCloseDelay = time.Second / 10 // DefaultEngine will be set to a Upgrader.Engine to handle details such as buffers. @@ -138,6 +142,7 @@ func NewUpgrader() *Upgrader { u := &Upgrader{ commonFields: commonFields{ KeepaliveTime: nbhttp.DefaultKeepaliveTime, + MessageLengthLimit: DefaultMessageLengthLimit, BlockingModAsyncCloseDelay: DefaultBlockingModAsyncCloseDelay, }, compressionLevel: defaultCompressionLevel, diff --git a/nbio_test.go b/nbio_test.go index 4bb3112d..21e388c1 100644 --- a/nbio_test.go +++ b/nbio_test.go @@ -13,11 +13,13 @@ import ( "time" ) -var addr = "127.0.0.1:8888" +var addr = "127.0.0.1:9999" var testfile = "test_tmp.file" -var gopher *Engine +var engine *Engine var testFileSize = 1024 * 1024 * 32 +const osWindows = "windows" + func init() { if err := os.WriteFile(testfile, make([]byte, testFileSize), 0600); err != nil { log.Panicf("write file failed: %v", err) @@ -83,7 +85,7 @@ func init() { wsess := session.(*writtenSizeSession) if wsess.isFile { if wsess.sumSend != testFileSize { - panic("invalid send size for sendfile") + panic(fmt.Errorf("invalid send size for sendfile: %v, %v", wsess.sumSend, testFileSize)) } } else { if wsess.sumSend != wsess.sumRecv { @@ -97,7 +99,7 @@ func init() { log.Panicf("Start failed: %v\n", err) } - gopher = g + engine = g } func TestEcho(t *testing.T) { @@ -158,7 +160,7 @@ func TestEcho(t *testing.T) { } for i := 0; i < clientNum; i++ { - if runtime.GOOS != "windows" { + if runtime.GOOS != osWindows { one(i) } else { go one(i) @@ -434,8 +436,85 @@ func TestUDP(t *testing.T) { time.Sleep(timeout * 2) } +func TestDialAsyncTCP(t *testing.T) { + network := "tcp" + addr := "127.0.0.1:10001" + testDialAsync(t, network, addr) +} + +func TestDialAsyncUDP(t *testing.T) { + network := "udp" + addr := "127.0.0.1:10001" + testDialAsync(t, network, addr) +} + +func TestDialAsyncUnix(t *testing.T) { + if runtime.GOOS == osWindows { + return + } + network := "unix" + addr := "unix.server" + testDialAsync(t, network, addr) +} + +func testDialAsync(t *testing.T, network, addr string) { + done := make(chan error, 1) + engineAsync := NewEngine(Config{ + Name: "udp-testing", + Network: network, + Addrs: []string{addr}, + NPoller: 1, + }) + engineAsync.OnOpen(func(c *Conn) { + log.Printf("TestDialAsync[%v, %v] OnOpen: %v, %v", network, addr, c.LocalAddr().String(), c.RemoteAddr().String()) + }) + cnt := 0 + engineAsync.OnData(func(c *Conn, data []byte) { + cnt++ + if cnt == 1 { + c.Write(data) + log.Printf("TestDialAsync[%v, %v] Server OnData: %v, %v, %v", network, addr, c.LocalAddr().String(), c.RemoteAddr().String(), string(data)) + } else { + log.Printf("TestDialAsync[%v, %v] Client OnData: %v, %v, %v", network, addr, c.LocalAddr().String(), c.RemoteAddr().String(), string(data)) + close(done) + } + }) + engineAsync.OnClose(func(c *Conn, err error) { + log.Printf("TestDialAsync[%v, %v] OnClose: %v, %v", network, addr, c.LocalAddr().String(), c.RemoteAddr().String()) + }) + err := engineAsync.Start() + if err != nil { + t.Fatalf("engineAsync start failed: %v", err) + } + defer engineAsync.Stop() + + onConnected := func(c *Conn, err error) { + log.Printf("TestTestDialAsync[%v, %v] OnConnected: %v, %v, %v", network, addr, c.LocalAddr().String(), c.RemoteAddr().String(), err) + if err == nil { + var n int + n, err = c.Write([]byte("hello")) + if err != nil { + done <- err + } + log.Printf("TestTestDialAsync[%v, %v] OnConnected Write n: %v", network, addr, n) + } else { + done <- err + } + } + + time.Sleep(time.Second / 10) + err = engineAsync.DialAsyncTimeout(network, addr, time.Second*10, onConnected) + if err != nil { + t.Fatalf("TestTestDialAsync[%v, %v] DialAsyncTimeout failed: %v", network, addr, err) + } + err = <-done + if err != nil { + t.Fatalf("TestTestDialAsync[%v, %v] DialAsyncTimeout failed: %v", network, addr, err) + } +} + func TestUnix(t *testing.T) { - if runtime.GOOS == "windows" { + if runtime.GOOS == osWindows { return } @@ -501,6 +580,6 @@ func TestUnix(t *testing.T) { } func TestStop(t *testing.T) { - gopher.Stop() + engine.Stop() os.Remove(testfile) } diff --git a/net_unix.go b/net_unix.go index eaa3b4f6..0f1df872 100644 --- a/net_unix.go +++ b/net_unix.go @@ -10,6 +10,7 @@ package nbio import ( "errors" "net" + "strings" "syscall" ) @@ -100,3 +101,74 @@ func dupStdConn(conn net.Conn) (*Conn, error) { return c, nil } + +func parseDomainAndType(network, addr string) (int, int, syscall.Sockaddr, net.Addr, ConnType, error) { + var ( + isIPv4 = len(strings.Split(addr, ":")) == 2 + ) + + socketResult := func(sockType int, connType ConnType) (int, int, syscall.Sockaddr, net.Addr, ConnType, error) { + var ( + ip net.IP + port int + zone string + retAddr net.Addr + ) + if connType == ConnTypeTCP { + dstAddr, err := net.ResolveTCPAddr(network, addr) + if err != nil { + return 0, 0, nil, nil, 0, err + } + ip, port, zone, retAddr = dstAddr.IP, dstAddr.Port, dstAddr.Zone, dstAddr + } else { + dstAddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return 0, 0, nil, nil, 0, err + } + ip, port, zone, retAddr = dstAddr.IP, dstAddr.Port, dstAddr.Zone, dstAddr + } + + if isIPv4 { + return syscall.AF_INET, sockType, &syscall.SockaddrInet4{ + Addr: [4]byte{ip[0], ip[1], ip[2], ip[3]}, + Port: port, + }, retAddr, connType, nil + } + + iface, err := net.InterfaceByName(zone) + if err != nil { + return 0, 0, nil, nil, 0, err + } + addr6 := &syscall.SockaddrInet6{ + Port: port, + ZoneId: uint32(iface.Index), + } + copy(addr6.Addr[:], ip) + return syscall.AF_INET6, sockType, addr6, retAddr, connType, nil + } + + switch network { + case NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6: + return socketResult(syscall.SOCK_STREAM, ConnTypeTCP) + case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6: + return socketResult(syscall.SOCK_DGRAM, ConnTypeUDPClientFromDial) + case NETWORK_UNIX, NETWORK_UNIXGRAM, NETWORK_UNIXPACKET: + sotype := syscall.SOCK_STREAM + switch network { + case NETWORK_UNIX: + sotype = syscall.SOCK_STREAM + case NETWORK_UNIXGRAM: + sotype = syscall.SOCK_DGRAM + case NETWORK_UNIXPACKET: + sotype = syscall.SOCK_SEQPACKET + default: + } + dstAddr := &net.UnixAddr{ + Net: network, + Name: addr, + } + return syscall.AF_UNIX, sotype, &syscall.SockaddrUnix{Name: addr}, dstAddr, ConnTypeUnix, nil + default: + } + return 0, 0, nil, nil, 0, net.UnknownNetworkError(network) +} diff --git a/poller_epoll.go b/poller_epoll.go index ff010170..5b84d331 100644 --- a/poller_epoll.go +++ b/poller_epoll.go @@ -68,16 +68,15 @@ type poller struct { } // add the connection to poller and handle its io events. -func (p *poller) addConn(c *Conn) { +func (p *poller) addConn(c *Conn) error { fd := c.fd if fd >= len(p.g.connsUnix) { - c.closeWithError( - fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", - fd, - len(p.g.connsUnix), - ), + err := fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", + fd, + len(p.g.connsUnix), ) - return + c.closeWithError(err) + return err } c.p = p if c.typ != ConnTypeUDPServer { @@ -92,6 +91,30 @@ func (p *poller) addConn(c *Conn) { c.closeWithError(err) logging.Error("[%v] add read event failed: %v", c.fd, err) } + return err +} + +// add the connection to poller and handle its io events. +func (p *poller) addDialer(c *Conn) error { + fd := c.fd + if fd >= len(p.g.connsUnix) { + err := fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", + fd, + len(p.g.connsUnix), + ) + c.closeWithError(err) + return err + } + c.p = p + p.g.connsUnix[fd] = c + c.isWAdded = true + err := p.addReadWrite(fd) + if err != nil { + p.g.connsUnix[fd] = nil + c.closeWithError(err) + logging.Error("[%v] add read event failed: %v", c.fd, err) + } + return err } func (p *poller) getConn(fd int) *Conn { @@ -219,13 +242,16 @@ func (p *poller) readWriteLoop() { c := p.getConn(fd) if c != nil { if ev.Events&epollEventsWrite != 0 { - c.flush() + if c.onConnected == nil { + c.flush() + } else { + c.onConnected(c, nil) + c.onConnected = nil + c.resetRead() + } } if ev.Events&epollEventsRead != 0 { - if c.onConnected != nil { - c.onConnected(c, nil) - } if g.onRead == nil { if asyncReadEnabled { c.AsyncRead() @@ -328,12 +354,18 @@ func (p *poller) setRead(op int, fd int) error { } func (p *poller) modWrite(fd int) error { + return p.setReadWrite(syscall.EPOLL_CTL_MOD, fd) +} + +func (p *poller) addReadWrite(fd int) error { + return p.setReadWrite(syscall.EPOLL_CTL_ADD, fd) +} + +func (p *poller) setReadWrite(op int, fd int) error { switch p.g.EpollMod { case EPOLLET: return syscall.EpollCtl( - p.epfd, - syscall.EPOLL_CTL_MOD, - fd, + p.epfd, op, fd, &syscall.EpollEvent{ Fd: int32(fd), Events: syscall.EPOLLERR | @@ -347,9 +379,8 @@ func (p *poller) modWrite(fd int) error { }, ) default: - return syscall.EpollCtl(p.epfd, - syscall.EPOLL_CTL_MOD, - fd, + return syscall.EpollCtl( + p.epfd, op, fd, &syscall.EpollEvent{ Fd: int32(fd), Events: syscall.EPOLLERR | diff --git a/poller_kqueue.go b/poller_kqueue.go index 2f747578..c664f83b 100644 --- a/poller_kqueue.go +++ b/poller_kqueue.go @@ -61,11 +61,14 @@ type poller struct { eventList []syscall.Kevent_t } -func (p *poller) addConn(c *Conn) { +func (p *poller) addConn(c *Conn) error { fd := c.fd if fd >= len(p.g.connsUnix) { - c.closeWithError(fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", fd, len(p.g.connsUnix))) - return + err := fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", + fd, + len(p.g.connsUnix)) + c.closeWithError(err) + return err } c.p = p if c.typ != ConnTypeUDPServer { @@ -75,6 +78,24 @@ func (p *poller) addConn(c *Conn) { } p.g.connsUnix[fd] = c p.addRead(fd) + return nil +} + +func (p *poller) addDialer(c *Conn) error { + fd := c.fd + if fd >= len(p.g.connsUnix) { + err := fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", + fd, + len(p.g.connsUnix), + ) + c.closeWithError(err) + return err + } + c.p = p + p.g.connsUnix[fd] = c + c.isWAdded = true + p.addReadWrite(fd) + return nil } func (p *poller) getConn(fd int) *Conn { @@ -125,6 +146,14 @@ func (p *poller) modWrite(fd int) { p.trigger() } +func (p *poller) addReadWrite(fd int) { + p.mux.Lock() + p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_READ}) + p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_WRITE}) + p.mux.Unlock() + p.trigger() +} + // func (p *poller) deleteEvent(fd int) { // p.mux.Lock() // p.eventList = append(p.eventList, @@ -142,9 +171,6 @@ func (p *poller) readWrite(ev *syscall.Kevent_t) { c := p.getConn(fd) if c != nil { if ev.Filter == syscall.EVFILT_READ { - if c.onConnected != nil { - c.onConnected(c, nil) - } if p.g.onRead == nil { for { buffer := p.g.borrow(c) @@ -174,12 +200,24 @@ func (p *poller) readWrite(ev *syscall.Kevent_t) { } if ev.Flags&syscall.EV_EOF != 0 { - c.flush() + if c.onConnected == nil { + c.flush() + } else { + c.onConnected(c, nil) + c.onConnected = nil + c.resetRead() + } } } if ev.Filter == syscall.EVFILT_WRITE { - c.flush() + if c.onConnected == nil { + c.flush() + } else { + c.resetRead() + c.onConnected(c, nil) + c.onConnected = nil + } } } } diff --git a/poller_std.go b/poller_std.go index fd092cdc..fdf07276 100644 --- a/poller_std.go +++ b/poller_std.go @@ -86,6 +86,15 @@ func (p *poller) addConn(c *Conn) error { return nil } +func (p *poller) addDialer(c *Conn) error { + c.p = p + p.g.mux.Lock() + p.g.connsStd[c] = struct{}{} + p.g.mux.Unlock() + go p.readConn(c) + return nil +} + func (p *poller) deleteConn(c *Conn) { p.g.mux.Lock() delete(p.g.connsStd, c) diff --git a/sendfile_std.go b/sendfile_std.go index b765b069..3ff98167 100644 --- a/sendfile_std.go +++ b/sendfile_std.go @@ -39,9 +39,6 @@ func (c *Conn) Sendfile(f *os.File, remain int64) (written int64, err error) { if nw < 0 { nw = 0 } - if c.p.g.onWrittenSize != nil && nw > 0 { - c.p.g.onWrittenSize(c, nil, nw) - } remain -= int64(nw) written += int64(nw) if ew != nil { @@ -60,5 +57,10 @@ func (c *Conn) Sendfile(f *os.File, remain int64) (written int64, err error) { break } } + + if c.p.g.onWrittenSize != nil && written > 0 { + c.p.g.onWrittenSize(c, nil, int(written)) + } + return written, err }