From 414891a418eddf994113da60e8c2aaec2ddb9728 Mon Sep 17 00:00:00 2001 From: Daniel Mangum Date: Wed, 2 Aug 2023 13:56:35 -0400 Subject: [PATCH] switch to accepting packet conns Signed-off-by: Daniel Mangum --- bench_test.go | 9 ++-- cipher_suite_test.go | 5 ++- conn.go | 37 +++++++---------- conn_go_test.go | 9 ++-- conn_test.go | 97 ++++++++++++++++++++++--------------------- e2e/e2e_lossy_test.go | 5 ++- internal/util/net.go | 57 +++++++++++++++++++++++++ listener.go | 3 +- packet.go | 50 ---------------------- resume.go | 4 +- resume_test.go | 9 ++-- 11 files changed, 146 insertions(+), 139 deletions(-) create mode 100644 internal/util/net.go diff --git a/bench_test.go b/bench_test.go index abec5a5d7..9f27bb71b 100644 --- a/bench_test.go +++ b/bench_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/logging" "github.com/pion/transport/v2/dpipe" @@ -30,7 +31,7 @@ func TestSimpleReadWrite(t *testing.T) { gotHello := make(chan struct{}) go func() { - server, sErr := testServer(ctx, cb, &Config{ + server, sErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) @@ -48,7 +49,7 @@ func TestSimpleReadWrite(t *testing.T) { } }() - client, err := testClient(ctx, ca, &Config{ + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) @@ -78,7 +79,7 @@ func benchmarkConn(b *testing.B, n int64) { certificate, err := selfsign.GenerateSelfSigned() server := make(chan *Conn) go func() { - s, sErr := testServer(ctx, cb, &Config{ + s, sErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, }, false) if err != nil { @@ -94,7 +95,7 @@ func benchmarkConn(b *testing.B, n int64) { b.ReportAllocs() b.SetBytes(int64(len(hw))) go func() { - client, cErr := testClient(ctx, ca, &Config{InsecureSkipVerify: true}, false) + client, cErr := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false) if cErr != nil { b.Error(err) } diff --git a/cipher_suite_test.go b/cipher_suite_test.go index 655fe6717..0d2d83d09 100644 --- a/cipher_suite_test.go +++ b/cipher_suite_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/pion/dtls/v2/internal/ciphersuite" + "github.com/pion/dtls/v2/internal/util" "github.com/pion/transport/v2/dpipe" "github.com/pion/transport/v2/test" ) @@ -70,14 +71,14 @@ func TestCustomCipherSuite(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{ + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) c <- result{client, err} }() - server, err := testServer(ctx, cb, &Config{ + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) diff --git a/conn.go b/conn.go index 3749a3c51..11b90cf7f 100644 --- a/conn.go +++ b/conn.go @@ -248,51 +248,45 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co // Dial connects to the given network address and establishes a DTLS connection on top. // Connection handshake will timeout using ConnectContextMaker in the Config. // If you want to specify the timeout duration, use DialWithContext() instead. -func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) { +func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { ctx, cancel := config.connectContextMaker() defer cancel() - return DialWithContext(ctx, network, raddr, config) + return DialWithContext(ctx, network, rAddr, config) } // Client establishes a DTLS connection over an existing connection. // Connection handshake will timeout using ConnectContextMaker in the Config. // If you want to specify the timeout duration, use ClientWithContext() instead. -func Client(conn net.Conn, config *Config) (*Conn, error) { +func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { ctx, cancel := config.connectContextMaker() defer cancel() - return ClientWithContext(ctx, conn, config) + return ClientWithContext(ctx, conn, rAddr, config) } // Server listens for incoming DTLS connections. // Connection handshake will timeout using ConnectContextMaker in the Config. // If you want to specify the timeout duration, use ServerWithContext() instead. -func Server(conn net.Conn, config *Config) (*Conn, error) { +func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { ctx, cancel := config.connectContextMaker() defer cancel() - return ServerWithContext(ctx, conn, config) + return ServerWithContext(ctx, conn, rAddr, config) } -// PacketServer listens for incoming DTLS connections. -// Unlike Server, PacketServer allows for connections to change remote address. -// The provided rAddr will be used as the initial remote address for sending. -func PacketServer(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { - return createConn(ctx, conn, rAddr, config, true, nil) -} - -// DialWithContext connects to the given network address and establishes a DTLS connection on top. -func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) { - pConn, err := net.DialUDP(network, nil, raddr) +// DialWithContext connects to the given network address and establishes a DTLS +// connection on top. +func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { + pConn, err := net.DialUDP(network, nil, rAddr) if err != nil { return nil, err } - return ClientWithContext(ctx, pConn, config) + return ClientWithContext(ctx, pConn, rAddr, config) } // ClientWithContext establishes a DTLS connection over an existing connection. -func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { +func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { switch { case config == nil: return nil, errNoConfigProvided @@ -300,16 +294,16 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con return nil, errPSKAndIdentityMustBeSetForClient } - return createConn(ctx, fromConn(conn), conn.RemoteAddr(), config, true, nil) + return createConn(ctx, conn, rAddr, config, true, nil) } // ServerWithContext listens for incoming DTLS connections. -func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { +func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if config == nil { return nil, errNoConfigProvided } - return createConn(ctx, fromConn(conn), conn.RemoteAddr(), config, false, nil) + return createConn(ctx, conn, rAddr, config, false, nil) } // Read reads data from the connection. @@ -843,7 +837,6 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.log.Debug("unexpected connection ID") return false, nil, nil } - } isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...)) diff --git a/conn_go_test.go b/conn_go_test.go index 99e6f74c4..2978eb340 100644 --- a/conn_go_test.go +++ b/conn_go_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/transport/v2/dpipe" "github.com/pion/transport/v2/test" @@ -85,7 +86,7 @@ func TestContextConfig(t *testing.T) { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return Client(ca, config) + return Client(util.FromConn(ca), ca.RemoteAddr(), config) }, func() { _ = ca.Close() } @@ -97,7 +98,7 @@ func TestContextConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return ClientWithContext(ctx, ca, config) + return ClientWithContext(ctx, util.FromConn(ca), ca.RemoteAddr(), config) }, func() { cancel() _ = ca.Close() @@ -109,7 +110,7 @@ func TestContextConfig(t *testing.T) { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return Server(ca, config) + return Server(util.FromConn(ca), ca.RemoteAddr(), config) }, func() { _ = ca.Close() } @@ -121,7 +122,7 @@ func TestContextConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return ServerWithContext(ctx, ca, config) + return ServerWithContext(ctx, util.FromConn(ca), ca.RemoteAddr(), config) }, func() { cancel() _ = ca.Close() diff --git a/conn_test.go b/conn_test.go index 946e4ab43..b0dacc58c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/pion/dtls/v2/internal/ciphersuite" + "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/hash" "github.com/pion/dtls/v2/pkg/crypto/selfsign" @@ -265,12 +266,12 @@ func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { // Setup client go func() { - client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) c <- result{client, err} }() // Setup server - server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) if err != nil { return nil, nil, err } @@ -285,7 +286,7 @@ func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { return res.c, server, nil } -func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { +func testClient(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool) (*Conn, error) { if generateCertificate { clientCert, err := selfsign.GenerateSelfSigned() if err != nil { @@ -294,10 +295,10 @@ func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificat cfg.Certificates = []tls.Certificate{clientCert} } cfg.InsecureSkipVerify = true - return ClientWithContext(ctx, c, cfg) + return ClientWithContext(ctx, c, rAddr, cfg) } -func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { +func testServer(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool) (*Conn, error) { if generateCertificate { serverCert, err := selfsign.GenerateSelfSigned() if err != nil { @@ -305,7 +306,7 @@ func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificat } cfg.Certificates = []tls.Certificate{serverCert} } - return ServerWithContext(ctx, c, cfg) + return ServerWithContext(ctx, c, rAddr, cfg) } func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension) error { @@ -384,11 +385,11 @@ func TestHandshakeWithAlert(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - _, err := testClient(ctx, ca, testCase.configClient, true) + _, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), testCase.configClient, true) clientErr <- err }() - _, errServer := testServer(ctx, cb, testCase.configServer, true) + _, errServer := testServer(ctx, util.FromConn(cb), ca.RemoteAddr(), testCase.configServer, true) if !errors.Is(errServer, testCase.errServer) { t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer) } @@ -551,7 +552,7 @@ func TestPSK(t *testing.T) { VerifyConnection: test.ClientVerifyConnection, } - c, err := testClient(ctx, ca, conf, false) + c, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, false) clientRes <- result{c, err} }() @@ -567,7 +568,7 @@ func TestPSK(t *testing.T) { VerifyConnection: test.ServerVerifyConnection, } - server, err := testServer(ctx, cb, config, false) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, false) if test.WantFail { res := <-clientRes if err == nil || !strings.Contains(err.Error(), test.ExpectedServerErr) { @@ -626,7 +627,7 @@ func TestPSKHintFail(t *testing.T) { CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } - _, err := testClient(ctx, ca, conf, false) + _, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, false) clientErr <- err }() @@ -638,7 +639,7 @@ func TestPSKHintFail(t *testing.T) { CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } - if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) { + if _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, false); !errors.Is(err, serverAlertError) { t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err) } @@ -665,7 +666,7 @@ func TestClientTimeout(t *testing.T) { go func() { conf := &Config{} - c, err := testClient(ctx, ca, conf, true) + c, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, true) if err == nil { _ = c.Close() //nolint:contextcheck } @@ -753,11 +754,11 @@ func TestSRTPConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) c <- result{client, err} }() - server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) if !errors.Is(err, test.WantServerError) { t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) } @@ -961,11 +962,11 @@ func TestClientCertificate(t *testing.T) { c := make(chan result) go func() { - client, err := Client(ca, tt.clientCfg) + client, err := Client(util.FromConn(ca), ca.RemoteAddr(), tt.clientCfg) c <- result{client, err} }() - server, err := Server(cb, tt.serverCfg) + server, err := Server(util.FromConn(cb), cb.RemoteAddr(), tt.serverCfg) res := <-c defer func() { if err == nil { @@ -1157,11 +1158,11 @@ func TestExtendedMasterSecret(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, ca, tt.clientCfg, true) + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) c <- result{client, err} }() - server, err := testServer(ctx, cb, tt.serverCfg, true) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) res := <-c defer func() { if err == nil { @@ -1267,11 +1268,11 @@ func TestServerCertificate(t *testing.T) { } srvCh := make(chan result) go func() { - s, err := Server(cb, tt.serverCfg) + s, err := Server(util.FromConn(cb), cb.RemoteAddr(), tt.serverCfg) srvCh <- result{s, err} }() - cli, err := Client(ca, tt.clientCfg) + cli, err := Client(util.FromConn(ca), ca.RemoteAddr(), tt.clientCfg) if err == nil { _ = cli.Close() } @@ -1371,11 +1372,11 @@ func TestCipherSuiteConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{CipherSuites: test.ClientCipherSuites}, true) + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: test.ClientCipherSuites}, true) c <- result{client, err} }() - server, err := testServer(ctx, cb, &Config{CipherSuites: test.ServerCipherSuites}, true) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{CipherSuites: test.ServerCipherSuites}, true) if err == nil { defer func() { _ = server.Close() @@ -1440,7 +1441,7 @@ func TestCertificateAndPSKServer(t *testing.T) { config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256} } - client, err := testClient(ctx, ca, config, false) + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), config, false) c <- result{client, err} }() @@ -1451,7 +1452,7 @@ func TestCertificateAndPSKServer(t *testing.T) { }, } - server, err := testServer(ctx, cb, config, true) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true) if err == nil { defer func() { _ = server.Close() @@ -1543,11 +1544,11 @@ func TestPSKConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate) + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate) c <- result{client, err} }() - _, err := testServer(ctx, cb, &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate) + _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) @@ -1677,7 +1678,7 @@ func TestServerTimeout(t *testing.T) { FlightInterval: 100 * time.Millisecond, } - _, serverErr := testServer(ctx, cb, config, true) + _, serverErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true) var netErr net.Error if !errors.As(serverErr, &netErr) || !netErr.Timeout() { t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr) @@ -1792,7 +1793,7 @@ func TestProtocolVersionValidation(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) { + if _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true); !errors.Is(err, errUnsupportedProtocolVersion) { t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() @@ -1882,7 +1883,7 @@ func TestProtocolVersionValidation(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) { + if _, err := testClient(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true); !errors.Is(err, errUnsupportedProtocolVersion) { t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() @@ -1980,7 +1981,7 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - _, _ = testClient(ctx, ca, &Config{}, false) + _, _ = testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{}, false) }() for i, cookie := range cookies { @@ -2052,7 +2053,7 @@ func TestRenegotationInfo(t *testing.T) { defer cancel() go func() { - if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) { + if _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is(err, context.Canceled) { t.Error(err) } }() @@ -2164,7 +2165,7 @@ func TestServerNameIndicationExtension(t *testing.T) { ServerName: test.ServerName, } - _, _ = testClient(ctx, ca, conf, false) + _, _ = testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello @@ -2282,7 +2283,7 @@ func TestALPNExtension(t *testing.T) { conf := &Config{ SupportedProtocols: test.ClientProtocolNameList, } - _, _ = testClient(ctx, ca, conf, false) + _, _ = testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello @@ -2300,7 +2301,7 @@ func TestALPNExtension(t *testing.T) { conf := &Config{ SupportedProtocols: test.ServerProtocolNameList, } - if _, err2 := testServer(ctx2, cb2, conf, true); !errors.Is(err2, context.Canceled) { + if _, err2 := testServer(ctx2, util.FromConn(cb2), cb2.RemoteAddr(), conf, true); !errors.Is(err2, context.Canceled) { if test.ExpectAlertFromServer { //nolint // Assert the error type? } else { @@ -2447,7 +2448,7 @@ func TestSupportedGroupsExtension(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) { + if _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is(err, context.Canceled) { t.Error(err) } }() @@ -2556,7 +2557,7 @@ func TestSessionResume(t *testing.T) { SessionStore: ss, MTU: 100, } - c, err := testClient(ctx, ca, config, false) + c, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() @@ -2566,7 +2567,7 @@ func TestSessionResume(t *testing.T) { SessionStore: ss, MTU: 100, } - server, err := testServer(ctx, cb, config, true) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true) if err != nil { t.Fatalf("TestSessionResume: Server failed(%v)", err) } @@ -2610,14 +2611,14 @@ func TestSessionResume(t *testing.T) { ServerName: "example.com", SessionStore: s1, } - c, err := testClient(ctx, ca, config, false) + c, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() config := &Config{ SessionStore: s2, } - server, err := testServer(ctx, cb, config, true) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true) if err != nil { t.Fatalf("TestSessionResumetion: Server failed(%v)", err) } @@ -2715,7 +2716,7 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - c, err := testClient(context.TODO(), ca, &Config{CipherSuites: test.cipherList}, false) + c, err := testClient(context.TODO(), util.FromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: test.cipherList}, false) clientErr <- err client <- c }() @@ -2740,7 +2741,7 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { t.Fatal(err) } - if s, err := testServer(context.TODO(), cb, &Config{ + if s, err := testServer(context.TODO(), util.FromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: test.cipherList, Certificates: []tls.Certificate{serverCert}, }, false); err != nil { @@ -2805,7 +2806,7 @@ func TestMultipleServerCertificates(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - c, err := testClient(context.TODO(), ca, &Config{ + c, err := testClient(context.TODO(), util.FromConn(ca), ca.RemoteAddr(), &Config{ RootCAs: caPool, ServerName: test.RequestServerName, VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { @@ -2825,7 +2826,7 @@ func TestMultipleServerCertificates(t *testing.T) { client <- c }() - if s, err := testServer(context.TODO(), cb, &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil { + if s, err := testServer(context.TODO(), util.FromConn(cb), cb.RemoteAddr(), &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil { t.Fatal(err) } else if err = s.Close(); err != nil { t.Fatal(err) @@ -2877,11 +2878,11 @@ func TestEllipticCurveConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) c <- result{client, err} }() - server, err := testServer(ctx, cb, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) + server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) if err != nil { t.Fatalf("Server error: %v", err) } @@ -2933,7 +2934,7 @@ func TestSkipHelloVerify(t *testing.T) { gotHello := make(chan struct{}) go func() { - server, sErr := testServer(ctx, cb, &Config{ + server, sErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerifyHello: true, @@ -2952,7 +2953,7 @@ func TestSkipHelloVerify(t *testing.T) { } }() - client, err := testClient(ctx, ca, &Config{ + client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) diff --git a/e2e/e2e_lossy_test.go b/e2e/e2e_lossy_test.go index 2789ec3e9..c694287da 100644 --- a/e2e/e2e_lossy_test.go +++ b/e2e/e2e_lossy_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/pion/dtls/v2" + "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" transportTest "github.com/pion/transport/v2/test" ) @@ -144,7 +145,7 @@ func TestPionE2ELossy(t *testing.T) { cfg.Certificates = []tls.Certificate{clientCert} } - client, startupErr := dtls.Client(br.GetConn0(), cfg) + client, startupErr := dtls.Client(util.FromConn(br.GetConn0()), br.GetConn0().RemoteAddr(), cfg) clientDone <- runResult{client, startupErr} }() @@ -159,7 +160,7 @@ func TestPionE2ELossy(t *testing.T) { cfg.ClientAuth = dtls.RequireAnyClientCert } - server, startupErr := dtls.Server(br.GetConn1(), cfg) + server, startupErr := dtls.Server(util.FromConn(br.GetConn1()), br.GetConn1().RemoteAddr(), cfg) serverDone <- runResult{server, startupErr} }() diff --git a/internal/util/net.go b/internal/util/net.go new file mode 100644 index 000000000..5a94dcf2a --- /dev/null +++ b/internal/util/net.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package util contains small helpers used across the repo +package util + +import ( + "net" + "time" +) + +// packetConn wraps a net.Conn with methods that satisfy net.PacketConn. +type packetConn struct { + conn net.Conn +} + +// FromConn converts a net.Conn into a net.PacketConn. +func FromConn(conn net.Conn) net.PacketConn { + return &packetConn{conn} +} + +// ReadFrom reads from the underlying net.Conn and returns its remote address. +func (cp *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := cp.conn.Read(b) + return n, cp.conn.RemoteAddr(), err +} + +// WriteTo writes to the underlying net.Conn. +func (cp *packetConn) WriteTo(b []byte, _ net.Addr) (int, error) { + n, err := cp.conn.Write(b) + return n, err +} + +// Close closes the underlying net.Conn. +func (cp *packetConn) Close() error { + return cp.conn.Close() +} + +// LocalAddr returns the local address of the underlying net.Conn. +func (cp *packetConn) LocalAddr() net.Addr { + return cp.conn.LocalAddr() +} + +// SetDeadline sets the deadline on the underlying net.Conn. +func (cp *packetConn) SetDeadline(t time.Time) error { + return cp.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying net.Conn. +func (cp *packetConn) SetReadDeadline(t time.Time) error { + return cp.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying net.Conn. +func (cp *packetConn) SetWriteDeadline(t time.Time) error { + return cp.conn.SetWriteDeadline(t) +} diff --git a/listener.go b/listener.go index 190d236c7..0d281fc4d 100644 --- a/listener.go +++ b/listener.go @@ -6,6 +6,7 @@ package dtls import ( "net" + "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" "github.com/pion/transport/v2/udp" @@ -67,7 +68,7 @@ func (l *listener) Accept() (net.Conn, error) { if err != nil { return nil, err } - return Server(c, l.config) + return Server(util.FromConn(c), c.RemoteAddr(), l.config) } // Close closes the listener. diff --git a/packet.go b/packet.go index 02d762b38..052c33a19 100644 --- a/packet.go +++ b/packet.go @@ -4,9 +4,6 @@ package dtls import ( - "net" - "time" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) @@ -16,50 +13,3 @@ type packet struct { shouldWrapCID bool resetLocalSequenceNumber bool } - -// packetConn wraps a net.Conn with methods that satisfy net.PacketConn. -type packetConn struct { - conn net.Conn -} - -// fromConn converts a net.Conn into a net.PacketConn. -func fromConn(conn net.Conn) net.PacketConn { - return &packetConn{conn} -} - -// ReadFrom reads from the underlying net.Conn and returns its remote address. -func (cp *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, err := cp.conn.Read(b) - return n, cp.conn.RemoteAddr(), err -} - -// WriteTo writes to the underlying net.Conn. -func (cp *packetConn) WriteTo(b []byte, _ net.Addr) (int, error) { - n, err := cp.conn.Write(b) - return n, err -} - -// Close closes the underlying net.Conn. -func (cp *packetConn) Close() error { - return cp.conn.Close() -} - -// LocalAddr returns the local address of the underlying net.Conn. -func (cp *packetConn) LocalAddr() net.Addr { - return cp.conn.LocalAddr() -} - -// SetDeadline sets the deadline on the underlying net.Conn. -func (cp *packetConn) SetDeadline(t time.Time) error { - return cp.conn.SetDeadline(t) -} - -// SetReadDeadline sets the read deadline on the underlying net.Conn. -func (cp *packetConn) SetReadDeadline(t time.Time) error { - return cp.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the write deadline on the underlying net.Conn. -func (cp *packetConn) SetWriteDeadline(t time.Time) error { - return cp.conn.SetWriteDeadline(t) -} diff --git a/resume.go b/resume.go index dcc304f23..9e8a2ae42 100644 --- a/resume.go +++ b/resume.go @@ -9,11 +9,11 @@ import ( ) // Resume imports an already established dtls connection using a specific dtls state -func Resume(state *State, conn net.Conn, config *Config) (*Conn, error) { +func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if err := state.initCipherSuite(); err != nil { return nil, err } - c, err := createConn(context.Background(), fromConn(conn), conn.RemoteAddr(), config, state.isClient, state) + c, err := createConn(context.Background(), conn, rAddr, config, state.isClient, state) if err != nil { return nil, err } diff --git a/resume_test.go b/resume_test.go index c8c231b86..570034d83 100644 --- a/resume_test.go +++ b/resume_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/transport/v2/test" ) @@ -32,7 +33,7 @@ func fatal(t *testing.T, errChan chan error, err error) { t.Fatal(err) } -func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Conn, error)) { +func DoTestResume(t *testing.T, newLocal, newRemote func(net.PacketConn, net.Addr, *Config) (*Conn, error)) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -67,7 +68,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Co go func() { var remote *Conn var errR error - remote, errR = newRemote(remoteConn, config) + remote, errR = newRemote(util.FromConn(remoteConn), remote.RemoteAddr(), config) if errR != nil { errChan <- errR } @@ -89,7 +90,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Co }() var local *Conn - local, err = newLocal(localConn1, config) + local, err = newLocal(util.FromConn(localConn1), localConn1.RemoteAddr(), config) if err != nil { fatal(t, errChan, err) } @@ -132,7 +133,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Co // Resume dtls connection var resumed net.Conn - resumed, err = Resume(deserialized, localConn2, config) + resumed, err = Resume(deserialized, util.FromConn(localConn2), localConn2.RemoteAddr(), config) if err != nil { fatal(t, errChan, err) }