diff --git a/client.go b/client.go index 6c826bc..3a6e1cd 100644 --- a/client.go +++ b/client.go @@ -18,6 +18,10 @@ type Client struct { // retry). Retry time.Duration + // Interval on which the response is timed out (zero or negative value means no + // timeout). + ResponseTimeout time.Duration + // MaxPacketErrors controls how many packet parsing and validation errors // the client will ignore before returning the error from Exchange. // @@ -32,6 +36,7 @@ type Client struct { // DefaultClient is the RADIUS client used by the Exchange function. var DefaultClient = &Client{ Retry: time.Second, + ResponseTimeout: 0 * time.Second, MaxPacketErrors: 10, } @@ -95,6 +100,12 @@ func (c *Client) Exchange(ctx context.Context, packet *Packet, addr string) (*Pa }() var packetErrorCount int + if c.ResponseTimeout > 0 { + err = conn.SetReadDeadline(time.Now().Add(c.ResponseTimeout)) + if err != nil { + return nil, err + } + } var incoming [MaxPacketLength]byte for { diff --git a/client_test.go b/client_test.go index 151f135..4432060 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package radius import ( "context" + "net" "strings" "sync/atomic" "testing" @@ -160,3 +161,32 @@ func TestClient_Exchange_nilContext(t *testing.T) { //lint:ignore SA1012 This test is specifically checking for a nil context Exchange(nil, req, "") } + +func TestClient_Exchange_readTimeout(t *testing.T) { + secret := []byte(`12345`) + + var server *TestServer + handler := HandlerFunc(func(w ResponseWriter, r *Request) { + time.Sleep(time.Minute) + resp := r.Response(CodeAccessAccept) + w.Write(resp) + }) + server = NewTestServer(handler, StaticSecretSource(secret)) + defer server.Close() + + req := New(CodeAccessRequest, secret) + + client := Client{ + ResponseTimeout: time.Second, + } + resp, err := client.Exchange(context.Background(), req, server.Addr) + if resp != nil { + t.Fatalf("got non-nil response (%v); expected nil", resp) + } + if err == nil { + t.Fatal("got nil error; expected one") + } + if !err.(net.Error).Timeout() { + t.Fatalf("got err = %v; expected net.Error.Timeout()", err) + } +} diff --git a/packet.go b/packet.go index a6db2a7..b8b01f8 100644 --- a/packet.go +++ b/packet.go @@ -48,11 +48,11 @@ func Parse(b, secret []byte) (*Packet, error) { } length := int(binary.BigEndian.Uint16(b[2:4])) - if length < 20 || length > MaxPacketLength || len(b) != length { + if length < 20 || length > MaxPacketLength || len(b) < length { return nil, errors.New("radius: invalid packet length") } - attrs, err := ParseAttributes(b[20:]) + attrs, err := ParseAttributes(b[20:length]) if err != nil { return nil, err } diff --git a/packet_test.go b/packet_test.go index e9a8c60..2490529 100644 --- a/packet_test.go +++ b/packet_test.go @@ -267,7 +267,7 @@ func TestParse_invalid(t *testing.T) { "invalid packet length", }, { - "\x00\xff\x00\x14\x01\x01\x01\x01\x01\x01" + + "\x00\xff\x00\x16\x01\x01\x01\x01\x01\x01" + "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" + "\x00", "invalid packet length",