Skip to content

Commit

Permalink
switch to accepting packet conns
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Mangum <[email protected]>
  • Loading branch information
hasheddan committed Aug 2, 2023
1 parent 054dd65 commit 09e44fb
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 139 deletions.
9 changes: 5 additions & 4 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"
"time"

"github.com/pion/dtls/v2/internal/util"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/logging"
"github.com/pion/transport/v2/dpipe"
Expand All @@ -30,7 +31,7 @@ func TestSimpleReadWrite(t *testing.T) {
gotHello := make(chan struct{})

go func() {
server, sErr := testServer(ctx, cb, &Config{
server, sErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{
Certificates: []tls.Certificate{certificate},
LoggerFactory: logging.NewDefaultLoggerFactory(),
}, false)
Expand All @@ -48,7 +49,7 @@ func TestSimpleReadWrite(t *testing.T) {
}
}()

client, err := testClient(ctx, ca, &Config{
client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{
LoggerFactory: logging.NewDefaultLoggerFactory(),
InsecureSkipVerify: true,
}, false)
Expand Down Expand Up @@ -78,7 +79,7 @@ func benchmarkConn(b *testing.B, n int64) {
certificate, err := selfsign.GenerateSelfSigned()
server := make(chan *Conn)
go func() {
s, sErr := testServer(ctx, cb, &Config{
s, sErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{
Certificates: []tls.Certificate{certificate},
}, false)
if err != nil {
Expand All @@ -94,7 +95,7 @@ func benchmarkConn(b *testing.B, n int64) {
b.ReportAllocs()
b.SetBytes(int64(len(hw)))
go func() {
client, cErr := testClient(ctx, ca, &Config{InsecureSkipVerify: true}, false)
client, cErr := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false)
if cErr != nil {
b.Error(err)
}
Expand Down
5 changes: 3 additions & 2 deletions cipher_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/pion/dtls/v2/internal/ciphersuite"
"github.com/pion/dtls/v2/internal/util"
"github.com/pion/transport/v2/dpipe"
"github.com/pion/transport/v2/test"
)
Expand Down Expand Up @@ -70,14 +71,14 @@ func TestCustomCipherSuite(t *testing.T) {
c := make(chan result)

go func() {
client, err := testClient(ctx, ca, &Config{
client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{
CipherSuites: []CipherSuiteID{},
CustomCipherSuites: cipherFactory,
}, true)
c <- result{client, err}
}()

server, err := testServer(ctx, cb, &Config{
server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{
CipherSuites: []CipherSuiteID{},
CustomCipherSuites: cipherFactory,
}, true)
Expand Down
37 changes: 15 additions & 22 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,68 +248,62 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
// Dial connects to the given network address and establishes a DTLS connection on top.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use DialWithContext() instead.
func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return DialWithContext(ctx, network, raddr, config)
return DialWithContext(ctx, network, rAddr, config)
}

// Client establishes a DTLS connection over an existing connection.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ClientWithContext() instead.
func Client(conn net.Conn, config *Config) (*Conn, error) {
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return ClientWithContext(ctx, conn, config)
return ClientWithContext(ctx, conn, rAddr, config)
}

// Server listens for incoming DTLS connections.
// Connection handshake will timeout using ConnectContextMaker in the Config.
// If you want to specify the timeout duration, use ServerWithContext() instead.
func Server(conn net.Conn, config *Config) (*Conn, error) {
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
ctx, cancel := config.connectContextMaker()
defer cancel()

return ServerWithContext(ctx, conn, config)
return ServerWithContext(ctx, conn, rAddr, config)
}

// PacketServer listens for incoming DTLS connections.
// Unlike Server, PacketServer allows for connections to change remote address.
// The provided rAddr will be used as the initial remote address for sending.
func PacketServer(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
return createConn(ctx, conn, rAddr, config, true, nil)
}

// DialWithContext connects to the given network address and establishes a DTLS connection on top.
func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
pConn, err := net.DialUDP(network, nil, raddr)
// DialWithContext connects to the given network address and establishes a DTLS
// connection on top.
func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
pConn, err := net.DialUDP(network, nil, rAddr)
if err != nil {
return nil, err
}
return ClientWithContext(ctx, pConn, config)
return ClientWithContext(ctx, pConn, rAddr, config)
}

// ClientWithContext establishes a DTLS connection over an existing connection.
func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
switch {
case config == nil:
return nil, errNoConfigProvided
case config.PSK != nil && config.PSKIdentityHint == nil:
return nil, errPSKAndIdentityMustBeSetForClient
}

return createConn(ctx, fromConn(conn), conn.RemoteAddr(), config, true, nil)
return createConn(ctx, conn, rAddr, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}

return createConn(ctx, fromConn(conn), conn.RemoteAddr(), config, false, nil)
return createConn(ctx, conn, rAddr, config, false, nil)
}

// Read reads data from the connection.
Expand Down Expand Up @@ -843,7 +837,6 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debug("unexpected connection ID")
return false, nil, nil
}

}

isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
Expand Down
9 changes: 5 additions & 4 deletions conn_go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"testing"
"time"

"github.com/pion/dtls/v2/internal/util"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/transport/v2/dpipe"
"github.com/pion/transport/v2/test"
Expand Down Expand Up @@ -85,7 +86,7 @@ func TestContextConfig(t *testing.T) {
f: func() (func() (net.Conn, error), func()) {
ca, _ := dpipe.Pipe()
return func() (net.Conn, error) {
return Client(ca, config)
return Client(util.FromConn(ca), ca.RemoteAddr(), config)
}, func() {
_ = ca.Close()
}
Expand All @@ -97,7 +98,7 @@ func TestContextConfig(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
ca, _ := dpipe.Pipe()
return func() (net.Conn, error) {
return ClientWithContext(ctx, ca, config)
return ClientWithContext(ctx, util.FromConn(ca), ca.RemoteAddr(), config)
}, func() {
cancel()
_ = ca.Close()
Expand All @@ -109,7 +110,7 @@ func TestContextConfig(t *testing.T) {
f: func() (func() (net.Conn, error), func()) {
ca, _ := dpipe.Pipe()
return func() (net.Conn, error) {
return Server(ca, config)
return Server(util.FromConn(ca), ca.RemoteAddr(), config)
}, func() {
_ = ca.Close()
}
Expand All @@ -121,7 +122,7 @@ func TestContextConfig(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
ca, _ := dpipe.Pipe()
return func() (net.Conn, error) {
return ServerWithContext(ctx, ca, config)
return ServerWithContext(ctx, util.FromConn(ca), ca.RemoteAddr(), config)
}, func() {
cancel()
_ = ca.Close()
Expand Down
Loading

0 comments on commit 09e44fb

Please sign in to comment.